feat(sora): 对齐 Sora OAuth 流程并隔离网关请求路径

- 新增并接通 Sora 专用 OAuth 接口与 ST/RT 换取能力
- 完成前端 Sora 授权、RT/ST 手动导入与账号创建流程
- 强化 Sora token 恢复、转发日志与网关路由隔离行为
- 补充后端服务层与路由层相关测试覆盖

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
yangjianbo
2026-02-19 08:02:56 +08:00
parent 36bb327024
commit 900cce20a1
39 changed files with 2561 additions and 283 deletions

View File

@@ -1333,6 +1333,12 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
return
}
// Handle Sora accounts
if account.Platform == service.PlatformSora {
response.Success(c, service.DefaultSoraModels(nil))
return
}
// Handle Claude/Anthropic accounts
// For OAuth and Setup-Token accounts: return default models
if account.IsOAuth() {

View File

@@ -2,6 +2,7 @@ package admin
import (
"strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
@@ -16,6 +17,13 @@ type OpenAIOAuthHandler struct {
adminService service.AdminService
}
func oauthPlatformFromPath(c *gin.Context) string {
if strings.Contains(c.FullPath(), "/admin/sora/") {
return service.PlatformSora
}
return service.PlatformOpenAI
}
// NewOpenAIOAuthHandler creates a new OpenAI OAuth handler
func NewOpenAIOAuthHandler(openaiOAuthService *service.OpenAIOAuthService, adminService service.AdminService) *OpenAIOAuthHandler {
return &OpenAIOAuthHandler{
@@ -52,6 +60,7 @@ func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) {
type OpenAIExchangeCodeRequest struct {
SessionID string `json:"session_id" binding:"required"`
Code string `json:"code" binding:"required"`
State string `json:"state" binding:"required"`
RedirectURI string `json:"redirect_uri"`
ProxyID *int64 `json:"proxy_id"`
}
@@ -68,6 +77,7 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) {
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
SessionID: req.SessionID,
Code: req.Code,
State: req.State,
RedirectURI: req.RedirectURI,
ProxyID: req.ProxyID,
})
@@ -81,18 +91,29 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) {
// OpenAIRefreshTokenRequest represents the request for refreshing OpenAI token
type OpenAIRefreshTokenRequest struct {
RefreshToken string `json:"refresh_token" binding:"required"`
RefreshToken string `json:"refresh_token"`
RT string `json:"rt"`
ClientID string `json:"client_id"`
ProxyID *int64 `json:"proxy_id"`
}
// RefreshToken refreshes an OpenAI OAuth token
// POST /api/v1/admin/openai/refresh-token
// POST /api/v1/admin/sora/rt2at
func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
var req OpenAIRefreshTokenRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
refreshToken := strings.TrimSpace(req.RefreshToken)
if refreshToken == "" {
refreshToken = strings.TrimSpace(req.RT)
}
if refreshToken == "" {
response.BadRequest(c, "refresh_token is required")
return
}
var proxyURL string
if req.ProxyID != nil {
@@ -102,7 +123,7 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
}
}
tokenInfo, err := h.openaiOAuthService.RefreshToken(c.Request.Context(), req.RefreshToken, proxyURL)
tokenInfo, err := h.openaiOAuthService.RefreshTokenWithClientID(c.Request.Context(), refreshToken, proxyURL, strings.TrimSpace(req.ClientID))
if err != nil {
response.ErrorFrom(c, err)
return
@@ -111,8 +132,39 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
response.Success(c, tokenInfo)
}
// RefreshAccountToken refreshes token for a specific OpenAI account
// ExchangeSoraSessionToken exchanges Sora session token to access token
// POST /api/v1/admin/sora/st2at
func (h *OpenAIOAuthHandler) ExchangeSoraSessionToken(c *gin.Context) {
var req struct {
SessionToken string `json:"session_token"`
ST string `json:"st"`
ProxyID *int64 `json:"proxy_id"`
}
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
sessionToken := strings.TrimSpace(req.SessionToken)
if sessionToken == "" {
sessionToken = strings.TrimSpace(req.ST)
}
if sessionToken == "" {
response.BadRequest(c, "session_token is required")
return
}
tokenInfo, err := h.openaiOAuthService.ExchangeSoraSessionToken(c.Request.Context(), sessionToken, req.ProxyID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, tokenInfo)
}
// RefreshAccountToken refreshes token for a specific OpenAI/Sora account
// POST /api/v1/admin/openai/accounts/:id/refresh
// POST /api/v1/admin/sora/accounts/:id/refresh
func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
@@ -127,9 +179,9 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
return
}
// Ensure account is OpenAI platform
if !account.IsOpenAI() {
response.BadRequest(c, "Account is not an OpenAI account")
platform := oauthPlatformFromPath(c)
if account.Platform != platform {
response.BadRequest(c, "Account platform does not match OAuth endpoint")
return
}
@@ -167,12 +219,14 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
response.Success(c, dto.AccountFromService(updatedAccount))
}
// CreateAccountFromOAuth creates a new OpenAI OAuth account from token info
// CreateAccountFromOAuth creates a new OpenAI/Sora OAuth account from token info
// POST /api/v1/admin/openai/create-from-oauth
// POST /api/v1/admin/sora/create-from-oauth
func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
var req struct {
SessionID string `json:"session_id" binding:"required"`
Code string `json:"code" binding:"required"`
State string `json:"state" binding:"required"`
RedirectURI string `json:"redirect_uri"`
ProxyID *int64 `json:"proxy_id"`
Name string `json:"name"`
@@ -189,6 +243,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
SessionID: req.SessionID,
Code: req.Code,
State: req.State,
RedirectURI: req.RedirectURI,
ProxyID: req.ProxyID,
})
@@ -200,19 +255,25 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
// Build credentials from token info
credentials := h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
platform := oauthPlatformFromPath(c)
// Use email as default name if not provided
name := req.Name
if name == "" && tokenInfo.Email != "" {
name = tokenInfo.Email
}
if name == "" {
name = "OpenAI OAuth Account"
if platform == service.PlatformSora {
name = "Sora OAuth Account"
} else {
name = "OpenAI OAuth Account"
}
}
// Create account
account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{
Name: name,
Platform: "openai",
Platform: platform,
Type: "oauth",
Credentials: credentials,
ProxyID: req.ProxyID,

View File

@@ -212,6 +212,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0
var lastFailoverBody []byte
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "")
@@ -224,7 +225,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
}
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverBody, streamStarted)
return
}
account := selection.Account
@@ -287,14 +288,19 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
failedAccountIDs[account.ID] = struct{}{}
if switchCount >= maxAccountSwitches {
lastFailoverStatus = failoverErr.StatusCode
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
lastFailoverBody = failoverErr.ResponseBody
h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverBody, streamStarted)
return
}
lastFailoverStatus = failoverErr.StatusCode
lastFailoverBody = failoverErr.ResponseBody
switchCount++
upstreamErrCode, upstreamErrMsg := extractUpstreamErrorCodeAndMessage(lastFailoverBody)
reqLog.Warn("sora.upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.String("upstream_error_code", upstreamErrCode),
zap.String("upstream_error_message", upstreamErrMsg),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
)
@@ -360,17 +366,32 @@ func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, s
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
}
func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) {
status, errType, errMsg := h.mapUpstreamError(statusCode)
func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, responseBody []byte, streamStarted bool) {
status, errType, errMsg := h.mapUpstreamError(statusCode, responseBody)
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
}
func (h *SoraGatewayHandler) mapUpstreamError(statusCode int) (int, string, string) {
func (h *SoraGatewayHandler) mapUpstreamError(statusCode int, responseBody []byte) (int, string, string) {
upstreamCode, upstreamMessage := extractUpstreamErrorCodeAndMessage(responseBody)
if upstreamMessage != "" {
switch statusCode {
case 401, 403, 404, 500, 502, 503, 504:
return http.StatusBadGateway, "upstream_error", upstreamMessage
case 429:
return http.StatusTooManyRequests, "rate_limit_error", upstreamMessage
}
}
switch statusCode {
case 401:
return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator"
case 403:
return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator"
case 404:
if strings.EqualFold(upstreamCode, "unsupported_country_code") {
return http.StatusBadGateway, "upstream_error", "Upstream region capability unavailable for this account, please contact administrator"
}
return http.StatusBadGateway, "upstream_error", "Upstream capability unavailable for this account, please contact administrator"
case 429:
return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later"
case 529:
@@ -382,6 +403,41 @@ func (h *SoraGatewayHandler) mapUpstreamError(statusCode int) (int, string, stri
}
}
func extractUpstreamErrorCodeAndMessage(body []byte) (string, string) {
trimmed := strings.TrimSpace(string(body))
if trimmed == "" {
return "", ""
}
if !gjson.Valid(trimmed) {
return "", truncateSoraErrorMessage(trimmed, 256)
}
code := strings.TrimSpace(gjson.Get(trimmed, "error.code").String())
if code == "" {
code = strings.TrimSpace(gjson.Get(trimmed, "code").String())
}
message := strings.TrimSpace(gjson.Get(trimmed, "error.message").String())
if message == "" {
message = strings.TrimSpace(gjson.Get(trimmed, "message").String())
}
if message == "" {
message = strings.TrimSpace(gjson.Get(trimmed, "error.detail").String())
}
if message == "" {
message = strings.TrimSpace(gjson.Get(trimmed, "detail").String())
}
return code, truncateSoraErrorMessage(message, 512)
}
func truncateSoraErrorMessage(s string, maxLen int) string {
if maxLen <= 0 {
return ""
}
if len(s) <= maxLen {
return s
}
return s[:maxLen] + "...(truncated)"
}
func (h *SoraGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
if streamStarted {
flusher, ok := c.Writer.(http.Flusher)

View File

@@ -43,6 +43,9 @@ func (s *stubSoraClient) CreateImageTask(ctx context.Context, account *service.A
func (s *stubSoraClient) CreateVideoTask(ctx context.Context, account *service.Account, req service.SoraVideoRequest) (string, error) {
return "task-video", nil
}
func (s *stubSoraClient) EnhancePrompt(ctx context.Context, account *service.Account, prompt, expansionLevel string, durationS int) (string, error) {
return "enhanced prompt", nil
}
func (s *stubSoraClient) GetImageTask(ctx context.Context, account *service.Account, taskID string) (*service.SoraImageTaskStatus, error) {
return &service.SoraImageTaskStatus{ID: taskID, Status: "completed", URLs: s.imageURLs}, nil
}