feat(sora): 对齐 Sora OAuth 流程并隔离网关请求路径

- 新增并接通 Sora 专用 OAuth 接口与 ST/RT 换取能力
- 完成前端 Sora 授权、RT/ST 手动导入与账号创建流程
- 强化 Sora token 恢复、转发日志与网关路由隔离行为
- 补充后端服务层与路由层相关测试覆盖

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
yangjianbo
2026-02-19 08:02:56 +08:00
parent 36bb327024
commit 900cce20a1
39 changed files with 2561 additions and 283 deletions

View File

@@ -162,6 +162,8 @@ type TokenRefreshConfig struct {
MaxRetries int `mapstructure:"max_retries"` MaxRetries int `mapstructure:"max_retries"`
// 重试退避基础时间(秒) // 重试退避基础时间(秒)
RetryBackoffSeconds int `mapstructure:"retry_backoff_seconds"` RetryBackoffSeconds int `mapstructure:"retry_backoff_seconds"`
// 是否允许 OpenAI 刷新器同步覆盖关联的 Sora 账号 token默认关闭
SyncLinkedSoraAccounts bool `mapstructure:"sync_linked_sora_accounts"`
} }
type PricingConfig struct { type PricingConfig struct {
@@ -277,6 +279,7 @@ type SoraClientConfig struct {
RecentTaskLimit int `mapstructure:"recent_task_limit"` RecentTaskLimit int `mapstructure:"recent_task_limit"`
RecentTaskLimitMax int `mapstructure:"recent_task_limit_max"` RecentTaskLimitMax int `mapstructure:"recent_task_limit_max"`
Debug bool `mapstructure:"debug"` Debug bool `mapstructure:"debug"`
UseOpenAITokenProvider bool `mapstructure:"use_openai_token_provider"`
Headers map[string]string `mapstructure:"headers"` Headers map[string]string `mapstructure:"headers"`
UserAgent string `mapstructure:"user_agent"` UserAgent string `mapstructure:"user_agent"`
DisableTLSFingerprint bool `mapstructure:"disable_tls_fingerprint"` DisableTLSFingerprint bool `mapstructure:"disable_tls_fingerprint"`
@@ -1116,6 +1119,7 @@ func setDefaults() {
viper.SetDefault("sora.client.recent_task_limit", 50) viper.SetDefault("sora.client.recent_task_limit", 50)
viper.SetDefault("sora.client.recent_task_limit_max", 200) viper.SetDefault("sora.client.recent_task_limit_max", 200)
viper.SetDefault("sora.client.debug", false) viper.SetDefault("sora.client.debug", false)
viper.SetDefault("sora.client.use_openai_token_provider", false)
viper.SetDefault("sora.client.headers", map[string]string{}) viper.SetDefault("sora.client.headers", map[string]string{})
viper.SetDefault("sora.client.user_agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") viper.SetDefault("sora.client.user_agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
viper.SetDefault("sora.client.disable_tls_fingerprint", false) viper.SetDefault("sora.client.disable_tls_fingerprint", false)
@@ -1137,6 +1141,7 @@ func setDefaults() {
viper.SetDefault("token_refresh.refresh_before_expiry_hours", 0.5) // 提前30分钟刷新适配Google 1小时token viper.SetDefault("token_refresh.refresh_before_expiry_hours", 0.5) // 提前30分钟刷新适配Google 1小时token
viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次 viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次
viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒 viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒
viper.SetDefault("token_refresh.sync_linked_sora_accounts", false) // 默认不跨平台覆盖 Sora token
// Gemini OAuth - configure via environment variables or config file // Gemini OAuth - configure via environment variables or config file
// GEMINI_OAUTH_CLIENT_ID and GEMINI_OAUTH_CLIENT_SECRET // GEMINI_OAUTH_CLIENT_ID and GEMINI_OAUTH_CLIENT_SECRET

View File

@@ -1333,6 +1333,12 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
return return
} }
// Handle Sora accounts
if account.Platform == service.PlatformSora {
response.Success(c, service.DefaultSoraModels(nil))
return
}
// Handle Claude/Anthropic accounts // Handle Claude/Anthropic accounts
// For OAuth and Setup-Token accounts: return default models // For OAuth and Setup-Token accounts: return default models
if account.IsOAuth() { if account.IsOAuth() {

View File

@@ -2,6 +2,7 @@ package admin
import ( import (
"strconv" "strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
@@ -16,6 +17,13 @@ type OpenAIOAuthHandler struct {
adminService service.AdminService adminService service.AdminService
} }
func oauthPlatformFromPath(c *gin.Context) string {
if strings.Contains(c.FullPath(), "/admin/sora/") {
return service.PlatformSora
}
return service.PlatformOpenAI
}
// NewOpenAIOAuthHandler creates a new OpenAI OAuth handler // NewOpenAIOAuthHandler creates a new OpenAI OAuth handler
func NewOpenAIOAuthHandler(openaiOAuthService *service.OpenAIOAuthService, adminService service.AdminService) *OpenAIOAuthHandler { func NewOpenAIOAuthHandler(openaiOAuthService *service.OpenAIOAuthService, adminService service.AdminService) *OpenAIOAuthHandler {
return &OpenAIOAuthHandler{ return &OpenAIOAuthHandler{
@@ -52,6 +60,7 @@ func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) {
type OpenAIExchangeCodeRequest struct { type OpenAIExchangeCodeRequest struct {
SessionID string `json:"session_id" binding:"required"` SessionID string `json:"session_id" binding:"required"`
Code string `json:"code" binding:"required"` Code string `json:"code" binding:"required"`
State string `json:"state" binding:"required"`
RedirectURI string `json:"redirect_uri"` RedirectURI string `json:"redirect_uri"`
ProxyID *int64 `json:"proxy_id"` ProxyID *int64 `json:"proxy_id"`
} }
@@ -68,6 +77,7 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) {
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{ tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
SessionID: req.SessionID, SessionID: req.SessionID,
Code: req.Code, Code: req.Code,
State: req.State,
RedirectURI: req.RedirectURI, RedirectURI: req.RedirectURI,
ProxyID: req.ProxyID, ProxyID: req.ProxyID,
}) })
@@ -81,18 +91,29 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) {
// OpenAIRefreshTokenRequest represents the request for refreshing OpenAI token // OpenAIRefreshTokenRequest represents the request for refreshing OpenAI token
type OpenAIRefreshTokenRequest struct { type OpenAIRefreshTokenRequest struct {
RefreshToken string `json:"refresh_token" binding:"required"` RefreshToken string `json:"refresh_token"`
RT string `json:"rt"`
ClientID string `json:"client_id"`
ProxyID *int64 `json:"proxy_id"` ProxyID *int64 `json:"proxy_id"`
} }
// RefreshToken refreshes an OpenAI OAuth token // RefreshToken refreshes an OpenAI OAuth token
// POST /api/v1/admin/openai/refresh-token // POST /api/v1/admin/openai/refresh-token
// POST /api/v1/admin/sora/rt2at
func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) { func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
var req OpenAIRefreshTokenRequest var req OpenAIRefreshTokenRequest
if err := c.ShouldBindJSON(&req); err != nil { if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error()) response.BadRequest(c, "Invalid request: "+err.Error())
return return
} }
refreshToken := strings.TrimSpace(req.RefreshToken)
if refreshToken == "" {
refreshToken = strings.TrimSpace(req.RT)
}
if refreshToken == "" {
response.BadRequest(c, "refresh_token is required")
return
}
var proxyURL string var proxyURL string
if req.ProxyID != nil { if req.ProxyID != nil {
@@ -102,7 +123,7 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
} }
} }
tokenInfo, err := h.openaiOAuthService.RefreshToken(c.Request.Context(), req.RefreshToken, proxyURL) tokenInfo, err := h.openaiOAuthService.RefreshTokenWithClientID(c.Request.Context(), refreshToken, proxyURL, strings.TrimSpace(req.ClientID))
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
@@ -111,8 +132,39 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
response.Success(c, tokenInfo) response.Success(c, tokenInfo)
} }
// RefreshAccountToken refreshes token for a specific OpenAI account // ExchangeSoraSessionToken exchanges Sora session token to access token
// POST /api/v1/admin/sora/st2at
func (h *OpenAIOAuthHandler) ExchangeSoraSessionToken(c *gin.Context) {
var req struct {
SessionToken string `json:"session_token"`
ST string `json:"st"`
ProxyID *int64 `json:"proxy_id"`
}
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
sessionToken := strings.TrimSpace(req.SessionToken)
if sessionToken == "" {
sessionToken = strings.TrimSpace(req.ST)
}
if sessionToken == "" {
response.BadRequest(c, "session_token is required")
return
}
tokenInfo, err := h.openaiOAuthService.ExchangeSoraSessionToken(c.Request.Context(), sessionToken, req.ProxyID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, tokenInfo)
}
// RefreshAccountToken refreshes token for a specific OpenAI/Sora account
// POST /api/v1/admin/openai/accounts/:id/refresh // POST /api/v1/admin/openai/accounts/:id/refresh
// POST /api/v1/admin/sora/accounts/:id/refresh
func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) { func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64) accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil { if err != nil {
@@ -127,9 +179,9 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
return return
} }
// Ensure account is OpenAI platform platform := oauthPlatformFromPath(c)
if !account.IsOpenAI() { if account.Platform != platform {
response.BadRequest(c, "Account is not an OpenAI account") response.BadRequest(c, "Account platform does not match OAuth endpoint")
return return
} }
@@ -167,12 +219,14 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
response.Success(c, dto.AccountFromService(updatedAccount)) response.Success(c, dto.AccountFromService(updatedAccount))
} }
// CreateAccountFromOAuth creates a new OpenAI OAuth account from token info // CreateAccountFromOAuth creates a new OpenAI/Sora OAuth account from token info
// POST /api/v1/admin/openai/create-from-oauth // POST /api/v1/admin/openai/create-from-oauth
// POST /api/v1/admin/sora/create-from-oauth
func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) { func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
var req struct { var req struct {
SessionID string `json:"session_id" binding:"required"` SessionID string `json:"session_id" binding:"required"`
Code string `json:"code" binding:"required"` Code string `json:"code" binding:"required"`
State string `json:"state" binding:"required"`
RedirectURI string `json:"redirect_uri"` RedirectURI string `json:"redirect_uri"`
ProxyID *int64 `json:"proxy_id"` ProxyID *int64 `json:"proxy_id"`
Name string `json:"name"` Name string `json:"name"`
@@ -189,6 +243,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{ tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
SessionID: req.SessionID, SessionID: req.SessionID,
Code: req.Code, Code: req.Code,
State: req.State,
RedirectURI: req.RedirectURI, RedirectURI: req.RedirectURI,
ProxyID: req.ProxyID, ProxyID: req.ProxyID,
}) })
@@ -200,19 +255,25 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
// Build credentials from token info // Build credentials from token info
credentials := h.openaiOAuthService.BuildAccountCredentials(tokenInfo) credentials := h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
platform := oauthPlatformFromPath(c)
// Use email as default name if not provided // Use email as default name if not provided
name := req.Name name := req.Name
if name == "" && tokenInfo.Email != "" { if name == "" && tokenInfo.Email != "" {
name = tokenInfo.Email name = tokenInfo.Email
} }
if name == "" { if name == "" {
if platform == service.PlatformSora {
name = "Sora OAuth Account"
} else {
name = "OpenAI OAuth Account" name = "OpenAI OAuth Account"
} }
}
// Create account // Create account
account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{ account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{
Name: name, Name: name,
Platform: "openai", Platform: platform,
Type: "oauth", Type: "oauth",
Credentials: credentials, Credentials: credentials,
ProxyID: req.ProxyID, ProxyID: req.ProxyID,

View File

@@ -212,6 +212,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
switchCount := 0 switchCount := 0
failedAccountIDs := make(map[int64]struct{}) failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0 lastFailoverStatus := 0
var lastFailoverBody []byte
for { for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "") selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "")
@@ -224,7 +225,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return return
} }
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverBody, streamStarted)
return return
} }
account := selection.Account account := selection.Account
@@ -287,14 +288,19 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
failedAccountIDs[account.ID] = struct{}{} failedAccountIDs[account.ID] = struct{}{}
if switchCount >= maxAccountSwitches { if switchCount >= maxAccountSwitches {
lastFailoverStatus = failoverErr.StatusCode lastFailoverStatus = failoverErr.StatusCode
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) lastFailoverBody = failoverErr.ResponseBody
h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverBody, streamStarted)
return return
} }
lastFailoverStatus = failoverErr.StatusCode lastFailoverStatus = failoverErr.StatusCode
lastFailoverBody = failoverErr.ResponseBody
switchCount++ switchCount++
upstreamErrCode, upstreamErrMsg := extractUpstreamErrorCodeAndMessage(lastFailoverBody)
reqLog.Warn("sora.upstream_failover_switching", reqLog.Warn("sora.upstream_failover_switching",
zap.Int64("account_id", account.ID), zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode), zap.Int("upstream_status", failoverErr.StatusCode),
zap.String("upstream_error_code", upstreamErrCode),
zap.String("upstream_error_message", upstreamErrMsg),
zap.Int("switch_count", switchCount), zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches), zap.Int("max_switches", maxAccountSwitches),
) )
@@ -360,17 +366,32 @@ func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, s
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted) fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
} }
func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) { func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, responseBody []byte, streamStarted bool) {
status, errType, errMsg := h.mapUpstreamError(statusCode) status, errType, errMsg := h.mapUpstreamError(statusCode, responseBody)
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted) h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
} }
func (h *SoraGatewayHandler) mapUpstreamError(statusCode int) (int, string, string) { func (h *SoraGatewayHandler) mapUpstreamError(statusCode int, responseBody []byte) (int, string, string) {
upstreamCode, upstreamMessage := extractUpstreamErrorCodeAndMessage(responseBody)
if upstreamMessage != "" {
switch statusCode {
case 401, 403, 404, 500, 502, 503, 504:
return http.StatusBadGateway, "upstream_error", upstreamMessage
case 429:
return http.StatusTooManyRequests, "rate_limit_error", upstreamMessage
}
}
switch statusCode { switch statusCode {
case 401: case 401:
return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator" return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator"
case 403: case 403:
return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator" return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator"
case 404:
if strings.EqualFold(upstreamCode, "unsupported_country_code") {
return http.StatusBadGateway, "upstream_error", "Upstream region capability unavailable for this account, please contact administrator"
}
return http.StatusBadGateway, "upstream_error", "Upstream capability unavailable for this account, please contact administrator"
case 429: case 429:
return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later" return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later"
case 529: case 529:
@@ -382,6 +403,41 @@ func (h *SoraGatewayHandler) mapUpstreamError(statusCode int) (int, string, stri
} }
} }
func extractUpstreamErrorCodeAndMessage(body []byte) (string, string) {
trimmed := strings.TrimSpace(string(body))
if trimmed == "" {
return "", ""
}
if !gjson.Valid(trimmed) {
return "", truncateSoraErrorMessage(trimmed, 256)
}
code := strings.TrimSpace(gjson.Get(trimmed, "error.code").String())
if code == "" {
code = strings.TrimSpace(gjson.Get(trimmed, "code").String())
}
message := strings.TrimSpace(gjson.Get(trimmed, "error.message").String())
if message == "" {
message = strings.TrimSpace(gjson.Get(trimmed, "message").String())
}
if message == "" {
message = strings.TrimSpace(gjson.Get(trimmed, "error.detail").String())
}
if message == "" {
message = strings.TrimSpace(gjson.Get(trimmed, "detail").String())
}
return code, truncateSoraErrorMessage(message, 512)
}
func truncateSoraErrorMessage(s string, maxLen int) string {
if maxLen <= 0 {
return ""
}
if len(s) <= maxLen {
return s
}
return s[:maxLen] + "...(truncated)"
}
func (h *SoraGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) { func (h *SoraGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
if streamStarted { if streamStarted {
flusher, ok := c.Writer.(http.Flusher) flusher, ok := c.Writer.(http.Flusher)

View File

@@ -43,6 +43,9 @@ func (s *stubSoraClient) CreateImageTask(ctx context.Context, account *service.A
func (s *stubSoraClient) CreateVideoTask(ctx context.Context, account *service.Account, req service.SoraVideoRequest) (string, error) { func (s *stubSoraClient) CreateVideoTask(ctx context.Context, account *service.Account, req service.SoraVideoRequest) (string, error) {
return "task-video", nil return "task-video", nil
} }
func (s *stubSoraClient) EnhancePrompt(ctx context.Context, account *service.Account, prompt, expansionLevel string, durationS int) (string, error) {
return "enhanced prompt", nil
}
func (s *stubSoraClient) GetImageTask(ctx context.Context, account *service.Account, taskID string) (*service.SoraImageTaskStatus, error) { func (s *stubSoraClient) GetImageTask(ctx context.Context, account *service.Account, taskID string) (*service.SoraImageTaskStatus, error) {
return &service.SoraImageTaskStatus{ID: taskID, Status: "completed", URLs: s.imageURLs}, nil return &service.SoraImageTaskStatus{ID: taskID, Status: "completed", URLs: s.imageURLs}, nil
} }

View File

@@ -17,6 +17,8 @@ import (
const ( const (
// OAuth Client ID for OpenAI (Codex CLI official) // OAuth Client ID for OpenAI (Codex CLI official)
ClientID = "app_EMoamEEZ73f0CkXaXp7hrann" ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
// OAuth Client ID for Sora mobile flow (aligned with sora2api)
SoraClientID = "app_LlGpXReQgckcGGUo2JrYvtJK"
// OAuth endpoints // OAuth endpoints
AuthorizeURL = "https://auth.openai.com/oauth/authorize" AuthorizeURL = "https://auth.openai.com/oauth/authorize"

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"net/http" "net/http"
"net/url" "net/url"
"strings"
"time" "time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
@@ -56,12 +57,49 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie
} }
func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) { func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, "")
}
func (s *openaiOAuthService) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
if strings.TrimSpace(clientID) != "" {
return s.refreshTokenWithClientID(ctx, refreshToken, proxyURL, strings.TrimSpace(clientID))
}
clientIDs := []string{
openai.ClientID,
openai.SoraClientID,
}
seen := make(map[string]struct{}, len(clientIDs))
var lastErr error
for _, clientID := range clientIDs {
clientID = strings.TrimSpace(clientID)
if clientID == "" {
continue
}
if _, ok := seen[clientID]; ok {
continue
}
seen[clientID] = struct{}{}
tokenResp, err := s.refreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID)
if err == nil {
return tokenResp, nil
}
lastErr = err
}
if lastErr != nil {
return nil, lastErr
}
return nil, infraerrors.New(http.StatusBadGateway, "OPENAI_OAUTH_TOKEN_REFRESH_FAILED", "token refresh failed")
}
func (s *openaiOAuthService) refreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL, clientID string) (*openai.TokenResponse, error) {
client := createOpenAIReqClient(proxyURL) client := createOpenAIReqClient(proxyURL)
formData := url.Values{} formData := url.Values{}
formData.Set("grant_type", "refresh_token") formData.Set("grant_type", "refresh_token")
formData.Set("refresh_token", refreshToken) formData.Set("refresh_token", refreshToken)
formData.Set("client_id", openai.ClientID) formData.Set("client_id", clientID)
formData.Set("scope", openai.RefreshScopes) formData.Set("scope", openai.RefreshScopes)
var tokenResp openai.TokenResponse var tokenResp openai.TokenResponse

View File

@@ -136,6 +136,60 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FormFields() {
require.Equal(s.T(), "rt2", resp.RefreshToken) require.Equal(s.T(), "rt2", resp.RefreshToken)
} }
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FallbackToSoraClientID() {
var seenClientIDs []string
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
clientID := r.PostForm.Get("client_id")
seenClientIDs = append(seenClientIDs, clientID)
if clientID == openai.ClientID {
w.WriteHeader(http.StatusBadRequest)
_, _ = io.WriteString(w, "invalid_grant")
return
}
if clientID == openai.SoraClientID {
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"access_token":"at-sora","refresh_token":"rt-sora","token_type":"bearer","expires_in":3600}`)
return
}
w.WriteHeader(http.StatusBadRequest)
}))
resp, err := s.svc.RefreshToken(s.ctx, "rt", "")
require.NoError(s.T(), err, "RefreshToken")
require.Equal(s.T(), "at-sora", resp.AccessToken)
require.Equal(s.T(), "rt-sora", resp.RefreshToken)
require.Equal(s.T(), []string{openai.ClientID, openai.SoraClientID}, seenClientIDs)
}
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseProvidedClientID() {
const customClientID = "custom-client-id"
var seenClientIDs []string
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
clientID := r.PostForm.Get("client_id")
seenClientIDs = append(seenClientIDs, clientID)
if clientID != customClientID {
w.WriteHeader(http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"access_token":"at-custom","refresh_token":"rt-custom","token_type":"bearer","expires_in":3600}`)
}))
resp, err := s.svc.RefreshTokenWithClientID(s.ctx, "rt", "", customClientID)
require.NoError(s.T(), err, "RefreshTokenWithClientID")
require.Equal(s.T(), "at-custom", resp.AccessToken)
require.Equal(s.T(), "rt-custom", resp.RefreshToken)
require.Equal(s.T(), []string{customClientID}, seenClientIDs)
}
func (s *OpenAIOAuthServiceSuite) TestNonSuccessStatus_IncludesBody() { func (s *OpenAIOAuthServiceSuite) TestNonSuccessStatus_IncludesBody() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)

View File

