feat(sora): 对齐 Sora OAuth 流程并隔离网关请求路径
- 新增并接通 Sora 专用 OAuth 接口与 ST/RT 换取能力 - 完成前端 Sora 授权、RT/ST 手动导入与账号创建流程 - 强化 Sora token 恢复、转发日志与网关路由隔离行为 - 补充后端服务层与路由层相关测试覆盖 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -17,12 +17,15 @@ import (
|
||||
"net/textproto"
|
||||
"net/url"
|
||||
"path"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
openaioauth "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
|
||||
"github.com/google/uuid"
|
||||
"github.com/tidwall/gjson"
|
||||
"golang.org/x/crypto/sha3"
|
||||
@@ -34,6 +37,11 @@ const (
|
||||
soraDefaultUserAgent = "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)"
|
||||
)
|
||||
|
||||
var (
|
||||
soraSessionAuthURL = "https://sora.chatgpt.com/api/auth/session"
|
||||
soraOAuthTokenURL = "https://auth.openai.com/oauth/token"
|
||||
)
|
||||
|
||||
const (
|
||||
soraPowMaxIteration = 500000
|
||||
)
|
||||
@@ -96,6 +104,7 @@ type SoraClient interface {
|
||||
UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error)
|
||||
CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error)
|
||||
CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error)
|
||||
EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error)
|
||||
GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error)
|
||||
GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error)
|
||||
}
|
||||
@@ -157,26 +166,94 @@ func (e *SoraUpstreamError) Error() string {
|
||||
|
||||
// SoraDirectClient 直连 Sora 实现
|
||||
type SoraDirectClient struct {
|
||||
cfg *config.Config
|
||||
httpUpstream HTTPUpstream
|
||||
tokenProvider *OpenAITokenProvider
|
||||
cfg *config.Config
|
||||
httpUpstream HTTPUpstream
|
||||
tokenProvider *OpenAITokenProvider
|
||||
accountRepo AccountRepository
|
||||
soraAccountRepo SoraAccountRepository
|
||||
baseURL string
|
||||
}
|
||||
|
||||
// NewSoraDirectClient 创建 Sora 直连客户端
|
||||
func NewSoraDirectClient(cfg *config.Config, httpUpstream HTTPUpstream, tokenProvider *OpenAITokenProvider) *SoraDirectClient {
|
||||
baseURL := ""
|
||||
if cfg != nil {
|
||||
rawBaseURL := strings.TrimRight(strings.TrimSpace(cfg.Sora.Client.BaseURL), "/")
|
||||
baseURL = normalizeSoraBaseURL(rawBaseURL)
|
||||
if rawBaseURL != "" && baseURL != rawBaseURL {
|
||||
log.Printf("[SoraClient] normalized base_url from %s to %s", sanitizeSoraLogURL(rawBaseURL), sanitizeSoraLogURL(baseURL))
|
||||
}
|
||||
}
|
||||
return &SoraDirectClient{
|
||||
cfg: cfg,
|
||||
httpUpstream: httpUpstream,
|
||||
tokenProvider: tokenProvider,
|
||||
baseURL: baseURL,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) SetAccountRepositories(accountRepo AccountRepository, soraAccountRepo SoraAccountRepository) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.accountRepo = accountRepo
|
||||
c.soraAccountRepo = soraAccountRepo
|
||||
}
|
||||
|
||||
// Enabled 判断是否启用 Sora 直连
|
||||
func (c *SoraDirectClient) Enabled() bool {
|
||||
if c == nil || c.cfg == nil {
|
||||
if c == nil {
|
||||
return false
|
||||
}
|
||||
return strings.TrimSpace(c.cfg.Sora.Client.BaseURL) != ""
|
||||
if strings.TrimSpace(c.baseURL) != "" {
|
||||
return true
|
||||
}
|
||||
if c.cfg == nil {
|
||||
return false
|
||||
}
|
||||
return strings.TrimSpace(normalizeSoraBaseURL(c.cfg.Sora.Client.BaseURL)) != ""
|
||||
}
|
||||
|
||||
// PreflightCheck 在创建任务前执行账号能力预检。
|
||||
// 当前仅对视频模型执行 /nf/check 预检,用于提前识别额度耗尽或能力缺失。
|
||||
func (c *SoraDirectClient) PreflightCheck(ctx context.Context, account *Account, requestedModel string, modelCfg SoraModelConfig) error {
|
||||
if modelCfg.Type != "video" {
|
||||
return nil
|
||||
}
|
||||
token, err := c.getAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
headers := c.buildBaseHeaders(token, c.defaultUserAgent())
|
||||
headers.Set("Accept", "application/json")
|
||||
body, _, err := c.doRequest(ctx, account, http.MethodGet, c.buildURL("/nf/check"), headers, nil, false)
|
||||
if err != nil {
|
||||
var upstreamErr *SoraUpstreamError
|
||||
if errors.As(err, &upstreamErr) && upstreamErr.StatusCode == http.StatusNotFound {
|
||||
return &SoraUpstreamError{
|
||||
StatusCode: http.StatusForbidden,
|
||||
Message: "当前账号未开通 Sora2 能力或无可用配额",
|
||||
Headers: upstreamErr.Headers,
|
||||
Body: upstreamErr.Body,
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
rateLimitReached := gjson.GetBytes(body, "rate_limit_and_credit_balance.rate_limit_reached").Bool()
|
||||
remaining := gjson.GetBytes(body, "rate_limit_and_credit_balance.estimated_num_videos_remaining")
|
||||
if rateLimitReached || (remaining.Exists() && remaining.Int() <= 0) {
|
||||
msg := "当前账号 Sora2 可用配额不足"
|
||||
if requestedModel != "" {
|
||||
msg = fmt.Sprintf("当前账号 %s 可用配额不足", requestedModel)
|
||||
}
|
||||
return &SoraUpstreamError{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Message: msg,
|
||||
Headers: http.Header{},
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) {
|
||||
@@ -347,6 +424,45 @@ func (c *SoraDirectClient) CreateVideoTask(ctx context.Context, account *Account
|
||||
return taskID, nil
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) {
|
||||
token, err := c.getAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if strings.TrimSpace(expansionLevel) == "" {
|
||||
expansionLevel = "medium"
|
||||
}
|
||||
if durationS <= 0 {
|
||||
durationS = 10
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"prompt": prompt,
|
||||
"expansion_level": expansionLevel,
|
||||
"duration_s": durationS,
|
||||
}
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
headers := c.buildBaseHeaders(token, c.defaultUserAgent())
|
||||
headers.Set("Content-Type", "application/json")
|
||||
headers.Set("Accept", "application/json")
|
||||
headers.Set("Origin", "https://sora.chatgpt.com")
|
||||
headers.Set("Referer", "https://sora.chatgpt.com/")
|
||||
|
||||
respBody, _, err := c.doRequest(ctx, account, http.MethodPost, c.buildURL("/editor/enhance_prompt"), headers, bytes.NewReader(body), false)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
enhancedPrompt := strings.TrimSpace(gjson.GetBytes(respBody, "enhanced_prompt").String())
|
||||
if enhancedPrompt == "" {
|
||||
return "", errors.New("enhance_prompt response missing enhanced_prompt")
|
||||
}
|
||||
return enhancedPrompt, nil
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) {
|
||||
status, found, err := c.fetchRecentImageTask(ctx, account, taskID, c.recentTaskLimit())
|
||||
if err != nil {
|
||||
@@ -512,9 +628,10 @@ func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, t
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) buildURL(endpoint string) string {
|
||||
base := ""
|
||||
if c != nil && c.cfg != nil {
|
||||
base = strings.TrimRight(strings.TrimSpace(c.cfg.Sora.Client.BaseURL), "/")
|
||||
base := strings.TrimRight(strings.TrimSpace(c.baseURL), "/")
|
||||
if base == "" && c != nil && c.cfg != nil {
|
||||
base = normalizeSoraBaseURL(c.cfg.Sora.Client.BaseURL)
|
||||
c.baseURL = base
|
||||
}
|
||||
if base == "" {
|
||||
return endpoint
|
||||
@@ -540,14 +657,257 @@ func (c *SoraDirectClient) getAccessToken(ctx context.Context, account *Account)
|
||||
if account == nil {
|
||||
return "", errors.New("account is nil")
|
||||
}
|
||||
if c.tokenProvider != nil {
|
||||
return c.tokenProvider.GetAccessToken(ctx, account)
|
||||
|
||||
allowProvider := c.allowOpenAITokenProvider(account)
|
||||
var providerErr error
|
||||
if allowProvider && c.tokenProvider != nil {
|
||||
token, err := c.tokenProvider.GetAccessToken(ctx, account)
|
||||
if err == nil && strings.TrimSpace(token) != "" {
|
||||
c.logTokenSource(account, "openai_token_provider")
|
||||
return token, nil
|
||||
}
|
||||
providerErr = err
|
||||
if err != nil && c.debugEnabled() {
|
||||
c.debugLogf(
|
||||
"token_provider_failed account_id=%d platform=%s err=%s",
|
||||
account.ID,
|
||||
account.Platform,
|
||||
logredact.RedactText(err.Error()),
|
||||
)
|
||||
}
|
||||
}
|
||||
token := strings.TrimSpace(account.GetCredential("access_token"))
|
||||
if token == "" {
|
||||
return "", errors.New("access_token not found")
|
||||
if token != "" {
|
||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||
if expiresAt != nil && time.Until(*expiresAt) <= 2*time.Minute {
|
||||
refreshed, refreshErr := c.recoverAccessToken(ctx, account, "access_token_expiring")
|
||||
if refreshErr == nil && strings.TrimSpace(refreshed) != "" {
|
||||
c.logTokenSource(account, "refresh_token_recovered")
|
||||
return refreshed, nil
|
||||
}
|
||||
if refreshErr != nil && c.debugEnabled() {
|
||||
c.debugLogf("token_refresh_before_use_failed account_id=%d err=%s", account.ID, logredact.RedactText(refreshErr.Error()))
|
||||
}
|
||||
}
|
||||
c.logTokenSource(account, "account_credentials")
|
||||
return token, nil
|
||||
}
|
||||
return token, nil
|
||||
|
||||
recovered, recoverErr := c.recoverAccessToken(ctx, account, "access_token_missing")
|
||||
if recoverErr == nil && strings.TrimSpace(recovered) != "" {
|
||||
c.logTokenSource(account, "session_or_refresh_recovered")
|
||||
return recovered, nil
|
||||
}
|
||||
if recoverErr != nil && c.debugEnabled() {
|
||||
c.debugLogf("token_recover_failed account_id=%d platform=%s err=%s", account.ID, account.Platform, logredact.RedactText(recoverErr.Error()))
|
||||
}
|
||||
if providerErr != nil {
|
||||
return "", providerErr
|
||||
}
|
||||
if c.tokenProvider != nil && !allowProvider {
|
||||
c.logTokenSource(account, "account_credentials(provider_disabled)")
|
||||
}
|
||||
return "", errors.New("access_token not found")
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) recoverAccessToken(ctx context.Context, account *Account, reason string) (string, error) {
|
||||
if account == nil {
|
||||
return "", errors.New("account is nil")
|
||||
}
|
||||
|
||||
if sessionToken := strings.TrimSpace(account.GetCredential("session_token")); sessionToken != "" {
|
||||
accessToken, expiresAt, err := c.exchangeSessionToken(ctx, account, sessionToken)
|
||||
if err == nil && strings.TrimSpace(accessToken) != "" {
|
||||
c.applyRecoveredToken(ctx, account, accessToken, "", expiresAt, sessionToken)
|
||||
c.logTokenRecover(account, "session_token", reason, true, nil)
|
||||
return accessToken, nil
|
||||
}
|
||||
c.logTokenRecover(account, "session_token", reason, false, err)
|
||||
}
|
||||
|
||||
refreshToken := strings.TrimSpace(account.GetCredential("refresh_token"))
|
||||
if refreshToken == "" {
|
||||
return "", errors.New("session_token/refresh_token not found")
|
||||
}
|
||||
accessToken, newRefreshToken, expiresAt, err := c.exchangeRefreshToken(ctx, account, refreshToken)
|
||||
if err != nil {
|
||||
c.logTokenRecover(account, "refresh_token", reason, false, err)
|
||||
return "", err
|
||||
}
|
||||
if strings.TrimSpace(accessToken) == "" {
|
||||
return "", errors.New("refreshed access_token is empty")
|
||||
}
|
||||
c.applyRecoveredToken(ctx, account, accessToken, newRefreshToken, expiresAt, "")
|
||||
c.logTokenRecover(account, "refresh_token", reason, true, nil)
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) exchangeSessionToken(ctx context.Context, account *Account, sessionToken string) (string, string, error) {
|
||||
headers := http.Header{}
|
||||
headers.Set("Cookie", "__Secure-next-auth.session-token="+sessionToken)
|
||||
headers.Set("Accept", "application/json")
|
||||
headers.Set("Origin", "https://sora.chatgpt.com")
|
||||
headers.Set("Referer", "https://sora.chatgpt.com/")
|
||||
headers.Set("User-Agent", c.defaultUserAgent())
|
||||
body, _, err := c.doRequest(ctx, account, http.MethodGet, soraSessionAuthURL, headers, nil, false)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
accessToken := strings.TrimSpace(gjson.GetBytes(body, "accessToken").String())
|
||||
if accessToken == "" {
|
||||
return "", "", errors.New("session exchange missing accessToken")
|
||||
}
|
||||
expiresAt := strings.TrimSpace(gjson.GetBytes(body, "expires").String())
|
||||
return accessToken, expiresAt, nil
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) exchangeRefreshToken(ctx context.Context, account *Account, refreshToken string) (string, string, string, error) {
|
||||
clientIDs := []string{
|
||||
strings.TrimSpace(account.GetCredential("client_id")),
|
||||
openaioauth.SoraClientID,
|
||||
openaioauth.ClientID,
|
||||
}
|
||||
tried := make(map[string]struct{}, len(clientIDs))
|
||||
var lastErr error
|
||||
|
||||
for _, clientID := range clientIDs {
|
||||
if clientID == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := tried[clientID]; ok {
|
||||
continue
|
||||
}
|
||||
tried[clientID] = struct{}{}
|
||||
|
||||
payload := map[string]any{
|
||||
"client_id": clientID,
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refreshToken,
|
||||
"redirect_uri": "com.openai.chat://auth0.openai.com/ios/com.openai.chat/callback",
|
||||
}
|
||||
bodyBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
headers := http.Header{}
|
||||
headers.Set("Accept", "application/json")
|
||||
headers.Set("Content-Type", "application/json")
|
||||
headers.Set("User-Agent", c.defaultUserAgent())
|
||||
|
||||
respBody, _, err := c.doRequest(ctx, account, http.MethodPost, soraOAuthTokenURL, headers, bytes.NewReader(bodyBytes), false)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
if c.debugEnabled() {
|
||||
c.debugLogf("refresh_token_exchange_failed account_id=%d client_id=%s err=%s", account.ID, clientID, logredact.RedactText(err.Error()))
|
||||
}
|
||||
continue
|
||||
}
|
||||
accessToken := strings.TrimSpace(gjson.GetBytes(respBody, "access_token").String())
|
||||
if accessToken == "" {
|
||||
lastErr = errors.New("oauth refresh response missing access_token")
|
||||
continue
|
||||
}
|
||||
newRefreshToken := strings.TrimSpace(gjson.GetBytes(respBody, "refresh_token").String())
|
||||
expiresIn := gjson.GetBytes(respBody, "expires_in").Int()
|
||||
expiresAt := ""
|
||||
if expiresIn > 0 {
|
||||
expiresAt = time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339)
|
||||
}
|
||||
return accessToken, newRefreshToken, expiresAt, nil
|
||||
}
|
||||
|
||||
if lastErr != nil {
|
||||
return "", "", "", lastErr
|
||||
}
|
||||
return "", "", "", errors.New("no available client_id for refresh_token exchange")
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) applyRecoveredToken(ctx context.Context, account *Account, accessToken, refreshToken, expiresAt, sessionToken string) {
|
||||
if account == nil {
|
||||
return
|
||||
}
|
||||
if account.Credentials == nil {
|
||||
account.Credentials = make(map[string]any)
|
||||
}
|
||||
if strings.TrimSpace(accessToken) != "" {
|
||||
account.Credentials["access_token"] = accessToken
|
||||
}
|
||||
if strings.TrimSpace(refreshToken) != "" {
|
||||
account.Credentials["refresh_token"] = refreshToken
|
||||
}
|
||||
if strings.TrimSpace(expiresAt) != "" {
|
||||
account.Credentials["expires_at"] = expiresAt
|
||||
}
|
||||
if strings.TrimSpace(sessionToken) != "" {
|
||||
account.Credentials["session_token"] = sessionToken
|
||||
}
|
||||
|
||||
if c.accountRepo != nil {
|
||||
if err := c.accountRepo.Update(ctx, account); err != nil {
|
||||
if c.debugEnabled() {
|
||||
c.debugLogf("persist_recovered_token_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error()))
|
||||
}
|
||||
}
|
||||
}
|
||||
c.updateSoraAccountExtension(ctx, account, accessToken, refreshToken, sessionToken)
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) updateSoraAccountExtension(ctx context.Context, account *Account, accessToken, refreshToken, sessionToken string) {
|
||||
if c == nil || c.soraAccountRepo == nil || account == nil || account.ID <= 0 {
|
||||
return
|
||||
}
|
||||
updates := make(map[string]any)
|
||||
if strings.TrimSpace(accessToken) != "" && strings.TrimSpace(refreshToken) != "" {
|
||||
updates["access_token"] = accessToken
|
||||
updates["refresh_token"] = refreshToken
|
||||
}
|
||||
if strings.TrimSpace(sessionToken) != "" {
|
||||
updates["session_token"] = sessionToken
|
||||
}
|
||||
if len(updates) == 0 {
|
||||
return
|
||||
}
|
||||
if err := c.soraAccountRepo.Upsert(ctx, account.ID, updates); err != nil && c.debugEnabled() {
|
||||
c.debugLogf("persist_sora_extension_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error()))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) logTokenRecover(account *Account, source, reason string, success bool, err error) {
|
||||
if !c.debugEnabled() || account == nil {
|
||||
return
|
||||
}
|
||||
if success {
|
||||
c.debugLogf("token_recover_success account_id=%d platform=%s source=%s reason=%s", account.ID, account.Platform, source, reason)
|
||||
return
|
||||
}
|
||||
if err == nil {
|
||||
c.debugLogf("token_recover_failed account_id=%d platform=%s source=%s reason=%s", account.ID, account.Platform, source, reason)
|
||||
return
|
||||
}
|
||||
c.debugLogf("token_recover_failed account_id=%d platform=%s source=%s reason=%s err=%s", account.ID, account.Platform, source, reason, logredact.RedactText(err.Error()))
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) allowOpenAITokenProvider(account *Account) bool {
|
||||
if c == nil || c.tokenProvider == nil {
|
||||
return false
|
||||
}
|
||||
if account != nil && account.Platform == PlatformSora {
|
||||
return c.cfg != nil && c.cfg.Sora.Client.UseOpenAITokenProvider
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) logTokenSource(account *Account, source string) {
|
||||
if !c.debugEnabled() || account == nil {
|
||||
return
|
||||
}
|
||||
c.debugLogf(
|
||||
"token_selected account_id=%d platform=%s account_type=%s source=%s",
|
||||
account.ID,
|
||||
account.Platform,
|
||||
account.Type,
|
||||
source,
|
||||
)
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) buildBaseHeaders(token, userAgent string) http.Header {
|
||||
@@ -600,7 +960,24 @@ func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, meth
|
||||
}
|
||||
|
||||
attempts := maxRetries + 1
|
||||
authRecovered := false
|
||||
authRecoverExtraAttemptGranted := false
|
||||
var lastErr error
|
||||
for attempt := 1; attempt <= attempts; attempt++ {
|
||||
if c.debugEnabled() {
|
||||
c.debugLogf(
|
||||
"request_start method=%s url=%s attempt=%d/%d timeout_s=%d body_bytes=%d proxy_bound=%t headers=%s",
|
||||
method,
|
||||
sanitizeSoraLogURL(urlStr),
|
||||
attempt,
|
||||
attempts,
|
||||
timeout,
|
||||
len(bodyBytes),
|
||||
account != nil && account.ProxyID != nil && account.Proxy != nil,
|
||||
formatSoraHeaders(headers),
|
||||
)
|
||||
}
|
||||
|
||||
var reader io.Reader
|
||||
if bodyBytes != nil {
|
||||
reader = bytes.NewReader(bodyBytes)
|
||||
@@ -618,7 +995,21 @@ func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, meth
|
||||
}
|
||||
resp, err := c.doHTTP(req, proxyURL, account)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
if c.debugEnabled() {
|
||||
c.debugLogf(
|
||||
"request_transport_error method=%s url=%s attempt=%d/%d err=%s",
|
||||
method,
|
||||
sanitizeSoraLogURL(urlStr),
|
||||
attempt,
|
||||
attempts,
|
||||
logredact.RedactText(err.Error()),
|
||||
)
|
||||
}
|
||||
if attempt < attempts && allowRetry {
|
||||
if c.debugEnabled() {
|
||||
c.debugLogf("request_retry_scheduled method=%s url=%s reason=transport_error next_attempt=%d/%d", method, sanitizeSoraLogURL(urlStr), attempt+1, attempts)
|
||||
}
|
||||
c.sleepRetry(attempt)
|
||||
continue
|
||||
}
|
||||
@@ -632,12 +1023,53 @@ func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, meth
|
||||
}
|
||||
|
||||
if c.cfg != nil && c.cfg.Sora.Client.Debug {
|
||||
log.Printf("[SoraClient] %s %s status=%d cost=%s", method, sanitizeSoraLogURL(urlStr), resp.StatusCode, time.Since(start))
|
||||
c.debugLogf(
|
||||
"response_received method=%s url=%s attempt=%d/%d status=%d cost=%s resp_bytes=%d resp_headers=%s",
|
||||
method,
|
||||
sanitizeSoraLogURL(urlStr),
|
||||
attempt,
|
||||
attempts,
|
||||
resp.StatusCode,
|
||||
time.Since(start),
|
||||
len(respBody),
|
||||
formatSoraHeaders(resp.Header),
|
||||
)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
|
||||
upstreamErr := c.buildUpstreamError(resp.StatusCode, resp.Header, respBody)
|
||||
if !authRecovered && shouldAttemptSoraTokenRecover(resp.StatusCode, urlStr) && account != nil {
|
||||
if recovered, recoverErr := c.recoverAccessToken(ctx, account, fmt.Sprintf("upstream_status_%d", resp.StatusCode)); recoverErr == nil && strings.TrimSpace(recovered) != "" {
|
||||
headers.Set("Authorization", "Bearer "+recovered)
|
||||
authRecovered = true
|
||||
if attempt == attempts && !authRecoverExtraAttemptGranted {
|
||||
attempts++
|
||||
authRecoverExtraAttemptGranted = true
|
||||
}
|
||||
if c.debugEnabled() {
|
||||
c.debugLogf("request_retry_with_recovered_token method=%s url=%s status=%d", method, sanitizeSoraLogURL(urlStr), resp.StatusCode)
|
||||
}
|
||||
continue
|
||||
} else if recoverErr != nil && c.debugEnabled() {
|
||||
c.debugLogf("request_recover_token_failed method=%s url=%s status=%d err=%s", method, sanitizeSoraLogURL(urlStr), resp.StatusCode, logredact.RedactText(recoverErr.Error()))
|
||||
}
|
||||
}
|
||||
if c.debugEnabled() {
|
||||
c.debugLogf(
|
||||
"response_non_success method=%s url=%s attempt=%d/%d status=%d body=%s",
|
||||
method,
|
||||
sanitizeSoraLogURL(urlStr),
|
||||
attempt,
|
||||
attempts,
|
||||
resp.StatusCode,
|
||||
summarizeSoraResponseBody(respBody, 512),
|
||||
)
|
||||
}
|
||||
upstreamErr := c.buildUpstreamError(resp.StatusCode, resp.Header, respBody, urlStr)
|
||||
lastErr = upstreamErr
|
||||
if allowRetry && attempt < attempts && (resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode >= 500) {
|
||||
if c.debugEnabled() {
|
||||
c.debugLogf("request_retry_scheduled method=%s url=%s reason=status_%d next_attempt=%d/%d", method, sanitizeSoraLogURL(urlStr), resp.StatusCode, attempt+1, attempts)
|
||||
}
|
||||
c.sleepRetry(attempt)
|
||||
continue
|
||||
}
|
||||
@@ -645,9 +1077,34 @@ func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, meth
|
||||
}
|
||||
return respBody, resp.Header, nil
|
||||
}
|
||||
if lastErr != nil {
|
||||
return nil, nil, lastErr
|
||||
}
|
||||
return nil, nil, errors.New("upstream retries exhausted")
|
||||
}
|
||||
|
||||
func shouldAttemptSoraTokenRecover(statusCode int, rawURL string) bool {
|
||||
switch statusCode {
|
||||
case http.StatusUnauthorized, http.StatusForbidden:
|
||||
parsed, err := url.Parse(strings.TrimSpace(rawURL))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
host := strings.ToLower(parsed.Hostname())
|
||||
if host != "sora.chatgpt.com" && host != "chatgpt.com" {
|
||||
return false
|
||||
}
|
||||
// 避免在 ST->AT 转换接口上递归触发 token 恢复导致死循环。
|
||||
path := strings.ToLower(strings.TrimSpace(parsed.Path))
|
||||
if path == "/api/auth/session" {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) doHTTP(req *http.Request, proxyURL string, account *Account) (*http.Response, error) {
|
||||
enableTLS := c != nil && c.cfg != nil && c.cfg.Gateway.TLSFingerprint.Enabled && !c.cfg.Sora.Client.DisableTLSFingerprint
|
||||
if c.httpUpstream != nil {
|
||||
@@ -670,9 +1127,14 @@ func (c *SoraDirectClient) sleepRetry(attempt int) {
|
||||
time.Sleep(backoff)
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) buildUpstreamError(status int, headers http.Header, body []byte) error {
|
||||
func (c *SoraDirectClient) buildUpstreamError(status int, headers http.Header, body []byte, requestURL string) error {
|
||||
msg := strings.TrimSpace(extractUpstreamErrorMessage(body))
|
||||
msg = sanitizeUpstreamErrorMessage(msg)
|
||||
if status == http.StatusNotFound && strings.Contains(strings.ToLower(msg), "not found") {
|
||||
if hint := soraBaseURLNotFoundHint(requestURL); hint != "" {
|
||||
msg = strings.TrimSpace(msg + " " + hint)
|
||||
}
|
||||
}
|
||||
if msg == "" {
|
||||
msg = truncateForLog(body, 256)
|
||||
}
|
||||
@@ -684,6 +1146,45 @@ func (c *SoraDirectClient) buildUpstreamError(status int, headers http.Header, b
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeSoraBaseURL(raw string) string {
|
||||
trimmed := strings.TrimRight(strings.TrimSpace(raw), "/")
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
parsed, err := url.Parse(trimmed)
|
||||
if err != nil || parsed.Scheme == "" || parsed.Host == "" {
|
||||
return trimmed
|
||||
}
|
||||
host := strings.ToLower(parsed.Hostname())
|
||||
if host != "sora.chatgpt.com" && host != "chatgpt.com" {
|
||||
return trimmed
|
||||
}
|
||||
pathVal := strings.TrimRight(strings.TrimSpace(parsed.Path), "/")
|
||||
switch pathVal {
|
||||
case "", "/":
|
||||
parsed.Path = "/backend"
|
||||
case "/backend-api":
|
||||
parsed.Path = "/backend"
|
||||
}
|
||||
return strings.TrimRight(parsed.String(), "/")
|
||||
}
|
||||
|
||||
func soraBaseURLNotFoundHint(requestURL string) string {
|
||||
parsed, err := url.Parse(strings.TrimSpace(requestURL))
|
||||
if err != nil || parsed.Host == "" {
|
||||
return ""
|
||||
}
|
||||
host := strings.ToLower(parsed.Hostname())
|
||||
if host != "sora.chatgpt.com" && host != "chatgpt.com" {
|
||||
return ""
|
||||
}
|
||||
pathVal := strings.TrimSpace(parsed.Path)
|
||||
if strings.HasPrefix(pathVal, "/backend/") || pathVal == "/backend" {
|
||||
return ""
|
||||
}
|
||||
return "(请检查 sora.client.base_url,建议配置为 https://sora.chatgpt.com/backend)"
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) generateSentinelToken(ctx context.Context, account *Account, accessToken string) (string, error) {
|
||||
reqID := uuid.NewString()
|
||||
userAgent := soraRandChoice(soraDesktopUserAgents)
|
||||
@@ -901,3 +1402,70 @@ func sanitizeSoraLogURL(raw string) string {
|
||||
parsed.RawQuery = q.Encode()
|
||||
return parsed.String()
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) debugEnabled() bool {
|
||||
return c != nil && c.cfg != nil && c.cfg.Sora.Client.Debug
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) debugLogf(format string, args ...any) {
|
||||
if !c.debugEnabled() {
|
||||
return
|
||||
}
|
||||
log.Printf("[SoraClient] "+format, args...)
|
||||
}
|
||||
|
||||
func formatSoraHeaders(headers http.Header) string {
|
||||
if len(headers) == 0 {
|
||||
return "{}"
|
||||
}
|
||||
keys := make([]string, 0, len(headers))
|
||||
for key := range headers {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
out := make(map[string]string, len(keys))
|
||||
for _, key := range keys {
|
||||
values := headers.Values(key)
|
||||
if len(values) == 0 {
|
||||
continue
|
||||
}
|
||||
val := strings.Join(values, ",")
|
||||
if isSensitiveHeader(key) {
|
||||
out[key] = "***"
|
||||
continue
|
||||
}
|
||||
out[key] = truncateForLog([]byte(logredact.RedactText(val)), 160)
|
||||
}
|
||||
encoded, err := json.Marshal(out)
|
||||
if err != nil {
|
||||
return "{}"
|
||||
}
|
||||
return string(encoded)
|
||||
}
|
||||
|
||||
func isSensitiveHeader(key string) bool {
|
||||
k := strings.ToLower(strings.TrimSpace(key))
|
||||
switch k {
|
||||
case "authorization", "openai-sentinel-token", "cookie", "set-cookie", "x-api-key":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func summarizeSoraResponseBody(body []byte, maxLen int) string {
|
||||
if len(body) == 0 {
|
||||
return ""
|
||||
}
|
||||
var text string
|
||||
if json.Valid(body) {
|
||||
text = logredact.RedactJSON(body)
|
||||
} else {
|
||||
text = logredact.RedactText(string(body))
|
||||
}
|
||||
text = strings.TrimSpace(text)
|
||||
if maxLen <= 0 || len(text) <= maxLen {
|
||||
return text
|
||||
}
|
||||
return text[:maxLen] + "...(truncated)"
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user