feat(sora): 对齐 Sora OAuth 流程并隔离网关请求路径
- 新增并接通 Sora 专用 OAuth 接口与 ST/RT 换取能力 - 完成前端 Sora 授权、RT/ST 手动导入与账号创建流程 - 强化 Sora token 恢复、转发日志与网关路由隔离行为 - 补充后端服务层与路由层相关测试覆盖 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
@@ -56,12 +57,49 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie
|
||||
}
|
||||
|
||||
func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
|
||||
return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, "")
|
||||
}
|
||||
|
||||
func (s *openaiOAuthService) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
|
||||
if strings.TrimSpace(clientID) != "" {
|
||||
return s.refreshTokenWithClientID(ctx, refreshToken, proxyURL, strings.TrimSpace(clientID))
|
||||
}
|
||||
|
||||
clientIDs := []string{
|
||||
openai.ClientID,
|
||||
openai.SoraClientID,
|
||||
}
|
||||
seen := make(map[string]struct{}, len(clientIDs))
|
||||
var lastErr error
|
||||
for _, clientID := range clientIDs {
|
||||
clientID = strings.TrimSpace(clientID)
|
||||
if clientID == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[clientID]; ok {
|
||||
continue
|
||||
}
|
||||
seen[clientID] = struct{}{}
|
||||
|
||||
tokenResp, err := s.refreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID)
|
||||
if err == nil {
|
||||
return tokenResp, nil
|
||||
}
|
||||
lastErr = err
|
||||
}
|
||||
if lastErr != nil {
|
||||
return nil, lastErr
|
||||
}
|
||||
return nil, infraerrors.New(http.StatusBadGateway, "OPENAI_OAUTH_TOKEN_REFRESH_FAILED", "token refresh failed")
|
||||
}
|
||||
|
||||
func (s *openaiOAuthService) refreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL, clientID string) (*openai.TokenResponse, error) {
|
||||
client := createOpenAIReqClient(proxyURL)
|
||||
|
||||
formData := url.Values{}
|
||||
formData.Set("grant_type", "refresh_token")
|
||||
formData.Set("refresh_token", refreshToken)
|
||||
formData.Set("client_id", openai.ClientID)
|
||||
formData.Set("client_id", clientID)
|
||||
formData.Set("scope", openai.RefreshScopes)
|
||||
|
||||
var tokenResp openai.TokenResponse
|
||||
|
||||
@@ -136,6 +136,60 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FormFields() {
|
||||
require.Equal(s.T(), "rt2", resp.RefreshToken)
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FallbackToSoraClientID() {
|
||||
var seenClientIDs []string
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
clientID := r.PostForm.Get("client_id")
|
||||
seenClientIDs = append(seenClientIDs, clientID)
|
||||
if clientID == openai.ClientID {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = io.WriteString(w, "invalid_grant")
|
||||
return
|
||||
}
|
||||
if clientID == openai.SoraClientID {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"access_token":"at-sora","refresh_token":"rt-sora","token_type":"bearer","expires_in":3600}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
}))
|
||||
|
||||
resp, err := s.svc.RefreshToken(s.ctx, "rt", "")
|
||||
require.NoError(s.T(), err, "RefreshToken")
|
||||
require.Equal(s.T(), "at-sora", resp.AccessToken)
|
||||
require.Equal(s.T(), "rt-sora", resp.RefreshToken)
|
||||
require.Equal(s.T(), []string{openai.ClientID, openai.SoraClientID}, seenClientIDs)
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseProvidedClientID() {
|
||||
const customClientID = "custom-client-id"
|
||||
var seenClientIDs []string
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
clientID := r.PostForm.Get("client_id")
|
||||
seenClientIDs = append(seenClientIDs, clientID)
|
||||
if clientID != customClientID {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"access_token":"at-custom","refresh_token":"rt-custom","token_type":"bearer","expires_in":3600}`)
|
||||
}))
|
||||
|
||||
resp, err := s.svc.RefreshTokenWithClientID(s.ctx, "rt", "", customClientID)
|
||||
require.NoError(s.T(), err, "RefreshTokenWithClientID")
|
||||
require.Equal(s.T(), "at-custom", resp.AccessToken)
|
||||
require.Equal(s.T(), "rt-custom", resp.RefreshToken)
|
||||
require.Equal(s.T(), []string{customClientID}, seenClientIDs)
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestNonSuccessStatus_IncludesBody() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
|
||||
Reference in New Issue
Block a user