fix(sora): 修复流式重写与计费问题

This commit is contained in:
yangjianbo
2026-01-31 21:46:28 +08:00
parent 618a614cbf
commit 78d0ca3775
19 changed files with 325 additions and 210 deletions

View File

@@ -23,6 +23,8 @@ var soraSSEDataRe = regexp.MustCompile(`^data:\s*`)
var soraImageMarkdownRe = regexp.MustCompile(`!\[[^\]]*\]\(([^)]+)\)`)
var soraVideoHTMLRe = regexp.MustCompile(`(?i)<video[^>]+src=['"]([^'"]+)['"]`)
const soraRewriteBufferLimit = 2048
var soraImageSizeMap = map[string]string{
"gpt-image": "360",
"gpt-image-landscape": "540",
@@ -30,7 +32,6 @@ var soraImageSizeMap = map[string]string{
}
type soraStreamingResult struct {
content string
mediaType string
mediaURLs []string
imageCount int
@@ -307,6 +308,7 @@ func (s *SoraGatewayService) handleStreamingResponse(ctx context.Context, resp *
contentBuilder := strings.Builder{}
var firstTokenMs *int
var upstreamError error
rewriteBuffer := ""
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
@@ -333,12 +335,29 @@ func (s *SoraGatewayService) handleStreamingResponse(ctx context.Context, resp *
if soraSSEDataRe.MatchString(line) {
data := soraSSEDataRe.ReplaceAllString(line, "")
if data == "[DONE]" {
if rewriteBuffer != "" {
flushLine, flushContent, err := s.flushSoraRewriteBuffer(rewriteBuffer, originalModel)
if err != nil {
return nil, err
}
if flushLine != "" {
if flushContent != "" {
if _, err := contentBuilder.WriteString(flushContent); err != nil {
return nil, err
}
}
if err := sendLine(flushLine); err != nil {
return nil, err
}
}
rewriteBuffer = ""
}
if err := sendLine("data: [DONE]"); err != nil {
return nil, err
}
break
}
updatedLine, contentDelta, errEvent := s.processSoraSSEData(data, originalModel)
updatedLine, contentDelta, errEvent := s.processSoraSSEData(data, originalModel, &rewriteBuffer)
if errEvent != nil && upstreamError == nil {
upstreamError = errEvent
}
@@ -347,7 +366,9 @@ func (s *SoraGatewayService) handleStreamingResponse(ctx context.Context, resp *
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
contentBuilder.WriteString(contentDelta)
if _, err := contentBuilder.WriteString(contentDelta); err != nil {
return nil, err
}
}
if err := sendLine(updatedLine); err != nil {
return nil, err
@@ -417,7 +438,6 @@ func (s *SoraGatewayService) handleStreamingResponse(ctx context.Context, resp *
}
return &soraStreamingResult{
content: content,
mediaType: mediaType,
mediaURLs: mediaURLs,
imageCount: imageCount,
@@ -426,7 +446,7 @@ func (s *SoraGatewayService) handleStreamingResponse(ctx context.Context, resp *
}, nil
}
func (s *SoraGatewayService) processSoraSSEData(data string, originalModel string) (string, string, error) {
func (s *SoraGatewayService) processSoraSSEData(data string, originalModel string, rewriteBuffer *string) (string, string, error) {
if strings.TrimSpace(data) == "" {
return "data: ", "", nil
}
@@ -448,7 +468,12 @@ func (s *SoraGatewayService) processSoraSSEData(data string, originalModel strin
contentDelta, updated := extractSoraContent(payload)
if updated {
rewritten := s.rewriteSoraContent(contentDelta)
var rewritten string
if rewriteBuffer != nil {
rewritten = s.rewriteSoraContentWithBuffer(contentDelta, rewriteBuffer)
} else {
rewritten = s.rewriteSoraContent(contentDelta)
}
if rewritten != contentDelta {
applySoraContent(payload, rewritten)
contentDelta = rewritten
@@ -504,6 +529,78 @@ func applySoraContent(payload map[string]any, content string) {
}
}
func (s *SoraGatewayService) rewriteSoraContentWithBuffer(contentDelta string, buffer *string) string {
if buffer == nil {
return s.rewriteSoraContent(contentDelta)
}
if contentDelta == "" && *buffer == "" {
return ""
}
combined := *buffer + contentDelta
rewritten := s.rewriteSoraContent(combined)
bufferStart := s.findSoraRewriteBufferStart(rewritten)
if bufferStart < 0 {
*buffer = ""
return rewritten
}
if len(rewritten)-bufferStart > soraRewriteBufferLimit {
bufferStart = len(rewritten) - soraRewriteBufferLimit
}
output := rewritten[:bufferStart]
*buffer = rewritten[bufferStart:]
return output
}
func (s *SoraGatewayService) findSoraRewriteBufferStart(content string) int {
minIndex := -1
start := 0
for {
idx := strings.Index(content[start:], "![")
if idx < 0 {
break
}
idx += start
if !hasSoraImageMatchAt(content, idx) {
if minIndex == -1 || idx < minIndex {
minIndex = idx
}
}
start = idx + 2
}
lower := strings.ToLower(content)
start = 0
for {
idx := strings.Index(lower[start:], "<video")
if idx < 0 {
break
}
idx += start
if !hasSoraVideoMatchAt(content, idx) {
if minIndex == -1 || idx < minIndex {
minIndex = idx
}
}
start = idx + len("<video")
}
return minIndex
}
func hasSoraImageMatchAt(content string, idx int) bool {
if idx < 0 || idx >= len(content) {
return false
}
loc := soraImageMarkdownRe.FindStringIndex(content[idx:])
return loc != nil && loc[0] == 0
}
func hasSoraVideoMatchAt(content string, idx int) bool {
if idx < 0 || idx >= len(content) {
return false
}
loc := soraVideoHTMLRe.FindStringIndex(content[idx:])
return loc != nil && loc[0] == 0
}
func (s *SoraGatewayService) rewriteSoraContent(content string) string {
if content == "" {
return content
@@ -533,6 +630,31 @@ func (s *SoraGatewayService) rewriteSoraContent(content string) string {
return content
}
func (s *SoraGatewayService) flushSoraRewriteBuffer(buffer string, originalModel string) (string, string, error) {
if buffer == "" {
return "", "", nil
}
rewritten := s.rewriteSoraContent(buffer)
payload := map[string]any{
"choices": []any{
map[string]any{
"delta": map[string]any{
"content": rewritten,
},
"index": 0,
},
},
}
if originalModel != "" {
payload["model"] = originalModel
}
updatedData, err := json.Marshal(payload)
if err != nil {
return "", "", err
}
return "data: " + string(updatedData), rewritten, nil
}
func (s *SoraGatewayService) rewriteSoraURL(raw string) string {
raw = strings.TrimSpace(raw)
if raw == "" {