@@ -34,6 +34,8 @@ func RegisterAdminRoutes(
// OpenAI OAuth // OpenAI OAuth
registerOpenAIOAuthRoutes(admin, h) registerOpenAIOAuthRoutes(admin, h)
// Sora OAuth实现复用 OpenAI OAuth 服务,入口独立)
registerSoraOAuthRoutes(admin, h)
// Gemini OAuth // Gemini OAuth
registerGeminiOAuthRoutes(admin, h) registerGeminiOAuthRoutes(admin, h)
@@ -276,6 +278,19 @@ func registerOpenAIOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
} }
} }
func registerSoraOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
sora := admin.Group("/sora")
{
sora.POST("/generate-auth-url", h.Admin.OpenAIOAuth.GenerateAuthURL)
sora.POST("/exchange-code", h.Admin.OpenAIOAuth.ExchangeCode)
sora.POST("/refresh-token", h.Admin.OpenAIOAuth.RefreshToken)
sora.POST("/st2at", h.Admin.OpenAIOAuth.ExchangeSoraSessionToken)
sora.POST("/rt2at", h.Admin.OpenAIOAuth.RefreshToken)
sora.POST("/accounts/:id/refresh", h.Admin.OpenAIOAuth.RefreshAccountToken)
sora.POST("/create-from-oauth", h.Admin.OpenAIOAuth.CreateAccountFromOAuth)
}
}
func registerGeminiOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) { func registerGeminiOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
gemini := admin.Group("/gemini") gemini := admin.Group("/gemini")
{ {

View File

@@ -1,6 +1,8 @@
package routes package routes
import ( import (
"net/http"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler" "github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/server/middleware"
@@ -41,16 +43,15 @@ func RegisterGatewayRoutes(
gateway.GET("/usage", h.Gateway.Usage) gateway.GET("/usage", h.Gateway.Usage)
// OpenAI Responses API // OpenAI Responses API
gateway.POST("/responses", h.OpenAIGateway.Responses) gateway.POST("/responses", h.OpenAIGateway.Responses)
} // 明确阻止旧入口误用到 Sora避免客户端把 OpenAI Chat Completions 当作 Sora 入口
gateway.POST("/chat/completions", func(c *gin.Context) {
// Sora Chat Completions c.JSON(http.StatusBadRequest, gin.H{
soraGateway := r.Group("/v1") "error": gin.H{
soraGateway.Use(soraBodyLimit) "type": "invalid_request_error",
soraGateway.Use(clientRequestID) "message": "For Sora, use /sora/v1/chat/completions. OpenAI should use /v1/responses.",
soraGateway.Use(opsErrorLogger) },
soraGateway.Use(gin.HandlerFunc(apiKeyAuth)) })
{ })
soraGateway.POST("/chat/completions", h.SoraGateway.ChatCompletions)
} }
// Gemini 原生 API 兼容层Gemini SDK/CLI 直连) // Gemini 原生 API 兼容层Gemini SDK/CLI 直连)

View File

@@ -27,11 +27,13 @@ import (
// sseDataPrefix matches SSE data lines with optional whitespace after colon. // sseDataPrefix matches SSE data lines with optional whitespace after colon.
// Some upstream APIs return non-standard "data:" without space (should be "data: "). // Some upstream APIs return non-standard "data:" without space (should be "data: ").
var sseDataPrefix = regexp.MustCompile(`^data:\s*`) var sseDataPrefix = regexp.MustCompile(`^data:\s*`)
var cloudflareRayPattern = regexp.MustCompile(`(?i)cRay:\s*'([a-z0-9-]+)'`)
const ( const (
testClaudeAPIURL = "https://api.anthropic.com/v1/messages" testClaudeAPIURL = "https://api.anthropic.com/v1/messages"
chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses" chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses"
soraMeAPIURL = "https://sora.chatgpt.com/backend/me" // Sora 用户信息接口,用于测试连接 soraMeAPIURL = "https://sora.chatgpt.com/backend/me" // Sora 用户信息接口,用于测试连接
soraBillingAPIURL = "https://sora.chatgpt.com/backend/billing/subscriptions"
) )
// TestEvent represents a SSE event for account testing // TestEvent represents a SSE event for account testing
@@ -502,8 +504,9 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
if account.ProxyID != nil && account.Proxy != nil { if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL() proxyURL = account.Proxy.URL()
} }
enableSoraTLSFingerprint := s.shouldEnableSoraTLSFingerprint()
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, enableSoraTLSFingerprint)
if err != nil { if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
} }
@@ -512,7 +515,10 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
body, _ := io.ReadAll(resp.Body) body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, string(body))) if isCloudflareChallengeResponse(resp.StatusCode, body) {
return s.sendErrorAndEnd(c, formatCloudflareChallengeMessage("Sora request blocked by Cloudflare challenge (HTTP 403). Please switch to a clean proxy/network and retry.", resp.Header, body))
}
return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, truncateSoraErrorBody(body, 512)))
} }
// 解析 /me 响应,提取用户信息 // 解析 /me 响应,提取用户信息
@@ -531,10 +537,129 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
s.sendEvent(c, TestEvent{Type: "content", Text: info}) s.sendEvent(c, TestEvent{Type: "content", Text: info})
} }
// 追加轻量能力检查:订阅信息查询(失败仅告警,不中断连接测试)
subReq, err := http.NewRequestWithContext(ctx, "GET", soraBillingAPIURL, nil)
if err == nil {
subReq.Header.Set("Authorization", "Bearer "+authToken)
subReq.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
subReq.Header.Set("Accept", "application/json")
subResp, subErr := s.httpUpstream.DoWithTLS(subReq, proxyURL, account.ID, account.Concurrency, enableSoraTLSFingerprint)
if subErr != nil {
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check skipped: %s", subErr.Error())})
} else {
subBody, _ := io.ReadAll(subResp.Body)
_ = subResp.Body.Close()
if subResp.StatusCode == http.StatusOK {
if summary := parseSoraSubscriptionSummary(subBody); summary != "" {
s.sendEvent(c, TestEvent{Type: "content", Text: summary})
} else {
s.sendEvent(c, TestEvent{Type: "content", Text: "Subscription check OK"})
}
} else {
if isCloudflareChallengeResponse(subResp.StatusCode, subBody) {
s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage("Subscription check blocked by Cloudflare challenge (HTTP 403)", subResp.Header, subBody)})
} else {
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check returned %d", subResp.StatusCode)})
}
}
}
}
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil return nil
} }
func parseSoraSubscriptionSummary(body []byte) string {
var subResp struct {
Data []struct {
Plan struct {
ID string `json:"id"`
Title string `json:"title"`
} `json:"plan"`
EndTS string `json:"end_ts"`
} `json:"data"`
}
if err := json.Unmarshal(body, &subResp); err != nil {
return ""
}
if len(subResp.Data) == 0 {
return ""
}
first := subResp.Data[0]
parts := make([]string, 0, 3)
if first.Plan.Title != "" {
parts = append(parts, first.Plan.Title)
}
if first.Plan.ID != "" {
parts = append(parts, first.Plan.ID)
}
if first.EndTS != "" {
parts = append(parts, "end="+first.EndTS)
}
if len(parts) == 0 {
return ""
}
return "Subscription: " + strings.Join(parts, " | ")
}
func (s *AccountTestService) shouldEnableSoraTLSFingerprint() bool {
if s == nil || s.cfg == nil {
return false
}
return s.cfg.Gateway.TLSFingerprint.Enabled && !s.cfg.Sora.Client.DisableTLSFingerprint
}
func isCloudflareChallengeResponse(statusCode int, body []byte) bool {
if statusCode != http.StatusForbidden {
return false
}
preview := strings.ToLower(truncateSoraErrorBody(body, 4096))
return strings.Contains(preview, "window._cf_chl_opt") ||
strings.Contains(preview, "just a moment") ||
strings.Contains(preview, "enable javascript and cookies to continue")
}
func formatCloudflareChallengeMessage(base string, headers http.Header, body []byte) string {
rayID := extractCloudflareRayID(headers, body)
if rayID == "" {
return base
}
return fmt.Sprintf("%s (cf-ray: %s)", base, rayID)
}
func extractCloudflareRayID(headers http.Header, body []byte) string {
if headers != nil {
rayID := strings.TrimSpace(headers.Get("cf-ray"))
if rayID != "" {
return rayID
}
rayID = strings.TrimSpace(headers.Get("Cf-Ray"))
if rayID != "" {
return rayID
}
}
preview := truncateSoraErrorBody(body, 8192)
matches := cloudflareRayPattern.FindStringSubmatch(preview)
if len(matches) >= 2 {
return strings.TrimSpace(matches[1])
}
return ""
}
func truncateSoraErrorBody(body []byte, max int) string {
if max <= 0 {
max = 512
}
raw := strings.TrimSpace(string(body))
if len(raw) <= max {
return raw
}
return raw[:max] + "...(truncated)"
}
// testAntigravityAccountConnection tests an Antigravity account's connection // testAntigravityAccountConnection tests an Antigravity account's connection
// 支持 Claude 和 Gemini 两种协议,使用非流式请求 // 支持 Claude 和 Gemini 两种协议,使用非流式请求
func (s *AccountTestService) testAntigravityAccountConnection(c *gin.Context, account *Account, modelID string) error { func (s *AccountTestService) testAntigravityAccountConnection(c *gin.Context, account *Account, modelID string) error {

View File

@@ -0,0 +1,193 @@
package service
import (
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type queuedHTTPUpstream struct {
responses []*http.Response
requests []*http.Request
tlsFlags []bool
}
func (u *queuedHTTPUpstream) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
return nil, fmt.Errorf("unexpected Do call")
}
func (u *queuedHTTPUpstream) DoWithTLS(req *http.Request, _ string, _ int64, _ int, enableTLSFingerprint bool) (*http.Response, error) {
u.requests = append(u.requests, req)
u.tlsFlags = append(u.tlsFlags, enableTLSFingerprint)
if len(u.responses) == 0 {
return nil, fmt.Errorf("no mocked response")
}
resp := u.responses[0]
u.responses = u.responses[1:]
return resp, nil
}
func newJSONResponse(status int, body string) *http.Response {
return &http.Response{
StatusCode: status,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader(body)),
}
}
func newJSONResponseWithHeader(status int, body, key, value string) *http.Response {
resp := newJSONResponse(status, body)
resp.Header.Set(key, value)
return resp
}
func newSoraTestContext() (*gin.Context, *httptest.ResponseRecorder) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil)
return c, rec
}
func TestAccountTestService_testSoraAccountConnection_WithSubscription(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`),
newJSONResponse(http.StatusOK, `{"data":[{"plan":{"id":"chatgpt_plus","title":"ChatGPT Plus"},"end_ts":"2026-12-31T00:00:00Z"}]}`),
},
}
svc := &AccountTestService{
httpUpstream: upstream,
cfg: &config.Config{
Gateway: config.GatewayConfig{
TLSFingerprint: config.TLSFingerprintConfig{
Enabled: true,
},
},
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
DisableTLSFingerprint: false,
},
},
},
}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c, rec := newSoraTestContext()
err := svc.testSoraAccountConnection(c, account)
require.NoError(t, err)
require.Len(t, upstream.requests, 2)
require.Equal(t, soraMeAPIURL, upstream.requests[0].URL.String())
require.Equal(t, soraBillingAPIURL, upstream.requests[1].URL.String())
require.Equal(t, "Bearer test_token", upstream.requests[0].Header.Get("Authorization"))
require.Equal(t, "Bearer test_token", upstream.requests[1].Header.Get("Authorization"))
require.Equal(t, []bool{true, true}, upstream.tlsFlags)
body := rec.Body.String()
require.Contains(t, body, `"type":"test_start"`)
require.Contains(t, body, "Sora connection OK - Email: demo@example.com")
require.Contains(t, body, "Subscription: ChatGPT Plus | chatgpt_plus | end=2026-12-31T00:00:00Z")
require.Contains(t, body, `"type":"test_complete","success":true`)
}
func TestAccountTestService_testSoraAccountConnection_SubscriptionFailedStillSuccess(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusOK, `{"name":"demo-user"}`),
newJSONResponse(http.StatusForbidden, `{"error":{"message":"forbidden"}}`),
},
}
svc := &AccountTestService{httpUpstream: upstream}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c, rec := newSoraTestContext()
err := svc.testSoraAccountConnection(c, account)
require.NoError(t, err)
require.Len(t, upstream.requests, 2)
body := rec.Body.String()
require.Contains(t, body, "Sora connection OK - User: demo-user")
require.Contains(t, body, "Subscription check returned 403")
require.Contains(t, body, `"type":"test_complete","success":true`)
}
func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponseWithHeader(http.StatusForbidden, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={};</script><noscript>Enable JavaScript and cookies to continue</noscript></body></html>`, "cf-ray", "9cff2d62d83bb98d"),
},
}
svc := &AccountTestService{httpUpstream: upstream}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c, rec := newSoraTestContext()
err := svc.testSoraAccountConnection(c, account)
require.Error(t, err)
require.Contains(t, err.Error(), "Cloudflare challenge")
require.Contains(t, err.Error(), "cf-ray: 9cff2d62d83bb98d")
body := rec.Body.String()
require.Contains(t, body, `"type":"error"`)
require.Contains(t, body, "Cloudflare challenge")
require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d")
}
func TestAccountTestService_testSoraAccountConnection_SubscriptionCloudflareChallengeWithRay(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusOK, `{"name":"demo-user"}`),
newJSONResponse(http.StatusForbidden, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={cRay: '9cff2d62d83bb98d'};</script><noscript>Enable JavaScript and cookies to continue</noscript></body></html>`),
},
}
svc := &AccountTestService{httpUpstream: upstream}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c, rec := newSoraTestContext()
err := svc.testSoraAccountConnection(c, account)
require.NoError(t, err)
body := rec.Body.String()
require.Contains(t, body, "Subscription check blocked by Cloudflare challenge (HTTP 403)")
require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d")
require.Contains(t, body, `"type":"test_complete","success":true`)
}

View File

@@ -14,6 +14,7 @@ import (
type OpenAIOAuthClient interface { type OpenAIOAuthClient interface {
ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error)
RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error)
RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error)
} }
// ClaudeOAuthClient handles HTTP requests for Claude OAuth flows // ClaudeOAuthClient handles HTTP requests for Claude OAuth flows

View File

@@ -2,13 +2,20 @@ package service
import ( import (
"context" "context"
"crypto/subtle"
"encoding/json"
"io"
"net/http" "net/http"
"net/url"
"strings"
"time" "time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
) )
var openAISoraSessionAuthURL = "https://sora.chatgpt.com/api/auth/session"
// OpenAIOAuthService handles OpenAI OAuth authentication flows // OpenAIOAuthService handles OpenAI OAuth authentication flows
type OpenAIOAuthService struct { type OpenAIOAuthService struct {
sessionStore *openai.SessionStore sessionStore *openai.SessionStore
@@ -92,6 +99,7 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
type OpenAIExchangeCodeInput struct { type OpenAIExchangeCodeInput struct {
SessionID string SessionID string
Code string Code string
State string
RedirectURI string RedirectURI string
ProxyID *int64 ProxyID *int64
} }
@@ -116,6 +124,12 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
if !ok { if !ok {
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_SESSION_NOT_FOUND", "session not found or expired") return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_SESSION_NOT_FOUND", "session not found or expired")
} }
if input.State == "" {
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_STATE_REQUIRED", "oauth state is required")
}
if subtle.ConstantTimeCompare([]byte(input.State), []byte(session.State)) != 1 {
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_STATE", "invalid oauth state")
}
// Get proxy URL: prefer input.ProxyID, fallback to session.ProxyURL // Get proxy URL: prefer input.ProxyID, fallback to session.ProxyURL
proxyURL := session.ProxyURL proxyURL := session.ProxyURL
@@ -173,7 +187,12 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
// RefreshToken refreshes an OpenAI OAuth token // RefreshToken refreshes an OpenAI OAuth token
func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken string, proxyURL string) (*OpenAITokenInfo, error) { func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken string, proxyURL string) (*OpenAITokenInfo, error) {
tokenResp, err := s.oauthClient.RefreshToken(ctx, refreshToken, proxyURL) return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, "")
}
// RefreshTokenWithClientID refreshes an OpenAI/Sora OAuth token with optional client_id.
func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refreshToken string, proxyURL string, clientID string) (*OpenAITokenInfo, error) {
tokenResp, err := s.oauthClient.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -205,13 +224,83 @@ func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken stri
return tokenInfo, nil return tokenInfo, nil
} }
// RefreshAccountToken refreshes token for an OpenAI account // ExchangeSoraSessionToken exchanges Sora session_token to access_token.
func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) { func (s *OpenAIOAuthService) ExchangeSoraSessionToken(ctx context.Context, sessionToken string, proxyID *int64) (*OpenAITokenInfo, error) {
if !account.IsOpenAI() { if strings.TrimSpace(sessionToken) == "" {
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI account") return nil, infraerrors.New(http.StatusBadRequest, "SORA_SESSION_TOKEN_REQUIRED", "session_token is required")
} }
refreshToken := account.GetOpenAIRefreshToken() proxyURL, err := s.resolveProxyURL(ctx, proxyID)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, openAISoraSessionAuthURL, nil)
if err != nil {
return nil, infraerrors.Newf(http.StatusInternalServerError, "SORA_SESSION_REQUEST_BUILD_FAILED", "failed to build request: %v", err)
}
req.Header.Set("Cookie", "__Secure-next-auth.session-token="+strings.TrimSpace(sessionToken))
req.Header.Set("Accept", "application/json")
req.Header.Set("Origin", "https://sora.chatgpt.com")
req.Header.Set("Referer", "https://sora.chatgpt.com/")
req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
client := newOpenAIOAuthHTTPClient(proxyURL)
resp, err := client.Do(req)
if err != nil {
return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_REQUEST_FAILED", "request failed: %v", err)
}
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
if resp.StatusCode != http.StatusOK {
return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_EXCHANGE_FAILED", "status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var sessionResp struct {
AccessToken string `json:"accessToken"`
Expires string `json:"expires"`
User struct {
Email string `json:"email"`
Name string `json:"name"`
} `json:"user"`
}
if err := json.Unmarshal(body, &sessionResp); err != nil {
return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_PARSE_FAILED", "failed to parse response: %v", err)
}
if strings.TrimSpace(sessionResp.AccessToken) == "" {
return nil, infraerrors.New(http.StatusBadGateway, "SORA_SESSION_ACCESS_TOKEN_MISSING", "session exchange response missing access token")
}
expiresAt := time.Now().Add(time.Hour).Unix()
if strings.TrimSpace(sessionResp.Expires) != "" {
if parsed, parseErr := time.Parse(time.RFC3339, sessionResp.Expires); parseErr == nil {
expiresAt = parsed.Unix()
}
}
expiresIn := expiresAt - time.Now().Unix()
if expiresIn < 0 {
expiresIn = 0
}
return &OpenAITokenInfo{
AccessToken: strings.TrimSpace(sessionResp.AccessToken),
ExpiresIn: expiresIn,
ExpiresAt: expiresAt,
Email: strings.TrimSpace(sessionResp.User.Email),
}, nil
}
// RefreshAccountToken refreshes token for an OpenAI/Sora OAuth account
func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
if account.Platform != PlatformOpenAI && account.Platform != PlatformSora {
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI/Sora account")
}
if account.Type != AccountTypeOAuth {
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT_TYPE", "account is not an OAuth account")
}
refreshToken := account.GetCredential("refresh_token")
if refreshToken == "" { if refreshToken == "" {
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_NO_REFRESH_TOKEN", "no refresh token available") return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_NO_REFRESH_TOKEN", "no refresh token available")
} }
@@ -224,7 +313,8 @@ func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *A
} }
} }
return s.RefreshToken(ctx, refreshToken, proxyURL) clientID := account.GetCredential("client_id")
return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID)
} }
// BuildAccountCredentials builds credentials map from token info // BuildAccountCredentials builds credentials map from token info
@@ -260,3 +350,30 @@ func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo)
func (s *OpenAIOAuthService) Stop() { func (s *OpenAIOAuthService) Stop() {
s.sessionStore.Stop() s.sessionStore.Stop()
} }
func (s *OpenAIOAuthService) resolveProxyURL(ctx context.Context, proxyID *int64) (string, error) {
if proxyID == nil {
return "", nil
}
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
if err != nil {
return "", infraerrors.Newf(http.StatusBadRequest, "OPENAI_OAUTH_PROXY_NOT_FOUND", "proxy not found: %v", err)
}
if proxy == nil {
return "", nil
}
return proxy.URL(), nil
}
func newOpenAIOAuthHTTPClient(proxyURL string) *http.Client {
transport := &http.Transport{}
if strings.TrimSpace(proxyURL) != "" {
if parsed, err := url.Parse(proxyURL); err == nil && parsed.Host != "" {
transport.Proxy = http.ProxyURL(parsed)
}
}
return &http.Client{
Timeout: 120 * time.Second,
Transport: transport,
}
}

View File

@@ -0,0 +1,69 @@
package service
import (
"context"
"errors"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/stretchr/testify/require"
)
type openaiOAuthClientNoopStub struct{}
func (s *openaiOAuthClientNoopStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
return nil, errors.New("not implemented")
}
func (s *openaiOAuthClientNoopStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
return nil, errors.New("not implemented")
}
func (s *openaiOAuthClientNoopStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
return nil, errors.New("not implemented")
}
func TestOpenAIOAuthService_ExchangeSoraSessionToken_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=st-token")
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
}))
defer server.Close()
origin := openAISoraSessionAuthURL
openAISoraSessionAuthURL = server.URL
defer func() { openAISoraSessionAuthURL = origin }()
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
defer svc.Stop()
info, err := svc.ExchangeSoraSessionToken(context.Background(), "st-token", nil)
require.NoError(t, err)
require.NotNil(t, info)
require.Equal(t, "at-token", info.AccessToken)
require.Equal(t, "demo@example.com", info.Email)
require.Greater(t, info.ExpiresAt, int64(0))
}
func TestOpenAIOAuthService_ExchangeSoraSessionToken_MissingAccessToken(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"expires":"2099-01-01T00:00:00Z"}`))
}))
defer server.Close()
origin := openAISoraSessionAuthURL
openAISoraSessionAuthURL = server.URL
defer func() { openAISoraSessionAuthURL = origin }()
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
defer svc.Stop()
_, err := svc.ExchangeSoraSessionToken(context.Background(), "st-token", nil)
require.Error(t, err)
require.Contains(t, err.Error(), "missing access token")
}

View File

