fix(sora): 增强 Cloudflare 挑战识别并收敛 Sora 请求链路
- 在 failover 场景透传上游响应头并识别 Cloudflare challenge/cf-ray - 统一 Sora 任务请求的 UA 与代理使用,sentinel 与业务请求保持一致 - 修复流式错误事件 JSON 转义问题并补充相关单元测试
This commit is contained in:
@@ -12,6 +12,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -28,6 +29,8 @@ import (
|
|||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var soraCloudflareRayPattern = regexp.MustCompile(`(?i)cf-ray[:\s=]+([a-z0-9-]+)`)
|
||||||
|
|
||||||
// SoraGatewayHandler handles Sora chat completions requests
|
// SoraGatewayHandler handles Sora chat completions requests
|
||||||
type SoraGatewayHandler struct {
|
type SoraGatewayHandler struct {
|
||||||
gatewayService *service.GatewayService
|
gatewayService *service.GatewayService
|
||||||
@@ -214,6 +217,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
failedAccountIDs := make(map[int64]struct{})
|
failedAccountIDs := make(map[int64]struct{})
|
||||||
lastFailoverStatus := 0
|
lastFailoverStatus := 0
|
||||||
var lastFailoverBody []byte
|
var lastFailoverBody []byte
|
||||||
|
var lastFailoverHeaders http.Header
|
||||||
|
|
||||||
for {
|
for {
|
||||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "")
|
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "")
|
||||||
@@ -226,7 +230,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverBody, streamStarted)
|
h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverHeaders, lastFailoverBody, streamStarted)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
account := selection.Account
|
account := selection.Account
|
||||||
@@ -289,11 +293,13 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
failedAccountIDs[account.ID] = struct{}{}
|
failedAccountIDs[account.ID] = struct{}{}
|
||||||
if switchCount >= maxAccountSwitches {
|
if switchCount >= maxAccountSwitches {
|
||||||
lastFailoverStatus = failoverErr.StatusCode
|
lastFailoverStatus = failoverErr.StatusCode
|
||||||
|
lastFailoverHeaders = failoverErr.ResponseHeaders
|
||||||
lastFailoverBody = failoverErr.ResponseBody
|
lastFailoverBody = failoverErr.ResponseBody
|
||||||
h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverBody, streamStarted)
|
h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverHeaders, lastFailoverBody, streamStarted)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
lastFailoverStatus = failoverErr.StatusCode
|
lastFailoverStatus = failoverErr.StatusCode
|
||||||
|
lastFailoverHeaders = failoverErr.ResponseHeaders
|
||||||
lastFailoverBody = failoverErr.ResponseBody
|
lastFailoverBody = failoverErr.ResponseBody
|
||||||
switchCount++
|
switchCount++
|
||||||
upstreamErrCode, upstreamErrMsg := extractUpstreamErrorCodeAndMessage(lastFailoverBody)
|
upstreamErrCode, upstreamErrMsg := extractUpstreamErrorCodeAndMessage(lastFailoverBody)
|
||||||
@@ -367,14 +373,19 @@ func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, s
|
|||||||
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
|
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, responseBody []byte, streamStarted bool) {
|
func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, responseHeaders http.Header, responseBody []byte, streamStarted bool) {
|
||||||
status, errType, errMsg := h.mapUpstreamError(statusCode, responseBody)
|
status, errType, errMsg := h.mapUpstreamError(statusCode, responseHeaders, responseBody)
|
||||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *SoraGatewayHandler) mapUpstreamError(statusCode int, responseBody []byte) (int, string, string) {
|
func (h *SoraGatewayHandler) mapUpstreamError(statusCode int, responseHeaders http.Header, responseBody []byte) (int, string, string) {
|
||||||
|
if isSoraCloudflareChallengeResponse(statusCode, responseHeaders, responseBody) {
|
||||||
|
baseMsg := fmt.Sprintf("Sora request blocked by Cloudflare challenge (HTTP %d). Please switch to a clean proxy/network and retry.", statusCode)
|
||||||
|
return http.StatusBadGateway, "upstream_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody)
|
||||||
|
}
|
||||||
|
|
||||||
upstreamCode, upstreamMessage := extractUpstreamErrorCodeAndMessage(responseBody)
|
upstreamCode, upstreamMessage := extractUpstreamErrorCodeAndMessage(responseBody)
|
||||||
if upstreamMessage != "" {
|
if shouldPassthroughSoraUpstreamMessage(statusCode, upstreamMessage) {
|
||||||
switch statusCode {
|
switch statusCode {
|
||||||
case 401, 403, 404, 500, 502, 503, 504:
|
case 401, 403, 404, 500, 502, 503, 504:
|
||||||
return http.StatusBadGateway, "upstream_error", upstreamMessage
|
return http.StatusBadGateway, "upstream_error", upstreamMessage
|
||||||
@@ -404,6 +415,71 @@ func (h *SoraGatewayHandler) mapUpstreamError(statusCode int, responseBody []byt
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isSoraCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool {
|
||||||
|
if statusCode != http.StatusForbidden && statusCode != http.StatusTooManyRequests {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if strings.EqualFold(strings.TrimSpace(headers.Get("cf-mitigated")), "challenge") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
preview := strings.ToLower(truncateSoraErrorBody(body, 4096))
|
||||||
|
if strings.Contains(preview, "window._cf_chl_opt") ||
|
||||||
|
strings.Contains(preview, "just a moment") ||
|
||||||
|
strings.Contains(preview, "enable javascript and cookies to continue") ||
|
||||||
|
strings.Contains(preview, "__cf_chl_") ||
|
||||||
|
strings.Contains(preview, "challenge-platform") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
contentType := strings.ToLower(strings.TrimSpace(headers.Get("content-type")))
|
||||||
|
if strings.Contains(contentType, "text/html") &&
|
||||||
|
(strings.Contains(preview, "<html") || strings.Contains(preview, "<!doctype html")) &&
|
||||||
|
(strings.Contains(preview, "cloudflare") || strings.Contains(preview, "challenge")) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldPassthroughSoraUpstreamMessage(statusCode int, message string) bool {
|
||||||
|
message = strings.TrimSpace(message)
|
||||||
|
if message == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if statusCode == http.StatusForbidden || statusCode == http.StatusTooManyRequests {
|
||||||
|
lower := strings.ToLower(message)
|
||||||
|
if strings.Contains(lower, "<html") || strings.Contains(lower, "<!doctype html") || strings.Contains(lower, "window._cf_chl_opt") {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatSoraCloudflareChallengeMessage(base string, headers http.Header, body []byte) string {
|
||||||
|
rayID := extractSoraCloudflareRayID(headers, body)
|
||||||
|
if rayID == "" {
|
||||||
|
return base
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s (cf-ray: %s)", base, rayID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractSoraCloudflareRayID(headers http.Header, body []byte) string {
|
||||||
|
if headers != nil {
|
||||||
|
rayID := strings.TrimSpace(headers.Get("cf-ray"))
|
||||||
|
if rayID != "" {
|
||||||
|
return rayID
|
||||||
|
}
|
||||||
|
rayID = strings.TrimSpace(headers.Get("Cf-Ray"))
|
||||||
|
if rayID != "" {
|
||||||
|
return rayID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
preview := truncateSoraErrorBody(body, 8192)
|
||||||
|
matches := soraCloudflareRayPattern.FindStringSubmatch(preview)
|
||||||
|
if len(matches) >= 2 {
|
||||||
|
return strings.TrimSpace(matches[1])
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
func extractUpstreamErrorCodeAndMessage(body []byte) (string, string) {
|
func extractUpstreamErrorCodeAndMessage(body []byte) (string, string) {
|
||||||
trimmed := strings.TrimSpace(string(body))
|
trimmed := strings.TrimSpace(string(body))
|
||||||
if trimmed == "" {
|
if trimmed == "" {
|
||||||
@@ -439,6 +515,17 @@ func truncateSoraErrorMessage(s string, maxLen int) string {
|
|||||||
return s[:maxLen] + "...(truncated)"
|
return s[:maxLen] + "...(truncated)"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func truncateSoraErrorBody(body []byte, maxLen int) string {
|
||||||
|
if maxLen <= 0 {
|
||||||
|
maxLen = 512
|
||||||
|
}
|
||||||
|
raw := strings.TrimSpace(string(body))
|
||||||
|
if len(raw) <= maxLen {
|
||||||
|
return raw
|
||||||
|
}
|
||||||
|
return raw[:maxLen] + "...(truncated)"
|
||||||
|
}
|
||||||
|
|
||||||
func (h *SoraGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
|
func (h *SoraGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
|
||||||
if streamStarted {
|
if streamStarted {
|
||||||
flusher, ok := c.Writer.(http.Flusher)
|
flusher, ok := c.Writer.(http.Flusher)
|
||||||
|
|||||||
@@ -561,7 +561,7 @@ func TestSoraHandleFailoverExhausted_StreamPassesUpstreamMessage(t *testing.T) {
|
|||||||
|
|
||||||
h := &SoraGatewayHandler{}
|
h := &SoraGatewayHandler{}
|
||||||
resp := []byte(`{"error":{"message":"invalid \"prompt\"\nline2","code":"bad_request"}}`)
|
resp := []byte(`{"error":{"message":"invalid \"prompt\"\nline2","code":"bad_request"}}`)
|
||||||
h.handleFailoverExhausted(c, http.StatusBadGateway, resp, true)
|
h.handleFailoverExhausted(c, http.StatusBadGateway, nil, resp, true)
|
||||||
|
|
||||||
body := w.Body.String()
|
body := w.Body.String()
|
||||||
require.True(t, strings.HasPrefix(body, "event: error\n"))
|
require.True(t, strings.HasPrefix(body, "event: error\n"))
|
||||||
@@ -579,3 +579,31 @@ func TestSoraHandleFailoverExhausted_StreamPassesUpstreamMessage(t *testing.T) {
|
|||||||
require.Equal(t, "upstream_error", errorObj["type"])
|
require.Equal(t, "upstream_error", errorObj["type"])
|
||||||
require.Equal(t, "invalid \"prompt\"\nline2", errorObj["message"])
|
require.Equal(t, "invalid \"prompt\"\nline2", errorObj["message"])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSoraHandleFailoverExhausted_CloudflareChallengeIncludesRay(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
|
||||||
|
headers := http.Header{}
|
||||||
|
headers.Set("cf-ray", "9d01b0e9ecc35829-SEA")
|
||||||
|
body := []byte(`<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={};</script></body></html>`)
|
||||||
|
|
||||||
|
h := &SoraGatewayHandler{}
|
||||||
|
h.handleFailoverExhausted(c, http.StatusForbidden, headers, body, true)
|
||||||
|
|
||||||
|
lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n")
|
||||||
|
require.Len(t, lines, 2)
|
||||||
|
jsonStr := strings.TrimPrefix(lines[1], "data: ")
|
||||||
|
|
||||||
|
var parsed map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
|
||||||
|
|
||||||
|
errorObj, ok := parsed["error"].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "upstream_error", errorObj["type"])
|
||||||
|
msg, _ := errorObj["message"].(string)
|
||||||
|
require.Contains(t, msg, "Cloudflare challenge")
|
||||||
|
require.Contains(t, msg, "cf-ray: 9d01b0e9ecc35829-SEA")
|
||||||
|
}
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -522,6 +523,7 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
|
|||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
if isCloudflareChallengeResponse(resp.StatusCode, body) {
|
if isCloudflareChallengeResponse(resp.StatusCode, body) {
|
||||||
|
s.logSoraCloudflareChallenge(account, proxyURL, soraMeAPIURL, resp.Header, body)
|
||||||
return s.sendErrorAndEnd(c, formatCloudflareChallengeMessage("Sora request blocked by Cloudflare challenge (HTTP 403). Please switch to a clean proxy/network and retry.", resp.Header, body))
|
return s.sendErrorAndEnd(c, formatCloudflareChallengeMessage("Sora request blocked by Cloudflare challenge (HTTP 403). Please switch to a clean proxy/network and retry.", resp.Header, body))
|
||||||
}
|
}
|
||||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, truncateSoraErrorBody(body, 512)))
|
return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, truncateSoraErrorBody(body, 512)))
|
||||||
@@ -567,6 +569,7 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if isCloudflareChallengeResponse(subResp.StatusCode, subBody) {
|
if isCloudflareChallengeResponse(subResp.StatusCode, subBody) {
|
||||||
|
s.logSoraCloudflareChallenge(account, proxyURL, soraBillingAPIURL, subResp.Header, subBody)
|
||||||
s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage("Subscription check blocked by Cloudflare challenge (HTTP 403)", subResp.Header, subBody)})
|
s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage("Subscription check blocked by Cloudflare challenge (HTTP 403)", subResp.Header, subBody)})
|
||||||
} else {
|
} else {
|
||||||
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check returned %d", subResp.StatusCode)})
|
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check returned %d", subResp.StatusCode)})
|
||||||
@@ -824,6 +827,75 @@ func extractCloudflareRayID(headers http.Header, body []byte) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func extractSoraEgressIPHint(headers http.Header) string {
|
||||||
|
if headers == nil {
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
candidates := []string{
|
||||||
|
"x-openai-public-ip",
|
||||||
|
"x-envoy-external-address",
|
||||||
|
"cf-connecting-ip",
|
||||||
|
"x-forwarded-for",
|
||||||
|
}
|
||||||
|
for _, key := range candidates {
|
||||||
|
if value := strings.TrimSpace(headers.Get(key)); value != "" {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
|
||||||
|
func sanitizeProxyURLForLog(raw string) string {
|
||||||
|
raw = strings.TrimSpace(raw)
|
||||||
|
if raw == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
u, err := url.Parse(raw)
|
||||||
|
if err != nil {
|
||||||
|
return "<invalid_proxy_url>"
|
||||||
|
}
|
||||||
|
if u.User != nil {
|
||||||
|
u.User = nil
|
||||||
|
}
|
||||||
|
return u.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func endpointPathForLog(endpoint string) string {
|
||||||
|
parsed, err := url.Parse(strings.TrimSpace(endpoint))
|
||||||
|
if err != nil || parsed.Path == "" {
|
||||||
|
return endpoint
|
||||||
|
}
|
||||||
|
return parsed.Path
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AccountTestService) logSoraCloudflareChallenge(account *Account, proxyURL, endpoint string, headers http.Header, body []byte) {
|
||||||
|
accountID := int64(0)
|
||||||
|
platform := ""
|
||||||
|
proxyID := "none"
|
||||||
|
if account != nil {
|
||||||
|
accountID = account.ID
|
||||||
|
platform = account.Platform
|
||||||
|
if account.ProxyID != nil {
|
||||||
|
proxyID = fmt.Sprintf("%d", *account.ProxyID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cfRay := extractCloudflareRayID(headers, body)
|
||||||
|
if cfRay == "" {
|
||||||
|
cfRay = "unknown"
|
||||||
|
}
|
||||||
|
log.Printf(
|
||||||
|
"[SoraCFChallenge] account_id=%d platform=%s endpoint=%s path=%s proxy_id=%s proxy_url=%s cf_ray=%s egress_ip_hint=%s",
|
||||||
|
accountID,
|
||||||
|
platform,
|
||||||
|
endpoint,
|
||||||
|
endpointPathForLog(endpoint),
|
||||||
|
proxyID,
|
||||||
|
sanitizeProxyURLForLog(proxyURL),
|
||||||
|
cfRay,
|
||||||
|
extractSoraEgressIPHint(headers),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
func truncateSoraErrorBody(body []byte, max int) string {
|
func truncateSoraErrorBody(body []byte, max int) string {
|
||||||
if max <= 0 {
|
if max <= 0 {
|
||||||
max = 512
|
max = 512
|
||||||
|
|||||||
@@ -202,3 +202,22 @@ func TestAccountTestService_testSoraAccountConnection_SubscriptionCloudflareChal
|
|||||||
require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d")
|
require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d")
|
||||||
require.Contains(t, body, `"type":"test_complete","success":true`)
|
require.Contains(t, body, `"type":"test_complete","success":true`)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSanitizeProxyURLForLog(t *testing.T) {
|
||||||
|
require.Equal(t, "http://proxy.example.com:8080", sanitizeProxyURLForLog("http://user:pass@proxy.example.com:8080"))
|
||||||
|
require.Equal(t, "", sanitizeProxyURLForLog(""))
|
||||||
|
require.Equal(t, "<invalid_proxy_url>", sanitizeProxyURLForLog("://invalid"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractSoraEgressIPHint(t *testing.T) {
|
||||||
|
h := make(http.Header)
|
||||||
|
h.Set("x-openai-public-ip", "203.0.113.10")
|
||||||
|
require.Equal(t, "203.0.113.10", extractSoraEgressIPHint(h))
|
||||||
|
|
||||||
|
h2 := make(http.Header)
|
||||||
|
h2.Set("x-envoy-external-address", "198.51.100.9")
|
||||||
|
require.Equal(t, "198.51.100.9", extractSoraEgressIPHint(h2))
|
||||||
|
|
||||||
|
require.Equal(t, "unknown", extractSoraEgressIPHint(nil))
|
||||||
|
require.Equal(t, "unknown", extractSoraEgressIPHint(http.Header{}))
|
||||||
|
}
|
||||||
|
|||||||
@@ -376,8 +376,9 @@ type ForwardResult struct {
|
|||||||
type UpstreamFailoverError struct {
|
type UpstreamFailoverError struct {
|
||||||
StatusCode int
|
StatusCode int
|
||||||
ResponseBody []byte // 上游响应体,用于错误透传规则匹配
|
ResponseBody []byte // 上游响应体,用于错误透传规则匹配
|
||||||
ForceCacheBilling bool // Antigravity 粘性会话切换时设为 true
|
ResponseHeaders http.Header
|
||||||
RetryableOnSameAccount bool // 临时性错误(如 Google 间歇性 400、空响应),应在同一账号上重试 N 次再切换
|
ForceCacheBilling bool // Antigravity 粘性会话切换时设为 true
|
||||||
|
RetryableOnSameAccount bool // 临时性错误(如 Google 间歇性 400、空响应),应在同一账号上重试 N 次再切换
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *UpstreamFailoverError) Error() string {
|
func (e *UpstreamFailoverError) Error() string {
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"hash/fnv"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
@@ -97,6 +98,7 @@ var soraDesktopUserAgents = []string{
|
|||||||
var soraRand = rand.New(rand.NewSource(time.Now().UnixNano()))
|
var soraRand = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||||
var soraRandMu sync.Mutex
|
var soraRandMu sync.Mutex
|
||||||
var soraPerfStart = time.Now()
|
var soraPerfStart = time.Now()
|
||||||
|
var soraPowTokenGenerator = soraGetPowToken
|
||||||
|
|
||||||
// SoraClient 定义直连 Sora 的任务操作接口。
|
// SoraClient 定义直连 Sora 的任务操作接口。
|
||||||
type SoraClient interface {
|
type SoraClient interface {
|
||||||
@@ -224,9 +226,11 @@ func (c *SoraDirectClient) PreflightCheck(ctx context.Context, account *Account,
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
headers := c.buildBaseHeaders(token, c.defaultUserAgent())
|
userAgent := c.taskUserAgent()
|
||||||
|
proxyURL := c.resolveProxyURL(account)
|
||||||
|
headers := c.buildBaseHeaders(token, userAgent)
|
||||||
headers.Set("Accept", "application/json")
|
headers.Set("Accept", "application/json")
|
||||||
body, _, err := c.doRequest(ctx, account, http.MethodGet, c.buildURL("/nf/check"), headers, nil, false)
|
body, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodGet, c.buildURL("/nf/check"), headers, nil, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var upstreamErr *SoraUpstreamError
|
var upstreamErr *SoraUpstreamError
|
||||||
if errors.As(err, &upstreamErr) && upstreamErr.StatusCode == http.StatusNotFound {
|
if errors.As(err, &upstreamErr) && upstreamErr.StatusCode == http.StatusNotFound {
|
||||||
@@ -264,6 +268,8 @@ func (c *SoraDirectClient) UploadImage(ctx context.Context, account *Account, da
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
userAgent := c.taskUserAgent()
|
||||||
|
proxyURL := c.resolveProxyURL(account)
|
||||||
if filename == "" {
|
if filename == "" {
|
||||||
filename = "image.png"
|
filename = "image.png"
|
||||||
}
|
}
|
||||||
@@ -290,10 +296,10 @@ func (c *SoraDirectClient) UploadImage(ctx context.Context, account *Account, da
|
|||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
headers := c.buildBaseHeaders(token, c.defaultUserAgent())
|
headers := c.buildBaseHeaders(token, userAgent)
|
||||||
headers.Set("Content-Type", writer.FormDataContentType())
|
headers.Set("Content-Type", writer.FormDataContentType())
|
||||||
|
|
||||||
respBody, _, err := c.doRequest(ctx, account, http.MethodPost, c.buildURL("/uploads"), headers, &body, false)
|
respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/uploads"), headers, &body, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
@@ -309,6 +315,8 @@ func (c *SoraDirectClient) CreateImageTask(ctx context.Context, account *Account
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
userAgent := c.taskUserAgent()
|
||||||
|
proxyURL := c.resolveProxyURL(account)
|
||||||
operation := "simple_compose"
|
operation := "simple_compose"
|
||||||
inpaintItems := []map[string]any{}
|
inpaintItems := []map[string]any{}
|
||||||
if strings.TrimSpace(req.MediaID) != "" {
|
if strings.TrimSpace(req.MediaID) != "" {
|
||||||
@@ -329,7 +337,7 @@ func (c *SoraDirectClient) CreateImageTask(ctx context.Context, account *Account
|
|||||||
"n_frames": 1,
|
"n_frames": 1,
|
||||||
"inpaint_items": inpaintItems,
|
"inpaint_items": inpaintItems,
|
||||||
}
|
}
|
||||||
headers := c.buildBaseHeaders(token, c.defaultUserAgent())
|
headers := c.buildBaseHeaders(token, userAgent)
|
||||||
headers.Set("Content-Type", "application/json")
|
headers.Set("Content-Type", "application/json")
|
||||||
headers.Set("Origin", "https://sora.chatgpt.com")
|
headers.Set("Origin", "https://sora.chatgpt.com")
|
||||||
headers.Set("Referer", "https://sora.chatgpt.com/")
|
headers.Set("Referer", "https://sora.chatgpt.com/")
|
||||||
@@ -338,13 +346,13 @@ func (c *SoraDirectClient) CreateImageTask(ctx context.Context, account *Account
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
sentinel, err := c.generateSentinelToken(ctx, account, token)
|
sentinel, err := c.generateSentinelToken(ctx, account, token, userAgent, proxyURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
headers.Set("openai-sentinel-token", sentinel)
|
headers.Set("openai-sentinel-token", sentinel)
|
||||||
|
|
||||||
respBody, _, err := c.doRequest(ctx, account, http.MethodPost, c.buildURL("/video_gen"), headers, bytes.NewReader(body), true)
|
respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/video_gen"), headers, bytes.NewReader(body), true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
@@ -360,6 +368,8 @@ func (c *SoraDirectClient) CreateVideoTask(ctx context.Context, account *Account
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
userAgent := c.taskUserAgent()
|
||||||
|
proxyURL := c.resolveProxyURL(account)
|
||||||
orientation := req.Orientation
|
orientation := req.Orientation
|
||||||
if orientation == "" {
|
if orientation == "" {
|
||||||
orientation = "landscape"
|
orientation = "landscape"
|
||||||
@@ -399,7 +409,7 @@ func (c *SoraDirectClient) CreateVideoTask(ctx context.Context, account *Account
|
|||||||
payload["cameo_replacements"] = map[string]any{}
|
payload["cameo_replacements"] = map[string]any{}
|
||||||
}
|
}
|
||||||
|
|
||||||
headers := c.buildBaseHeaders(token, c.defaultUserAgent())
|
headers := c.buildBaseHeaders(token, userAgent)
|
||||||
headers.Set("Content-Type", "application/json")
|
headers.Set("Content-Type", "application/json")
|
||||||
headers.Set("Origin", "https://sora.chatgpt.com")
|
headers.Set("Origin", "https://sora.chatgpt.com")
|
||||||
headers.Set("Referer", "https://sora.chatgpt.com/")
|
headers.Set("Referer", "https://sora.chatgpt.com/")
|
||||||
@@ -407,13 +417,13 @@ func (c *SoraDirectClient) CreateVideoTask(ctx context.Context, account *Account
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
sentinel, err := c.generateSentinelToken(ctx, account, token)
|
sentinel, err := c.generateSentinelToken(ctx, account, token, userAgent, proxyURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
headers.Set("openai-sentinel-token", sentinel)
|
headers.Set("openai-sentinel-token", sentinel)
|
||||||
|
|
||||||
respBody, _, err := c.doRequest(ctx, account, http.MethodPost, c.buildURL("/nf/create"), headers, bytes.NewReader(body), true)
|
respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/nf/create"), headers, bytes.NewReader(body), true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
@@ -429,6 +439,8 @@ func (c *SoraDirectClient) EnhancePrompt(ctx context.Context, account *Account,
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
userAgent := c.taskUserAgent()
|
||||||
|
proxyURL := c.resolveProxyURL(account)
|
||||||
if strings.TrimSpace(expansionLevel) == "" {
|
if strings.TrimSpace(expansionLevel) == "" {
|
||||||
expansionLevel = "medium"
|
expansionLevel = "medium"
|
||||||
}
|
}
|
||||||
@@ -446,13 +458,13 @@ func (c *SoraDirectClient) EnhancePrompt(ctx context.Context, account *Account,
|
|||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
headers := c.buildBaseHeaders(token, c.defaultUserAgent())
|
headers := c.buildBaseHeaders(token, userAgent)
|
||||||
headers.Set("Content-Type", "application/json")
|
headers.Set("Content-Type", "application/json")
|
||||||
headers.Set("Accept", "application/json")
|
headers.Set("Accept", "application/json")
|
||||||
headers.Set("Origin", "https://sora.chatgpt.com")
|
headers.Set("Origin", "https://sora.chatgpt.com")
|
||||||
headers.Set("Referer", "https://sora.chatgpt.com/")
|
headers.Set("Referer", "https://sora.chatgpt.com/")
|
||||||
|
|
||||||
respBody, _, err := c.doRequest(ctx, account, http.MethodPost, c.buildURL("/editor/enhance_prompt"), headers, bytes.NewReader(body), false)
|
respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/editor/enhance_prompt"), headers, bytes.NewReader(body), false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
@@ -489,12 +501,14 @@ func (c *SoraDirectClient) fetchRecentImageTask(ctx context.Context, account *Ac
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, false, err
|
return nil, false, err
|
||||||
}
|
}
|
||||||
headers := c.buildBaseHeaders(token, c.defaultUserAgent())
|
userAgent := c.taskUserAgent()
|
||||||
|
proxyURL := c.resolveProxyURL(account)
|
||||||
|
headers := c.buildBaseHeaders(token, userAgent)
|
||||||
if limit <= 0 {
|
if limit <= 0 {
|
||||||
limit = 20
|
limit = 20
|
||||||
}
|
}
|
||||||
endpoint := fmt.Sprintf("/v2/recent_tasks?limit=%d", limit)
|
endpoint := fmt.Sprintf("/v2/recent_tasks?limit=%d", limit)
|
||||||
respBody, _, err := c.doRequest(ctx, account, http.MethodGet, c.buildURL(endpoint), headers, nil, false)
|
respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodGet, c.buildURL(endpoint), headers, nil, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, false, err
|
return nil, false, err
|
||||||
}
|
}
|
||||||
@@ -551,9 +565,11 @@ func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, t
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
headers := c.buildBaseHeaders(token, c.defaultUserAgent())
|
userAgent := c.taskUserAgent()
|
||||||
|
proxyURL := c.resolveProxyURL(account)
|
||||||
|
headers := c.buildBaseHeaders(token, userAgent)
|
||||||
|
|
||||||
respBody, _, err := c.doRequest(ctx, account, http.MethodGet, c.buildURL("/nf/pending/v2"), headers, nil, false)
|
respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodGet, c.buildURL("/nf/pending/v2"), headers, nil, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -582,7 +598,7 @@ func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, t
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
respBody, _, err = c.doRequest(ctx, account, http.MethodGet, c.buildURL("/project_y/profile/drafts?limit=15"), headers, nil, false)
|
respBody, _, err = c.doRequestWithProxy(ctx, account, proxyURL, http.MethodGet, c.buildURL("/project_y/profile/drafts?limit=15"), headers, nil, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -653,6 +669,25 @@ func (c *SoraDirectClient) defaultUserAgent() string {
|
|||||||
return ua
|
return ua
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *SoraDirectClient) taskUserAgent() string {
|
||||||
|
if c != nil && c.cfg != nil {
|
||||||
|
if ua := strings.TrimSpace(c.cfg.Sora.Client.UserAgent); ua != "" {
|
||||||
|
return ua
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(soraDesktopUserAgents) > 0 {
|
||||||
|
return soraDesktopUserAgents[0]
|
||||||
|
}
|
||||||
|
return soraDefaultUserAgent
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SoraDirectClient) resolveProxyURL(account *Account) string {
|
||||||
|
if account == nil || account.ProxyID == nil || account.Proxy == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(account.Proxy.URL())
|
||||||
|
}
|
||||||
|
|
||||||
func (c *SoraDirectClient) getAccessToken(ctx context.Context, account *Account) (string, error) {
|
func (c *SoraDirectClient) getAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||||
if account == nil {
|
if account == nil {
|
||||||
return "", errors.New("account is nil")
|
return "", errors.New("account is nil")
|
||||||
@@ -925,9 +960,26 @@ func (c *SoraDirectClient) buildBaseHeaders(token, userAgent string) http.Header
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, method, urlStr string, headers http.Header, body io.Reader, allowRetry bool) ([]byte, http.Header, error) {
|
func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, method, urlStr string, headers http.Header, body io.Reader, allowRetry bool) ([]byte, http.Header, error) {
|
||||||
|
return c.doRequestWithProxy(ctx, account, c.resolveProxyURL(account), method, urlStr, headers, body, allowRetry)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SoraDirectClient) doRequestWithProxy(
|
||||||
|
ctx context.Context,
|
||||||
|
account *Account,
|
||||||
|
proxyURL string,
|
||||||
|
method,
|
||||||
|
urlStr string,
|
||||||
|
headers http.Header,
|
||||||
|
body io.Reader,
|
||||||
|
allowRetry bool,
|
||||||
|
) ([]byte, http.Header, error) {
|
||||||
if strings.TrimSpace(urlStr) == "" {
|
if strings.TrimSpace(urlStr) == "" {
|
||||||
return nil, nil, errors.New("empty upstream url")
|
return nil, nil, errors.New("empty upstream url")
|
||||||
}
|
}
|
||||||
|
proxyURL = strings.TrimSpace(proxyURL)
|
||||||
|
if proxyURL == "" {
|
||||||
|
proxyURL = c.resolveProxyURL(account)
|
||||||
|
}
|
||||||
timeout := 0
|
timeout := 0
|
||||||
if c != nil && c.cfg != nil {
|
if c != nil && c.cfg != nil {
|
||||||
timeout = c.cfg.Sora.Client.TimeoutSeconds
|
timeout = c.cfg.Sora.Client.TimeoutSeconds
|
||||||
@@ -968,7 +1020,7 @@ func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, meth
|
|||||||
attempts,
|
attempts,
|
||||||
timeout,
|
timeout,
|
||||||
len(bodyBytes),
|
len(bodyBytes),
|
||||||
account != nil && account.ProxyID != nil && account.Proxy != nil,
|
proxyURL != "",
|
||||||
formatSoraHeaders(headers),
|
formatSoraHeaders(headers),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -984,10 +1036,6 @@ func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, meth
|
|||||||
req.Header = headers.Clone()
|
req.Header = headers.Clone()
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
|
|
||||||
proxyURL := ""
|
|
||||||
if account != nil && account.ProxyID != nil && account.Proxy != nil {
|
|
||||||
proxyURL = account.Proxy.URL()
|
|
||||||
}
|
|
||||||
resp, err := c.doHTTP(req, proxyURL, account)
|
resp, err := c.doHTTP(req, proxyURL, account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lastErr = err
|
lastErr = err
|
||||||
@@ -1183,10 +1231,13 @@ func soraBaseURLNotFoundHint(requestURL string) string {
|
|||||||
return "(请检查 sora.client.base_url,建议配置为 https://sora.chatgpt.com/backend)"
|
return "(请检查 sora.client.base_url,建议配置为 https://sora.chatgpt.com/backend)"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *SoraDirectClient) generateSentinelToken(ctx context.Context, account *Account, accessToken string) (string, error) {
|
func (c *SoraDirectClient) generateSentinelToken(ctx context.Context, account *Account, accessToken, userAgent, proxyURL string) (string, error) {
|
||||||
reqID := uuid.NewString()
|
reqID := uuid.NewString()
|
||||||
userAgent := soraRandChoice(soraDesktopUserAgents)
|
userAgent = strings.TrimSpace(userAgent)
|
||||||
powToken := soraGetPowToken(userAgent)
|
if userAgent == "" {
|
||||||
|
userAgent = c.taskUserAgent()
|
||||||
|
}
|
||||||
|
powToken := soraPowTokenGenerator(userAgent)
|
||||||
payload := map[string]any{
|
payload := map[string]any{
|
||||||
"p": powToken,
|
"p": powToken,
|
||||||
"flow": soraSentinelFlow,
|
"flow": soraSentinelFlow,
|
||||||
@@ -1207,7 +1258,7 @@ func (c *SoraDirectClient) generateSentinelToken(ctx context.Context, account *A
|
|||||||
}
|
}
|
||||||
|
|
||||||
urlStr := soraChatGPTBaseURL + "/backend-api/sentinel/req"
|
urlStr := soraChatGPTBaseURL + "/backend-api/sentinel/req"
|
||||||
respBody, _, err := c.doRequest(ctx, account, http.MethodPost, urlStr, headers, bytes.NewReader(body), true)
|
respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, urlStr, headers, bytes.NewReader(body), true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
@@ -1223,16 +1274,6 @@ func (c *SoraDirectClient) generateSentinelToken(ctx context.Context, account *A
|
|||||||
return sentinel, nil
|
return sentinel, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func soraRandChoice(items []string) string {
|
|
||||||
if len(items) == 0 {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
soraRandMu.Lock()
|
|
||||||
idx := soraRand.Intn(len(items))
|
|
||||||
soraRandMu.Unlock()
|
|
||||||
return items[idx]
|
|
||||||
}
|
|
||||||
|
|
||||||
func soraGetPowToken(userAgent string) string {
|
func soraGetPowToken(userAgent string) string {
|
||||||
configList := soraBuildPowConfig(userAgent)
|
configList := soraBuildPowConfig(userAgent)
|
||||||
seed := strconv.FormatFloat(soraRandFloat(), 'f', -1, 64)
|
seed := strconv.FormatFloat(soraRandFloat(), 'f', -1, 64)
|
||||||
@@ -1248,13 +1289,16 @@ func soraRandFloat() float64 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func soraBuildPowConfig(userAgent string) []any {
|
func soraBuildPowConfig(userAgent string) []any {
|
||||||
screen := soraRandChoice([]string{
|
userAgent = strings.TrimSpace(userAgent)
|
||||||
strconv.Itoa(1920 + 1080),
|
if userAgent == "" && len(soraDesktopUserAgents) > 0 {
|
||||||
strconv.Itoa(2560 + 1440),
|
userAgent = soraDesktopUserAgents[0]
|
||||||
strconv.Itoa(1920 + 1200),
|
}
|
||||||
strconv.Itoa(2560 + 1600),
|
screenVal := soraStableChoiceInt([]int{
|
||||||
})
|
1920 + 1080,
|
||||||
screenVal, _ := strconv.Atoi(screen)
|
2560 + 1440,
|
||||||
|
1920 + 1200,
|
||||||
|
2560 + 1600,
|
||||||
|
}, userAgent+"|screen")
|
||||||
perfMs := float64(time.Since(soraPerfStart).Milliseconds())
|
perfMs := float64(time.Since(soraPerfStart).Milliseconds())
|
||||||
wallMs := float64(time.Now().UnixNano()) / 1e6
|
wallMs := float64(time.Now().UnixNano()) / 1e6
|
||||||
diff := wallMs - perfMs
|
diff := wallMs - perfMs
|
||||||
@@ -1264,32 +1308,47 @@ func soraBuildPowConfig(userAgent string) []any {
|
|||||||
4294705152,
|
4294705152,
|
||||||
0,
|
0,
|
||||||
userAgent,
|
userAgent,
|
||||||
soraRandChoice(soraPowScripts),
|
soraStableChoice(soraPowScripts, userAgent+"|script"),
|
||||||
soraRandChoice(soraPowDPL),
|
soraStableChoice(soraPowDPL, userAgent+"|dpl"),
|
||||||
"en-US",
|
"en-US",
|
||||||
"en-US,es-US,en,es",
|
"en-US,es-US,en,es",
|
||||||
0,
|
0,
|
||||||
soraRandChoice(soraPowNavigatorKeys),
|
soraStableChoice(soraPowNavigatorKeys, userAgent+"|navigator"),
|
||||||
soraRandChoice(soraPowDocumentKeys),
|
soraStableChoice(soraPowDocumentKeys, userAgent+"|document"),
|
||||||
soraRandChoice(soraPowWindowKeys),
|
soraStableChoice(soraPowWindowKeys, userAgent+"|window"),
|
||||||
perfMs,
|
perfMs,
|
||||||
uuid.NewString(),
|
uuid.NewString(),
|
||||||
"",
|
"",
|
||||||
soraRandChoiceInt(soraPowCores),
|
soraStableChoiceInt(soraPowCores, userAgent+"|cores"),
|
||||||
diff,
|
diff,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func soraRandChoiceInt(items []int) int {
|
func soraStableChoice(items []string, seed string) string {
|
||||||
|
if len(items) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
idx := soraStableIndex(seed, len(items))
|
||||||
|
return items[idx]
|
||||||
|
}
|
||||||
|
|
||||||
|
func soraStableChoiceInt(items []int, seed string) int {
|
||||||
if len(items) == 0 {
|
if len(items) == 0 {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
soraRandMu.Lock()
|
idx := soraStableIndex(seed, len(items))
|
||||||
idx := soraRand.Intn(len(items))
|
|
||||||
soraRandMu.Unlock()
|
|
||||||
return items[idx]
|
return items[idx]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func soraStableIndex(seed string, size int) int {
|
||||||
|
if size <= 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
h := fnv.New32a()
|
||||||
|
_, _ = h.Write([]byte(seed))
|
||||||
|
return int(h.Sum32() % uint32(size))
|
||||||
|
}
|
||||||
|
|
||||||
func soraPowParseTime() string {
|
func soraPowParseTime() string {
|
||||||
loc := time.FixedZone("EST", -5*3600)
|
loc := time.FixedZone("EST", -5*3600)
|
||||||
return time.Now().In(loc).Format("Mon Jan 02 2006 15:04:05 GMT-0700 (Eastern Standard Time)")
|
return time.Now().In(loc).Format("Mon Jan 02 2006 15:04:05 GMT-0700 (Eastern Standard Time)")
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ package service
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -365,3 +367,166 @@ func TestShouldAttemptSoraTokenRecover(t *testing.T) {
|
|||||||
require.False(t, shouldAttemptSoraTokenRecover(http.StatusUnauthorized, "https://auth.openai.com/oauth/token"))
|
require.False(t, shouldAttemptSoraTokenRecover(http.StatusUnauthorized, "https://auth.openai.com/oauth/token"))
|
||||||
require.False(t, shouldAttemptSoraTokenRecover(http.StatusTooManyRequests, "https://sora.chatgpt.com/backend/video_gen"))
|
require.False(t, shouldAttemptSoraTokenRecover(http.StatusTooManyRequests, "https://sora.chatgpt.com/backend/video_gen"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type soraClientRequestCall struct {
|
||||||
|
Path string
|
||||||
|
UserAgent string
|
||||||
|
ProxyURL string
|
||||||
|
}
|
||||||
|
|
||||||
|
type soraClientRecordingUpstream struct {
|
||||||
|
calls []soraClientRequestCall
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *soraClientRecordingUpstream) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
|
||||||
|
return nil, errors.New("unexpected Do call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *soraClientRecordingUpstream) DoWithTLS(req *http.Request, proxyURL string, _ int64, _ int, _ bool) (*http.Response, error) {
|
||||||
|
u.calls = append(u.calls, soraClientRequestCall{
|
||||||
|
Path: req.URL.Path,
|
||||||
|
UserAgent: req.Header.Get("User-Agent"),
|
||||||
|
ProxyURL: proxyURL,
|
||||||
|
})
|
||||||
|
switch req.URL.Path {
|
||||||
|
case "/backend-api/sentinel/req":
|
||||||
|
return newSoraClientMockResponse(http.StatusOK, `{"token":"sentinel-token","turnstile":{"dx":"ok"}}`), nil
|
||||||
|
case "/backend/nf/create":
|
||||||
|
return newSoraClientMockResponse(http.StatusOK, `{"id":"task-123"}`), nil
|
||||||
|
case "/backend/uploads":
|
||||||
|
return newSoraClientMockResponse(http.StatusOK, `{"id":"upload-123"}`), nil
|
||||||
|
case "/backend/nf/check":
|
||||||
|
return newSoraClientMockResponse(http.StatusOK, `{"rate_limit_and_credit_balance":{"estimated_num_videos_remaining":1,"rate_limit_reached":false}}`), nil
|
||||||
|
default:
|
||||||
|
return newSoraClientMockResponse(http.StatusOK, `{"ok":true}`), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSoraClientMockResponse(statusCode int, body string) *http.Response {
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: statusCode,
|
||||||
|
Header: make(http.Header),
|
||||||
|
Body: io.NopCloser(strings.NewReader(body)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSoraDirectClient_TaskUserAgent_DefaultDesktopFallback(t *testing.T) {
|
||||||
|
client := NewSoraDirectClient(&config.Config{}, nil, nil)
|
||||||
|
require.Equal(t, soraDesktopUserAgents[0], client.taskUserAgent())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSoraDirectClient_CreateVideoTask_UsesSameUserAgentAndProxyForSentinelAndCreate(t *testing.T) {
|
||||||
|
originPowTokenGenerator := soraPowTokenGenerator
|
||||||
|
soraPowTokenGenerator = func(_ string) string { return "gAAAAACmock" }
|
||||||
|
defer func() {
|
||||||
|
soraPowTokenGenerator = originPowTokenGenerator
|
||||||
|
}()
|
||||||
|
|
||||||
|
upstream := &soraClientRecordingUpstream{}
|
||||||
|
cfg := &config.Config{
|
||||||
|
Sora: config.SoraConfig{
|
||||||
|
Client: config.SoraClientConfig{
|
||||||
|
BaseURL: "https://sora.chatgpt.com/backend",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
client := NewSoraDirectClient(cfg, upstream, nil)
|
||||||
|
proxyID := int64(9)
|
||||||
|
account := &Account{
|
||||||
|
ID: 21,
|
||||||
|
Platform: PlatformSora,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Concurrency: 1,
|
||||||
|
ProxyID: &proxyID,
|
||||||
|
Proxy: &Proxy{
|
||||||
|
Protocol: "http",
|
||||||
|
Host: "127.0.0.1",
|
||||||
|
Port: 8080,
|
||||||
|
},
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "access-token",
|
||||||
|
"expires_at": time.Now().Add(30 * time.Minute).Format(time.RFC3339),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
taskID, err := client.CreateVideoTask(context.Background(), account, SoraVideoRequest{Prompt: "test"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "task-123", taskID)
|
||||||
|
require.Len(t, upstream.calls, 2)
|
||||||
|
|
||||||
|
sentinelCall := upstream.calls[0]
|
||||||
|
createCall := upstream.calls[1]
|
||||||
|
require.Equal(t, "/backend-api/sentinel/req", sentinelCall.Path)
|
||||||
|
require.Equal(t, "/backend/nf/create", createCall.Path)
|
||||||
|
require.Equal(t, "http://127.0.0.1:8080", sentinelCall.ProxyURL)
|
||||||
|
require.Equal(t, sentinelCall.ProxyURL, createCall.ProxyURL)
|
||||||
|
require.Equal(t, soraDesktopUserAgents[0], sentinelCall.UserAgent)
|
||||||
|
require.Equal(t, sentinelCall.UserAgent, createCall.UserAgent)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSoraDirectClient_UploadImage_UsesTaskUserAgentAndProxy(t *testing.T) {
|
||||||
|
upstream := &soraClientRecordingUpstream{}
|
||||||
|
cfg := &config.Config{
|
||||||
|
Sora: config.SoraConfig{
|
||||||
|
Client: config.SoraClientConfig{
|
||||||
|
BaseURL: "https://sora.chatgpt.com/backend",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
client := NewSoraDirectClient(cfg, upstream, nil)
|
||||||
|
proxyID := int64(3)
|
||||||
|
account := &Account{
|
||||||
|
ID: 31,
|
||||||
|
ProxyID: &proxyID,
|
||||||
|
Proxy: &Proxy{
|
||||||
|
Protocol: "http",
|
||||||
|
Host: "127.0.0.1",
|
||||||
|
Port: 8080,
|
||||||
|
},
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "access-token",
|
||||||
|
"expires_at": time.Now().Add(30 * time.Minute).Format(time.RFC3339),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
uploadID, err := client.UploadImage(context.Background(), account, []byte("mock-image"), "a.png")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "upload-123", uploadID)
|
||||||
|
require.Len(t, upstream.calls, 1)
|
||||||
|
require.Equal(t, "/backend/uploads", upstream.calls[0].Path)
|
||||||
|
require.Equal(t, "http://127.0.0.1:8080", upstream.calls[0].ProxyURL)
|
||||||
|
require.Equal(t, soraDesktopUserAgents[0], upstream.calls[0].UserAgent)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSoraDirectClient_PreflightCheck_UsesTaskUserAgentAndProxy(t *testing.T) {
|
||||||
|
upstream := &soraClientRecordingUpstream{}
|
||||||
|
cfg := &config.Config{
|
||||||
|
Sora: config.SoraConfig{
|
||||||
|
Client: config.SoraClientConfig{
|
||||||
|
BaseURL: "https://sora.chatgpt.com/backend",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
client := NewSoraDirectClient(cfg, upstream, nil)
|
||||||
|
proxyID := int64(7)
|
||||||
|
account := &Account{
|
||||||
|
ID: 41,
|
||||||
|
ProxyID: &proxyID,
|
||||||
|
Proxy: &Proxy{
|
||||||
|
Protocol: "http",
|
||||||
|
Host: "127.0.0.1",
|
||||||
|
Port: 8080,
|
||||||
|
},
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "access-token",
|
||||||
|
"expires_at": time.Now().Add(30 * time.Minute).Format(time.RFC3339),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := client.PreflightCheck(context.Background(), account, "sora2", SoraModelConfig{Type: "video"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, upstream.calls, 1)
|
||||||
|
require.Equal(t, "/backend/nf/check", upstream.calls[0].Path)
|
||||||
|
require.Equal(t, "http://127.0.0.1:8080", upstream.calls[0].ProxyURL)
|
||||||
|
require.Equal(t, soraDesktopUserAgents[0], upstream.calls[0].UserAgent)
|
||||||
|
}
|
||||||
|
|||||||
@@ -468,7 +468,18 @@ func (s *SoraGatewayService) writeSoraError(c *gin.Context, status int, errType,
|
|||||||
}
|
}
|
||||||
if stream {
|
if stream {
|
||||||
flusher, _ := c.Writer.(http.Flusher)
|
flusher, _ := c.Writer.(http.Flusher)
|
||||||
errorEvent := fmt.Sprintf(`event: error`+"\n"+`data: {"error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message)
|
errorData := map[string]any{
|
||||||
|
"error": map[string]string{
|
||||||
|
"type": errType,
|
||||||
|
"message": message,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
jsonBytes, err := json.Marshal(errorData)
|
||||||
|
if err != nil {
|
||||||
|
_ = c.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
errorEvent := fmt.Sprintf("event: error\ndata: %s\n\n", string(jsonBytes))
|
||||||
_, _ = fmt.Fprint(c.Writer, errorEvent)
|
_, _ = fmt.Fprint(c.Writer, errorEvent)
|
||||||
_, _ = fmt.Fprint(c.Writer, "data: [DONE]\n\n")
|
_, _ = fmt.Fprint(c.Writer, "data: [DONE]\n\n")
|
||||||
if flusher != nil {
|
if flusher != nil {
|
||||||
@@ -494,7 +505,11 @@ func (s *SoraGatewayService) handleSoraRequestError(ctx context.Context, account
|
|||||||
s.rateLimitService.HandleUpstreamError(ctx, account, upstreamErr.StatusCode, upstreamErr.Headers, upstreamErr.Body)
|
s.rateLimitService.HandleUpstreamError(ctx, account, upstreamErr.StatusCode, upstreamErr.Headers, upstreamErr.Body)
|
||||||
}
|
}
|
||||||
if s.shouldFailoverUpstreamError(upstreamErr.StatusCode) {
|
if s.shouldFailoverUpstreamError(upstreamErr.StatusCode) {
|
||||||
return &UpstreamFailoverError{StatusCode: upstreamErr.StatusCode, ResponseBody: upstreamErr.Body}
|
return &UpstreamFailoverError{
|
||||||
|
StatusCode: upstreamErr.StatusCode,
|
||||||
|
ResponseBody: upstreamErr.Body,
|
||||||
|
ResponseHeaders: upstreamErr.Headers,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
msg := upstreamErr.Message
|
msg := upstreamErr.Message
|
||||||
if override := soraProErrorMessage(model, msg); override != "" {
|
if override := soraProErrorMessage(model, msg); override != "" {
|
||||||
|
|||||||
@@ -4,10 +4,15 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -210,6 +215,33 @@ func TestSoraProErrorMessage(t *testing.T) {
|
|||||||
require.Empty(t, soraProErrorMessage("sora-basic", ""))
|
require.Empty(t, soraProErrorMessage("sora-basic", ""))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSoraGatewayService_WriteSoraError_StreamEscapesJSON(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
|
||||||
|
svc := NewSoraGatewayService(nil, nil, nil, &config.Config{})
|
||||||
|
svc.writeSoraError(c, http.StatusBadGateway, "upstream_error", "invalid \"prompt\"\nline2", true)
|
||||||
|
|
||||||
|
body := rec.Body.String()
|
||||||
|
require.Contains(t, body, "event: error\n")
|
||||||
|
require.Contains(t, body, "data: [DONE]\n\n")
|
||||||
|
|
||||||
|
lines := strings.Split(body, "\n")
|
||||||
|
require.GreaterOrEqual(t, len(lines), 2)
|
||||||
|
require.Equal(t, "event: error", lines[0])
|
||||||
|
require.True(t, strings.HasPrefix(lines[1], "data: "))
|
||||||
|
|
||||||
|
data := strings.TrimPrefix(lines[1], "data: ")
|
||||||
|
var parsed map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal([]byte(data), &parsed))
|
||||||
|
errObj, ok := parsed["error"].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "upstream_error", errObj["type"])
|
||||||
|
require.Equal(t, "invalid \"prompt\"\nline2", errObj["message"])
|
||||||
|
}
|
||||||
|
|
||||||
func TestShouldFailoverUpstreamError(t *testing.T) {
|
func TestShouldFailoverUpstreamError(t *testing.T) {
|
||||||
svc := NewSoraGatewayService(nil, nil, nil, &config.Config{})
|
svc := NewSoraGatewayService(nil, nil, nil, &config.Config{})
|
||||||
require.True(t, svc.shouldFailoverUpstreamError(401))
|
require.True(t, svc.shouldFailoverUpstreamError(401))
|
||||||
|
|||||||
Reference in New Issue
Block a user