feat(sora): 对齐 Sora OAuth 流程并隔离网关请求路径
- 新增并接通 Sora 专用 OAuth 接口与 ST/RT 换取能力 - 完成前端 Sora 授权、RT/ST 手动导入与账号创建流程 - 强化 Sora token 恢复、转发日志与网关路由隔离行为 - 补充后端服务层与路由层相关测试覆盖 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -162,6 +162,8 @@ type TokenRefreshConfig struct {
|
|||||||
MaxRetries int `mapstructure:"max_retries"`
|
MaxRetries int `mapstructure:"max_retries"`
|
||||||
// 重试退避基础时间(秒)
|
// 重试退避基础时间(秒)
|
||||||
RetryBackoffSeconds int `mapstructure:"retry_backoff_seconds"`
|
RetryBackoffSeconds int `mapstructure:"retry_backoff_seconds"`
|
||||||
|
// 是否允许 OpenAI 刷新器同步覆盖关联的 Sora 账号 token(默认关闭)
|
||||||
|
SyncLinkedSoraAccounts bool `mapstructure:"sync_linked_sora_accounts"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type PricingConfig struct {
|
type PricingConfig struct {
|
||||||
@@ -277,6 +279,7 @@ type SoraClientConfig struct {
|
|||||||
RecentTaskLimit int `mapstructure:"recent_task_limit"`
|
RecentTaskLimit int `mapstructure:"recent_task_limit"`
|
||||||
RecentTaskLimitMax int `mapstructure:"recent_task_limit_max"`
|
RecentTaskLimitMax int `mapstructure:"recent_task_limit_max"`
|
||||||
Debug bool `mapstructure:"debug"`
|
Debug bool `mapstructure:"debug"`
|
||||||
|
UseOpenAITokenProvider bool `mapstructure:"use_openai_token_provider"`
|
||||||
Headers map[string]string `mapstructure:"headers"`
|
Headers map[string]string `mapstructure:"headers"`
|
||||||
UserAgent string `mapstructure:"user_agent"`
|
UserAgent string `mapstructure:"user_agent"`
|
||||||
DisableTLSFingerprint bool `mapstructure:"disable_tls_fingerprint"`
|
DisableTLSFingerprint bool `mapstructure:"disable_tls_fingerprint"`
|
||||||
@@ -1116,6 +1119,7 @@ func setDefaults() {
|
|||||||
viper.SetDefault("sora.client.recent_task_limit", 50)
|
viper.SetDefault("sora.client.recent_task_limit", 50)
|
||||||
viper.SetDefault("sora.client.recent_task_limit_max", 200)
|
viper.SetDefault("sora.client.recent_task_limit_max", 200)
|
||||||
viper.SetDefault("sora.client.debug", false)
|
viper.SetDefault("sora.client.debug", false)
|
||||||
|
viper.SetDefault("sora.client.use_openai_token_provider", false)
|
||||||
viper.SetDefault("sora.client.headers", map[string]string{})
|
viper.SetDefault("sora.client.headers", map[string]string{})
|
||||||
viper.SetDefault("sora.client.user_agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
|
viper.SetDefault("sora.client.user_agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
|
||||||
viper.SetDefault("sora.client.disable_tls_fingerprint", false)
|
viper.SetDefault("sora.client.disable_tls_fingerprint", false)
|
||||||
@@ -1137,6 +1141,7 @@ func setDefaults() {
|
|||||||
viper.SetDefault("token_refresh.refresh_before_expiry_hours", 0.5) // 提前30分钟刷新(适配Google 1小时token)
|
viper.SetDefault("token_refresh.refresh_before_expiry_hours", 0.5) // 提前30分钟刷新(适配Google 1小时token)
|
||||||
viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次
|
viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次
|
||||||
viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒
|
viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒
|
||||||
|
viper.SetDefault("token_refresh.sync_linked_sora_accounts", false) // 默认不跨平台覆盖 Sora token
|
||||||
|
|
||||||
// Gemini OAuth - configure via environment variables or config file
|
// Gemini OAuth - configure via environment variables or config file
|
||||||
// GEMINI_OAUTH_CLIENT_ID and GEMINI_OAUTH_CLIENT_SECRET
|
// GEMINI_OAUTH_CLIENT_ID and GEMINI_OAUTH_CLIENT_SECRET
|
||||||
|
|||||||
@@ -1333,6 +1333,12 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle Sora accounts
|
||||||
|
if account.Platform == service.PlatformSora {
|
||||||
|
response.Success(c, service.DefaultSoraModels(nil))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Handle Claude/Anthropic accounts
|
// Handle Claude/Anthropic accounts
|
||||||
// For OAuth and Setup-Token accounts: return default models
|
// For OAuth and Setup-Token accounts: return default models
|
||||||
if account.IsOAuth() {
|
if account.IsOAuth() {
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package admin
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
@@ -16,6 +17,13 @@ type OpenAIOAuthHandler struct {
|
|||||||
adminService service.AdminService
|
adminService service.AdminService
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func oauthPlatformFromPath(c *gin.Context) string {
|
||||||
|
if strings.Contains(c.FullPath(), "/admin/sora/") {
|
||||||
|
return service.PlatformSora
|
||||||
|
}
|
||||||
|
return service.PlatformOpenAI
|
||||||
|
}
|
||||||
|
|
||||||
// NewOpenAIOAuthHandler creates a new OpenAI OAuth handler
|
// NewOpenAIOAuthHandler creates a new OpenAI OAuth handler
|
||||||
func NewOpenAIOAuthHandler(openaiOAuthService *service.OpenAIOAuthService, adminService service.AdminService) *OpenAIOAuthHandler {
|
func NewOpenAIOAuthHandler(openaiOAuthService *service.OpenAIOAuthService, adminService service.AdminService) *OpenAIOAuthHandler {
|
||||||
return &OpenAIOAuthHandler{
|
return &OpenAIOAuthHandler{
|
||||||
@@ -52,6 +60,7 @@ func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) {
|
|||||||
type OpenAIExchangeCodeRequest struct {
|
type OpenAIExchangeCodeRequest struct {
|
||||||
SessionID string `json:"session_id" binding:"required"`
|
SessionID string `json:"session_id" binding:"required"`
|
||||||
Code string `json:"code" binding:"required"`
|
Code string `json:"code" binding:"required"`
|
||||||
|
State string `json:"state" binding:"required"`
|
||||||
RedirectURI string `json:"redirect_uri"`
|
RedirectURI string `json:"redirect_uri"`
|
||||||
ProxyID *int64 `json:"proxy_id"`
|
ProxyID *int64 `json:"proxy_id"`
|
||||||
}
|
}
|
||||||
@@ -68,6 +77,7 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) {
|
|||||||
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
|
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
|
||||||
SessionID: req.SessionID,
|
SessionID: req.SessionID,
|
||||||
Code: req.Code,
|
Code: req.Code,
|
||||||
|
State: req.State,
|
||||||
RedirectURI: req.RedirectURI,
|
RedirectURI: req.RedirectURI,
|
||||||
ProxyID: req.ProxyID,
|
ProxyID: req.ProxyID,
|
||||||
})
|
})
|
||||||
@@ -81,18 +91,29 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) {
|
|||||||
|
|
||||||
// OpenAIRefreshTokenRequest represents the request for refreshing OpenAI token
|
// OpenAIRefreshTokenRequest represents the request for refreshing OpenAI token
|
||||||
type OpenAIRefreshTokenRequest struct {
|
type OpenAIRefreshTokenRequest struct {
|
||||||
RefreshToken string `json:"refresh_token" binding:"required"`
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
RT string `json:"rt"`
|
||||||
|
ClientID string `json:"client_id"`
|
||||||
ProxyID *int64 `json:"proxy_id"`
|
ProxyID *int64 `json:"proxy_id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// RefreshToken refreshes an OpenAI OAuth token
|
// RefreshToken refreshes an OpenAI OAuth token
|
||||||
// POST /api/v1/admin/openai/refresh-token
|
// POST /api/v1/admin/openai/refresh-token
|
||||||
|
// POST /api/v1/admin/sora/rt2at
|
||||||
func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
|
func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
|
||||||
var req OpenAIRefreshTokenRequest
|
var req OpenAIRefreshTokenRequest
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
refreshToken := strings.TrimSpace(req.RefreshToken)
|
||||||
|
if refreshToken == "" {
|
||||||
|
refreshToken = strings.TrimSpace(req.RT)
|
||||||
|
}
|
||||||
|
if refreshToken == "" {
|
||||||
|
response.BadRequest(c, "refresh_token is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var proxyURL string
|
var proxyURL string
|
||||||
if req.ProxyID != nil {
|
if req.ProxyID != nil {
|
||||||
@@ -102,7 +123,7 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenInfo, err := h.openaiOAuthService.RefreshToken(c.Request.Context(), req.RefreshToken, proxyURL)
|
tokenInfo, err := h.openaiOAuthService.RefreshTokenWithClientID(c.Request.Context(), refreshToken, proxyURL, strings.TrimSpace(req.ClientID))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
@@ -111,8 +132,39 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
|
|||||||
response.Success(c, tokenInfo)
|
response.Success(c, tokenInfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RefreshAccountToken refreshes token for a specific OpenAI account
|
// ExchangeSoraSessionToken exchanges Sora session token to access token
|
||||||
|
// POST /api/v1/admin/sora/st2at
|
||||||
|
func (h *OpenAIOAuthHandler) ExchangeSoraSessionToken(c *gin.Context) {
|
||||||
|
var req struct {
|
||||||
|
SessionToken string `json:"session_token"`
|
||||||
|
ST string `json:"st"`
|
||||||
|
ProxyID *int64 `json:"proxy_id"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionToken := strings.TrimSpace(req.SessionToken)
|
||||||
|
if sessionToken == "" {
|
||||||
|
sessionToken = strings.TrimSpace(req.ST)
|
||||||
|
}
|
||||||
|
if sessionToken == "" {
|
||||||
|
response.BadRequest(c, "session_token is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenInfo, err := h.openaiOAuthService.ExchangeSoraSessionToken(c.Request.Context(), sessionToken, req.ProxyID)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, tokenInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshAccountToken refreshes token for a specific OpenAI/Sora account
|
||||||
// POST /api/v1/admin/openai/accounts/:id/refresh
|
// POST /api/v1/admin/openai/accounts/:id/refresh
|
||||||
|
// POST /api/v1/admin/sora/accounts/:id/refresh
|
||||||
func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
|
func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
|
||||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -127,9 +179,9 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure account is OpenAI platform
|
platform := oauthPlatformFromPath(c)
|
||||||
if !account.IsOpenAI() {
|
if account.Platform != platform {
|
||||||
response.BadRequest(c, "Account is not an OpenAI account")
|
response.BadRequest(c, "Account platform does not match OAuth endpoint")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -167,12 +219,14 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
|
|||||||
response.Success(c, dto.AccountFromService(updatedAccount))
|
response.Success(c, dto.AccountFromService(updatedAccount))
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateAccountFromOAuth creates a new OpenAI OAuth account from token info
|
// CreateAccountFromOAuth creates a new OpenAI/Sora OAuth account from token info
|
||||||
// POST /api/v1/admin/openai/create-from-oauth
|
// POST /api/v1/admin/openai/create-from-oauth
|
||||||
|
// POST /api/v1/admin/sora/create-from-oauth
|
||||||
func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
|
func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
|
||||||
var req struct {
|
var req struct {
|
||||||
SessionID string `json:"session_id" binding:"required"`
|
SessionID string `json:"session_id" binding:"required"`
|
||||||
Code string `json:"code" binding:"required"`
|
Code string `json:"code" binding:"required"`
|
||||||
|
State string `json:"state" binding:"required"`
|
||||||
RedirectURI string `json:"redirect_uri"`
|
RedirectURI string `json:"redirect_uri"`
|
||||||
ProxyID *int64 `json:"proxy_id"`
|
ProxyID *int64 `json:"proxy_id"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
@@ -189,6 +243,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
|
|||||||
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
|
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
|
||||||
SessionID: req.SessionID,
|
SessionID: req.SessionID,
|
||||||
Code: req.Code,
|
Code: req.Code,
|
||||||
|
State: req.State,
|
||||||
RedirectURI: req.RedirectURI,
|
RedirectURI: req.RedirectURI,
|
||||||
ProxyID: req.ProxyID,
|
ProxyID: req.ProxyID,
|
||||||
})
|
})
|
||||||
@@ -200,19 +255,25 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
|
|||||||
// Build credentials from token info
|
// Build credentials from token info
|
||||||
credentials := h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
|
credentials := h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||||
|
|
||||||
|
platform := oauthPlatformFromPath(c)
|
||||||
|
|
||||||
// Use email as default name if not provided
|
// Use email as default name if not provided
|
||||||
name := req.Name
|
name := req.Name
|
||||||
if name == "" && tokenInfo.Email != "" {
|
if name == "" && tokenInfo.Email != "" {
|
||||||
name = tokenInfo.Email
|
name = tokenInfo.Email
|
||||||
}
|
}
|
||||||
if name == "" {
|
if name == "" {
|
||||||
|
if platform == service.PlatformSora {
|
||||||
|
name = "Sora OAuth Account"
|
||||||
|
} else {
|
||||||
name = "OpenAI OAuth Account"
|
name = "OpenAI OAuth Account"
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Create account
|
// Create account
|
||||||
account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{
|
account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{
|
||||||
Name: name,
|
Name: name,
|
||||||
Platform: "openai",
|
Platform: platform,
|
||||||
Type: "oauth",
|
Type: "oauth",
|
||||||
Credentials: credentials,
|
Credentials: credentials,
|
||||||
ProxyID: req.ProxyID,
|
ProxyID: req.ProxyID,
|
||||||
|
|||||||
@@ -212,6 +212,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
switchCount := 0
|
switchCount := 0
|
||||||
failedAccountIDs := make(map[int64]struct{})
|
failedAccountIDs := make(map[int64]struct{})
|
||||||
lastFailoverStatus := 0
|
lastFailoverStatus := 0
|
||||||
|
var lastFailoverBody []byte
|
||||||
|
|
||||||
for {
|
for {
|
||||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "")
|
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "")
|
||||||
@@ -224,7 +225,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverBody, streamStarted)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
account := selection.Account
|
account := selection.Account
|
||||||
@@ -287,14 +288,19 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
failedAccountIDs[account.ID] = struct{}{}
|
failedAccountIDs[account.ID] = struct{}{}
|
||||||
if switchCount >= maxAccountSwitches {
|
if switchCount >= maxAccountSwitches {
|
||||||
lastFailoverStatus = failoverErr.StatusCode
|
lastFailoverStatus = failoverErr.StatusCode
|
||||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
lastFailoverBody = failoverErr.ResponseBody
|
||||||
|
h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverBody, streamStarted)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
lastFailoverStatus = failoverErr.StatusCode
|
lastFailoverStatus = failoverErr.StatusCode
|
||||||
|
lastFailoverBody = failoverErr.ResponseBody
|
||||||
switchCount++
|
switchCount++
|
||||||
|
upstreamErrCode, upstreamErrMsg := extractUpstreamErrorCodeAndMessage(lastFailoverBody)
|
||||||
reqLog.Warn("sora.upstream_failover_switching",
|
reqLog.Warn("sora.upstream_failover_switching",
|
||||||
zap.Int64("account_id", account.ID),
|
zap.Int64("account_id", account.ID),
|
||||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||||
|
zap.String("upstream_error_code", upstreamErrCode),
|
||||||
|
zap.String("upstream_error_message", upstreamErrMsg),
|
||||||
zap.Int("switch_count", switchCount),
|
zap.Int("switch_count", switchCount),
|
||||||
zap.Int("max_switches", maxAccountSwitches),
|
zap.Int("max_switches", maxAccountSwitches),
|
||||||
)
|
)
|
||||||
@@ -360,17 +366,32 @@ func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, s
|
|||||||
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
|
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) {
|
func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, responseBody []byte, streamStarted bool) {
|
||||||
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
status, errType, errMsg := h.mapUpstreamError(statusCode, responseBody)
|
||||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *SoraGatewayHandler) mapUpstreamError(statusCode int) (int, string, string) {
|
func (h *SoraGatewayHandler) mapUpstreamError(statusCode int, responseBody []byte) (int, string, string) {
|
||||||
|
upstreamCode, upstreamMessage := extractUpstreamErrorCodeAndMessage(responseBody)
|
||||||
|
if upstreamMessage != "" {
|
||||||
|
switch statusCode {
|
||||||
|
case 401, 403, 404, 500, 502, 503, 504:
|
||||||
|
return http.StatusBadGateway, "upstream_error", upstreamMessage
|
||||||
|
case 429:
|
||||||
|
return http.StatusTooManyRequests, "rate_limit_error", upstreamMessage
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
switch statusCode {
|
switch statusCode {
|
||||||
case 401:
|
case 401:
|
||||||
return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator"
|
return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator"
|
||||||
case 403:
|
case 403:
|
||||||
return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator"
|
return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator"
|
||||||
|
case 404:
|
||||||
|
if strings.EqualFold(upstreamCode, "unsupported_country_code") {
|
||||||
|
return http.StatusBadGateway, "upstream_error", "Upstream region capability unavailable for this account, please contact administrator"
|
||||||
|
}
|
||||||
|
return http.StatusBadGateway, "upstream_error", "Upstream capability unavailable for this account, please contact administrator"
|
||||||
case 429:
|
case 429:
|
||||||
return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later"
|
return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later"
|
||||||
case 529:
|
case 529:
|
||||||
@@ -382,6 +403,41 @@ func (h *SoraGatewayHandler) mapUpstreamError(statusCode int) (int, string, stri
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func extractUpstreamErrorCodeAndMessage(body []byte) (string, string) {
|
||||||
|
trimmed := strings.TrimSpace(string(body))
|
||||||
|
if trimmed == "" {
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
if !gjson.Valid(trimmed) {
|
||||||
|
return "", truncateSoraErrorMessage(trimmed, 256)
|
||||||
|
}
|
||||||
|
code := strings.TrimSpace(gjson.Get(trimmed, "error.code").String())
|
||||||
|
if code == "" {
|
||||||
|
code = strings.TrimSpace(gjson.Get(trimmed, "code").String())
|
||||||
|
}
|
||||||
|
message := strings.TrimSpace(gjson.Get(trimmed, "error.message").String())
|
||||||
|
if message == "" {
|
||||||
|
message = strings.TrimSpace(gjson.Get(trimmed, "message").String())
|
||||||
|
}
|
||||||
|
if message == "" {
|
||||||
|
message = strings.TrimSpace(gjson.Get(trimmed, "error.detail").String())
|
||||||
|
}
|
||||||
|
if message == "" {
|
||||||
|
message = strings.TrimSpace(gjson.Get(trimmed, "detail").String())
|
||||||
|
}
|
||||||
|
return code, truncateSoraErrorMessage(message, 512)
|
||||||
|
}
|
||||||
|
|
||||||
|
func truncateSoraErrorMessage(s string, maxLen int) string {
|
||||||
|
if maxLen <= 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if len(s) <= maxLen {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return s[:maxLen] + "...(truncated)"
|
||||||
|
}
|
||||||
|
|
||||||
func (h *SoraGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
|
func (h *SoraGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
|
||||||
if streamStarted {
|
if streamStarted {
|
||||||
flusher, ok := c.Writer.(http.Flusher)
|
flusher, ok := c.Writer.(http.Flusher)
|
||||||
|
|||||||
@@ -43,6 +43,9 @@ func (s *stubSoraClient) CreateImageTask(ctx context.Context, account *service.A
|
|||||||
func (s *stubSoraClient) CreateVideoTask(ctx context.Context, account *service.Account, req service.SoraVideoRequest) (string, error) {
|
func (s *stubSoraClient) CreateVideoTask(ctx context.Context, account *service.Account, req service.SoraVideoRequest) (string, error) {
|
||||||
return "task-video", nil
|
return "task-video", nil
|
||||||
}
|
}
|
||||||
|
func (s *stubSoraClient) EnhancePrompt(ctx context.Context, account *service.Account, prompt, expansionLevel string, durationS int) (string, error) {
|
||||||
|
return "enhanced prompt", nil
|
||||||
|
}
|
||||||
func (s *stubSoraClient) GetImageTask(ctx context.Context, account *service.Account, taskID string) (*service.SoraImageTaskStatus, error) {
|
func (s *stubSoraClient) GetImageTask(ctx context.Context, account *service.Account, taskID string) (*service.SoraImageTaskStatus, error) {
|
||||||
return &service.SoraImageTaskStatus{ID: taskID, Status: "completed", URLs: s.imageURLs}, nil
|
return &service.SoraImageTaskStatus{ID: taskID, Status: "completed", URLs: s.imageURLs}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,6 +17,8 @@ import (
|
|||||||
const (
|
const (
|
||||||
// OAuth Client ID for OpenAI (Codex CLI official)
|
// OAuth Client ID for OpenAI (Codex CLI official)
|
||||||
ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||||
|
// OAuth Client ID for Sora mobile flow (aligned with sora2api)
|
||||||
|
SoraClientID = "app_LlGpXReQgckcGGUo2JrYvtJK"
|
||||||
|
|
||||||
// OAuth endpoints
|
// OAuth endpoints
|
||||||
AuthorizeURL = "https://auth.openai.com/oauth/authorize"
|
AuthorizeURL = "https://auth.openai.com/oauth/authorize"
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
@@ -56,12 +57,49 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
|
func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
|
||||||
|
return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openaiOAuthService) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
|
||||||
|
if strings.TrimSpace(clientID) != "" {
|
||||||
|
return s.refreshTokenWithClientID(ctx, refreshToken, proxyURL, strings.TrimSpace(clientID))
|
||||||
|
}
|
||||||
|
|
||||||
|
clientIDs := []string{
|
||||||
|
openai.ClientID,
|
||||||
|
openai.SoraClientID,
|
||||||
|
}
|
||||||
|
seen := make(map[string]struct{}, len(clientIDs))
|
||||||
|
var lastErr error
|
||||||
|
for _, clientID := range clientIDs {
|
||||||
|
clientID = strings.TrimSpace(clientID)
|
||||||
|
if clientID == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := seen[clientID]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[clientID] = struct{}{}
|
||||||
|
|
||||||
|
tokenResp, err := s.refreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID)
|
||||||
|
if err == nil {
|
||||||
|
return tokenResp, nil
|
||||||
|
}
|
||||||
|
lastErr = err
|
||||||
|
}
|
||||||
|
if lastErr != nil {
|
||||||
|
return nil, lastErr
|
||||||
|
}
|
||||||
|
return nil, infraerrors.New(http.StatusBadGateway, "OPENAI_OAUTH_TOKEN_REFRESH_FAILED", "token refresh failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openaiOAuthService) refreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL, clientID string) (*openai.TokenResponse, error) {
|
||||||
client := createOpenAIReqClient(proxyURL)
|
client := createOpenAIReqClient(proxyURL)
|
||||||
|
|
||||||
formData := url.Values{}
|
formData := url.Values{}
|
||||||
formData.Set("grant_type", "refresh_token")
|
formData.Set("grant_type", "refresh_token")
|
||||||
formData.Set("refresh_token", refreshToken)
|
formData.Set("refresh_token", refreshToken)
|
||||||
formData.Set("client_id", openai.ClientID)
|
formData.Set("client_id", clientID)
|
||||||
formData.Set("scope", openai.RefreshScopes)
|
formData.Set("scope", openai.RefreshScopes)
|
||||||
|
|
||||||
var tokenResp openai.TokenResponse
|
var tokenResp openai.TokenResponse
|
||||||
|
|||||||
@@ -136,6 +136,60 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FormFields() {
|
|||||||
require.Equal(s.T(), "rt2", resp.RefreshToken)
|
require.Equal(s.T(), "rt2", resp.RefreshToken)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FallbackToSoraClientID() {
|
||||||
|
var seenClientIDs []string
|
||||||
|
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if err := r.ParseForm(); err != nil {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
clientID := r.PostForm.Get("client_id")
|
||||||
|
seenClientIDs = append(seenClientIDs, clientID)
|
||||||
|
if clientID == openai.ClientID {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
_, _ = io.WriteString(w, "invalid_grant")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if clientID == openai.SoraClientID {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = io.WriteString(w, `{"access_token":"at-sora","refresh_token":"rt-sora","token_type":"bearer","expires_in":3600}`)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
}))
|
||||||
|
|
||||||
|
resp, err := s.svc.RefreshToken(s.ctx, "rt", "")
|
||||||
|
require.NoError(s.T(), err, "RefreshToken")
|
||||||
|
require.Equal(s.T(), "at-sora", resp.AccessToken)
|
||||||
|
require.Equal(s.T(), "rt-sora", resp.RefreshToken)
|
||||||
|
require.Equal(s.T(), []string{openai.ClientID, openai.SoraClientID}, seenClientIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseProvidedClientID() {
|
||||||
|
const customClientID = "custom-client-id"
|
||||||
|
var seenClientIDs []string
|
||||||
|
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if err := r.ParseForm(); err != nil {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
clientID := r.PostForm.Get("client_id")
|
||||||
|
seenClientIDs = append(seenClientIDs, clientID)
|
||||||
|
if clientID != customClientID {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = io.WriteString(w, `{"access_token":"at-custom","refresh_token":"rt-custom","token_type":"bearer","expires_in":3600}`)
|
||||||
|
}))
|
||||||
|
|
||||||
|
resp, err := s.svc.RefreshTokenWithClientID(s.ctx, "rt", "", customClientID)
|
||||||
|
require.NoError(s.T(), err, "RefreshTokenWithClientID")
|
||||||
|
require.Equal(s.T(), "at-custom", resp.AccessToken)
|
||||||
|
require.Equal(s.T(), "rt-custom", resp.RefreshToken)
|
||||||
|
require.Equal(s.T(), []string{customClientID}, seenClientIDs)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *OpenAIOAuthServiceSuite) TestNonSuccessStatus_IncludesBody() {
|
func (s *OpenAIOAuthServiceSuite) TestNonSuccessStatus_IncludesBody() {
|
||||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusBadRequest)
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
|||||||
@@ -34,6 +34,8 @@ func RegisterAdminRoutes(
|
|||||||
|
|
||||||
// OpenAI OAuth
|
// OpenAI OAuth
|
||||||
registerOpenAIOAuthRoutes(admin, h)
|
registerOpenAIOAuthRoutes(admin, h)
|
||||||
|
// Sora OAuth(实现复用 OpenAI OAuth 服务,入口独立)
|
||||||
|
registerSoraOAuthRoutes(admin, h)
|
||||||
|
|
||||||
// Gemini OAuth
|
// Gemini OAuth
|
||||||
registerGeminiOAuthRoutes(admin, h)
|
registerGeminiOAuthRoutes(admin, h)
|
||||||
@@ -276,6 +278,19 @@ func registerOpenAIOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func registerSoraOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||||
|
sora := admin.Group("/sora")
|
||||||
|
{
|
||||||
|
sora.POST("/generate-auth-url", h.Admin.OpenAIOAuth.GenerateAuthURL)
|
||||||
|
sora.POST("/exchange-code", h.Admin.OpenAIOAuth.ExchangeCode)
|
||||||
|
sora.POST("/refresh-token", h.Admin.OpenAIOAuth.RefreshToken)
|
||||||
|
sora.POST("/st2at", h.Admin.OpenAIOAuth.ExchangeSoraSessionToken)
|
||||||
|
sora.POST("/rt2at", h.Admin.OpenAIOAuth.RefreshToken)
|
||||||
|
sora.POST("/accounts/:id/refresh", h.Admin.OpenAIOAuth.RefreshAccountToken)
|
||||||
|
sora.POST("/create-from-oauth", h.Admin.OpenAIOAuth.CreateAccountFromOAuth)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func registerGeminiOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
func registerGeminiOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||||
gemini := admin.Group("/gemini")
|
gemini := admin.Group("/gemini")
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
package routes
|
package routes
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||||
@@ -41,16 +43,15 @@ func RegisterGatewayRoutes(
|
|||||||
gateway.GET("/usage", h.Gateway.Usage)
|
gateway.GET("/usage", h.Gateway.Usage)
|
||||||
// OpenAI Responses API
|
// OpenAI Responses API
|
||||||
gateway.POST("/responses", h.OpenAIGateway.Responses)
|
gateway.POST("/responses", h.OpenAIGateway.Responses)
|
||||||
}
|
// 明确阻止旧入口误用到 Sora,避免客户端把 OpenAI Chat Completions 当作 Sora 入口
|
||||||
|
gateway.POST("/chat/completions", func(c *gin.Context) {
|
||||||
// Sora Chat Completions
|
c.JSON(http.StatusBadRequest, gin.H{
|
||||||
soraGateway := r.Group("/v1")
|
"error": gin.H{
|
||||||
soraGateway.Use(soraBodyLimit)
|
"type": "invalid_request_error",
|
||||||
soraGateway.Use(clientRequestID)
|
"message": "For Sora, use /sora/v1/chat/completions. OpenAI should use /v1/responses.",
|
||||||
soraGateway.Use(opsErrorLogger)
|
},
|
||||||
soraGateway.Use(gin.HandlerFunc(apiKeyAuth))
|
})
|
||||||
{
|
})
|
||||||
soraGateway.POST("/chat/completions", h.SoraGateway.ChatCompletions)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Gemini 原生 API 兼容层(Gemini SDK/CLI 直连)
|
// Gemini 原生 API 兼容层(Gemini SDK/CLI 直连)
|
||||||
|
|||||||
@@ -27,11 +27,13 @@ import (
|
|||||||
// sseDataPrefix matches SSE data lines with optional whitespace after colon.
|
// sseDataPrefix matches SSE data lines with optional whitespace after colon.
|
||||||
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
|
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
|
||||||
var sseDataPrefix = regexp.MustCompile(`^data:\s*`)
|
var sseDataPrefix = regexp.MustCompile(`^data:\s*`)
|
||||||
|
var cloudflareRayPattern = regexp.MustCompile(`(?i)cRay:\s*'([a-z0-9-]+)'`)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
testClaudeAPIURL = "https://api.anthropic.com/v1/messages"
|
testClaudeAPIURL = "https://api.anthropic.com/v1/messages"
|
||||||
chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses"
|
chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses"
|
||||||
soraMeAPIURL = "https://sora.chatgpt.com/backend/me" // Sora 用户信息接口,用于测试连接
|
soraMeAPIURL = "https://sora.chatgpt.com/backend/me" // Sora 用户信息接口,用于测试连接
|
||||||
|
soraBillingAPIURL = "https://sora.chatgpt.com/backend/billing/subscriptions"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestEvent represents a SSE event for account testing
|
// TestEvent represents a SSE event for account testing
|
||||||
@@ -502,8 +504,9 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
|
|||||||
if account.ProxyID != nil && account.Proxy != nil {
|
if account.ProxyID != nil && account.Proxy != nil {
|
||||||
proxyURL = account.Proxy.URL()
|
proxyURL = account.Proxy.URL()
|
||||||
}
|
}
|
||||||
|
enableSoraTLSFingerprint := s.shouldEnableSoraTLSFingerprint()
|
||||||
|
|
||||||
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, enableSoraTLSFingerprint)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||||
}
|
}
|
||||||
@@ -512,7 +515,10 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
|
|||||||
body, _ := io.ReadAll(resp.Body)
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, string(body)))
|
if isCloudflareChallengeResponse(resp.StatusCode, body) {
|
||||||
|
return s.sendErrorAndEnd(c, formatCloudflareChallengeMessage("Sora request blocked by Cloudflare challenge (HTTP 403). Please switch to a clean proxy/network and retry.", resp.Header, body))
|
||||||
|
}
|
||||||
|
return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, truncateSoraErrorBody(body, 512)))
|
||||||
}
|
}
|
||||||
|
|
||||||
// 解析 /me 响应,提取用户信息
|
// 解析 /me 响应,提取用户信息
|
||||||
@@ -531,10 +537,129 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
|
|||||||
s.sendEvent(c, TestEvent{Type: "content", Text: info})
|
s.sendEvent(c, TestEvent{Type: "content", Text: info})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 追加轻量能力检查:订阅信息查询(失败仅告警,不中断连接测试)
|
||||||
|
subReq, err := http.NewRequestWithContext(ctx, "GET", soraBillingAPIURL, nil)
|
||||||
|
if err == nil {
|
||||||
|
subReq.Header.Set("Authorization", "Bearer "+authToken)
|
||||||
|
subReq.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
|
||||||
|
subReq.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
|
subResp, subErr := s.httpUpstream.DoWithTLS(subReq, proxyURL, account.ID, account.Concurrency, enableSoraTLSFingerprint)
|
||||||
|
if subErr != nil {
|
||||||
|
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check skipped: %s", subErr.Error())})
|
||||||
|
} else {
|
||||||
|
subBody, _ := io.ReadAll(subResp.Body)
|
||||||
|
_ = subResp.Body.Close()
|
||||||
|
if subResp.StatusCode == http.StatusOK {
|
||||||
|
if summary := parseSoraSubscriptionSummary(subBody); summary != "" {
|
||||||
|
s.sendEvent(c, TestEvent{Type: "content", Text: summary})
|
||||||
|
} else {
|
||||||
|
s.sendEvent(c, TestEvent{Type: "content", Text: "Subscription check OK"})
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if isCloudflareChallengeResponse(subResp.StatusCode, subBody) {
|
||||||
|
s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage("Subscription check blocked by Cloudflare challenge (HTTP 403)", subResp.Header, subBody)})
|
||||||
|
} else {
|
||||||
|
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check returned %d", subResp.StatusCode)})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseSoraSubscriptionSummary(body []byte) string {
|
||||||
|
var subResp struct {
|
||||||
|
Data []struct {
|
||||||
|
Plan struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
} `json:"plan"`
|
||||||
|
EndTS string `json:"end_ts"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(body, &subResp); err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if len(subResp.Data) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
first := subResp.Data[0]
|
||||||
|
parts := make([]string, 0, 3)
|
||||||
|
if first.Plan.Title != "" {
|
||||||
|
parts = append(parts, first.Plan.Title)
|
||||||
|
}
|
||||||
|
if first.Plan.ID != "" {
|
||||||
|
parts = append(parts, first.Plan.ID)
|
||||||
|
}
|
||||||
|
if first.EndTS != "" {
|
||||||
|
parts = append(parts, "end="+first.EndTS)
|
||||||
|
}
|
||||||
|
if len(parts) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return "Subscription: " + strings.Join(parts, " | ")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AccountTestService) shouldEnableSoraTLSFingerprint() bool {
|
||||||
|
if s == nil || s.cfg == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return s.cfg.Gateway.TLSFingerprint.Enabled && !s.cfg.Sora.Client.DisableTLSFingerprint
|
||||||
|
}
|
||||||
|
|
||||||
|
func isCloudflareChallengeResponse(statusCode int, body []byte) bool {
|
||||||
|
if statusCode != http.StatusForbidden {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
preview := strings.ToLower(truncateSoraErrorBody(body, 4096))
|
||||||
|
return strings.Contains(preview, "window._cf_chl_opt") ||
|
||||||
|
strings.Contains(preview, "just a moment") ||
|
||||||
|
strings.Contains(preview, "enable javascript and cookies to continue")
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatCloudflareChallengeMessage(base string, headers http.Header, body []byte) string {
|
||||||
|
rayID := extractCloudflareRayID(headers, body)
|
||||||
|
if rayID == "" {
|
||||||
|
return base
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s (cf-ray: %s)", base, rayID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractCloudflareRayID(headers http.Header, body []byte) string {
|
||||||
|
if headers != nil {
|
||||||
|
rayID := strings.TrimSpace(headers.Get("cf-ray"))
|
||||||
|
if rayID != "" {
|
||||||
|
return rayID
|
||||||
|
}
|
||||||
|
rayID = strings.TrimSpace(headers.Get("Cf-Ray"))
|
||||||
|
if rayID != "" {
|
||||||
|
return rayID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
preview := truncateSoraErrorBody(body, 8192)
|
||||||
|
matches := cloudflareRayPattern.FindStringSubmatch(preview)
|
||||||
|
if len(matches) >= 2 {
|
||||||
|
return strings.TrimSpace(matches[1])
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func truncateSoraErrorBody(body []byte, max int) string {
|
||||||
|
if max <= 0 {
|
||||||
|
max = 512
|
||||||
|
}
|
||||||
|
raw := strings.TrimSpace(string(body))
|
||||||
|
if len(raw) <= max {
|
||||||
|
return raw
|
||||||
|
}
|
||||||
|
return raw[:max] + "...(truncated)"
|
||||||
|
}
|
||||||
|
|
||||||
// testAntigravityAccountConnection tests an Antigravity account's connection
|
// testAntigravityAccountConnection tests an Antigravity account's connection
|
||||||
// 支持 Claude 和 Gemini 两种协议,使用非流式请求
|
// 支持 Claude 和 Gemini 两种协议,使用非流式请求
|
||||||
func (s *AccountTestService) testAntigravityAccountConnection(c *gin.Context, account *Account, modelID string) error {
|
func (s *AccountTestService) testAntigravityAccountConnection(c *gin.Context, account *Account, modelID string) error {
|
||||||
|
|||||||
193
backend/internal/service/account_test_service_sora_test.go
Normal file
193
backend/internal/service/account_test_service_sora_test.go
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type queuedHTTPUpstream struct {
|
||||||
|
responses []*http.Response
|
||||||
|
requests []*http.Request
|
||||||
|
tlsFlags []bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *queuedHTTPUpstream) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
|
||||||
|
return nil, fmt.Errorf("unexpected Do call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *queuedHTTPUpstream) DoWithTLS(req *http.Request, _ string, _ int64, _ int, enableTLSFingerprint bool) (*http.Response, error) {
|
||||||
|
u.requests = append(u.requests, req)
|
||||||
|
u.tlsFlags = append(u.tlsFlags, enableTLSFingerprint)
|
||||||
|
if len(u.responses) == 0 {
|
||||||
|
return nil, fmt.Errorf("no mocked response")
|
||||||
|
}
|
||||||
|
resp := u.responses[0]
|
||||||
|
u.responses = u.responses[1:]
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newJSONResponse(status int, body string) *http.Response {
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: status,
|
||||||
|
Header: make(http.Header),
|
||||||
|
Body: io.NopCloser(strings.NewReader(body)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newJSONResponseWithHeader(status int, body, key, value string) *http.Response {
|
||||||
|
resp := newJSONResponse(status, body)
|
||||||
|
resp.Header.Set(key, value)
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSoraTestContext() (*gin.Context, *httptest.ResponseRecorder) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil)
|
||||||
|
return c, rec
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountTestService_testSoraAccountConnection_WithSubscription(t *testing.T) {
|
||||||
|
upstream := &queuedHTTPUpstream{
|
||||||
|
responses: []*http.Response{
|
||||||
|
newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`),
|
||||||
|
newJSONResponse(http.StatusOK, `{"data":[{"plan":{"id":"chatgpt_plus","title":"ChatGPT Plus"},"end_ts":"2026-12-31T00:00:00Z"}]}`),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &AccountTestService{
|
||||||
|
httpUpstream: upstream,
|
||||||
|
cfg: &config.Config{
|
||||||
|
Gateway: config.GatewayConfig{
|
||||||
|
TLSFingerprint: config.TLSFingerprintConfig{
|
||||||
|
Enabled: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Sora: config.SoraConfig{
|
||||||
|
Client: config.SoraClientConfig{
|
||||||
|
DisableTLSFingerprint: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
account := &Account{
|
||||||
|
ID: 1,
|
||||||
|
Platform: PlatformSora,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "test_token",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
c, rec := newSoraTestContext()
|
||||||
|
err := svc.testSoraAccountConnection(c, account)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, upstream.requests, 2)
|
||||||
|
require.Equal(t, soraMeAPIURL, upstream.requests[0].URL.String())
|
||||||
|
require.Equal(t, soraBillingAPIURL, upstream.requests[1].URL.String())
|
||||||
|
require.Equal(t, "Bearer test_token", upstream.requests[0].Header.Get("Authorization"))
|
||||||
|
require.Equal(t, "Bearer test_token", upstream.requests[1].Header.Get("Authorization"))
|
||||||
|
require.Equal(t, []bool{true, true}, upstream.tlsFlags)
|
||||||
|
|
||||||
|
body := rec.Body.String()
|
||||||
|
require.Contains(t, body, `"type":"test_start"`)
|
||||||
|
require.Contains(t, body, "Sora connection OK - Email: demo@example.com")
|
||||||
|
require.Contains(t, body, "Subscription: ChatGPT Plus | chatgpt_plus | end=2026-12-31T00:00:00Z")
|
||||||
|
require.Contains(t, body, `"type":"test_complete","success":true`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountTestService_testSoraAccountConnection_SubscriptionFailedStillSuccess(t *testing.T) {
|
||||||
|
upstream := &queuedHTTPUpstream{
|
||||||
|
responses: []*http.Response{
|
||||||
|
newJSONResponse(http.StatusOK, `{"name":"demo-user"}`),
|
||||||
|
newJSONResponse(http.StatusForbidden, `{"error":{"message":"forbidden"}}`),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &AccountTestService{httpUpstream: upstream}
|
||||||
|
account := &Account{
|
||||||
|
ID: 1,
|
||||||
|
Platform: PlatformSora,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "test_token",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
c, rec := newSoraTestContext()
|
||||||
|
err := svc.testSoraAccountConnection(c, account)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, upstream.requests, 2)
|
||||||
|
body := rec.Body.String()
|
||||||
|
require.Contains(t, body, "Sora connection OK - User: demo-user")
|
||||||
|
require.Contains(t, body, "Subscription check returned 403")
|
||||||
|
require.Contains(t, body, `"type":"test_complete","success":true`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge(t *testing.T) {
|
||||||
|
upstream := &queuedHTTPUpstream{
|
||||||
|
responses: []*http.Response{
|
||||||
|
newJSONResponseWithHeader(http.StatusForbidden, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={};</script><noscript>Enable JavaScript and cookies to continue</noscript></body></html>`, "cf-ray", "9cff2d62d83bb98d"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &AccountTestService{httpUpstream: upstream}
|
||||||
|
account := &Account{
|
||||||
|
ID: 1,
|
||||||
|
Platform: PlatformSora,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "test_token",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
c, rec := newSoraTestContext()
|
||||||
|
err := svc.testSoraAccountConnection(c, account)
|
||||||
|
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "Cloudflare challenge")
|
||||||
|
require.Contains(t, err.Error(), "cf-ray: 9cff2d62d83bb98d")
|
||||||
|
body := rec.Body.String()
|
||||||
|
require.Contains(t, body, `"type":"error"`)
|
||||||
|
require.Contains(t, body, "Cloudflare challenge")
|
||||||
|
require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountTestService_testSoraAccountConnection_SubscriptionCloudflareChallengeWithRay(t *testing.T) {
|
||||||
|
upstream := &queuedHTTPUpstream{
|
||||||
|
responses: []*http.Response{
|
||||||
|
newJSONResponse(http.StatusOK, `{"name":"demo-user"}`),
|
||||||
|
newJSONResponse(http.StatusForbidden, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={cRay: '9cff2d62d83bb98d'};</script><noscript>Enable JavaScript and cookies to continue</noscript></body></html>`),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &AccountTestService{httpUpstream: upstream}
|
||||||
|
account := &Account{
|
||||||
|
ID: 1,
|
||||||
|
Platform: PlatformSora,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "test_token",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
c, rec := newSoraTestContext()
|
||||||
|
err := svc.testSoraAccountConnection(c, account)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
body := rec.Body.String()
|
||||||
|
require.Contains(t, body, "Subscription check blocked by Cloudflare challenge (HTTP 403)")
|
||||||
|
require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d")
|
||||||
|
require.Contains(t, body, `"type":"test_complete","success":true`)
|
||||||
|
}
|
||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
type OpenAIOAuthClient interface {
|
type OpenAIOAuthClient interface {
|
||||||
ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error)
|
ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error)
|
||||||
RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error)
|
RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error)
|
||||||
|
RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClaudeOAuthClient handles HTTP requests for Claude OAuth flows
|
// ClaudeOAuthClient handles HTTP requests for Claude OAuth flows
|
||||||
|
|||||||
@@ -2,13 +2,20 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/subtle"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var openAISoraSessionAuthURL = "https://sora.chatgpt.com/api/auth/session"
|
||||||
|
|
||||||
// OpenAIOAuthService handles OpenAI OAuth authentication flows
|
// OpenAIOAuthService handles OpenAI OAuth authentication flows
|
||||||
type OpenAIOAuthService struct {
|
type OpenAIOAuthService struct {
|
||||||
sessionStore *openai.SessionStore
|
sessionStore *openai.SessionStore
|
||||||
@@ -92,6 +99,7 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
|
|||||||
type OpenAIExchangeCodeInput struct {
|
type OpenAIExchangeCodeInput struct {
|
||||||
SessionID string
|
SessionID string
|
||||||
Code string
|
Code string
|
||||||
|
State string
|
||||||
RedirectURI string
|
RedirectURI string
|
||||||
ProxyID *int64
|
ProxyID *int64
|
||||||
}
|
}
|
||||||
@@ -116,6 +124,12 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
|
|||||||
if !ok {
|
if !ok {
|
||||||
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_SESSION_NOT_FOUND", "session not found or expired")
|
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_SESSION_NOT_FOUND", "session not found or expired")
|
||||||
}
|
}
|
||||||
|
if input.State == "" {
|
||||||
|
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_STATE_REQUIRED", "oauth state is required")
|
||||||
|
}
|
||||||
|
if subtle.ConstantTimeCompare([]byte(input.State), []byte(session.State)) != 1 {
|
||||||
|
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_STATE", "invalid oauth state")
|
||||||
|
}
|
||||||
|
|
||||||
// Get proxy URL: prefer input.ProxyID, fallback to session.ProxyURL
|
// Get proxy URL: prefer input.ProxyID, fallback to session.ProxyURL
|
||||||
proxyURL := session.ProxyURL
|
proxyURL := session.ProxyURL
|
||||||
@@ -173,7 +187,12 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
|
|||||||
|
|
||||||
// RefreshToken refreshes an OpenAI OAuth token
|
// RefreshToken refreshes an OpenAI OAuth token
|
||||||
func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken string, proxyURL string) (*OpenAITokenInfo, error) {
|
func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken string, proxyURL string) (*OpenAITokenInfo, error) {
|
||||||
tokenResp, err := s.oauthClient.RefreshToken(ctx, refreshToken, proxyURL)
|
return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshTokenWithClientID refreshes an OpenAI/Sora OAuth token with optional client_id.
|
||||||
|
func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refreshToken string, proxyURL string, clientID string) (*OpenAITokenInfo, error) {
|
||||||
|
tokenResp, err := s.oauthClient.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -205,13 +224,83 @@ func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken stri
|
|||||||
return tokenInfo, nil
|
return tokenInfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RefreshAccountToken refreshes token for an OpenAI account
|
// ExchangeSoraSessionToken exchanges Sora session_token to access_token.
|
||||||
func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
|
func (s *OpenAIOAuthService) ExchangeSoraSessionToken(ctx context.Context, sessionToken string, proxyID *int64) (*OpenAITokenInfo, error) {
|
||||||
if !account.IsOpenAI() {
|
if strings.TrimSpace(sessionToken) == "" {
|
||||||
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI account")
|
return nil, infraerrors.New(http.StatusBadRequest, "SORA_SESSION_TOKEN_REQUIRED", "session_token is required")
|
||||||
}
|
}
|
||||||
|
|
||||||
refreshToken := account.GetOpenAIRefreshToken()
|
proxyURL, err := s.resolveProxyURL(ctx, proxyID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, openAISoraSessionAuthURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, infraerrors.Newf(http.StatusInternalServerError, "SORA_SESSION_REQUEST_BUILD_FAILED", "failed to build request: %v", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Cookie", "__Secure-next-auth.session-token="+strings.TrimSpace(sessionToken))
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
req.Header.Set("Origin", "https://sora.chatgpt.com")
|
||||||
|
req.Header.Set("Referer", "https://sora.chatgpt.com/")
|
||||||
|
req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
|
||||||
|
|
||||||
|
client := newOpenAIOAuthHTTPClient(proxyURL)
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_REQUEST_FAILED", "request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_EXCHANGE_FAILED", "status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||||
|
}
|
||||||
|
|
||||||
|
var sessionResp struct {
|
||||||
|
AccessToken string `json:"accessToken"`
|
||||||
|
Expires string `json:"expires"`
|
||||||
|
User struct {
|
||||||
|
Email string `json:"email"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
} `json:"user"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(body, &sessionResp); err != nil {
|
||||||
|
return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_PARSE_FAILED", "failed to parse response: %v", err)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(sessionResp.AccessToken) == "" {
|
||||||
|
return nil, infraerrors.New(http.StatusBadGateway, "SORA_SESSION_ACCESS_TOKEN_MISSING", "session exchange response missing access token")
|
||||||
|
}
|
||||||
|
|
||||||
|
expiresAt := time.Now().Add(time.Hour).Unix()
|
||||||
|
if strings.TrimSpace(sessionResp.Expires) != "" {
|
||||||
|
if parsed, parseErr := time.Parse(time.RFC3339, sessionResp.Expires); parseErr == nil {
|
||||||
|
expiresAt = parsed.Unix()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
expiresIn := expiresAt - time.Now().Unix()
|
||||||
|
if expiresIn < 0 {
|
||||||
|
expiresIn = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return &OpenAITokenInfo{
|
||||||
|
AccessToken: strings.TrimSpace(sessionResp.AccessToken),
|
||||||
|
ExpiresIn: expiresIn,
|
||||||
|
ExpiresAt: expiresAt,
|
||||||
|
Email: strings.TrimSpace(sessionResp.User.Email),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshAccountToken refreshes token for an OpenAI/Sora OAuth account
|
||||||
|
func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
|
||||||
|
if account.Platform != PlatformOpenAI && account.Platform != PlatformSora {
|
||||||
|
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI/Sora account")
|
||||||
|
}
|
||||||
|
if account.Type != AccountTypeOAuth {
|
||||||
|
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT_TYPE", "account is not an OAuth account")
|
||||||
|
}
|
||||||
|
|
||||||
|
refreshToken := account.GetCredential("refresh_token")
|
||||||
if refreshToken == "" {
|
if refreshToken == "" {
|
||||||
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_NO_REFRESH_TOKEN", "no refresh token available")
|
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_NO_REFRESH_TOKEN", "no refresh token available")
|
||||||
}
|
}
|
||||||
@@ -224,7 +313,8 @@ func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *A
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.RefreshToken(ctx, refreshToken, proxyURL)
|
clientID := account.GetCredential("client_id")
|
||||||
|
return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// BuildAccountCredentials builds credentials map from token info
|
// BuildAccountCredentials builds credentials map from token info
|
||||||
@@ -260,3 +350,30 @@ func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo)
|
|||||||
func (s *OpenAIOAuthService) Stop() {
|
func (s *OpenAIOAuthService) Stop() {
|
||||||
s.sessionStore.Stop()
|
s.sessionStore.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIOAuthService) resolveProxyURL(ctx context.Context, proxyID *int64) (string, error) {
|
||||||
|
if proxyID == nil {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
|
||||||
|
if err != nil {
|
||||||
|
return "", infraerrors.Newf(http.StatusBadRequest, "OPENAI_OAUTH_PROXY_NOT_FOUND", "proxy not found: %v", err)
|
||||||
|
}
|
||||||
|
if proxy == nil {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
return proxy.URL(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newOpenAIOAuthHTTPClient(proxyURL string) *http.Client {
|
||||||
|
transport := &http.Transport{}
|
||||||
|
if strings.TrimSpace(proxyURL) != "" {
|
||||||
|
if parsed, err := url.Parse(proxyURL); err == nil && parsed.Host != "" {
|
||||||
|
transport.Proxy = http.ProxyURL(parsed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &http.Client{
|
||||||
|
Timeout: 120 * time.Second,
|
||||||
|
Transport: transport,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,69 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type openaiOAuthClientNoopStub struct{}
|
||||||
|
|
||||||
|
func (s *openaiOAuthClientNoopStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openaiOAuthClientNoopStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openaiOAuthClientNoopStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIOAuthService_ExchangeSoraSessionToken_Success(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
require.Equal(t, http.MethodGet, r.Method)
|
||||||
|
require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=st-token")
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
origin := openAISoraSessionAuthURL
|
||||||
|
openAISoraSessionAuthURL = server.URL
|
||||||
|
defer func() { openAISoraSessionAuthURL = origin }()
|
||||||
|
|
||||||
|
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
|
||||||
|
defer svc.Stop()
|
||||||
|
|
||||||
|
info, err := svc.ExchangeSoraSessionToken(context.Background(), "st-token", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, info)
|
||||||
|
require.Equal(t, "at-token", info.AccessToken)
|
||||||
|
require.Equal(t, "demo@example.com", info.Email)
|
||||||
|
require.Greater(t, info.ExpiresAt, int64(0))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIOAuthService_ExchangeSoraSessionToken_MissingAccessToken(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write([]byte(`{"expires":"2099-01-01T00:00:00Z"}`))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
origin := openAISoraSessionAuthURL
|
||||||
|
openAISoraSessionAuthURL = server.URL
|
||||||
|
defer func() { openAISoraSessionAuthURL = origin }()
|
||||||
|
|
||||||
|
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
|
||||||
|
defer svc.Stop()
|
||||||
|
|
||||||
|
_, err := svc.ExchangeSoraSessionToken(context.Background(), "st-token", nil)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "missing access token")
|
||||||
|
}
|
||||||
102
backend/internal/service/openai_oauth_service_state_test.go
Normal file
102
backend/internal/service/openai_oauth_service_state_test.go
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type openaiOAuthClientStateStub struct {
|
||||||
|
exchangeCalled int32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openaiOAuthClientStateStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
|
||||||
|
atomic.AddInt32(&s.exchangeCalled, 1)
|
||||||
|
return &openai.TokenResponse{
|
||||||
|
AccessToken: "at",
|
||||||
|
RefreshToken: "rt",
|
||||||
|
ExpiresIn: 3600,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openaiOAuthClientStateStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openaiOAuthClientStateStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
|
||||||
|
return s.RefreshToken(ctx, refreshToken, proxyURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIOAuthService_ExchangeCode_StateRequired(t *testing.T) {
|
||||||
|
client := &openaiOAuthClientStateStub{}
|
||||||
|
svc := NewOpenAIOAuthService(nil, client)
|
||||||
|
defer svc.Stop()
|
||||||
|
|
||||||
|
svc.sessionStore.Set("sid", &openai.OAuthSession{
|
||||||
|
State: "expected-state",
|
||||||
|
CodeVerifier: "verifier",
|
||||||
|
RedirectURI: openai.DefaultRedirectURI,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err := svc.ExchangeCode(context.Background(), &OpenAIExchangeCodeInput{
|
||||||
|
SessionID: "sid",
|
||||||
|
Code: "auth-code",
|
||||||
|
})
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "oauth state is required")
|
||||||
|
require.Equal(t, int32(0), atomic.LoadInt32(&client.exchangeCalled))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIOAuthService_ExchangeCode_StateMismatch(t *testing.T) {
|
||||||
|
client := &openaiOAuthClientStateStub{}
|
||||||
|
svc := NewOpenAIOAuthService(nil, client)
|
||||||
|
defer svc.Stop()
|
||||||
|
|
||||||
|
svc.sessionStore.Set("sid", &openai.OAuthSession{
|
||||||
|
State: "expected-state",
|
||||||
|
CodeVerifier: "verifier",
|
||||||
|
RedirectURI: openai.DefaultRedirectURI,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err := svc.ExchangeCode(context.Background(), &OpenAIExchangeCodeInput{
|
||||||
|
SessionID: "sid",
|
||||||
|
Code: "auth-code",
|
||||||
|
State: "wrong-state",
|
||||||
|
})
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "invalid oauth state")
|
||||||
|
require.Equal(t, int32(0), atomic.LoadInt32(&client.exchangeCalled))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIOAuthService_ExchangeCode_StateMatch(t *testing.T) {
|
||||||
|
client := &openaiOAuthClientStateStub{}
|
||||||
|
svc := NewOpenAIOAuthService(nil, client)
|
||||||
|
defer svc.Stop()
|
||||||
|
|
||||||
|
svc.sessionStore.Set("sid", &openai.OAuthSession{
|
||||||
|
State: "expected-state",
|
||||||
|
CodeVerifier: "verifier",
|
||||||
|
RedirectURI: openai.DefaultRedirectURI,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
})
|
||||||
|
|
||||||
|
info, err := svc.ExchangeCode(context.Background(), &OpenAIExchangeCodeInput{
|
||||||
|
SessionID: "sid",
|
||||||
|
Code: "auth-code",
|
||||||
|
State: "expected-state",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, info)
|
||||||
|
require.Equal(t, "at", info.AccessToken)
|
||||||
|
require.Equal(t, int32(1), atomic.LoadInt32(&client.exchangeCalled))
|
||||||
|
|
||||||
|
_, ok := svc.sessionStore.Get("sid")
|
||||||
|
require.False(t, ok)
|
||||||
|
}
|
||||||
@@ -157,7 +157,11 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
|||||||
}
|
}
|
||||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||||
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
|
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
|
||||||
if p.openAIOAuthService == nil {
|
if account.Platform == PlatformSora {
|
||||||
|
slog.Debug("openai_token_refresh_skipped_for_sora", "account_id", account.ID)
|
||||||
|
// Sora 账号不走 OpenAI OAuth 刷新,交由 Sora 客户端的 ST/RT 恢复链路处理。
|
||||||
|
refreshFailed = true
|
||||||
|
} else if p.openAIOAuthService == nil {
|
||||||
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
|
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
|
||||||
p.metrics.refreshFailure.Add(1)
|
p.metrics.refreshFailure.Add(1)
|
||||||
refreshFailed = true // 无法刷新,标记失败
|
refreshFailed = true // 无法刷新,标记失败
|
||||||
@@ -206,7 +210,11 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
|||||||
|
|
||||||
// 仅在 expires_at 已过期/接近过期时才执行无锁刷新
|
// 仅在 expires_at 已过期/接近过期时才执行无锁刷新
|
||||||
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
|
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
|
||||||
if p.openAIOAuthService == nil {
|
if account.Platform == PlatformSora {
|
||||||
|
slog.Debug("openai_token_refresh_skipped_for_sora_degraded", "account_id", account.ID)
|
||||||
|
// Sora 账号不走 OpenAI OAuth 刷新,交由 Sora 客户端的 ST/RT 恢复链路处理。
|
||||||
|
refreshFailed = true
|
||||||
|
} else if p.openAIOAuthService == nil {
|
||||||
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
|
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
|
||||||
p.metrics.refreshFailure.Add(1)
|
p.metrics.refreshFailure.Add(1)
|
||||||
refreshFailed = true
|
refreshFailed = true
|
||||||
|
|||||||
@@ -17,12 +17,15 @@ import (
|
|||||||
"net/textproto"
|
"net/textproto"
|
||||||
"net/url"
|
"net/url"
|
||||||
"path"
|
"path"
|
||||||
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
openaioauth "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"golang.org/x/crypto/sha3"
|
"golang.org/x/crypto/sha3"
|
||||||
@@ -34,6 +37,11 @@ const (
|
|||||||
soraDefaultUserAgent = "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)"
|
soraDefaultUserAgent = "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
soraSessionAuthURL = "https://sora.chatgpt.com/api/auth/session"
|
||||||
|
soraOAuthTokenURL = "https://auth.openai.com/oauth/token"
|
||||||
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
soraPowMaxIteration = 500000
|
soraPowMaxIteration = 500000
|
||||||
)
|
)
|
||||||
@@ -96,6 +104,7 @@ type SoraClient interface {
|
|||||||
UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error)
|
UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error)
|
||||||
CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error)
|
CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error)
|
||||||
CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error)
|
CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error)
|
||||||
|
EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error)
|
||||||
GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error)
|
GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error)
|
||||||
GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error)
|
GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error)
|
||||||
}
|
}
|
||||||
@@ -160,23 +169,91 @@ type SoraDirectClient struct {
|
|||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
httpUpstream HTTPUpstream
|
httpUpstream HTTPUpstream
|
||||||
tokenProvider *OpenAITokenProvider
|
tokenProvider *OpenAITokenProvider
|
||||||
|
accountRepo AccountRepository
|
||||||
|
soraAccountRepo SoraAccountRepository
|
||||||
|
baseURL string
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSoraDirectClient 创建 Sora 直连客户端
|
// NewSoraDirectClient 创建 Sora 直连客户端
|
||||||
func NewSoraDirectClient(cfg *config.Config, httpUpstream HTTPUpstream, tokenProvider *OpenAITokenProvider) *SoraDirectClient {
|
func NewSoraDirectClient(cfg *config.Config, httpUpstream HTTPUpstream, tokenProvider *OpenAITokenProvider) *SoraDirectClient {
|
||||||
|
baseURL := ""
|
||||||
|
if cfg != nil {
|
||||||
|
rawBaseURL := strings.TrimRight(strings.TrimSpace(cfg.Sora.Client.BaseURL), "/")
|
||||||
|
baseURL = normalizeSoraBaseURL(rawBaseURL)
|
||||||
|
if rawBaseURL != "" && baseURL != rawBaseURL {
|
||||||
|
log.Printf("[SoraClient] normalized base_url from %s to %s", sanitizeSoraLogURL(rawBaseURL), sanitizeSoraLogURL(baseURL))
|
||||||
|
}
|
||||||
|
}
|
||||||
return &SoraDirectClient{
|
return &SoraDirectClient{
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
httpUpstream: httpUpstream,
|
httpUpstream: httpUpstream,
|
||||||
tokenProvider: tokenProvider,
|
tokenProvider: tokenProvider,
|
||||||
|
baseURL: baseURL,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *SoraDirectClient) SetAccountRepositories(accountRepo AccountRepository, soraAccountRepo SoraAccountRepository) {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.accountRepo = accountRepo
|
||||||
|
c.soraAccountRepo = soraAccountRepo
|
||||||
|
}
|
||||||
|
|
||||||
// Enabled 判断是否启用 Sora 直连
|
// Enabled 判断是否启用 Sora 直连
|
||||||
func (c *SoraDirectClient) Enabled() bool {
|
func (c *SoraDirectClient) Enabled() bool {
|
||||||
if c == nil || c.cfg == nil {
|
if c == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return strings.TrimSpace(c.cfg.Sora.Client.BaseURL) != ""
|
if strings.TrimSpace(c.baseURL) != "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if c.cfg == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(normalizeSoraBaseURL(c.cfg.Sora.Client.BaseURL)) != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// PreflightCheck 在创建任务前执行账号能力预检。
|
||||||
|
// 当前仅对视频模型执行 /nf/check 预检,用于提前识别额度耗尽或能力缺失。
|
||||||
|
func (c *SoraDirectClient) PreflightCheck(ctx context.Context, account *Account, requestedModel string, modelCfg SoraModelConfig) error {
|
||||||
|
if modelCfg.Type != "video" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
token, err := c.getAccessToken(ctx, account)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
headers := c.buildBaseHeaders(token, c.defaultUserAgent())
|
||||||
|
headers.Set("Accept", "application/json")
|
||||||
|
body, _, err := c.doRequest(ctx, account, http.MethodGet, c.buildURL("/nf/check"), headers, nil, false)
|
||||||
|
if err != nil {
|
||||||
|
var upstreamErr *SoraUpstreamError
|
||||||
|
if errors.As(err, &upstreamErr) && upstreamErr.StatusCode == http.StatusNotFound {
|
||||||
|
return &SoraUpstreamError{
|
||||||
|
StatusCode: http.StatusForbidden,
|
||||||
|
Message: "当前账号未开通 Sora2 能力或无可用配额",
|
||||||
|
Headers: upstreamErr.Headers,
|
||||||
|
Body: upstreamErr.Body,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
rateLimitReached := gjson.GetBytes(body, "rate_limit_and_credit_balance.rate_limit_reached").Bool()
|
||||||
|
remaining := gjson.GetBytes(body, "rate_limit_and_credit_balance.estimated_num_videos_remaining")
|
||||||
|
if rateLimitReached || (remaining.Exists() && remaining.Int() <= 0) {
|
||||||
|
msg := "当前账号 Sora2 可用配额不足"
|
||||||
|
if requestedModel != "" {
|
||||||
|
msg = fmt.Sprintf("当前账号 %s 可用配额不足", requestedModel)
|
||||||
|
}
|
||||||
|
return &SoraUpstreamError{
|
||||||
|
StatusCode: http.StatusTooManyRequests,
|
||||||
|
Message: msg,
|
||||||
|
Headers: http.Header{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *SoraDirectClient) UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) {
|
func (c *SoraDirectClient) UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) {
|
||||||
@@ -347,6 +424,45 @@ func (c *SoraDirectClient) CreateVideoTask(ctx context.Context, account *Account
|
|||||||
return taskID, nil
|
return taskID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *SoraDirectClient) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) {
|
||||||
|
token, err := c.getAccessToken(ctx, account)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(expansionLevel) == "" {
|
||||||
|
expansionLevel = "medium"
|
||||||
|
}
|
||||||
|
if durationS <= 0 {
|
||||||
|
durationS = 10
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := map[string]any{
|
||||||
|
"prompt": prompt,
|
||||||
|
"expansion_level": expansionLevel,
|
||||||
|
"duration_s": durationS,
|
||||||
|
}
|
||||||
|
body, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
headers := c.buildBaseHeaders(token, c.defaultUserAgent())
|
||||||
|
headers.Set("Content-Type", "application/json")
|
||||||
|
headers.Set("Accept", "application/json")
|
||||||
|
headers.Set("Origin", "https://sora.chatgpt.com")
|
||||||
|
headers.Set("Referer", "https://sora.chatgpt.com/")
|
||||||
|
|
||||||
|
respBody, _, err := c.doRequest(ctx, account, http.MethodPost, c.buildURL("/editor/enhance_prompt"), headers, bytes.NewReader(body), false)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
enhancedPrompt := strings.TrimSpace(gjson.GetBytes(respBody, "enhanced_prompt").String())
|
||||||
|
if enhancedPrompt == "" {
|
||||||
|
return "", errors.New("enhance_prompt response missing enhanced_prompt")
|
||||||
|
}
|
||||||
|
return enhancedPrompt, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *SoraDirectClient) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) {
|
func (c *SoraDirectClient) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) {
|
||||||
status, found, err := c.fetchRecentImageTask(ctx, account, taskID, c.recentTaskLimit())
|
status, found, err := c.fetchRecentImageTask(ctx, account, taskID, c.recentTaskLimit())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -512,9 +628,10 @@ func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, t
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *SoraDirectClient) buildURL(endpoint string) string {
|
func (c *SoraDirectClient) buildURL(endpoint string) string {
|
||||||
base := ""
|
base := strings.TrimRight(strings.TrimSpace(c.baseURL), "/")
|
||||||
if c != nil && c.cfg != nil {
|
if base == "" && c != nil && c.cfg != nil {
|
||||||
base = strings.TrimRight(strings.TrimSpace(c.cfg.Sora.Client.BaseURL), "/")
|
base = normalizeSoraBaseURL(c.cfg.Sora.Client.BaseURL)
|
||||||
|
c.baseURL = base
|
||||||
}
|
}
|
||||||
if base == "" {
|
if base == "" {
|
||||||
return endpoint
|
return endpoint
|
||||||
@@ -540,14 +657,257 @@ func (c *SoraDirectClient) getAccessToken(ctx context.Context, account *Account)
|
|||||||
if account == nil {
|
if account == nil {
|
||||||
return "", errors.New("account is nil")
|
return "", errors.New("account is nil")
|
||||||
}
|
}
|
||||||
if c.tokenProvider != nil {
|
|
||||||
return c.tokenProvider.GetAccessToken(ctx, account)
|
allowProvider := c.allowOpenAITokenProvider(account)
|
||||||
|
var providerErr error
|
||||||
|
if allowProvider && c.tokenProvider != nil {
|
||||||
|
token, err := c.tokenProvider.GetAccessToken(ctx, account)
|
||||||
|
if err == nil && strings.TrimSpace(token) != "" {
|
||||||
|
c.logTokenSource(account, "openai_token_provider")
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
providerErr = err
|
||||||
|
if err != nil && c.debugEnabled() {
|
||||||
|
c.debugLogf(
|
||||||
|
"token_provider_failed account_id=%d platform=%s err=%s",
|
||||||
|
account.ID,
|
||||||
|
account.Platform,
|
||||||
|
logredact.RedactText(err.Error()),
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
token := strings.TrimSpace(account.GetCredential("access_token"))
|
token := strings.TrimSpace(account.GetCredential("access_token"))
|
||||||
if token == "" {
|
if token != "" {
|
||||||
|
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||||
|
if expiresAt != nil && time.Until(*expiresAt) <= 2*time.Minute {
|
||||||
|
refreshed, refreshErr := c.recoverAccessToken(ctx, account, "access_token_expiring")
|
||||||
|
if refreshErr == nil && strings.TrimSpace(refreshed) != "" {
|
||||||
|
c.logTokenSource(account, "refresh_token_recovered")
|
||||||
|
return refreshed, nil
|
||||||
|
}
|
||||||
|
if refreshErr != nil && c.debugEnabled() {
|
||||||
|
c.debugLogf("token_refresh_before_use_failed account_id=%d err=%s", account.ID, logredact.RedactText(refreshErr.Error()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.logTokenSource(account, "account_credentials")
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
recovered, recoverErr := c.recoverAccessToken(ctx, account, "access_token_missing")
|
||||||
|
if recoverErr == nil && strings.TrimSpace(recovered) != "" {
|
||||||
|
c.logTokenSource(account, "session_or_refresh_recovered")
|
||||||
|
return recovered, nil
|
||||||
|
}
|
||||||
|
if recoverErr != nil && c.debugEnabled() {
|
||||||
|
c.debugLogf("token_recover_failed account_id=%d platform=%s err=%s", account.ID, account.Platform, logredact.RedactText(recoverErr.Error()))
|
||||||
|
}
|
||||||
|
if providerErr != nil {
|
||||||
|
return "", providerErr
|
||||||
|
}
|
||||||
|
if c.tokenProvider != nil && !allowProvider {
|
||||||
|
c.logTokenSource(account, "account_credentials(provider_disabled)")
|
||||||
|
}
|
||||||
return "", errors.New("access_token not found")
|
return "", errors.New("access_token not found")
|
||||||
}
|
}
|
||||||
return token, nil
|
|
||||||
|
func (c *SoraDirectClient) recoverAccessToken(ctx context.Context, account *Account, reason string) (string, error) {
|
||||||
|
if account == nil {
|
||||||
|
return "", errors.New("account is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if sessionToken := strings.TrimSpace(account.GetCredential("session_token")); sessionToken != "" {
|
||||||
|
accessToken, expiresAt, err := c.exchangeSessionToken(ctx, account, sessionToken)
|
||||||
|
if err == nil && strings.TrimSpace(accessToken) != "" {
|
||||||
|
c.applyRecoveredToken(ctx, account, accessToken, "", expiresAt, sessionToken)
|
||||||
|
c.logTokenRecover(account, "session_token", reason, true, nil)
|
||||||
|
return accessToken, nil
|
||||||
|
}
|
||||||
|
c.logTokenRecover(account, "session_token", reason, false, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
refreshToken := strings.TrimSpace(account.GetCredential("refresh_token"))
|
||||||
|
if refreshToken == "" {
|
||||||
|
return "", errors.New("session_token/refresh_token not found")
|
||||||
|
}
|
||||||
|
accessToken, newRefreshToken, expiresAt, err := c.exchangeRefreshToken(ctx, account, refreshToken)
|
||||||
|
if err != nil {
|
||||||
|
c.logTokenRecover(account, "refresh_token", reason, false, err)
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(accessToken) == "" {
|
||||||
|
return "", errors.New("refreshed access_token is empty")
|
||||||
|
}
|
||||||
|
c.applyRecoveredToken(ctx, account, accessToken, newRefreshToken, expiresAt, "")
|
||||||
|
c.logTokenRecover(account, "refresh_token", reason, true, nil)
|
||||||
|
return accessToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SoraDirectClient) exchangeSessionToken(ctx context.Context, account *Account, sessionToken string) (string, string, error) {
|
||||||
|
headers := http.Header{}
|
||||||
|
headers.Set("Cookie", "__Secure-next-auth.session-token="+sessionToken)
|
||||||
|
headers.Set("Accept", "application/json")
|
||||||
|
headers.Set("Origin", "https://sora.chatgpt.com")
|
||||||
|
headers.Set("Referer", "https://sora.chatgpt.com/")
|
||||||
|
headers.Set("User-Agent", c.defaultUserAgent())
|
||||||
|
body, _, err := c.doRequest(ctx, account, http.MethodGet, soraSessionAuthURL, headers, nil, false)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
accessToken := strings.TrimSpace(gjson.GetBytes(body, "accessToken").String())
|
||||||
|
if accessToken == "" {
|
||||||
|
return "", "", errors.New("session exchange missing accessToken")
|
||||||
|
}
|
||||||
|
expiresAt := strings.TrimSpace(gjson.GetBytes(body, "expires").String())
|
||||||
|
return accessToken, expiresAt, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SoraDirectClient) exchangeRefreshToken(ctx context.Context, account *Account, refreshToken string) (string, string, string, error) {
|
||||||
|
clientIDs := []string{
|
||||||
|
strings.TrimSpace(account.GetCredential("client_id")),
|
||||||
|
openaioauth.SoraClientID,
|
||||||
|
openaioauth.ClientID,
|
||||||
|
}
|
||||||
|
tried := make(map[string]struct{}, len(clientIDs))
|
||||||
|
var lastErr error
|
||||||
|
|
||||||
|
for _, clientID := range clientIDs {
|
||||||
|
if clientID == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := tried[clientID]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
tried[clientID] = struct{}{}
|
||||||
|
|
||||||
|
payload := map[string]any{
|
||||||
|
"client_id": clientID,
|
||||||
|
"grant_type": "refresh_token",
|
||||||
|
"refresh_token": refreshToken,
|
||||||
|
"redirect_uri": "com.openai.chat://auth0.openai.com/ios/com.openai.chat/callback",
|
||||||
|
}
|
||||||
|
bodyBytes, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", "", err
|
||||||
|
}
|
||||||
|
headers := http.Header{}
|
||||||
|
headers.Set("Accept", "application/json")
|
||||||
|
headers.Set("Content-Type", "application/json")
|
||||||
|
headers.Set("User-Agent", c.defaultUserAgent())
|
||||||
|
|
||||||
|
respBody, _, err := c.doRequest(ctx, account, http.MethodPost, soraOAuthTokenURL, headers, bytes.NewReader(bodyBytes), false)
|
||||||
|
if err != nil {
|
||||||
|
lastErr = err
|
||||||
|
if c.debugEnabled() {
|
||||||
|
c.debugLogf("refresh_token_exchange_failed account_id=%d client_id=%s err=%s", account.ID, clientID, logredact.RedactText(err.Error()))
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
accessToken := strings.TrimSpace(gjson.GetBytes(respBody, "access_token").String())
|
||||||
|
if accessToken == "" {
|
||||||
|
lastErr = errors.New("oauth refresh response missing access_token")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
newRefreshToken := strings.TrimSpace(gjson.GetBytes(respBody, "refresh_token").String())
|
||||||
|
expiresIn := gjson.GetBytes(respBody, "expires_in").Int()
|
||||||
|
expiresAt := ""
|
||||||
|
if expiresIn > 0 {
|
||||||
|
expiresAt = time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339)
|
||||||
|
}
|
||||||
|
return accessToken, newRefreshToken, expiresAt, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if lastErr != nil {
|
||||||
|
return "", "", "", lastErr
|
||||||
|
}
|
||||||
|
return "", "", "", errors.New("no available client_id for refresh_token exchange")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SoraDirectClient) applyRecoveredToken(ctx context.Context, account *Account, accessToken, refreshToken, expiresAt, sessionToken string) {
|
||||||
|
if account == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if account.Credentials == nil {
|
||||||
|
account.Credentials = make(map[string]any)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(accessToken) != "" {
|
||||||
|
account.Credentials["access_token"] = accessToken
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(refreshToken) != "" {
|
||||||
|
account.Credentials["refresh_token"] = refreshToken
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(expiresAt) != "" {
|
||||||
|
account.Credentials["expires_at"] = expiresAt
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(sessionToken) != "" {
|
||||||
|
account.Credentials["session_token"] = sessionToken
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.accountRepo != nil {
|
||||||
|
if err := c.accountRepo.Update(ctx, account); err != nil {
|
||||||
|
if c.debugEnabled() {
|
||||||
|
c.debugLogf("persist_recovered_token_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.updateSoraAccountExtension(ctx, account, accessToken, refreshToken, sessionToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SoraDirectClient) updateSoraAccountExtension(ctx context.Context, account *Account, accessToken, refreshToken, sessionToken string) {
|
||||||
|
if c == nil || c.soraAccountRepo == nil || account == nil || account.ID <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
updates := make(map[string]any)
|
||||||
|
if strings.TrimSpace(accessToken) != "" && strings.TrimSpace(refreshToken) != "" {
|
||||||
|
updates["access_token"] = accessToken
|
||||||
|
updates["refresh_token"] = refreshToken
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(sessionToken) != "" {
|
||||||
|
updates["session_token"] = sessionToken
|
||||||
|
}
|
||||||
|
if len(updates) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := c.soraAccountRepo.Upsert(ctx, account.ID, updates); err != nil && c.debugEnabled() {
|
||||||
|
c.debugLogf("persist_sora_extension_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SoraDirectClient) logTokenRecover(account *Account, source, reason string, success bool, err error) {
|
||||||
|
if !c.debugEnabled() || account == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if success {
|
||||||
|
c.debugLogf("token_recover_success account_id=%d platform=%s source=%s reason=%s", account.ID, account.Platform, source, reason)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err == nil {
|
||||||
|
c.debugLogf("token_recover_failed account_id=%d platform=%s source=%s reason=%s", account.ID, account.Platform, source, reason)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.debugLogf("token_recover_failed account_id=%d platform=%s source=%s reason=%s err=%s", account.ID, account.Platform, source, reason, logredact.RedactText(err.Error()))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SoraDirectClient) allowOpenAITokenProvider(account *Account) bool {
|
||||||
|
if c == nil || c.tokenProvider == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if account != nil && account.Platform == PlatformSora {
|
||||||
|
return c.cfg != nil && c.cfg.Sora.Client.UseOpenAITokenProvider
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SoraDirectClient) logTokenSource(account *Account, source string) {
|
||||||
|
if !c.debugEnabled() || account == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.debugLogf(
|
||||||
|
"token_selected account_id=%d platform=%s account_type=%s source=%s",
|
||||||
|
account.ID,
|
||||||
|
account.Platform,
|
||||||
|
account.Type,
|
||||||
|
source,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *SoraDirectClient) buildBaseHeaders(token, userAgent string) http.Header {
|
func (c *SoraDirectClient) buildBaseHeaders(token, userAgent string) http.Header {
|
||||||
@@ -600,7 +960,24 @@ func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, meth
|
|||||||
}
|
}
|
||||||
|
|
||||||
attempts := maxRetries + 1
|
attempts := maxRetries + 1
|
||||||
|
authRecovered := false
|
||||||
|
authRecoverExtraAttemptGranted := false
|
||||||
|
var lastErr error
|
||||||
for attempt := 1; attempt <= attempts; attempt++ {
|
for attempt := 1; attempt <= attempts; attempt++ {
|
||||||
|
if c.debugEnabled() {
|
||||||
|
c.debugLogf(
|
||||||
|
"request_start method=%s url=%s attempt=%d/%d timeout_s=%d body_bytes=%d proxy_bound=%t headers=%s",
|
||||||
|
method,
|
||||||
|
sanitizeSoraLogURL(urlStr),
|
||||||
|
attempt,
|
||||||
|
attempts,
|
||||||
|
timeout,
|
||||||
|
len(bodyBytes),
|
||||||
|
account != nil && account.ProxyID != nil && account.Proxy != nil,
|
||||||
|
formatSoraHeaders(headers),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
var reader io.Reader
|
var reader io.Reader
|
||||||
if bodyBytes != nil {
|
if bodyBytes != nil {
|
||||||
reader = bytes.NewReader(bodyBytes)
|
reader = bytes.NewReader(bodyBytes)
|
||||||
@@ -618,7 +995,21 @@ func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, meth
|
|||||||
}
|
}
|
||||||
resp, err := c.doHTTP(req, proxyURL, account)
|
resp, err := c.doHTTP(req, proxyURL, account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
lastErr = err
|
||||||
|
if c.debugEnabled() {
|
||||||
|
c.debugLogf(
|
||||||
|
"request_transport_error method=%s url=%s attempt=%d/%d err=%s",
|
||||||
|
method,
|
||||||
|
sanitizeSoraLogURL(urlStr),
|
||||||
|
attempt,
|
||||||
|
attempts,
|
||||||
|
logredact.RedactText(err.Error()),
|
||||||
|
)
|
||||||
|
}
|
||||||
if attempt < attempts && allowRetry {
|
if attempt < attempts && allowRetry {
|
||||||
|
if c.debugEnabled() {
|
||||||
|
c.debugLogf("request_retry_scheduled method=%s url=%s reason=transport_error next_attempt=%d/%d", method, sanitizeSoraLogURL(urlStr), attempt+1, attempts)
|
||||||
|
}
|
||||||
c.sleepRetry(attempt)
|
c.sleepRetry(attempt)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -632,12 +1023,53 @@ func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, meth
|
|||||||
}
|
}
|
||||||
|
|
||||||
if c.cfg != nil && c.cfg.Sora.Client.Debug {
|
if c.cfg != nil && c.cfg.Sora.Client.Debug {
|
||||||
log.Printf("[SoraClient] %s %s status=%d cost=%s", method, sanitizeSoraLogURL(urlStr), resp.StatusCode, time.Since(start))
|
c.debugLogf(
|
||||||
|
"response_received method=%s url=%s attempt=%d/%d status=%d cost=%s resp_bytes=%d resp_headers=%s",
|
||||||
|
method,
|
||||||
|
sanitizeSoraLogURL(urlStr),
|
||||||
|
attempt,
|
||||||
|
attempts,
|
||||||
|
resp.StatusCode,
|
||||||
|
time.Since(start),
|
||||||
|
len(respBody),
|
||||||
|
formatSoraHeaders(resp.Header),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
|
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
|
||||||
upstreamErr := c.buildUpstreamError(resp.StatusCode, resp.Header, respBody)
|
if !authRecovered && shouldAttemptSoraTokenRecover(resp.StatusCode, urlStr) && account != nil {
|
||||||
|
if recovered, recoverErr := c.recoverAccessToken(ctx, account, fmt.Sprintf("upstream_status_%d", resp.StatusCode)); recoverErr == nil && strings.TrimSpace(recovered) != "" {
|
||||||
|
headers.Set("Authorization", "Bearer "+recovered)
|
||||||
|
authRecovered = true
|
||||||
|
if attempt == attempts && !authRecoverExtraAttemptGranted {
|
||||||
|
attempts++
|
||||||
|
authRecoverExtraAttemptGranted = true
|
||||||
|
}
|
||||||
|
if c.debugEnabled() {
|
||||||
|
c.debugLogf("request_retry_with_recovered_token method=%s url=%s status=%d", method, sanitizeSoraLogURL(urlStr), resp.StatusCode)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
} else if recoverErr != nil && c.debugEnabled() {
|
||||||
|
c.debugLogf("request_recover_token_failed method=%s url=%s status=%d err=%s", method, sanitizeSoraLogURL(urlStr), resp.StatusCode, logredact.RedactText(recoverErr.Error()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if c.debugEnabled() {
|
||||||
|
c.debugLogf(
|
||||||
|
"response_non_success method=%s url=%s attempt=%d/%d status=%d body=%s",
|
||||||
|
method,
|
||||||
|
sanitizeSoraLogURL(urlStr),
|
||||||
|
attempt,
|
||||||
|
attempts,
|
||||||
|
resp.StatusCode,
|
||||||
|
summarizeSoraResponseBody(respBody, 512),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
upstreamErr := c.buildUpstreamError(resp.StatusCode, resp.Header, respBody, urlStr)
|
||||||
|
lastErr = upstreamErr
|
||||||
if allowRetry && attempt < attempts && (resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode >= 500) {
|
if allowRetry && attempt < attempts && (resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode >= 500) {
|
||||||
|
if c.debugEnabled() {
|
||||||
|
c.debugLogf("request_retry_scheduled method=%s url=%s reason=status_%d next_attempt=%d/%d", method, sanitizeSoraLogURL(urlStr), resp.StatusCode, attempt+1, attempts)
|
||||||
|
}
|
||||||
c.sleepRetry(attempt)
|
c.sleepRetry(attempt)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -645,9 +1077,34 @@ func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, meth
|
|||||||
}
|
}
|
||||||
return respBody, resp.Header, nil
|
return respBody, resp.Header, nil
|
||||||
}
|
}
|
||||||
|
if lastErr != nil {
|
||||||
|
return nil, nil, lastErr
|
||||||
|
}
|
||||||
return nil, nil, errors.New("upstream retries exhausted")
|
return nil, nil, errors.New("upstream retries exhausted")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func shouldAttemptSoraTokenRecover(statusCode int, rawURL string) bool {
|
||||||
|
switch statusCode {
|
||||||
|
case http.StatusUnauthorized, http.StatusForbidden:
|
||||||
|
parsed, err := url.Parse(strings.TrimSpace(rawURL))
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
host := strings.ToLower(parsed.Hostname())
|
||||||
|
if host != "sora.chatgpt.com" && host != "chatgpt.com" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// 避免在 ST->AT 转换接口上递归触发 token 恢复导致死循环。
|
||||||
|
path := strings.ToLower(strings.TrimSpace(parsed.Path))
|
||||||
|
if path == "/api/auth/session" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (c *SoraDirectClient) doHTTP(req *http.Request, proxyURL string, account *Account) (*http.Response, error) {
|
func (c *SoraDirectClient) doHTTP(req *http.Request, proxyURL string, account *Account) (*http.Response, error) {
|
||||||
enableTLS := c != nil && c.cfg != nil && c.cfg.Gateway.TLSFingerprint.Enabled && !c.cfg.Sora.Client.DisableTLSFingerprint
|
enableTLS := c != nil && c.cfg != nil && c.cfg.Gateway.TLSFingerprint.Enabled && !c.cfg.Sora.Client.DisableTLSFingerprint
|
||||||
if c.httpUpstream != nil {
|
if c.httpUpstream != nil {
|
||||||
@@ -670,9 +1127,14 @@ func (c *SoraDirectClient) sleepRetry(attempt int) {
|
|||||||
time.Sleep(backoff)
|
time.Sleep(backoff)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *SoraDirectClient) buildUpstreamError(status int, headers http.Header, body []byte) error {
|
func (c *SoraDirectClient) buildUpstreamError(status int, headers http.Header, body []byte, requestURL string) error {
|
||||||
msg := strings.TrimSpace(extractUpstreamErrorMessage(body))
|
msg := strings.TrimSpace(extractUpstreamErrorMessage(body))
|
||||||
msg = sanitizeUpstreamErrorMessage(msg)
|
msg = sanitizeUpstreamErrorMessage(msg)
|
||||||
|
if status == http.StatusNotFound && strings.Contains(strings.ToLower(msg), "not found") {
|
||||||
|
if hint := soraBaseURLNotFoundHint(requestURL); hint != "" {
|
||||||
|
msg = strings.TrimSpace(msg + " " + hint)
|
||||||
|
}
|
||||||
|
}
|
||||||
if msg == "" {
|
if msg == "" {
|
||||||
msg = truncateForLog(body, 256)
|
msg = truncateForLog(body, 256)
|
||||||
}
|
}
|
||||||
@@ -684,6 +1146,45 @@ func (c *SoraDirectClient) buildUpstreamError(status int, headers http.Header, b
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func normalizeSoraBaseURL(raw string) string {
|
||||||
|
trimmed := strings.TrimRight(strings.TrimSpace(raw), "/")
|
||||||
|
if trimmed == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
parsed, err := url.Parse(trimmed)
|
||||||
|
if err != nil || parsed.Scheme == "" || parsed.Host == "" {
|
||||||
|
return trimmed
|
||||||
|
}
|
||||||
|
host := strings.ToLower(parsed.Hostname())
|
||||||
|
if host != "sora.chatgpt.com" && host != "chatgpt.com" {
|
||||||
|
return trimmed
|
||||||
|
}
|
||||||
|
pathVal := strings.TrimRight(strings.TrimSpace(parsed.Path), "/")
|
||||||
|
switch pathVal {
|
||||||
|
case "", "/":
|
||||||
|
parsed.Path = "/backend"
|
||||||
|
case "/backend-api":
|
||||||
|
parsed.Path = "/backend"
|
||||||
|
}
|
||||||
|
return strings.TrimRight(parsed.String(), "/")
|
||||||
|
}
|
||||||
|
|
||||||
|
func soraBaseURLNotFoundHint(requestURL string) string {
|
||||||
|
parsed, err := url.Parse(strings.TrimSpace(requestURL))
|
||||||
|
if err != nil || parsed.Host == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
host := strings.ToLower(parsed.Hostname())
|
||||||
|
if host != "sora.chatgpt.com" && host != "chatgpt.com" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
pathVal := strings.TrimSpace(parsed.Path)
|
||||||
|
if strings.HasPrefix(pathVal, "/backend/") || pathVal == "/backend" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return "(请检查 sora.client.base_url,建议配置为 https://sora.chatgpt.com/backend)"
|
||||||
|
}
|
||||||
|
|
||||||
func (c *SoraDirectClient) generateSentinelToken(ctx context.Context, account *Account, accessToken string) (string, error) {
|
func (c *SoraDirectClient) generateSentinelToken(ctx context.Context, account *Account, accessToken string) (string, error) {
|
||||||
reqID := uuid.NewString()
|
reqID := uuid.NewString()
|
||||||
userAgent := soraRandChoice(soraDesktopUserAgents)
|
userAgent := soraRandChoice(soraDesktopUserAgents)
|
||||||
@@ -901,3 +1402,70 @@ func sanitizeSoraLogURL(raw string) string {
|
|||||||
parsed.RawQuery = q.Encode()
|
parsed.RawQuery = q.Encode()
|
||||||
return parsed.String()
|
return parsed.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *SoraDirectClient) debugEnabled() bool {
|
||||||
|
return c != nil && c.cfg != nil && c.cfg.Sora.Client.Debug
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SoraDirectClient) debugLogf(format string, args ...any) {
|
||||||
|
if !c.debugEnabled() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Printf("[SoraClient] "+format, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatSoraHeaders(headers http.Header) string {
|
||||||
|
if len(headers) == 0 {
|
||||||
|
return "{}"
|
||||||
|
}
|
||||||
|
keys := make([]string, 0, len(headers))
|
||||||
|
for key := range headers {
|
||||||
|
keys = append(keys, key)
|
||||||
|
}
|
||||||
|
sort.Strings(keys)
|
||||||
|
out := make(map[string]string, len(keys))
|
||||||
|
for _, key := range keys {
|
||||||
|
values := headers.Values(key)
|
||||||
|
if len(values) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
val := strings.Join(values, ",")
|
||||||
|
if isSensitiveHeader(key) {
|
||||||
|
out[key] = "***"
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out[key] = truncateForLog([]byte(logredact.RedactText(val)), 160)
|
||||||
|
}
|
||||||
|
encoded, err := json.Marshal(out)
|
||||||
|
if err != nil {
|
||||||
|
return "{}"
|
||||||
|
}
|
||||||
|
return string(encoded)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isSensitiveHeader(key string) bool {
|
||||||
|
k := strings.ToLower(strings.TrimSpace(key))
|
||||||
|
switch k {
|
||||||
|
case "authorization", "openai-sentinel-token", "cookie", "set-cookie", "x-api-key":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func summarizeSoraResponseBody(body []byte, maxLen int) string {
|
||||||
|
if len(body) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
var text string
|
||||||
|
if json.Valid(body) {
|
||||||
|
text = logredact.RedactJSON(body)
|
||||||
|
} else {
|
||||||
|
text = logredact.RedactText(string(body))
|
||||||
|
}
|
||||||
|
text = strings.TrimSpace(text)
|
||||||
|
if maxLen <= 0 || len(text) <= maxLen {
|
||||||
|
return text
|
||||||
|
}
|
||||||
|
return text[:maxLen] + "...(truncated)"
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,9 +4,13 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@@ -85,3 +89,273 @@ func TestSoraDirectClient_GetImageTaskFallbackLimit(t *testing.T) {
|
|||||||
require.Equal(t, "completed", status.Status)
|
require.Equal(t, "completed", status.Status)
|
||||||
require.Equal(t, []string{"https://example.com/a.png"}, status.URLs)
|
require.Equal(t, []string{"https://example.com/a.png"}, status.URLs)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNormalizeSoraBaseURL(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
raw string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty",
|
||||||
|
raw: "",
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "append_backend_for_sora_host",
|
||||||
|
raw: "https://sora.chatgpt.com",
|
||||||
|
want: "https://sora.chatgpt.com/backend",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "convert_backend_api_to_backend",
|
||||||
|
raw: "https://sora.chatgpt.com/backend-api",
|
||||||
|
want: "https://sora.chatgpt.com/backend",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "keep_backend",
|
||||||
|
raw: "https://sora.chatgpt.com/backend",
|
||||||
|
want: "https://sora.chatgpt.com/backend",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "keep_custom_host",
|
||||||
|
raw: "https://example.com/custom-path",
|
||||||
|
want: "https://example.com/custom-path",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := normalizeSoraBaseURL(tt.raw)
|
||||||
|
require.Equal(t, tt.want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSoraDirectClient_BuildURL_UsesNormalizedBaseURL(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
cfg := &config.Config{
|
||||||
|
Sora: config.SoraConfig{
|
||||||
|
Client: config.SoraClientConfig{
|
||||||
|
BaseURL: "https://sora.chatgpt.com",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
client := NewSoraDirectClient(cfg, nil, nil)
|
||||||
|
require.Equal(t, "https://sora.chatgpt.com/backend/video_gen", client.buildURL("/video_gen"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSoraDirectClient_BuildUpstreamError_NotFoundHint(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
client := NewSoraDirectClient(&config.Config{}, nil, nil)
|
||||||
|
|
||||||
|
err := client.buildUpstreamError(http.StatusNotFound, http.Header{}, []byte(`{"error":{"message":"Not found"}}`), "https://sora.chatgpt.com/video_gen")
|
||||||
|
var upstreamErr *SoraUpstreamError
|
||||||
|
require.ErrorAs(t, err, &upstreamErr)
|
||||||
|
require.Contains(t, upstreamErr.Message, "请检查 sora.client.base_url")
|
||||||
|
|
||||||
|
errNoHint := client.buildUpstreamError(http.StatusNotFound, http.Header{}, []byte(`{"error":{"message":"Not found"}}`), "https://sora.chatgpt.com/backend/video_gen")
|
||||||
|
require.ErrorAs(t, errNoHint, &upstreamErr)
|
||||||
|
require.NotContains(t, upstreamErr.Message, "请检查 sora.client.base_url")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatSoraHeaders_RedactsSensitive(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
headers := http.Header{}
|
||||||
|
headers.Set("Authorization", "Bearer secret-token")
|
||||||
|
headers.Set("openai-sentinel-token", "sentinel-secret")
|
||||||
|
headers.Set("X-Test", "ok")
|
||||||
|
|
||||||
|
out := formatSoraHeaders(headers)
|
||||||
|
require.Contains(t, out, `"Authorization":"***"`)
|
||||||
|
require.Contains(t, out, `Sentinel-Token":"***"`)
|
||||||
|
require.Contains(t, out, `"X-Test":"ok"`)
|
||||||
|
require.NotContains(t, out, "secret-token")
|
||||||
|
require.NotContains(t, out, "sentinel-secret")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSummarizeSoraResponseBody_RedactsJSON(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
body := []byte(`{"error":{"message":"bad"},"access_token":"abc123"}`)
|
||||||
|
out := summarizeSoraResponseBody(body, 512)
|
||||||
|
require.Contains(t, out, `"access_token":"***"`)
|
||||||
|
require.NotContains(t, out, "abc123")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSummarizeSoraResponseBody_Truncates(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
body := []byte(strings.Repeat("x", 100))
|
||||||
|
out := summarizeSoraResponseBody(body, 10)
|
||||||
|
require.Contains(t, out, "(truncated)")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSoraDirectClient_GetAccessToken_SoraDefaultUseCredentials(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
cache := newOpenAITokenCacheStub()
|
||||||
|
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||||
|
cfg := &config.Config{
|
||||||
|
Sora: config.SoraConfig{
|
||||||
|
Client: config.SoraClientConfig{
|
||||||
|
BaseURL: "https://sora.chatgpt.com/backend",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
client := NewSoraDirectClient(cfg, nil, provider)
|
||||||
|
account := &Account{
|
||||||
|
ID: 1,
|
||||||
|
Platform: PlatformSora,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "sora-credential-token",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := client.getAccessToken(context.Background(), account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "sora-credential-token", token)
|
||||||
|
require.Equal(t, int32(0), atomic.LoadInt32(&cache.getCalled))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSoraDirectClient_GetAccessToken_SoraCanEnableProvider(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
cache := newOpenAITokenCacheStub()
|
||||||
|
account := &Account{
|
||||||
|
ID: 2,
|
||||||
|
Platform: PlatformSora,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "sora-credential-token",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cache.tokens[OpenAITokenCacheKey(account)] = "provider-token"
|
||||||
|
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||||
|
cfg := &config.Config{
|
||||||
|
Sora: config.SoraConfig{
|
||||||
|
Client: config.SoraClientConfig{
|
||||||
|
BaseURL: "https://sora.chatgpt.com/backend",
|
||||||
|
UseOpenAITokenProvider: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
client := NewSoraDirectClient(cfg, nil, provider)
|
||||||
|
|
||||||
|
token, err := client.getAccessToken(context.Background(), account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "provider-token", token)
|
||||||
|
require.Greater(t, atomic.LoadInt32(&cache.getCalled), int32(0))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSoraDirectClient_GetAccessToken_FromSessionToken(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
require.Equal(t, http.MethodGet, r.Method)
|
||||||
|
require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=session-token")
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"accessToken": "session-access-token",
|
||||||
|
"expires": "2099-01-01T00:00:00Z",
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
origin := soraSessionAuthURL
|
||||||
|
soraSessionAuthURL = server.URL
|
||||||
|
defer func() { soraSessionAuthURL = origin }()
|
||||||
|
|
||||||
|
client := NewSoraDirectClient(&config.Config{}, nil, nil)
|
||||||
|
account := &Account{
|
||||||
|
ID: 10,
|
||||||
|
Platform: PlatformSora,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"session_token": "session-token",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := client.getAccessToken(context.Background(), account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "session-access-token", token)
|
||||||
|
require.Equal(t, "session-access-token", account.GetCredential("access_token"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSoraDirectClient_GetAccessToken_FromRefreshToken(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
require.Equal(t, http.MethodPost, r.Method)
|
||||||
|
require.Equal(t, "/oauth/token", r.URL.Path)
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"access_token": "refresh-access-token",
|
||||||
|
"refresh_token": "refresh-token-new",
|
||||||
|
"expires_in": 3600,
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
origin := soraOAuthTokenURL
|
||||||
|
soraOAuthTokenURL = server.URL + "/oauth/token"
|
||||||
|
defer func() { soraOAuthTokenURL = origin }()
|
||||||
|
|
||||||
|
client := NewSoraDirectClient(&config.Config{}, nil, nil)
|
||||||
|
account := &Account{
|
||||||
|
ID: 11,
|
||||||
|
Platform: PlatformSora,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"refresh_token": "refresh-token-old",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := client.getAccessToken(context.Background(), account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "refresh-access-token", token)
|
||||||
|
require.Equal(t, "refresh-token-new", account.GetCredential("refresh_token"))
|
||||||
|
require.NotNil(t, account.GetCredentialAsTime("expires_at"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSoraDirectClient_PreflightCheck_VideoQuotaExceeded(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
require.Equal(t, http.MethodGet, r.Method)
|
||||||
|
require.Equal(t, "/nf/check", r.URL.Path)
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"rate_limit_and_credit_balance": map[string]any{
|
||||||
|
"estimated_num_videos_remaining": 0,
|
||||||
|
"rate_limit_reached": true,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
cfg := &config.Config{
|
||||||
|
Sora: config.SoraConfig{
|
||||||
|
Client: config.SoraClientConfig{
|
||||||
|
BaseURL: server.URL,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
client := NewSoraDirectClient(cfg, nil, nil)
|
||||||
|
account := &Account{
|
||||||
|
ID: 12,
|
||||||
|
Platform: PlatformSora,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "ok",
|
||||||
|
"expires_at": time.Now().Add(2 * time.Hour).Format(time.RFC3339),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
err := client.PreflightCheck(context.Background(), account, "sora2-landscape-10s", SoraModelConfig{Type: "video"})
|
||||||
|
require.Error(t, err)
|
||||||
|
var upstreamErr *SoraUpstreamError
|
||||||
|
require.ErrorAs(t, err, &upstreamErr)
|
||||||
|
require.Equal(t, http.StatusTooManyRequests, upstreamErr.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShouldAttemptSoraTokenRecover(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
require.True(t, shouldAttemptSoraTokenRecover(http.StatusUnauthorized, "https://sora.chatgpt.com/backend/video_gen"))
|
||||||
|
require.True(t, shouldAttemptSoraTokenRecover(http.StatusForbidden, "https://chatgpt.com/backend/video_gen"))
|
||||||
|
require.False(t, shouldAttemptSoraTokenRecover(http.StatusUnauthorized, "https://sora.chatgpt.com/api/auth/session"))
|
||||||
|
require.False(t, shouldAttemptSoraTokenRecover(http.StatusUnauthorized, "https://auth.openai.com/oauth/token"))
|
||||||
|
require.False(t, shouldAttemptSoraTokenRecover(http.StatusTooManyRequests, "https://sora.chatgpt.com/backend/video_gen"))
|
||||||
|
}
|
||||||
|
|||||||
@@ -61,6 +61,10 @@ type SoraGatewayService struct {
|
|||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type soraPreflightChecker interface {
|
||||||
|
PreflightCheck(ctx context.Context, account *Account, requestedModel string, modelCfg SoraModelConfig) error
|
||||||
|
}
|
||||||
|
|
||||||
func NewSoraGatewayService(
|
func NewSoraGatewayService(
|
||||||
soraClient SoraClient,
|
soraClient SoraClient,
|
||||||
mediaStorage *SoraMediaStorage,
|
mediaStorage *SoraMediaStorage,
|
||||||
@@ -112,11 +116,6 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
|
|||||||
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Unsupported Sora model", clientStream)
|
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Unsupported Sora model", clientStream)
|
||||||
return nil, fmt.Errorf("unsupported model: %s", reqModel)
|
return nil, fmt.Errorf("unsupported model: %s", reqModel)
|
||||||
}
|
}
|
||||||
if modelCfg.Type == "prompt_enhance" {
|
|
||||||
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Prompt-enhance 模型暂未支持", clientStream)
|
|
||||||
return nil, fmt.Errorf("prompt-enhance not supported")
|
|
||||||
}
|
|
||||||
|
|
||||||
prompt, imageInput, videoInput, remixTargetID := extractSoraInput(reqBody)
|
prompt, imageInput, videoInput, remixTargetID := extractSoraInput(reqBody)
|
||||||
if strings.TrimSpace(prompt) == "" {
|
if strings.TrimSpace(prompt) == "" {
|
||||||
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream)
|
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream)
|
||||||
@@ -131,6 +130,41 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
|
|||||||
if cancel != nil {
|
if cancel != nil {
|
||||||
defer cancel()
|
defer cancel()
|
||||||
}
|
}
|
||||||
|
if checker, ok := s.soraClient.(soraPreflightChecker); ok {
|
||||||
|
if err := checker.PreflightCheck(reqCtx, account, reqModel, modelCfg); err != nil {
|
||||||
|
return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelCfg.Type == "prompt_enhance" {
|
||||||
|
enhancedPrompt, err := s.soraClient.EnhancePrompt(reqCtx, account, prompt, modelCfg.ExpansionLevel, modelCfg.DurationS)
|
||||||
|
if err != nil {
|
||||||
|
return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream)
|
||||||
|
}
|
||||||
|
content := strings.TrimSpace(enhancedPrompt)
|
||||||
|
if content == "" {
|
||||||
|
content = prompt
|
||||||
|
}
|
||||||
|
var firstTokenMs *int
|
||||||
|
if clientStream {
|
||||||
|
ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime)
|
||||||
|
if streamErr != nil {
|
||||||
|
return nil, streamErr
|
||||||
|
}
|
||||||
|
firstTokenMs = ms
|
||||||
|
} else if c != nil {
|
||||||
|
c.JSON(http.StatusOK, buildSoraNonStreamResponse(content, reqModel))
|
||||||
|
}
|
||||||
|
return &ForwardResult{
|
||||||
|
RequestID: "",
|
||||||
|
Model: reqModel,
|
||||||
|
Stream: clientStream,
|
||||||
|
Duration: time.Since(startTime),
|
||||||
|
FirstTokenMs: firstTokenMs,
|
||||||
|
Usage: ClaudeUsage{},
|
||||||
|
MediaType: "prompt",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
var imageData []byte
|
var imageData []byte
|
||||||
imageFilename := ""
|
imageFilename := ""
|
||||||
@@ -267,7 +301,7 @@ func (s *SoraGatewayService) withSoraTimeout(ctx context.Context, stream bool) (
|
|||||||
|
|
||||||
func (s *SoraGatewayService) shouldFailoverUpstreamError(statusCode int) bool {
|
func (s *SoraGatewayService) shouldFailoverUpstreamError(statusCode int) bool {
|
||||||
switch statusCode {
|
switch statusCode {
|
||||||
case 401, 402, 403, 429, 529:
|
case 401, 402, 403, 404, 429, 529:
|
||||||
return true
|
return true
|
||||||
default:
|
default:
|
||||||
return statusCode >= 500
|
return statusCode >= 500
|
||||||
@@ -460,7 +494,7 @@ func (s *SoraGatewayService) handleSoraRequestError(ctx context.Context, account
|
|||||||
s.rateLimitService.HandleUpstreamError(ctx, account, upstreamErr.StatusCode, upstreamErr.Headers, upstreamErr.Body)
|
s.rateLimitService.HandleUpstreamError(ctx, account, upstreamErr.StatusCode, upstreamErr.Headers, upstreamErr.Body)
|
||||||
}
|
}
|
||||||
if s.shouldFailoverUpstreamError(upstreamErr.StatusCode) {
|
if s.shouldFailoverUpstreamError(upstreamErr.StatusCode) {
|
||||||
return &UpstreamFailoverError{StatusCode: upstreamErr.StatusCode}
|
return &UpstreamFailoverError{StatusCode: upstreamErr.StatusCode, ResponseBody: upstreamErr.Body}
|
||||||
}
|
}
|
||||||
msg := upstreamErr.Message
|
msg := upstreamErr.Message
|
||||||
if override := soraProErrorMessage(model, msg); override != "" {
|
if override := soraProErrorMessage(model, msg); override != "" {
|
||||||
|
|||||||
@@ -18,6 +18,8 @@ type stubSoraClientForPoll struct {
|
|||||||
videoStatus *SoraVideoTaskStatus
|
videoStatus *SoraVideoTaskStatus
|
||||||
imageCalls int
|
imageCalls int
|
||||||
videoCalls int
|
videoCalls int
|
||||||
|
enhanced string
|
||||||
|
enhanceErr error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *stubSoraClientForPoll) Enabled() bool { return true }
|
func (s *stubSoraClientForPoll) Enabled() bool { return true }
|
||||||
@@ -30,6 +32,12 @@ func (s *stubSoraClientForPoll) CreateImageTask(ctx context.Context, account *Ac
|
|||||||
func (s *stubSoraClientForPoll) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) {
|
func (s *stubSoraClientForPoll) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) {
|
||||||
return "task-video", nil
|
return "task-video", nil
|
||||||
}
|
}
|
||||||
|
func (s *stubSoraClientForPoll) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) {
|
||||||
|
if s.enhanced != "" {
|
||||||
|
return s.enhanced, s.enhanceErr
|
||||||
|
}
|
||||||
|
return "enhanced prompt", s.enhanceErr
|
||||||
|
}
|
||||||
func (s *stubSoraClientForPoll) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) {
|
func (s *stubSoraClientForPoll) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) {
|
||||||
s.imageCalls++
|
s.imageCalls++
|
||||||
return s.imageStatus, nil
|
return s.imageStatus, nil
|
||||||
@@ -62,6 +70,33 @@ func TestSoraGatewayService_PollImageTaskCompleted(t *testing.T) {
|
|||||||
require.Equal(t, 1, client.imageCalls)
|
require.Equal(t, 1, client.imageCalls)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSoraGatewayService_ForwardPromptEnhance(t *testing.T) {
|
||||||
|
client := &stubSoraClientForPoll{
|
||||||
|
enhanced: "cinematic prompt",
|
||||||
|
}
|
||||||
|
cfg := &config.Config{
|
||||||
|
Sora: config.SoraConfig{
|
||||||
|
Client: config.SoraClientConfig{
|
||||||
|
PollIntervalSeconds: 1,
|
||||||
|
MaxPollAttempts: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewSoraGatewayService(client, nil, nil, cfg)
|
||||||
|
account := &Account{
|
||||||
|
ID: 1,
|
||||||
|
Platform: PlatformSora,
|
||||||
|
Status: StatusActive,
|
||||||
|
}
|
||||||
|
body := []byte(`{"model":"prompt-enhance-short-10s","messages":[{"role":"user","content":"cat running"}],"stream":false}`)
|
||||||
|
|
||||||
|
result, err := svc.Forward(context.Background(), nil, account, body, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Equal(t, "prompt", result.MediaType)
|
||||||
|
require.Equal(t, "prompt-enhance-short-10s", result.Model)
|
||||||
|
}
|
||||||
|
|
||||||
func TestSoraGatewayService_PollVideoTaskFailed(t *testing.T) {
|
func TestSoraGatewayService_PollVideoTaskFailed(t *testing.T) {
|
||||||
client := &stubSoraClientForPoll{
|
client := &stubSoraClientForPoll{
|
||||||
videoStatus: &SoraVideoTaskStatus{
|
videoStatus: &SoraVideoTaskStatus{
|
||||||
@@ -178,6 +213,7 @@ func TestSoraProErrorMessage(t *testing.T) {
|
|||||||
func TestShouldFailoverUpstreamError(t *testing.T) {
|
func TestShouldFailoverUpstreamError(t *testing.T) {
|
||||||
svc := NewSoraGatewayService(nil, nil, nil, &config.Config{})
|
svc := NewSoraGatewayService(nil, nil, nil, &config.Config{})
|
||||||
require.True(t, svc.shouldFailoverUpstreamError(401))
|
require.True(t, svc.shouldFailoverUpstreamError(401))
|
||||||
|
require.True(t, svc.shouldFailoverUpstreamError(404))
|
||||||
require.True(t, svc.shouldFailoverUpstreamError(429))
|
require.True(t, svc.shouldFailoverUpstreamError(429))
|
||||||
require.True(t, svc.shouldFailoverUpstreamError(500))
|
require.True(t, svc.shouldFailoverUpstreamError(500))
|
||||||
require.True(t, svc.shouldFailoverUpstreamError(502))
|
require.True(t, svc.shouldFailoverUpstreamError(502))
|
||||||
|
|||||||
@@ -17,6 +17,9 @@ type SoraModelConfig struct {
|
|||||||
Model string
|
Model string
|
||||||
Size string
|
Size string
|
||||||
RequirePro bool
|
RequirePro bool
|
||||||
|
// Prompt-enhance 专用参数
|
||||||
|
ExpansionLevel string
|
||||||
|
DurationS int
|
||||||
}
|
}
|
||||||
|
|
||||||
var soraModelConfigs = map[string]SoraModelConfig{
|
var soraModelConfigs = map[string]SoraModelConfig{
|
||||||
@@ -161,30 +164,48 @@ var soraModelConfigs = map[string]SoraModelConfig{
|
|||||||
},
|
},
|
||||||
"prompt-enhance-short-10s": {
|
"prompt-enhance-short-10s": {
|
||||||
Type: "prompt_enhance",
|
Type: "prompt_enhance",
|
||||||
|
ExpansionLevel: "short",
|
||||||
|
DurationS: 10,
|
||||||
},
|
},
|
||||||
"prompt-enhance-short-15s": {
|
"prompt-enhance-short-15s": {
|
||||||
Type: "prompt_enhance",
|
Type: "prompt_enhance",
|
||||||
|
ExpansionLevel: "short",
|
||||||
|
DurationS: 15,
|
||||||
},
|
},
|
||||||
"prompt-enhance-short-20s": {
|
"prompt-enhance-short-20s": {
|
||||||
Type: "prompt_enhance",
|
Type: "prompt_enhance",
|
||||||
|
ExpansionLevel: "short",
|
||||||
|
DurationS: 20,
|
||||||
},
|
},
|
||||||
"prompt-enhance-medium-10s": {
|
"prompt-enhance-medium-10s": {
|
||||||
Type: "prompt_enhance",
|
Type: "prompt_enhance",
|
||||||
|
ExpansionLevel: "medium",
|
||||||
|
DurationS: 10,
|
||||||
},
|
},
|
||||||
"prompt-enhance-medium-15s": {
|
"prompt-enhance-medium-15s": {
|
||||||
Type: "prompt_enhance",
|
Type: "prompt_enhance",
|
||||||
|
ExpansionLevel: "medium",
|
||||||
|
DurationS: 15,
|
||||||
},
|
},
|
||||||
"prompt-enhance-medium-20s": {
|
"prompt-enhance-medium-20s": {
|
||||||
Type: "prompt_enhance",
|
Type: "prompt_enhance",
|
||||||
|
ExpansionLevel: "medium",
|
||||||
|
DurationS: 20,
|
||||||
},
|
},
|
||||||
"prompt-enhance-long-10s": {
|
"prompt-enhance-long-10s": {
|
||||||
Type: "prompt_enhance",
|
Type: "prompt_enhance",
|
||||||
|
ExpansionLevel: "long",
|
||||||
|
DurationS: 10,
|
||||||
},
|
},
|
||||||
"prompt-enhance-long-15s": {
|
"prompt-enhance-long-15s": {
|
||||||
Type: "prompt_enhance",
|
Type: "prompt_enhance",
|
||||||
|
ExpansionLevel: "long",
|
||||||
|
DurationS: 15,
|
||||||
},
|
},
|
||||||
"prompt-enhance-long-20s": {
|
"prompt-enhance-long-20s": {
|
||||||
Type: "prompt_enhance",
|
Type: "prompt_enhance",
|
||||||
|
ExpansionLevel: "long",
|
||||||
|
DurationS: 20,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -43,10 +43,13 @@ func NewTokenRefreshService(
|
|||||||
stopCh: make(chan struct{}),
|
stopCh: make(chan struct{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
openAIRefresher := NewOpenAITokenRefresher(openaiOAuthService, accountRepo)
|
||||||
|
openAIRefresher.SetSyncLinkedSoraAccounts(cfg.TokenRefresh.SyncLinkedSoraAccounts)
|
||||||
|
|
||||||
// 注册平台特定的刷新器
|
// 注册平台特定的刷新器
|
||||||
s.refreshers = []TokenRefresher{
|
s.refreshers = []TokenRefresher{
|
||||||
NewClaudeTokenRefresher(oauthService),
|
NewClaudeTokenRefresher(oauthService),
|
||||||
NewOpenAITokenRefresher(openaiOAuthService, accountRepo),
|
openAIRefresher,
|
||||||
NewGeminiTokenRefresher(geminiOAuthService),
|
NewGeminiTokenRefresher(geminiOAuthService),
|
||||||
NewAntigravityTokenRefresher(antigravityOAuthService),
|
NewAntigravityTokenRefresher(antigravityOAuthService),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -86,6 +86,7 @@ type OpenAITokenRefresher struct {
|
|||||||
openaiOAuthService *OpenAIOAuthService
|
openaiOAuthService *OpenAIOAuthService
|
||||||
accountRepo AccountRepository
|
accountRepo AccountRepository
|
||||||
soraAccountRepo SoraAccountRepository // Sora 扩展表仓储,用于双表同步
|
soraAccountRepo SoraAccountRepository // Sora 扩展表仓储,用于双表同步
|
||||||
|
syncLinkedSora bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOpenAITokenRefresher 创建 OpenAI token刷新器
|
// NewOpenAITokenRefresher 创建 OpenAI token刷新器
|
||||||
@@ -103,11 +104,15 @@ func (r *OpenAITokenRefresher) SetSoraAccountRepo(repo SoraAccountRepository) {
|
|||||||
r.soraAccountRepo = repo
|
r.soraAccountRepo = repo
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetSyncLinkedSoraAccounts 控制是否同步覆盖关联的 Sora 账号 token。
|
||||||
|
func (r *OpenAITokenRefresher) SetSyncLinkedSoraAccounts(enabled bool) {
|
||||||
|
r.syncLinkedSora = enabled
|
||||||
|
}
|
||||||
|
|
||||||
// CanRefresh 检查是否能处理此账号
|
// CanRefresh 检查是否能处理此账号
|
||||||
// 只处理 openai 平台的 oauth 类型账号
|
// 只处理 openai 平台的 oauth 类型账号(不直接刷新 sora 平台账号)
|
||||||
func (r *OpenAITokenRefresher) CanRefresh(account *Account) bool {
|
func (r *OpenAITokenRefresher) CanRefresh(account *Account) bool {
|
||||||
return (account.Platform == PlatformOpenAI || account.Platform == PlatformSora) &&
|
return account.Platform == PlatformOpenAI && account.Type == AccountTypeOAuth
|
||||||
account.Type == AccountTypeOAuth
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NeedsRefresh 检查token是否需要刷新
|
// NeedsRefresh 检查token是否需要刷新
|
||||||
@@ -141,7 +146,7 @@ func (r *OpenAITokenRefresher) Refresh(ctx context.Context, account *Account) (m
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 异步同步关联的 Sora 账号(不阻塞主流程)
|
// 异步同步关联的 Sora 账号(不阻塞主流程)
|
||||||
if r.accountRepo != nil {
|
if r.accountRepo != nil && r.syncLinkedSora {
|
||||||
go r.syncLinkedSoraAccounts(context.Background(), account.ID, newCredentials)
|
go r.syncLinkedSoraAccounts(context.Background(), account.ID, newCredentials)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -226,3 +226,43 @@ func TestClaudeTokenRefresher_CanRefresh(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAITokenRefresher_CanRefresh(t *testing.T) {
|
||||||
|
refresher := &OpenAITokenRefresher{}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
platform string
|
||||||
|
accType string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "openai oauth - can refresh",
|
||||||
|
platform: PlatformOpenAI,
|
||||||
|
accType: AccountTypeOAuth,
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "sora oauth - cannot refresh directly",
|
||||||
|
platform: PlatformSora,
|
||||||
|
accType: AccountTypeOAuth,
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "openai apikey - cannot refresh",
|
||||||
|
platform: PlatformOpenAI,
|
||||||
|
accType: AccountTypeAPIKey,
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
Platform: tt.platform,
|
||||||
|
Type: tt.accType,
|
||||||
|
}
|
||||||
|
require.Equal(t, tt.want, refresher.CanRefresh(account))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -206,6 +206,18 @@ func ProvideSoraMediaStorage(cfg *config.Config) *SoraMediaStorage {
|
|||||||
return NewSoraMediaStorage(cfg)
|
return NewSoraMediaStorage(cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ProvideSoraDirectClient(
|
||||||
|
cfg *config.Config,
|
||||||
|
httpUpstream HTTPUpstream,
|
||||||
|
tokenProvider *OpenAITokenProvider,
|
||||||
|
accountRepo AccountRepository,
|
||||||
|
soraAccountRepo SoraAccountRepository,
|
||||||
|
) *SoraDirectClient {
|
||||||
|
client := NewSoraDirectClient(cfg, httpUpstream, tokenProvider)
|
||||||
|
client.SetAccountRepositories(accountRepo, soraAccountRepo)
|
||||||
|
return client
|
||||||
|
}
|
||||||
|
|
||||||
// ProvideSoraMediaCleanupService 创建并启动 Sora 媒体清理服务
|
// ProvideSoraMediaCleanupService 创建并启动 Sora 媒体清理服务
|
||||||
func ProvideSoraMediaCleanupService(storage *SoraMediaStorage, cfg *config.Config) *SoraMediaCleanupService {
|
func ProvideSoraMediaCleanupService(storage *SoraMediaStorage, cfg *config.Config) *SoraMediaCleanupService {
|
||||||
svc := NewSoraMediaCleanupService(storage, cfg)
|
svc := NewSoraMediaCleanupService(storage, cfg)
|
||||||
@@ -255,7 +267,7 @@ var ProviderSet = wire.NewSet(
|
|||||||
NewGatewayService,
|
NewGatewayService,
|
||||||
ProvideSoraMediaStorage,
|
ProvideSoraMediaStorage,
|
||||||
ProvideSoraMediaCleanupService,
|
ProvideSoraMediaCleanupService,
|
||||||
NewSoraDirectClient,
|
ProvideSoraDirectClient,
|
||||||
wire.Bind(new(SoraClient), new(*SoraDirectClient)),
|
wire.Bind(new(SoraClient), new(*SoraDirectClient)),
|
||||||
NewSoraGatewayService,
|
NewSoraGatewayService,
|
||||||
NewOpenAIGatewayService,
|
NewOpenAIGatewayService,
|
||||||
|
|||||||
@@ -86,6 +86,7 @@ func (s *FrontendServer) Middleware() gin.HandlerFunc {
|
|||||||
if strings.HasPrefix(path, "/api/") ||
|
if strings.HasPrefix(path, "/api/") ||
|
||||||
strings.HasPrefix(path, "/v1/") ||
|
strings.HasPrefix(path, "/v1/") ||
|
||||||
strings.HasPrefix(path, "/v1beta/") ||
|
strings.HasPrefix(path, "/v1beta/") ||
|
||||||
|
strings.HasPrefix(path, "/sora/") ||
|
||||||
strings.HasPrefix(path, "/antigravity/") ||
|
strings.HasPrefix(path, "/antigravity/") ||
|
||||||
strings.HasPrefix(path, "/setup/") ||
|
strings.HasPrefix(path, "/setup/") ||
|
||||||
path == "/health" ||
|
path == "/health" ||
|
||||||
@@ -209,6 +210,7 @@ func ServeEmbeddedFrontend() gin.HandlerFunc {
|
|||||||
if strings.HasPrefix(path, "/api/") ||
|
if strings.HasPrefix(path, "/api/") ||
|
||||||
strings.HasPrefix(path, "/v1/") ||
|
strings.HasPrefix(path, "/v1/") ||
|
||||||
strings.HasPrefix(path, "/v1beta/") ||
|
strings.HasPrefix(path, "/v1beta/") ||
|
||||||
|
strings.HasPrefix(path, "/sora/") ||
|
||||||
strings.HasPrefix(path, "/antigravity/") ||
|
strings.HasPrefix(path, "/antigravity/") ||
|
||||||
strings.HasPrefix(path, "/setup/") ||
|
strings.HasPrefix(path, "/setup/") ||
|
||||||
path == "/health" ||
|
path == "/health" ||
|
||||||
|
|||||||
@@ -362,6 +362,7 @@ func TestFrontendServer_Middleware(t *testing.T) {
|
|||||||
"/api/v1/users",
|
"/api/v1/users",
|
||||||
"/v1/models",
|
"/v1/models",
|
||||||
"/v1beta/chat",
|
"/v1beta/chat",
|
||||||
|
"/sora/v1/models",
|
||||||
"/antigravity/test",
|
"/antigravity/test",
|
||||||
"/setup/init",
|
"/setup/init",
|
||||||
"/health",
|
"/health",
|
||||||
@@ -537,6 +538,7 @@ func TestServeEmbeddedFrontend(t *testing.T) {
|
|||||||
"/api/users",
|
"/api/users",
|
||||||
"/v1/models",
|
"/v1/models",
|
||||||
"/v1beta/chat",
|
"/v1beta/chat",
|
||||||
|
"/sora/v1/models",
|
||||||
"/antigravity/test",
|
"/antigravity/test",
|
||||||
"/setup/init",
|
"/setup/init",
|
||||||
"/health",
|
"/health",
|
||||||
|
|||||||
@@ -388,7 +388,11 @@ sora:
|
|||||||
recent_task_limit_max: 200
|
recent_task_limit_max: 200
|
||||||
# Enable debug logs for Sora upstream requests
|
# Enable debug logs for Sora upstream requests
|
||||||
# 启用 Sora 直连调试日志
|
# 启用 Sora 直连调试日志
|
||||||
|
# 调试日志会输出上游请求尝试、重试、响应摘要;Authorization/openai-sentinel-token 等敏感头会自动脱敏
|
||||||
debug: false
|
debug: false
|
||||||
|
# Allow Sora client to fetch token via OpenAI token provider
|
||||||
|
# 是否允许 Sora 客户端通过 OpenAI token provider 取 token(默认 false,避免误走 OpenAI 刷新链路)
|
||||||
|
use_openai_token_provider: false
|
||||||
# Optional custom headers (key-value)
|
# Optional custom headers (key-value)
|
||||||
# 额外请求头(键值对)
|
# 额外请求头(键值对)
|
||||||
headers: {}
|
headers: {}
|
||||||
@@ -431,6 +435,13 @@ sora:
|
|||||||
# Cron 调度表达式
|
# Cron 调度表达式
|
||||||
schedule: "0 3 * * *"
|
schedule: "0 3 * * *"
|
||||||
|
|
||||||
|
# Token refresh behavior
|
||||||
|
# token 刷新行为控制
|
||||||
|
token_refresh:
|
||||||
|
# Whether OpenAI refresh flow is allowed to sync linked Sora accounts
|
||||||
|
# 是否允许 OpenAI 刷新流程同步覆盖 linked_openai_account_id 关联的 Sora 账号 token
|
||||||
|
sync_linked_sora_accounts: false
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# API Key Auth Cache Configuration
|
# API Key Auth Cache Configuration
|
||||||
# API Key 认证缓存配置
|
# API Key 认证缓存配置
|
||||||
|
|||||||
@@ -220,7 +220,7 @@ export async function generateAuthUrl(
|
|||||||
*/
|
*/
|
||||||
export async function exchangeCode(
|
export async function exchangeCode(
|
||||||
endpoint: string,
|
endpoint: string,
|
||||||
exchangeData: { session_id: string; code: string; proxy_id?: number }
|
exchangeData: { session_id: string; code: string; state?: string; proxy_id?: number }
|
||||||
): Promise<Record<string, unknown>> {
|
): Promise<Record<string, unknown>> {
|
||||||
const { data } = await apiClient.post<Record<string, unknown>>(endpoint, exchangeData)
|
const { data } = await apiClient.post<Record<string, unknown>>(endpoint, exchangeData)
|
||||||
return data
|
return data
|
||||||
@@ -442,7 +442,8 @@ export async function getAntigravityDefaultModelMapping(): Promise<Record<string
|
|||||||
*/
|
*/
|
||||||
export async function refreshOpenAIToken(
|
export async function refreshOpenAIToken(
|
||||||
refreshToken: string,
|
refreshToken: string,
|
||||||
proxyId?: number | null
|
proxyId?: number | null,
|
||||||
|
endpoint: string = '/admin/openai/refresh-token'
|
||||||
): Promise<Record<string, unknown>> {
|
): Promise<Record<string, unknown>> {
|
||||||
const payload: { refresh_token: string; proxy_id?: number } = {
|
const payload: { refresh_token: string; proxy_id?: number } = {
|
||||||
refresh_token: refreshToken
|
refresh_token: refreshToken
|
||||||
@@ -450,7 +451,29 @@ export async function refreshOpenAIToken(
|
|||||||
if (proxyId) {
|
if (proxyId) {
|
||||||
payload.proxy_id = proxyId
|
payload.proxy_id = proxyId
|
||||||
}
|
}
|
||||||
const { data } = await apiClient.post<Record<string, unknown>>('/admin/openai/refresh-token', payload)
|
const { data } = await apiClient.post<Record<string, unknown>>(endpoint, payload)
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Validate Sora session token and exchange to access token
|
||||||
|
* @param sessionToken - Sora session token
|
||||||
|
* @param proxyId - Optional proxy ID
|
||||||
|
* @param endpoint - API endpoint path
|
||||||
|
* @returns Token information including access_token
|
||||||
|
*/
|
||||||
|
export async function validateSoraSessionToken(
|
||||||
|
sessionToken: string,
|
||||||
|
proxyId?: number | null,
|
||||||
|
endpoint: string = '/admin/sora/st2at'
|
||||||
|
): Promise<Record<string, unknown>> {
|
||||||
|
const payload: { session_token: string; proxy_id?: number } = {
|
||||||
|
session_token: sessionToken
|
||||||
|
}
|
||||||
|
if (proxyId) {
|
||||||
|
payload.proxy_id = proxyId
|
||||||
|
}
|
||||||
|
const { data } = await apiClient.post<Record<string, unknown>>(endpoint, payload)
|
||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -475,6 +498,7 @@ export const accountsAPI = {
|
|||||||
generateAuthUrl,
|
generateAuthUrl,
|
||||||
exchangeCode,
|
exchangeCode,
|
||||||
refreshOpenAIToken,
|
refreshOpenAIToken,
|
||||||
|
validateSoraSessionToken,
|
||||||
batchCreate,
|
batchCreate,
|
||||||
batchUpdateCredentials,
|
batchUpdateCredentials,
|
||||||
bulkUpdate,
|
bulkUpdate,
|
||||||
|
|||||||
@@ -109,6 +109,28 @@
|
|||||||
</svg>
|
</svg>
|
||||||
OpenAI
|
OpenAI
|
||||||
</button>
|
</button>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
@click="form.platform = 'sora'"
|
||||||
|
:class="[
|
||||||
|
'flex flex-1 items-center justify-center gap-2 rounded-md px-4 py-2.5 text-sm font-medium transition-all',
|
||||||
|
form.platform === 'sora'
|
||||||
|
? 'bg-white text-rose-600 shadow-sm dark:bg-dark-600 dark:text-rose-400'
|
||||||
|
: 'text-gray-600 hover:text-gray-900 dark:text-gray-400 dark:hover:text-gray-200'
|
||||||
|
]"
|
||||||
|
>
|
||||||
|
<svg
|
||||||
|
class="h-4 w-4"
|
||||||
|
fill="none"
|
||||||
|
viewBox="0 0 24 24"
|
||||||
|
stroke="currentColor"
|
||||||
|
stroke-width="2"
|
||||||
|
>
|
||||||
|
<path stroke-linecap="round" stroke-linejoin="round" d="M14.752 11.168l-3.197-2.132A1 1 0 0010 9.87v4.263a1 1 0 001.555.832l3.197-2.132a1 1 0 000-1.664z" />
|
||||||
|
<path stroke-linecap="round" stroke-linejoin="round" d="M21 12a9 9 0 11-18 0 9 9 0 0118 0z" />
|
||||||
|
</svg>
|
||||||
|
Sora
|
||||||
|
</button>
|
||||||
<button
|
<button
|
||||||
type="button"
|
type="button"
|
||||||
@click="form.platform = 'gemini'"
|
@click="form.platform = 'gemini'"
|
||||||
@@ -150,6 +172,38 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<!-- Account Type Selection (Sora) -->
|
||||||
|
<div v-if="form.platform === 'sora'">
|
||||||
|
<label class="input-label">{{ t('admin.accounts.accountType') }}</label>
|
||||||
|
<div class="mt-2 grid grid-cols-1 gap-3" data-tour="account-form-type">
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
@click="accountCategory = 'oauth-based'"
|
||||||
|
:class="[
|
||||||
|
'flex items-center gap-3 rounded-lg border-2 p-3 text-left transition-all',
|
||||||
|
accountCategory === 'oauth-based'
|
||||||
|
? 'border-rose-500 bg-rose-50 dark:bg-rose-900/20'
|
||||||
|
: 'border-gray-200 hover:border-rose-300 dark:border-dark-600 dark:hover:border-rose-700'
|
||||||
|
]"
|
||||||
|
>
|
||||||
|
<div
|
||||||
|
:class="[
|
||||||
|
'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
|
||||||
|
accountCategory === 'oauth-based'
|
||||||
|
? 'bg-rose-500 text-white'
|
||||||
|
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
|
||||||
|
]"
|
||||||
|
>
|
||||||
|
<Icon name="key" size="sm" />
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<span class="block text-sm font-medium text-gray-900 dark:text-white">OAuth</span>
|
||||||
|
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('admin.accounts.types.chatgptOauth') }}</span>
|
||||||
|
</div>
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
<!-- Account Type Selection (Anthropic) -->
|
<!-- Account Type Selection (Anthropic) -->
|
||||||
<div v-if="form.platform === 'anthropic'">
|
<div v-if="form.platform === 'anthropic'">
|
||||||
<label class="input-label">{{ t('admin.accounts.accountType') }}</label>
|
<label class="input-label">{{ t('admin.accounts.accountType') }}</label>
|
||||||
@@ -1747,32 +1801,6 @@
|
|||||||
|
|
||||||
<!-- Step 2: OAuth Authorization -->
|
<!-- Step 2: OAuth Authorization -->
|
||||||
<div v-else class="space-y-5">
|
<div v-else class="space-y-5">
|
||||||
<!-- 同时启用 Sora 开关 (仅 OpenAI OAuth) -->
|
|
||||||
<div v-if="form.platform === 'openai' && accountCategory === 'oauth-based'" class="mb-4">
|
|
||||||
<label class="flex items-center justify-between rounded-lg border border-gray-200 p-3 dark:border-dark-600">
|
|
||||||
<div class="flex items-center gap-3">
|
|
||||||
<div class="flex h-8 w-8 shrink-0 items-center justify-center rounded-lg bg-rose-100 text-rose-600 dark:bg-rose-900/30 dark:text-rose-400">
|
|
||||||
<svg class="h-4 w-4" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
|
||||||
<path stroke-linecap="round" stroke-linejoin="round" d="M14.752 11.168l-3.197-2.132A1 1 0 0010 9.87v4.263a1 1 0 001.555.832l3.197-2.132a1 1 0 000-1.664z" />
|
|
||||||
<path stroke-linecap="round" stroke-linejoin="round" d="M21 12a9 9 0 11-18 0 9 9 0 0118 0z" />
|
|
||||||
</svg>
|
|
||||||
</div>
|
|
||||||
<div>
|
|
||||||
<span class="block text-sm font-medium text-gray-900 dark:text-white">
|
|
||||||
{{ t('admin.accounts.openai.enableSora') }}
|
|
||||||
</span>
|
|
||||||
<span class="text-xs text-gray-500 dark:text-gray-400">
|
|
||||||
{{ t('admin.accounts.openai.enableSoraHint') }}
|
|
||||||
</span>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
<label :class="['switch', { 'switch-active': enableSoraOnOpenAIOAuth }]">
|
|
||||||
<input type="checkbox" v-model="enableSoraOnOpenAIOAuth" class="sr-only" />
|
|
||||||
<span class="switch-thumb"></span>
|
|
||||||
</label>
|
|
||||||
</label>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<OAuthAuthorizationFlow
|
<OAuthAuthorizationFlow
|
||||||
ref="oauthFlowRef"
|
ref="oauthFlowRef"
|
||||||
:add-method="form.platform === 'anthropic' ? addMethod : 'oauth'"
|
:add-method="form.platform === 'anthropic' ? addMethod : 'oauth'"
|
||||||
@@ -1781,15 +1809,17 @@
|
|||||||
:loading="currentOAuthLoading"
|
:loading="currentOAuthLoading"
|
||||||
:error="currentOAuthError"
|
:error="currentOAuthError"
|
||||||
:show-help="form.platform === 'anthropic'"
|
:show-help="form.platform === 'anthropic'"
|
||||||
:show-proxy-warning="form.platform !== 'openai' && !!form.proxy_id"
|
:show-proxy-warning="form.platform !== 'openai' && form.platform !== 'sora' && !!form.proxy_id"
|
||||||
:allow-multiple="form.platform === 'anthropic'"
|
:allow-multiple="form.platform === 'anthropic'"
|
||||||
:show-cookie-option="form.platform === 'anthropic'"
|
:show-cookie-option="form.platform === 'anthropic'"
|
||||||
:show-refresh-token-option="form.platform === 'openai' || form.platform === 'antigravity'"
|
:show-refresh-token-option="form.platform === 'openai' || form.platform === 'sora' || form.platform === 'antigravity'"
|
||||||
|
:show-session-token-option="form.platform === 'sora'"
|
||||||
:platform="form.platform"
|
:platform="form.platform"
|
||||||
:show-project-id="geminiOAuthType === 'code_assist'"
|
:show-project-id="geminiOAuthType === 'code_assist'"
|
||||||
@generate-url="handleGenerateUrl"
|
@generate-url="handleGenerateUrl"
|
||||||
@cookie-auth="handleCookieAuth"
|
@cookie-auth="handleCookieAuth"
|
||||||
@validate-refresh-token="handleValidateRefreshToken"
|
@validate-refresh-token="handleValidateRefreshToken"
|
||||||
|
@validate-session-token="handleValidateSessionToken"
|
||||||
/>
|
/>
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
@@ -2148,6 +2178,7 @@ interface OAuthFlowExposed {
|
|||||||
projectId: string
|
projectId: string
|
||||||
sessionKey: string
|
sessionKey: string
|
||||||
refreshToken: string
|
refreshToken: string
|
||||||
|
sessionToken: string
|
||||||
inputMethod: AuthInputMethod
|
inputMethod: AuthInputMethod
|
||||||
reset: () => void
|
reset: () => void
|
||||||
}
|
}
|
||||||
@@ -2156,7 +2187,7 @@ const { t } = useI18n()
|
|||||||
const authStore = useAuthStore()
|
const authStore = useAuthStore()
|
||||||
|
|
||||||
const oauthStepTitle = computed(() => {
|
const oauthStepTitle = computed(() => {
|
||||||
if (form.platform === 'openai') return t('admin.accounts.oauth.openai.title')
|
if (form.platform === 'openai' || form.platform === 'sora') return t('admin.accounts.oauth.openai.title')
|
||||||
if (form.platform === 'gemini') return t('admin.accounts.oauth.gemini.title')
|
if (form.platform === 'gemini') return t('admin.accounts.oauth.gemini.title')
|
||||||
if (form.platform === 'antigravity') return t('admin.accounts.oauth.antigravity.title')
|
if (form.platform === 'antigravity') return t('admin.accounts.oauth.antigravity.title')
|
||||||
return t('admin.accounts.oauth.title')
|
return t('admin.accounts.oauth.title')
|
||||||
@@ -2164,13 +2195,13 @@ const oauthStepTitle = computed(() => {
|
|||||||
|
|
||||||
// Platform-specific hints for API Key type
|
// Platform-specific hints for API Key type
|
||||||
const baseUrlHint = computed(() => {
|
const baseUrlHint = computed(() => {
|
||||||
if (form.platform === 'openai') return t('admin.accounts.openai.baseUrlHint')
|
if (form.platform === 'openai' || form.platform === 'sora') return t('admin.accounts.openai.baseUrlHint')
|
||||||
if (form.platform === 'gemini') return t('admin.accounts.gemini.baseUrlHint')
|
if (form.platform === 'gemini') return t('admin.accounts.gemini.baseUrlHint')
|
||||||
return t('admin.accounts.baseUrlHint')
|
return t('admin.accounts.baseUrlHint')
|
||||||
})
|
})
|
||||||
|
|
||||||
const apiKeyHint = computed(() => {
|
const apiKeyHint = computed(() => {
|
||||||
if (form.platform === 'openai') return t('admin.accounts.openai.apiKeyHint')
|
if (form.platform === 'openai' || form.platform === 'sora') return t('admin.accounts.openai.apiKeyHint')
|
||||||
if (form.platform === 'gemini') return t('admin.accounts.gemini.apiKeyHint')
|
if (form.platform === 'gemini') return t('admin.accounts.gemini.apiKeyHint')
|
||||||
return t('admin.accounts.apiKeyHint')
|
return t('admin.accounts.apiKeyHint')
|
||||||
})
|
})
|
||||||
@@ -2191,34 +2222,36 @@ const appStore = useAppStore()
|
|||||||
|
|
||||||
// OAuth composables
|
// OAuth composables
|
||||||
const oauth = useAccountOAuth() // For Anthropic OAuth
|
const oauth = useAccountOAuth() // For Anthropic OAuth
|
||||||
const openaiOAuth = useOpenAIOAuth() // For OpenAI OAuth
|
const openaiOAuth = useOpenAIOAuth({ platform: 'openai' }) // For OpenAI OAuth
|
||||||
|
const soraOAuth = useOpenAIOAuth({ platform: 'sora' }) // For Sora OAuth
|
||||||
const geminiOAuth = useGeminiOAuth() // For Gemini OAuth
|
const geminiOAuth = useGeminiOAuth() // For Gemini OAuth
|
||||||
const antigravityOAuth = useAntigravityOAuth() // For Antigravity OAuth
|
const antigravityOAuth = useAntigravityOAuth() // For Antigravity OAuth
|
||||||
|
const activeOpenAIOAuth = computed(() => (form.platform === 'sora' ? soraOAuth : openaiOAuth))
|
||||||
|
|
||||||
// Computed: current OAuth state for template binding
|
// Computed: current OAuth state for template binding
|
||||||
const currentAuthUrl = computed(() => {
|
const currentAuthUrl = computed(() => {
|
||||||
if (form.platform === 'openai') return openaiOAuth.authUrl.value
|
if (form.platform === 'openai' || form.platform === 'sora') return activeOpenAIOAuth.value.authUrl.value
|
||||||
if (form.platform === 'gemini') return geminiOAuth.authUrl.value
|
if (form.platform === 'gemini') return geminiOAuth.authUrl.value
|
||||||
if (form.platform === 'antigravity') return antigravityOAuth.authUrl.value
|
if (form.platform === 'antigravity') return antigravityOAuth.authUrl.value
|
||||||
return oauth.authUrl.value
|
return oauth.authUrl.value
|
||||||
})
|
})
|
||||||
|
|
||||||
const currentSessionId = computed(() => {
|
const currentSessionId = computed(() => {
|
||||||
if (form.platform === 'openai') return openaiOAuth.sessionId.value
|
if (form.platform === 'openai' || form.platform === 'sora') return activeOpenAIOAuth.value.sessionId.value
|
||||||
if (form.platform === 'gemini') return geminiOAuth.sessionId.value
|
if (form.platform === 'gemini') return geminiOAuth.sessionId.value
|
||||||
if (form.platform === 'antigravity') return antigravityOAuth.sessionId.value
|
if (form.platform === 'antigravity') return antigravityOAuth.sessionId.value
|
||||||
return oauth.sessionId.value
|
return oauth.sessionId.value
|
||||||
})
|
})
|
||||||
|
|
||||||
const currentOAuthLoading = computed(() => {
|
const currentOAuthLoading = computed(() => {
|
||||||
if (form.platform === 'openai') return openaiOAuth.loading.value
|
if (form.platform === 'openai' || form.platform === 'sora') return activeOpenAIOAuth.value.loading.value
|
||||||
if (form.platform === 'gemini') return geminiOAuth.loading.value
|
if (form.platform === 'gemini') return geminiOAuth.loading.value
|
||||||
if (form.platform === 'antigravity') return antigravityOAuth.loading.value
|
if (form.platform === 'antigravity') return antigravityOAuth.loading.value
|
||||||
return oauth.loading.value
|
return oauth.loading.value
|
||||||
})
|
})
|
||||||
|
|
||||||
const currentOAuthError = computed(() => {
|
const currentOAuthError = computed(() => {
|
||||||
if (form.platform === 'openai') return openaiOAuth.error.value
|
if (form.platform === 'openai' || form.platform === 'sora') return activeOpenAIOAuth.value.error.value
|
||||||
if (form.platform === 'gemini') return geminiOAuth.error.value
|
if (form.platform === 'gemini') return geminiOAuth.error.value
|
||||||
if (form.platform === 'antigravity') return antigravityOAuth.error.value
|
if (form.platform === 'antigravity') return antigravityOAuth.error.value
|
||||||
return oauth.error.value
|
return oauth.error.value
|
||||||
@@ -2257,7 +2290,6 @@ const interceptWarmupRequests = ref(false)
|
|||||||
const autoPauseOnExpired = ref(true)
|
const autoPauseOnExpired = ref(true)
|
||||||
const openaiPassthroughEnabled = ref(false)
|
const openaiPassthroughEnabled = ref(false)
|
||||||
const codexCLIOnlyEnabled = ref(false)
|
const codexCLIOnlyEnabled = ref(false)
|
||||||
const enableSoraOnOpenAIOAuth = ref(false) // OpenAI OAuth 时同时启用 Sora
|
|
||||||
const mixedScheduling = ref(false) // For antigravity accounts: enable mixed scheduling
|
const mixedScheduling = ref(false) // For antigravity accounts: enable mixed scheduling
|
||||||
const antigravityAccountType = ref<'oauth' | 'upstream'>('oauth') // For antigravity: oauth or upstream
|
const antigravityAccountType = ref<'oauth' | 'upstream'>('oauth') // For antigravity: oauth or upstream
|
||||||
const upstreamBaseUrl = ref('') // For upstream type: base URL
|
const upstreamBaseUrl = ref('') // For upstream type: base URL
|
||||||
@@ -2398,8 +2430,8 @@ const expiresAtInput = computed({
|
|||||||
|
|
||||||
const canExchangeCode = computed(() => {
|
const canExchangeCode = computed(() => {
|
||||||
const authCode = oauthFlowRef.value?.authCode || ''
|
const authCode = oauthFlowRef.value?.authCode || ''
|
||||||
if (form.platform === 'openai') {
|
if (form.platform === 'openai' || form.platform === 'sora') {
|
||||||
return authCode.trim() && openaiOAuth.sessionId.value && !openaiOAuth.loading.value
|
return authCode.trim() && activeOpenAIOAuth.value.sessionId.value && !activeOpenAIOAuth.value.loading.value
|
||||||
}
|
}
|
||||||
if (form.platform === 'gemini') {
|
if (form.platform === 'gemini') {
|
||||||
return authCode.trim() && geminiOAuth.sessionId.value && !geminiOAuth.loading.value
|
return authCode.trim() && geminiOAuth.sessionId.value && !geminiOAuth.loading.value
|
||||||
@@ -2459,7 +2491,7 @@ watch(
|
|||||||
(newPlatform) => {
|
(newPlatform) => {
|
||||||
// Reset base URL based on platform
|
// Reset base URL based on platform
|
||||||
apiKeyBaseUrl.value =
|
apiKeyBaseUrl.value =
|
||||||
newPlatform === 'openai'
|
(newPlatform === 'openai' || newPlatform === 'sora')
|
||||||
? 'https://api.openai.com'
|
? 'https://api.openai.com'
|
||||||
: newPlatform === 'gemini'
|
: newPlatform === 'gemini'
|
||||||
? 'https://generativelanguage.googleapis.com'
|
? 'https://generativelanguage.googleapis.com'
|
||||||
@@ -2485,6 +2517,11 @@ watch(
|
|||||||
if (newPlatform !== 'anthropic') {
|
if (newPlatform !== 'anthropic') {
|
||||||
interceptWarmupRequests.value = false
|
interceptWarmupRequests.value = false
|
||||||
}
|
}
|
||||||
|
if (newPlatform === 'sora') {
|
||||||
|
accountCategory.value = 'oauth-based'
|
||||||
|
addMethod.value = 'oauth'
|
||||||
|
form.type = 'oauth'
|
||||||
|
}
|
||||||
if (newPlatform !== 'openai') {
|
if (newPlatform !== 'openai') {
|
||||||
openaiPassthroughEnabled.value = false
|
openaiPassthroughEnabled.value = false
|
||||||
codexCLIOnlyEnabled.value = false
|
codexCLIOnlyEnabled.value = false
|
||||||
@@ -2492,6 +2529,7 @@ watch(
|
|||||||
// Reset OAuth states
|
// Reset OAuth states
|
||||||
oauth.resetState()
|
oauth.resetState()
|
||||||
openaiOAuth.resetState()
|
openaiOAuth.resetState()
|
||||||
|
soraOAuth.resetState()
|
||||||
geminiOAuth.resetState()
|
geminiOAuth.resetState()
|
||||||
antigravityOAuth.resetState()
|
antigravityOAuth.resetState()
|
||||||
}
|
}
|
||||||
@@ -2753,7 +2791,6 @@ const resetForm = () => {
|
|||||||
autoPauseOnExpired.value = true
|
autoPauseOnExpired.value = true
|
||||||
openaiPassthroughEnabled.value = false
|
openaiPassthroughEnabled.value = false
|
||||||
codexCLIOnlyEnabled.value = false
|
codexCLIOnlyEnabled.value = false
|
||||||
enableSoraOnOpenAIOAuth.value = false
|
|
||||||
// Reset quota control state
|
// Reset quota control state
|
||||||
windowCostEnabled.value = false
|
windowCostEnabled.value = false
|
||||||
windowCostLimit.value = null
|
windowCostLimit.value = null
|
||||||
@@ -2776,6 +2813,7 @@ const resetForm = () => {
|
|||||||
geminiTierAIStudio.value = 'aistudio_free'
|
geminiTierAIStudio.value = 'aistudio_free'
|
||||||
oauth.resetState()
|
oauth.resetState()
|
||||||
openaiOAuth.resetState()
|
openaiOAuth.resetState()
|
||||||
|
soraOAuth.resetState()
|
||||||
geminiOAuth.resetState()
|
geminiOAuth.resetState()
|
||||||
antigravityOAuth.resetState()
|
antigravityOAuth.resetState()
|
||||||
oauthFlowRef.value?.reset()
|
oauthFlowRef.value?.reset()
|
||||||
@@ -2807,6 +2845,23 @@ const buildOpenAIExtra = (base?: Record<string, unknown>): Record<string, unknow
|
|||||||
return Object.keys(extra).length > 0 ? extra : undefined
|
return Object.keys(extra).length > 0 ? extra : undefined
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const buildSoraExtra = (
|
||||||
|
base?: Record<string, unknown>,
|
||||||
|
linkedOpenAIAccountId?: string | number
|
||||||
|
): Record<string, unknown> | undefined => {
|
||||||
|
const extra: Record<string, unknown> = { ...(base || {}) }
|
||||||
|
if (linkedOpenAIAccountId !== undefined && linkedOpenAIAccountId !== null) {
|
||||||
|
const id = String(linkedOpenAIAccountId).trim()
|
||||||
|
if (id) {
|
||||||
|
extra.linked_openai_account_id = id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
delete extra.openai_passthrough
|
||||||
|
delete extra.openai_oauth_passthrough
|
||||||
|
delete extra.codex_cli_only
|
||||||
|
return Object.keys(extra).length > 0 ? extra : undefined
|
||||||
|
}
|
||||||
|
|
||||||
// Helper function to create account with mixed channel warning handling
|
// Helper function to create account with mixed channel warning handling
|
||||||
const doCreateAccount = async (payload: any) => {
|
const doCreateAccount = async (payload: any) => {
|
||||||
submitting.value = true
|
submitting.value = true
|
||||||
@@ -2922,7 +2977,7 @@ const handleSubmit = async () => {
|
|||||||
|
|
||||||
// Determine default base URL based on platform
|
// Determine default base URL based on platform
|
||||||
const defaultBaseUrl =
|
const defaultBaseUrl =
|
||||||
form.platform === 'openai'
|
(form.platform === 'openai' || form.platform === 'sora')
|
||||||
? 'https://api.openai.com'
|
? 'https://api.openai.com'
|
||||||
: form.platform === 'gemini'
|
: form.platform === 'gemini'
|
||||||
? 'https://generativelanguage.googleapis.com'
|
? 'https://generativelanguage.googleapis.com'
|
||||||
@@ -2974,14 +3029,15 @@ const goBackToBasicInfo = () => {
|
|||||||
step.value = 1
|
step.value = 1
|
||||||
oauth.resetState()
|
oauth.resetState()
|
||||||
openaiOAuth.resetState()
|
openaiOAuth.resetState()
|
||||||
|
soraOAuth.resetState()
|
||||||
geminiOAuth.resetState()
|
geminiOAuth.resetState()
|
||||||
antigravityOAuth.resetState()
|
antigravityOAuth.resetState()
|
||||||
oauthFlowRef.value?.reset()
|
oauthFlowRef.value?.reset()
|
||||||
}
|
}
|
||||||
|
|
||||||
const handleGenerateUrl = async () => {
|
const handleGenerateUrl = async () => {
|
||||||
if (form.platform === 'openai') {
|
if (form.platform === 'openai' || form.platform === 'sora') {
|
||||||
await openaiOAuth.generateAuthUrl(form.proxy_id)
|
await activeOpenAIOAuth.value.generateAuthUrl(form.proxy_id)
|
||||||
} else if (form.platform === 'gemini') {
|
} else if (form.platform === 'gemini') {
|
||||||
await geminiOAuth.generateAuthUrl(
|
await geminiOAuth.generateAuthUrl(
|
||||||
form.proxy_id,
|
form.proxy_id,
|
||||||
@@ -2997,13 +3053,19 @@ const handleGenerateUrl = async () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const handleValidateRefreshToken = (rt: string) => {
|
const handleValidateRefreshToken = (rt: string) => {
|
||||||
if (form.platform === 'openai') {
|
if (form.platform === 'openai' || form.platform === 'sora') {
|
||||||
handleOpenAIValidateRT(rt)
|
handleOpenAIValidateRT(rt)
|
||||||
} else if (form.platform === 'antigravity') {
|
} else if (form.platform === 'antigravity') {
|
||||||
handleAntigravityValidateRT(rt)
|
handleAntigravityValidateRT(rt)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const handleValidateSessionToken = (sessionToken: string) => {
|
||||||
|
if (form.platform === 'sora') {
|
||||||
|
handleSoraValidateST(sessionToken)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const formatDateTimeLocal = formatDateTimeLocalInput
|
const formatDateTimeLocal = formatDateTimeLocalInput
|
||||||
const parseDateTimeLocal = parseDateTimeLocalInput
|
const parseDateTimeLocal = parseDateTimeLocalInput
|
||||||
|
|
||||||
@@ -3039,29 +3101,42 @@ const createAccountAndFinish = async (
|
|||||||
|
|
||||||
// OpenAI OAuth 授权码兑换
|
// OpenAI OAuth 授权码兑换
|
||||||
const handleOpenAIExchange = async (authCode: string) => {
|
const handleOpenAIExchange = async (authCode: string) => {
|
||||||
if (!authCode.trim() || !openaiOAuth.sessionId.value) return
|
const oauthClient = activeOpenAIOAuth.value
|
||||||
|
if (!authCode.trim() || !oauthClient.sessionId.value) return
|
||||||
|
|
||||||
openaiOAuth.loading.value = true
|
oauthClient.loading.value = true
|
||||||
openaiOAuth.error.value = ''
|
oauthClient.error.value = ''
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const tokenInfo = await openaiOAuth.exchangeAuthCode(
|
const stateToUse = (oauthFlowRef.value?.oauthState || oauthClient.oauthState.value || '').trim()
|
||||||
|
if (!stateToUse) {
|
||||||
|
oauthClient.error.value = t('admin.accounts.oauth.authFailed')
|
||||||
|
appStore.showError(oauthClient.error.value)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
const tokenInfo = await oauthClient.exchangeAuthCode(
|
||||||
authCode.trim(),
|
authCode.trim(),
|
||||||
openaiOAuth.sessionId.value,
|
oauthClient.sessionId.value,
|
||||||
|
stateToUse,
|
||||||
form.proxy_id
|
form.proxy_id
|
||||||
)
|
)
|
||||||
if (!tokenInfo) return
|
if (!tokenInfo) return
|
||||||
|
|
||||||
const credentials = openaiOAuth.buildCredentials(tokenInfo)
|
const credentials = oauthClient.buildCredentials(tokenInfo)
|
||||||
const oauthExtra = openaiOAuth.buildExtraInfo(tokenInfo) as Record<string, unknown> | undefined
|
const oauthExtra = oauthClient.buildExtraInfo(tokenInfo) as Record<string, unknown> | undefined
|
||||||
const extra = buildOpenAIExtra(oauthExtra)
|
const extra = buildOpenAIExtra(oauthExtra)
|
||||||
|
const shouldCreateOpenAI = form.platform === 'openai'
|
||||||
|
const shouldCreateSora = form.platform === 'sora'
|
||||||
|
|
||||||
// 应用临时不可调度配置
|
// 应用临时不可调度配置
|
||||||
if (!applyTempUnschedConfig(credentials)) {
|
if (!applyTempUnschedConfig(credentials)) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 1. 创建 OpenAI 账号
|
let openaiAccountId: string | number | undefined
|
||||||
|
|
||||||
|
if (shouldCreateOpenAI) {
|
||||||
const openaiAccount = await adminAPI.accounts.create({
|
const openaiAccount = await adminAPI.accounts.create({
|
||||||
name: form.name,
|
name: form.name,
|
||||||
notes: form.notes,
|
notes: form.notes,
|
||||||
@@ -3077,29 +3152,21 @@ const handleOpenAIExchange = async (authCode: string) => {
|
|||||||
expires_at: form.expires_at,
|
expires_at: form.expires_at,
|
||||||
auto_pause_on_expired: autoPauseOnExpired.value
|
auto_pause_on_expired: autoPauseOnExpired.value
|
||||||
})
|
})
|
||||||
|
openaiAccountId = openaiAccount.id
|
||||||
appStore.showSuccess(t('admin.accounts.accountCreated'))
|
appStore.showSuccess(t('admin.accounts.accountCreated'))
|
||||||
|
}
|
||||||
|
|
||||||
// 2. 如果启用了 Sora,同时创建 Sora 账号
|
if (shouldCreateSora) {
|
||||||
if (enableSoraOnOpenAIOAuth.value) {
|
|
||||||
try {
|
|
||||||
// Sora 使用相同的 OAuth credentials
|
|
||||||
const soraCredentials = {
|
const soraCredentials = {
|
||||||
access_token: credentials.access_token,
|
access_token: credentials.access_token,
|
||||||
refresh_token: credentials.refresh_token,
|
refresh_token: credentials.refresh_token,
|
||||||
expires_at: credentials.expires_at
|
expires_at: credentials.expires_at
|
||||||
}
|
}
|
||||||
|
|
||||||
// 建立关联关系
|
const soraName = shouldCreateOpenAI ? `${form.name} (Sora)` : form.name
|
||||||
const soraExtra: Record<string, unknown> = {
|
const soraExtra = buildSoraExtra(shouldCreateOpenAI ? extra : oauthExtra, openaiAccountId)
|
||||||
...(extra || {}),
|
|
||||||
linked_openai_account_id: String(openaiAccount.id)
|
|
||||||
}
|
|
||||||
delete soraExtra.openai_passthrough
|
|
||||||
delete soraExtra.openai_oauth_passthrough
|
|
||||||
|
|
||||||
await adminAPI.accounts.create({
|
await adminAPI.accounts.create({
|
||||||
name: `${form.name} (Sora)`,
|
name: soraName,
|
||||||
notes: form.notes,
|
notes: form.notes,
|
||||||
platform: 'sora',
|
platform: 'sora',
|
||||||
type: 'oauth',
|
type: 'oauth',
|
||||||
@@ -3113,26 +3180,22 @@ const handleOpenAIExchange = async (authCode: string) => {
|
|||||||
expires_at: form.expires_at,
|
expires_at: form.expires_at,
|
||||||
auto_pause_on_expired: autoPauseOnExpired.value
|
auto_pause_on_expired: autoPauseOnExpired.value
|
||||||
})
|
})
|
||||||
|
appStore.showSuccess(t('admin.accounts.accountCreated'))
|
||||||
appStore.showSuccess(t('admin.accounts.soraAccountCreated'))
|
|
||||||
} catch (error: any) {
|
|
||||||
console.error('创建 Sora 账号失败:', error)
|
|
||||||
appStore.showWarning(t('admin.accounts.soraAccountFailed'))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
emit('created')
|
emit('created')
|
||||||
handleClose()
|
handleClose()
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
openaiOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
|
oauthClient.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
|
||||||
appStore.showError(openaiOAuth.error.value)
|
appStore.showError(oauthClient.error.value)
|
||||||
} finally {
|
} finally {
|
||||||
openaiOAuth.loading.value = false
|
oauthClient.loading.value = false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// OpenAI 手动 RT 批量验证和创建
|
// OpenAI 手动 RT 批量验证和创建
|
||||||
const handleOpenAIValidateRT = async (refreshTokenInput: string) => {
|
const handleOpenAIValidateRT = async (refreshTokenInput: string) => {
|
||||||
|
const oauthClient = activeOpenAIOAuth.value
|
||||||
if (!refreshTokenInput.trim()) return
|
if (!refreshTokenInput.trim()) return
|
||||||
|
|
||||||
// Parse multiple refresh tokens (one per line)
|
// Parse multiple refresh tokens (one per line)
|
||||||
@@ -3142,39 +3205,44 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => {
|
|||||||
.filter((rt) => rt)
|
.filter((rt) => rt)
|
||||||
|
|
||||||
if (refreshTokens.length === 0) {
|
if (refreshTokens.length === 0) {
|
||||||
openaiOAuth.error.value = t('admin.accounts.oauth.openai.pleaseEnterRefreshToken')
|
oauthClient.error.value = t('admin.accounts.oauth.openai.pleaseEnterRefreshToken')
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
openaiOAuth.loading.value = true
|
oauthClient.loading.value = true
|
||||||
openaiOAuth.error.value = ''
|
oauthClient.error.value = ''
|
||||||
|
|
||||||
let successCount = 0
|
let successCount = 0
|
||||||
let failedCount = 0
|
let failedCount = 0
|
||||||
const errors: string[] = []
|
const errors: string[] = []
|
||||||
|
const shouldCreateOpenAI = form.platform === 'openai'
|
||||||
|
const shouldCreateSora = form.platform === 'sora'
|
||||||
|
|
||||||
try {
|
try {
|
||||||
for (let i = 0; i < refreshTokens.length; i++) {
|
for (let i = 0; i < refreshTokens.length; i++) {
|
||||||
try {
|
try {
|
||||||
const tokenInfo = await openaiOAuth.validateRefreshToken(
|
const tokenInfo = await oauthClient.validateRefreshToken(
|
||||||
refreshTokens[i],
|
refreshTokens[i],
|
||||||
form.proxy_id
|
form.proxy_id
|
||||||
)
|
)
|
||||||
if (!tokenInfo) {
|
if (!tokenInfo) {
|
||||||
failedCount++
|
failedCount++
|
||||||
errors.push(`#${i + 1}: ${openaiOAuth.error.value || 'Validation failed'}`)
|
errors.push(`#${i + 1}: ${oauthClient.error.value || 'Validation failed'}`)
|
||||||
openaiOAuth.error.value = ''
|
oauthClient.error.value = ''
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
const credentials = openaiOAuth.buildCredentials(tokenInfo)
|
const credentials = oauthClient.buildCredentials(tokenInfo)
|
||||||
const oauthExtra = openaiOAuth.buildExtraInfo(tokenInfo) as Record<string, unknown> | undefined
|
const oauthExtra = oauthClient.buildExtraInfo(tokenInfo) as Record<string, unknown> | undefined
|
||||||
const extra = buildOpenAIExtra(oauthExtra)
|
const extra = buildOpenAIExtra(oauthExtra)
|
||||||
|
|
||||||
// Generate account name with index for batch
|
// Generate account name with index for batch
|
||||||
const accountName = refreshTokens.length > 1 ? `${form.name} #${i + 1}` : form.name
|
const accountName = refreshTokens.length > 1 ? `${form.name} #${i + 1}` : form.name
|
||||||
|
|
||||||
await adminAPI.accounts.create({
|
let openaiAccountId: string | number | undefined
|
||||||
|
|
||||||
|
if (shouldCreateOpenAI) {
|
||||||
|
const openaiAccount = await adminAPI.accounts.create({
|
||||||
name: accountName,
|
name: accountName,
|
||||||
notes: form.notes,
|
notes: form.notes,
|
||||||
platform: 'openai',
|
platform: 'openai',
|
||||||
@@ -3189,6 +3257,34 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => {
|
|||||||
expires_at: form.expires_at,
|
expires_at: form.expires_at,
|
||||||
auto_pause_on_expired: autoPauseOnExpired.value
|
auto_pause_on_expired: autoPauseOnExpired.value
|
||||||
})
|
})
|
||||||
|
openaiAccountId = openaiAccount.id
|
||||||
|
}
|
||||||
|
|
||||||
|
if (shouldCreateSora) {
|
||||||
|
const soraCredentials = {
|
||||||
|
access_token: credentials.access_token,
|
||||||
|
refresh_token: credentials.refresh_token,
|
||||||
|
expires_at: credentials.expires_at
|
||||||
|
}
|
||||||
|
const soraName = shouldCreateOpenAI ? `${accountName} (Sora)` : accountName
|
||||||
|
const soraExtra = buildSoraExtra(shouldCreateOpenAI ? extra : oauthExtra, openaiAccountId)
|
||||||
|
await adminAPI.accounts.create({
|
||||||
|
name: soraName,
|
||||||
|
notes: form.notes,
|
||||||
|
platform: 'sora',
|
||||||
|
type: 'oauth',
|
||||||
|
credentials: soraCredentials,
|
||||||
|
extra: soraExtra,
|
||||||
|
proxy_id: form.proxy_id,
|
||||||
|
concurrency: form.concurrency,
|
||||||
|
priority: form.priority,
|
||||||
|
rate_multiplier: form.rate_multiplier,
|
||||||
|
group_ids: form.group_ids,
|
||||||
|
expires_at: form.expires_at,
|
||||||
|
auto_pause_on_expired: autoPauseOnExpired.value
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
successCount++
|
successCount++
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
failedCount++
|
failedCount++
|
||||||
@@ -3210,14 +3306,99 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => {
|
|||||||
appStore.showWarning(
|
appStore.showWarning(
|
||||||
t('admin.accounts.oauth.batchPartialSuccess', { success: successCount, failed: failedCount })
|
t('admin.accounts.oauth.batchPartialSuccess', { success: successCount, failed: failedCount })
|
||||||
)
|
)
|
||||||
openaiOAuth.error.value = errors.join('\n')
|
oauthClient.error.value = errors.join('\n')
|
||||||
emit('created')
|
emit('created')
|
||||||
} else {
|
} else {
|
||||||
openaiOAuth.error.value = errors.join('\n')
|
oauthClient.error.value = errors.join('\n')
|
||||||
appStore.showError(t('admin.accounts.oauth.batchFailed'))
|
appStore.showError(t('admin.accounts.oauth.batchFailed'))
|
||||||
}
|
}
|
||||||
} finally {
|
} finally {
|
||||||
openaiOAuth.loading.value = false
|
oauthClient.loading.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sora 手动 ST 批量验证和创建
|
||||||
|
const handleSoraValidateST = async (sessionTokenInput: string) => {
|
||||||
|
const oauthClient = activeOpenAIOAuth.value
|
||||||
|
if (!sessionTokenInput.trim()) return
|
||||||
|
|
||||||
|
const sessionTokens = sessionTokenInput
|
||||||
|
.split('\n')
|
||||||
|
.map((st) => st.trim())
|
||||||
|
.filter((st) => st)
|
||||||
|
|
||||||
|
if (sessionTokens.length === 0) {
|
||||||
|
oauthClient.error.value = t('admin.accounts.oauth.openai.pleaseEnterSessionToken')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
oauthClient.loading.value = true
|
||||||
|
oauthClient.error.value = ''
|
||||||
|
|
||||||
|
let successCount = 0
|
||||||
|
let failedCount = 0
|
||||||
|
const errors: string[] = []
|
||||||
|
|
||||||
|
try {
|
||||||
|
for (let i = 0; i < sessionTokens.length; i++) {
|
||||||
|
try {
|
||||||
|
const tokenInfo = await oauthClient.validateSessionToken(sessionTokens[i], form.proxy_id)
|
||||||
|
if (!tokenInfo) {
|
||||||
|
failedCount++
|
||||||
|
errors.push(`#${i + 1}: ${oauthClient.error.value || 'Validation failed'}`)
|
||||||
|
oauthClient.error.value = ''
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
const credentials = oauthClient.buildCredentials(tokenInfo)
|
||||||
|
credentials.session_token = sessionTokens[i]
|
||||||
|
const oauthExtra = oauthClient.buildExtraInfo(tokenInfo) as Record<string, unknown> | undefined
|
||||||
|
const soraExtra = buildSoraExtra(oauthExtra)
|
||||||
|
|
||||||
|
const accountName = sessionTokens.length > 1 ? `${form.name} #${i + 1}` : form.name
|
||||||
|
await adminAPI.accounts.create({
|
||||||
|
name: accountName,
|
||||||
|
notes: form.notes,
|
||||||
|
platform: 'sora',
|
||||||
|
type: 'oauth',
|
||||||
|
credentials,
|
||||||
|
extra: soraExtra,
|
||||||
|
proxy_id: form.proxy_id,
|
||||||
|
concurrency: form.concurrency,
|
||||||
|
priority: form.priority,
|
||||||
|
rate_multiplier: form.rate_multiplier,
|
||||||
|
group_ids: form.group_ids,
|
||||||
|
expires_at: form.expires_at,
|
||||||
|
auto_pause_on_expired: autoPauseOnExpired.value
|
||||||
|
})
|
||||||
|
successCount++
|
||||||
|
} catch (error: any) {
|
||||||
|
failedCount++
|
||||||
|
const errMsg = error.response?.data?.detail || error.message || 'Unknown error'
|
||||||
|
errors.push(`#${i + 1}: ${errMsg}`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (successCount > 0 && failedCount === 0) {
|
||||||
|
appStore.showSuccess(
|
||||||
|
sessionTokens.length > 1
|
||||||
|
? t('admin.accounts.oauth.batchSuccess', { count: successCount })
|
||||||
|
: t('admin.accounts.accountCreated')
|
||||||
|
)
|
||||||
|
emit('created')
|
||||||
|
handleClose()
|
||||||
|
} else if (successCount > 0 && failedCount > 0) {
|
||||||
|
appStore.showWarning(
|
||||||
|
t('admin.accounts.oauth.batchPartialSuccess', { success: successCount, failed: failedCount })
|
||||||
|
)
|
||||||
|
oauthClient.error.value = errors.join('\n')
|
||||||
|
emit('created')
|
||||||
|
} else {
|
||||||
|
oauthClient.error.value = errors.join('\n')
|
||||||
|
appStore.showError(t('admin.accounts.oauth.batchFailed'))
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
oauthClient.loading.value = false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -3462,6 +3643,7 @@ const handleExchangeCode = async () => {
|
|||||||
|
|
||||||
switch (form.platform) {
|
switch (form.platform) {
|
||||||
case 'openai':
|
case 'openai':
|
||||||
|
case 'sora':
|
||||||
return handleOpenAIExchange(authCode)
|
return handleOpenAIExchange(authCode)
|
||||||
case 'gemini':
|
case 'gemini':
|
||||||
return handleGeminiExchange(authCode)
|
return handleGeminiExchange(authCode)
|
||||||
|
|||||||
@@ -48,6 +48,17 @@
|
|||||||
t(getOAuthKey('refreshTokenAuth'))
|
t(getOAuthKey('refreshTokenAuth'))
|
||||||
}}</span>
|
}}</span>
|
||||||
</label>
|
</label>
|
||||||
|
<label v-if="showSessionTokenOption" class="flex cursor-pointer items-center gap-2">
|
||||||
|
<input
|
||||||
|
v-model="inputMethod"
|
||||||
|
type="radio"
|
||||||
|
value="session_token"
|
||||||
|
class="text-blue-600 focus:ring-blue-500"
|
||||||
|
/>
|
||||||
|
<span class="text-sm text-blue-900 dark:text-blue-200">{{
|
||||||
|
t(getOAuthKey('sessionTokenAuth'))
|
||||||
|
}}</span>
|
||||||
|
</label>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@@ -135,6 +146,87 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<!-- Session Token Input (Sora) -->
|
||||||
|
<div v-if="inputMethod === 'session_token'" class="space-y-4">
|
||||||
|
<div
|
||||||
|
class="rounded-lg border border-blue-300 bg-white/80 p-4 dark:border-blue-600 dark:bg-gray-800/80"
|
||||||
|
>
|
||||||
|
<p class="mb-3 text-sm text-blue-700 dark:text-blue-300">
|
||||||
|
{{ t(getOAuthKey('sessionTokenDesc')) }}
|
||||||
|
</p>
|
||||||
|
|
||||||
|
<div class="mb-4">
|
||||||
|
<label
|
||||||
|
class="mb-2 flex items-center gap-2 text-sm font-semibold text-gray-700 dark:text-gray-300"
|
||||||
|
>
|
||||||
|
<Icon name="key" size="sm" class="text-blue-500" />
|
||||||
|
Session Token
|
||||||
|
<span
|
||||||
|
v-if="parsedSessionTokenCount > 1"
|
||||||
|
class="rounded-full bg-blue-500 px-2 py-0.5 text-xs text-white"
|
||||||
|
>
|
||||||
|
{{ t('admin.accounts.oauth.keysCount', { count: parsedSessionTokenCount }) }}
|
||||||
|
</span>
|
||||||
|
</label>
|
||||||
|
<textarea
|
||||||
|
v-model="sessionTokenInput"
|
||||||
|
rows="3"
|
||||||
|
class="input w-full resize-y font-mono text-sm"
|
||||||
|
:placeholder="t(getOAuthKey('sessionTokenPlaceholder'))"
|
||||||
|
></textarea>
|
||||||
|
<p
|
||||||
|
v-if="parsedSessionTokenCount > 1"
|
||||||
|
class="mt-1 text-xs text-blue-600 dark:text-blue-400"
|
||||||
|
>
|
||||||
|
{{ t('admin.accounts.oauth.batchCreateAccounts', { count: parsedSessionTokenCount }) }}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div
|
||||||
|
v-if="error"
|
||||||
|
class="mb-4 rounded-lg border border-red-200 bg-red-50 p-3 dark:border-red-700 dark:bg-red-900/30"
|
||||||
|
>
|
||||||
|
<p class="whitespace-pre-line text-sm text-red-600 dark:text-red-400">
|
||||||
|
{{ error }}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
class="btn btn-primary w-full"
|
||||||
|
:disabled="loading || !sessionTokenInput.trim()"
|
||||||
|
@click="handleValidateSessionToken"
|
||||||
|
>
|
||||||
|
<svg
|
||||||
|
v-if="loading"
|
||||||
|
class="-ml-1 mr-2 h-4 w-4 animate-spin"
|
||||||
|
fill="none"
|
||||||
|
viewBox="0 0 24 24"
|
||||||
|
>
|
||||||
|
<circle
|
||||||
|
class="opacity-25"
|
||||||
|
cx="12"
|
||||||
|
cy="12"
|
||||||
|
r="10"
|
||||||
|
stroke="currentColor"
|
||||||
|
stroke-width="4"
|
||||||
|
></circle>
|
||||||
|
<path
|
||||||
|
class="opacity-75"
|
||||||
|
fill="currentColor"
|
||||||
|
d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"
|
||||||
|
></path>
|
||||||
|
</svg>
|
||||||
|
<Icon v-else name="sparkles" size="sm" class="mr-2" />
|
||||||
|
{{
|
||||||
|
loading
|
||||||
|
? t(getOAuthKey('validating'))
|
||||||
|
: t(getOAuthKey('validateAndCreate'))
|
||||||
|
}}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
<!-- Cookie Auto-Auth Form -->
|
<!-- Cookie Auto-Auth Form -->
|
||||||
<div v-if="inputMethod === 'cookie'" class="space-y-4">
|
<div v-if="inputMethod === 'cookie'" class="space-y-4">
|
||||||
<div
|
<div
|
||||||
@@ -525,6 +617,7 @@ interface Props {
|
|||||||
methodLabel?: string
|
methodLabel?: string
|
||||||
showCookieOption?: boolean // Whether to show cookie auto-auth option
|
showCookieOption?: boolean // Whether to show cookie auto-auth option
|
||||||
showRefreshTokenOption?: boolean // Whether to show refresh token input option (OpenAI only)
|
showRefreshTokenOption?: boolean // Whether to show refresh token input option (OpenAI only)
|
||||||
|
showSessionTokenOption?: boolean // Whether to show session token input option (Sora only)
|
||||||
platform?: AccountPlatform // Platform type for different UI/text
|
platform?: AccountPlatform // Platform type for different UI/text
|
||||||
showProjectId?: boolean // New prop to control project ID visibility
|
showProjectId?: boolean // New prop to control project ID visibility
|
||||||
}
|
}
|
||||||
@@ -540,6 +633,7 @@ const props = withDefaults(defineProps<Props>(), {
|
|||||||
methodLabel: 'Authorization Method',
|
methodLabel: 'Authorization Method',
|
||||||
showCookieOption: true,
|
showCookieOption: true,
|
||||||
showRefreshTokenOption: false,
|
showRefreshTokenOption: false,
|
||||||
|
showSessionTokenOption: false,
|
||||||
platform: 'anthropic',
|
platform: 'anthropic',
|
||||||
showProjectId: true
|
showProjectId: true
|
||||||
})
|
})
|
||||||
@@ -549,6 +643,7 @@ const emit = defineEmits<{
|
|||||||
'exchange-code': [code: string]
|
'exchange-code': [code: string]
|
||||||
'cookie-auth': [sessionKey: string]
|
'cookie-auth': [sessionKey: string]
|
||||||
'validate-refresh-token': [refreshToken: string]
|
'validate-refresh-token': [refreshToken: string]
|
||||||
|
'validate-session-token': [sessionToken: string]
|
||||||
'update:inputMethod': [method: AuthInputMethod]
|
'update:inputMethod': [method: AuthInputMethod]
|
||||||
}>()
|
}>()
|
||||||
|
|
||||||
@@ -587,12 +682,13 @@ const inputMethod = ref<AuthInputMethod>(props.showCookieOption ? 'manual' : 'ma
|
|||||||
const authCodeInput = ref('')
|
const authCodeInput = ref('')
|
||||||
const sessionKeyInput = ref('')
|
const sessionKeyInput = ref('')
|
||||||
const refreshTokenInput = ref('')
|
const refreshTokenInput = ref('')
|
||||||
|
const sessionTokenInput = ref('')
|
||||||
const showHelpDialog = ref(false)
|
const showHelpDialog = ref(false)
|
||||||
const oauthState = ref('')
|
const oauthState = ref('')
|
||||||
const projectId = ref('')
|
const projectId = ref('')
|
||||||
|
|
||||||
// Computed: show method selection when either cookie or refresh token option is enabled
|
// Computed: show method selection when either cookie or refresh token option is enabled
|
||||||
const showMethodSelection = computed(() => props.showCookieOption || props.showRefreshTokenOption)
|
const showMethodSelection = computed(() => props.showCookieOption || props.showRefreshTokenOption || props.showSessionTokenOption)
|
||||||
|
|
||||||
// Clipboard
|
// Clipboard
|
||||||
const { copied, copyToClipboard } = useClipboard()
|
const { copied, copyToClipboard } = useClipboard()
|
||||||
@@ -613,6 +709,13 @@ const parsedRefreshTokenCount = computed(() => {
|
|||||||
.filter((rt) => rt).length
|
.filter((rt) => rt).length
|
||||||
})
|
})
|
||||||
|
|
||||||
|
const parsedSessionTokenCount = computed(() => {
|
||||||
|
return sessionTokenInput.value
|
||||||
|
.split('\n')
|
||||||
|
.map((st) => st.trim())
|
||||||
|
.filter((st) => st).length
|
||||||
|
})
|
||||||
|
|
||||||
// Watchers
|
// Watchers
|
||||||
watch(inputMethod, (newVal) => {
|
watch(inputMethod, (newVal) => {
|
||||||
emit('update:inputMethod', newVal)
|
emit('update:inputMethod', newVal)
|
||||||
@@ -631,7 +734,7 @@ watch(authCodeInput, (newVal) => {
|
|||||||
const url = new URL(trimmed)
|
const url = new URL(trimmed)
|
||||||
const code = url.searchParams.get('code')
|
const code = url.searchParams.get('code')
|
||||||
const stateParam = url.searchParams.get('state')
|
const stateParam = url.searchParams.get('state')
|
||||||
if ((props.platform === 'gemini' || props.platform === 'antigravity') && stateParam) {
|
if ((props.platform === 'openai' || props.platform === 'sora' || props.platform === 'gemini' || props.platform === 'antigravity') && stateParam) {
|
||||||
oauthState.value = stateParam
|
oauthState.value = stateParam
|
||||||
}
|
}
|
||||||
if (code && code !== trimmed) {
|
if (code && code !== trimmed) {
|
||||||
@@ -642,7 +745,7 @@ watch(authCodeInput, (newVal) => {
|
|||||||
// If URL parsing fails, try regex extraction
|
// If URL parsing fails, try regex extraction
|
||||||
const match = trimmed.match(/[?&]code=([^&]+)/)
|
const match = trimmed.match(/[?&]code=([^&]+)/)
|
||||||
const stateMatch = trimmed.match(/[?&]state=([^&]+)/)
|
const stateMatch = trimmed.match(/[?&]state=([^&]+)/)
|
||||||
if ((props.platform === 'gemini' || props.platform === 'antigravity') && stateMatch && stateMatch[1]) {
|
if ((props.platform === 'openai' || props.platform === 'sora' || props.platform === 'gemini' || props.platform === 'antigravity') && stateMatch && stateMatch[1]) {
|
||||||
oauthState.value = stateMatch[1]
|
oauthState.value = stateMatch[1]
|
||||||
}
|
}
|
||||||
if (match && match[1] && match[1] !== trimmed) {
|
if (match && match[1] && match[1] !== trimmed) {
|
||||||
@@ -680,6 +783,12 @@ const handleValidateRefreshToken = () => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const handleValidateSessionToken = () => {
|
||||||
|
if (sessionTokenInput.value.trim()) {
|
||||||
|
emit('validate-session-token', sessionTokenInput.value.trim())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Expose methods and state
|
// Expose methods and state
|
||||||
defineExpose({
|
defineExpose({
|
||||||
authCode: authCodeInput,
|
authCode: authCodeInput,
|
||||||
@@ -687,6 +796,7 @@ defineExpose({
|
|||||||
projectId,
|
projectId,
|
||||||
sessionKey: sessionKeyInput,
|
sessionKey: sessionKeyInput,
|
||||||
refreshToken: refreshTokenInput,
|
refreshToken: refreshTokenInput,
|
||||||
|
sessionToken: sessionTokenInput,
|
||||||
inputMethod,
|
inputMethod,
|
||||||
reset: () => {
|
reset: () => {
|
||||||
authCodeInput.value = ''
|
authCodeInput.value = ''
|
||||||
@@ -694,6 +804,7 @@ defineExpose({
|
|||||||
projectId.value = ''
|
projectId.value = ''
|
||||||
sessionKeyInput.value = ''
|
sessionKeyInput.value = ''
|
||||||
refreshTokenInput.value = ''
|
refreshTokenInput.value = ''
|
||||||
|
sessionTokenInput.value = ''
|
||||||
inputMethod.value = 'manual'
|
inputMethod.value = 'manual'
|
||||||
showHelpDialog.value = false
|
showHelpDialog.value = false
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
<div
|
<div
|
||||||
:class="[
|
:class="[
|
||||||
'flex h-10 w-10 items-center justify-center rounded-lg bg-gradient-to-br',
|
'flex h-10 w-10 items-center justify-center rounded-lg bg-gradient-to-br',
|
||||||
isOpenAI
|
isOpenAILike
|
||||||
? 'from-green-500 to-green-600'
|
? 'from-green-500 to-green-600'
|
||||||
: isGemini
|
: isGemini
|
||||||
? 'from-blue-500 to-blue-600'
|
? 'from-blue-500 to-blue-600'
|
||||||
@@ -33,6 +33,8 @@
|
|||||||
{{
|
{{
|
||||||
isOpenAI
|
isOpenAI
|
||||||
? t('admin.accounts.openaiAccount')
|
? t('admin.accounts.openaiAccount')
|
||||||
|
: isSora
|
||||||
|
? t('admin.accounts.soraAccount')
|
||||||
: isGemini
|
: isGemini
|
||||||
? t('admin.accounts.geminiAccount')
|
? t('admin.accounts.geminiAccount')
|
||||||
: isAntigravity
|
: isAntigravity
|
||||||
@@ -128,7 +130,7 @@
|
|||||||
:show-cookie-option="isAnthropic"
|
:show-cookie-option="isAnthropic"
|
||||||
:allow-multiple="false"
|
:allow-multiple="false"
|
||||||
:method-label="t('admin.accounts.inputMethod')"
|
:method-label="t('admin.accounts.inputMethod')"
|
||||||
:platform="isOpenAI ? 'openai' : isGemini ? 'gemini' : isAntigravity ? 'antigravity' : 'anthropic'"
|
:platform="isOpenAI ? 'openai' : isSora ? 'sora' : isGemini ? 'gemini' : isAntigravity ? 'antigravity' : 'anthropic'"
|
||||||
:show-project-id="isGemini && geminiOAuthType === 'code_assist'"
|
:show-project-id="isGemini && geminiOAuthType === 'code_assist'"
|
||||||
@generate-url="handleGenerateUrl"
|
@generate-url="handleGenerateUrl"
|
||||||
@cookie-auth="handleCookieAuth"
|
@cookie-auth="handleCookieAuth"
|
||||||
@@ -224,7 +226,8 @@ const { t } = useI18n()
|
|||||||
|
|
||||||
// OAuth composables
|
// OAuth composables
|
||||||
const claudeOAuth = useAccountOAuth()
|
const claudeOAuth = useAccountOAuth()
|
||||||
const openaiOAuth = useOpenAIOAuth()
|
const openaiOAuth = useOpenAIOAuth({ platform: 'openai' })
|
||||||
|
const soraOAuth = useOpenAIOAuth({ platform: 'sora' })
|
||||||
const geminiOAuth = useGeminiOAuth()
|
const geminiOAuth = useGeminiOAuth()
|
||||||
const antigravityOAuth = useAntigravityOAuth()
|
const antigravityOAuth = useAntigravityOAuth()
|
||||||
|
|
||||||
@@ -237,31 +240,34 @@ const geminiOAuthType = ref<'code_assist' | 'google_one' | 'ai_studio'>('code_as
|
|||||||
|
|
||||||
// Computed - check platform
|
// Computed - check platform
|
||||||
const isOpenAI = computed(() => props.account?.platform === 'openai')
|
const isOpenAI = computed(() => props.account?.platform === 'openai')
|
||||||
|
const isSora = computed(() => props.account?.platform === 'sora')
|
||||||
|
const isOpenAILike = computed(() => isOpenAI.value || isSora.value)
|
||||||
const isGemini = computed(() => props.account?.platform === 'gemini')
|
const isGemini = computed(() => props.account?.platform === 'gemini')
|
||||||
const isAnthropic = computed(() => props.account?.platform === 'anthropic')
|
const isAnthropic = computed(() => props.account?.platform === 'anthropic')
|
||||||
const isAntigravity = computed(() => props.account?.platform === 'antigravity')
|
const isAntigravity = computed(() => props.account?.platform === 'antigravity')
|
||||||
|
const activeOpenAIOAuth = computed(() => (isSora.value ? soraOAuth : openaiOAuth))
|
||||||
|
|
||||||
// Computed - current OAuth state based on platform
|
// Computed - current OAuth state based on platform
|
||||||
const currentAuthUrl = computed(() => {
|
const currentAuthUrl = computed(() => {
|
||||||
if (isOpenAI.value) return openaiOAuth.authUrl.value
|
if (isOpenAILike.value) return activeOpenAIOAuth.value.authUrl.value
|
||||||
if (isGemini.value) return geminiOAuth.authUrl.value
|
if (isGemini.value) return geminiOAuth.authUrl.value
|
||||||
if (isAntigravity.value) return antigravityOAuth.authUrl.value
|
if (isAntigravity.value) return antigravityOAuth.authUrl.value
|
||||||
return claudeOAuth.authUrl.value
|
return claudeOAuth.authUrl.value
|
||||||
})
|
})
|
||||||
const currentSessionId = computed(() => {
|
const currentSessionId = computed(() => {
|
||||||
if (isOpenAI.value) return openaiOAuth.sessionId.value
|
if (isOpenAILike.value) return activeOpenAIOAuth.value.sessionId.value
|
||||||
if (isGemini.value) return geminiOAuth.sessionId.value
|
if (isGemini.value) return geminiOAuth.sessionId.value
|
||||||
if (isAntigravity.value) return antigravityOAuth.sessionId.value
|
if (isAntigravity.value) return antigravityOAuth.sessionId.value
|
||||||
return claudeOAuth.sessionId.value
|
return claudeOAuth.sessionId.value
|
||||||
})
|
})
|
||||||
const currentLoading = computed(() => {
|
const currentLoading = computed(() => {
|
||||||
if (isOpenAI.value) return openaiOAuth.loading.value
|
if (isOpenAILike.value) return activeOpenAIOAuth.value.loading.value
|
||||||
if (isGemini.value) return geminiOAuth.loading.value
|
if (isGemini.value) return geminiOAuth.loading.value
|
||||||
if (isAntigravity.value) return antigravityOAuth.loading.value
|
if (isAntigravity.value) return antigravityOAuth.loading.value
|
||||||
return claudeOAuth.loading.value
|
return claudeOAuth.loading.value
|
||||||
})
|
})
|
||||||
const currentError = computed(() => {
|
const currentError = computed(() => {
|
||||||
if (isOpenAI.value) return openaiOAuth.error.value
|
if (isOpenAILike.value) return activeOpenAIOAuth.value.error.value
|
||||||
if (isGemini.value) return geminiOAuth.error.value
|
if (isGemini.value) return geminiOAuth.error.value
|
||||||
if (isAntigravity.value) return antigravityOAuth.error.value
|
if (isAntigravity.value) return antigravityOAuth.error.value
|
||||||
return claudeOAuth.error.value
|
return claudeOAuth.error.value
|
||||||
@@ -269,8 +275,8 @@ const currentError = computed(() => {
|
|||||||
|
|
||||||
// Computed
|
// Computed
|
||||||
const isManualInputMethod = computed(() => {
|
const isManualInputMethod = computed(() => {
|
||||||
// OpenAI/Gemini/Antigravity always use manual input (no cookie auth option)
|
// OpenAI/Sora/Gemini/Antigravity always use manual input (no cookie auth option)
|
||||||
return isOpenAI.value || isGemini.value || isAntigravity.value || oauthFlowRef.value?.inputMethod === 'manual'
|
return isOpenAILike.value || isGemini.value || isAntigravity.value || oauthFlowRef.value?.inputMethod === 'manual'
|
||||||
})
|
})
|
||||||
|
|
||||||
const canExchangeCode = computed(() => {
|
const canExchangeCode = computed(() => {
|
||||||
@@ -313,6 +319,7 @@ const resetState = () => {
|
|||||||
geminiOAuthType.value = 'code_assist'
|
geminiOAuthType.value = 'code_assist'
|
||||||
claudeOAuth.resetState()
|
claudeOAuth.resetState()
|
||||||
openaiOAuth.resetState()
|
openaiOAuth.resetState()
|
||||||
|
soraOAuth.resetState()
|
||||||
geminiOAuth.resetState()
|
geminiOAuth.resetState()
|
||||||
antigravityOAuth.resetState()
|
antigravityOAuth.resetState()
|
||||||
oauthFlowRef.value?.reset()
|
oauthFlowRef.value?.reset()
|
||||||
@@ -325,8 +332,8 @@ const handleClose = () => {
|
|||||||
const handleGenerateUrl = async () => {
|
const handleGenerateUrl = async () => {
|
||||||
if (!props.account) return
|
if (!props.account) return
|
||||||
|
|
||||||
if (isOpenAI.value) {
|
if (isOpenAILike.value) {
|
||||||
await openaiOAuth.generateAuthUrl(props.account.proxy_id)
|
await activeOpenAIOAuth.value.generateAuthUrl(props.account.proxy_id)
|
||||||
} else if (isGemini.value) {
|
} else if (isGemini.value) {
|
||||||
const creds = (props.account.credentials || {}) as Record<string, unknown>
|
const creds = (props.account.credentials || {}) as Record<string, unknown>
|
||||||
const tierId = typeof creds.tier_id === 'string' ? creds.tier_id : undefined
|
const tierId = typeof creds.tier_id === 'string' ? creds.tier_id : undefined
|
||||||
@@ -345,21 +352,29 @@ const handleExchangeCode = async () => {
|
|||||||
const authCode = oauthFlowRef.value?.authCode || ''
|
const authCode = oauthFlowRef.value?.authCode || ''
|
||||||
if (!authCode.trim()) return
|
if (!authCode.trim()) return
|
||||||
|
|
||||||
if (isOpenAI.value) {
|
if (isOpenAILike.value) {
|
||||||
// OpenAI OAuth flow
|
// OpenAI OAuth flow
|
||||||
const sessionId = openaiOAuth.sessionId.value
|
const oauthClient = activeOpenAIOAuth.value
|
||||||
|
const sessionId = oauthClient.sessionId.value
|
||||||
if (!sessionId) return
|
if (!sessionId) return
|
||||||
|
const stateToUse = (oauthFlowRef.value?.oauthState || oauthClient.oauthState.value || '').trim()
|
||||||
|
if (!stateToUse) {
|
||||||
|
oauthClient.error.value = t('admin.accounts.oauth.authFailed')
|
||||||
|
appStore.showError(oauthClient.error.value)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
const tokenInfo = await openaiOAuth.exchangeAuthCode(
|
const tokenInfo = await oauthClient.exchangeAuthCode(
|
||||||
authCode.trim(),
|
authCode.trim(),
|
||||||
sessionId,
|
sessionId,
|
||||||
|
stateToUse,
|
||||||
props.account.proxy_id
|
props.account.proxy_id
|
||||||
)
|
)
|
||||||
if (!tokenInfo) return
|
if (!tokenInfo) return
|
||||||
|
|
||||||
// Build credentials and extra info
|
// Build credentials and extra info
|
||||||
const credentials = openaiOAuth.buildCredentials(tokenInfo)
|
const credentials = oauthClient.buildCredentials(tokenInfo)
|
||||||
const extra = openaiOAuth.buildExtraInfo(tokenInfo)
|
const extra = oauthClient.buildExtraInfo(tokenInfo)
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Update account with new credentials
|
// Update account with new credentials
|
||||||
@@ -376,8 +391,8 @@ const handleExchangeCode = async () => {
|
|||||||
emit('reauthorized')
|
emit('reauthorized')
|
||||||
handleClose()
|
handleClose()
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
openaiOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
|
oauthClient.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
|
||||||
appStore.showError(openaiOAuth.error.value)
|
appStore.showError(oauthClient.error.value)
|
||||||
}
|
}
|
||||||
} else if (isGemini.value) {
|
} else if (isGemini.value) {
|
||||||
const sessionId = geminiOAuth.sessionId.value
|
const sessionId = geminiOAuth.sessionId.value
|
||||||
@@ -490,7 +505,7 @@ const handleExchangeCode = async () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const handleCookieAuth = async (sessionKey: string) => {
|
const handleCookieAuth = async (sessionKey: string) => {
|
||||||
if (!props.account || isOpenAI.value) return
|
if (!props.account || isOpenAILike.value) return
|
||||||
|
|
||||||
claudeOAuth.loading.value = true
|
claudeOAuth.loading.value = true
|
||||||
claudeOAuth.error.value = ''
|
claudeOAuth.error.value = ''
|
||||||
|
|||||||
@@ -238,6 +238,11 @@ const loadAvailableModels = async () => {
|
|||||||
availableModels.value.find((m) => m.id === 'gemini-3-flash-preview') ||
|
availableModels.value.find((m) => m.id === 'gemini-3-flash-preview') ||
|
||||||
availableModels.value.find((m) => m.id === 'gemini-3-pro-preview')
|
availableModels.value.find((m) => m.id === 'gemini-3-pro-preview')
|
||||||
selectedModelId.value = preferred?.id || availableModels.value[0].id
|
selectedModelId.value = preferred?.id || availableModels.value[0].id
|
||||||
|
} else if (props.account.platform === 'sora') {
|
||||||
|
const preferred =
|
||||||
|
availableModels.value.find((m) => m.id === 'gpt-image') ||
|
||||||
|
availableModels.value.find((m) => !m.id.startsWith('prompt-enhance'))
|
||||||
|
selectedModelId.value = preferred?.id || availableModels.value[0].id
|
||||||
} else {
|
} else {
|
||||||
// Try to select Sonnet as default, otherwise use first model
|
// Try to select Sonnet as default, otherwise use first model
|
||||||
const sonnetModel = availableModels.value.find((m) => m.id.includes('sonnet'))
|
const sonnetModel = availableModels.value.find((m) => m.id.includes('sonnet'))
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
<div
|
<div
|
||||||
:class="[
|
:class="[
|
||||||
'flex h-10 w-10 items-center justify-center rounded-lg bg-gradient-to-br',
|
'flex h-10 w-10 items-center justify-center rounded-lg bg-gradient-to-br',
|
||||||
isOpenAI
|
isOpenAILike
|
||||||
? 'from-green-500 to-green-600'
|
? 'from-green-500 to-green-600'
|
||||||
: isGemini
|
: isGemini
|
||||||
? 'from-blue-500 to-blue-600'
|
? 'from-blue-500 to-blue-600'
|
||||||
@@ -33,6 +33,8 @@
|
|||||||
{{
|
{{
|
||||||
isOpenAI
|
isOpenAI
|
||||||
? t('admin.accounts.openaiAccount')
|
? t('admin.accounts.openaiAccount')
|
||||||
|
: isSora
|
||||||
|
? t('admin.accounts.soraAccount')
|
||||||
: isGemini
|
: isGemini
|
||||||
? t('admin.accounts.geminiAccount')
|
? t('admin.accounts.geminiAccount')
|
||||||
: isAntigravity
|
: isAntigravity
|
||||||
@@ -128,7 +130,7 @@
|
|||||||
:show-cookie-option="isAnthropic"
|
:show-cookie-option="isAnthropic"
|
||||||
:allow-multiple="false"
|
:allow-multiple="false"
|
||||||
:method-label="t('admin.accounts.inputMethod')"
|
:method-label="t('admin.accounts.inputMethod')"
|
||||||
:platform="isOpenAI ? 'openai' : isGemini ? 'gemini' : isAntigravity ? 'antigravity' : 'anthropic'"
|
:platform="isOpenAI ? 'openai' : isSora ? 'sora' : isGemini ? 'gemini' : isAntigravity ? 'antigravity' : 'anthropic'"
|
||||||
:show-project-id="isGemini && geminiOAuthType === 'code_assist'"
|
:show-project-id="isGemini && geminiOAuthType === 'code_assist'"
|
||||||
@generate-url="handleGenerateUrl"
|
@generate-url="handleGenerateUrl"
|
||||||
@cookie-auth="handleCookieAuth"
|
@cookie-auth="handleCookieAuth"
|
||||||
@@ -224,7 +226,8 @@ const { t } = useI18n()
|
|||||||
|
|
||||||
// OAuth composables
|
// OAuth composables
|
||||||
const claudeOAuth = useAccountOAuth()
|
const claudeOAuth = useAccountOAuth()
|
||||||
const openaiOAuth = useOpenAIOAuth()
|
const openaiOAuth = useOpenAIOAuth({ platform: 'openai' })
|
||||||
|
const soraOAuth = useOpenAIOAuth({ platform: 'sora' })
|
||||||
const geminiOAuth = useGeminiOAuth()
|
const geminiOAuth = useGeminiOAuth()
|
||||||
const antigravityOAuth = useAntigravityOAuth()
|
const antigravityOAuth = useAntigravityOAuth()
|
||||||
|
|
||||||
@@ -237,31 +240,34 @@ const geminiOAuthType = ref<'code_assist' | 'google_one' | 'ai_studio'>('code_as
|
|||||||
|
|
||||||
// Computed - check platform
|
// Computed - check platform
|
||||||
const isOpenAI = computed(() => props.account?.platform === 'openai')
|
const isOpenAI = computed(() => props.account?.platform === 'openai')
|
||||||
|
const isSora = computed(() => props.account?.platform === 'sora')
|
||||||
|
const isOpenAILike = computed(() => isOpenAI.value || isSora.value)
|
||||||
const isGemini = computed(() => props.account?.platform === 'gemini')
|
const isGemini = computed(() => props.account?.platform === 'gemini')
|
||||||
const isAnthropic = computed(() => props.account?.platform === 'anthropic')
|
const isAnthropic = computed(() => props.account?.platform === 'anthropic')
|
||||||
const isAntigravity = computed(() => props.account?.platform === 'antigravity')
|
const isAntigravity = computed(() => props.account?.platform === 'antigravity')
|
||||||
|
const activeOpenAIOAuth = computed(() => (isSora.value ? soraOAuth : openaiOAuth))
|
||||||
|
|
||||||
// Computed - current OAuth state based on platform
|
// Computed - current OAuth state based on platform
|
||||||
const currentAuthUrl = computed(() => {
|
const currentAuthUrl = computed(() => {
|
||||||
if (isOpenAI.value) return openaiOAuth.authUrl.value
|
if (isOpenAILike.value) return activeOpenAIOAuth.value.authUrl.value
|
||||||
if (isGemini.value) return geminiOAuth.authUrl.value
|
if (isGemini.value) return geminiOAuth.authUrl.value
|
||||||
if (isAntigravity.value) return antigravityOAuth.authUrl.value
|
if (isAntigravity.value) return antigravityOAuth.authUrl.value
|
||||||
return claudeOAuth.authUrl.value
|
return claudeOAuth.authUrl.value
|
||||||
})
|
})
|
||||||
const currentSessionId = computed(() => {
|
const currentSessionId = computed(() => {
|
||||||
if (isOpenAI.value) return openaiOAuth.sessionId.value
|
if (isOpenAILike.value) return activeOpenAIOAuth.value.sessionId.value
|
||||||
if (isGemini.value) return geminiOAuth.sessionId.value
|
if (isGemini.value) return geminiOAuth.sessionId.value
|
||||||
if (isAntigravity.value) return antigravityOAuth.sessionId.value
|
if (isAntigravity.value) return antigravityOAuth.sessionId.value
|
||||||
return claudeOAuth.sessionId.value
|
return claudeOAuth.sessionId.value
|
||||||
})
|
})
|
||||||
const currentLoading = computed(() => {
|
const currentLoading = computed(() => {
|
||||||
if (isOpenAI.value) return openaiOAuth.loading.value
|
if (isOpenAILike.value) return activeOpenAIOAuth.value.loading.value
|
||||||
if (isGemini.value) return geminiOAuth.loading.value
|
if (isGemini.value) return geminiOAuth.loading.value
|
||||||
if (isAntigravity.value) return antigravityOAuth.loading.value
|
if (isAntigravity.value) return antigravityOAuth.loading.value
|
||||||
return claudeOAuth.loading.value
|
return claudeOAuth.loading.value
|
||||||
})
|
})
|
||||||
const currentError = computed(() => {
|
const currentError = computed(() => {
|
||||||
if (isOpenAI.value) return openaiOAuth.error.value
|
if (isOpenAILike.value) return activeOpenAIOAuth.value.error.value
|
||||||
if (isGemini.value) return geminiOAuth.error.value
|
if (isGemini.value) return geminiOAuth.error.value
|
||||||
if (isAntigravity.value) return antigravityOAuth.error.value
|
if (isAntigravity.value) return antigravityOAuth.error.value
|
||||||
return claudeOAuth.error.value
|
return claudeOAuth.error.value
|
||||||
@@ -269,8 +275,8 @@ const currentError = computed(() => {
|
|||||||
|
|
||||||
// Computed
|
// Computed
|
||||||
const isManualInputMethod = computed(() => {
|
const isManualInputMethod = computed(() => {
|
||||||
// OpenAI/Gemini/Antigravity always use manual input (no cookie auth option)
|
// OpenAI/Sora/Gemini/Antigravity always use manual input (no cookie auth option)
|
||||||
return isOpenAI.value || isGemini.value || isAntigravity.value || oauthFlowRef.value?.inputMethod === 'manual'
|
return isOpenAILike.value || isGemini.value || isAntigravity.value || oauthFlowRef.value?.inputMethod === 'manual'
|
||||||
})
|
})
|
||||||
|
|
||||||
const canExchangeCode = computed(() => {
|
const canExchangeCode = computed(() => {
|
||||||
@@ -313,6 +319,7 @@ const resetState = () => {
|
|||||||
geminiOAuthType.value = 'code_assist'
|
geminiOAuthType.value = 'code_assist'
|
||||||
claudeOAuth.resetState()
|
claudeOAuth.resetState()
|
||||||
openaiOAuth.resetState()
|
openaiOAuth.resetState()
|
||||||
|
soraOAuth.resetState()
|
||||||
geminiOAuth.resetState()
|
geminiOAuth.resetState()
|
||||||
antigravityOAuth.resetState()
|
antigravityOAuth.resetState()
|
||||||
oauthFlowRef.value?.reset()
|
oauthFlowRef.value?.reset()
|
||||||
@@ -325,8 +332,8 @@ const handleClose = () => {
|
|||||||
const handleGenerateUrl = async () => {
|
const handleGenerateUrl = async () => {
|
||||||
if (!props.account) return
|
if (!props.account) return
|
||||||
|
|
||||||
if (isOpenAI.value) {
|
if (isOpenAILike.value) {
|
||||||
await openaiOAuth.generateAuthUrl(props.account.proxy_id)
|
await activeOpenAIOAuth.value.generateAuthUrl(props.account.proxy_id)
|
||||||
} else if (isGemini.value) {
|
} else if (isGemini.value) {
|
||||||
const creds = (props.account.credentials || {}) as Record<string, unknown>
|
const creds = (props.account.credentials || {}) as Record<string, unknown>
|
||||||
const tierId = typeof creds.tier_id === 'string' ? creds.tier_id : undefined
|
const tierId = typeof creds.tier_id === 'string' ? creds.tier_id : undefined
|
||||||
@@ -345,21 +352,29 @@ const handleExchangeCode = async () => {
|
|||||||
const authCode = oauthFlowRef.value?.authCode || ''
|
const authCode = oauthFlowRef.value?.authCode || ''
|
||||||
if (!authCode.trim()) return
|
if (!authCode.trim()) return
|
||||||
|
|
||||||
if (isOpenAI.value) {
|
if (isOpenAILike.value) {
|
||||||
// OpenAI OAuth flow
|
// OpenAI OAuth flow
|
||||||
const sessionId = openaiOAuth.sessionId.value
|
const oauthClient = activeOpenAIOAuth.value
|
||||||
|
const sessionId = oauthClient.sessionId.value
|
||||||
if (!sessionId) return
|
if (!sessionId) return
|
||||||
|
const stateToUse = (oauthFlowRef.value?.oauthState || oauthClient.oauthState.value || '').trim()
|
||||||
|
if (!stateToUse) {
|
||||||
|
oauthClient.error.value = t('admin.accounts.oauth.authFailed')
|
||||||
|
appStore.showError(oauthClient.error.value)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
const tokenInfo = await openaiOAuth.exchangeAuthCode(
|
const tokenInfo = await oauthClient.exchangeAuthCode(
|
||||||
authCode.trim(),
|
authCode.trim(),
|
||||||
sessionId,
|
sessionId,
|
||||||
|
stateToUse,
|
||||||
props.account.proxy_id
|
props.account.proxy_id
|
||||||
)
|
)
|
||||||
if (!tokenInfo) return
|
if (!tokenInfo) return
|
||||||
|
|
||||||
// Build credentials and extra info
|
// Build credentials and extra info
|
||||||
const credentials = openaiOAuth.buildCredentials(tokenInfo)
|
const credentials = oauthClient.buildCredentials(tokenInfo)
|
||||||
const extra = openaiOAuth.buildExtraInfo(tokenInfo)
|
const extra = oauthClient.buildExtraInfo(tokenInfo)
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Update account with new credentials
|
// Update account with new credentials
|
||||||
@@ -376,8 +391,8 @@ const handleExchangeCode = async () => {
|
|||||||
emit('reauthorized', updatedAccount)
|
emit('reauthorized', updatedAccount)
|
||||||
handleClose()
|
handleClose()
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
openaiOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
|
oauthClient.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
|
||||||
appStore.showError(openaiOAuth.error.value)
|
appStore.showError(oauthClient.error.value)
|
||||||
}
|
}
|
||||||
} else if (isGemini.value) {
|
} else if (isGemini.value) {
|
||||||
const sessionId = geminiOAuth.sessionId.value
|
const sessionId = geminiOAuth.sessionId.value
|
||||||
@@ -490,7 +505,7 @@ const handleExchangeCode = async () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const handleCookieAuth = async (sessionKey: string) => {
|
const handleCookieAuth = async (sessionKey: string) => {
|
||||||
if (!props.account || isOpenAI.value) return
|
if (!props.account || isOpenAILike.value) return
|
||||||
|
|
||||||
claudeOAuth.loading.value = true
|
claudeOAuth.loading.value = true
|
||||||
claudeOAuth.error.value = ''
|
claudeOAuth.error.value = ''
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import { useAppStore } from '@/stores/app'
|
|||||||
import { adminAPI } from '@/api/admin'
|
import { adminAPI } from '@/api/admin'
|
||||||
|
|
||||||
export type AddMethod = 'oauth' | 'setup-token'
|
export type AddMethod = 'oauth' | 'setup-token'
|
||||||
export type AuthInputMethod = 'manual' | 'cookie' | 'refresh_token'
|
export type AuthInputMethod = 'manual' | 'cookie' | 'refresh_token' | 'session_token'
|
||||||
|
|
||||||
export interface OAuthState {
|
export interface OAuthState {
|
||||||
authUrl: string
|
authUrl: string
|
||||||
|
|||||||
@@ -19,12 +19,21 @@ export interface OpenAITokenInfo {
|
|||||||
[key: string]: unknown
|
[key: string]: unknown
|
||||||
}
|
}
|
||||||
|
|
||||||
export function useOpenAIOAuth() {
|
export type OpenAIOAuthPlatform = 'openai' | 'sora'
|
||||||
|
|
||||||
|
interface UseOpenAIOAuthOptions {
|
||||||
|
platform?: OpenAIOAuthPlatform
|
||||||
|
}
|
||||||
|
|
||||||
|
export function useOpenAIOAuth(options?: UseOpenAIOAuthOptions) {
|
||||||
const appStore = useAppStore()
|
const appStore = useAppStore()
|
||||||
|
const oauthPlatform = options?.platform ?? 'openai'
|
||||||
|
const endpointPrefix = oauthPlatform === 'sora' ? '/admin/sora' : '/admin/openai'
|
||||||
|
|
||||||
// State
|
// State
|
||||||
const authUrl = ref('')
|
const authUrl = ref('')
|
||||||
const sessionId = ref('')
|
const sessionId = ref('')
|
||||||
|
const oauthState = ref('')
|
||||||
const loading = ref(false)
|
const loading = ref(false)
|
||||||
const error = ref('')
|
const error = ref('')
|
||||||
|
|
||||||
@@ -32,6 +41,7 @@ export function useOpenAIOAuth() {
|
|||||||
const resetState = () => {
|
const resetState = () => {
|
||||||
authUrl.value = ''
|
authUrl.value = ''
|
||||||
sessionId.value = ''
|
sessionId.value = ''
|
||||||
|
oauthState.value = ''
|
||||||
loading.value = false
|
loading.value = false
|
||||||
error.value = ''
|
error.value = ''
|
||||||
}
|
}
|
||||||
@@ -44,6 +54,7 @@ export function useOpenAIOAuth() {
|
|||||||
loading.value = true
|
loading.value = true
|
||||||
authUrl.value = ''
|
authUrl.value = ''
|
||||||
sessionId.value = ''
|
sessionId.value = ''
|
||||||
|
oauthState.value = ''
|
||||||
error.value = ''
|
error.value = ''
|
||||||
|
|
||||||
try {
|
try {
|
||||||
@@ -56,11 +67,17 @@ export function useOpenAIOAuth() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const response = await adminAPI.accounts.generateAuthUrl(
|
const response = await adminAPI.accounts.generateAuthUrl(
|
||||||
'/admin/openai/generate-auth-url',
|
`${endpointPrefix}/generate-auth-url`,
|
||||||
payload
|
payload
|
||||||
)
|
)
|
||||||
authUrl.value = response.auth_url
|
authUrl.value = response.auth_url
|
||||||
sessionId.value = response.session_id
|
sessionId.value = response.session_id
|
||||||
|
try {
|
||||||
|
const parsed = new URL(response.auth_url)
|
||||||
|
oauthState.value = parsed.searchParams.get('state') || ''
|
||||||
|
} catch {
|
||||||
|
oauthState.value = ''
|
||||||
|
}
|
||||||
return true
|
return true
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
error.value = err.response?.data?.detail || 'Failed to generate OpenAI auth URL'
|
error.value = err.response?.data?.detail || 'Failed to generate OpenAI auth URL'
|
||||||
@@ -75,10 +92,11 @@ export function useOpenAIOAuth() {
|
|||||||
const exchangeAuthCode = async (
|
const exchangeAuthCode = async (
|
||||||
code: string,
|
code: string,
|
||||||
currentSessionId: string,
|
currentSessionId: string,
|
||||||
|
state: string,
|
||||||
proxyId?: number | null
|
proxyId?: number | null
|
||||||
): Promise<OpenAITokenInfo | null> => {
|
): Promise<OpenAITokenInfo | null> => {
|
||||||
if (!code.trim() || !currentSessionId) {
|
if (!code.trim() || !currentSessionId || !state.trim()) {
|
||||||
error.value = 'Missing auth code or session ID'
|
error.value = 'Missing auth code, session ID, or state'
|
||||||
return null
|
return null
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -86,15 +104,16 @@ export function useOpenAIOAuth() {
|
|||||||
error.value = ''
|
error.value = ''
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const payload: { session_id: string; code: string; proxy_id?: number } = {
|
const payload: { session_id: string; code: string; state: string; proxy_id?: number } = {
|
||||||
session_id: currentSessionId,
|
session_id: currentSessionId,
|
||||||
code: code.trim()
|
code: code.trim(),
|
||||||
|
state: state.trim()
|
||||||
}
|
}
|
||||||
if (proxyId) {
|
if (proxyId) {
|
||||||
payload.proxy_id = proxyId
|
payload.proxy_id = proxyId
|
||||||
}
|
}
|
||||||
|
|
||||||
const tokenInfo = await adminAPI.accounts.exchangeCode('/admin/openai/exchange-code', payload)
|
const tokenInfo = await adminAPI.accounts.exchangeCode(`${endpointPrefix}/exchange-code`, payload)
|
||||||
return tokenInfo as OpenAITokenInfo
|
return tokenInfo as OpenAITokenInfo
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
error.value = err.response?.data?.detail || 'Failed to exchange OpenAI auth code'
|
error.value = err.response?.data?.detail || 'Failed to exchange OpenAI auth code'
|
||||||
@@ -120,7 +139,11 @@ export function useOpenAIOAuth() {
|
|||||||
|
|
||||||
try {
|
try {
|
||||||
// Use dedicated refresh-token endpoint
|
// Use dedicated refresh-token endpoint
|
||||||
const tokenInfo = await adminAPI.accounts.refreshOpenAIToken(refreshToken.trim(), proxyId)
|
const tokenInfo = await adminAPI.accounts.refreshOpenAIToken(
|
||||||
|
refreshToken.trim(),
|
||||||
|
proxyId,
|
||||||
|
`${endpointPrefix}/refresh-token`
|
||||||
|
)
|
||||||
return tokenInfo as OpenAITokenInfo
|
return tokenInfo as OpenAITokenInfo
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
error.value = err.response?.data?.detail || 'Failed to validate refresh token'
|
error.value = err.response?.data?.detail || 'Failed to validate refresh token'
|
||||||
@@ -131,6 +154,33 @@ export function useOpenAIOAuth() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate Sora session token and get access token
|
||||||
|
const validateSessionToken = async (
|
||||||
|
sessionToken: string,
|
||||||
|
proxyId?: number | null
|
||||||
|
): Promise<OpenAITokenInfo | null> => {
|
||||||
|
if (!sessionToken.trim()) {
|
||||||
|
error.value = 'Missing session token'
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
loading.value = true
|
||||||
|
error.value = ''
|
||||||
|
try {
|
||||||
|
const tokenInfo = await adminAPI.accounts.validateSoraSessionToken(
|
||||||
|
sessionToken.trim(),
|
||||||
|
proxyId,
|
||||||
|
`${endpointPrefix}/st2at`
|
||||||
|
)
|
||||||
|
return tokenInfo as OpenAITokenInfo
|
||||||
|
} catch (err: any) {
|
||||||
|
error.value = err.response?.data?.detail || 'Failed to validate session token'
|
||||||
|
appStore.showError(error.value)
|
||||||
|
return null
|
||||||
|
} finally {
|
||||||
|
loading.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Build credentials for OpenAI OAuth account
|
// Build credentials for OpenAI OAuth account
|
||||||
const buildCredentials = (tokenInfo: OpenAITokenInfo): Record<string, unknown> => {
|
const buildCredentials = (tokenInfo: OpenAITokenInfo): Record<string, unknown> => {
|
||||||
const creds: Record<string, unknown> = {
|
const creds: Record<string, unknown> = {
|
||||||
@@ -172,6 +222,7 @@ export function useOpenAIOAuth() {
|
|||||||
// State
|
// State
|
||||||
authUrl,
|
authUrl,
|
||||||
sessionId,
|
sessionId,
|
||||||
|
oauthState,
|
||||||
loading,
|
loading,
|
||||||
error,
|
error,
|
||||||
// Methods
|
// Methods
|
||||||
@@ -179,6 +230,7 @@ export function useOpenAIOAuth() {
|
|||||||
generateAuthUrl,
|
generateAuthUrl,
|
||||||
exchangeAuthCode,
|
exchangeAuthCode,
|
||||||
validateRefreshToken,
|
validateRefreshToken,
|
||||||
|
validateSessionToken,
|
||||||
buildCredentials,
|
buildCredentials,
|
||||||
buildExtraInfo
|
buildExtraInfo
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1740,9 +1740,13 @@ export default {
|
|||||||
refreshTokenAuth: 'Manual RT Input',
|
refreshTokenAuth: 'Manual RT Input',
|
||||||
refreshTokenDesc: 'Enter your existing OpenAI Refresh Token(s). Supports batch input (one per line). The system will automatically validate and create accounts.',
|
refreshTokenDesc: 'Enter your existing OpenAI Refresh Token(s). Supports batch input (one per line). The system will automatically validate and create accounts.',
|
||||||
refreshTokenPlaceholder: 'Paste your OpenAI Refresh Token...\nSupports multiple, one per line',
|
refreshTokenPlaceholder: 'Paste your OpenAI Refresh Token...\nSupports multiple, one per line',
|
||||||
|
sessionTokenAuth: 'Manual ST Input',
|
||||||
|
sessionTokenDesc: 'Enter your existing Sora Session Token(s). Supports batch input (one per line). The system will automatically validate and create accounts.',
|
||||||
|
sessionTokenPlaceholder: 'Paste your Sora Session Token...\nSupports multiple, one per line',
|
||||||
validating: 'Validating...',
|
validating: 'Validating...',
|
||||||
validateAndCreate: 'Validate & Create Account',
|
validateAndCreate: 'Validate & Create Account',
|
||||||
pleaseEnterRefreshToken: 'Please enter Refresh Token'
|
pleaseEnterRefreshToken: 'Please enter Refresh Token',
|
||||||
|
pleaseEnterSessionToken: 'Please enter Session Token'
|
||||||
},
|
},
|
||||||
// Gemini specific
|
// Gemini specific
|
||||||
gemini: {
|
gemini: {
|
||||||
@@ -1963,6 +1967,7 @@ export default {
|
|||||||
reAuthorizeAccount: 'Re-Authorize Account',
|
reAuthorizeAccount: 'Re-Authorize Account',
|
||||||
claudeCodeAccount: 'Claude Code Account',
|
claudeCodeAccount: 'Claude Code Account',
|
||||||
openaiAccount: 'OpenAI Account',
|
openaiAccount: 'OpenAI Account',
|
||||||
|
soraAccount: 'Sora Account',
|
||||||
geminiAccount: 'Gemini Account',
|
geminiAccount: 'Gemini Account',
|
||||||
antigravityAccount: 'Antigravity Account',
|
antigravityAccount: 'Antigravity Account',
|
||||||
inputMethod: 'Input Method',
|
inputMethod: 'Input Method',
|
||||||
|
|||||||
@@ -1879,9 +1879,13 @@ export default {
|
|||||||
refreshTokenAuth: '手动输入 RT',
|
refreshTokenAuth: '手动输入 RT',
|
||||||
refreshTokenDesc: '输入您已有的 OpenAI Refresh Token,支持批量输入(每行一个),系统将自动验证并创建账号。',
|
refreshTokenDesc: '输入您已有的 OpenAI Refresh Token,支持批量输入(每行一个),系统将自动验证并创建账号。',
|
||||||
refreshTokenPlaceholder: '粘贴您的 OpenAI Refresh Token...\n支持多个,每行一个',
|
refreshTokenPlaceholder: '粘贴您的 OpenAI Refresh Token...\n支持多个,每行一个',
|
||||||
|
sessionTokenAuth: '手动输入 ST',
|
||||||
|
sessionTokenDesc: '输入您已有的 Sora Session Token,支持批量输入(每行一个),系统将自动验证并创建账号。',
|
||||||
|
sessionTokenPlaceholder: '粘贴您的 Sora Session Token...\n支持多个,每行一个',
|
||||||
validating: '验证中...',
|
validating: '验证中...',
|
||||||
validateAndCreate: '验证并创建账号',
|
validateAndCreate: '验证并创建账号',
|
||||||
pleaseEnterRefreshToken: '请输入 Refresh Token'
|
pleaseEnterRefreshToken: '请输入 Refresh Token',
|
||||||
|
pleaseEnterSessionToken: '请输入 Session Token'
|
||||||
},
|
},
|
||||||
// Gemini specific
|
// Gemini specific
|
||||||
gemini: {
|
gemini: {
|
||||||
@@ -2097,6 +2101,7 @@ export default {
|
|||||||
reAuthorizeAccount: '重新授权账号',
|
reAuthorizeAccount: '重新授权账号',
|
||||||
claudeCodeAccount: 'Claude Code 账号',
|
claudeCodeAccount: 'Claude Code 账号',
|
||||||
openaiAccount: 'OpenAI 账号',
|
openaiAccount: 'OpenAI 账号',
|
||||||
|
soraAccount: 'Sora 账号',
|
||||||
geminiAccount: 'Gemini 账号',
|
geminiAccount: 'Gemini 账号',
|
||||||
antigravityAccount: 'Antigravity 账号',
|
antigravityAccount: 'Antigravity 账号',
|
||||||
inputMethod: '输入方式',
|
inputMethod: '输入方式',
|
||||||
|
|||||||
Reference in New Issue
Block a user