@@ -0,0 +1,102 @@
package service
import (
"context"
"errors"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/stretchr/testify/require"
)
type openaiOAuthClientStateStub struct {
exchangeCalled int32
}
func (s *openaiOAuthClientStateStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
atomic.AddInt32(&s.exchangeCalled, 1)
return &openai.TokenResponse{
AccessToken: "at",
RefreshToken: "rt",
ExpiresIn: 3600,
}, nil
}
func (s *openaiOAuthClientStateStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
return nil, errors.New("not implemented")
}
func (s *openaiOAuthClientStateStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
return s.RefreshToken(ctx, refreshToken, proxyURL)
}
func TestOpenAIOAuthService_ExchangeCode_StateRequired(t *testing.T) {
client := &openaiOAuthClientStateStub{}
svc := NewOpenAIOAuthService(nil, client)
defer svc.Stop()
svc.sessionStore.Set("sid", &openai.OAuthSession{
State: "expected-state",
CodeVerifier: "verifier",
RedirectURI: openai.DefaultRedirectURI,
CreatedAt: time.Now(),
})
_, err := svc.ExchangeCode(context.Background(), &OpenAIExchangeCodeInput{
SessionID: "sid",
Code: "auth-code",
})
require.Error(t, err)
require.Contains(t, err.Error(), "oauth state is required")
require.Equal(t, int32(0), atomic.LoadInt32(&client.exchangeCalled))
}
func TestOpenAIOAuthService_ExchangeCode_StateMismatch(t *testing.T) {
client := &openaiOAuthClientStateStub{}
svc := NewOpenAIOAuthService(nil, client)
defer svc.Stop()
svc.sessionStore.Set("sid", &openai.OAuthSession{
State: "expected-state",
CodeVerifier: "verifier",
RedirectURI: openai.DefaultRedirectURI,
CreatedAt: time.Now(),
})
_, err := svc.ExchangeCode(context.Background(), &OpenAIExchangeCodeInput{
SessionID: "sid",
Code: "auth-code",
State: "wrong-state",
})
require.Error(t, err)
require.Contains(t, err.Error(), "invalid oauth state")
require.Equal(t, int32(0), atomic.LoadInt32(&client.exchangeCalled))
}
func TestOpenAIOAuthService_ExchangeCode_StateMatch(t *testing.T) {
client := &openaiOAuthClientStateStub{}
svc := NewOpenAIOAuthService(nil, client)
defer svc.Stop()
svc.sessionStore.Set("sid", &openai.OAuthSession{
State: "expected-state",
CodeVerifier: "verifier",
RedirectURI: openai.DefaultRedirectURI,
CreatedAt: time.Now(),
})
info, err := svc.ExchangeCode(context.Background(), &OpenAIExchangeCodeInput{
SessionID: "sid",
Code: "auth-code",
State: "expected-state",
})
require.NoError(t, err)
require.NotNil(t, info)
require.Equal(t, "at", info.AccessToken)
require.Equal(t, int32(1), atomic.LoadInt32(&client.exchangeCalled))
_, ok := svc.sessionStore.Get("sid")
require.False(t, ok)
}

View File

@@ -157,7 +157,11 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
} }
expiresAt = account.GetCredentialAsTime("expires_at") expiresAt = account.GetCredentialAsTime("expires_at")
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew { if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
if p.openAIOAuthService == nil { if account.Platform == PlatformSora {
slog.Debug("openai_token_refresh_skipped_for_sora", "account_id", account.ID)
// Sora 账号不走 OpenAI OAuth 刷新,交由 Sora 客户端的 ST/RT 恢复链路处理。
refreshFailed = true
} else if p.openAIOAuthService == nil {
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID) slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
p.metrics.refreshFailure.Add(1) p.metrics.refreshFailure.Add(1)
refreshFailed = true // 无法刷新,标记失败 refreshFailed = true // 无法刷新,标记失败
@@ -206,7 +210,11 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
// 仅在 expires_at 已过期/接近过期时才执行无锁刷新 // 仅在 expires_at 已过期/接近过期时才执行无锁刷新
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew { if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
if p.openAIOAuthService == nil { if account.Platform == PlatformSora {
slog.Debug("openai_token_refresh_skipped_for_sora_degraded", "account_id", account.ID)
// Sora 账号不走 OpenAI OAuth 刷新,交由 Sora 客户端的 ST/RT 恢复链路处理。
refreshFailed = true
} else if p.openAIOAuthService == nil {
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID) slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
p.metrics.refreshFailure.Add(1) p.metrics.refreshFailure.Add(1)
refreshFailed = true refreshFailed = true

View File

@@ -17,12 +17,15 @@ import (
"net/textproto" "net/textproto"
"net/url" "net/url"
"path" "path"
"sort"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
openaioauth "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"golang.org/x/crypto/sha3" "golang.org/x/crypto/sha3"
@@ -34,6 +37,11 @@ const (
soraDefaultUserAgent = "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)" soraDefaultUserAgent = "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)"
) )
var (
soraSessionAuthURL = "https://sora.chatgpt.com/api/auth/session"
soraOAuthTokenURL = "https://auth.openai.com/oauth/token"
)
const ( const (
soraPowMaxIteration = 500000 soraPowMaxIteration = 500000
) )
@@ -96,6 +104,7 @@ type SoraClient interface {
UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error)
CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error) CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error)
CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error)
EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error)
GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error)
GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error) GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error)
} }
@@ -160,23 +169,91 @@ type SoraDirectClient struct {
cfg *config.Config cfg *config.Config
httpUpstream HTTPUpstream httpUpstream HTTPUpstream
tokenProvider *OpenAITokenProvider tokenProvider *OpenAITokenProvider
accountRepo AccountRepository
soraAccountRepo SoraAccountRepository
baseURL string
} }
// NewSoraDirectClient 创建 Sora 直连客户端 // NewSoraDirectClient 创建 Sora 直连客户端
func NewSoraDirectClient(cfg *config.Config, httpUpstream HTTPUpstream, tokenProvider *OpenAITokenProvider) *SoraDirectClient { func NewSoraDirectClient(cfg *config.Config, httpUpstream HTTPUpstream, tokenProvider *OpenAITokenProvider) *SoraDirectClient {
baseURL := ""
if cfg != nil {
rawBaseURL := strings.TrimRight(strings.TrimSpace(cfg.Sora.Client.BaseURL), "/")
baseURL = normalizeSoraBaseURL(rawBaseURL)
if rawBaseURL != "" && baseURL != rawBaseURL {
log.Printf("[SoraClient] normalized base_url from %s to %s", sanitizeSoraLogURL(rawBaseURL), sanitizeSoraLogURL(baseURL))
}
}
return &SoraDirectClient{ return &SoraDirectClient{
cfg: cfg, cfg: cfg,
httpUpstream: httpUpstream, httpUpstream: httpUpstream,
tokenProvider: tokenProvider, tokenProvider: tokenProvider,
baseURL: baseURL,
} }
} }
func (c *SoraDirectClient) SetAccountRepositories(accountRepo AccountRepository, soraAccountRepo SoraAccountRepository) {
if c == nil {
return
}
c.accountRepo = accountRepo
c.soraAccountRepo = soraAccountRepo
}
// Enabled 判断是否启用 Sora 直连 // Enabled 判断是否启用 Sora 直连
func (c *SoraDirectClient) Enabled() bool { func (c *SoraDirectClient) Enabled() bool {
if c == nil || c.cfg == nil { if c == nil {
return false return false
} }
return strings.TrimSpace(c.cfg.Sora.Client.BaseURL) != "" if strings.TrimSpace(c.baseURL) != "" {
return true
}
if c.cfg == nil {
return false
}
return strings.TrimSpace(normalizeSoraBaseURL(c.cfg.Sora.Client.BaseURL)) != ""
}
// PreflightCheck 在创建任务前执行账号能力预检。
// 当前仅对视频模型执行 /nf/check 预检,用于提前识别额度耗尽或能力缺失。
func (c *SoraDirectClient) PreflightCheck(ctx context.Context, account *Account, requestedModel string, modelCfg SoraModelConfig) error {
if modelCfg.Type != "video" {
return nil
}
token, err := c.getAccessToken(ctx, account)
if err != nil {
return err
}
headers := c.buildBaseHeaders(token, c.defaultUserAgent())
headers.Set("Accept", "application/json")
body, _, err := c.doRequest(ctx, account, http.MethodGet, c.buildURL("/nf/check"), headers, nil, false)
if err != nil {
var upstreamErr *SoraUpstreamError
if errors.As(err, &upstreamErr) && upstreamErr.StatusCode == http.StatusNotFound {
return &SoraUpstreamError{
StatusCode: http.StatusForbidden,
Message: "当前账号未开通 Sora2 能力或无可用配额",
Headers: upstreamErr.Headers,
Body: upstreamErr.Body,
}
}
return err
}
rateLimitReached := gjson.GetBytes(body, "rate_limit_and_credit_balance.rate_limit_reached").Bool()
remaining := gjson.GetBytes(body, "rate_limit_and_credit_balance.estimated_num_videos_remaining")
if rateLimitReached || (remaining.Exists() && remaining.Int() <= 0) {
msg := "当前账号 Sora2 可用配额不足"
if requestedModel != "" {
msg = fmt.Sprintf("当前账号 %s 可用配额不足", requestedModel)
}
return &SoraUpstreamError{
StatusCode: http.StatusTooManyRequests,
Message: msg,
Headers: http.Header{},
}
}
return nil
} }
func (c *SoraDirectClient) UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) { func (c *SoraDirectClient) UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) {
@@ -347,6 +424,45 @@ func (c *SoraDirectClient) CreateVideoTask(ctx context.Context, account *Account
return taskID, nil return taskID, nil
} }
func (c *SoraDirectClient) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return "", err
}
if strings.TrimSpace(expansionLevel) == "" {
expansionLevel = "medium"
}
if durationS <= 0 {
durationS = 10
}
payload := map[string]any{
"prompt": prompt,
"expansion_level": expansionLevel,
"duration_s": durationS,
}
body, err := json.Marshal(payload)
if err != nil {
return "", err
}
headers := c.buildBaseHeaders(token, c.defaultUserAgent())
headers.Set("Content-Type", "application/json")
headers.Set("Accept", "application/json")
headers.Set("Origin", "https://sora.chatgpt.com")
headers.Set("Referer", "https://sora.chatgpt.com/")
respBody, _, err := c.doRequest(ctx, account, http.MethodPost, c.buildURL("/editor/enhance_prompt"), headers, bytes.NewReader(body), false)
if err != nil {
return "", err
}
enhancedPrompt := strings.TrimSpace(gjson.GetBytes(respBody, "enhanced_prompt").String())
if enhancedPrompt == "" {
return "", errors.New("enhance_prompt response missing enhanced_prompt")
}
return enhancedPrompt, nil
}
func (c *SoraDirectClient) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) { func (c *SoraDirectClient) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) {
status, found, err := c.fetchRecentImageTask(ctx, account, taskID, c.recentTaskLimit()) status, found, err := c.fetchRecentImageTask(ctx, account, taskID, c.recentTaskLimit())
if err != nil { if err != nil {
@@ -512,9 +628,10 @@ func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, t
} }
func (c *SoraDirectClient) buildURL(endpoint string) string { func (c *SoraDirectClient) buildURL(endpoint string) string {
base := "" base := strings.TrimRight(strings.TrimSpace(c.baseURL), "/")
if c != nil && c.cfg != nil { if base == "" && c != nil && c.cfg != nil {
base = strings.TrimRight(strings.TrimSpace(c.cfg.Sora.Client.BaseURL), "/") base = normalizeSoraBaseURL(c.cfg.Sora.Client.BaseURL)
c.baseURL = base
} }
if base == "" { if base == "" {
return endpoint return endpoint
@@ -540,14 +657,257 @@ func (c *SoraDirectClient) getAccessToken(ctx context.Context, account *Account)
if account == nil { if account == nil {
return "", errors.New("account is nil") return "", errors.New("account is nil")
} }
if c.tokenProvider != nil {
return c.tokenProvider.GetAccessToken(ctx, account) allowProvider := c.allowOpenAITokenProvider(account)
var providerErr error
if allowProvider && c.tokenProvider != nil {
token, err := c.tokenProvider.GetAccessToken(ctx, account)
if err == nil && strings.TrimSpace(token) != "" {
c.logTokenSource(account, "openai_token_provider")
return token, nil
}
providerErr = err
if err != nil && c.debugEnabled() {
c.debugLogf(
"token_provider_failed account_id=%d platform=%s err=%s",
account.ID,
account.Platform,
logredact.RedactText(err.Error()),
)
}
} }
token := strings.TrimSpace(account.GetCredential("access_token")) token := strings.TrimSpace(account.GetCredential("access_token"))
if token == "" { if token != "" {
return "", errors.New("access_token not found") expiresAt := account.GetCredentialAsTime("expires_at")
if expiresAt != nil && time.Until(*expiresAt) <= 2*time.Minute {
refreshed, refreshErr := c.recoverAccessToken(ctx, account, "access_token_expiring")
if refreshErr == nil && strings.TrimSpace(refreshed) != "" {
c.logTokenSource(account, "refresh_token_recovered")
return refreshed, nil
} }
if refreshErr != nil && c.debugEnabled() {
c.debugLogf("token_refresh_before_use_failed account_id=%d err=%s", account.ID, logredact.RedactText(refreshErr.Error()))
}
}
c.logTokenSource(account, "account_credentials")
return token, nil return token, nil
}
recovered, recoverErr := c.recoverAccessToken(ctx, account, "access_token_missing")
if recoverErr == nil && strings.TrimSpace(recovered) != "" {
c.logTokenSource(account, "session_or_refresh_recovered")
return recovered, nil
}
if recoverErr != nil && c.debugEnabled() {
c.debugLogf("token_recover_failed account_id=%d platform=%s err=%s", account.ID, account.Platform, logredact.RedactText(recoverErr.Error()))
}
if providerErr != nil {
return "", providerErr
}
if c.tokenProvider != nil && !allowProvider {
c.logTokenSource(account, "account_credentials(provider_disabled)")
}
return "", errors.New("access_token not found")
}
func (c *SoraDirectClient) recoverAccessToken(ctx context.Context, account *Account, reason string) (string, error) {
if account == nil {
return "", errors.New("account is nil")
}
if sessionToken := strings.TrimSpace(account.GetCredential("session_token")); sessionToken != "" {
accessToken, expiresAt, err := c.exchangeSessionToken(ctx, account, sessionToken)
if err == nil && strings.TrimSpace(accessToken) != "" {
c.applyRecoveredToken(ctx, account, accessToken, "", expiresAt, sessionToken)
c.logTokenRecover(account, "session_token", reason, true, nil)
return accessToken, nil
}
c.logTokenRecover(account, "session_token", reason, false, err)
}
refreshToken := strings.TrimSpace(account.GetCredential("refresh_token"))
if refreshToken == "" {
return "", errors.New("session_token/refresh_token not found")
}
accessToken, newRefreshToken, expiresAt, err := c.exchangeRefreshToken(ctx, account, refreshToken)
if err != nil {
c.logTokenRecover(account, "refresh_token", reason, false, err)
return "", err
}
if strings.TrimSpace(accessToken) == "" {
return "", errors.New("refreshed access_token is empty")
}
c.applyRecoveredToken(ctx, account, accessToken, newRefreshToken, expiresAt, "")
c.logTokenRecover(account, "refresh_token", reason, true, nil)
return accessToken, nil
}
func (c *SoraDirectClient) exchangeSessionToken(ctx context.Context, account *Account, sessionToken string) (string, string, error) {
headers := http.Header{}
headers.Set("Cookie", "__Secure-next-auth.session-token="+sessionToken)
headers.Set("Accept", "application/json")
headers.Set("Origin", "https://sora.chatgpt.com")
headers.Set("Referer", "https://sora.chatgpt.com/")
headers.Set("User-Agent", c.defaultUserAgent())
body, _, err := c.doRequest(ctx, account, http.MethodGet, soraSessionAuthURL, headers, nil, false)
if err != nil {
return "", "", err
}
accessToken := strings.TrimSpace(gjson.GetBytes(body, "accessToken").String())
if accessToken == "" {
return "", "", errors.New("session exchange missing accessToken")
}
expiresAt := strings.TrimSpace(gjson.GetBytes(body, "expires").String())
return accessToken, expiresAt, nil
}
func (c *SoraDirectClient) exchangeRefreshToken(ctx context.Context, account *Account, refreshToken string) (string, string, string, error) {
clientIDs := []string{
strings.TrimSpace(account.GetCredential("client_id")),
openaioauth.SoraClientID,
openaioauth.ClientID,
}
tried := make(map[string]struct{}, len(clientIDs))
var lastErr error
for _, clientID := range clientIDs {
if clientID == "" {
continue
}
if _, ok := tried[clientID]; ok {
continue
}
tried[clientID] = struct{}{}
payload := map[string]any{
"client_id": clientID,
"grant_type": "refresh_token",
"refresh_token": refreshToken,
"redirect_uri": "com.openai.chat://auth0.openai.com/ios/com.openai.chat/callback",
}
bodyBytes, err := json.Marshal(payload)
if err != nil {
return "", "", "", err
}
headers := http.Header{}
headers.Set("Accept", "application/json")
headers.Set("Content-Type", "application/json")
headers.Set("User-Agent", c.defaultUserAgent())
respBody, _, err := c.doRequest(ctx, account, http.MethodPost, soraOAuthTokenURL, headers, bytes.NewReader(bodyBytes), false)
if err != nil {
lastErr = err
if c.debugEnabled() {
c.debugLogf("refresh_token_exchange_failed account_id=%d client_id=%s err=%s", account.ID, clientID, logredact.RedactText(err.Error()))
}
continue
}
accessToken := strings.TrimSpace(gjson.GetBytes(respBody, "access_token").String())
if accessToken == "" {
lastErr = errors.New("oauth refresh response missing access_token")
continue
}
newRefreshToken := strings.TrimSpace(gjson.GetBytes(respBody, "refresh_token").String())
expiresIn := gjson.GetBytes(respBody, "expires_in").Int()
expiresAt := ""
if expiresIn > 0 {
expiresAt = time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339)
}
return accessToken, newRefreshToken, expiresAt, nil
}
if lastErr != nil {
return "", "", "", lastErr
}
return "", "", "", errors.New("no available client_id for refresh_token exchange")
}
func (c *SoraDirectClient) applyRecoveredToken(ctx context.Context, account *Account, accessToken, refreshToken, expiresAt, sessionToken string) {
if account == nil {
return
}
if account.Credentials == nil {
account.Credentials = make(map[string]any)
}
if strings.TrimSpace(accessToken) != "" {
account.Credentials["access_token"] = accessToken
}
if strings.TrimSpace(refreshToken) != "" {
account.Credentials["refresh_token"] = refreshToken
}
if strings.TrimSpace(expiresAt) != "" {
account.Credentials["expires_at"] = expiresAt
}
if strings.TrimSpace(sessionToken) != "" {
account.Credentials["session_token"] = sessionToken
}
if c.accountRepo != nil {
if err := c.accountRepo.Update(ctx, account); err != nil {
if c.debugEnabled() {
c.debugLogf("persist_recovered_token_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error()))
}
}
}
c.updateSoraAccountExtension(ctx, account, accessToken, refreshToken, sessionToken)
}
func (c *SoraDirectClient) updateSoraAccountExtension(ctx context.Context, account *Account, accessToken, refreshToken, sessionToken string) {
if c == nil || c.soraAccountRepo == nil || account == nil || account.ID <= 0 {
return
}
updates := make(map[string]any)
if strings.TrimSpace(accessToken) != "" && strings.TrimSpace(refreshToken) != "" {
updates["access_token"] = accessToken
updates["refresh_token"] = refreshToken
}
if strings.TrimSpace(sessionToken) != "" {
updates["session_token"] = sessionToken
}
if len(updates) == 0 {
return
}
if err := c.soraAccountRepo.Upsert(ctx, account.ID, updates); err != nil && c.debugEnabled() {
c.debugLogf("persist_sora_extension_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error()))
}
}
func (c *SoraDirectClient) logTokenRecover(account *Account, source, reason string, success bool, err error) {
if !c.debugEnabled() || account == nil {
return
}
if success {
c.debugLogf("token_recover_success account_id=%d platform=%s source=%s reason=%s", account.ID, account.Platform, source, reason)
return
}
if err == nil {
c.debugLogf("token_recover_failed account_id=%d platform=%s source=%s reason=%s", account.ID, account.Platform, source, reason)
return
}
c.debugLogf("token_recover_failed account_id=%d platform=%s source=%s reason=%s err=%s", account.ID, account.Platform, source, reason, logredact.RedactText(err.Error()))
}
func (c *SoraDirectClient) allowOpenAITokenProvider(account *Account) bool {
if c == nil || c.tokenProvider == nil {
return false
}
if account != nil && account.Platform == PlatformSora {
return c.cfg != nil && c.cfg.Sora.Client.UseOpenAITokenProvider
}
return true
}
func (c *SoraDirectClient) logTokenSource(account *Account, source string) {
if !c.debugEnabled() || account == nil {
return
}
c.debugLogf(
"token_selected account_id=%d platform=%s account_type=%s source=%s",
account.ID,
account.Platform,
account.Type,
source,
)
} }
func (c *SoraDirectClient) buildBaseHeaders(token, userAgent string) http.Header { func (c *SoraDirectClient) buildBaseHeaders(token, userAgent string) http.Header {
@@ -600,7 +960,24 @@ func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, meth
} }
attempts := maxRetries + 1 attempts := maxRetries + 1
authRecovered := false
authRecoverExtraAttemptGranted := false
var lastErr error
for attempt := 1; attempt <= attempts; attempt++ { for attempt := 1; attempt <= attempts; attempt++ {
if c.debugEnabled() {
c.debugLogf(
"request_start method=%s url=%s attempt=%d/%d timeout_s=%d body_bytes=%d proxy_bound=%t headers=%s",
method,
sanitizeSoraLogURL(urlStr),
attempt,
attempts,
timeout,
len(bodyBytes),
account != nil && account.ProxyID != nil && account.Proxy != nil,
formatSoraHeaders(headers),
)
}
var reader io.Reader var reader io.Reader
if bodyBytes != nil { if bodyBytes != nil {
reader = bytes.NewReader(bodyBytes) reader = bytes.NewReader(bodyBytes)
@@ -618,7 +995,21 @@ func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, meth
} }
resp, err := c.doHTTP(req, proxyURL, account) resp, err := c.doHTTP(req, proxyURL, account)
if err != nil { if err != nil {
lastErr = err
if c.debugEnabled() {
c.debugLogf(
"request_transport_error method=%s url=%s attempt=%d/%d err=%s",
method,
sanitizeSoraLogURL(urlStr),
attempt,
attempts,
logredact.RedactText(err.Error()),
)
}
if attempt < attempts && allowRetry { if attempt < attempts && allowRetry {
if c.debugEnabled() {
c.debugLogf("request_retry_scheduled method=%s url=%s reason=transport_error next_attempt=%d/%d", method, sanitizeSoraLogURL(urlStr), attempt+1, attempts)
}
c.sleepRetry(attempt) c.sleepRetry(attempt)
continue continue
} }
@@ -632,12 +1023,53 @@ func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, meth
} }
if c.cfg != nil && c.cfg.Sora.Client.Debug { if c.cfg != nil && c.cfg.Sora.Client.Debug {
log.Printf("[SoraClient] %s %s status=%d cost=%s", method, sanitizeSoraLogURL(urlStr), resp.StatusCode, time.Since(start)) c.debugLogf(
"response_received method=%s url=%s attempt=%d/%d status=%d cost=%s resp_bytes=%d resp_headers=%s",
method,
sanitizeSoraLogURL(urlStr),
attempt,
attempts,
resp.StatusCode,
time.Since(start),
len(respBody),
formatSoraHeaders(resp.Header),
)
} }
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
upstreamErr := c.buildUpstreamError(resp.StatusCode, resp.Header, respBody) if !authRecovered && shouldAttemptSoraTokenRecover(resp.StatusCode, urlStr) && account != nil {
if recovered, recoverErr := c.recoverAccessToken(ctx, account, fmt.Sprintf("upstream_status_%d", resp.StatusCode)); recoverErr == nil && strings.TrimSpace(recovered) != "" {
headers.Set("Authorization", "Bearer "+recovered)
authRecovered = true
if attempt == attempts && !authRecoverExtraAttemptGranted {
attempts++
authRecoverExtraAttemptGranted = true
}
if c.debugEnabled() {
c.debugLogf("request_retry_with_recovered_token method=%s url=%s status=%d", method, sanitizeSoraLogURL(urlStr), resp.StatusCode)
}
continue
} else if recoverErr != nil && c.debugEnabled() {
c.debugLogf("request_recover_token_failed method=%s url=%s status=%d err=%s", method, sanitizeSoraLogURL(urlStr), resp.StatusCode, logredact.RedactText(recoverErr.Error()))
}
}
if c.debugEnabled() {
c.debugLogf(
"response_non_success method=%s url=%s attempt=%d/%d status=%d body=%s",
method,
sanitizeSoraLogURL(urlStr),
attempt,
attempts,
resp.StatusCode,
summarizeSoraResponseBody(respBody, 512),
)
}
upstreamErr := c.buildUpstreamError(resp.StatusCode, resp.Header, respBody, urlStr)
lastErr = upstreamErr
if allowRetry && attempt < attempts && (resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode >= 500) { if allowRetry && attempt < attempts && (resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode >= 500) {
if c.debugEnabled() {
c.debugLogf("request_retry_scheduled method=%s url=%s reason=status_%d next_attempt=%d/%d", method, sanitizeSoraLogURL(urlStr), resp.StatusCode, attempt+1, attempts)
}
c.sleepRetry(attempt) c.sleepRetry(attempt)
continue continue
} }
@@ -645,9 +1077,34 @@ func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, meth
} }
return respBody, resp.Header, nil return respBody, resp.Header, nil
} }
if lastErr != nil {
return nil, nil, lastErr
}
return nil, nil, errors.New("upstream retries exhausted") return nil, nil, errors.New("upstream retries exhausted")
} }
func shouldAttemptSoraTokenRecover(statusCode int, rawURL string) bool {
switch statusCode {
case http.StatusUnauthorized, http.StatusForbidden:
parsed, err := url.Parse(strings.TrimSpace(rawURL))
if err != nil {
return false
}
host := strings.ToLower(parsed.Hostname())
if host != "sora.chatgpt.com" && host != "chatgpt.com" {
return false
}
// 避免在 ST->AT 转换接口上递归触发 token 恢复导致死循环。
path := strings.ToLower(strings.TrimSpace(parsed.Path))
if path == "/api/auth/session" {
return false
}
return true
default:
return false
}
}
func (c *SoraDirectClient) doHTTP(req *http.Request, proxyURL string, account *Account) (*http.Response, error) { func (c *SoraDirectClient) doHTTP(req *http.Request, proxyURL string, account *Account) (*http.Response, error) {
enableTLS := c != nil && c.cfg != nil && c.cfg.Gateway.TLSFingerprint.Enabled && !c.cfg.Sora.Client.DisableTLSFingerprint enableTLS := c != nil && c.cfg != nil && c.cfg.Gateway.TLSFingerprint.Enabled && !c.cfg.Sora.Client.DisableTLSFingerprint
if c.httpUpstream != nil { if c.httpUpstream != nil {
@@ -670,9 +1127,14 @@ func (c *SoraDirectClient) sleepRetry(attempt int) {
time.Sleep(backoff) time.Sleep(backoff)
} }
func (c *SoraDirectClient) buildUpstreamError(status int, headers http.Header, body []byte) error { func (c *SoraDirectClient) buildUpstreamError(status int, headers http.Header, body []byte, requestURL string) error {
msg := strings.TrimSpace(extractUpstreamErrorMessage(body)) msg := strings.TrimSpace(extractUpstreamErrorMessage(body))
msg = sanitizeUpstreamErrorMessage(msg) msg = sanitizeUpstreamErrorMessage(msg)
if status == http.StatusNotFound && strings.Contains(strings.ToLower(msg), "not found") {
if hint := soraBaseURLNotFoundHint(requestURL); hint != "" {
msg = strings.TrimSpace(msg + " " + hint)
}
}
if msg == "" { if msg == "" {
msg = truncateForLog(body, 256) msg = truncateForLog(body, 256)
} }
@@ -684,6 +1146,45 @@ func (c *SoraDirectClient) buildUpstreamError(status int, headers http.Header, b
} }
} }
func normalizeSoraBaseURL(raw string) string {
trimmed := strings.TrimRight(strings.TrimSpace(raw), "/")
if trimmed == "" {
return ""
}
parsed, err := url.Parse(trimmed)
if err != nil || parsed.Scheme == "" || parsed.Host == "" {
return trimmed
}
host := strings.ToLower(parsed.Hostname())
if host != "sora.chatgpt.com" && host != "chatgpt.com" {
return trimmed
}
pathVal := strings.TrimRight(strings.TrimSpace(parsed.Path), "/")
switch pathVal {
case "", "/":
parsed.Path = "/backend"
case "/backend-api":
parsed.Path = "/backend"
}
return strings.TrimRight(parsed.String(), "/")
}
func soraBaseURLNotFoundHint(requestURL string) string {
parsed, err := url.Parse(strings.TrimSpace(requestURL))
if err != nil || parsed.Host == "" {
return ""
}
host := strings.ToLower(parsed.Hostname())
if host != "sora.chatgpt.com" && host != "chatgpt.com" {
return ""
}
pathVal := strings.TrimSpace(parsed.Path)
if strings.HasPrefix(pathVal, "/backend/") || pathVal == "/backend" {
return ""
}
return "(请检查 sora.client.base_url建议配置为 https://sora.chatgpt.com/backend)"
}
func (c *SoraDirectClient) generateSentinelToken(ctx context.Context, account *Account, accessToken string) (string, error) { func (c *SoraDirectClient) generateSentinelToken(ctx context.Context, account *Account, accessToken string) (string, error) {
reqID := uuid.NewString() reqID := uuid.NewString()
userAgent := soraRandChoice(soraDesktopUserAgents) userAgent := soraRandChoice(soraDesktopUserAgents)
@@ -901,3 +1402,70 @@ func sanitizeSoraLogURL(raw string) string {
parsed.RawQuery = q.Encode() parsed.RawQuery = q.Encode()
return parsed.String() return parsed.String()
} }
func (c *SoraDirectClient) debugEnabled() bool {
return c != nil && c.cfg != nil && c.cfg.Sora.Client.Debug
}
func (c *SoraDirectClient) debugLogf(format string, args ...any) {
if !c.debugEnabled() {
return
}
log.Printf("[SoraClient] "+format, args...)
}
func formatSoraHeaders(headers http.Header) string {
if len(headers) == 0 {
return "{}"
}
keys := make([]string, 0, len(headers))
for key := range headers {
keys = append(keys, key)
}
sort.Strings(keys)
out := make(map[string]string, len(keys))
for _, key := range keys {
values := headers.Values(key)
if len(values) == 0 {
continue
}
val := strings.Join(values, ",")
if isSensitiveHeader(key) {
out[key] = "***"
continue
}
out[key] = truncateForLog([]byte(logredact.RedactText(val)), 160)
}
encoded, err := json.Marshal(out)
if err != nil {
return "{}"
}
return string(encoded)
}
func isSensitiveHeader(key string) bool {
k := strings.ToLower(strings.TrimSpace(key))
switch k {
case "authorization", "openai-sentinel-token", "cookie", "set-cookie", "x-api-key":
return true
default:
return false
}
}
func summarizeSoraResponseBody(body []byte, maxLen int) string {
if len(body) == 0 {
return ""
}
var text string
if json.Valid(body) {
text = logredact.RedactJSON(body)
} else {
text = logredact.RedactText(string(body))
}
text = strings.TrimSpace(text)
if maxLen <= 0 || len(text) <= maxLen {
return text
}
return text[:maxLen] + "...(truncated)"
}

View File

@@ -4,9 +4,13 @@ package service
import ( import (
"context" "context"
"encoding/json"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings"
"sync/atomic"
"testing" "testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -85,3 +89,273 @@ func TestSoraDirectClient_GetImageTaskFallbackLimit(t *testing.T) {
require.Equal(t, "completed", status.Status) require.Equal(t, "completed", status.Status)
require.Equal(t, []string{"https://example.com/a.png"}, status.URLs) require.Equal(t, []string{"https://example.com/a.png"}, status.URLs)
} }
func TestNormalizeSoraBaseURL(t *testing.T) {
t.Parallel()
tests := []struct {
name string
raw string
want string
}{
{
name: "empty",
raw: "",
want: "",
},
{
name: "append_backend_for_sora_host",
raw: "https://sora.chatgpt.com",
want: "https://sora.chatgpt.com/backend",
},
{
name: "convert_backend_api_to_backend",
raw: "https://sora.chatgpt.com/backend-api",
want: "https://sora.chatgpt.com/backend",
},
{
name: "keep_backend",
raw: "https://sora.chatgpt.com/backend",
want: "https://sora.chatgpt.com/backend",
},
{
name: "keep_custom_host",
raw: "https://example.com/custom-path",
want: "https://example.com/custom-path",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := normalizeSoraBaseURL(tt.raw)
require.Equal(t, tt.want, got)
})
}
}
func TestSoraDirectClient_BuildURL_UsesNormalizedBaseURL(t *testing.T) {
t.Parallel()
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: "https://sora.chatgpt.com",
},
},
}
client := NewSoraDirectClient(cfg, nil, nil)
require.Equal(t, "https://sora.chatgpt.com/backend/video_gen", client.buildURL("/video_gen"))
}
func TestSoraDirectClient_BuildUpstreamError_NotFoundHint(t *testing.T) {
t.Parallel()
client := NewSoraDirectClient(&config.Config{}, nil, nil)
err := client.buildUpstreamError(http.StatusNotFound, http.Header{}, []byte(`{"error":{"message":"Not found"}}`), "https://sora.chatgpt.com/video_gen")
var upstreamErr *SoraUpstreamError
require.ErrorAs(t, err, &upstreamErr)
require.Contains(t, upstreamErr.Message, "请检查 sora.client.base_url")
errNoHint := client.buildUpstreamError(http.StatusNotFound, http.Header{}, []byte(`{"error":{"message":"Not found"}}`), "https://sora.chatgpt.com/backend/video_gen")
require.ErrorAs(t, errNoHint, &upstreamErr)
require.NotContains(t, upstreamErr.Message, "请检查 sora.client.base_url")
}
func TestFormatSoraHeaders_RedactsSensitive(t *testing.T) {
t.Parallel()
headers := http.Header{}
headers.Set("Authorization", "Bearer secret-token")
headers.Set("openai-sentinel-token", "sentinel-secret")
headers.Set("X-Test", "ok")
out := formatSoraHeaders(headers)
require.Contains(t, out, `"Authorization":"***"`)
require.Contains(t, out, `Sentinel-Token":"***"`)
require.Contains(t, out, `"X-Test":"ok"`)
require.NotContains(t, out, "secret-token")
require.NotContains(t, out, "sentinel-secret")
}
func TestSummarizeSoraResponseBody_RedactsJSON(t *testing.T) {
t.Parallel()
body := []byte(`{"error":{"message":"bad"},"access_token":"abc123"}`)
out := summarizeSoraResponseBody(body, 512)
require.Contains(t, out, `"access_token":"***"`)
require.NotContains(t, out, "abc123")
}
func TestSummarizeSoraResponseBody_Truncates(t *testing.T) {
t.Parallel()
body := []byte(strings.Repeat("x", 100))
out := summarizeSoraResponseBody(body, 10)
require.Contains(t, out, "(truncated)")
}
func TestSoraDirectClient_GetAccessToken_SoraDefaultUseCredentials(t *testing.T) {
t.Parallel()
cache := newOpenAITokenCacheStub()
provider := NewOpenAITokenProvider(nil, cache, nil)
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: "https://sora.chatgpt.com/backend",
},
},
}
client := NewSoraDirectClient(cfg, nil, provider)
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "sora-credential-token",
},
}
token, err := client.getAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "sora-credential-token", token)
require.Equal(t, int32(0), atomic.LoadInt32(&cache.getCalled))
}
func TestSoraDirectClient_GetAccessToken_SoraCanEnableProvider(t *testing.T) {
t.Parallel()
cache := newOpenAITokenCacheStub()
account := &Account{
ID: 2,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "sora-credential-token",
},
}
cache.tokens[OpenAITokenCacheKey(account)] = "provider-token"
provider := NewOpenAITokenProvider(nil, cache, nil)
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: "https://sora.chatgpt.com/backend",
UseOpenAITokenProvider: true,
},
},
}
client := NewSoraDirectClient(cfg, nil, provider)
token, err := client.getAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "provider-token", token)
require.Greater(t, atomic.LoadInt32(&cache.getCalled), int32(0))
}
func TestSoraDirectClient_GetAccessToken_FromSessionToken(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=session-token")
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"accessToken": "session-access-token",
"expires": "2099-01-01T00:00:00Z",
})
}))
defer server.Close()
origin := soraSessionAuthURL
soraSessionAuthURL = server.URL
defer func() { soraSessionAuthURL = origin }()
client := NewSoraDirectClient(&config.Config{}, nil, nil)
account := &Account{
ID: 10,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"session_token": "session-token",
},
}
token, err := client.getAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "session-access-token", token)
require.Equal(t, "session-access-token", account.GetCredential("access_token"))
}
func TestSoraDirectClient_GetAccessToken_FromRefreshToken(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodPost, r.Method)
require.Equal(t, "/oauth/token", r.URL.Path)
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"access_token": "refresh-access-token",
"refresh_token": "refresh-token-new",
"expires_in": 3600,
})
}))
defer server.Close()
origin := soraOAuthTokenURL
soraOAuthTokenURL = server.URL + "/oauth/token"
defer func() { soraOAuthTokenURL = origin }()
client := NewSoraDirectClient(&config.Config{}, nil, nil)
account := &Account{
ID: 11,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"refresh_token": "refresh-token-old",
},
}
token, err := client.getAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "refresh-access-token", token)
require.Equal(t, "refresh-token-new", account.GetCredential("refresh_token"))
require.NotNil(t, account.GetCredentialAsTime("expires_at"))
}
func TestSoraDirectClient_PreflightCheck_VideoQuotaExceeded(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
require.Equal(t, "/nf/check", r.URL.Path)
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"rate_limit_and_credit_balance": map[string]any{
"estimated_num_videos_remaining": 0,
"rate_limit_reached": true,
},
})
}))
defer server.Close()
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: server.URL,
},
},
}
client := NewSoraDirectClient(cfg, nil, nil)
account := &Account{
ID: 12,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "ok",
"expires_at": time.Now().Add(2 * time.Hour).Format(time.RFC3339),
},
}
err := client.PreflightCheck(context.Background(), account, "sora2-landscape-10s", SoraModelConfig{Type: "video"})
require.Error(t, err)
var upstreamErr *SoraUpstreamError
require.ErrorAs(t, err, &upstreamErr)
require.Equal(t, http.StatusTooManyRequests, upstreamErr.StatusCode)
}
func TestShouldAttemptSoraTokenRecover(t *testing.T) {
t.Parallel()
require.True(t, shouldAttemptSoraTokenRecover(http.StatusUnauthorized, "https://sora.chatgpt.com/backend/video_gen"))
require.True(t, shouldAttemptSoraTokenRecover(http.StatusForbidden, "https://chatgpt.com/backend/video_gen"))
require.False(t, shouldAttemptSoraTokenRecover(http.StatusUnauthorized, "https://sora.chatgpt.com/api/auth/session"))
require.False(t, shouldAttemptSoraTokenRecover(http.StatusUnauthorized, "https://auth.openai.com/oauth/token"))
require.False(t, shouldAttemptSoraTokenRecover(http.StatusTooManyRequests, "https://sora.chatgpt.com/backend/video_gen"))
}

