diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go
index b9f31ba9..8efcb550 100644
--- a/backend/internal/config/config.go
+++ b/backend/internal/config/config.go
@@ -162,6 +162,8 @@ type TokenRefreshConfig struct {
MaxRetries int `mapstructure:"max_retries"`
// 重试退避基础时间(秒)
RetryBackoffSeconds int `mapstructure:"retry_backoff_seconds"`
+ // 是否允许 OpenAI 刷新器同步覆盖关联的 Sora 账号 token(默认关闭)
+ SyncLinkedSoraAccounts bool `mapstructure:"sync_linked_sora_accounts"`
}
type PricingConfig struct {
@@ -269,17 +271,18 @@ type SoraConfig struct {
// SoraClientConfig 直连 Sora 客户端配置
type SoraClientConfig struct {
- BaseURL string `mapstructure:"base_url"`
- TimeoutSeconds int `mapstructure:"timeout_seconds"`
- MaxRetries int `mapstructure:"max_retries"`
- PollIntervalSeconds int `mapstructure:"poll_interval_seconds"`
- MaxPollAttempts int `mapstructure:"max_poll_attempts"`
- RecentTaskLimit int `mapstructure:"recent_task_limit"`
- RecentTaskLimitMax int `mapstructure:"recent_task_limit_max"`
- Debug bool `mapstructure:"debug"`
- Headers map[string]string `mapstructure:"headers"`
- UserAgent string `mapstructure:"user_agent"`
- DisableTLSFingerprint bool `mapstructure:"disable_tls_fingerprint"`
+ BaseURL string `mapstructure:"base_url"`
+ TimeoutSeconds int `mapstructure:"timeout_seconds"`
+ MaxRetries int `mapstructure:"max_retries"`
+ PollIntervalSeconds int `mapstructure:"poll_interval_seconds"`
+ MaxPollAttempts int `mapstructure:"max_poll_attempts"`
+ RecentTaskLimit int `mapstructure:"recent_task_limit"`
+ RecentTaskLimitMax int `mapstructure:"recent_task_limit_max"`
+ Debug bool `mapstructure:"debug"`
+ UseOpenAITokenProvider bool `mapstructure:"use_openai_token_provider"`
+ Headers map[string]string `mapstructure:"headers"`
+ UserAgent string `mapstructure:"user_agent"`
+ DisableTLSFingerprint bool `mapstructure:"disable_tls_fingerprint"`
}
// SoraStorageConfig 媒体存储配置
@@ -1116,6 +1119,7 @@ func setDefaults() {
viper.SetDefault("sora.client.recent_task_limit", 50)
viper.SetDefault("sora.client.recent_task_limit_max", 200)
viper.SetDefault("sora.client.debug", false)
+ viper.SetDefault("sora.client.use_openai_token_provider", false)
viper.SetDefault("sora.client.headers", map[string]string{})
viper.SetDefault("sora.client.user_agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
viper.SetDefault("sora.client.disable_tls_fingerprint", false)
@@ -1137,6 +1141,7 @@ func setDefaults() {
viper.SetDefault("token_refresh.refresh_before_expiry_hours", 0.5) // 提前30分钟刷新(适配Google 1小时token)
viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次
viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒
+ viper.SetDefault("token_refresh.sync_linked_sora_accounts", false) // 默认不跨平台覆盖 Sora token
// Gemini OAuth - configure via environment variables or config file
// GEMINI_OAUTH_CLIENT_ID and GEMINI_OAUTH_CLIENT_SECRET
diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go
index 7533c70e..79f90b8e 100644
--- a/backend/internal/handler/admin/account_handler.go
+++ b/backend/internal/handler/admin/account_handler.go
@@ -1333,6 +1333,12 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
return
}
+ // Handle Sora accounts
+ if account.Platform == service.PlatformSora {
+ response.Success(c, service.DefaultSoraModels(nil))
+ return
+ }
+
// Handle Claude/Anthropic accounts
// For OAuth and Setup-Token accounts: return default models
if account.IsOAuth() {
diff --git a/backend/internal/handler/admin/openai_oauth_handler.go b/backend/internal/handler/admin/openai_oauth_handler.go
index ed86fea9..cf43f89e 100644
--- a/backend/internal/handler/admin/openai_oauth_handler.go
+++ b/backend/internal/handler/admin/openai_oauth_handler.go
@@ -2,6 +2,7 @@ package admin
import (
"strconv"
+ "strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
@@ -16,6 +17,13 @@ type OpenAIOAuthHandler struct {
adminService service.AdminService
}
+func oauthPlatformFromPath(c *gin.Context) string {
+ if strings.Contains(c.FullPath(), "/admin/sora/") {
+ return service.PlatformSora
+ }
+ return service.PlatformOpenAI
+}
+
// NewOpenAIOAuthHandler creates a new OpenAI OAuth handler
func NewOpenAIOAuthHandler(openaiOAuthService *service.OpenAIOAuthService, adminService service.AdminService) *OpenAIOAuthHandler {
return &OpenAIOAuthHandler{
@@ -52,6 +60,7 @@ func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) {
type OpenAIExchangeCodeRequest struct {
SessionID string `json:"session_id" binding:"required"`
Code string `json:"code" binding:"required"`
+ State string `json:"state" binding:"required"`
RedirectURI string `json:"redirect_uri"`
ProxyID *int64 `json:"proxy_id"`
}
@@ -68,6 +77,7 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) {
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
SessionID: req.SessionID,
Code: req.Code,
+ State: req.State,
RedirectURI: req.RedirectURI,
ProxyID: req.ProxyID,
})
@@ -81,18 +91,29 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) {
// OpenAIRefreshTokenRequest represents the request for refreshing OpenAI token
type OpenAIRefreshTokenRequest struct {
- RefreshToken string `json:"refresh_token" binding:"required"`
+ RefreshToken string `json:"refresh_token"`
+ RT string `json:"rt"`
+ ClientID string `json:"client_id"`
ProxyID *int64 `json:"proxy_id"`
}
// RefreshToken refreshes an OpenAI OAuth token
// POST /api/v1/admin/openai/refresh-token
+// POST /api/v1/admin/sora/rt2at
func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
var req OpenAIRefreshTokenRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
+ refreshToken := strings.TrimSpace(req.RefreshToken)
+ if refreshToken == "" {
+ refreshToken = strings.TrimSpace(req.RT)
+ }
+ if refreshToken == "" {
+ response.BadRequest(c, "refresh_token is required")
+ return
+ }
var proxyURL string
if req.ProxyID != nil {
@@ -102,7 +123,7 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
}
}
- tokenInfo, err := h.openaiOAuthService.RefreshToken(c.Request.Context(), req.RefreshToken, proxyURL)
+ tokenInfo, err := h.openaiOAuthService.RefreshTokenWithClientID(c.Request.Context(), refreshToken, proxyURL, strings.TrimSpace(req.ClientID))
if err != nil {
response.ErrorFrom(c, err)
return
@@ -111,8 +132,39 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
response.Success(c, tokenInfo)
}
-// RefreshAccountToken refreshes token for a specific OpenAI account
+// ExchangeSoraSessionToken exchanges Sora session token to access token
+// POST /api/v1/admin/sora/st2at
+func (h *OpenAIOAuthHandler) ExchangeSoraSessionToken(c *gin.Context) {
+ var req struct {
+ SessionToken string `json:"session_token"`
+ ST string `json:"st"`
+ ProxyID *int64 `json:"proxy_id"`
+ }
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ sessionToken := strings.TrimSpace(req.SessionToken)
+ if sessionToken == "" {
+ sessionToken = strings.TrimSpace(req.ST)
+ }
+ if sessionToken == "" {
+ response.BadRequest(c, "session_token is required")
+ return
+ }
+
+ tokenInfo, err := h.openaiOAuthService.ExchangeSoraSessionToken(c.Request.Context(), sessionToken, req.ProxyID)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, tokenInfo)
+}
+
+// RefreshAccountToken refreshes token for a specific OpenAI/Sora account
// POST /api/v1/admin/openai/accounts/:id/refresh
+// POST /api/v1/admin/sora/accounts/:id/refresh
func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
@@ -127,9 +179,9 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
return
}
- // Ensure account is OpenAI platform
- if !account.IsOpenAI() {
- response.BadRequest(c, "Account is not an OpenAI account")
+ platform := oauthPlatformFromPath(c)
+ if account.Platform != platform {
+ response.BadRequest(c, "Account platform does not match OAuth endpoint")
return
}
@@ -167,12 +219,14 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
response.Success(c, dto.AccountFromService(updatedAccount))
}
-// CreateAccountFromOAuth creates a new OpenAI OAuth account from token info
+// CreateAccountFromOAuth creates a new OpenAI/Sora OAuth account from token info
// POST /api/v1/admin/openai/create-from-oauth
+// POST /api/v1/admin/sora/create-from-oauth
func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
var req struct {
SessionID string `json:"session_id" binding:"required"`
Code string `json:"code" binding:"required"`
+ State string `json:"state" binding:"required"`
RedirectURI string `json:"redirect_uri"`
ProxyID *int64 `json:"proxy_id"`
Name string `json:"name"`
@@ -189,6 +243,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
SessionID: req.SessionID,
Code: req.Code,
+ State: req.State,
RedirectURI: req.RedirectURI,
ProxyID: req.ProxyID,
})
@@ -200,19 +255,25 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
// Build credentials from token info
credentials := h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
+ platform := oauthPlatformFromPath(c)
+
// Use email as default name if not provided
name := req.Name
if name == "" && tokenInfo.Email != "" {
name = tokenInfo.Email
}
if name == "" {
- name = "OpenAI OAuth Account"
+ if platform == service.PlatformSora {
+ name = "Sora OAuth Account"
+ } else {
+ name = "OpenAI OAuth Account"
+ }
}
// Create account
account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{
Name: name,
- Platform: "openai",
+ Platform: platform,
Type: "oauth",
Credentials: credentials,
ProxyID: req.ProxyID,
diff --git a/backend/internal/handler/sora_gateway_handler.go b/backend/internal/handler/sora_gateway_handler.go
index 80932899..9c9f53b1 100644
--- a/backend/internal/handler/sora_gateway_handler.go
+++ b/backend/internal/handler/sora_gateway_handler.go
@@ -212,6 +212,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0
+ var lastFailoverBody []byte
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "")
@@ -224,7 +225,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
}
- h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
+ h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverBody, streamStarted)
return
}
account := selection.Account
@@ -287,14 +288,19 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
failedAccountIDs[account.ID] = struct{}{}
if switchCount >= maxAccountSwitches {
lastFailoverStatus = failoverErr.StatusCode
- h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
+ lastFailoverBody = failoverErr.ResponseBody
+ h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverBody, streamStarted)
return
}
lastFailoverStatus = failoverErr.StatusCode
+ lastFailoverBody = failoverErr.ResponseBody
switchCount++
+ upstreamErrCode, upstreamErrMsg := extractUpstreamErrorCodeAndMessage(lastFailoverBody)
reqLog.Warn("sora.upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
+ zap.String("upstream_error_code", upstreamErrCode),
+ zap.String("upstream_error_message", upstreamErrMsg),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
)
@@ -360,17 +366,32 @@ func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, s
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
}
-func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) {
- status, errType, errMsg := h.mapUpstreamError(statusCode)
+func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, responseBody []byte, streamStarted bool) {
+ status, errType, errMsg := h.mapUpstreamError(statusCode, responseBody)
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
}
-func (h *SoraGatewayHandler) mapUpstreamError(statusCode int) (int, string, string) {
+func (h *SoraGatewayHandler) mapUpstreamError(statusCode int, responseBody []byte) (int, string, string) {
+ upstreamCode, upstreamMessage := extractUpstreamErrorCodeAndMessage(responseBody)
+ if upstreamMessage != "" {
+ switch statusCode {
+ case 401, 403, 404, 500, 502, 503, 504:
+ return http.StatusBadGateway, "upstream_error", upstreamMessage
+ case 429:
+ return http.StatusTooManyRequests, "rate_limit_error", upstreamMessage
+ }
+ }
+
switch statusCode {
case 401:
return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator"
case 403:
return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator"
+ case 404:
+ if strings.EqualFold(upstreamCode, "unsupported_country_code") {
+ return http.StatusBadGateway, "upstream_error", "Upstream region capability unavailable for this account, please contact administrator"
+ }
+ return http.StatusBadGateway, "upstream_error", "Upstream capability unavailable for this account, please contact administrator"
case 429:
return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later"
case 529:
@@ -382,6 +403,41 @@ func (h *SoraGatewayHandler) mapUpstreamError(statusCode int) (int, string, stri
}
}
+func extractUpstreamErrorCodeAndMessage(body []byte) (string, string) {
+ trimmed := strings.TrimSpace(string(body))
+ if trimmed == "" {
+ return "", ""
+ }
+ if !gjson.Valid(trimmed) {
+ return "", truncateSoraErrorMessage(trimmed, 256)
+ }
+ code := strings.TrimSpace(gjson.Get(trimmed, "error.code").String())
+ if code == "" {
+ code = strings.TrimSpace(gjson.Get(trimmed, "code").String())
+ }
+ message := strings.TrimSpace(gjson.Get(trimmed, "error.message").String())
+ if message == "" {
+ message = strings.TrimSpace(gjson.Get(trimmed, "message").String())
+ }
+ if message == "" {
+ message = strings.TrimSpace(gjson.Get(trimmed, "error.detail").String())
+ }
+ if message == "" {
+ message = strings.TrimSpace(gjson.Get(trimmed, "detail").String())
+ }
+ return code, truncateSoraErrorMessage(message, 512)
+}
+
+func truncateSoraErrorMessage(s string, maxLen int) string {
+ if maxLen <= 0 {
+ return ""
+ }
+ if len(s) <= maxLen {
+ return s
+ }
+ return s[:maxLen] + "...(truncated)"
+}
+
func (h *SoraGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
if streamStarted {
flusher, ok := c.Writer.(http.Flusher)
diff --git a/backend/internal/handler/sora_gateway_handler_test.go b/backend/internal/handler/sora_gateway_handler_test.go
index 04a58e49..39e2eed6 100644
--- a/backend/internal/handler/sora_gateway_handler_test.go
+++ b/backend/internal/handler/sora_gateway_handler_test.go
@@ -43,6 +43,9 @@ func (s *stubSoraClient) CreateImageTask(ctx context.Context, account *service.A
func (s *stubSoraClient) CreateVideoTask(ctx context.Context, account *service.Account, req service.SoraVideoRequest) (string, error) {
return "task-video", nil
}
+func (s *stubSoraClient) EnhancePrompt(ctx context.Context, account *service.Account, prompt, expansionLevel string, durationS int) (string, error) {
+ return "enhanced prompt", nil
+}
func (s *stubSoraClient) GetImageTask(ctx context.Context, account *service.Account, taskID string) (*service.SoraImageTaskStatus, error) {
return &service.SoraImageTaskStatus{ID: taskID, Status: "completed", URLs: s.imageURLs}, nil
}
diff --git a/backend/internal/pkg/openai/oauth.go b/backend/internal/pkg/openai/oauth.go
index bb120b57..e3b931be 100644
--- a/backend/internal/pkg/openai/oauth.go
+++ b/backend/internal/pkg/openai/oauth.go
@@ -17,6 +17,8 @@ import (
const (
// OAuth Client ID for OpenAI (Codex CLI official)
ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
+ // OAuth Client ID for Sora mobile flow (aligned with sora2api)
+ SoraClientID = "app_LlGpXReQgckcGGUo2JrYvtJK"
// OAuth endpoints
AuthorizeURL = "https://auth.openai.com/oauth/authorize"
diff --git a/backend/internal/repository/openai_oauth_service.go b/backend/internal/repository/openai_oauth_service.go
index 394d3a1a..088e7d7f 100644
--- a/backend/internal/repository/openai_oauth_service.go
+++ b/backend/internal/repository/openai_oauth_service.go
@@ -4,6 +4,7 @@ import (
"context"
"net/http"
"net/url"
+ "strings"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
@@ -56,12 +57,49 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie
}
func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
+ return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, "")
+}
+
+func (s *openaiOAuthService) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
+ if strings.TrimSpace(clientID) != "" {
+ return s.refreshTokenWithClientID(ctx, refreshToken, proxyURL, strings.TrimSpace(clientID))
+ }
+
+ clientIDs := []string{
+ openai.ClientID,
+ openai.SoraClientID,
+ }
+ seen := make(map[string]struct{}, len(clientIDs))
+ var lastErr error
+ for _, clientID := range clientIDs {
+ clientID = strings.TrimSpace(clientID)
+ if clientID == "" {
+ continue
+ }
+ if _, ok := seen[clientID]; ok {
+ continue
+ }
+ seen[clientID] = struct{}{}
+
+ tokenResp, err := s.refreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID)
+ if err == nil {
+ return tokenResp, nil
+ }
+ lastErr = err
+ }
+ if lastErr != nil {
+ return nil, lastErr
+ }
+ return nil, infraerrors.New(http.StatusBadGateway, "OPENAI_OAUTH_TOKEN_REFRESH_FAILED", "token refresh failed")
+}
+
+func (s *openaiOAuthService) refreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL, clientID string) (*openai.TokenResponse, error) {
client := createOpenAIReqClient(proxyURL)
formData := url.Values{}
formData.Set("grant_type", "refresh_token")
formData.Set("refresh_token", refreshToken)
- formData.Set("client_id", openai.ClientID)
+ formData.Set("client_id", clientID)
formData.Set("scope", openai.RefreshScopes)
var tokenResp openai.TokenResponse
diff --git a/backend/internal/repository/openai_oauth_service_test.go b/backend/internal/repository/openai_oauth_service_test.go
index f9df08c8..5938272a 100644
--- a/backend/internal/repository/openai_oauth_service_test.go
+++ b/backend/internal/repository/openai_oauth_service_test.go
@@ -136,6 +136,60 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FormFields() {
require.Equal(s.T(), "rt2", resp.RefreshToken)
}
+func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FallbackToSoraClientID() {
+ var seenClientIDs []string
+ s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if err := r.ParseForm(); err != nil {
+ w.WriteHeader(http.StatusBadRequest)
+ return
+ }
+ clientID := r.PostForm.Get("client_id")
+ seenClientIDs = append(seenClientIDs, clientID)
+ if clientID == openai.ClientID {
+ w.WriteHeader(http.StatusBadRequest)
+ _, _ = io.WriteString(w, "invalid_grant")
+ return
+ }
+ if clientID == openai.SoraClientID {
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = io.WriteString(w, `{"access_token":"at-sora","refresh_token":"rt-sora","token_type":"bearer","expires_in":3600}`)
+ return
+ }
+ w.WriteHeader(http.StatusBadRequest)
+ }))
+
+ resp, err := s.svc.RefreshToken(s.ctx, "rt", "")
+ require.NoError(s.T(), err, "RefreshToken")
+ require.Equal(s.T(), "at-sora", resp.AccessToken)
+ require.Equal(s.T(), "rt-sora", resp.RefreshToken)
+ require.Equal(s.T(), []string{openai.ClientID, openai.SoraClientID}, seenClientIDs)
+}
+
+func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseProvidedClientID() {
+ const customClientID = "custom-client-id"
+ var seenClientIDs []string
+ s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if err := r.ParseForm(); err != nil {
+ w.WriteHeader(http.StatusBadRequest)
+ return
+ }
+ clientID := r.PostForm.Get("client_id")
+ seenClientIDs = append(seenClientIDs, clientID)
+ if clientID != customClientID {
+ w.WriteHeader(http.StatusBadRequest)
+ return
+ }
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = io.WriteString(w, `{"access_token":"at-custom","refresh_token":"rt-custom","token_type":"bearer","expires_in":3600}`)
+ }))
+
+ resp, err := s.svc.RefreshTokenWithClientID(s.ctx, "rt", "", customClientID)
+ require.NoError(s.T(), err, "RefreshTokenWithClientID")
+ require.Equal(s.T(), "at-custom", resp.AccessToken)
+ require.Equal(s.T(), "rt-custom", resp.RefreshToken)
+ require.Equal(s.T(), []string{customClientID}, seenClientIDs)
+}
+
func (s *OpenAIOAuthServiceSuite) TestNonSuccessStatus_IncludesBody() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go
index 57d54a54..7341f85b 100644
--- a/backend/internal/server/routes/admin.go
+++ b/backend/internal/server/routes/admin.go
@@ -34,6 +34,8 @@ func RegisterAdminRoutes(
// OpenAI OAuth
registerOpenAIOAuthRoutes(admin, h)
+ // Sora OAuth(实现复用 OpenAI OAuth 服务,入口独立)
+ registerSoraOAuthRoutes(admin, h)
// Gemini OAuth
registerGeminiOAuthRoutes(admin, h)
@@ -276,6 +278,19 @@ func registerOpenAIOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
}
}
+func registerSoraOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
+ sora := admin.Group("/sora")
+ {
+ sora.POST("/generate-auth-url", h.Admin.OpenAIOAuth.GenerateAuthURL)
+ sora.POST("/exchange-code", h.Admin.OpenAIOAuth.ExchangeCode)
+ sora.POST("/refresh-token", h.Admin.OpenAIOAuth.RefreshToken)
+ sora.POST("/st2at", h.Admin.OpenAIOAuth.ExchangeSoraSessionToken)
+ sora.POST("/rt2at", h.Admin.OpenAIOAuth.RefreshToken)
+ sora.POST("/accounts/:id/refresh", h.Admin.OpenAIOAuth.RefreshAccountToken)
+ sora.POST("/create-from-oauth", h.Admin.OpenAIOAuth.CreateAccountFromOAuth)
+ }
+}
+
func registerGeminiOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
gemini := admin.Group("/gemini")
{
diff --git a/backend/internal/server/routes/gateway.go b/backend/internal/server/routes/gateway.go
index 32f34e0c..69881e70 100644
--- a/backend/internal/server/routes/gateway.go
+++ b/backend/internal/server/routes/gateway.go
@@ -1,6 +1,8 @@
package routes
import (
+ "net/http"
+
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
@@ -41,16 +43,15 @@ func RegisterGatewayRoutes(
gateway.GET("/usage", h.Gateway.Usage)
// OpenAI Responses API
gateway.POST("/responses", h.OpenAIGateway.Responses)
- }
-
- // Sora Chat Completions
- soraGateway := r.Group("/v1")
- soraGateway.Use(soraBodyLimit)
- soraGateway.Use(clientRequestID)
- soraGateway.Use(opsErrorLogger)
- soraGateway.Use(gin.HandlerFunc(apiKeyAuth))
- {
- soraGateway.POST("/chat/completions", h.SoraGateway.ChatCompletions)
+ // 明确阻止旧入口误用到 Sora,避免客户端把 OpenAI Chat Completions 当作 Sora 入口
+ gateway.POST("/chat/completions", func(c *gin.Context) {
+ c.JSON(http.StatusBadRequest, gin.H{
+ "error": gin.H{
+ "type": "invalid_request_error",
+ "message": "For Sora, use /sora/v1/chat/completions. OpenAI should use /v1/responses.",
+ },
+ })
+ })
}
// Gemini 原生 API 兼容层(Gemini SDK/CLI 直连)
diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go
index 093f7d4d..67c9ef0c 100644
--- a/backend/internal/service/account_test_service.go
+++ b/backend/internal/service/account_test_service.go
@@ -27,11 +27,13 @@ import (
// sseDataPrefix matches SSE data lines with optional whitespace after colon.
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
var sseDataPrefix = regexp.MustCompile(`^data:\s*`)
+var cloudflareRayPattern = regexp.MustCompile(`(?i)cRay:\s*'([a-z0-9-]+)'`)
const (
testClaudeAPIURL = "https://api.anthropic.com/v1/messages"
chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses"
soraMeAPIURL = "https://sora.chatgpt.com/backend/me" // Sora 用户信息接口,用于测试连接
+ soraBillingAPIURL = "https://sora.chatgpt.com/backend/billing/subscriptions"
)
// TestEvent represents a SSE event for account testing
@@ -502,8 +504,9 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
+ enableSoraTLSFingerprint := s.shouldEnableSoraTLSFingerprint()
- resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
+ resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, enableSoraTLSFingerprint)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
}
@@ -512,7 +515,10 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
- return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, string(body)))
+ if isCloudflareChallengeResponse(resp.StatusCode, body) {
+ return s.sendErrorAndEnd(c, formatCloudflareChallengeMessage("Sora request blocked by Cloudflare challenge (HTTP 403). Please switch to a clean proxy/network and retry.", resp.Header, body))
+ }
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, truncateSoraErrorBody(body, 512)))
}
// 解析 /me 响应,提取用户信息
@@ -531,10 +537,129 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
s.sendEvent(c, TestEvent{Type: "content", Text: info})
}
+ // 追加轻量能力检查:订阅信息查询(失败仅告警,不中断连接测试)
+ subReq, err := http.NewRequestWithContext(ctx, "GET", soraBillingAPIURL, nil)
+ if err == nil {
+ subReq.Header.Set("Authorization", "Bearer "+authToken)
+ subReq.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
+ subReq.Header.Set("Accept", "application/json")
+
+ subResp, subErr := s.httpUpstream.DoWithTLS(subReq, proxyURL, account.ID, account.Concurrency, enableSoraTLSFingerprint)
+ if subErr != nil {
+ s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check skipped: %s", subErr.Error())})
+ } else {
+ subBody, _ := io.ReadAll(subResp.Body)
+ _ = subResp.Body.Close()
+ if subResp.StatusCode == http.StatusOK {
+ if summary := parseSoraSubscriptionSummary(subBody); summary != "" {
+ s.sendEvent(c, TestEvent{Type: "content", Text: summary})
+ } else {
+ s.sendEvent(c, TestEvent{Type: "content", Text: "Subscription check OK"})
+ }
+ } else {
+ if isCloudflareChallengeResponse(subResp.StatusCode, subBody) {
+ s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage("Subscription check blocked by Cloudflare challenge (HTTP 403)", subResp.Header, subBody)})
+ } else {
+ s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check returned %d", subResp.StatusCode)})
+ }
+ }
+ }
+ }
+
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
+func parseSoraSubscriptionSummary(body []byte) string {
+ var subResp struct {
+ Data []struct {
+ Plan struct {
+ ID string `json:"id"`
+ Title string `json:"title"`
+ } `json:"plan"`
+ EndTS string `json:"end_ts"`
+ } `json:"data"`
+ }
+ if err := json.Unmarshal(body, &subResp); err != nil {
+ return ""
+ }
+ if len(subResp.Data) == 0 {
+ return ""
+ }
+
+ first := subResp.Data[0]
+ parts := make([]string, 0, 3)
+ if first.Plan.Title != "" {
+ parts = append(parts, first.Plan.Title)
+ }
+ if first.Plan.ID != "" {
+ parts = append(parts, first.Plan.ID)
+ }
+ if first.EndTS != "" {
+ parts = append(parts, "end="+first.EndTS)
+ }
+ if len(parts) == 0 {
+ return ""
+ }
+ return "Subscription: " + strings.Join(parts, " | ")
+}
+
+func (s *AccountTestService) shouldEnableSoraTLSFingerprint() bool {
+ if s == nil || s.cfg == nil {
+ return false
+ }
+ return s.cfg.Gateway.TLSFingerprint.Enabled && !s.cfg.Sora.Client.DisableTLSFingerprint
+}
+
+func isCloudflareChallengeResponse(statusCode int, body []byte) bool {
+ if statusCode != http.StatusForbidden {
+ return false
+ }
+ preview := strings.ToLower(truncateSoraErrorBody(body, 4096))
+ return strings.Contains(preview, "window._cf_chl_opt") ||
+ strings.Contains(preview, "just a moment") ||
+ strings.Contains(preview, "enable javascript and cookies to continue")
+}
+
+func formatCloudflareChallengeMessage(base string, headers http.Header, body []byte) string {
+ rayID := extractCloudflareRayID(headers, body)
+ if rayID == "" {
+ return base
+ }
+ return fmt.Sprintf("%s (cf-ray: %s)", base, rayID)
+}
+
+func extractCloudflareRayID(headers http.Header, body []byte) string {
+ if headers != nil {
+ rayID := strings.TrimSpace(headers.Get("cf-ray"))
+ if rayID != "" {
+ return rayID
+ }
+ rayID = strings.TrimSpace(headers.Get("Cf-Ray"))
+ if rayID != "" {
+ return rayID
+ }
+ }
+
+ preview := truncateSoraErrorBody(body, 8192)
+ matches := cloudflareRayPattern.FindStringSubmatch(preview)
+ if len(matches) >= 2 {
+ return strings.TrimSpace(matches[1])
+ }
+ return ""
+}
+
+func truncateSoraErrorBody(body []byte, max int) string {
+ if max <= 0 {
+ max = 512
+ }
+ raw := strings.TrimSpace(string(body))
+ if len(raw) <= max {
+ return raw
+ }
+ return raw[:max] + "...(truncated)"
+}
+
// testAntigravityAccountConnection tests an Antigravity account's connection
// 支持 Claude 和 Gemini 两种协议,使用非流式请求
func (s *AccountTestService) testAntigravityAccountConnection(c *gin.Context, account *Account, modelID string) error {
diff --git a/backend/internal/service/account_test_service_sora_test.go b/backend/internal/service/account_test_service_sora_test.go
new file mode 100644
index 00000000..fbbc8ff1
--- /dev/null
+++ b/backend/internal/service/account_test_service_sora_test.go
@@ -0,0 +1,193 @@
+package service
+
+import (
+ "fmt"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+type queuedHTTPUpstream struct {
+ responses []*http.Response
+ requests []*http.Request
+ tlsFlags []bool
+}
+
+func (u *queuedHTTPUpstream) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
+ return nil, fmt.Errorf("unexpected Do call")
+}
+
+func (u *queuedHTTPUpstream) DoWithTLS(req *http.Request, _ string, _ int64, _ int, enableTLSFingerprint bool) (*http.Response, error) {
+ u.requests = append(u.requests, req)
+ u.tlsFlags = append(u.tlsFlags, enableTLSFingerprint)
+ if len(u.responses) == 0 {
+ return nil, fmt.Errorf("no mocked response")
+ }
+ resp := u.responses[0]
+ u.responses = u.responses[1:]
+ return resp, nil
+}
+
+func newJSONResponse(status int, body string) *http.Response {
+ return &http.Response{
+ StatusCode: status,
+ Header: make(http.Header),
+ Body: io.NopCloser(strings.NewReader(body)),
+ }
+}
+
+func newJSONResponseWithHeader(status int, body, key, value string) *http.Response {
+ resp := newJSONResponse(status, body)
+ resp.Header.Set(key, value)
+ return resp
+}
+
+func newSoraTestContext() (*gin.Context, *httptest.ResponseRecorder) {
+ gin.SetMode(gin.TestMode)
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil)
+ return c, rec
+}
+
+func TestAccountTestService_testSoraAccountConnection_WithSubscription(t *testing.T) {
+ upstream := &queuedHTTPUpstream{
+ responses: []*http.Response{
+ newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`),
+ newJSONResponse(http.StatusOK, `{"data":[{"plan":{"id":"chatgpt_plus","title":"ChatGPT Plus"},"end_ts":"2026-12-31T00:00:00Z"}]}`),
+ },
+ }
+ svc := &AccountTestService{
+ httpUpstream: upstream,
+ cfg: &config.Config{
+ Gateway: config.GatewayConfig{
+ TLSFingerprint: config.TLSFingerprintConfig{
+ Enabled: true,
+ },
+ },
+ Sora: config.SoraConfig{
+ Client: config.SoraClientConfig{
+ DisableTLSFingerprint: false,
+ },
+ },
+ },
+ }
+ account := &Account{
+ ID: 1,
+ Platform: PlatformSora,
+ Type: AccountTypeOAuth,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "access_token": "test_token",
+ },
+ }
+
+ c, rec := newSoraTestContext()
+ err := svc.testSoraAccountConnection(c, account)
+
+ require.NoError(t, err)
+ require.Len(t, upstream.requests, 2)
+ require.Equal(t, soraMeAPIURL, upstream.requests[0].URL.String())
+ require.Equal(t, soraBillingAPIURL, upstream.requests[1].URL.String())
+ require.Equal(t, "Bearer test_token", upstream.requests[0].Header.Get("Authorization"))
+ require.Equal(t, "Bearer test_token", upstream.requests[1].Header.Get("Authorization"))
+ require.Equal(t, []bool{true, true}, upstream.tlsFlags)
+
+ body := rec.Body.String()
+ require.Contains(t, body, `"type":"test_start"`)
+ require.Contains(t, body, "Sora connection OK - Email: demo@example.com")
+ require.Contains(t, body, "Subscription: ChatGPT Plus | chatgpt_plus | end=2026-12-31T00:00:00Z")
+ require.Contains(t, body, `"type":"test_complete","success":true`)
+}
+
+func TestAccountTestService_testSoraAccountConnection_SubscriptionFailedStillSuccess(t *testing.T) {
+ upstream := &queuedHTTPUpstream{
+ responses: []*http.Response{
+ newJSONResponse(http.StatusOK, `{"name":"demo-user"}`),
+ newJSONResponse(http.StatusForbidden, `{"error":{"message":"forbidden"}}`),
+ },
+ }
+ svc := &AccountTestService{httpUpstream: upstream}
+ account := &Account{
+ ID: 1,
+ Platform: PlatformSora,
+ Type: AccountTypeOAuth,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "access_token": "test_token",
+ },
+ }
+
+ c, rec := newSoraTestContext()
+ err := svc.testSoraAccountConnection(c, account)
+
+ require.NoError(t, err)
+ require.Len(t, upstream.requests, 2)
+ body := rec.Body.String()
+ require.Contains(t, body, "Sora connection OK - User: demo-user")
+ require.Contains(t, body, "Subscription check returned 403")
+ require.Contains(t, body, `"type":"test_complete","success":true`)
+}
+
+func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge(t *testing.T) {
+ upstream := &queuedHTTPUpstream{
+ responses: []*http.Response{
+ newJSONResponseWithHeader(http.StatusForbidden, `
Just a moment...`, "cf-ray", "9cff2d62d83bb98d"),
+ },
+ }
+ svc := &AccountTestService{httpUpstream: upstream}
+ account := &Account{
+ ID: 1,
+ Platform: PlatformSora,
+ Type: AccountTypeOAuth,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "access_token": "test_token",
+ },
+ }
+
+ c, rec := newSoraTestContext()
+ err := svc.testSoraAccountConnection(c, account)
+
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "Cloudflare challenge")
+ require.Contains(t, err.Error(), "cf-ray: 9cff2d62d83bb98d")
+ body := rec.Body.String()
+ require.Contains(t, body, `"type":"error"`)
+ require.Contains(t, body, "Cloudflare challenge")
+ require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d")
+}
+
+func TestAccountTestService_testSoraAccountConnection_SubscriptionCloudflareChallengeWithRay(t *testing.T) {
+ upstream := &queuedHTTPUpstream{
+ responses: []*http.Response{
+ newJSONResponse(http.StatusOK, `{"name":"demo-user"}`),
+ newJSONResponse(http.StatusForbidden, `Just a moment...`),
+ },
+ }
+ svc := &AccountTestService{httpUpstream: upstream}
+ account := &Account{
+ ID: 1,
+ Platform: PlatformSora,
+ Type: AccountTypeOAuth,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "access_token": "test_token",
+ },
+ }
+
+ c, rec := newSoraTestContext()
+ err := svc.testSoraAccountConnection(c, account)
+
+ require.NoError(t, err)
+ body := rec.Body.String()
+ require.Contains(t, body, "Subscription check blocked by Cloudflare challenge (HTTP 403)")
+ require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d")
+ require.Contains(t, body, `"type":"test_complete","success":true`)
+}
diff --git a/backend/internal/service/oauth_service.go b/backend/internal/service/oauth_service.go
index e247e654..6f6261d8 100644
--- a/backend/internal/service/oauth_service.go
+++ b/backend/internal/service/oauth_service.go
@@ -14,6 +14,7 @@ import (
type OpenAIOAuthClient interface {
ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error)
RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error)
+ RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error)
}
// ClaudeOAuthClient handles HTTP requests for Claude OAuth flows
diff --git a/backend/internal/service/openai_oauth_service.go b/backend/internal/service/openai_oauth_service.go
index ca7470b9..087ad4ec 100644
--- a/backend/internal/service/openai_oauth_service.go
+++ b/backend/internal/service/openai_oauth_service.go
@@ -2,13 +2,20 @@ package service
import (
"context"
+ "crypto/subtle"
+ "encoding/json"
+ "io"
"net/http"
+ "net/url"
+ "strings"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
)
+var openAISoraSessionAuthURL = "https://sora.chatgpt.com/api/auth/session"
+
// OpenAIOAuthService handles OpenAI OAuth authentication flows
type OpenAIOAuthService struct {
sessionStore *openai.SessionStore
@@ -92,6 +99,7 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
type OpenAIExchangeCodeInput struct {
SessionID string
Code string
+ State string
RedirectURI string
ProxyID *int64
}
@@ -116,6 +124,12 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
if !ok {
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_SESSION_NOT_FOUND", "session not found or expired")
}
+ if input.State == "" {
+ return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_STATE_REQUIRED", "oauth state is required")
+ }
+ if subtle.ConstantTimeCompare([]byte(input.State), []byte(session.State)) != 1 {
+ return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_STATE", "invalid oauth state")
+ }
// Get proxy URL: prefer input.ProxyID, fallback to session.ProxyURL
proxyURL := session.ProxyURL
@@ -173,7 +187,12 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
// RefreshToken refreshes an OpenAI OAuth token
func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken string, proxyURL string) (*OpenAITokenInfo, error) {
- tokenResp, err := s.oauthClient.RefreshToken(ctx, refreshToken, proxyURL)
+ return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, "")
+}
+
+// RefreshTokenWithClientID refreshes an OpenAI/Sora OAuth token with optional client_id.
+func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refreshToken string, proxyURL string, clientID string) (*OpenAITokenInfo, error) {
+ tokenResp, err := s.oauthClient.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID)
if err != nil {
return nil, err
}
@@ -205,13 +224,83 @@ func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken stri
return tokenInfo, nil
}
-// RefreshAccountToken refreshes token for an OpenAI account
-func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
- if !account.IsOpenAI() {
- return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI account")
+// ExchangeSoraSessionToken exchanges Sora session_token to access_token.
+func (s *OpenAIOAuthService) ExchangeSoraSessionToken(ctx context.Context, sessionToken string, proxyID *int64) (*OpenAITokenInfo, error) {
+ if strings.TrimSpace(sessionToken) == "" {
+ return nil, infraerrors.New(http.StatusBadRequest, "SORA_SESSION_TOKEN_REQUIRED", "session_token is required")
}
- refreshToken := account.GetOpenAIRefreshToken()
+ proxyURL, err := s.resolveProxyURL(ctx, proxyID)
+ if err != nil {
+ return nil, err
+ }
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, openAISoraSessionAuthURL, nil)
+ if err != nil {
+ return nil, infraerrors.Newf(http.StatusInternalServerError, "SORA_SESSION_REQUEST_BUILD_FAILED", "failed to build request: %v", err)
+ }
+ req.Header.Set("Cookie", "__Secure-next-auth.session-token="+strings.TrimSpace(sessionToken))
+ req.Header.Set("Accept", "application/json")
+ req.Header.Set("Origin", "https://sora.chatgpt.com")
+ req.Header.Set("Referer", "https://sora.chatgpt.com/")
+ req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
+
+ client := newOpenAIOAuthHTTPClient(proxyURL)
+ resp, err := client.Do(req)
+ if err != nil {
+ return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_REQUEST_FAILED", "request failed: %v", err)
+ }
+ defer func() { _ = resp.Body.Close() }()
+
+ body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
+ if resp.StatusCode != http.StatusOK {
+ return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_EXCHANGE_FAILED", "status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
+ }
+
+ var sessionResp struct {
+ AccessToken string `json:"accessToken"`
+ Expires string `json:"expires"`
+ User struct {
+ Email string `json:"email"`
+ Name string `json:"name"`
+ } `json:"user"`
+ }
+ if err := json.Unmarshal(body, &sessionResp); err != nil {
+ return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_PARSE_FAILED", "failed to parse response: %v", err)
+ }
+ if strings.TrimSpace(sessionResp.AccessToken) == "" {
+ return nil, infraerrors.New(http.StatusBadGateway, "SORA_SESSION_ACCESS_TOKEN_MISSING", "session exchange response missing access token")
+ }
+
+ expiresAt := time.Now().Add(time.Hour).Unix()
+ if strings.TrimSpace(sessionResp.Expires) != "" {
+ if parsed, parseErr := time.Parse(time.RFC3339, sessionResp.Expires); parseErr == nil {
+ expiresAt = parsed.Unix()
+ }
+ }
+ expiresIn := expiresAt - time.Now().Unix()
+ if expiresIn < 0 {
+ expiresIn = 0
+ }
+
+ return &OpenAITokenInfo{
+ AccessToken: strings.TrimSpace(sessionResp.AccessToken),
+ ExpiresIn: expiresIn,
+ ExpiresAt: expiresAt,
+ Email: strings.TrimSpace(sessionResp.User.Email),
+ }, nil
+}
+
+// RefreshAccountToken refreshes token for an OpenAI/Sora OAuth account
+func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
+ if account.Platform != PlatformOpenAI && account.Platform != PlatformSora {
+ return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI/Sora account")
+ }
+ if account.Type != AccountTypeOAuth {
+ return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT_TYPE", "account is not an OAuth account")
+ }
+
+ refreshToken := account.GetCredential("refresh_token")
if refreshToken == "" {
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_NO_REFRESH_TOKEN", "no refresh token available")
}
@@ -224,7 +313,8 @@ func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *A
}
}
- return s.RefreshToken(ctx, refreshToken, proxyURL)
+ clientID := account.GetCredential("client_id")
+ return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID)
}
// BuildAccountCredentials builds credentials map from token info
@@ -260,3 +350,30 @@ func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo)
func (s *OpenAIOAuthService) Stop() {
s.sessionStore.Stop()
}
+
+func (s *OpenAIOAuthService) resolveProxyURL(ctx context.Context, proxyID *int64) (string, error) {
+ if proxyID == nil {
+ return "", nil
+ }
+ proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
+ if err != nil {
+ return "", infraerrors.Newf(http.StatusBadRequest, "OPENAI_OAUTH_PROXY_NOT_FOUND", "proxy not found: %v", err)
+ }
+ if proxy == nil {
+ return "", nil
+ }
+ return proxy.URL(), nil
+}
+
+func newOpenAIOAuthHTTPClient(proxyURL string) *http.Client {
+ transport := &http.Transport{}
+ if strings.TrimSpace(proxyURL) != "" {
+ if parsed, err := url.Parse(proxyURL); err == nil && parsed.Host != "" {
+ transport.Proxy = http.ProxyURL(parsed)
+ }
+ }
+ return &http.Client{
+ Timeout: 120 * time.Second,
+ Transport: transport,
+ }
+}
diff --git a/backend/internal/service/openai_oauth_service_sora_session_test.go b/backend/internal/service/openai_oauth_service_sora_session_test.go
new file mode 100644
index 00000000..fb76f6c1
--- /dev/null
+++ b/backend/internal/service/openai_oauth_service_sora_session_test.go
@@ -0,0 +1,69 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
+ "github.com/stretchr/testify/require"
+)
+
+type openaiOAuthClientNoopStub struct{}
+
+func (s *openaiOAuthClientNoopStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
+ return nil, errors.New("not implemented")
+}
+
+func (s *openaiOAuthClientNoopStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
+ return nil, errors.New("not implemented")
+}
+
+func (s *openaiOAuthClientNoopStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
+ return nil, errors.New("not implemented")
+}
+
+func TestOpenAIOAuthService_ExchangeSoraSessionToken_Success(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ require.Equal(t, http.MethodGet, r.Method)
+ require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=st-token")
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
+ }))
+ defer server.Close()
+
+ origin := openAISoraSessionAuthURL
+ openAISoraSessionAuthURL = server.URL
+ defer func() { openAISoraSessionAuthURL = origin }()
+
+ svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
+ defer svc.Stop()
+
+ info, err := svc.ExchangeSoraSessionToken(context.Background(), "st-token", nil)
+ require.NoError(t, err)
+ require.NotNil(t, info)
+ require.Equal(t, "at-token", info.AccessToken)
+ require.Equal(t, "demo@example.com", info.Email)
+ require.Greater(t, info.ExpiresAt, int64(0))
+}
+
+func TestOpenAIOAuthService_ExchangeSoraSessionToken_MissingAccessToken(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"expires":"2099-01-01T00:00:00Z"}`))
+ }))
+ defer server.Close()
+
+ origin := openAISoraSessionAuthURL
+ openAISoraSessionAuthURL = server.URL
+ defer func() { openAISoraSessionAuthURL = origin }()
+
+ svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
+ defer svc.Stop()
+
+ _, err := svc.ExchangeSoraSessionToken(context.Background(), "st-token", nil)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "missing access token")
+}
diff --git a/backend/internal/service/openai_oauth_service_state_test.go b/backend/internal/service/openai_oauth_service_state_test.go
new file mode 100644
index 00000000..0a2a195f
--- /dev/null
+++ b/backend/internal/service/openai_oauth_service_state_test.go
@@ -0,0 +1,102 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
+ "github.com/stretchr/testify/require"
+)
+
+type openaiOAuthClientStateStub struct {
+ exchangeCalled int32
+}
+
+func (s *openaiOAuthClientStateStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
+ atomic.AddInt32(&s.exchangeCalled, 1)
+ return &openai.TokenResponse{
+ AccessToken: "at",
+ RefreshToken: "rt",
+ ExpiresIn: 3600,
+ }, nil
+}
+
+func (s *openaiOAuthClientStateStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
+ return nil, errors.New("not implemented")
+}
+
+func (s *openaiOAuthClientStateStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
+ return s.RefreshToken(ctx, refreshToken, proxyURL)
+}
+
+func TestOpenAIOAuthService_ExchangeCode_StateRequired(t *testing.T) {
+ client := &openaiOAuthClientStateStub{}
+ svc := NewOpenAIOAuthService(nil, client)
+ defer svc.Stop()
+
+ svc.sessionStore.Set("sid", &openai.OAuthSession{
+ State: "expected-state",
+ CodeVerifier: "verifier",
+ RedirectURI: openai.DefaultRedirectURI,
+ CreatedAt: time.Now(),
+ })
+
+ _, err := svc.ExchangeCode(context.Background(), &OpenAIExchangeCodeInput{
+ SessionID: "sid",
+ Code: "auth-code",
+ })
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "oauth state is required")
+ require.Equal(t, int32(0), atomic.LoadInt32(&client.exchangeCalled))
+}
+
+func TestOpenAIOAuthService_ExchangeCode_StateMismatch(t *testing.T) {
+ client := &openaiOAuthClientStateStub{}
+ svc := NewOpenAIOAuthService(nil, client)
+ defer svc.Stop()
+
+ svc.sessionStore.Set("sid", &openai.OAuthSession{
+ State: "expected-state",
+ CodeVerifier: "verifier",
+ RedirectURI: openai.DefaultRedirectURI,
+ CreatedAt: time.Now(),
+ })
+
+ _, err := svc.ExchangeCode(context.Background(), &OpenAIExchangeCodeInput{
+ SessionID: "sid",
+ Code: "auth-code",
+ State: "wrong-state",
+ })
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "invalid oauth state")
+ require.Equal(t, int32(0), atomic.LoadInt32(&client.exchangeCalled))
+}
+
+func TestOpenAIOAuthService_ExchangeCode_StateMatch(t *testing.T) {
+ client := &openaiOAuthClientStateStub{}
+ svc := NewOpenAIOAuthService(nil, client)
+ defer svc.Stop()
+
+ svc.sessionStore.Set("sid", &openai.OAuthSession{
+ State: "expected-state",
+ CodeVerifier: "verifier",
+ RedirectURI: openai.DefaultRedirectURI,
+ CreatedAt: time.Now(),
+ })
+
+ info, err := svc.ExchangeCode(context.Background(), &OpenAIExchangeCodeInput{
+ SessionID: "sid",
+ Code: "auth-code",
+ State: "expected-state",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, info)
+ require.Equal(t, "at", info.AccessToken)
+ require.Equal(t, int32(1), atomic.LoadInt32(&client.exchangeCalled))
+
+ _, ok := svc.sessionStore.Get("sid")
+ require.False(t, ok)
+}
diff --git a/backend/internal/service/openai_token_provider.go b/backend/internal/service/openai_token_provider.go
index 3842f0a4..a8a6b96c 100644
--- a/backend/internal/service/openai_token_provider.go
+++ b/backend/internal/service/openai_token_provider.go
@@ -157,7 +157,11 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
}
expiresAt = account.GetCredentialAsTime("expires_at")
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
- if p.openAIOAuthService == nil {
+ if account.Platform == PlatformSora {
+ slog.Debug("openai_token_refresh_skipped_for_sora", "account_id", account.ID)
+ // Sora 账号不走 OpenAI OAuth 刷新,交由 Sora 客户端的 ST/RT 恢复链路处理。
+ refreshFailed = true
+ } else if p.openAIOAuthService == nil {
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
p.metrics.refreshFailure.Add(1)
refreshFailed = true // 无法刷新,标记失败
@@ -206,7 +210,11 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
// 仅在 expires_at 已过期/接近过期时才执行无锁刷新
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
- if p.openAIOAuthService == nil {
+ if account.Platform == PlatformSora {
+ slog.Debug("openai_token_refresh_skipped_for_sora_degraded", "account_id", account.ID)
+ // Sora 账号不走 OpenAI OAuth 刷新,交由 Sora 客户端的 ST/RT 恢复链路处理。
+ refreshFailed = true
+ } else if p.openAIOAuthService == nil {
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
p.metrics.refreshFailure.Add(1)
refreshFailed = true
diff --git a/backend/internal/service/sora_client.go b/backend/internal/service/sora_client.go
index de097d5e..38be7a04 100644
--- a/backend/internal/service/sora_client.go
+++ b/backend/internal/service/sora_client.go
@@ -17,12 +17,15 @@ import (
"net/textproto"
"net/url"
"path"
+ "sort"
"strconv"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
+ openaioauth "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
+ "github.com/Wei-Shaw/sub2api/internal/util/logredact"
"github.com/google/uuid"
"github.com/tidwall/gjson"
"golang.org/x/crypto/sha3"
@@ -34,6 +37,11 @@ const (
soraDefaultUserAgent = "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)"
)
+var (
+ soraSessionAuthURL = "https://sora.chatgpt.com/api/auth/session"
+ soraOAuthTokenURL = "https://auth.openai.com/oauth/token"
+)
+
const (
soraPowMaxIteration = 500000
)
@@ -96,6 +104,7 @@ type SoraClient interface {
UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error)
CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error)
CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error)
+ EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error)
GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error)
GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error)
}
@@ -157,26 +166,94 @@ func (e *SoraUpstreamError) Error() string {
// SoraDirectClient 直连 Sora 实现
type SoraDirectClient struct {
- cfg *config.Config
- httpUpstream HTTPUpstream
- tokenProvider *OpenAITokenProvider
+ cfg *config.Config
+ httpUpstream HTTPUpstream
+ tokenProvider *OpenAITokenProvider
+ accountRepo AccountRepository
+ soraAccountRepo SoraAccountRepository
+ baseURL string
}
// NewSoraDirectClient 创建 Sora 直连客户端
func NewSoraDirectClient(cfg *config.Config, httpUpstream HTTPUpstream, tokenProvider *OpenAITokenProvider) *SoraDirectClient {
+ baseURL := ""
+ if cfg != nil {
+ rawBaseURL := strings.TrimRight(strings.TrimSpace(cfg.Sora.Client.BaseURL), "/")
+ baseURL = normalizeSoraBaseURL(rawBaseURL)
+ if rawBaseURL != "" && baseURL != rawBaseURL {
+ log.Printf("[SoraClient] normalized base_url from %s to %s", sanitizeSoraLogURL(rawBaseURL), sanitizeSoraLogURL(baseURL))
+ }
+ }
return &SoraDirectClient{
cfg: cfg,
httpUpstream: httpUpstream,
tokenProvider: tokenProvider,
+ baseURL: baseURL,
}
}
+func (c *SoraDirectClient) SetAccountRepositories(accountRepo AccountRepository, soraAccountRepo SoraAccountRepository) {
+ if c == nil {
+ return
+ }
+ c.accountRepo = accountRepo
+ c.soraAccountRepo = soraAccountRepo
+}
+
// Enabled 判断是否启用 Sora 直连
func (c *SoraDirectClient) Enabled() bool {
- if c == nil || c.cfg == nil {
+ if c == nil {
return false
}
- return strings.TrimSpace(c.cfg.Sora.Client.BaseURL) != ""
+ if strings.TrimSpace(c.baseURL) != "" {
+ return true
+ }
+ if c.cfg == nil {
+ return false
+ }
+ return strings.TrimSpace(normalizeSoraBaseURL(c.cfg.Sora.Client.BaseURL)) != ""
+}
+
+// PreflightCheck 在创建任务前执行账号能力预检。
+// 当前仅对视频模型执行 /nf/check 预检,用于提前识别额度耗尽或能力缺失。
+func (c *SoraDirectClient) PreflightCheck(ctx context.Context, account *Account, requestedModel string, modelCfg SoraModelConfig) error {
+ if modelCfg.Type != "video" {
+ return nil
+ }
+ token, err := c.getAccessToken(ctx, account)
+ if err != nil {
+ return err
+ }
+ headers := c.buildBaseHeaders(token, c.defaultUserAgent())
+ headers.Set("Accept", "application/json")
+ body, _, err := c.doRequest(ctx, account, http.MethodGet, c.buildURL("/nf/check"), headers, nil, false)
+ if err != nil {
+ var upstreamErr *SoraUpstreamError
+ if errors.As(err, &upstreamErr) && upstreamErr.StatusCode == http.StatusNotFound {
+ return &SoraUpstreamError{
+ StatusCode: http.StatusForbidden,
+ Message: "当前账号未开通 Sora2 能力或无可用配额",
+ Headers: upstreamErr.Headers,
+ Body: upstreamErr.Body,
+ }
+ }
+ return err
+ }
+
+ rateLimitReached := gjson.GetBytes(body, "rate_limit_and_credit_balance.rate_limit_reached").Bool()
+ remaining := gjson.GetBytes(body, "rate_limit_and_credit_balance.estimated_num_videos_remaining")
+ if rateLimitReached || (remaining.Exists() && remaining.Int() <= 0) {
+ msg := "当前账号 Sora2 可用配额不足"
+ if requestedModel != "" {
+ msg = fmt.Sprintf("当前账号 %s 可用配额不足", requestedModel)
+ }
+ return &SoraUpstreamError{
+ StatusCode: http.StatusTooManyRequests,
+ Message: msg,
+ Headers: http.Header{},
+ }
+ }
+ return nil
}
func (c *SoraDirectClient) UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) {
@@ -347,6 +424,45 @@ func (c *SoraDirectClient) CreateVideoTask(ctx context.Context, account *Account
return taskID, nil
}
+func (c *SoraDirectClient) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) {
+ token, err := c.getAccessToken(ctx, account)
+ if err != nil {
+ return "", err
+ }
+ if strings.TrimSpace(expansionLevel) == "" {
+ expansionLevel = "medium"
+ }
+ if durationS <= 0 {
+ durationS = 10
+ }
+
+ payload := map[string]any{
+ "prompt": prompt,
+ "expansion_level": expansionLevel,
+ "duration_s": durationS,
+ }
+ body, err := json.Marshal(payload)
+ if err != nil {
+ return "", err
+ }
+
+ headers := c.buildBaseHeaders(token, c.defaultUserAgent())
+ headers.Set("Content-Type", "application/json")
+ headers.Set("Accept", "application/json")
+ headers.Set("Origin", "https://sora.chatgpt.com")
+ headers.Set("Referer", "https://sora.chatgpt.com/")
+
+ respBody, _, err := c.doRequest(ctx, account, http.MethodPost, c.buildURL("/editor/enhance_prompt"), headers, bytes.NewReader(body), false)
+ if err != nil {
+ return "", err
+ }
+ enhancedPrompt := strings.TrimSpace(gjson.GetBytes(respBody, "enhanced_prompt").String())
+ if enhancedPrompt == "" {
+ return "", errors.New("enhance_prompt response missing enhanced_prompt")
+ }
+ return enhancedPrompt, nil
+}
+
func (c *SoraDirectClient) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) {
status, found, err := c.fetchRecentImageTask(ctx, account, taskID, c.recentTaskLimit())
if err != nil {
@@ -512,9 +628,10 @@ func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, t
}
func (c *SoraDirectClient) buildURL(endpoint string) string {
- base := ""
- if c != nil && c.cfg != nil {
- base = strings.TrimRight(strings.TrimSpace(c.cfg.Sora.Client.BaseURL), "/")
+ base := strings.TrimRight(strings.TrimSpace(c.baseURL), "/")
+ if base == "" && c != nil && c.cfg != nil {
+ base = normalizeSoraBaseURL(c.cfg.Sora.Client.BaseURL)
+ c.baseURL = base
}
if base == "" {
return endpoint
@@ -540,14 +657,257 @@ func (c *SoraDirectClient) getAccessToken(ctx context.Context, account *Account)
if account == nil {
return "", errors.New("account is nil")
}
- if c.tokenProvider != nil {
- return c.tokenProvider.GetAccessToken(ctx, account)
+
+ allowProvider := c.allowOpenAITokenProvider(account)
+ var providerErr error
+ if allowProvider && c.tokenProvider != nil {
+ token, err := c.tokenProvider.GetAccessToken(ctx, account)
+ if err == nil && strings.TrimSpace(token) != "" {
+ c.logTokenSource(account, "openai_token_provider")
+ return token, nil
+ }
+ providerErr = err
+ if err != nil && c.debugEnabled() {
+ c.debugLogf(
+ "token_provider_failed account_id=%d platform=%s err=%s",
+ account.ID,
+ account.Platform,
+ logredact.RedactText(err.Error()),
+ )
+ }
}
token := strings.TrimSpace(account.GetCredential("access_token"))
- if token == "" {
- return "", errors.New("access_token not found")
+ if token != "" {
+ expiresAt := account.GetCredentialAsTime("expires_at")
+ if expiresAt != nil && time.Until(*expiresAt) <= 2*time.Minute {
+ refreshed, refreshErr := c.recoverAccessToken(ctx, account, "access_token_expiring")
+ if refreshErr == nil && strings.TrimSpace(refreshed) != "" {
+ c.logTokenSource(account, "refresh_token_recovered")
+ return refreshed, nil
+ }
+ if refreshErr != nil && c.debugEnabled() {
+ c.debugLogf("token_refresh_before_use_failed account_id=%d err=%s", account.ID, logredact.RedactText(refreshErr.Error()))
+ }
+ }
+ c.logTokenSource(account, "account_credentials")
+ return token, nil
}
- return token, nil
+
+ recovered, recoverErr := c.recoverAccessToken(ctx, account, "access_token_missing")
+ if recoverErr == nil && strings.TrimSpace(recovered) != "" {
+ c.logTokenSource(account, "session_or_refresh_recovered")
+ return recovered, nil
+ }
+ if recoverErr != nil && c.debugEnabled() {
+ c.debugLogf("token_recover_failed account_id=%d platform=%s err=%s", account.ID, account.Platform, logredact.RedactText(recoverErr.Error()))
+ }
+ if providerErr != nil {
+ return "", providerErr
+ }
+ if c.tokenProvider != nil && !allowProvider {
+ c.logTokenSource(account, "account_credentials(provider_disabled)")
+ }
+ return "", errors.New("access_token not found")
+}
+
+func (c *SoraDirectClient) recoverAccessToken(ctx context.Context, account *Account, reason string) (string, error) {
+ if account == nil {
+ return "", errors.New("account is nil")
+ }
+
+ if sessionToken := strings.TrimSpace(account.GetCredential("session_token")); sessionToken != "" {
+ accessToken, expiresAt, err := c.exchangeSessionToken(ctx, account, sessionToken)
+ if err == nil && strings.TrimSpace(accessToken) != "" {
+ c.applyRecoveredToken(ctx, account, accessToken, "", expiresAt, sessionToken)
+ c.logTokenRecover(account, "session_token", reason, true, nil)
+ return accessToken, nil
+ }
+ c.logTokenRecover(account, "session_token", reason, false, err)
+ }
+
+ refreshToken := strings.TrimSpace(account.GetCredential("refresh_token"))
+ if refreshToken == "" {
+ return "", errors.New("session_token/refresh_token not found")
+ }
+ accessToken, newRefreshToken, expiresAt, err := c.exchangeRefreshToken(ctx, account, refreshToken)
+ if err != nil {
+ c.logTokenRecover(account, "refresh_token", reason, false, err)
+ return "", err
+ }
+ if strings.TrimSpace(accessToken) == "" {
+ return "", errors.New("refreshed access_token is empty")
+ }
+ c.applyRecoveredToken(ctx, account, accessToken, newRefreshToken, expiresAt, "")
+ c.logTokenRecover(account, "refresh_token", reason, true, nil)
+ return accessToken, nil
+}
+
+func (c *SoraDirectClient) exchangeSessionToken(ctx context.Context, account *Account, sessionToken string) (string, string, error) {
+ headers := http.Header{}
+ headers.Set("Cookie", "__Secure-next-auth.session-token="+sessionToken)
+ headers.Set("Accept", "application/json")
+ headers.Set("Origin", "https://sora.chatgpt.com")
+ headers.Set("Referer", "https://sora.chatgpt.com/")
+ headers.Set("User-Agent", c.defaultUserAgent())
+ body, _, err := c.doRequest(ctx, account, http.MethodGet, soraSessionAuthURL, headers, nil, false)
+ if err != nil {
+ return "", "", err
+ }
+ accessToken := strings.TrimSpace(gjson.GetBytes(body, "accessToken").String())
+ if accessToken == "" {
+ return "", "", errors.New("session exchange missing accessToken")
+ }
+ expiresAt := strings.TrimSpace(gjson.GetBytes(body, "expires").String())
+ return accessToken, expiresAt, nil
+}
+
+func (c *SoraDirectClient) exchangeRefreshToken(ctx context.Context, account *Account, refreshToken string) (string, string, string, error) {
+ clientIDs := []string{
+ strings.TrimSpace(account.GetCredential("client_id")),
+ openaioauth.SoraClientID,
+ openaioauth.ClientID,
+ }
+ tried := make(map[string]struct{}, len(clientIDs))
+ var lastErr error
+
+ for _, clientID := range clientIDs {
+ if clientID == "" {
+ continue
+ }
+ if _, ok := tried[clientID]; ok {
+ continue
+ }
+ tried[clientID] = struct{}{}
+
+ payload := map[string]any{
+ "client_id": clientID,
+ "grant_type": "refresh_token",
+ "refresh_token": refreshToken,
+ "redirect_uri": "com.openai.chat://auth0.openai.com/ios/com.openai.chat/callback",
+ }
+ bodyBytes, err := json.Marshal(payload)
+ if err != nil {
+ return "", "", "", err
+ }
+ headers := http.Header{}
+ headers.Set("Accept", "application/json")
+ headers.Set("Content-Type", "application/json")
+ headers.Set("User-Agent", c.defaultUserAgent())
+
+ respBody, _, err := c.doRequest(ctx, account, http.MethodPost, soraOAuthTokenURL, headers, bytes.NewReader(bodyBytes), false)
+ if err != nil {
+ lastErr = err
+ if c.debugEnabled() {
+ c.debugLogf("refresh_token_exchange_failed account_id=%d client_id=%s err=%s", account.ID, clientID, logredact.RedactText(err.Error()))
+ }
+ continue
+ }
+ accessToken := strings.TrimSpace(gjson.GetBytes(respBody, "access_token").String())
+ if accessToken == "" {
+ lastErr = errors.New("oauth refresh response missing access_token")
+ continue
+ }
+ newRefreshToken := strings.TrimSpace(gjson.GetBytes(respBody, "refresh_token").String())
+ expiresIn := gjson.GetBytes(respBody, "expires_in").Int()
+ expiresAt := ""
+ if expiresIn > 0 {
+ expiresAt = time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339)
+ }
+ return accessToken, newRefreshToken, expiresAt, nil
+ }
+
+ if lastErr != nil {
+ return "", "", "", lastErr
+ }
+ return "", "", "", errors.New("no available client_id for refresh_token exchange")
+}
+
+func (c *SoraDirectClient) applyRecoveredToken(ctx context.Context, account *Account, accessToken, refreshToken, expiresAt, sessionToken string) {
+ if account == nil {
+ return
+ }
+ if account.Credentials == nil {
+ account.Credentials = make(map[string]any)
+ }
+ if strings.TrimSpace(accessToken) != "" {
+ account.Credentials["access_token"] = accessToken
+ }
+ if strings.TrimSpace(refreshToken) != "" {
+ account.Credentials["refresh_token"] = refreshToken
+ }
+ if strings.TrimSpace(expiresAt) != "" {
+ account.Credentials["expires_at"] = expiresAt
+ }
+ if strings.TrimSpace(sessionToken) != "" {
+ account.Credentials["session_token"] = sessionToken
+ }
+
+ if c.accountRepo != nil {
+ if err := c.accountRepo.Update(ctx, account); err != nil {
+ if c.debugEnabled() {
+ c.debugLogf("persist_recovered_token_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error()))
+ }
+ }
+ }
+ c.updateSoraAccountExtension(ctx, account, accessToken, refreshToken, sessionToken)
+}
+
+func (c *SoraDirectClient) updateSoraAccountExtension(ctx context.Context, account *Account, accessToken, refreshToken, sessionToken string) {
+ if c == nil || c.soraAccountRepo == nil || account == nil || account.ID <= 0 {
+ return
+ }
+ updates := make(map[string]any)
+ if strings.TrimSpace(accessToken) != "" && strings.TrimSpace(refreshToken) != "" {
+ updates["access_token"] = accessToken
+ updates["refresh_token"] = refreshToken
+ }
+ if strings.TrimSpace(sessionToken) != "" {
+ updates["session_token"] = sessionToken
+ }
+ if len(updates) == 0 {
+ return
+ }
+ if err := c.soraAccountRepo.Upsert(ctx, account.ID, updates); err != nil && c.debugEnabled() {
+ c.debugLogf("persist_sora_extension_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error()))
+ }
+}
+
+func (c *SoraDirectClient) logTokenRecover(account *Account, source, reason string, success bool, err error) {
+ if !c.debugEnabled() || account == nil {
+ return
+ }
+ if success {
+ c.debugLogf("token_recover_success account_id=%d platform=%s source=%s reason=%s", account.ID, account.Platform, source, reason)
+ return
+ }
+ if err == nil {
+ c.debugLogf("token_recover_failed account_id=%d platform=%s source=%s reason=%s", account.ID, account.Platform, source, reason)
+ return
+ }
+ c.debugLogf("token_recover_failed account_id=%d platform=%s source=%s reason=%s err=%s", account.ID, account.Platform, source, reason, logredact.RedactText(err.Error()))
+}
+
+func (c *SoraDirectClient) allowOpenAITokenProvider(account *Account) bool {
+ if c == nil || c.tokenProvider == nil {
+ return false
+ }
+ if account != nil && account.Platform == PlatformSora {
+ return c.cfg != nil && c.cfg.Sora.Client.UseOpenAITokenProvider
+ }
+ return true
+}
+
+func (c *SoraDirectClient) logTokenSource(account *Account, source string) {
+ if !c.debugEnabled() || account == nil {
+ return
+ }
+ c.debugLogf(
+ "token_selected account_id=%d platform=%s account_type=%s source=%s",
+ account.ID,
+ account.Platform,
+ account.Type,
+ source,
+ )
}
func (c *SoraDirectClient) buildBaseHeaders(token, userAgent string) http.Header {
@@ -600,7 +960,24 @@ func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, meth
}
attempts := maxRetries + 1
+ authRecovered := false
+ authRecoverExtraAttemptGranted := false
+ var lastErr error
for attempt := 1; attempt <= attempts; attempt++ {
+ if c.debugEnabled() {
+ c.debugLogf(
+ "request_start method=%s url=%s attempt=%d/%d timeout_s=%d body_bytes=%d proxy_bound=%t headers=%s",
+ method,
+ sanitizeSoraLogURL(urlStr),
+ attempt,
+ attempts,
+ timeout,
+ len(bodyBytes),
+ account != nil && account.ProxyID != nil && account.Proxy != nil,
+ formatSoraHeaders(headers),
+ )
+ }
+
var reader io.Reader
if bodyBytes != nil {
reader = bytes.NewReader(bodyBytes)
@@ -618,7 +995,21 @@ func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, meth
}
resp, err := c.doHTTP(req, proxyURL, account)
if err != nil {
+ lastErr = err
+ if c.debugEnabled() {
+ c.debugLogf(
+ "request_transport_error method=%s url=%s attempt=%d/%d err=%s",
+ method,
+ sanitizeSoraLogURL(urlStr),
+ attempt,
+ attempts,
+ logredact.RedactText(err.Error()),
+ )
+ }
if attempt < attempts && allowRetry {
+ if c.debugEnabled() {
+ c.debugLogf("request_retry_scheduled method=%s url=%s reason=transport_error next_attempt=%d/%d", method, sanitizeSoraLogURL(urlStr), attempt+1, attempts)
+ }
c.sleepRetry(attempt)
continue
}
@@ -632,12 +1023,53 @@ func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, meth
}
if c.cfg != nil && c.cfg.Sora.Client.Debug {
- log.Printf("[SoraClient] %s %s status=%d cost=%s", method, sanitizeSoraLogURL(urlStr), resp.StatusCode, time.Since(start))
+ c.debugLogf(
+ "response_received method=%s url=%s attempt=%d/%d status=%d cost=%s resp_bytes=%d resp_headers=%s",
+ method,
+ sanitizeSoraLogURL(urlStr),
+ attempt,
+ attempts,
+ resp.StatusCode,
+ time.Since(start),
+ len(respBody),
+ formatSoraHeaders(resp.Header),
+ )
}
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
- upstreamErr := c.buildUpstreamError(resp.StatusCode, resp.Header, respBody)
+ if !authRecovered && shouldAttemptSoraTokenRecover(resp.StatusCode, urlStr) && account != nil {
+ if recovered, recoverErr := c.recoverAccessToken(ctx, account, fmt.Sprintf("upstream_status_%d", resp.StatusCode)); recoverErr == nil && strings.TrimSpace(recovered) != "" {
+ headers.Set("Authorization", "Bearer "+recovered)
+ authRecovered = true
+ if attempt == attempts && !authRecoverExtraAttemptGranted {
+ attempts++
+ authRecoverExtraAttemptGranted = true
+ }
+ if c.debugEnabled() {
+ c.debugLogf("request_retry_with_recovered_token method=%s url=%s status=%d", method, sanitizeSoraLogURL(urlStr), resp.StatusCode)
+ }
+ continue
+ } else if recoverErr != nil && c.debugEnabled() {
+ c.debugLogf("request_recover_token_failed method=%s url=%s status=%d err=%s", method, sanitizeSoraLogURL(urlStr), resp.StatusCode, logredact.RedactText(recoverErr.Error()))
+ }
+ }
+ if c.debugEnabled() {
+ c.debugLogf(
+ "response_non_success method=%s url=%s attempt=%d/%d status=%d body=%s",
+ method,
+ sanitizeSoraLogURL(urlStr),
+ attempt,
+ attempts,
+ resp.StatusCode,
+ summarizeSoraResponseBody(respBody, 512),
+ )
+ }
+ upstreamErr := c.buildUpstreamError(resp.StatusCode, resp.Header, respBody, urlStr)
+ lastErr = upstreamErr
if allowRetry && attempt < attempts && (resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode >= 500) {
+ if c.debugEnabled() {
+ c.debugLogf("request_retry_scheduled method=%s url=%s reason=status_%d next_attempt=%d/%d", method, sanitizeSoraLogURL(urlStr), resp.StatusCode, attempt+1, attempts)
+ }
c.sleepRetry(attempt)
continue
}
@@ -645,9 +1077,34 @@ func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, meth
}
return respBody, resp.Header, nil
}
+ if lastErr != nil {
+ return nil, nil, lastErr
+ }
return nil, nil, errors.New("upstream retries exhausted")
}
+func shouldAttemptSoraTokenRecover(statusCode int, rawURL string) bool {
+ switch statusCode {
+ case http.StatusUnauthorized, http.StatusForbidden:
+ parsed, err := url.Parse(strings.TrimSpace(rawURL))
+ if err != nil {
+ return false
+ }
+ host := strings.ToLower(parsed.Hostname())
+ if host != "sora.chatgpt.com" && host != "chatgpt.com" {
+ return false
+ }
+ // 避免在 ST->AT 转换接口上递归触发 token 恢复导致死循环。
+ path := strings.ToLower(strings.TrimSpace(parsed.Path))
+ if path == "/api/auth/session" {
+ return false
+ }
+ return true
+ default:
+ return false
+ }
+}
+
func (c *SoraDirectClient) doHTTP(req *http.Request, proxyURL string, account *Account) (*http.Response, error) {
enableTLS := c != nil && c.cfg != nil && c.cfg.Gateway.TLSFingerprint.Enabled && !c.cfg.Sora.Client.DisableTLSFingerprint
if c.httpUpstream != nil {
@@ -670,9 +1127,14 @@ func (c *SoraDirectClient) sleepRetry(attempt int) {
time.Sleep(backoff)
}
-func (c *SoraDirectClient) buildUpstreamError(status int, headers http.Header, body []byte) error {
+func (c *SoraDirectClient) buildUpstreamError(status int, headers http.Header, body []byte, requestURL string) error {
msg := strings.TrimSpace(extractUpstreamErrorMessage(body))
msg = sanitizeUpstreamErrorMessage(msg)
+ if status == http.StatusNotFound && strings.Contains(strings.ToLower(msg), "not found") {
+ if hint := soraBaseURLNotFoundHint(requestURL); hint != "" {
+ msg = strings.TrimSpace(msg + " " + hint)
+ }
+ }
if msg == "" {
msg = truncateForLog(body, 256)
}
@@ -684,6 +1146,45 @@ func (c *SoraDirectClient) buildUpstreamError(status int, headers http.Header, b
}
}
+func normalizeSoraBaseURL(raw string) string {
+ trimmed := strings.TrimRight(strings.TrimSpace(raw), "/")
+ if trimmed == "" {
+ return ""
+ }
+ parsed, err := url.Parse(trimmed)
+ if err != nil || parsed.Scheme == "" || parsed.Host == "" {
+ return trimmed
+ }
+ host := strings.ToLower(parsed.Hostname())
+ if host != "sora.chatgpt.com" && host != "chatgpt.com" {
+ return trimmed
+ }
+ pathVal := strings.TrimRight(strings.TrimSpace(parsed.Path), "/")
+ switch pathVal {
+ case "", "/":
+ parsed.Path = "/backend"
+ case "/backend-api":
+ parsed.Path = "/backend"
+ }
+ return strings.TrimRight(parsed.String(), "/")
+}
+
+func soraBaseURLNotFoundHint(requestURL string) string {
+ parsed, err := url.Parse(strings.TrimSpace(requestURL))
+ if err != nil || parsed.Host == "" {
+ return ""
+ }
+ host := strings.ToLower(parsed.Hostname())
+ if host != "sora.chatgpt.com" && host != "chatgpt.com" {
+ return ""
+ }
+ pathVal := strings.TrimSpace(parsed.Path)
+ if strings.HasPrefix(pathVal, "/backend/") || pathVal == "/backend" {
+ return ""
+ }
+ return "(请检查 sora.client.base_url,建议配置为 https://sora.chatgpt.com/backend)"
+}
+
func (c *SoraDirectClient) generateSentinelToken(ctx context.Context, account *Account, accessToken string) (string, error) {
reqID := uuid.NewString()
userAgent := soraRandChoice(soraDesktopUserAgents)
@@ -901,3 +1402,70 @@ func sanitizeSoraLogURL(raw string) string {
parsed.RawQuery = q.Encode()
return parsed.String()
}
+
+func (c *SoraDirectClient) debugEnabled() bool {
+ return c != nil && c.cfg != nil && c.cfg.Sora.Client.Debug
+}
+
+func (c *SoraDirectClient) debugLogf(format string, args ...any) {
+ if !c.debugEnabled() {
+ return
+ }
+ log.Printf("[SoraClient] "+format, args...)
+}
+
+func formatSoraHeaders(headers http.Header) string {
+ if len(headers) == 0 {
+ return "{}"
+ }
+ keys := make([]string, 0, len(headers))
+ for key := range headers {
+ keys = append(keys, key)
+ }
+ sort.Strings(keys)
+ out := make(map[string]string, len(keys))
+ for _, key := range keys {
+ values := headers.Values(key)
+ if len(values) == 0 {
+ continue
+ }
+ val := strings.Join(values, ",")
+ if isSensitiveHeader(key) {
+ out[key] = "***"
+ continue
+ }
+ out[key] = truncateForLog([]byte(logredact.RedactText(val)), 160)
+ }
+ encoded, err := json.Marshal(out)
+ if err != nil {
+ return "{}"
+ }
+ return string(encoded)
+}
+
+func isSensitiveHeader(key string) bool {
+ k := strings.ToLower(strings.TrimSpace(key))
+ switch k {
+ case "authorization", "openai-sentinel-token", "cookie", "set-cookie", "x-api-key":
+ return true
+ default:
+ return false
+ }
+}
+
+func summarizeSoraResponseBody(body []byte, maxLen int) string {
+ if len(body) == 0 {
+ return ""
+ }
+ var text string
+ if json.Valid(body) {
+ text = logredact.RedactJSON(body)
+ } else {
+ text = logredact.RedactText(string(body))
+ }
+ text = strings.TrimSpace(text)
+ if maxLen <= 0 || len(text) <= maxLen {
+ return text
+ }
+ return text[:maxLen] + "...(truncated)"
+}
diff --git a/backend/internal/service/sora_client_test.go b/backend/internal/service/sora_client_test.go
index a6bf71cd..3e88c9f9 100644
--- a/backend/internal/service/sora_client_test.go
+++ b/backend/internal/service/sora_client_test.go
@@ -4,9 +4,13 @@ package service
import (
"context"
+ "encoding/json"
"net/http"
"net/http/httptest"
+ "strings"
+ "sync/atomic"
"testing"
+ "time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
@@ -85,3 +89,273 @@ func TestSoraDirectClient_GetImageTaskFallbackLimit(t *testing.T) {
require.Equal(t, "completed", status.Status)
require.Equal(t, []string{"https://example.com/a.png"}, status.URLs)
}
+
+func TestNormalizeSoraBaseURL(t *testing.T) {
+ t.Parallel()
+ tests := []struct {
+ name string
+ raw string
+ want string
+ }{
+ {
+ name: "empty",
+ raw: "",
+ want: "",
+ },
+ {
+ name: "append_backend_for_sora_host",
+ raw: "https://sora.chatgpt.com",
+ want: "https://sora.chatgpt.com/backend",
+ },
+ {
+ name: "convert_backend_api_to_backend",
+ raw: "https://sora.chatgpt.com/backend-api",
+ want: "https://sora.chatgpt.com/backend",
+ },
+ {
+ name: "keep_backend",
+ raw: "https://sora.chatgpt.com/backend",
+ want: "https://sora.chatgpt.com/backend",
+ },
+ {
+ name: "keep_custom_host",
+ raw: "https://example.com/custom-path",
+ want: "https://example.com/custom-path",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := normalizeSoraBaseURL(tt.raw)
+ require.Equal(t, tt.want, got)
+ })
+ }
+}
+
+func TestSoraDirectClient_BuildURL_UsesNormalizedBaseURL(t *testing.T) {
+ t.Parallel()
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Client: config.SoraClientConfig{
+ BaseURL: "https://sora.chatgpt.com",
+ },
+ },
+ }
+ client := NewSoraDirectClient(cfg, nil, nil)
+ require.Equal(t, "https://sora.chatgpt.com/backend/video_gen", client.buildURL("/video_gen"))
+}
+
+func TestSoraDirectClient_BuildUpstreamError_NotFoundHint(t *testing.T) {
+ t.Parallel()
+ client := NewSoraDirectClient(&config.Config{}, nil, nil)
+
+ err := client.buildUpstreamError(http.StatusNotFound, http.Header{}, []byte(`{"error":{"message":"Not found"}}`), "https://sora.chatgpt.com/video_gen")
+ var upstreamErr *SoraUpstreamError
+ require.ErrorAs(t, err, &upstreamErr)
+ require.Contains(t, upstreamErr.Message, "请检查 sora.client.base_url")
+
+ errNoHint := client.buildUpstreamError(http.StatusNotFound, http.Header{}, []byte(`{"error":{"message":"Not found"}}`), "https://sora.chatgpt.com/backend/video_gen")
+ require.ErrorAs(t, errNoHint, &upstreamErr)
+ require.NotContains(t, upstreamErr.Message, "请检查 sora.client.base_url")
+}
+
+func TestFormatSoraHeaders_RedactsSensitive(t *testing.T) {
+ t.Parallel()
+ headers := http.Header{}
+ headers.Set("Authorization", "Bearer secret-token")
+ headers.Set("openai-sentinel-token", "sentinel-secret")
+ headers.Set("X-Test", "ok")
+
+ out := formatSoraHeaders(headers)
+ require.Contains(t, out, `"Authorization":"***"`)
+ require.Contains(t, out, `Sentinel-Token":"***"`)
+ require.Contains(t, out, `"X-Test":"ok"`)
+ require.NotContains(t, out, "secret-token")
+ require.NotContains(t, out, "sentinel-secret")
+}
+
+func TestSummarizeSoraResponseBody_RedactsJSON(t *testing.T) {
+ t.Parallel()
+ body := []byte(`{"error":{"message":"bad"},"access_token":"abc123"}`)
+ out := summarizeSoraResponseBody(body, 512)
+ require.Contains(t, out, `"access_token":"***"`)
+ require.NotContains(t, out, "abc123")
+}
+
+func TestSummarizeSoraResponseBody_Truncates(t *testing.T) {
+ t.Parallel()
+ body := []byte(strings.Repeat("x", 100))
+ out := summarizeSoraResponseBody(body, 10)
+ require.Contains(t, out, "(truncated)")
+}
+
+func TestSoraDirectClient_GetAccessToken_SoraDefaultUseCredentials(t *testing.T) {
+ t.Parallel()
+ cache := newOpenAITokenCacheStub()
+ provider := NewOpenAITokenProvider(nil, cache, nil)
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Client: config.SoraClientConfig{
+ BaseURL: "https://sora.chatgpt.com/backend",
+ },
+ },
+ }
+ client := NewSoraDirectClient(cfg, nil, provider)
+ account := &Account{
+ ID: 1,
+ Platform: PlatformSora,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "sora-credential-token",
+ },
+ }
+
+ token, err := client.getAccessToken(context.Background(), account)
+ require.NoError(t, err)
+ require.Equal(t, "sora-credential-token", token)
+ require.Equal(t, int32(0), atomic.LoadInt32(&cache.getCalled))
+}
+
+func TestSoraDirectClient_GetAccessToken_SoraCanEnableProvider(t *testing.T) {
+ t.Parallel()
+ cache := newOpenAITokenCacheStub()
+ account := &Account{
+ ID: 2,
+ Platform: PlatformSora,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "sora-credential-token",
+ },
+ }
+ cache.tokens[OpenAITokenCacheKey(account)] = "provider-token"
+ provider := NewOpenAITokenProvider(nil, cache, nil)
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Client: config.SoraClientConfig{
+ BaseURL: "https://sora.chatgpt.com/backend",
+ UseOpenAITokenProvider: true,
+ },
+ },
+ }
+ client := NewSoraDirectClient(cfg, nil, provider)
+
+ token, err := client.getAccessToken(context.Background(), account)
+ require.NoError(t, err)
+ require.Equal(t, "provider-token", token)
+ require.Greater(t, atomic.LoadInt32(&cache.getCalled), int32(0))
+}
+
+func TestSoraDirectClient_GetAccessToken_FromSessionToken(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ require.Equal(t, http.MethodGet, r.Method)
+ require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=session-token")
+ w.Header().Set("Content-Type", "application/json")
+ _ = json.NewEncoder(w).Encode(map[string]any{
+ "accessToken": "session-access-token",
+ "expires": "2099-01-01T00:00:00Z",
+ })
+ }))
+ defer server.Close()
+
+ origin := soraSessionAuthURL
+ soraSessionAuthURL = server.URL
+ defer func() { soraSessionAuthURL = origin }()
+
+ client := NewSoraDirectClient(&config.Config{}, nil, nil)
+ account := &Account{
+ ID: 10,
+ Platform: PlatformSora,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "session_token": "session-token",
+ },
+ }
+
+ token, err := client.getAccessToken(context.Background(), account)
+ require.NoError(t, err)
+ require.Equal(t, "session-access-token", token)
+ require.Equal(t, "session-access-token", account.GetCredential("access_token"))
+}
+
+func TestSoraDirectClient_GetAccessToken_FromRefreshToken(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ require.Equal(t, http.MethodPost, r.Method)
+ require.Equal(t, "/oauth/token", r.URL.Path)
+ w.Header().Set("Content-Type", "application/json")
+ _ = json.NewEncoder(w).Encode(map[string]any{
+ "access_token": "refresh-access-token",
+ "refresh_token": "refresh-token-new",
+ "expires_in": 3600,
+ })
+ }))
+ defer server.Close()
+
+ origin := soraOAuthTokenURL
+ soraOAuthTokenURL = server.URL + "/oauth/token"
+ defer func() { soraOAuthTokenURL = origin }()
+
+ client := NewSoraDirectClient(&config.Config{}, nil, nil)
+ account := &Account{
+ ID: 11,
+ Platform: PlatformSora,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "refresh_token": "refresh-token-old",
+ },
+ }
+
+ token, err := client.getAccessToken(context.Background(), account)
+ require.NoError(t, err)
+ require.Equal(t, "refresh-access-token", token)
+ require.Equal(t, "refresh-token-new", account.GetCredential("refresh_token"))
+ require.NotNil(t, account.GetCredentialAsTime("expires_at"))
+}
+
+func TestSoraDirectClient_PreflightCheck_VideoQuotaExceeded(t *testing.T) {
+ t.Parallel()
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ require.Equal(t, http.MethodGet, r.Method)
+ require.Equal(t, "/nf/check", r.URL.Path)
+ w.Header().Set("Content-Type", "application/json")
+ _ = json.NewEncoder(w).Encode(map[string]any{
+ "rate_limit_and_credit_balance": map[string]any{
+ "estimated_num_videos_remaining": 0,
+ "rate_limit_reached": true,
+ },
+ })
+ }))
+ defer server.Close()
+
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Client: config.SoraClientConfig{
+ BaseURL: server.URL,
+ },
+ },
+ }
+ client := NewSoraDirectClient(cfg, nil, nil)
+ account := &Account{
+ ID: 12,
+ Platform: PlatformSora,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "ok",
+ "expires_at": time.Now().Add(2 * time.Hour).Format(time.RFC3339),
+ },
+ }
+ err := client.PreflightCheck(context.Background(), account, "sora2-landscape-10s", SoraModelConfig{Type: "video"})
+ require.Error(t, err)
+ var upstreamErr *SoraUpstreamError
+ require.ErrorAs(t, err, &upstreamErr)
+ require.Equal(t, http.StatusTooManyRequests, upstreamErr.StatusCode)
+}
+
+func TestShouldAttemptSoraTokenRecover(t *testing.T) {
+ t.Parallel()
+
+ require.True(t, shouldAttemptSoraTokenRecover(http.StatusUnauthorized, "https://sora.chatgpt.com/backend/video_gen"))
+ require.True(t, shouldAttemptSoraTokenRecover(http.StatusForbidden, "https://chatgpt.com/backend/video_gen"))
+ require.False(t, shouldAttemptSoraTokenRecover(http.StatusUnauthorized, "https://sora.chatgpt.com/api/auth/session"))
+ require.False(t, shouldAttemptSoraTokenRecover(http.StatusUnauthorized, "https://auth.openai.com/oauth/token"))
+ require.False(t, shouldAttemptSoraTokenRecover(http.StatusTooManyRequests, "https://sora.chatgpt.com/backend/video_gen"))
+}
diff --git a/backend/internal/service/sora_gateway_service.go b/backend/internal/service/sora_gateway_service.go
index d7ff297c..8ae89f92 100644
--- a/backend/internal/service/sora_gateway_service.go
+++ b/backend/internal/service/sora_gateway_service.go
@@ -61,6 +61,10 @@ type SoraGatewayService struct {
cfg *config.Config
}
+type soraPreflightChecker interface {
+ PreflightCheck(ctx context.Context, account *Account, requestedModel string, modelCfg SoraModelConfig) error
+}
+
func NewSoraGatewayService(
soraClient SoraClient,
mediaStorage *SoraMediaStorage,
@@ -112,11 +116,6 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Unsupported Sora model", clientStream)
return nil, fmt.Errorf("unsupported model: %s", reqModel)
}
- if modelCfg.Type == "prompt_enhance" {
- s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Prompt-enhance 模型暂未支持", clientStream)
- return nil, fmt.Errorf("prompt-enhance not supported")
- }
-
prompt, imageInput, videoInput, remixTargetID := extractSoraInput(reqBody)
if strings.TrimSpace(prompt) == "" {
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream)
@@ -131,6 +130,41 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
if cancel != nil {
defer cancel()
}
+ if checker, ok := s.soraClient.(soraPreflightChecker); ok {
+ if err := checker.PreflightCheck(reqCtx, account, reqModel, modelCfg); err != nil {
+ return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream)
+ }
+ }
+
+ if modelCfg.Type == "prompt_enhance" {
+ enhancedPrompt, err := s.soraClient.EnhancePrompt(reqCtx, account, prompt, modelCfg.ExpansionLevel, modelCfg.DurationS)
+ if err != nil {
+ return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream)
+ }
+ content := strings.TrimSpace(enhancedPrompt)
+ if content == "" {
+ content = prompt
+ }
+ var firstTokenMs *int
+ if clientStream {
+ ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime)
+ if streamErr != nil {
+ return nil, streamErr
+ }
+ firstTokenMs = ms
+ } else if c != nil {
+ c.JSON(http.StatusOK, buildSoraNonStreamResponse(content, reqModel))
+ }
+ return &ForwardResult{
+ RequestID: "",
+ Model: reqModel,
+ Stream: clientStream,
+ Duration: time.Since(startTime),
+ FirstTokenMs: firstTokenMs,
+ Usage: ClaudeUsage{},
+ MediaType: "prompt",
+ }, nil
+ }
var imageData []byte
imageFilename := ""
@@ -267,7 +301,7 @@ func (s *SoraGatewayService) withSoraTimeout(ctx context.Context, stream bool) (
func (s *SoraGatewayService) shouldFailoverUpstreamError(statusCode int) bool {
switch statusCode {
- case 401, 402, 403, 429, 529:
+ case 401, 402, 403, 404, 429, 529:
return true
default:
return statusCode >= 500
@@ -460,7 +494,7 @@ func (s *SoraGatewayService) handleSoraRequestError(ctx context.Context, account
s.rateLimitService.HandleUpstreamError(ctx, account, upstreamErr.StatusCode, upstreamErr.Headers, upstreamErr.Body)
}
if s.shouldFailoverUpstreamError(upstreamErr.StatusCode) {
- return &UpstreamFailoverError{StatusCode: upstreamErr.StatusCode}
+ return &UpstreamFailoverError{StatusCode: upstreamErr.StatusCode, ResponseBody: upstreamErr.Body}
}
msg := upstreamErr.Message
if override := soraProErrorMessage(model, msg); override != "" {
diff --git a/backend/internal/service/sora_gateway_service_test.go b/backend/internal/service/sora_gateway_service_test.go
index d6bf9eae..f706d052 100644
--- a/backend/internal/service/sora_gateway_service_test.go
+++ b/backend/internal/service/sora_gateway_service_test.go
@@ -18,6 +18,8 @@ type stubSoraClientForPoll struct {
videoStatus *SoraVideoTaskStatus
imageCalls int
videoCalls int
+ enhanced string
+ enhanceErr error
}
func (s *stubSoraClientForPoll) Enabled() bool { return true }
@@ -30,6 +32,12 @@ func (s *stubSoraClientForPoll) CreateImageTask(ctx context.Context, account *Ac
func (s *stubSoraClientForPoll) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) {
return "task-video", nil
}
+func (s *stubSoraClientForPoll) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) {
+ if s.enhanced != "" {
+ return s.enhanced, s.enhanceErr
+ }
+ return "enhanced prompt", s.enhanceErr
+}
func (s *stubSoraClientForPoll) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) {
s.imageCalls++
return s.imageStatus, nil
@@ -62,6 +70,33 @@ func TestSoraGatewayService_PollImageTaskCompleted(t *testing.T) {
require.Equal(t, 1, client.imageCalls)
}
+func TestSoraGatewayService_ForwardPromptEnhance(t *testing.T) {
+ client := &stubSoraClientForPoll{
+ enhanced: "cinematic prompt",
+ }
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Client: config.SoraClientConfig{
+ PollIntervalSeconds: 1,
+ MaxPollAttempts: 1,
+ },
+ },
+ }
+ svc := NewSoraGatewayService(client, nil, nil, cfg)
+ account := &Account{
+ ID: 1,
+ Platform: PlatformSora,
+ Status: StatusActive,
+ }
+ body := []byte(`{"model":"prompt-enhance-short-10s","messages":[{"role":"user","content":"cat running"}],"stream":false}`)
+
+ result, err := svc.Forward(context.Background(), nil, account, body, false)
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, "prompt", result.MediaType)
+ require.Equal(t, "prompt-enhance-short-10s", result.Model)
+}
+
func TestSoraGatewayService_PollVideoTaskFailed(t *testing.T) {
client := &stubSoraClientForPoll{
videoStatus: &SoraVideoTaskStatus{
@@ -178,6 +213,7 @@ func TestSoraProErrorMessage(t *testing.T) {
func TestShouldFailoverUpstreamError(t *testing.T) {
svc := NewSoraGatewayService(nil, nil, nil, &config.Config{})
require.True(t, svc.shouldFailoverUpstreamError(401))
+ require.True(t, svc.shouldFailoverUpstreamError(404))
require.True(t, svc.shouldFailoverUpstreamError(429))
require.True(t, svc.shouldFailoverUpstreamError(500))
require.True(t, svc.shouldFailoverUpstreamError(502))
diff --git a/backend/internal/service/sora_models.go b/backend/internal/service/sora_models.go
index ab095e46..80b20a4b 100644
--- a/backend/internal/service/sora_models.go
+++ b/backend/internal/service/sora_models.go
@@ -17,6 +17,9 @@ type SoraModelConfig struct {
Model string
Size string
RequirePro bool
+ // Prompt-enhance 专用参数
+ ExpansionLevel string
+ DurationS int
}
var soraModelConfigs = map[string]SoraModelConfig{
@@ -160,31 +163,49 @@ var soraModelConfigs = map[string]SoraModelConfig{
RequirePro: true,
},
"prompt-enhance-short-10s": {
- Type: "prompt_enhance",
+ Type: "prompt_enhance",
+ ExpansionLevel: "short",
+ DurationS: 10,
},
"prompt-enhance-short-15s": {
- Type: "prompt_enhance",
+ Type: "prompt_enhance",
+ ExpansionLevel: "short",
+ DurationS: 15,
},
"prompt-enhance-short-20s": {
- Type: "prompt_enhance",
+ Type: "prompt_enhance",
+ ExpansionLevel: "short",
+ DurationS: 20,
},
"prompt-enhance-medium-10s": {
- Type: "prompt_enhance",
+ Type: "prompt_enhance",
+ ExpansionLevel: "medium",
+ DurationS: 10,
},
"prompt-enhance-medium-15s": {
- Type: "prompt_enhance",
+ Type: "prompt_enhance",
+ ExpansionLevel: "medium",
+ DurationS: 15,
},
"prompt-enhance-medium-20s": {
- Type: "prompt_enhance",
+ Type: "prompt_enhance",
+ ExpansionLevel: "medium",
+ DurationS: 20,
},
"prompt-enhance-long-10s": {
- Type: "prompt_enhance",
+ Type: "prompt_enhance",
+ ExpansionLevel: "long",
+ DurationS: 10,
},
"prompt-enhance-long-15s": {
- Type: "prompt_enhance",
+ Type: "prompt_enhance",
+ ExpansionLevel: "long",
+ DurationS: 15,
},
"prompt-enhance-long-20s": {
- Type: "prompt_enhance",
+ Type: "prompt_enhance",
+ ExpansionLevel: "long",
+ DurationS: 20,
},
}
diff --git a/backend/internal/service/token_refresh_service.go b/backend/internal/service/token_refresh_service.go
index 9de1c164..a37e0d0a 100644
--- a/backend/internal/service/token_refresh_service.go
+++ b/backend/internal/service/token_refresh_service.go
@@ -43,10 +43,13 @@ func NewTokenRefreshService(
stopCh: make(chan struct{}),
}
+ openAIRefresher := NewOpenAITokenRefresher(openaiOAuthService, accountRepo)
+ openAIRefresher.SetSyncLinkedSoraAccounts(cfg.TokenRefresh.SyncLinkedSoraAccounts)
+
// 注册平台特定的刷新器
s.refreshers = []TokenRefresher{
NewClaudeTokenRefresher(oauthService),
- NewOpenAITokenRefresher(openaiOAuthService, accountRepo),
+ openAIRefresher,
NewGeminiTokenRefresher(geminiOAuthService),
NewAntigravityTokenRefresher(antigravityOAuthService),
}
diff --git a/backend/internal/service/token_refresher.go b/backend/internal/service/token_refresher.go
index 46033f75..0dd3cf45 100644
--- a/backend/internal/service/token_refresher.go
+++ b/backend/internal/service/token_refresher.go
@@ -86,6 +86,7 @@ type OpenAITokenRefresher struct {
openaiOAuthService *OpenAIOAuthService
accountRepo AccountRepository
soraAccountRepo SoraAccountRepository // Sora 扩展表仓储,用于双表同步
+ syncLinkedSora bool
}
// NewOpenAITokenRefresher 创建 OpenAI token刷新器
@@ -103,11 +104,15 @@ func (r *OpenAITokenRefresher) SetSoraAccountRepo(repo SoraAccountRepository) {
r.soraAccountRepo = repo
}
+// SetSyncLinkedSoraAccounts 控制是否同步覆盖关联的 Sora 账号 token。
+func (r *OpenAITokenRefresher) SetSyncLinkedSoraAccounts(enabled bool) {
+ r.syncLinkedSora = enabled
+}
+
// CanRefresh 检查是否能处理此账号
-// 只处理 openai 平台的 oauth 类型账号
+// 只处理 openai 平台的 oauth 类型账号(不直接刷新 sora 平台账号)
func (r *OpenAITokenRefresher) CanRefresh(account *Account) bool {
- return (account.Platform == PlatformOpenAI || account.Platform == PlatformSora) &&
- account.Type == AccountTypeOAuth
+ return account.Platform == PlatformOpenAI && account.Type == AccountTypeOAuth
}
// NeedsRefresh 检查token是否需要刷新
@@ -141,7 +146,7 @@ func (r *OpenAITokenRefresher) Refresh(ctx context.Context, account *Account) (m
}
// 异步同步关联的 Sora 账号(不阻塞主流程)
- if r.accountRepo != nil {
+ if r.accountRepo != nil && r.syncLinkedSora {
go r.syncLinkedSoraAccounts(context.Background(), account.ID, newCredentials)
}
diff --git a/backend/internal/service/token_refresher_test.go b/backend/internal/service/token_refresher_test.go
index c7505037..264d7912 100644
--- a/backend/internal/service/token_refresher_test.go
+++ b/backend/internal/service/token_refresher_test.go
@@ -226,3 +226,43 @@ func TestClaudeTokenRefresher_CanRefresh(t *testing.T) {
})
}
}
+
+func TestOpenAITokenRefresher_CanRefresh(t *testing.T) {
+ refresher := &OpenAITokenRefresher{}
+
+ tests := []struct {
+ name string
+ platform string
+ accType string
+ want bool
+ }{
+ {
+ name: "openai oauth - can refresh",
+ platform: PlatformOpenAI,
+ accType: AccountTypeOAuth,
+ want: true,
+ },
+ {
+ name: "sora oauth - cannot refresh directly",
+ platform: PlatformSora,
+ accType: AccountTypeOAuth,
+ want: false,
+ },
+ {
+ name: "openai apikey - cannot refresh",
+ platform: PlatformOpenAI,
+ accType: AccountTypeAPIKey,
+ want: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ account := &Account{
+ Platform: tt.platform,
+ Type: tt.accType,
+ }
+ require.Equal(t, tt.want, refresher.CanRefresh(account))
+ })
+ }
+}
diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go
index 5d712f75..652f9e00 100644
--- a/backend/internal/service/wire.go
+++ b/backend/internal/service/wire.go
@@ -206,6 +206,18 @@ func ProvideSoraMediaStorage(cfg *config.Config) *SoraMediaStorage {
return NewSoraMediaStorage(cfg)
}
+func ProvideSoraDirectClient(
+ cfg *config.Config,
+ httpUpstream HTTPUpstream,
+ tokenProvider *OpenAITokenProvider,
+ accountRepo AccountRepository,
+ soraAccountRepo SoraAccountRepository,
+) *SoraDirectClient {
+ client := NewSoraDirectClient(cfg, httpUpstream, tokenProvider)
+ client.SetAccountRepositories(accountRepo, soraAccountRepo)
+ return client
+}
+
// ProvideSoraMediaCleanupService 创建并启动 Sora 媒体清理服务
func ProvideSoraMediaCleanupService(storage *SoraMediaStorage, cfg *config.Config) *SoraMediaCleanupService {
svc := NewSoraMediaCleanupService(storage, cfg)
@@ -255,7 +267,7 @@ var ProviderSet = wire.NewSet(
NewGatewayService,
ProvideSoraMediaStorage,
ProvideSoraMediaCleanupService,
- NewSoraDirectClient,
+ ProvideSoraDirectClient,
wire.Bind(new(SoraClient), new(*SoraDirectClient)),
NewSoraGatewayService,
NewOpenAIGatewayService,
diff --git a/backend/internal/web/embed_on.go b/backend/internal/web/embed_on.go
index 7f37d59c..f7ba5c9e 100644
--- a/backend/internal/web/embed_on.go
+++ b/backend/internal/web/embed_on.go
@@ -86,6 +86,7 @@ func (s *FrontendServer) Middleware() gin.HandlerFunc {
if strings.HasPrefix(path, "/api/") ||
strings.HasPrefix(path, "/v1/") ||
strings.HasPrefix(path, "/v1beta/") ||
+ strings.HasPrefix(path, "/sora/") ||
strings.HasPrefix(path, "/antigravity/") ||
strings.HasPrefix(path, "/setup/") ||
path == "/health" ||
@@ -209,6 +210,7 @@ func ServeEmbeddedFrontend() gin.HandlerFunc {
if strings.HasPrefix(path, "/api/") ||
strings.HasPrefix(path, "/v1/") ||
strings.HasPrefix(path, "/v1beta/") ||
+ strings.HasPrefix(path, "/sora/") ||
strings.HasPrefix(path, "/antigravity/") ||
strings.HasPrefix(path, "/setup/") ||
path == "/health" ||
diff --git a/backend/internal/web/embed_test.go b/backend/internal/web/embed_test.go
index 50f5a323..e2cbcf15 100644
--- a/backend/internal/web/embed_test.go
+++ b/backend/internal/web/embed_test.go
@@ -362,6 +362,7 @@ func TestFrontendServer_Middleware(t *testing.T) {
"/api/v1/users",
"/v1/models",
"/v1beta/chat",
+ "/sora/v1/models",
"/antigravity/test",
"/setup/init",
"/health",
@@ -537,6 +538,7 @@ func TestServeEmbeddedFrontend(t *testing.T) {
"/api/users",
"/v1/models",
"/v1beta/chat",
+ "/sora/v1/models",
"/antigravity/test",
"/setup/init",
"/health",
diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml
index 9fd2d391..0ff1ec02 100644
--- a/deploy/config.example.yaml
+++ b/deploy/config.example.yaml
@@ -388,7 +388,11 @@ sora:
recent_task_limit_max: 200
# Enable debug logs for Sora upstream requests
# 启用 Sora 直连调试日志
+ # 调试日志会输出上游请求尝试、重试、响应摘要;Authorization/openai-sentinel-token 等敏感头会自动脱敏
debug: false
+ # Allow Sora client to fetch token via OpenAI token provider
+ # 是否允许 Sora 客户端通过 OpenAI token provider 取 token(默认 false,避免误走 OpenAI 刷新链路)
+ use_openai_token_provider: false
# Optional custom headers (key-value)
# 额外请求头(键值对)
headers: {}
@@ -431,6 +435,13 @@ sora:
# Cron 调度表达式
schedule: "0 3 * * *"
+# Token refresh behavior
+# token 刷新行为控制
+token_refresh:
+ # Whether OpenAI refresh flow is allowed to sync linked Sora accounts
+ # 是否允许 OpenAI 刷新流程同步覆盖 linked_openai_account_id 关联的 Sora 账号 token
+ sync_linked_sora_accounts: false
+
# =============================================================================
# API Key Auth Cache Configuration
# API Key 认证缓存配置
diff --git a/frontend/src/api/admin/accounts.ts b/frontend/src/api/admin/accounts.ts
index 36bec4e7..e1f502ec 100644
--- a/frontend/src/api/admin/accounts.ts
+++ b/frontend/src/api/admin/accounts.ts
@@ -220,7 +220,7 @@ export async function generateAuthUrl(
*/
export async function exchangeCode(
endpoint: string,
- exchangeData: { session_id: string; code: string; proxy_id?: number }
+ exchangeData: { session_id: string; code: string; state?: string; proxy_id?: number }
): Promise> {
const { data } = await apiClient.post>(endpoint, exchangeData)
return data
@@ -442,7 +442,8 @@ export async function getAntigravityDefaultModelMapping(): Promise> {
const payload: { refresh_token: string; proxy_id?: number } = {
refresh_token: refreshToken
@@ -450,7 +451,29 @@ export async function refreshOpenAIToken(
if (proxyId) {
payload.proxy_id = proxyId
}
- const { data } = await apiClient.post>('/admin/openai/refresh-token', payload)
+ const { data } = await apiClient.post>(endpoint, payload)
+ return data
+}
+
+/**
+ * Validate Sora session token and exchange to access token
+ * @param sessionToken - Sora session token
+ * @param proxyId - Optional proxy ID
+ * @param endpoint - API endpoint path
+ * @returns Token information including access_token
+ */
+export async function validateSoraSessionToken(
+ sessionToken: string,
+ proxyId?: number | null,
+ endpoint: string = '/admin/sora/st2at'
+): Promise> {
+ const payload: { session_token: string; proxy_id?: number } = {
+ session_token: sessionToken
+ }
+ if (proxyId) {
+ payload.proxy_id = proxyId
+ }
+ const { data } = await apiClient.post>(endpoint, payload)
return data
}
@@ -475,6 +498,7 @@ export const accountsAPI = {
generateAuthUrl,
exchangeCode,
refreshOpenAIToken,
+ validateSoraSessionToken,
batchCreate,
batchUpdateCredentials,
bulkUpdate,
diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue
index 85785d6a..8024dfb6 100644
--- a/frontend/src/components/account/CreateAccountModal.vue
+++ b/frontend/src/components/account/CreateAccountModal.vue
@@ -109,6 +109,28 @@
OpenAI
+