fix(sora): 修复流式重写与计费问题
This commit is contained in:
@@ -23,6 +23,8 @@ var soraSSEDataRe = regexp.MustCompile(`^data:\s*`)
|
||||
var soraImageMarkdownRe = regexp.MustCompile(`!\[[^\]]*\]\(([^)]+)\)`)
|
||||
var soraVideoHTMLRe = regexp.MustCompile(`(?i)<video[^>]+src=['"]([^'"]+)['"]`)
|
||||
|
||||
const soraRewriteBufferLimit = 2048
|
||||
|
||||
var soraImageSizeMap = map[string]string{
|
||||
"gpt-image": "360",
|
||||
"gpt-image-landscape": "540",
|
||||
@@ -30,7 +32,6 @@ var soraImageSizeMap = map[string]string{
|
||||
}
|
||||
|
||||
type soraStreamingResult struct {
|
||||
content string
|
||||
mediaType string
|
||||
mediaURLs []string
|
||||
imageCount int
|
||||
@@ -307,6 +308,7 @@ func (s *SoraGatewayService) handleStreamingResponse(ctx context.Context, resp *
|
||||
contentBuilder := strings.Builder{}
|
||||
var firstTokenMs *int
|
||||
var upstreamError error
|
||||
rewriteBuffer := ""
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
maxLineSize := defaultMaxLineSize
|
||||
@@ -333,12 +335,29 @@ func (s *SoraGatewayService) handleStreamingResponse(ctx context.Context, resp *
|
||||
if soraSSEDataRe.MatchString(line) {
|
||||
data := soraSSEDataRe.ReplaceAllString(line, "")
|
||||
if data == "[DONE]" {
|
||||
if rewriteBuffer != "" {
|
||||
flushLine, flushContent, err := s.flushSoraRewriteBuffer(rewriteBuffer, originalModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if flushLine != "" {
|
||||
if flushContent != "" {
|
||||
if _, err := contentBuilder.WriteString(flushContent); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if err := sendLine(flushLine); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
rewriteBuffer = ""
|
||||
}
|
||||
if err := sendLine("data: [DONE]"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
break
|
||||
}
|
||||
updatedLine, contentDelta, errEvent := s.processSoraSSEData(data, originalModel)
|
||||
updatedLine, contentDelta, errEvent := s.processSoraSSEData(data, originalModel, &rewriteBuffer)
|
||||
if errEvent != nil && upstreamError == nil {
|
||||
upstreamError = errEvent
|
||||
}
|
||||
@@ -347,7 +366,9 @@ func (s *SoraGatewayService) handleStreamingResponse(ctx context.Context, resp *
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
contentBuilder.WriteString(contentDelta)
|
||||
if _, err := contentBuilder.WriteString(contentDelta); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if err := sendLine(updatedLine); err != nil {
|
||||
return nil, err
|
||||
@@ -417,7 +438,6 @@ func (s *SoraGatewayService) handleStreamingResponse(ctx context.Context, resp *
|
||||
}
|
||||
|
||||
return &soraStreamingResult{
|
||||
content: content,
|
||||
mediaType: mediaType,
|
||||
mediaURLs: mediaURLs,
|
||||
imageCount: imageCount,
|
||||
@@ -426,7 +446,7 @@ func (s *SoraGatewayService) handleStreamingResponse(ctx context.Context, resp *
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) processSoraSSEData(data string, originalModel string) (string, string, error) {
|
||||
func (s *SoraGatewayService) processSoraSSEData(data string, originalModel string, rewriteBuffer *string) (string, string, error) {
|
||||
if strings.TrimSpace(data) == "" {
|
||||
return "data: ", "", nil
|
||||
}
|
||||
@@ -448,7 +468,12 @@ func (s *SoraGatewayService) processSoraSSEData(data string, originalModel strin
|
||||
|
||||
contentDelta, updated := extractSoraContent(payload)
|
||||
if updated {
|
||||
rewritten := s.rewriteSoraContent(contentDelta)
|
||||
var rewritten string
|
||||
if rewriteBuffer != nil {
|
||||
rewritten = s.rewriteSoraContentWithBuffer(contentDelta, rewriteBuffer)
|
||||
} else {
|
||||
rewritten = s.rewriteSoraContent(contentDelta)
|
||||
}
|
||||
if rewritten != contentDelta {
|
||||
applySoraContent(payload, rewritten)
|
||||
contentDelta = rewritten
|
||||
@@ -504,6 +529,78 @@ func applySoraContent(payload map[string]any, content string) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) rewriteSoraContentWithBuffer(contentDelta string, buffer *string) string {
|
||||
if buffer == nil {
|
||||
return s.rewriteSoraContent(contentDelta)
|
||||
}
|
||||
if contentDelta == "" && *buffer == "" {
|
||||
return ""
|
||||
}
|
||||
combined := *buffer + contentDelta
|
||||
rewritten := s.rewriteSoraContent(combined)
|
||||
bufferStart := s.findSoraRewriteBufferStart(rewritten)
|
||||
if bufferStart < 0 {
|
||||
*buffer = ""
|
||||
return rewritten
|
||||
}
|
||||
if len(rewritten)-bufferStart > soraRewriteBufferLimit {
|
||||
bufferStart = len(rewritten) - soraRewriteBufferLimit
|
||||
}
|
||||
output := rewritten[:bufferStart]
|
||||
*buffer = rewritten[bufferStart:]
|
||||
return output
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) findSoraRewriteBufferStart(content string) int {
|
||||
minIndex := -1
|
||||
start := 0
|
||||
for {
|
||||
idx := strings.Index(content[start:], "![")
|
||||
if idx < 0 {
|
||||
break
|
||||
}
|
||||
idx += start
|
||||
if !hasSoraImageMatchAt(content, idx) {
|
||||
if minIndex == -1 || idx < minIndex {
|
||||
minIndex = idx
|
||||
}
|
||||
}
|
||||
start = idx + 2
|
||||
}
|
||||
lower := strings.ToLower(content)
|
||||
start = 0
|
||||
for {
|
||||
idx := strings.Index(lower[start:], "<video")
|
||||
if idx < 0 {
|
||||
break
|
||||
}
|
||||
idx += start
|
||||
if !hasSoraVideoMatchAt(content, idx) {
|
||||
if minIndex == -1 || idx < minIndex {
|
||||
minIndex = idx
|
||||
}
|
||||
}
|
||||
start = idx + len("<video")
|
||||
}
|
||||
return minIndex
|
||||
}
|
||||
|
||||
func hasSoraImageMatchAt(content string, idx int) bool {
|
||||
if idx < 0 || idx >= len(content) {
|
||||
return false
|
||||
}
|
||||
loc := soraImageMarkdownRe.FindStringIndex(content[idx:])
|
||||
return loc != nil && loc[0] == 0
|
||||
}
|
||||
|
||||
func hasSoraVideoMatchAt(content string, idx int) bool {
|
||||
if idx < 0 || idx >= len(content) {
|
||||
return false
|
||||
}
|
||||
loc := soraVideoHTMLRe.FindStringIndex(content[idx:])
|
||||
return loc != nil && loc[0] == 0
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) rewriteSoraContent(content string) string {
|
||||
if content == "" {
|
||||
return content
|
||||
@@ -533,6 +630,31 @@ func (s *SoraGatewayService) rewriteSoraContent(content string) string {
|
||||
return content
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) flushSoraRewriteBuffer(buffer string, originalModel string) (string, string, error) {
|
||||
if buffer == "" {
|
||||
return "", "", nil
|
||||
}
|
||||
rewritten := s.rewriteSoraContent(buffer)
|
||||
payload := map[string]any{
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"delta": map[string]any{
|
||||
"content": rewritten,
|
||||
},
|
||||
"index": 0,
|
||||
},
|
||||
},
|
||||
}
|
||||
if originalModel != "" {
|
||||
payload["model"] = originalModel
|
||||
}
|
||||
updatedData, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
return "data: " + string(updatedData), rewritten, nil
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) rewriteSoraURL(raw string) string {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
|
||||
Reference in New Issue
Block a user