View File

@@ -61,6 +61,10 @@ type SoraGatewayService struct {
cfg *config.Config cfg *config.Config
} }
type soraPreflightChecker interface {
PreflightCheck(ctx context.Context, account *Account, requestedModel string, modelCfg SoraModelConfig) error
}
func NewSoraGatewayService( func NewSoraGatewayService(
soraClient SoraClient, soraClient SoraClient,
mediaStorage *SoraMediaStorage, mediaStorage *SoraMediaStorage,
@@ -112,11 +116,6 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Unsupported Sora model", clientStream) s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Unsupported Sora model", clientStream)
return nil, fmt.Errorf("unsupported model: %s", reqModel) return nil, fmt.Errorf("unsupported model: %s", reqModel)
} }
if modelCfg.Type == "prompt_enhance" {
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Prompt-enhance 模型暂未支持", clientStream)
return nil, fmt.Errorf("prompt-enhance not supported")
}
prompt, imageInput, videoInput, remixTargetID := extractSoraInput(reqBody) prompt, imageInput, videoInput, remixTargetID := extractSoraInput(reqBody)
if strings.TrimSpace(prompt) == "" { if strings.TrimSpace(prompt) == "" {
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream) s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream)
@@ -131,6 +130,41 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
if cancel != nil { if cancel != nil {
defer cancel() defer cancel()
} }
if checker, ok := s.soraClient.(soraPreflightChecker); ok {
if err := checker.PreflightCheck(reqCtx, account, reqModel, modelCfg); err != nil {
return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream)
}
}
if modelCfg.Type == "prompt_enhance" {
enhancedPrompt, err := s.soraClient.EnhancePrompt(reqCtx, account, prompt, modelCfg.ExpansionLevel, modelCfg.DurationS)
if err != nil {
return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream)
}
content := strings.TrimSpace(enhancedPrompt)
if content == "" {
content = prompt
}
var firstTokenMs *int
if clientStream {
ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime)
if streamErr != nil {
return nil, streamErr
}
firstTokenMs = ms
} else if c != nil {
c.JSON(http.StatusOK, buildSoraNonStreamResponse(content, reqModel))
}
return &ForwardResult{
RequestID: "",
Model: reqModel,
Stream: clientStream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
Usage: ClaudeUsage{},
MediaType: "prompt",
}, nil
}
var imageData []byte var imageData []byte
imageFilename := "" imageFilename := ""
@@ -267,7 +301,7 @@ func (s *SoraGatewayService) withSoraTimeout(ctx context.Context, stream bool) (
func (s *SoraGatewayService) shouldFailoverUpstreamError(statusCode int) bool { func (s *SoraGatewayService) shouldFailoverUpstreamError(statusCode int) bool {
switch statusCode { switch statusCode {
case 401, 402, 403, 429, 529: case 401, 402, 403, 404, 429, 529:
return true return true
default: default:
return statusCode >= 500 return statusCode >= 500
@@ -460,7 +494,7 @@ func (s *SoraGatewayService) handleSoraRequestError(ctx context.Context, account
s.rateLimitService.HandleUpstreamError(ctx, account, upstreamErr.StatusCode, upstreamErr.Headers, upstreamErr.Body) s.rateLimitService.HandleUpstreamError(ctx, account, upstreamErr.StatusCode, upstreamErr.Headers, upstreamErr.Body)
} }
if s.shouldFailoverUpstreamError(upstreamErr.StatusCode) { if s.shouldFailoverUpstreamError(upstreamErr.StatusCode) {
return &UpstreamFailoverError{StatusCode: upstreamErr.StatusCode} return &UpstreamFailoverError{StatusCode: upstreamErr.StatusCode, ResponseBody: upstreamErr.Body}
} }
msg := upstreamErr.Message msg := upstreamErr.Message
if override := soraProErrorMessage(model, msg); override != "" { if override := soraProErrorMessage(model, msg); override != "" {

View File

@@ -18,6 +18,8 @@ type stubSoraClientForPoll struct {
videoStatus *SoraVideoTaskStatus videoStatus *SoraVideoTaskStatus
imageCalls int imageCalls int
videoCalls int videoCalls int
enhanced string
enhanceErr error
} }
func (s *stubSoraClientForPoll) Enabled() bool { return true } func (s *stubSoraClientForPoll) Enabled() bool { return true }
@@ -30,6 +32,12 @@ func (s *stubSoraClientForPoll) CreateImageTask(ctx context.Context, account *Ac
func (s *stubSoraClientForPoll) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) { func (s *stubSoraClientForPoll) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) {
return "task-video", nil return "task-video", nil
} }
func (s *stubSoraClientForPoll) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) {
if s.enhanced != "" {
return s.enhanced, s.enhanceErr
}
return "enhanced prompt", s.enhanceErr
}
func (s *stubSoraClientForPoll) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) { func (s *stubSoraClientForPoll) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) {
s.imageCalls++ s.imageCalls++
return s.imageStatus, nil return s.imageStatus, nil
@@ -62,6 +70,33 @@ func TestSoraGatewayService_PollImageTaskCompleted(t *testing.T) {
require.Equal(t, 1, client.imageCalls) require.Equal(t, 1, client.imageCalls)
} }
func TestSoraGatewayService_ForwardPromptEnhance(t *testing.T) {
client := &stubSoraClientForPoll{
enhanced: "cinematic prompt",
}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
PollIntervalSeconds: 1,
MaxPollAttempts: 1,
},
},
}
svc := NewSoraGatewayService(client, nil, nil, cfg)
account := &Account{
ID: 1,
Platform: PlatformSora,
Status: StatusActive,
}
body := []byte(`{"model":"prompt-enhance-short-10s","messages":[{"role":"user","content":"cat running"}],"stream":false}`)
result, err := svc.Forward(context.Background(), nil, account, body, false)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "prompt", result.MediaType)
require.Equal(t, "prompt-enhance-short-10s", result.Model)
}
func TestSoraGatewayService_PollVideoTaskFailed(t *testing.T) { func TestSoraGatewayService_PollVideoTaskFailed(t *testing.T) {
client := &stubSoraClientForPoll{ client := &stubSoraClientForPoll{
videoStatus: &SoraVideoTaskStatus{ videoStatus: &SoraVideoTaskStatus{
@@ -178,6 +213,7 @@ func TestSoraProErrorMessage(t *testing.T) {
func TestShouldFailoverUpstreamError(t *testing.T) { func TestShouldFailoverUpstreamError(t *testing.T) {
svc := NewSoraGatewayService(nil, nil, nil, &config.Config{}) svc := NewSoraGatewayService(nil, nil, nil, &config.Config{})
require.True(t, svc.shouldFailoverUpstreamError(401)) require.True(t, svc.shouldFailoverUpstreamError(401))
require.True(t, svc.shouldFailoverUpstreamError(404))
require.True(t, svc.shouldFailoverUpstreamError(429)) require.True(t, svc.shouldFailoverUpstreamError(429))
require.True(t, svc.shouldFailoverUpstreamError(500)) require.True(t, svc.shouldFailoverUpstreamError(500))
require.True(t, svc.shouldFailoverUpstreamError(502)) require.True(t, svc.shouldFailoverUpstreamError(502))

View File

@@ -17,6 +17,9 @@ type SoraModelConfig struct {
Model string Model string
Size string Size string
RequirePro bool RequirePro bool
// Prompt-enhance 专用参数
ExpansionLevel string
DurationS int
} }
var soraModelConfigs = map[string]SoraModelConfig{ var soraModelConfigs = map[string]SoraModelConfig{
@@ -161,30 +164,48 @@ var soraModelConfigs = map[string]SoraModelConfig{
}, },
"prompt-enhance-short-10s": { "prompt-enhance-short-10s": {
Type: "prompt_enhance", Type: "prompt_enhance",
ExpansionLevel: "short",
DurationS: 10,
}, },
"prompt-enhance-short-15s": { "prompt-enhance-short-15s": {
Type: "prompt_enhance", Type: "prompt_enhance",
ExpansionLevel: "short",
DurationS: 15,
}, },
"prompt-enhance-short-20s": { "prompt-enhance-short-20s": {
Type: "prompt_enhance", Type: "prompt_enhance",
ExpansionLevel: "short",
DurationS: 20,
}, },
"prompt-enhance-medium-10s": { "prompt-enhance-medium-10s": {
Type: "prompt_enhance", Type: "prompt_enhance",
ExpansionLevel: "medium",
DurationS: 10,
}, },
"prompt-enhance-medium-15s": { "prompt-enhance-medium-15s": {
Type: "prompt_enhance", Type: "prompt_enhance",
ExpansionLevel: "medium",
DurationS: 15,
}, },
"prompt-enhance-medium-20s": { "prompt-enhance-medium-20s": {
Type: "prompt_enhance", Type: "prompt_enhance",
ExpansionLevel: "medium",
DurationS: 20,
}, },
"prompt-enhance-long-10s": { "prompt-enhance-long-10s": {
Type: "prompt_enhance", Type: "prompt_enhance",
ExpansionLevel: "long",
DurationS: 10,
}, },
"prompt-enhance-long-15s": { "prompt-enhance-long-15s": {
Type: "prompt_enhance", Type: "prompt_enhance",
ExpansionLevel: "long",
DurationS: 15,
}, },
"prompt-enhance-long-20s": { "prompt-enhance-long-20s": {
Type: "prompt_enhance", Type: "prompt_enhance",
ExpansionLevel: "long",
DurationS: 20,
}, },
} }

View File

@@ -43,10 +43,13 @@ func NewTokenRefreshService(
stopCh: make(chan struct{}), stopCh: make(chan struct{}),
} }
openAIRefresher := NewOpenAITokenRefresher(openaiOAuthService, accountRepo)
openAIRefresher.SetSyncLinkedSoraAccounts(cfg.TokenRefresh.SyncLinkedSoraAccounts)
// 注册平台特定的刷新器 // 注册平台特定的刷新器
s.refreshers = []TokenRefresher{ s.refreshers = []TokenRefresher{
NewClaudeTokenRefresher(oauthService), NewClaudeTokenRefresher(oauthService),
NewOpenAITokenRefresher(openaiOAuthService, accountRepo), openAIRefresher,
NewGeminiTokenRefresher(geminiOAuthService), NewGeminiTokenRefresher(geminiOAuthService),
NewAntigravityTokenRefresher(antigravityOAuthService), NewAntigravityTokenRefresher(antigravityOAuthService),
} }

View File

@@ -86,6 +86,7 @@ type OpenAITokenRefresher struct {
openaiOAuthService *OpenAIOAuthService openaiOAuthService *OpenAIOAuthService
accountRepo AccountRepository accountRepo AccountRepository
soraAccountRepo SoraAccountRepository // Sora 扩展表仓储,用于双表同步 soraAccountRepo SoraAccountRepository // Sora 扩展表仓储,用于双表同步
syncLinkedSora bool
} }
// NewOpenAITokenRefresher 创建 OpenAI token刷新器 // NewOpenAITokenRefresher 创建 OpenAI token刷新器
@@ -103,11 +104,15 @@ func (r *OpenAITokenRefresher) SetSoraAccountRepo(repo SoraAccountRepository) {
r.soraAccountRepo = repo r.soraAccountRepo = repo
} }
// SetSyncLinkedSoraAccounts 控制是否同步覆盖关联的 Sora 账号 token。
func (r *OpenAITokenRefresher) SetSyncLinkedSoraAccounts(enabled bool) {
r.syncLinkedSora = enabled
}
// CanRefresh 检查是否能处理此账号 // CanRefresh 检查是否能处理此账号
// 只处理 openai 平台的 oauth 类型账号 // 只处理 openai 平台的 oauth 类型账号(不直接刷新 sora 平台账号)
func (r *OpenAITokenRefresher) CanRefresh(account *Account) bool { func (r *OpenAITokenRefresher) CanRefresh(account *Account) bool {
return (account.Platform == PlatformOpenAI || account.Platform == PlatformSora) && return account.Platform == PlatformOpenAI && account.Type == AccountTypeOAuth
account.Type == AccountTypeOAuth
} }
// NeedsRefresh 检查token是否需要刷新 // NeedsRefresh 检查token是否需要刷新
@@ -141,7 +146,7 @@ func (r *OpenAITokenRefresher) Refresh(ctx context.Context, account *Account) (m
} }
// 异步同步关联的 Sora 账号(不阻塞主流程) // 异步同步关联的 Sora 账号(不阻塞主流程)
if r.accountRepo != nil { if r.accountRepo != nil && r.syncLinkedSora {
go r.syncLinkedSoraAccounts(context.Background(), account.ID, newCredentials) go r.syncLinkedSoraAccounts(context.Background(), account.ID, newCredentials)
} }

View File

@@ -226,3 +226,43 @@ func TestClaudeTokenRefresher_CanRefresh(t *testing.T) {
}) })
} }
} }
func TestOpenAITokenRefresher_CanRefresh(t *testing.T) {
refresher := &OpenAITokenRefresher{}
tests := []struct {
name string
platform string
accType string
want bool
}{
{
name: "openai oauth - can refresh",
platform: PlatformOpenAI,
accType: AccountTypeOAuth,
want: true,
},
{
name: "sora oauth - cannot refresh directly",
platform: PlatformSora,
accType: AccountTypeOAuth,
want: false,
},
{
name: "openai apikey - cannot refresh",
platform: PlatformOpenAI,
accType: AccountTypeAPIKey,
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
account := &Account{
Platform: tt.platform,
Type: tt.accType,
}
require.Equal(t, tt.want, refresher.CanRefresh(account))
})
}
}

