feat(sora): 对齐sora2api分镜角色去水印与挑战错误治理

This commit is contained in:
yangjianbo
2026-02-19 20:04:10 +08:00
parent 440b87094a
commit 40498aac9d
12 changed files with 1994 additions and 202 deletions

View File

@@ -12,7 +12,6 @@ import (
"os"
"path"
"path/filepath"
"regexp"
"strconv"
"strings"
"time"
@@ -22,6 +21,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/util/soraerror"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
@@ -29,8 +29,6 @@ import (
"go.uber.org/zap"
)
var soraCloudflareRayPattern = regexp.MustCompile(`(?i)cf-ray[:\s=]+([a-z0-9-]+)`)
// SoraGatewayHandler handles Sora chat completions requests
type SoraGatewayHandler struct {
gatewayService *service.GatewayService
@@ -385,6 +383,10 @@ func (h *SoraGatewayHandler) mapUpstreamError(statusCode int, responseHeaders ht
}
upstreamCode, upstreamMessage := extractUpstreamErrorCodeAndMessage(responseBody)
if strings.EqualFold(upstreamCode, "cf_shield_429") {
baseMsg := "Sora request blocked by Cloudflare shield (429). Please switch to a clean proxy/network and retry."
return http.StatusTooManyRequests, "rate_limit_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody)
}
if shouldPassthroughSoraUpstreamMessage(statusCode, upstreamMessage) {
switch statusCode {
case 401, 403, 404, 500, 502, 503, 504:
@@ -416,27 +418,7 @@ func (h *SoraGatewayHandler) mapUpstreamError(statusCode int, responseHeaders ht
}
func isSoraCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool {
if statusCode != http.StatusForbidden && statusCode != http.StatusTooManyRequests {
return false
}
if strings.EqualFold(strings.TrimSpace(headers.Get("cf-mitigated")), "challenge") {
return true
}
preview := strings.ToLower(truncateSoraErrorBody(body, 4096))
if strings.Contains(preview, "window._cf_chl_opt") ||
strings.Contains(preview, "just a moment") ||
strings.Contains(preview, "enable javascript and cookies to continue") ||
strings.Contains(preview, "__cf_chl_") ||
strings.Contains(preview, "challenge-platform") {
return true
}
contentType := strings.ToLower(strings.TrimSpace(headers.Get("content-type")))
if strings.Contains(contentType, "text/html") &&
(strings.Contains(preview, "<html") || strings.Contains(preview, "<!doctype html")) &&
(strings.Contains(preview, "cloudflare") || strings.Contains(preview, "challenge")) {
return true
}
return false
return soraerror.IsCloudflareChallengeResponse(statusCode, headers, body)
}
func shouldPassthroughSoraUpstreamMessage(statusCode int, message string) bool {
@@ -454,76 +436,11 @@ func shouldPassthroughSoraUpstreamMessage(statusCode int, message string) bool {
}
func formatSoraCloudflareChallengeMessage(base string, headers http.Header, body []byte) string {
rayID := extractSoraCloudflareRayID(headers, body)
if rayID == "" {
return base
}
return fmt.Sprintf("%s (cf-ray: %s)", base, rayID)
}
func extractSoraCloudflareRayID(headers http.Header, body []byte) string {
if headers != nil {
rayID := strings.TrimSpace(headers.Get("cf-ray"))
if rayID != "" {
return rayID
}
rayID = strings.TrimSpace(headers.Get("Cf-Ray"))
if rayID != "" {
return rayID
}
}
preview := truncateSoraErrorBody(body, 8192)
matches := soraCloudflareRayPattern.FindStringSubmatch(preview)
if len(matches) >= 2 {
return strings.TrimSpace(matches[1])
}
return ""
return soraerror.FormatCloudflareChallengeMessage(base, headers, body)
}
func extractUpstreamErrorCodeAndMessage(body []byte) (string, string) {
trimmed := strings.TrimSpace(string(body))
if trimmed == "" {
return "", ""
}
if !gjson.Valid(trimmed) {
return "", truncateSoraErrorMessage(trimmed, 256)
}
code := strings.TrimSpace(gjson.Get(trimmed, "error.code").String())
if code == "" {
code = strings.TrimSpace(gjson.Get(trimmed, "code").String())
}
message := strings.TrimSpace(gjson.Get(trimmed, "error.message").String())
if message == "" {
message = strings.TrimSpace(gjson.Get(trimmed, "message").String())
}
if message == "" {
message = strings.TrimSpace(gjson.Get(trimmed, "error.detail").String())
}
if message == "" {
message = strings.TrimSpace(gjson.Get(trimmed, "detail").String())
}
return code, truncateSoraErrorMessage(message, 512)
}
func truncateSoraErrorMessage(s string, maxLen int) string {
if maxLen <= 0 {
return ""
}
if len(s) <= maxLen {
return s
}
return s[:maxLen] + "...(truncated)"
}
func truncateSoraErrorBody(body []byte, maxLen int) string {
if maxLen <= 0 {
maxLen = 512
}
raw := strings.TrimSpace(string(body))
if len(raw) <= maxLen {
return raw
}
return raw[:maxLen] + "...(truncated)"
return soraerror.ExtractUpstreamErrorCodeAndMessage(body)
}
func (h *SoraGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {

View File

@@ -43,6 +43,45 @@ func (s *stubSoraClient) CreateImageTask(ctx context.Context, account *service.A
func (s *stubSoraClient) CreateVideoTask(ctx context.Context, account *service.Account, req service.SoraVideoRequest) (string, error) {
return "task-video", nil
}
func (s *stubSoraClient) CreateStoryboardTask(ctx context.Context, account *service.Account, req service.SoraStoryboardRequest) (string, error) {
return "task-video", nil
}
func (s *stubSoraClient) UploadCharacterVideo(ctx context.Context, account *service.Account, data []byte) (string, error) {
return "cameo-1", nil
}
func (s *stubSoraClient) GetCameoStatus(ctx context.Context, account *service.Account, cameoID string) (*service.SoraCameoStatus, error) {
return &service.SoraCameoStatus{
Status: "finalized",
StatusMessage: "Completed",
DisplayNameHint: "Character",
UsernameHint: "user.character",
ProfileAssetURL: "https://example.com/avatar.webp",
}, nil
}
func (s *stubSoraClient) DownloadCharacterImage(ctx context.Context, account *service.Account, imageURL string) ([]byte, error) {
return []byte("avatar"), nil
}
func (s *stubSoraClient) UploadCharacterImage(ctx context.Context, account *service.Account, data []byte) (string, error) {
return "asset-pointer", nil
}
func (s *stubSoraClient) FinalizeCharacter(ctx context.Context, account *service.Account, req service.SoraCharacterFinalizeRequest) (string, error) {
return "character-1", nil
}
func (s *stubSoraClient) SetCharacterPublic(ctx context.Context, account *service.Account, cameoID string) error {
return nil
}
func (s *stubSoraClient) DeleteCharacter(ctx context.Context, account *service.Account, characterID string) error {
return nil
}
func (s *stubSoraClient) PostVideoForWatermarkFree(ctx context.Context, account *service.Account, generationID string) (string, error) {
return "s_post", nil
}
func (s *stubSoraClient) DeletePost(ctx context.Context, account *service.Account, postID string) error {
return nil
}
func (s *stubSoraClient) GetWatermarkFreeURLCustom(ctx context.Context, account *service.Account, parseURL, parseToken, postID string) (string, error) {
return "https://example.com/no-watermark.mp4", nil
}
func (s *stubSoraClient) EnhancePrompt(ctx context.Context, account *service.Account, prompt, expansionLevel string, durationS int) (string, error) {
return "enhanced prompt", nil
}
@@ -607,3 +646,31 @@ func TestSoraHandleFailoverExhausted_CloudflareChallengeIncludesRay(t *testing.T
require.Contains(t, msg, "Cloudflare challenge")
require.Contains(t, msg, "cf-ray: 9d01b0e9ecc35829-SEA")
}
func TestSoraHandleFailoverExhausted_CfShield429MappedToRateLimitError(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
headers := http.Header{}
headers.Set("cf-ray", "9d03b68c086027a1-SEA")
body := []byte(`{"error":{"code":"cf_shield_429","message":"shield blocked"}}`)
h := &SoraGatewayHandler{}
h.handleFailoverExhausted(c, http.StatusTooManyRequests, headers, body, true)
lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n")
require.Len(t, lines, 2)
jsonStr := strings.TrimPrefix(lines[1], "data: ")
var parsed map[string]any
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
errorObj, ok := parsed["error"].(map[string]any)
require.True(t, ok)
require.Equal(t, "rate_limit_error", errorObj["type"])
msg, _ := errorObj["message"].(string)
require.Contains(t, msg, "Cloudflare shield")
require.Contains(t, msg, "cf-ray: 9d03b68c086027a1-SEA")
}