feat(Sora): 直连生成并移除sora2api依赖
实现直连 Sora 客户端、媒体落地与清理策略\n更新网关与前端配置以支持 Sora 平台\n补齐单元测试与契约测试,新增 curl 测试脚本\n\n测试: go test ./... -tags=unit
This commit is contained in:
@@ -4,10 +4,12 @@ import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
@@ -39,23 +41,23 @@ type soraStreamingResult struct {
|
||||
firstTokenMs *int
|
||||
}
|
||||
|
||||
// SoraGatewayService handles forwarding requests to sora2api.
|
||||
// SoraGatewayService handles forwarding requests to Sora upstream.
|
||||
type SoraGatewayService struct {
|
||||
sora2api *Sora2APIService
|
||||
httpUpstream HTTPUpstream
|
||||
soraClient SoraClient
|
||||
mediaStorage *SoraMediaStorage
|
||||
rateLimitService *RateLimitService
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
func NewSoraGatewayService(
|
||||
sora2api *Sora2APIService,
|
||||
httpUpstream HTTPUpstream,
|
||||
soraClient SoraClient,
|
||||
mediaStorage *SoraMediaStorage,
|
||||
rateLimitService *RateLimitService,
|
||||
cfg *config.Config,
|
||||
) *SoraGatewayService {
|
||||
return &SoraGatewayService{
|
||||
sora2api: sora2api,
|
||||
httpUpstream: httpUpstream,
|
||||
soraClient: soraClient,
|
||||
mediaStorage: mediaStorage,
|
||||
rateLimitService: rateLimitService,
|
||||
cfg: cfg,
|
||||
}
|
||||
@@ -64,31 +66,53 @@ func NewSoraGatewayService(
|
||||
func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte, clientStream bool) (*ForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
if s.sora2api == nil || !s.sora2api.Enabled() {
|
||||
if s.soraClient == nil || !s.soraClient.Enabled() {
|
||||
if c != nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "api_error",
|
||||
"message": "sora2api 未配置",
|
||||
"message": "Sora 上游未配置",
|
||||
},
|
||||
})
|
||||
}
|
||||
return nil, errors.New("sora2api not configured")
|
||||
return nil, errors.New("sora upstream not configured")
|
||||
}
|
||||
|
||||
var reqBody map[string]any
|
||||
if err := json.Unmarshal(body, &reqBody); err != nil {
|
||||
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body", clientStream)
|
||||
return nil, fmt.Errorf("parse request: %w", err)
|
||||
}
|
||||
reqModel, _ := reqBody["model"].(string)
|
||||
reqStream, _ := reqBody["stream"].(bool)
|
||||
if strings.TrimSpace(reqModel) == "" {
|
||||
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "model is required", clientStream)
|
||||
return nil, errors.New("model is required")
|
||||
}
|
||||
|
||||
mappedModel := account.GetMappedModel(reqModel)
|
||||
if mappedModel != reqModel && mappedModel != "" {
|
||||
reqBody["model"] = mappedModel
|
||||
if updated, err := json.Marshal(reqBody); err == nil {
|
||||
body = updated
|
||||
}
|
||||
if mappedModel != "" && mappedModel != reqModel {
|
||||
reqModel = mappedModel
|
||||
}
|
||||
|
||||
modelCfg, ok := GetSoraModelConfig(reqModel)
|
||||
if !ok {
|
||||
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Unsupported Sora model", clientStream)
|
||||
return nil, fmt.Errorf("unsupported model: %s", reqModel)
|
||||
}
|
||||
if modelCfg.Type == "prompt_enhance" {
|
||||
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Prompt-enhance 模型暂未支持", clientStream)
|
||||
return nil, fmt.Errorf("prompt-enhance not supported")
|
||||
}
|
||||
|
||||
prompt, imageInput, videoInput, remixTargetID := extractSoraInput(reqBody)
|
||||
if strings.TrimSpace(prompt) == "" {
|
||||
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream)
|
||||
return nil, errors.New("prompt is required")
|
||||
}
|
||||
if strings.TrimSpace(videoInput) != "" {
|
||||
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Video input is not supported yet", clientStream)
|
||||
return nil, errors.New("video input not supported")
|
||||
}
|
||||
|
||||
reqCtx, cancel := s.withSoraTimeout(ctx, reqStream)
|
||||
@@ -96,81 +120,122 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
upstreamReq, err := s.sora2api.NewAPIRequest(reqCtx, http.MethodPost, "/v1/chat/completions", body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if c != nil {
|
||||
if ua := strings.TrimSpace(c.GetHeader("User-Agent")); ua != "" {
|
||||
upstreamReq.Header.Set("User-Agent", ua)
|
||||
var imageData []byte
|
||||
imageFilename := ""
|
||||
if strings.TrimSpace(imageInput) != "" {
|
||||
decoded, filename, err := decodeSoraImageInput(reqCtx, imageInput)
|
||||
if err != nil {
|
||||
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", err.Error(), clientStream)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if reqStream {
|
||||
upstreamReq.Header.Set("Accept", "text/event-stream")
|
||||
imageData = decoded
|
||||
imageFilename = filename
|
||||
}
|
||||
|
||||
if c != nil {
|
||||
c.Set(OpsUpstreamRequestBodyKey, string(body))
|
||||
mediaID := ""
|
||||
if len(imageData) > 0 {
|
||||
uploadID, err := s.soraClient.UploadImage(reqCtx, account, imageData, imageFilename)
|
||||
if err != nil {
|
||||
return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream)
|
||||
}
|
||||
mediaID = uploadID
|
||||
}
|
||||
|
||||
proxyURL := ""
|
||||
if account != nil && account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
taskID := ""
|
||||
var err error
|
||||
switch modelCfg.Type {
|
||||
case "image":
|
||||
taskID, err = s.soraClient.CreateImageTask(reqCtx, account, SoraImageRequest{
|
||||
Prompt: prompt,
|
||||
Width: modelCfg.Width,
|
||||
Height: modelCfg.Height,
|
||||
MediaID: mediaID,
|
||||
})
|
||||
case "video":
|
||||
taskID, err = s.soraClient.CreateVideoTask(reqCtx, account, SoraVideoRequest{
|
||||
Prompt: prompt,
|
||||
Orientation: modelCfg.Orientation,
|
||||
Frames: modelCfg.Frames,
|
||||
Model: modelCfg.Model,
|
||||
Size: modelCfg.Size,
|
||||
MediaID: mediaID,
|
||||
RemixTargetID: remixTargetID,
|
||||
})
|
||||
default:
|
||||
err = fmt.Errorf("unsupported model type: %s", modelCfg.Type)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream)
|
||||
}
|
||||
|
||||
var resp *http.Response
|
||||
if s.httpUpstream != nil {
|
||||
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
if clientStream && c != nil {
|
||||
s.prepareSoraStream(c, taskID)
|
||||
}
|
||||
|
||||
var mediaURLs []string
|
||||
mediaType := modelCfg.Type
|
||||
imageCount := 0
|
||||
imageSize := ""
|
||||
if modelCfg.Type == "image" {
|
||||
urls, pollErr := s.pollImageTask(reqCtx, c, account, taskID, clientStream)
|
||||
if pollErr != nil {
|
||||
return nil, s.handleSoraRequestError(ctx, account, pollErr, reqModel, c, clientStream)
|
||||
}
|
||||
mediaURLs = urls
|
||||
imageCount = len(urls)
|
||||
imageSize = soraImageSizeFromModel(reqModel)
|
||||
} else if modelCfg.Type == "video" {
|
||||
urls, pollErr := s.pollVideoTask(reqCtx, c, account, taskID, clientStream)
|
||||
if pollErr != nil {
|
||||
return nil, s.handleSoraRequestError(ctx, account, pollErr, reqModel, c, clientStream)
|
||||
}
|
||||
mediaURLs = urls
|
||||
} else {
|
||||
resp, err = http.DefaultClient.Do(upstreamReq)
|
||||
mediaType = "prompt"
|
||||
}
|
||||
if err != nil {
|
||||
s.setUpstreamRequestError(c, account, err)
|
||||
return nil, fmt.Errorf("upstream request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "failover",
|
||||
Message: upstreamMsg,
|
||||
})
|
||||
s.handleFailoverSideEffects(ctx, resp, account)
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
finalURLs := mediaURLs
|
||||
if len(mediaURLs) > 0 && s.mediaStorage != nil && s.mediaStorage.Enabled() {
|
||||
stored, storeErr := s.mediaStorage.StoreFromURLs(reqCtx, mediaType, mediaURLs)
|
||||
if storeErr != nil {
|
||||
return nil, s.handleSoraRequestError(ctx, account, storeErr, reqModel, c, clientStream)
|
||||
}
|
||||
return s.handleErrorResponse(ctx, resp, c, account, reqModel)
|
||||
finalURLs = s.normalizeSoraMediaURLs(stored)
|
||||
} else {
|
||||
finalURLs = s.normalizeSoraMediaURLs(mediaURLs)
|
||||
}
|
||||
|
||||
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, reqModel, clientStream)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
content := buildSoraContent(mediaType, finalURLs)
|
||||
var firstTokenMs *int
|
||||
if clientStream {
|
||||
ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime)
|
||||
if streamErr != nil {
|
||||
return nil, streamErr
|
||||
}
|
||||
firstTokenMs = ms
|
||||
} else if c != nil {
|
||||
response := buildSoraNonStreamResponse(content, reqModel)
|
||||
if len(finalURLs) > 0 {
|
||||
response["media_url"] = finalURLs[0]
|
||||
if len(finalURLs) > 1 {
|
||||
response["media_urls"] = finalURLs
|
||||
}
|
||||
}
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
result := &ForwardResult{
|
||||
RequestID: resp.Header.Get("x-request-id"),
|
||||
return &ForwardResult{
|
||||
RequestID: taskID,
|
||||
Model: reqModel,
|
||||
Stream: clientStream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: streamResult.firstTokenMs,
|
||||
FirstTokenMs: firstTokenMs,
|
||||
Usage: ClaudeUsage{},
|
||||
MediaType: streamResult.mediaType,
|
||||
MediaURL: firstMediaURL(streamResult.mediaURLs),
|
||||
ImageCount: streamResult.imageCount,
|
||||
ImageSize: streamResult.imageSize,
|
||||
}
|
||||
|
||||
return result, nil
|
||||
MediaType: mediaType,
|
||||
MediaURL: firstMediaURL(finalURLs),
|
||||
ImageCount: imageCount,
|
||||
ImageSize: imageSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) withSoraTimeout(ctx context.Context, stream bool) (context.Context, context.CancelFunc) {
|
||||
@@ -780,3 +845,414 @@ func (s *SoraGatewayService) buildSoraMediaURL(path string, rawQuery string) str
|
||||
}
|
||||
return prefix + path + "?" + encoded
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) prepareSoraStream(c *gin.Context, requestID string) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
if strings.TrimSpace(requestID) != "" {
|
||||
c.Header("x-request-id", requestID)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) writeSoraStream(c *gin.Context, model, content string, startTime time.Time) (*int, error) {
|
||||
if c == nil {
|
||||
return nil, nil
|
||||
}
|
||||
writer := c.Writer
|
||||
flusher, _ := writer.(http.Flusher)
|
||||
|
||||
chunk := map[string]any{
|
||||
"id": fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano()),
|
||||
"object": "chat.completion.chunk",
|
||||
"created": time.Now().Unix(),
|
||||
"model": model,
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
"delta": map[string]any{
|
||||
"content": content,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
encoded, _ := json.Marshal(chunk)
|
||||
if _, err := fmt.Fprintf(writer, "data: %s\n\n", encoded); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
finalChunk := map[string]any{
|
||||
"id": chunk["id"],
|
||||
"object": "chat.completion.chunk",
|
||||
"created": time.Now().Unix(),
|
||||
"model": model,
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
"delta": map[string]any{},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
}
|
||||
finalEncoded, _ := json.Marshal(finalChunk)
|
||||
if _, err := fmt.Fprintf(writer, "data: %s\n\n", finalEncoded); err != nil {
|
||||
return &ms, err
|
||||
}
|
||||
if _, err := fmt.Fprint(writer, "data: [DONE]\n\n"); err != nil {
|
||||
return &ms, err
|
||||
}
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
return &ms, nil
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) writeSoraError(c *gin.Context, status int, errType, message string, stream bool) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
if stream {
|
||||
flusher, _ := c.Writer.(http.Flusher)
|
||||
errorEvent := fmt.Sprintf(`event: error`+"\n"+`data: {"error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message)
|
||||
_, _ = fmt.Fprint(c.Writer, errorEvent)
|
||||
_, _ = fmt.Fprint(c.Writer, "data: [DONE]\n\n")
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
return
|
||||
}
|
||||
c.JSON(status, gin.H{
|
||||
"error": gin.H{
|
||||
"type": errType,
|
||||
"message": message,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) handleSoraRequestError(ctx context.Context, account *Account, err error, model string, c *gin.Context, stream bool) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
var upstreamErr *SoraUpstreamError
|
||||
if errors.As(err, &upstreamErr) {
|
||||
if s.rateLimitService != nil && account != nil {
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, upstreamErr.StatusCode, upstreamErr.Headers, upstreamErr.Body)
|
||||
}
|
||||
if s.shouldFailoverUpstreamError(upstreamErr.StatusCode) {
|
||||
return &UpstreamFailoverError{StatusCode: upstreamErr.StatusCode}
|
||||
}
|
||||
msg := upstreamErr.Message
|
||||
if override := soraProErrorMessage(model, msg); override != "" {
|
||||
msg = override
|
||||
}
|
||||
s.writeSoraError(c, upstreamErr.StatusCode, "upstream_error", msg, stream)
|
||||
return err
|
||||
}
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
s.writeSoraError(c, http.StatusGatewayTimeout, "timeout_error", "Sora generation timeout", stream)
|
||||
return err
|
||||
}
|
||||
s.writeSoraError(c, http.StatusBadGateway, "api_error", err.Error(), stream)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) pollImageTask(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) ([]string, error) {
|
||||
interval := s.pollInterval()
|
||||
maxAttempts := s.pollMaxAttempts()
|
||||
lastPing := time.Now()
|
||||
for attempt := 0; attempt < maxAttempts; attempt++ {
|
||||
status, err := s.soraClient.GetImageTask(ctx, account, taskID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch strings.ToLower(status.Status) {
|
||||
case "succeeded", "completed":
|
||||
return status.URLs, nil
|
||||
case "failed":
|
||||
if status.ErrorMsg != "" {
|
||||
return nil, errors.New(status.ErrorMsg)
|
||||
}
|
||||
return nil, errors.New("Sora image generation failed")
|
||||
}
|
||||
if stream {
|
||||
s.maybeSendPing(c, &lastPing)
|
||||
}
|
||||
if err := sleepWithContext(ctx, interval); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return nil, errors.New("Sora image generation timeout")
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) pollVideoTask(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) ([]string, error) {
|
||||
interval := s.pollInterval()
|
||||
maxAttempts := s.pollMaxAttempts()
|
||||
lastPing := time.Now()
|
||||
for attempt := 0; attempt < maxAttempts; attempt++ {
|
||||
status, err := s.soraClient.GetVideoTask(ctx, account, taskID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch strings.ToLower(status.Status) {
|
||||
case "completed", "succeeded":
|
||||
return status.URLs, nil
|
||||
case "failed":
|
||||
if status.ErrorMsg != "" {
|
||||
return nil, errors.New(status.ErrorMsg)
|
||||
}
|
||||
return nil, errors.New("Sora video generation failed")
|
||||
}
|
||||
if stream {
|
||||
s.maybeSendPing(c, &lastPing)
|
||||
}
|
||||
if err := sleepWithContext(ctx, interval); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return nil, errors.New("Sora video generation timeout")
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) pollInterval() time.Duration {
|
||||
if s == nil || s.cfg == nil {
|
||||
return 2 * time.Second
|
||||
}
|
||||
interval := s.cfg.Sora.Client.PollIntervalSeconds
|
||||
if interval <= 0 {
|
||||
interval = 2
|
||||
}
|
||||
return time.Duration(interval) * time.Second
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) pollMaxAttempts() int {
|
||||
if s == nil || s.cfg == nil {
|
||||
return 600
|
||||
}
|
||||
maxAttempts := s.cfg.Sora.Client.MaxPollAttempts
|
||||
if maxAttempts <= 0 {
|
||||
maxAttempts = 600
|
||||
}
|
||||
return maxAttempts
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) maybeSendPing(c *gin.Context, lastPing *time.Time) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
interval := 10 * time.Second
|
||||
if s != nil && s.cfg != nil && s.cfg.Concurrency.PingInterval > 0 {
|
||||
interval = time.Duration(s.cfg.Concurrency.PingInterval) * time.Second
|
||||
}
|
||||
if time.Since(*lastPing) < interval {
|
||||
return
|
||||
}
|
||||
if _, err := fmt.Fprint(c.Writer, ":\n\n"); err == nil {
|
||||
if flusher, ok := c.Writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
*lastPing = time.Now()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) normalizeSoraMediaURLs(urls []string) []string {
|
||||
if len(urls) == 0 {
|
||||
return urls
|
||||
}
|
||||
output := make([]string, 0, len(urls))
|
||||
for _, raw := range urls {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") {
|
||||
output = append(output, raw)
|
||||
continue
|
||||
}
|
||||
pathVal := raw
|
||||
if !strings.HasPrefix(pathVal, "/") {
|
||||
pathVal = "/" + pathVal
|
||||
}
|
||||
output = append(output, s.buildSoraMediaURL(pathVal, ""))
|
||||
}
|
||||
return output
|
||||
}
|
||||
|
||||
func buildSoraContent(mediaType string, urls []string) string {
|
||||
switch mediaType {
|
||||
case "image":
|
||||
parts := make([]string, 0, len(urls))
|
||||
for _, u := range urls {
|
||||
parts = append(parts, fmt.Sprintf("", u))
|
||||
}
|
||||
return strings.Join(parts, "\n")
|
||||
case "video":
|
||||
if len(urls) == 0 {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("```html\n<video src='%s' controls></video>\n```", urls[0])
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func extractSoraInput(body map[string]any) (prompt, imageInput, videoInput, remixTargetID string) {
|
||||
if body == nil {
|
||||
return "", "", "", ""
|
||||
}
|
||||
if v, ok := body["remix_target_id"].(string); ok {
|
||||
remixTargetID = v
|
||||
}
|
||||
if v, ok := body["image"].(string); ok {
|
||||
imageInput = v
|
||||
}
|
||||
if v, ok := body["video"].(string); ok {
|
||||
videoInput = v
|
||||
}
|
||||
if v, ok := body["prompt"].(string); ok && strings.TrimSpace(v) != "" {
|
||||
prompt = v
|
||||
}
|
||||
if messages, ok := body["messages"].([]any); ok {
|
||||
builder := strings.Builder{}
|
||||
for _, raw := range messages {
|
||||
msg, ok := raw.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
role, _ := msg["role"].(string)
|
||||
if role != "" && role != "user" {
|
||||
continue
|
||||
}
|
||||
content := msg["content"]
|
||||
text, img, vid := parseSoraMessageContent(content)
|
||||
if text != "" {
|
||||
if builder.Len() > 0 {
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
builder.WriteString(text)
|
||||
}
|
||||
if imageInput == "" && img != "" {
|
||||
imageInput = img
|
||||
}
|
||||
if videoInput == "" && vid != "" {
|
||||
videoInput = vid
|
||||
}
|
||||
}
|
||||
if prompt == "" {
|
||||
prompt = builder.String()
|
||||
}
|
||||
}
|
||||
return prompt, imageInput, videoInput, remixTargetID
|
||||
}
|
||||
|
||||
func parseSoraMessageContent(content any) (text, imageInput, videoInput string) {
|
||||
switch val := content.(type) {
|
||||
case string:
|
||||
return val, "", ""
|
||||
case []any:
|
||||
builder := strings.Builder{}
|
||||
for _, item := range val {
|
||||
itemMap, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
t, _ := itemMap["type"].(string)
|
||||
switch t {
|
||||
case "text":
|
||||
if txt, ok := itemMap["text"].(string); ok && strings.TrimSpace(txt) != "" {
|
||||
if builder.Len() > 0 {
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
builder.WriteString(txt)
|
||||
}
|
||||
case "image_url":
|
||||
if imageInput == "" {
|
||||
if urlVal, ok := itemMap["image_url"].(map[string]any); ok {
|
||||
imageInput = fmt.Sprintf("%v", urlVal["url"])
|
||||
} else if urlStr, ok := itemMap["image_url"].(string); ok {
|
||||
imageInput = urlStr
|
||||
}
|
||||
}
|
||||
case "video_url":
|
||||
if videoInput == "" {
|
||||
if urlVal, ok := itemMap["video_url"].(map[string]any); ok {
|
||||
videoInput = fmt.Sprintf("%v", urlVal["url"])
|
||||
} else if urlStr, ok := itemMap["video_url"].(string); ok {
|
||||
videoInput = urlStr
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return builder.String(), imageInput, videoInput
|
||||
default:
|
||||
return "", "", ""
|
||||
}
|
||||
}
|
||||
|
||||
func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, error) {
|
||||
raw := strings.TrimSpace(input)
|
||||
if raw == "" {
|
||||
return nil, "", errors.New("empty image input")
|
||||
}
|
||||
if strings.HasPrefix(raw, "data:") {
|
||||
parts := strings.SplitN(raw, ",", 2)
|
||||
if len(parts) != 2 {
|
||||
return nil, "", errors.New("invalid data url")
|
||||
}
|
||||
meta := parts[0]
|
||||
payload := parts[1]
|
||||
decoded, err := base64.StdEncoding.DecodeString(payload)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
ext := ""
|
||||
if strings.HasPrefix(meta, "data:") {
|
||||
metaParts := strings.SplitN(meta[5:], ";", 2)
|
||||
if len(metaParts) > 0 {
|
||||
if exts, err := mime.ExtensionsByType(metaParts[0]); err == nil && len(exts) > 0 {
|
||||
ext = exts[0]
|
||||
}
|
||||
}
|
||||
}
|
||||
filename := "image" + ext
|
||||
return decoded, filename, nil
|
||||
}
|
||||
if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") {
|
||||
return downloadSoraImageInput(ctx, raw)
|
||||
}
|
||||
decoded, err := base64.StdEncoding.DecodeString(raw)
|
||||
if err != nil {
|
||||
return nil, "", errors.New("invalid base64 image")
|
||||
}
|
||||
return decoded, "image.png", nil
|
||||
}
|
||||
|
||||
func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, "", fmt.Errorf("download image failed: %d", resp.StatusCode)
|
||||
}
|
||||
data, err := io.ReadAll(io.LimitReader(resp.Body, 20<<20))
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
ext := fileExtFromURL(rawURL)
|
||||
if ext == "" {
|
||||
ext = fileExtFromContentType(resp.Header.Get("Content-Type"))
|
||||
}
|
||||
filename := "image" + ext
|
||||
return data, filename, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user