View File

@@ -206,6 +206,18 @@ func ProvideSoraMediaStorage(cfg *config.Config) *SoraMediaStorage {
return NewSoraMediaStorage(cfg) return NewSoraMediaStorage(cfg)
} }
func ProvideSoraDirectClient(
cfg *config.Config,
httpUpstream HTTPUpstream,
tokenProvider *OpenAITokenProvider,
accountRepo AccountRepository,
soraAccountRepo SoraAccountRepository,
) *SoraDirectClient {
client := NewSoraDirectClient(cfg, httpUpstream, tokenProvider)
client.SetAccountRepositories(accountRepo, soraAccountRepo)
return client
}
// ProvideSoraMediaCleanupService 创建并启动 Sora 媒体清理服务 // ProvideSoraMediaCleanupService 创建并启动 Sora 媒体清理服务
func ProvideSoraMediaCleanupService(storage *SoraMediaStorage, cfg *config.Config) *SoraMediaCleanupService { func ProvideSoraMediaCleanupService(storage *SoraMediaStorage, cfg *config.Config) *SoraMediaCleanupService {
svc := NewSoraMediaCleanupService(storage, cfg) svc := NewSoraMediaCleanupService(storage, cfg)
@@ -255,7 +267,7 @@ var ProviderSet = wire.NewSet(
NewGatewayService, NewGatewayService,
ProvideSoraMediaStorage, ProvideSoraMediaStorage,
ProvideSoraMediaCleanupService, ProvideSoraMediaCleanupService,
NewSoraDirectClient, ProvideSoraDirectClient,
wire.Bind(new(SoraClient), new(*SoraDirectClient)), wire.Bind(new(SoraClient), new(*SoraDirectClient)),
NewSoraGatewayService, NewSoraGatewayService,
NewOpenAIGatewayService, NewOpenAIGatewayService,

View File

@@ -86,6 +86,7 @@ func (s *FrontendServer) Middleware() gin.HandlerFunc {
if strings.HasPrefix(path, "/api/") || if strings.HasPrefix(path, "/api/") ||
strings.HasPrefix(path, "/v1/") || strings.HasPrefix(path, "/v1/") ||
strings.HasPrefix(path, "/v1beta/") || strings.HasPrefix(path, "/v1beta/") ||
strings.HasPrefix(path, "/sora/") ||
strings.HasPrefix(path, "/antigravity/") || strings.HasPrefix(path, "/antigravity/") ||
strings.HasPrefix(path, "/setup/") || strings.HasPrefix(path, "/setup/") ||
path == "/health" || path == "/health" ||
@@ -209,6 +210,7 @@ func ServeEmbeddedFrontend() gin.HandlerFunc {
if strings.HasPrefix(path, "/api/") || if strings.HasPrefix(path, "/api/") ||
strings.HasPrefix(path, "/v1/") || strings.HasPrefix(path, "/v1/") ||
strings.HasPrefix(path, "/v1beta/") || strings.HasPrefix(path, "/v1beta/") ||
strings.HasPrefix(path, "/sora/") ||
strings.HasPrefix(path, "/antigravity/") || strings.HasPrefix(path, "/antigravity/") ||
strings.HasPrefix(path, "/setup/") || strings.HasPrefix(path, "/setup/") ||
path == "/health" || path == "/health" ||

View File

@@ -362,6 +362,7 @@ func TestFrontendServer_Middleware(t *testing.T) {
"/api/v1/users", "/api/v1/users",
"/v1/models", "/v1/models",
"/v1beta/chat", "/v1beta/chat",
"/sora/v1/models",
"/antigravity/test", "/antigravity/test",
"/setup/init", "/setup/init",
"/health", "/health",
@@ -537,6 +538,7 @@ func TestServeEmbeddedFrontend(t *testing.T) {
"/api/users", "/api/users",
"/v1/models", "/v1/models",
"/v1beta/chat", "/v1beta/chat",
"/sora/v1/models",
"/antigravity/test", "/antigravity/test",
"/setup/init", "/setup/init",
"/health", "/health",

View File

@@ -388,7 +388,11 @@ sora:
recent_task_limit_max: 200 recent_task_limit_max: 200
# Enable debug logs for Sora upstream requests # Enable debug logs for Sora upstream requests
# 启用 Sora 直连调试日志 # 启用 Sora 直连调试日志
# 调试日志会输出上游请求尝试、重试、响应摘要Authorization/openai-sentinel-token 等敏感头会自动脱敏
debug: false debug: false
# Allow Sora client to fetch token via OpenAI token provider
# 是否允许 Sora 客户端通过 OpenAI token provider 取 token默认 false避免误走 OpenAI 刷新链路)
use_openai_token_provider: false
# Optional custom headers (key-value) # Optional custom headers (key-value)
# 额外请求头(键值对) # 额外请求头(键值对)
headers: {} headers: {}
@@ -431,6 +435,13 @@ sora:
# Cron 调度表达式 # Cron 调度表达式
schedule: "0 3 * * *" schedule: "0 3 * * *"
# Token refresh behavior
# token 刷新行为控制
token_refresh:
# Whether OpenAI refresh flow is allowed to sync linked Sora accounts
# 是否允许 OpenAI 刷新流程同步覆盖 linked_openai_account_id 关联的 Sora 账号 token
sync_linked_sora_accounts: false
# ============================================================================= # =============================================================================
# API Key Auth Cache Configuration # API Key Auth Cache Configuration
# API Key 认证缓存配置 # API Key 认证缓存配置

View File

@@ -220,7 +220,7 @@ export async function generateAuthUrl(
*/ */
export async function exchangeCode( export async function exchangeCode(
endpoint: string, endpoint: string,
exchangeData: { session_id: string; code: string; proxy_id?: number } exchangeData: { session_id: string; code: string; state?: string; proxy_id?: number }
): Promise<Record<string, unknown>> { ): Promise<Record<string, unknown>> {
const { data } = await apiClient.post<Record<string, unknown>>(endpoint, exchangeData) const { data } = await apiClient.post<Record<string, unknown>>(endpoint, exchangeData)
return data return data
@@ -442,7 +442,8 @@ export async function getAntigravityDefaultModelMapping(): Promise<Record<string
*/ */
export async function refreshOpenAIToken( export async function refreshOpenAIToken(
refreshToken: string, refreshToken: string,
proxyId?: number | null proxyId?: number | null,
endpoint: string = '/admin/openai/refresh-token'
): Promise<Record<string, unknown>> { ): Promise<Record<string, unknown>> {
const payload: { refresh_token: string; proxy_id?: number } = { const payload: { refresh_token: string; proxy_id?: number } = {
refresh_token: refreshToken refresh_token: refreshToken
@@ -450,7 +451,29 @@ export async function refreshOpenAIToken(
if (proxyId) { if (proxyId) {
payload.proxy_id = proxyId payload.proxy_id = proxyId
} }
const { data } = await apiClient.post<Record<string, unknown>>('/admin/openai/refresh-token', payload) const { data } = await apiClient.post<Record<string, unknown>>(endpoint, payload)
return data
}
/**
* Validate Sora session token and exchange to access token
* @param sessionToken - Sora session token
* @param proxyId - Optional proxy ID
* @param endpoint - API endpoint path
* @returns Token information including access_token
*/
export async function validateSoraSessionToken(
sessionToken: string,
proxyId?: number | null,
endpoint: string = '/admin/sora/st2at'
): Promise<Record<string, unknown>> {
const payload: { session_token: string; proxy_id?: number } = {
session_token: sessionToken
}
if (proxyId) {
payload.proxy_id = proxyId
}
const { data } = await apiClient.post<Record<string, unknown>>(endpoint, payload)
return data return data
} }
@@ -475,6 +498,7 @@ export const accountsAPI = {
generateAuthUrl, generateAuthUrl,
exchangeCode, exchangeCode,
refreshOpenAIToken, refreshOpenAIToken,
validateSoraSessionToken,
batchCreate, batchCreate,
batchUpdateCredentials, batchUpdateCredentials,
bulkUpdate, bulkUpdate,

View File

@@ -109,6 +109,28 @@
</svg> </svg>
OpenAI OpenAI
</button> </button>
<button
type="button"
@click="form.platform = 'sora'"
:class="[
'flex flex-1 items-center justify-center gap-2 rounded-md px-4 py-2.5 text-sm font-medium transition-all',
form.platform === 'sora'
? 'bg-white text-rose-600 shadow-sm dark:bg-dark-600 dark:text-rose-400'
: 'text-gray-600 hover:text-gray-900 dark:text-gray-400 dark:hover:text-gray-200'
]"
>
<svg
class="h-4 w-4"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path stroke-linecap="round" stroke-linejoin="round" d="M14.752 11.168l-3.197-2.132A1 1 0 0010 9.87v4.263a1 1 0 001.555.832l3.197-2.132a1 1 0 000-1.664z" />
<path stroke-linecap="round" stroke-linejoin="round" d="M21 12a9 9 0 11-18 0 9 9 0 0118 0z" />
</svg>
Sora
</button>
<button <button
type="button" type="button"
@click="form.platform = 'gemini'" @click="form.platform = 'gemini'"
@@ -150,6 +172,38 @@
</div> </div>
</div> </div>
<!-- Account Type Selection (Sora) -->
<div v-if="form.platform === 'sora'">
<label class="input-label">{{ t('admin.accounts.accountType') }}</label>
<div class="mt-2 grid grid-cols-1 gap-3" data-tour="account-form-type">
<button
type="button"
@click="accountCategory = 'oauth-based'"
:class="[
'flex items-center gap-3 rounded-lg border-2 p-3 text-left transition-all',
accountCategory === 'oauth-based'
? 'border-rose-500 bg-rose-50 dark:bg-rose-900/20'
: 'border-gray-200 hover:border-rose-300 dark:border-dark-600 dark:hover:border-rose-700'
]"
>
<div
:class="[
'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
accountCategory === 'oauth-based'
? 'bg-rose-500 text-white'
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
]"
>
<Icon name="key" size="sm" />
</div>
<div>
<span class="block text-sm font-medium text-gray-900 dark:text-white">OAuth</span>
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('admin.accounts.types.chatgptOauth') }}</span>
</div>
</button>
</div>
</div>
<!-- Account Type Selection (Anthropic) --> <!-- Account Type Selection (Anthropic) -->
<div v-if="form.platform === 'anthropic'"> <div v-if="form.platform === 'anthropic'">
<label class="input-label">{{ t('admin.accounts.accountType') }}</label> <label class="input-label">{{ t('admin.accounts.accountType') }}</label>
@@ -1747,32 +1801,6 @@
<!-- Step 2: OAuth Authorization --> <!-- Step 2: OAuth Authorization -->
<div v-else class="space-y-5"> <div v-else class="space-y-5">
<!-- 同时启用 Sora 开关 ( OpenAI OAuth) -->
<div v-if="form.platform === 'openai' && accountCategory === 'oauth-based'" class="mb-4">
<label class="flex items-center justify-between rounded-lg border border-gray-200 p-3 dark:border-dark-600">
<div class="flex items-center gap-3">
<div class="flex h-8 w-8 shrink-0 items-center justify-center rounded-lg bg-rose-100 text-rose-600 dark:bg-rose-900/30 dark:text-rose-400">
<svg class="h-4 w-4" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<path stroke-linecap="round" stroke-linejoin="round" d="M14.752 11.168l-3.197-2.132A1 1 0 0010 9.87v4.263a1 1 0 001.555.832l3.197-2.132a1 1 0 000-1.664z" />
<path stroke-linecap="round" stroke-linejoin="round" d="M21 12a9 9 0 11-18 0 9 9 0 0118 0z" />
</svg>
</div>
<div>
<span class="block text-sm font-medium text-gray-900 dark:text-white">
{{ t('admin.accounts.openai.enableSora') }}
</span>
<span class="text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.openai.enableSoraHint') }}
</span>
</div>
</div>
<label :class="['switch', { 'switch-active': enableSoraOnOpenAIOAuth }]">
<input type="checkbox" v-model="enableSoraOnOpenAIOAuth" class="sr-only" />
<span class="switch-thumb"></span>
</label>
</label>
</div>
<OAuthAuthorizationFlow <OAuthAuthorizationFlow
ref="oauthFlowRef" ref="oauthFlowRef"
:add-method="form.platform === 'anthropic' ? addMethod : 'oauth'" :add-method="form.platform === 'anthropic' ? addMethod : 'oauth'"
@@ -1781,15 +1809,17 @@
:loading="currentOAuthLoading" :loading="currentOAuthLoading"
:error="currentOAuthError" :error="currentOAuthError"
:show-help="form.platform === 'anthropic'" :show-help="form.platform === 'anthropic'"
:show-proxy-warning="form.platform !== 'openai' && !!form.proxy_id" :show-proxy-warning="form.platform !== 'openai' && form.platform !== 'sora' && !!form.proxy_id"
:allow-multiple="form.platform === 'anthropic'" :allow-multiple="form.platform === 'anthropic'"
:show-cookie-option="form.platform === 'anthropic'" :show-cookie-option="form.platform === 'anthropic'"
:show-refresh-token-option="form.platform === 'openai' || form.platform === 'antigravity'" :show-refresh-token-option="form.platform === 'openai' || form.platform === 'sora' || form.platform === 'antigravity'"
:show-session-token-option="form.platform === 'sora'"
:platform="form.platform" :platform="form.platform"
:show-project-id="geminiOAuthType === 'code_assist'" :show-project-id="geminiOAuthType === 'code_assist'"
@generate-url="handleGenerateUrl" @generate-url="handleGenerateUrl"
@cookie-auth="handleCookieAuth" @cookie-auth="handleCookieAuth"
@validate-refresh-token="handleValidateRefreshToken" @validate-refresh-token="handleValidateRefreshToken"
@validate-session-token="handleValidateSessionToken"
/> />
</div> </div>
@@ -2148,6 +2178,7 @@ interface OAuthFlowExposed {
projectId: string projectId: string
sessionKey: string sessionKey: string
refreshToken: string refreshToken: string
sessionToken: string
inputMethod: AuthInputMethod inputMethod: AuthInputMethod
reset: () => void reset: () => void
} }
@@ -2156,7 +2187,7 @@ const { t } = useI18n()
const authStore = useAuthStore() const authStore = useAuthStore()
const oauthStepTitle = computed(() => { const oauthStepTitle = computed(() => {
if (form.platform === 'openai') return t('admin.accounts.oauth.openai.title') if (form.platform === 'openai' || form.platform === 'sora') return t('admin.accounts.oauth.openai.title')
if (form.platform === 'gemini') return t('admin.accounts.oauth.gemini.title') if (form.platform === 'gemini') return t('admin.accounts.oauth.gemini.title')
if (form.platform === 'antigravity') return t('admin.accounts.oauth.antigravity.title') if (form.platform === 'antigravity') return t('admin.accounts.oauth.antigravity.title')
return t('admin.accounts.oauth.title') return t('admin.accounts.oauth.title')
@@ -2164,13 +2195,13 @@ const oauthStepTitle = computed(() => {
// Platform-specific hints for API Key type // Platform-specific hints for API Key type
const baseUrlHint = computed(() => { const baseUrlHint = computed(() => {
if (form.platform === 'openai') return t('admin.accounts.openai.baseUrlHint') if (form.platform === 'openai' || form.platform === 'sora') return t('admin.accounts.openai.baseUrlHint')
if (form.platform === 'gemini') return t('admin.accounts.gemini.baseUrlHint') if (form.platform === 'gemini') return t('admin.accounts.gemini.baseUrlHint')
return t('admin.accounts.baseUrlHint') return t('admin.accounts.baseUrlHint')
}) })
const apiKeyHint = computed(() => { const apiKeyHint = computed(() => {
if (form.platform === 'openai') return t('admin.accounts.openai.apiKeyHint') if (form.platform === 'openai' || form.platform === 'sora') return t('admin.accounts.openai.apiKeyHint')
if (form.platform === 'gemini') return t('admin.accounts.gemini.apiKeyHint') if (form.platform === 'gemini') return t('admin.accounts.gemini.apiKeyHint')
return t('admin.accounts.apiKeyHint') return t('admin.accounts.apiKeyHint')
}) })
@@ -2191,34 +2222,36 @@ const appStore = useAppStore()
// OAuth composables // OAuth composables
const oauth = useAccountOAuth() // For Anthropic OAuth const oauth = useAccountOAuth() // For Anthropic OAuth
const openaiOAuth = useOpenAIOAuth() // For OpenAI OAuth const openaiOAuth = useOpenAIOAuth({ platform: 'openai' }) // For OpenAI OAuth
const soraOAuth = useOpenAIOAuth({ platform: 'sora' }) // For Sora OAuth
const geminiOAuth = useGeminiOAuth() // For Gemini OAuth const geminiOAuth = useGeminiOAuth() // For Gemini OAuth
const antigravityOAuth = useAntigravityOAuth() // For Antigravity OAuth const antigravityOAuth = useAntigravityOAuth() // For Antigravity OAuth
const activeOpenAIOAuth = computed(() => (form.platform === 'sora' ? soraOAuth : openaiOAuth))
// Computed: current OAuth state for template binding // Computed: current OAuth state for template binding
const currentAuthUrl = computed(() => { const currentAuthUrl = computed(() => {
if (form.platform === 'openai') return openaiOAuth.authUrl.value if (form.platform === 'openai' || form.platform === 'sora') return activeOpenAIOAuth.value.authUrl.value
if (form.platform === 'gemini') return geminiOAuth.authUrl.value if (form.platform === 'gemini') return geminiOAuth.authUrl.value
if (form.platform === 'antigravity') return antigravityOAuth.authUrl.value if (form.platform === 'antigravity') return antigravityOAuth.authUrl.value
return oauth.authUrl.value return oauth.authUrl.value
}) })
const currentSessionId = computed(() => { const currentSessionId = computed(() => {
if (form.platform === 'openai') return openaiOAuth.sessionId.value if (form.platform === 'openai' || form.platform === 'sora') return activeOpenAIOAuth.value.sessionId.value
if (form.platform === 'gemini') return geminiOAuth.sessionId.value if (form.platform === 'gemini') return geminiOAuth.sessionId.value
if (form.platform === 'antigravity') return antigravityOAuth.sessionId.value if (form.platform === 'antigravity') return antigravityOAuth.sessionId.value
return oauth.sessionId.value return oauth.sessionId.value
}) })
const currentOAuthLoading = computed(() => { const currentOAuthLoading = computed(() => {
if (form.platform === 'openai') return openaiOAuth.loading.value if (form.platform === 'openai' || form.platform === 'sora') return activeOpenAIOAuth.value.loading.value
if (form.platform === 'gemini') return geminiOAuth.loading.value if (form.platform === 'gemini') return geminiOAuth.loading.value
if (form.platform === 'antigravity') return antigravityOAuth.loading.value if (form.platform === 'antigravity') return antigravityOAuth.loading.value
return oauth.loading.value return oauth.loading.value
}) })
const currentOAuthError = computed(() => { const currentOAuthError = computed(() => {
if (form.platform === 'openai') return openaiOAuth.error.value if (form.platform === 'openai' || form.platform === 'sora') return activeOpenAIOAuth.value.error.value
if (form.platform === 'gemini') return geminiOAuth.error.value if (form.platform === 'gemini') return geminiOAuth.error.value
if (form.platform === 'antigravity') return antigravityOAuth.error.value if (form.platform === 'antigravity') return antigravityOAuth.error.value
return oauth.error.value return oauth.error.value
@@ -2257,7 +2290,6 @@ const interceptWarmupRequests = ref(false)
const autoPauseOnExpired = ref(true) const autoPauseOnExpired = ref(true)
const openaiPassthroughEnabled = ref(false) const openaiPassthroughEnabled = ref(false)
const codexCLIOnlyEnabled = ref(false) const codexCLIOnlyEnabled = ref(false)
const enableSoraOnOpenAIOAuth = ref(false) // OpenAI OAuth 时同时启用 Sora
const mixedScheduling = ref(false) // For antigravity accounts: enable mixed scheduling const mixedScheduling = ref(false) // For antigravity accounts: enable mixed scheduling
const antigravityAccountType = ref<'oauth' | 'upstream'>('oauth') // For antigravity: oauth or upstream const antigravityAccountType = ref<'oauth' | 'upstream'>('oauth') // For antigravity: oauth or upstream
const upstreamBaseUrl = ref('') // For upstream type: base URL const upstreamBaseUrl = ref('') // For upstream type: base URL
@@ -2398,8 +2430,8 @@ const expiresAtInput = computed({
const canExchangeCode = computed(() => { const canExchangeCode = computed(() => {
const authCode = oauthFlowRef.value?.authCode || '' const authCode = oauthFlowRef.value?.authCode || ''
if (form.platform === 'openai') { if (form.platform === 'openai' || form.platform === 'sora') {
return authCode.trim() && openaiOAuth.sessionId.value && !openaiOAuth.loading.value return authCode.trim() && activeOpenAIOAuth.value.sessionId.value && !activeOpenAIOAuth.value.loading.value
} }
if (form.platform === 'gemini') { if (form.platform === 'gemini') {
return authCode.trim() && geminiOAuth.sessionId.value && !geminiOAuth.loading.value return authCode.trim() && geminiOAuth.sessionId.value && !geminiOAuth.loading.value
@@ -2459,7 +2491,7 @@ watch(
(newPlatform) => { (newPlatform) => {
// Reset base URL based on platform // Reset base URL based on platform
apiKeyBaseUrl.value = apiKeyBaseUrl.value =
newPlatform === 'openai' (newPlatform === 'openai' || newPlatform === 'sora')
? 'https://api.openai.com' ? 'https://api.openai.com'
: newPlatform === 'gemini' : newPlatform === 'gemini'
? 'https://generativelanguage.googleapis.com' ? 'https://generativelanguage.googleapis.com'
@@ -2485,6 +2517,11 @@ watch(
if (newPlatform !== 'anthropic') { if (newPlatform !== 'anthropic') {
interceptWarmupRequests.value = false interceptWarmupRequests.value = false
} }
if (newPlatform === 'sora') {
accountCategory.value = 'oauth-based'
addMethod.value = 'oauth'
form.type = 'oauth'
}
if (newPlatform !== 'openai') { if (newPlatform !== 'openai') {
openaiPassthroughEnabled.value = false openaiPassthroughEnabled.value = false
codexCLIOnlyEnabled.value = false codexCLIOnlyEnabled.value = false
@@ -2492,6 +2529,7 @@ watch(
// Reset OAuth states // Reset OAuth states
oauth.resetState() oauth.resetState()
openaiOAuth.resetState() openaiOAuth.resetState()
soraOAuth.resetState()
geminiOAuth.resetState() geminiOAuth.resetState()
antigravityOAuth.resetState() antigravityOAuth.resetState()
} }
@@ -2753,7 +2791,6 @@ const resetForm = () => {
autoPauseOnExpired.value = true autoPauseOnExpired.value = true
openaiPassthroughEnabled.value = false openaiPassthroughEnabled.value = false
codexCLIOnlyEnabled.value = false codexCLIOnlyEnabled.value = false
enableSoraOnOpenAIOAuth.value = false
// Reset quota control state // Reset quota control state
windowCostEnabled.value = false windowCostEnabled.value = false
windowCostLimit.value = null windowCostLimit.value = null
@@ -2776,6 +2813,7 @@ const resetForm = () => {
geminiTierAIStudio.value = 'aistudio_free' geminiTierAIStudio.value = 'aistudio_free'
oauth.resetState() oauth.resetState()
openaiOAuth.resetState() openaiOAuth.resetState()
soraOAuth.resetState()
geminiOAuth.resetState() geminiOAuth.resetState()
antigravityOAuth.resetState() antigravityOAuth.resetState()
oauthFlowRef.value?.reset() oauthFlowRef.value?.reset()
@@ -2807,6 +2845,23 @@ const buildOpenAIExtra = (base?: Record<string, unknown>): Record<string, unknow
return Object.keys(extra).length > 0 ? extra : undefined return Object.keys(extra).length > 0 ? extra : undefined
} }
const buildSoraExtra = (
base?: Record<string, unknown>,
linkedOpenAIAccountId?: string | number
): Record<string, unknown> | undefined => {
const extra: Record<string, unknown> = { ...(base || {}) }
if (linkedOpenAIAccountId !== undefined && linkedOpenAIAccountId !== null) {
const id = String(linkedOpenAIAccountId).trim()
if (id) {
extra.linked_openai_account_id = id
}
}
delete extra.openai_passthrough
delete extra.openai_oauth_passthrough
delete extra.codex_cli_only
return Object.keys(extra).length > 0 ? extra : undefined
}
// Helper function to create account with mixed channel warning handling // Helper function to create account with mixed channel warning handling
const doCreateAccount = async (payload: any) => { const doCreateAccount = async (payload: any) => {
submitting.value = true submitting.value = true
@@ -2922,7 +2977,7 @@ const handleSubmit = async () => {
// Determine default base URL based on platform // Determine default base URL based on platform
const defaultBaseUrl = const defaultBaseUrl =
form.platform === 'openai' (form.platform === 'openai' || form.platform === 'sora')
? 'https://api.openai.com' ? 'https://api.openai.com'
: form.platform === 'gemini' : form.platform === 'gemini'
? 'https://generativelanguage.googleapis.com' ? 'https://generativelanguage.googleapis.com'
@@ -2974,14 +3029,15 @@ const goBackToBasicInfo = () => {
step.value = 1 step.value = 1
oauth.resetState() oauth.resetState()
openaiOAuth.resetState() openaiOAuth.resetState()
soraOAuth.resetState()
geminiOAuth.resetState() geminiOAuth.resetState()
antigravityOAuth.resetState() antigravityOAuth.resetState()
oauthFlowRef.value?.reset() oauthFlowRef.value?.reset()
} }
const handleGenerateUrl = async () => { const handleGenerateUrl = async () => {
if (form.platform === 'openai') { if (form.platform === 'openai' || form.platform === 'sora') {
await openaiOAuth.generateAuthUrl(form.proxy_id) await activeOpenAIOAuth.value.generateAuthUrl(form.proxy_id)
} else if (form.platform === 'gemini') { } else if (form.platform === 'gemini') {
await geminiOAuth.generateAuthUrl( await geminiOAuth.generateAuthUrl(
form.proxy_id, form.proxy_id,
@@ -2997,13 +3053,19 @@ const handleGenerateUrl = async () => {
} }
const handleValidateRefreshToken = (rt: string) => { const handleValidateRefreshToken = (rt: string) => {
if (form.platform === 'openai') { if (form.platform === 'openai' || form.platform === 'sora') {
handleOpenAIValidateRT(rt) handleOpenAIValidateRT(rt)
} else if (form.platform === 'antigravity') { } else if (form.platform === 'antigravity') {
handleAntigravityValidateRT(rt) handleAntigravityValidateRT(rt)
} }
} }
const handleValidateSessionToken = (sessionToken: string) => {
if (form.platform === 'sora') {
handleSoraValidateST(sessionToken)
}
}
const formatDateTimeLocal = formatDateTimeLocalInput const formatDateTimeLocal = formatDateTimeLocalInput
const parseDateTimeLocal = parseDateTimeLocalInput const parseDateTimeLocal = parseDateTimeLocalInput
@@ -3039,29 +3101,42 @@ const createAccountAndFinish = async (
// OpenAI OAuth 授权码兑换 // OpenAI OAuth 授权码兑换
const handleOpenAIExchange = async (authCode: string) => { const handleOpenAIExchange = async (authCode: string) => {
if (!authCode.trim() || !openaiOAuth.sessionId.value) return const oauthClient = activeOpenAIOAuth.value
if (!authCode.trim() || !oauthClient.sessionId.value) return
openaiOAuth.loading.value = true oauthClient.loading.value = true
openaiOAuth.error.value = '' oauthClient.error.value = ''
try { try {
const tokenInfo = await openaiOAuth.exchangeAuthCode( const stateToUse = (oauthFlowRef.value?.oauthState || oauthClient.oauthState.value || '').trim()
if (!stateToUse) {
oauthClient.error.value = t('admin.accounts.oauth.authFailed')
appStore.showError(oauthClient.error.value)
return
}
const tokenInfo = await oauthClient.exchangeAuthCode(
authCode.trim(), authCode.trim(),
openaiOAuth.sessionId.value, oauthClient.sessionId.value,
stateToUse,
form.proxy_id form.proxy_id
) )
if (!tokenInfo) return if (!tokenInfo) return
const credentials = openaiOAuth.buildCredentials(tokenInfo) const credentials = oauthClient.buildCredentials(tokenInfo)
const oauthExtra = openaiOAuth.buildExtraInfo(tokenInfo) as Record<string, unknown> | undefined const oauthExtra = oauthClient.buildExtraInfo(tokenInfo) as Record<string, unknown> | undefined
const extra = buildOpenAIExtra(oauthExtra) const extra = buildOpenAIExtra(oauthExtra)
const shouldCreateOpenAI = form.platform === 'openai'
const shouldCreateSora = form.platform === 'sora'
// 应用临时不可调度配置 // 应用临时不可调度配置
if (!applyTempUnschedConfig(credentials)) { if (!applyTempUnschedConfig(credentials)) {
return return
} }
// 1. 创建 OpenAI 账号 let openaiAccountId: string | number | undefined
if (shouldCreateOpenAI) {
const openaiAccount = await adminAPI.accounts.create({ const openaiAccount = await adminAPI.accounts.create({
name: form.name, name: form.name,
notes: form.notes, notes: form.notes,
@@ -3077,29 +3152,21 @@ const handleOpenAIExchange = async (authCode: string) => {
expires_at: form.expires_at, expires_at: form.expires_at,
auto_pause_on_expired: autoPauseOnExpired.value auto_pause_on_expired: autoPauseOnExpired.value
}) })
openaiAccountId = openaiAccount.id
appStore.showSuccess(t('admin.accounts.accountCreated')) appStore.showSuccess(t('admin.accounts.accountCreated'))
}
// 2. 如果启用了 Sora同时创建 Sora 账号 if (shouldCreateSora) {
if (enableSoraOnOpenAIOAuth.value) {
try {
// Sora 使用相同的 OAuth credentials
const soraCredentials = { const soraCredentials = {
access_token: credentials.access_token, access_token: credentials.access_token,
refresh_token: credentials.refresh_token, refresh_token: credentials.refresh_token,
expires_at: credentials.expires_at expires_at: credentials.expires_at
} }
// 建立关联关系 const soraName = shouldCreateOpenAI ? `${form.name} (Sora)` : form.name
const soraExtra: Record<string, unknown> = { const soraExtra = buildSoraExtra(shouldCreateOpenAI ? extra : oauthExtra, openaiAccountId)
...(extra || {}),
linked_openai_account_id: String(openaiAccount.id)
}
delete soraExtra.openai_passthrough
delete soraExtra.openai_oauth_passthrough
await adminAPI.accounts.create({ await adminAPI.accounts.create({
name: `${form.name} (Sora)`, name: soraName,
notes: form.notes, notes: form.notes,
platform: 'sora', platform: 'sora',
type: 'oauth', type: 'oauth',
@@ -3113,26 +3180,22 @@ const handleOpenAIExchange = async (authCode: string) => {
expires_at: form.expires_at, expires_at: form.expires_at,
auto_pause_on_expired: autoPauseOnExpired.value auto_pause_on_expired: autoPauseOnExpired.value
}) })
appStore.showSuccess(t('admin.accounts.accountCreated'))
appStore.showSuccess(t('admin.accounts.soraAccountCreated'))
} catch (error: any) {
console.error('创建 Sora 账号失败:', error)
appStore.showWarning(t('admin.accounts.soraAccountFailed'))
}
} }
emit('created') emit('created')
handleClose() handleClose()
} catch (error: any) { } catch (error: any) {
openaiOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed') oauthClient.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
appStore.showError(openaiOAuth.error.value) appStore.showError(oauthClient.error.value)
} finally { } finally {
openaiOAuth.loading.value = false oauthClient.loading.value = false
} }
} }
// OpenAI 手动 RT 批量验证和创建 // OpenAI 手动 RT 批量验证和创建
const handleOpenAIValidateRT = async (refreshTokenInput: string) => { const handleOpenAIValidateRT = async (refreshTokenInput: string) => {
const oauthClient = activeOpenAIOAuth.value
if (!refreshTokenInput.trim()) return if (!refreshTokenInput.trim()) return
// Parse multiple refresh tokens (one per line) // Parse multiple refresh tokens (one per line)
@@ -3142,39 +3205,44 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => {
.filter((rt) => rt) .filter((rt) => rt)
if (refreshTokens.length === 0) { if (refreshTokens.length === 0) {
openaiOAuth.error.value = t('admin.accounts.oauth.openai.pleaseEnterRefreshToken') oauthClient.error.value = t('admin.accounts.oauth.openai.pleaseEnterRefreshToken')
return return
} }
openaiOAuth.loading.value = true oauthClient.loading.value = true
openaiOAuth.error.value = '' oauthClient.error.value = ''
let successCount = 0 let successCount = 0
let failedCount = 0 let failedCount = 0
const errors: string[] = [] const errors: string[] = []
const shouldCreateOpenAI = form.platform === 'openai'
const shouldCreateSora = form.platform === 'sora'
try { try {
for (let i = 0; i < refreshTokens.length; i++) { for (let i = 0; i < refreshTokens.length; i++) {
try { try {
const tokenInfo = await openaiOAuth.validateRefreshToken( const tokenInfo = await oauthClient.validateRefreshToken(
refreshTokens[i], refreshTokens[i],
form.proxy_id form.proxy_id
) )
if (!tokenInfo) { if (!tokenInfo) {
failedCount++ failedCount++
errors.push(`#${i + 1}: ${openaiOAuth.error.value || 'Validation failed'}`) errors.push(`#${i + 1}: ${oauthClient.error.value || 'Validation failed'}`)
openaiOAuth.error.value = '' oauthClient.error.value = ''
continue continue
} }
const credentials = openaiOAuth.buildCredentials(tokenInfo) const credentials = oauthClient.buildCredentials(tokenInfo)
const oauthExtra = openaiOAuth.buildExtraInfo(tokenInfo) as Record<string, unknown> | undefined const oauthExtra = oauthClient.buildExtraInfo(tokenInfo) as Record<string, unknown> | undefined
const extra = buildOpenAIExtra(oauthExtra) const extra = buildOpenAIExtra(oauthExtra)
// Generate account name with index for batch // Generate account name with index for batch
const accountName = refreshTokens.length > 1 ? `${form.name} #${i + 1}` : form.name const accountName = refreshTokens.length > 1 ? `${form.name} #${i + 1}` : form.name
await adminAPI.accounts.create({ let openaiAccountId: string | number | undefined
if (shouldCreateOpenAI) {
const openaiAccount = await adminAPI.accounts.create({
name: accountName, name: accountName,
notes: form.notes, notes: form.notes,
platform: 'openai', platform: 'openai',
@@ -3189,6 +3257,34 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => {
expires_at: form.expires_at, expires_at: form.expires_at,
auto_pause_on_expired: autoPauseOnExpired.value auto_pause_on_expired: autoPauseOnExpired.value
}) })
openaiAccountId = openaiAccount.id
}
if (shouldCreateSora) {
const soraCredentials = {
access_token: credentials.access_token,
refresh_token: credentials.refresh_token,
expires_at: credentials.expires_at
}
const soraName = shouldCreateOpenAI ? `${accountName} (Sora)` : accountName
const soraExtra = buildSoraExtra(shouldCreateOpenAI ? extra : oauthExtra, openaiAccountId)
await adminAPI.accounts.create({
name: soraName,
notes: form.notes,
platform: 'sora',
type: 'oauth',
credentials: soraCredentials,
extra: soraExtra,
proxy_id: form.proxy_id,
concurrency: form.concurrency,
priority: form.priority,
rate_multiplier: form.rate_multiplier,
group_ids: form.group_ids,
expires_at: form.expires_at,
auto_pause_on_expired: autoPauseOnExpired.value
})
}
successCount++ successCount++
} catch (error: any) { } catch (error: any) {
failedCount++ failedCount++
@@ -3210,14 +3306,99 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => {
appStore.showWarning( appStore.showWarning(
t('admin.accounts.oauth.batchPartialSuccess', { success: successCount, failed: failedCount }) t('admin.accounts.oauth.batchPartialSuccess', { success: successCount, failed: failedCount })
) )
openaiOAuth.error.value = errors.join('\n') oauthClient.error.value = errors.join('\n')
emit('created') emit('created')
} else { } else {
openaiOAuth.error.value = errors.join('\n') oauthClient.error.value = errors.join('\n')
appStore.showError(t('admin.accounts.oauth.batchFailed')) appStore.showError(t('admin.accounts.oauth.batchFailed'))
} }
} finally { } finally {
openaiOAuth.loading.value = false oauthClient.loading.value = false
}
}
// Sora 手动 ST 批量验证和创建
const handleSoraValidateST = async (sessionTokenInput: string) => {
const oauthClient = activeOpenAIOAuth.value
if (!sessionTokenInput.trim()) return
const sessionTokens = sessionTokenInput
.split('\n')
.map((st) => st.trim())
.filter((st) => st)
if (sessionTokens.length === 0) {
oauthClient.error.value = t('admin.accounts.oauth.openai.pleaseEnterSessionToken')
return
}
oauthClient.loading.value = true
oauthClient.error.value = ''
let successCount = 0
let failedCount = 0
const errors: string[] = []
try {
for (let i = 0; i < sessionTokens.length; i++) {
try {
const tokenInfo = await oauthClient.validateSessionToken(sessionTokens[i], form.proxy_id)
if (!tokenInfo) {
failedCount++
errors.push(`#${i + 1}: ${oauthClient.error.value || 'Validation failed'}`)
oauthClient.error.value = ''
continue
}
const credentials = oauthClient.buildCredentials(tokenInfo)
credentials.session_token = sessionTokens[i]
const oauthExtra = oauthClient.buildExtraInfo(tokenInfo) as Record<string, unknown> | undefined
const soraExtra = buildSoraExtra(oauthExtra)
const accountName = sessionTokens.length > 1 ? `${form.name} #${i + 1}` : form.name
await adminAPI.accounts.create({
name: accountName,
notes: form.notes,
platform: 'sora',
type: 'oauth',
credentials,
extra: soraExtra,
proxy_id: form.proxy_id,
concurrency: form.concurrency,
priority: form.priority,
rate_multiplier: form.rate_multiplier,
group_ids: form.group_ids,
expires_at: form.expires_at,
auto_pause_on_expired: autoPauseOnExpired.value
})
successCount++
} catch (error: any) {
failedCount++
const errMsg = error.response?.data?.detail || error.message || 'Unknown error'
errors.push(`#${i + 1}: ${errMsg}`)
}
}
if (successCount > 0 && failedCount === 0) {
appStore.showSuccess(
sessionTokens.length > 1
? t('admin.accounts.oauth.batchSuccess', { count: successCount })
: t('admin.accounts.accountCreated')
)
emit('created')
handleClose()
} else if (successCount > 0 && failedCount > 0) {
appStore.showWarning(
t('admin.accounts.oauth.batchPartialSuccess', { success: successCount, failed: failedCount })
)
oauthClient.error.value = errors.join('\n')
emit('created')
} else {
oauthClient.error.value = errors.join('\n')
appStore.showError(t('admin.accounts.oauth.batchFailed'))
}
} finally {
oauthClient.loading.value = false
} }
} }
@@ -3462,6 +3643,7 @@ const handleExchangeCode = async () => {
switch (form.platform) { switch (form.platform) {
case 'openai': case 'openai':
case 'sora':
return handleOpenAIExchange(authCode) return handleOpenAIExchange(authCode)
case 'gemini': case 'gemini':
return handleGeminiExchange(authCode) return handleGeminiExchange(authCode)

View File

@@ -48,6 +48,17 @@
t(getOAuthKey('refreshTokenAuth')) t(getOAuthKey('refreshTokenAuth'))
}}</span> }}</span>
</label> </label>
<label v-if="showSessionTokenOption" class="flex cursor-pointer items-center gap-2">
<input
v-model="inputMethod"
type="radio"
value="session_token"
class="text-blue-600 focus:ring-blue-500"
/>
<span class="text-sm text-blue-900 dark:text-blue-200">{{
t(getOAuthKey('sessionTokenAuth'))
}}</span>
</label>
</div> </div>
</div> </div>
@@ -135,6 +146,87 @@
</div> </div>
</div> </div>
<!-- Session Token Input (Sora) -->
<div v-if="inputMethod === 'session_token'" class="space-y-4">
<div
class="rounded-lg border border-blue-300 bg-white/80 p-4 dark:border-blue-600 dark:bg-gray-800/80"
>
<p class="mb-3 text-sm text-blue-700 dark:text-blue-300">
{{ t(getOAuthKey('sessionTokenDesc')) }}
</p>
<div class="mb-4">
<label
class="mb-2 flex items-center gap-2 text-sm font-semibold text-gray-700 dark:text-gray-300"
>
<Icon name="key" size="sm" class="text-blue-500" />
Session Token
<span
v-if="parsedSessionTokenCount > 1"
class="rounded-full bg-blue-500 px-2 py-0.5 text-xs text-white"
>
{{ t('admin.accounts.oauth.keysCount', { count: parsedSessionTokenCount }) }}
</span>
</label>
<textarea
v-model="sessionTokenInput"
rows="3"
class="input w-full resize-y font-mono text-sm"
:placeholder="t(getOAuthKey('sessionTokenPlaceholder'))"
></textarea>
<p
v-if="parsedSessionTokenCount > 1"
class="mt-1 text-xs text-blue-600 dark:text-blue-400"
>
{{ t('admin.accounts.oauth.batchCreateAccounts', { count: parsedSessionTokenCount }) }}
</p>
</div>
<div
v-if="error"
class="mb-4 rounded-lg border border-red-200 bg-red-50 p-3 dark:border-red-700 dark:bg-red-900/30"
>
<p class="whitespace-pre-line text-sm text-red-600 dark:text-red-400">
{{ error }}
</p>
</div>
<button
type="button"
class="btn btn-primary w-full"
:disabled="loading || !sessionTokenInput.trim()"
@click="handleValidateSessionToken"
>
<svg
v-if="loading"
class="-ml-1 mr-2 h-4 w-4 animate-spin"
fill="none"
viewBox="0 0 24 24"
>
<circle
class="opacity-25"
cx="12"
cy="12"
r="10"
stroke="currentColor"
stroke-width="4"
></circle>
<path
class="opacity-75"
fill="currentColor"
d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"
></path>
</svg>
<Icon v-else name="sparkles" size="sm" class="mr-2" />
{{
loading
? t(getOAuthKey('validating'))
: t(getOAuthKey('validateAndCreate'))
}}
</button>
</div>
</div>
<!-- Cookie Auto-Auth Form --> <!-- Cookie Auto-Auth Form -->
<div v-if="inputMethod === 'cookie'" class="space-y-4"> <div v-if="inputMethod === 'cookie'" class="space-y-4">
<div <div
@@ -525,9 +617,10 @@ interface Props {
methodLabel?: string methodLabel?: string
showCookieOption?: boolean // Whether to show cookie auto-auth option showCookieOption?: boolean // Whether to show cookie auto-auth option
showRefreshTokenOption?: boolean // Whether to show refresh token input option (OpenAI only) showRefreshTokenOption?: boolean // Whether to show refresh token input option (OpenAI only)
showSessionTokenOption?: boolean // Whether to show session token input option (Sora only)
platform?: AccountPlatform // Platform type for different UI/text platform?: AccountPlatform // Platform type for different UI/text
showProjectId?: boolean // New prop to control project ID visibility showProjectId?: boolean // New prop to control project ID visibility
} }
const props = withDefaults(defineProps<Props>(), { const props = withDefaults(defineProps<Props>(), {
authUrl: '', authUrl: '',
@@ -540,6 +633,7 @@ const props = withDefaults(defineProps<Props>(), {
methodLabel: 'Authorization Method', methodLabel: 'Authorization Method',
showCookieOption: true, showCookieOption: true,
showRefreshTokenOption: false, showRefreshTokenOption: false,
showSessionTokenOption: false,
platform: 'anthropic', platform: 'anthropic',
showProjectId: true showProjectId: true
}) })
@@ -549,6 +643,7 @@ const emit = defineEmits<{
'exchange-code': [code: string] 'exchange-code': [code: string]
'cookie-auth': [sessionKey: string] 'cookie-auth': [sessionKey: string]
'validate-refresh-token': [refreshToken: string] 'validate-refresh-token': [refreshToken: string]
'validate-session-token': [sessionToken: string]
'update:inputMethod': [method: AuthInputMethod] 'update:inputMethod': [method: AuthInputMethod]
}>() }>()
@@ -587,12 +682,13 @@ const inputMethod = ref<AuthInputMethod>(props.showCookieOption ? 'manual' : 'ma
const authCodeInput = ref('') const authCodeInput = ref('')
const sessionKeyInput = ref('') const sessionKeyInput = ref('')
const refreshTokenInput = ref('') const refreshTokenInput = ref('')
const sessionTokenInput = ref('')
const showHelpDialog = ref(false) const showHelpDialog = ref(false)
const oauthState = ref('') const oauthState = ref('')
const projectId = ref('') const projectId = ref('')
// Computed: show method selection when either cookie or refresh token option is enabled // Computed: show method selection when either cookie or refresh token option is enabled
const showMethodSelection = computed(() => props.showCookieOption || props.showRefreshTokenOption) const showMethodSelection = computed(() => props.showCookieOption || props.showRefreshTokenOption || props.showSessionTokenOption)
// Clipboard // Clipboard
const { copied, copyToClipboard } = useClipboard() const { copied, copyToClipboard } = useClipboard()
@@ -613,6 +709,13 @@ const parsedRefreshTokenCount = computed(() => {
.filter((rt) => rt).length .filter((rt) => rt).length
}) })
const parsedSessionTokenCount = computed(() => {
return sessionTokenInput.value
.split('\n')
.map((st) => st.trim())
.filter((st) => st).length
})
// Watchers // Watchers
watch(inputMethod, (newVal) => { watch(inputMethod, (newVal) => {
emit('update:inputMethod', newVal) emit('update:inputMethod', newVal)
@@ -631,7 +734,7 @@ watch(authCodeInput, (newVal) => {
const url = new URL(trimmed) const url = new URL(trimmed)
const code = url.searchParams.get('code') const code = url.searchParams.get('code')
const stateParam = url.searchParams.get('state') const stateParam = url.searchParams.get('state')
if ((props.platform === 'gemini' || props.platform === 'antigravity') && stateParam) { if ((props.platform === 'openai' || props.platform === 'sora' || props.platform === 'gemini' || props.platform === 'antigravity') && stateParam) {
oauthState.value = stateParam oauthState.value = stateParam
} }
if (code && code !== trimmed) { if (code && code !== trimmed) {
@@ -642,7 +745,7 @@ watch(authCodeInput, (newVal) => {
// If URL parsing fails, try regex extraction // If URL parsing fails, try regex extraction
const match = trimmed.match(/[?&]code=([^&]+)/) const match = trimmed.match(/[?&]code=([^&]+)/)
const stateMatch = trimmed.match(/[?&]state=([^&]+)/) const stateMatch = trimmed.match(/[?&]state=([^&]+)/)
if ((props.platform === 'gemini' || props.platform === 'antigravity') && stateMatch && stateMatch[1]) { if ((props.platform === 'openai' || props.platform === 'sora' || props.platform === 'gemini' || props.platform === 'antigravity') && stateMatch && stateMatch[1]) {
oauthState.value = stateMatch[1] oauthState.value = stateMatch[1]
} }
if (match && match[1] && match[1] !== trimmed) { if (match && match[1] && match[1] !== trimmed) {
@@ -680,6 +783,12 @@ const handleValidateRefreshToken = () => {
} }
} }
const handleValidateSessionToken = () => {
if (sessionTokenInput.value.trim()) {
emit('validate-session-token', sessionTokenInput.value.trim())
}
}
// Expose methods and state // Expose methods and state
defineExpose({ defineExpose({
authCode: authCodeInput, authCode: authCodeInput,
@@ -687,6 +796,7 @@ defineExpose({
projectId, projectId,
sessionKey: sessionKeyInput, sessionKey: sessionKeyInput,
refreshToken: refreshTokenInput, refreshToken: refreshTokenInput,
sessionToken: sessionTokenInput,
inputMethod, inputMethod,
reset: () => { reset: () => {
authCodeInput.value = '' authCodeInput.value = ''
@@ -694,6 +804,7 @@ defineExpose({
projectId.value = '' projectId.value = ''
sessionKeyInput.value = '' sessionKeyInput.value = ''
refreshTokenInput.value = '' refreshTokenInput.value = ''
sessionTokenInput.value = ''
inputMethod.value = 'manual' inputMethod.value = 'manual'
showHelpDialog.value = false showHelpDialog.value = false
} }

View File

@@ -14,7 +14,7 @@
<div <div
:class="[ :class="[
'flex h-10 w-10 items-center justify-center rounded-lg bg-gradient-to-br', 'flex h-10 w-10 items-center justify-center rounded-lg bg-gradient-to-br',
isOpenAI isOpenAILike
? 'from-green-500 to-green-600' ? 'from-green-500 to-green-600'
: isGemini : isGemini
? 'from-blue-500 to-blue-600' ? 'from-blue-500 to-blue-600'
@@ -33,6 +33,8 @@
{{ {{
isOpenAI isOpenAI
? t('admin.accounts.openaiAccount') ? t('admin.accounts.openaiAccount')
: isSora
? t('admin.accounts.soraAccount')
: isGemini : isGemini
? t('admin.accounts.geminiAccount') ? t('admin.accounts.geminiAccount')
: isAntigravity : isAntigravity
@@ -128,7 +130,7 @@
:show-cookie-option="isAnthropic" :show-cookie-option="isAnthropic"
:allow-multiple="false" :allow-multiple="false"
:method-label="t('admin.accounts.inputMethod')" :method-label="t('admin.accounts.inputMethod')"
:platform="isOpenAI ? 'openai' : isGemini ? 'gemini' : isAntigravity ? 'antigravity' : 'anthropic'" :platform="isOpenAI ? 'openai' : isSora ? 'sora' : isGemini ? 'gemini' : isAntigravity ? 'antigravity' : 'anthropic'"
:show-project-id="isGemini && geminiOAuthType === 'code_assist'" :show-project-id="isGemini && geminiOAuthType === 'code_assist'"
@generate-url="handleGenerateUrl" @generate-url="handleGenerateUrl"
@cookie-auth="handleCookieAuth" @cookie-auth="handleCookieAuth"
@@ -224,7 +226,8 @@ const { t } = useI18n()
// OAuth composables // OAuth composables
const claudeOAuth = useAccountOAuth() const claudeOAuth = useAccountOAuth()
const openaiOAuth = useOpenAIOAuth() const openaiOAuth = useOpenAIOAuth({ platform: 'openai' })
const soraOAuth = useOpenAIOAuth({ platform: 'sora' })
const geminiOAuth = useGeminiOAuth() const geminiOAuth = useGeminiOAuth()
const antigravityOAuth = useAntigravityOAuth() const antigravityOAuth = useAntigravityOAuth()
@@ -237,31 +240,34 @@ const geminiOAuthType = ref<'code_assist' | 'google_one' | 'ai_studio'>('code_as
// Computed - check platform // Computed - check platform
const isOpenAI = computed(() => props.account?.platform === 'openai') const isOpenAI = computed(() => props.account?.platform === 'openai')
const isSora = computed(() => props.account?.platform === 'sora')
const isOpenAILike = computed(() => isOpenAI.value || isSora.value)
const isGemini = computed(() => props.account?.platform === 'gemini') const isGemini = computed(() => props.account?.platform === 'gemini')
const isAnthropic = computed(() => props.account?.platform === 'anthropic') const isAnthropic = computed(() => props.account?.platform === 'anthropic')
const isAntigravity = computed(() => props.account?.platform === 'antigravity') const isAntigravity = computed(() => props.account?.platform === 'antigravity')
const activeOpenAIOAuth = computed(() => (isSora.value ? soraOAuth : openaiOAuth))
// Computed - current OAuth state based on platform // Computed - current OAuth state based on platform
const currentAuthUrl = computed(() => { const currentAuthUrl = computed(() => {
if (isOpenAI.value) return openaiOAuth.authUrl.value if (isOpenAILike.value) return activeOpenAIOAuth.value.authUrl.value
if (isGemini.value) return geminiOAuth.authUrl.value if (isGemini.value) return geminiOAuth.authUrl.value
if (isAntigravity.value) return antigravityOAuth.authUrl.value if (isAntigravity.value) return antigravityOAuth.authUrl.value
return claudeOAuth.authUrl.value return claudeOAuth.authUrl.value
}) })
const currentSessionId = computed(() => { const currentSessionId = computed(() => {
if (isOpenAI.value) return openaiOAuth.sessionId.value if (isOpenAILike.value) return activeOpenAIOAuth.value.sessionId.value
if (isGemini.value) return geminiOAuth.sessionId.value if (isGemini.value) return geminiOAuth.sessionId.value
if (isAntigravity.value) return antigravityOAuth.sessionId.value if (isAntigravity.value) return antigravityOAuth.sessionId.value
return claudeOAuth.sessionId.value return claudeOAuth.sessionId.value
}) })
const currentLoading = computed(() => { const currentLoading = computed(() => {
if (isOpenAI.value) return openaiOAuth.loading.value if (isOpenAILike.value) return activeOpenAIOAuth.value.loading.value
if (isGemini.value) return geminiOAuth.loading.value if (isGemini.value) return geminiOAuth.loading.value
if (isAntigravity.value) return antigravityOAuth.loading.value if (isAntigravity.value) return antigravityOAuth.loading.value
return claudeOAuth.loading.value return claudeOAuth.loading.value
}) })
const currentError = computed(() => { const currentError = computed(() => {
if (isOpenAI.value) return openaiOAuth.error.value if (isOpenAILike.value) return activeOpenAIOAuth.value.error.value
if (isGemini.value) return geminiOAuth.error.value if (isGemini.value) return geminiOAuth.error.value
if (isAntigravity.value) return antigravityOAuth.error.value if (isAntigravity.value) return antigravityOAuth.error.value
return claudeOAuth.error.value return claudeOAuth.error.value
@@ -269,8 +275,8 @@ const currentError = computed(() => {
// Computed // Computed
const isManualInputMethod = computed(() => { const isManualInputMethod = computed(() => {
// OpenAI/Gemini/Antigravity always use manual input (no cookie auth option) // OpenAI/Sora/Gemini/Antigravity always use manual input (no cookie auth option)
return isOpenAI.value || isGemini.value || isAntigravity.value || oauthFlowRef.value?.inputMethod === 'manual' return isOpenAILike.value || isGemini.value || isAntigravity.value || oauthFlowRef.value?.inputMethod === 'manual'
}) })
const canExchangeCode = computed(() => { const canExchangeCode = computed(() => {
@@ -313,6 +319,7 @@ const resetState = () => {
geminiOAuthType.value = 'code_assist' geminiOAuthType.value = 'code_assist'
claudeOAuth.resetState() claudeOAuth.resetState()
openaiOAuth.resetState() openaiOAuth.resetState()
soraOAuth.resetState()
geminiOAuth.resetState() geminiOAuth.resetState()
antigravityOAuth.resetState() antigravityOAuth.resetState()
oauthFlowRef.value?.reset() oauthFlowRef.value?.reset()
@@ -325,8 +332,8 @@ const handleClose = () => {
const handleGenerateUrl = async () => { const handleGenerateUrl = async () => {
if (!props.account) return if (!props.account) return
if (isOpenAI.value) { if (isOpenAILike.value) {
await openaiOAuth.generateAuthUrl(props.account.proxy_id) await activeOpenAIOAuth.value.generateAuthUrl(props.account.proxy_id)
} else if (isGemini.value) { } else if (isGemini.value) {
const creds = (props.account.credentials || {}) as Record<string, unknown> const creds = (props.account.credentials || {}) as Record<string, unknown>
const tierId = typeof creds.tier_id === 'string' ? creds.tier_id : undefined const tierId = typeof creds.tier_id === 'string' ? creds.tier_id : undefined
@@ -345,21 +352,29 @@ const handleExchangeCode = async () => {
const authCode = oauthFlowRef.value?.authCode || '' const authCode = oauthFlowRef.value?.authCode || ''
if (!authCode.trim()) return if (!authCode.trim()) return
if (isOpenAI.value) { if (isOpenAILike.value) {
// OpenAI OAuth flow // OpenAI OAuth flow
const sessionId = openaiOAuth.sessionId.value const oauthClient = activeOpenAIOAuth.value
const sessionId = oauthClient.sessionId.value
if (!sessionId) return if (!sessionId) return
const stateToUse = (oauthFlowRef.value?.oauthState || oauthClient.oauthState.value || '').trim()
if (!stateToUse) {
oauthClient.error.value = t('admin.accounts.oauth.authFailed')
appStore.showError(oauthClient.error.value)
return
}
const tokenInfo = await openaiOAuth.exchangeAuthCode( const tokenInfo = await oauthClient.exchangeAuthCode(
authCode.trim(), authCode.trim(),
sessionId, sessionId,
stateToUse,
props.account.proxy_id props.account.proxy_id
) )
if (!tokenInfo) return if (!tokenInfo) return
// Build credentials and extra info // Build credentials and extra info
const credentials = openaiOAuth.buildCredentials(tokenInfo) const credentials = oauthClient.buildCredentials(tokenInfo)
const extra = openaiOAuth.buildExtraInfo(tokenInfo) const extra = oauthClient.buildExtraInfo(tokenInfo)
try { try {
// Update account with new credentials // Update account with new credentials
@@ -376,8 +391,8 @@ const handleExchangeCode = async () => {
emit('reauthorized') emit('reauthorized')
handleClose() handleClose()
} catch (error: any) { } catch (error: any) {
openaiOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed') oauthClient.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
appStore.showError(openaiOAuth.error.value) appStore.showError(oauthClient.error.value)
} }
} else if (isGemini.value) { } else if (isGemini.value) {
const sessionId = geminiOAuth.sessionId.value const sessionId = geminiOAuth.sessionId.value
@@ -490,7 +505,7 @@ const handleExchangeCode = async () => {
} }
const handleCookieAuth = async (sessionKey: string) => { const handleCookieAuth = async (sessionKey: string) => {
if (!props.account || isOpenAI.value) return if (!props.account || isOpenAILike.value) return
claudeOAuth.loading.value = true claudeOAuth.loading.value = true
claudeOAuth.error.value = '' claudeOAuth.error.value = ''

View File

@@ -238,6 +238,11 @@ const loadAvailableModels = async () => {
availableModels.value.find((m) => m.id === 'gemini-3-flash-preview') || availableModels.value.find((m) => m.id === 'gemini-3-flash-preview') ||
availableModels.value.find((m) => m.id === 'gemini-3-pro-preview') availableModels.value.find((m) => m.id === 'gemini-3-pro-preview')
selectedModelId.value = preferred?.id || availableModels.value[0].id selectedModelId.value = preferred?.id || availableModels.value[0].id
} else if (props.account.platform === 'sora') {
const preferred =
availableModels.value.find((m) => m.id === 'gpt-image') ||
availableModels.value.find((m) => !m.id.startsWith('prompt-enhance'))
selectedModelId.value = preferred?.id || availableModels.value[0].id
} else { } else {
// Try to select Sonnet as default, otherwise use first model // Try to select Sonnet as default, otherwise use first model
const sonnetModel = availableModels.value.find((m) => m.id.includes('sonnet')) const sonnetModel = availableModels.value.find((m) => m.id.includes('sonnet'))

View File

@@ -14,7 +14,7 @@
<div <div
:class="[ :class="[
'flex h-10 w-10 items-center justify-center rounded-lg bg-gradient-to-br', 'flex h-10 w-10 items-center justify-center rounded-lg bg-gradient-to-br',
isOpenAI isOpenAILike
? 'from-green-500 to-green-600' ? 'from-green-500 to-green-600'
: isGemini : isGemini
? 'from-blue-500 to-blue-600' ? 'from-blue-500 to-blue-600'
@@ -33,6 +33,8 @@
{{ {{
isOpenAI isOpenAI
? t('admin.accounts.openaiAccount') ? t('admin.accounts.openaiAccount')
: isSora
? t('admin.accounts.soraAccount')
: isGemini : isGemini
? t('admin.accounts.geminiAccount') ? t('admin.accounts.geminiAccount')
: isAntigravity : isAntigravity
@@ -128,7 +130,7 @@
:show-cookie-option="isAnthropic" :show-cookie-option="isAnthropic"
:allow-multiple="false" :allow-multiple="false"
:method-label="t('admin.accounts.inputMethod')" :method-label="t('admin.accounts.inputMethod')"
:platform="isOpenAI ? 'openai' : isGemini ? 'gemini' : isAntigravity ? 'antigravity' : 'anthropic'" :platform="isOpenAI ? 'openai' : isSora ? 'sora' : isGemini ? 'gemini' : isAntigravity ? 'antigravity' : 'anthropic'"
:show-project-id="isGemini && geminiOAuthType === 'code_assist'" :show-project-id="isGemini && geminiOAuthType === 'code_assist'"
@generate-url="handleGenerateUrl" @generate-url="handleGenerateUrl"
@cookie-auth="handleCookieAuth" @cookie-auth="handleCookieAuth"
@@ -224,7 +226,8 @@ const { t } = useI18n()
// OAuth composables // OAuth composables
const claudeOAuth = useAccountOAuth() const claudeOAuth = useAccountOAuth()
const openaiOAuth = useOpenAIOAuth() const openaiOAuth = useOpenAIOAuth({ platform: 'openai' })
const soraOAuth = useOpenAIOAuth({ platform: 'sora' })
const geminiOAuth = useGeminiOAuth() const geminiOAuth = useGeminiOAuth()
const antigravityOAuth = useAntigravityOAuth() const antigravityOAuth = useAntigravityOAuth()
@@ -237,31 +240,34 @@ const geminiOAuthType = ref<'code_assist' | 'google_one' | 'ai_studio'>('code_as
// Computed - check platform // Computed - check platform
const isOpenAI = computed(() => props.account?.platform === 'openai') const isOpenAI = computed(() => props.account?.platform === 'openai')
const isSora = computed(() => props.account?.platform === 'sora')
const isOpenAILike = computed(() => isOpenAI.value || isSora.value)
const isGemini = computed(() => props.account?.platform === 'gemini') const isGemini = computed(() => props.account?.platform === 'gemini')
const isAnthropic = computed(() => props.account?.platform === 'anthropic') const isAnthropic = computed(() => props.account?.platform === 'anthropic')
const isAntigravity = computed(() => props.account?.platform === 'antigravity') const isAntigravity = computed(() => props.account?.platform === 'antigravity')
const activeOpenAIOAuth = computed(() => (isSora.value ? soraOAuth : openaiOAuth))
// Computed - current OAuth state based on platform // Computed - current OAuth state based on platform
const currentAuthUrl = computed(() => { const currentAuthUrl = computed(() => {
if (isOpenAI.value) return openaiOAuth.authUrl.value if (isOpenAILike.value) return activeOpenAIOAuth.value.authUrl.value
if (isGemini.value) return geminiOAuth.authUrl.value if (isGemini.value) return geminiOAuth.authUrl.value
if (isAntigravity.value) return antigravityOAuth.authUrl.value if (isAntigravity.value) return antigravityOAuth.authUrl.value
return claudeOAuth.authUrl.value return claudeOAuth.authUrl.value
}) })
const currentSessionId = computed(() => { const currentSessionId = computed(() => {
if (isOpenAI.value) return openaiOAuth.sessionId.value if (isOpenAILike.value) return activeOpenAIOAuth.value.sessionId.value
if (isGemini.value) return geminiOAuth.sessionId.value if (isGemini.value) return geminiOAuth.sessionId.value
if (isAntigravity.value) return antigravityOAuth.sessionId.value if (isAntigravity.value) return antigravityOAuth.sessionId.value
return claudeOAuth.sessionId.value return claudeOAuth.sessionId.value
}) })
const currentLoading = computed(() => { const currentLoading = computed(() => {
if (isOpenAI.value) return openaiOAuth.loading.value if (isOpenAILike.value) return activeOpenAIOAuth.value.loading.value
if (isGemini.value) return geminiOAuth.loading.value if (isGemini.value) return geminiOAuth.loading.value
if (isAntigravity.value) return antigravityOAuth.loading.value if (isAntigravity.value) return antigravityOAuth.loading.value
return claudeOAuth.loading.value return claudeOAuth.loading.value
}) })
const currentError = computed(() => { const currentError = computed(() => {
if (isOpenAI.value) return openaiOAuth.error.value if (isOpenAILike.value) return activeOpenAIOAuth.value.error.value
if (isGemini.value) return geminiOAuth.error.value if (isGemini.value) return geminiOAuth.error.value
if (isAntigravity.value) return antigravityOAuth.error.value if (isAntigravity.value) return antigravityOAuth.error.value
return claudeOAuth.error.value return claudeOAuth.error.value
@@ -269,8 +275,8 @@ const currentError = computed(() => {
// Computed // Computed
const isManualInputMethod = computed(() => { const isManualInputMethod = computed(() => {
// OpenAI/Gemini/Antigravity always use manual input (no cookie auth option) // OpenAI/Sora/Gemini/Antigravity always use manual input (no cookie auth option)
return isOpenAI.value || isGemini.value || isAntigravity.value || oauthFlowRef.value?.inputMethod === 'manual' return isOpenAILike.value || isGemini.value || isAntigravity.value || oauthFlowRef.value?.inputMethod === 'manual'
}) })
const canExchangeCode = computed(() => { const canExchangeCode = computed(() => {
@@ -313,6 +319,7 @@ const resetState = () => {
geminiOAuthType.value = 'code_assist' geminiOAuthType.value = 'code_assist'
claudeOAuth.resetState() claudeOAuth.resetState()
openaiOAuth.resetState() openaiOAuth.resetState()
soraOAuth.resetState()
geminiOAuth.resetState() geminiOAuth.resetState()
antigravityOAuth.resetState() antigravityOAuth.resetState()
oauthFlowRef.value?.reset() oauthFlowRef.value?.reset()
@@ -325,8 +332,8 @@ const handleClose = () => {
const handleGenerateUrl = async () => { const handleGenerateUrl = async () => {
if (!props.account) return if (!props.account) return
if (isOpenAI.value) { if (isOpenAILike.value) {
await openaiOAuth.generateAuthUrl(props.account.proxy_id) await activeOpenAIOAuth.value.generateAuthUrl(props.account.proxy_id)
} else if (isGemini.value) { } else if (isGemini.value) {
const creds = (props.account.credentials || {}) as Record<string, unknown> const creds = (props.account.credentials || {}) as Record<string, unknown>
const tierId = typeof creds.tier_id === 'string' ? creds.tier_id : undefined const tierId = typeof creds.tier_id === 'string' ? creds.tier_id : undefined
@@ -345,21 +352,29 @@ const handleExchangeCode = async () => {
const authCode = oauthFlowRef.value?.authCode || '' const authCode = oauthFlowRef.value?.authCode || ''
if (!authCode.trim()) return if (!authCode.trim()) return
if (isOpenAI.value) { if (isOpenAILike.value) {
// OpenAI OAuth flow // OpenAI OAuth flow
const sessionId = openaiOAuth.sessionId.value const oauthClient = activeOpenAIOAuth.value
const sessionId = oauthClient.sessionId.value
if (!sessionId) return if (!sessionId) return
const stateToUse = (oauthFlowRef.value?.oauthState || oauthClient.oauthState.value || '').trim()
if (!stateToUse) {
oauthClient.error.value = t('admin.accounts.oauth.authFailed')
appStore.showError(oauthClient.error.value)
return
}
const tokenInfo = await openaiOAuth.exchangeAuthCode( const tokenInfo = await oauthClient.exchangeAuthCode(
authCode.trim(), authCode.trim(),
sessionId, sessionId,
stateToUse,
props.account.proxy_id props.account.proxy_id
) )
if (!tokenInfo) return if (!tokenInfo) return
// Build credentials and extra info // Build credentials and extra info
const credentials = openaiOAuth.buildCredentials(tokenInfo) const credentials = oauthClient.buildCredentials(tokenInfo)
const extra = openaiOAuth.buildExtraInfo(tokenInfo) const extra = oauthClient.buildExtraInfo(tokenInfo)
try { try {
// Update account with new credentials // Update account with new credentials
@@ -376,8 +391,8 @@ const handleExchangeCode = async () => {
emit('reauthorized', updatedAccount) emit('reauthorized', updatedAccount)
handleClose() handleClose()
} catch (error: any) { } catch (error: any) {
openaiOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed') oauthClient.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
appStore.showError(openaiOAuth.error.value) appStore.showError(oauthClient.error.value)
} }
} else if (isGemini.value) { } else if (isGemini.value) {
const sessionId = geminiOAuth.sessionId.value const sessionId = geminiOAuth.sessionId.value
@@ -490,7 +505,7 @@ const handleExchangeCode = async () => {
} }
const handleCookieAuth = async (sessionKey: string) => { const handleCookieAuth = async (sessionKey: string) => {
if (!props.account || isOpenAI.value) return if (!props.account || isOpenAILike.value) return
claudeOAuth.loading.value = true claudeOAuth.loading.value = true
claudeOAuth.error.value = '' claudeOAuth.error.value = ''

View File

@@ -3,7 +3,7 @@ import { useAppStore } from '@/stores/app'
import { adminAPI } from '@/api/admin' import { adminAPI } from '@/api/admin'
export type AddMethod = 'oauth' | 'setup-token' export type AddMethod = 'oauth' | 'setup-token'
export type AuthInputMethod = 'manual' | 'cookie' | 'refresh_token' export type AuthInputMethod = 'manual' | 'cookie' | 'refresh_token' | 'session_token'
export interface OAuthState { export interface OAuthState {
authUrl: string authUrl: string

View File

@@ -19,12 +19,21 @@ export interface OpenAITokenInfo {
[key: string]: unknown [key: string]: unknown
} }
export function useOpenAIOAuth() { export type OpenAIOAuthPlatform = 'openai' | 'sora'
interface UseOpenAIOAuthOptions {
platform?: OpenAIOAuthPlatform
}
export function useOpenAIOAuth(options?: UseOpenAIOAuthOptions) {
const appStore = useAppStore() const appStore = useAppStore()
const oauthPlatform = options?.platform ?? 'openai'
const endpointPrefix = oauthPlatform === 'sora' ? '/admin/sora' : '/admin/openai'
// State // State
const authUrl = ref('') const authUrl = ref('')
const sessionId = ref('') const sessionId = ref('')
const oauthState = ref('')
const loading = ref(false) const loading = ref(false)
const error = ref('') const error = ref('')
@@ -32,6 +41,7 @@ export function useOpenAIOAuth() {
const resetState = () => { const resetState = () => {
authUrl.value = '' authUrl.value = ''
sessionId.value = '' sessionId.value = ''
oauthState.value = ''
loading.value = false loading.value = false
error.value = '' error.value = ''
} }
@@ -44,6 +54,7 @@ export function useOpenAIOAuth() {
loading.value = true loading.value = true
authUrl.value = '' authUrl.value = ''
sessionId.value = '' sessionId.value = ''
oauthState.value = ''
error.value = '' error.value = ''
try { try {
@@ -56,11 +67,17 @@ export function useOpenAIOAuth() {
} }
const response = await adminAPI.accounts.generateAuthUrl( const response = await adminAPI.accounts.generateAuthUrl(
'/admin/openai/generate-auth-url', `${endpointPrefix}/generate-auth-url`,
payload payload
) )
authUrl.value = response.auth_url authUrl.value = response.auth_url
sessionId.value = response.session_id sessionId.value = response.session_id
try {
const parsed = new URL(response.auth_url)
oauthState.value = parsed.searchParams.get('state') || ''
} catch {
oauthState.value = ''
}
return true return true
} catch (err: any) { } catch (err: any) {
error.value = err.response?.data?.detail || 'Failed to generate OpenAI auth URL' error.value = err.response?.data?.detail || 'Failed to generate OpenAI auth URL'
@@ -75,10 +92,11 @@ export function useOpenAIOAuth() {
const exchangeAuthCode = async ( const exchangeAuthCode = async (
code: string, code: string,
currentSessionId: string, currentSessionId: string,
state: string,
proxyId?: number | null proxyId?: number | null
): Promise<OpenAITokenInfo | null> => { ): Promise<OpenAITokenInfo | null> => {
if (!code.trim() || !currentSessionId) { if (!code.trim() || !currentSessionId || !state.trim()) {
error.value = 'Missing auth code or session ID' error.value = 'Missing auth code, session ID, or state'
return null return null
} }
@@ -86,15 +104,16 @@ export function useOpenAIOAuth() {
error.value = '' error.value = ''
try { try {
const payload: { session_id: string; code: string; proxy_id?: number } = { const payload: { session_id: string; code: string; state: string; proxy_id?: number } = {
session_id: currentSessionId, session_id: currentSessionId,
code: code.trim() code: code.trim(),
state: state.trim()
} }
if (proxyId) { if (proxyId) {
payload.proxy_id = proxyId payload.proxy_id = proxyId
} }
const tokenInfo = await adminAPI.accounts.exchangeCode('/admin/openai/exchange-code', payload) const tokenInfo = await adminAPI.accounts.exchangeCode(`${endpointPrefix}/exchange-code`, payload)
return tokenInfo as OpenAITokenInfo return tokenInfo as OpenAITokenInfo
} catch (err: any) { } catch (err: any) {
error.value = err.response?.data?.detail || 'Failed to exchange OpenAI auth code' error.value = err.response?.data?.detail || 'Failed to exchange OpenAI auth code'
@@ -120,7 +139,11 @@ export function useOpenAIOAuth() {
try { try {
// Use dedicated refresh-token endpoint // Use dedicated refresh-token endpoint
const tokenInfo = await adminAPI.accounts.refreshOpenAIToken(refreshToken.trim(), proxyId) const tokenInfo = await adminAPI.accounts.refreshOpenAIToken(
refreshToken.trim(),
proxyId,
`${endpointPrefix}/refresh-token`
)
return tokenInfo as OpenAITokenInfo return tokenInfo as OpenAITokenInfo
} catch (err: any) { } catch (err: any) {
error.value = err.response?.data?.detail || 'Failed to validate refresh token' error.value = err.response?.data?.detail || 'Failed to validate refresh token'
@@ -131,6 +154,33 @@ export function useOpenAIOAuth() {
} }
} }
// Validate Sora session token and get access token
const validateSessionToken = async (
sessionToken: string,
proxyId?: number | null
): Promise<OpenAITokenInfo | null> => {
if (!sessionToken.trim()) {
error.value = 'Missing session token'
return null
}
loading.value = true
error.value = ''
try {
const tokenInfo = await adminAPI.accounts.validateSoraSessionToken(
sessionToken.trim(),
proxyId,
`${endpointPrefix}/st2at`
)
return tokenInfo as OpenAITokenInfo
} catch (err: any) {
error.value = err.response?.data?.detail || 'Failed to validate session token'
appStore.showError(error.value)
return null
} finally {
loading.value = false
}
}
// Build credentials for OpenAI OAuth account // Build credentials for OpenAI OAuth account
const buildCredentials = (tokenInfo: OpenAITokenInfo): Record<string, unknown> => { const buildCredentials = (tokenInfo: OpenAITokenInfo): Record<string, unknown> => {
const creds: Record<string, unknown> = { const creds: Record<string, unknown> = {
@@ -172,6 +222,7 @@ export function useOpenAIOAuth() {
// State // State
authUrl, authUrl,
sessionId, sessionId,
oauthState,
loading, loading,
error, error,
// Methods // Methods
@@ -179,6 +230,7 @@ export function useOpenAIOAuth() {
generateAuthUrl, generateAuthUrl,
exchangeAuthCode, exchangeAuthCode,
validateRefreshToken, validateRefreshToken,
validateSessionToken,
buildCredentials, buildCredentials,
buildExtraInfo buildExtraInfo
} }

View File

@@ -1740,9 +1740,13 @@ export default {
refreshTokenAuth: 'Manual RT Input', refreshTokenAuth: 'Manual RT Input',
refreshTokenDesc: 'Enter your existing OpenAI Refresh Token(s). Supports batch input (one per line). The system will automatically validate and create accounts.', refreshTokenDesc: 'Enter your existing OpenAI Refresh Token(s). Supports batch input (one per line). The system will automatically validate and create accounts.',
refreshTokenPlaceholder: 'Paste your OpenAI Refresh Token...\nSupports multiple, one per line', refreshTokenPlaceholder: 'Paste your OpenAI Refresh Token...\nSupports multiple, one per line',
sessionTokenAuth: 'Manual ST Input',
sessionTokenDesc: 'Enter your existing Sora Session Token(s). Supports batch input (one per line). The system will automatically validate and create accounts.',
sessionTokenPlaceholder: 'Paste your Sora Session Token...\nSupports multiple, one per line',
validating: 'Validating...', validating: 'Validating...',
validateAndCreate: 'Validate & Create Account', validateAndCreate: 'Validate & Create Account',
pleaseEnterRefreshToken: 'Please enter Refresh Token' pleaseEnterRefreshToken: 'Please enter Refresh Token',
pleaseEnterSessionToken: 'Please enter Session Token'
}, },
// Gemini specific // Gemini specific
gemini: { gemini: {
@@ -1963,6 +1967,7 @@ export default {
reAuthorizeAccount: 'Re-Authorize Account', reAuthorizeAccount: 'Re-Authorize Account',
claudeCodeAccount: 'Claude Code Account', claudeCodeAccount: 'Claude Code Account',
openaiAccount: 'OpenAI Account', openaiAccount: 'OpenAI Account',
soraAccount: 'Sora Account',
geminiAccount: 'Gemini Account', geminiAccount: 'Gemini Account',
antigravityAccount: 'Antigravity Account', antigravityAccount: 'Antigravity Account',
inputMethod: 'Input Method', inputMethod: 'Input Method',

View File

@@ -1879,9 +1879,13 @@ export default {
refreshTokenAuth: '手动输入 RT', refreshTokenAuth: '手动输入 RT',
refreshTokenDesc: '输入您已有的 OpenAI Refresh Token支持批量输入每行一个系统将自动验证并创建账号。', refreshTokenDesc: '输入您已有的 OpenAI Refresh Token支持批量输入每行一个系统将自动验证并创建账号。',
refreshTokenPlaceholder: '粘贴您的 OpenAI Refresh Token...\n支持多个每行一个', refreshTokenPlaceholder: '粘贴您的 OpenAI Refresh Token...\n支持多个每行一个',
sessionTokenAuth: '手动输入 ST',
sessionTokenDesc: '输入您已有的 Sora Session Token支持批量输入每行一个系统将自动验证并创建账号。',
sessionTokenPlaceholder: '粘贴您的 Sora Session Token...\n支持多个每行一个',
validating: '验证中...', validating: '验证中...',
validateAndCreate: '验证并创建账号', validateAndCreate: '验证并创建账号',
pleaseEnterRefreshToken: '请输入 Refresh Token' pleaseEnterRefreshToken: '请输入 Refresh Token',
pleaseEnterSessionToken: '请输入 Session Token'
}, },
// Gemini specific // Gemini specific
gemini: { gemini: {
@@ -2097,6 +2101,7 @@ export default {
reAuthorizeAccount: '重新授权账号', reAuthorizeAccount: '重新授权账号',
claudeCodeAccount: 'Claude Code 账号', claudeCodeAccount: 'Claude Code 账号',
openaiAccount: 'OpenAI 账号', openaiAccount: 'OpenAI 账号',
soraAccount: 'Sora 账号',
geminiAccount: 'Gemini 账号', geminiAccount: 'Gemini 账号',
antigravityAccount: 'Antigravity 账号', antigravityAccount: 'Antigravity 账号',
inputMethod: '输入方式', inputMethod: '输入方式',