feat(sora): 对齐 Sora OAuth 流程并隔离网关请求路径
- 新增并接通 Sora 专用 OAuth 接口与 ST/RT 换取能力 - 完成前端 Sora 授权、RT/ST 手动导入与账号创建流程 - 强化 Sora token 恢复、转发日志与网关路由隔离行为 - 补充后端服务层与路由层相关测试覆盖 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -27,11 +27,13 @@ import (
|
||||
// sseDataPrefix matches SSE data lines with optional whitespace after colon.
|
||||
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
|
||||
var sseDataPrefix = regexp.MustCompile(`^data:\s*`)
|
||||
var cloudflareRayPattern = regexp.MustCompile(`(?i)cRay:\s*'([a-z0-9-]+)'`)
|
||||
|
||||
const (
|
||||
testClaudeAPIURL = "https://api.anthropic.com/v1/messages"
|
||||
chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses"
|
||||
soraMeAPIURL = "https://sora.chatgpt.com/backend/me" // Sora 用户信息接口,用于测试连接
|
||||
soraBillingAPIURL = "https://sora.chatgpt.com/backend/billing/subscriptions"
|
||||
)
|
||||
|
||||
// TestEvent represents a SSE event for account testing
|
||||
@@ -502,8 +504,9 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
enableSoraTLSFingerprint := s.shouldEnableSoraTLSFingerprint()
|
||||
|
||||
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, enableSoraTLSFingerprint)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||
}
|
||||
@@ -512,7 +515,10 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, string(body)))
|
||||
if isCloudflareChallengeResponse(resp.StatusCode, body) {
|
||||
return s.sendErrorAndEnd(c, formatCloudflareChallengeMessage("Sora request blocked by Cloudflare challenge (HTTP 403). Please switch to a clean proxy/network and retry.", resp.Header, body))
|
||||
}
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, truncateSoraErrorBody(body, 512)))
|
||||
}
|
||||
|
||||
// 解析 /me 响应,提取用户信息
|
||||
@@ -531,10 +537,129 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: info})
|
||||
}
|
||||
|
||||
// 追加轻量能力检查:订阅信息查询(失败仅告警,不中断连接测试)
|
||||
subReq, err := http.NewRequestWithContext(ctx, "GET", soraBillingAPIURL, nil)
|
||||
if err == nil {
|
||||
subReq.Header.Set("Authorization", "Bearer "+authToken)
|
||||
subReq.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
|
||||
subReq.Header.Set("Accept", "application/json")
|
||||
|
||||
subResp, subErr := s.httpUpstream.DoWithTLS(subReq, proxyURL, account.ID, account.Concurrency, enableSoraTLSFingerprint)
|
||||
if subErr != nil {
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check skipped: %s", subErr.Error())})
|
||||
} else {
|
||||
subBody, _ := io.ReadAll(subResp.Body)
|
||||
_ = subResp.Body.Close()
|
||||
if subResp.StatusCode == http.StatusOK {
|
||||
if summary := parseSoraSubscriptionSummary(subBody); summary != "" {
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: summary})
|
||||
} else {
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: "Subscription check OK"})
|
||||
}
|
||||
} else {
|
||||
if isCloudflareChallengeResponse(subResp.StatusCode, subBody) {
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage("Subscription check blocked by Cloudflare challenge (HTTP 403)", subResp.Header, subBody)})
|
||||
} else {
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check returned %d", subResp.StatusCode)})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseSoraSubscriptionSummary(body []byte) string {
|
||||
var subResp struct {
|
||||
Data []struct {
|
||||
Plan struct {
|
||||
ID string `json:"id"`
|
||||
Title string `json:"title"`
|
||||
} `json:"plan"`
|
||||
EndTS string `json:"end_ts"`
|
||||
} `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &subResp); err != nil {
|
||||
return ""
|
||||
}
|
||||
if len(subResp.Data) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
first := subResp.Data[0]
|
||||
parts := make([]string, 0, 3)
|
||||
if first.Plan.Title != "" {
|
||||
parts = append(parts, first.Plan.Title)
|
||||
}
|
||||
if first.Plan.ID != "" {
|
||||
parts = append(parts, first.Plan.ID)
|
||||
}
|
||||
if first.EndTS != "" {
|
||||
parts = append(parts, "end="+first.EndTS)
|
||||
}
|
||||
if len(parts) == 0 {
|
||||
return ""
|
||||
}
|
||||
return "Subscription: " + strings.Join(parts, " | ")
|
||||
}
|
||||
|
||||
func (s *AccountTestService) shouldEnableSoraTLSFingerprint() bool {
|
||||
if s == nil || s.cfg == nil {
|
||||
return false
|
||||
}
|
||||
return s.cfg.Gateway.TLSFingerprint.Enabled && !s.cfg.Sora.Client.DisableTLSFingerprint
|
||||
}
|
||||
|
||||
func isCloudflareChallengeResponse(statusCode int, body []byte) bool {
|
||||
if statusCode != http.StatusForbidden {
|
||||
return false
|
||||
}
|
||||
preview := strings.ToLower(truncateSoraErrorBody(body, 4096))
|
||||
return strings.Contains(preview, "window._cf_chl_opt") ||
|
||||
strings.Contains(preview, "just a moment") ||
|
||||
strings.Contains(preview, "enable javascript and cookies to continue")
|
||||
}
|
||||
|
||||
func formatCloudflareChallengeMessage(base string, headers http.Header, body []byte) string {
|
||||
rayID := extractCloudflareRayID(headers, body)
|
||||
if rayID == "" {
|
||||
return base
|
||||
}
|
||||
return fmt.Sprintf("%s (cf-ray: %s)", base, rayID)
|
||||
}
|
||||
|
||||
func extractCloudflareRayID(headers http.Header, body []byte) string {
|
||||
if headers != nil {
|
||||
rayID := strings.TrimSpace(headers.Get("cf-ray"))
|
||||
if rayID != "" {
|
||||
return rayID
|
||||
}
|
||||
rayID = strings.TrimSpace(headers.Get("Cf-Ray"))
|
||||
if rayID != "" {
|
||||
return rayID
|
||||
}
|
||||
}
|
||||
|
||||
preview := truncateSoraErrorBody(body, 8192)
|
||||
matches := cloudflareRayPattern.FindStringSubmatch(preview)
|
||||
if len(matches) >= 2 {
|
||||
return strings.TrimSpace(matches[1])
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func truncateSoraErrorBody(body []byte, max int) string {
|
||||
if max <= 0 {
|
||||
max = 512
|
||||
}
|
||||
raw := strings.TrimSpace(string(body))
|
||||
if len(raw) <= max {
|
||||
return raw
|
||||
}
|
||||
return raw[:max] + "...(truncated)"
|
||||
}
|
||||
|
||||
// testAntigravityAccountConnection tests an Antigravity account's connection
|
||||
// 支持 Claude 和 Gemini 两种协议,使用非流式请求
|
||||
func (s *AccountTestService) testAntigravityAccountConnection(c *gin.Context, account *Account, modelID string) error {
|
||||
|
||||
193
backend/internal/service/account_test_service_sora_test.go
Normal file
193
backend/internal/service/account_test_service_sora_test.go
Normal file
@@ -0,0 +1,193 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type queuedHTTPUpstream struct {
|
||||
responses []*http.Response
|
||||
requests []*http.Request
|
||||
tlsFlags []bool
|
||||
}
|
||||
|
||||
func (u *queuedHTTPUpstream) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
|
||||
return nil, fmt.Errorf("unexpected Do call")
|
||||
}
|
||||
|
||||
func (u *queuedHTTPUpstream) DoWithTLS(req *http.Request, _ string, _ int64, _ int, enableTLSFingerprint bool) (*http.Response, error) {
|
||||
u.requests = append(u.requests, req)
|
||||
u.tlsFlags = append(u.tlsFlags, enableTLSFingerprint)
|
||||
if len(u.responses) == 0 {
|
||||
return nil, fmt.Errorf("no mocked response")
|
||||
}
|
||||
resp := u.responses[0]
|
||||
u.responses = u.responses[1:]
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func newJSONResponse(status int, body string) *http.Response {
|
||||
return &http.Response{
|
||||
StatusCode: status,
|
||||
Header: make(http.Header),
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
}
|
||||
}
|
||||
|
||||
func newJSONResponseWithHeader(status int, body, key, value string) *http.Response {
|
||||
resp := newJSONResponse(status, body)
|
||||
resp.Header.Set(key, value)
|
||||
return resp
|
||||
}
|
||||
|
||||
func newSoraTestContext() (*gin.Context, *httptest.ResponseRecorder) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil)
|
||||
return c, rec
|
||||
}
|
||||
|
||||
func TestAccountTestService_testSoraAccountConnection_WithSubscription(t *testing.T) {
|
||||
upstream := &queuedHTTPUpstream{
|
||||
responses: []*http.Response{
|
||||
newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`),
|
||||
newJSONResponse(http.StatusOK, `{"data":[{"plan":{"id":"chatgpt_plus","title":"ChatGPT Plus"},"end_ts":"2026-12-31T00:00:00Z"}]}`),
|
||||
},
|
||||
}
|
||||
svc := &AccountTestService{
|
||||
httpUpstream: upstream,
|
||||
cfg: &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
TLSFingerprint: config.TLSFingerprintConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
Sora: config.SoraConfig{
|
||||
Client: config.SoraClientConfig{
|
||||
DisableTLSFingerprint: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Platform: PlatformSora,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "test_token",
|
||||
},
|
||||
}
|
||||
|
||||
c, rec := newSoraTestContext()
|
||||
err := svc.testSoraAccountConnection(c, account)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Len(t, upstream.requests, 2)
|
||||
require.Equal(t, soraMeAPIURL, upstream.requests[0].URL.String())
|
||||
require.Equal(t, soraBillingAPIURL, upstream.requests[1].URL.String())
|
||||
require.Equal(t, "Bearer test_token", upstream.requests[0].Header.Get("Authorization"))
|
||||
require.Equal(t, "Bearer test_token", upstream.requests[1].Header.Get("Authorization"))
|
||||
require.Equal(t, []bool{true, true}, upstream.tlsFlags)
|
||||
|
||||
body := rec.Body.String()
|
||||
require.Contains(t, body, `"type":"test_start"`)
|
||||
require.Contains(t, body, "Sora connection OK - Email: demo@example.com")
|
||||
require.Contains(t, body, "Subscription: ChatGPT Plus | chatgpt_plus | end=2026-12-31T00:00:00Z")
|
||||
require.Contains(t, body, `"type":"test_complete","success":true`)
|
||||
}
|
||||
|
||||
func TestAccountTestService_testSoraAccountConnection_SubscriptionFailedStillSuccess(t *testing.T) {
|
||||
upstream := &queuedHTTPUpstream{
|
||||
responses: []*http.Response{
|
||||
newJSONResponse(http.StatusOK, `{"name":"demo-user"}`),
|
||||
newJSONResponse(http.StatusForbidden, `{"error":{"message":"forbidden"}}`),
|
||||
},
|
||||
}
|
||||
svc := &AccountTestService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Platform: PlatformSora,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "test_token",
|
||||
},
|
||||
}
|
||||
|
||||
c, rec := newSoraTestContext()
|
||||
err := svc.testSoraAccountConnection(c, account)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Len(t, upstream.requests, 2)
|
||||
body := rec.Body.String()
|
||||
require.Contains(t, body, "Sora connection OK - User: demo-user")
|
||||
require.Contains(t, body, "Subscription check returned 403")
|
||||
require.Contains(t, body, `"type":"test_complete","success":true`)
|
||||
}
|
||||
|
||||
func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge(t *testing.T) {
|
||||
upstream := &queuedHTTPUpstream{
|
||||
responses: []*http.Response{
|
||||
newJSONResponseWithHeader(http.StatusForbidden, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={};</script><noscript>Enable JavaScript and cookies to continue</noscript></body></html>`, "cf-ray", "9cff2d62d83bb98d"),
|
||||
},
|
||||
}
|
||||
svc := &AccountTestService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Platform: PlatformSora,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "test_token",
|
||||
},
|
||||
}
|
||||
|
||||
c, rec := newSoraTestContext()
|
||||
err := svc.testSoraAccountConnection(c, account)
|
||||
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "Cloudflare challenge")
|
||||
require.Contains(t, err.Error(), "cf-ray: 9cff2d62d83bb98d")
|
||||
body := rec.Body.String()
|
||||
require.Contains(t, body, `"type":"error"`)
|
||||
require.Contains(t, body, "Cloudflare challenge")
|
||||
require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d")
|
||||
}
|
||||
|
||||
func TestAccountTestService_testSoraAccountConnection_SubscriptionCloudflareChallengeWithRay(t *testing.T) {
|
||||
upstream := &queuedHTTPUpstream{
|
||||
responses: []*http.Response{
|
||||
newJSONResponse(http.StatusOK, `{"name":"demo-user"}`),
|
||||
newJSONResponse(http.StatusForbidden, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={cRay: '9cff2d62d83bb98d'};</script><noscript>Enable JavaScript and cookies to continue</noscript></body></html>`),
|
||||
},
|
||||
}
|
||||
svc := &AccountTestService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Platform: PlatformSora,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "test_token",
|
||||
},
|
||||
}
|
||||
|
||||
c, rec := newSoraTestContext()
|
||||
err := svc.testSoraAccountConnection(c, account)
|
||||
|
||||
require.NoError(t, err)
|
||||
body := rec.Body.String()
|
||||
require.Contains(t, body, "Subscription check blocked by Cloudflare challenge (HTTP 403)")
|
||||
require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d")
|
||||
require.Contains(t, body, `"type":"test_complete","success":true`)
|
||||
}
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
type OpenAIOAuthClient interface {
|
||||
ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error)
|
||||
RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error)
|
||||
RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error)
|
||||
}
|
||||
|
||||
// ClaudeOAuthClient handles HTTP requests for Claude OAuth flows
|
||||
|
||||
@@ -2,13 +2,20 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
)
|
||||
|
||||
var openAISoraSessionAuthURL = "https://sora.chatgpt.com/api/auth/session"
|
||||
|
||||
// OpenAIOAuthService handles OpenAI OAuth authentication flows
|
||||
type OpenAIOAuthService struct {
|
||||
sessionStore *openai.SessionStore
|
||||
@@ -92,6 +99,7 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
|
||||
type OpenAIExchangeCodeInput struct {
|
||||
SessionID string
|
||||
Code string
|
||||
State string
|
||||
RedirectURI string
|
||||
ProxyID *int64
|
||||
}
|
||||
@@ -116,6 +124,12 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
|
||||
if !ok {
|
||||
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_SESSION_NOT_FOUND", "session not found or expired")
|
||||
}
|
||||
if input.State == "" {
|
||||
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_STATE_REQUIRED", "oauth state is required")
|
||||
}
|
||||
if subtle.ConstantTimeCompare([]byte(input.State), []byte(session.State)) != 1 {
|
||||
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_STATE", "invalid oauth state")
|
||||
}
|
||||
|
||||
// Get proxy URL: prefer input.ProxyID, fallback to session.ProxyURL
|
||||
proxyURL := session.ProxyURL
|
||||
@@ -173,7 +187,12 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
|
||||
|
||||
// RefreshToken refreshes an OpenAI OAuth token
|
||||
func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken string, proxyURL string) (*OpenAITokenInfo, error) {
|
||||
tokenResp, err := s.oauthClient.RefreshToken(ctx, refreshToken, proxyURL)
|
||||
return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, "")
|
||||
}
|
||||
|
||||
// RefreshTokenWithClientID refreshes an OpenAI/Sora OAuth token with optional client_id.
|
||||
func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refreshToken string, proxyURL string, clientID string) (*OpenAITokenInfo, error) {
|
||||
tokenResp, err := s.oauthClient.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -205,13 +224,83 @@ func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken stri
|
||||
return tokenInfo, nil
|
||||
}
|
||||
|
||||
// RefreshAccountToken refreshes token for an OpenAI account
|
||||
func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
|
||||
if !account.IsOpenAI() {
|
||||
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI account")
|
||||
// ExchangeSoraSessionToken exchanges Sora session_token to access_token.
|
||||
func (s *OpenAIOAuthService) ExchangeSoraSessionToken(ctx context.Context, sessionToken string, proxyID *int64) (*OpenAITokenInfo, error) {
|
||||
if strings.TrimSpace(sessionToken) == "" {
|
||||
return nil, infraerrors.New(http.StatusBadRequest, "SORA_SESSION_TOKEN_REQUIRED", "session_token is required")
|
||||
}
|
||||
|
||||
refreshToken := account.GetOpenAIRefreshToken()
|
||||
proxyURL, err := s.resolveProxyURL(ctx, proxyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, openAISoraSessionAuthURL, nil)
|
||||
if err != nil {
|
||||
return nil, infraerrors.Newf(http.StatusInternalServerError, "SORA_SESSION_REQUEST_BUILD_FAILED", "failed to build request: %v", err)
|
||||
}
|
||||
req.Header.Set("Cookie", "__Secure-next-auth.session-token="+strings.TrimSpace(sessionToken))
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Origin", "https://sora.chatgpt.com")
|
||||
req.Header.Set("Referer", "https://sora.chatgpt.com/")
|
||||
req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
|
||||
|
||||
client := newOpenAIOAuthHTTPClient(proxyURL)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_REQUEST_FAILED", "request failed: %v", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_EXCHANGE_FAILED", "status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
var sessionResp struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
Expires string `json:"expires"`
|
||||
User struct {
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
} `json:"user"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &sessionResp); err != nil {
|
||||
return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_PARSE_FAILED", "failed to parse response: %v", err)
|
||||
}
|
||||
if strings.TrimSpace(sessionResp.AccessToken) == "" {
|
||||
return nil, infraerrors.New(http.StatusBadGateway, "SORA_SESSION_ACCESS_TOKEN_MISSING", "session exchange response missing access token")
|
||||
}
|
||||
|
||||
expiresAt := time.Now().Add(time.Hour).Unix()
|
||||
if strings.TrimSpace(sessionResp.Expires) != "" {
|
||||
if parsed, parseErr := time.Parse(time.RFC3339, sessionResp.Expires); parseErr == nil {
|
||||
expiresAt = parsed.Unix()
|
||||
}
|
||||
}
|
||||
expiresIn := expiresAt - time.Now().Unix()
|
||||
if expiresIn < 0 {
|
||||
expiresIn = 0
|
||||
}
|
||||
|
||||
return &OpenAITokenInfo{
|
||||
AccessToken: strings.TrimSpace(sessionResp.AccessToken),
|
||||
ExpiresIn: expiresIn,
|
||||
ExpiresAt: expiresAt,
|
||||
Email: strings.TrimSpace(sessionResp.User.Email),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RefreshAccountToken refreshes token for an OpenAI/Sora OAuth account
|
||||
func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
|
||||
if account.Platform != PlatformOpenAI && account.Platform != PlatformSora {
|
||||
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI/Sora account")
|
||||
}
|
||||
if account.Type != AccountTypeOAuth {
|
||||
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT_TYPE", "account is not an OAuth account")
|
||||
}
|
||||
|
||||
refreshToken := account.GetCredential("refresh_token")
|
||||
if refreshToken == "" {
|
||||
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_NO_REFRESH_TOKEN", "no refresh token available")
|
||||
}
|
||||
@@ -224,7 +313,8 @@ func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *A
|
||||
}
|
||||
}
|
||||
|
||||
return s.RefreshToken(ctx, refreshToken, proxyURL)
|
||||
clientID := account.GetCredential("client_id")
|
||||
return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID)
|
||||
}
|
||||
|
||||
// BuildAccountCredentials builds credentials map from token info
|
||||
@@ -260,3 +350,30 @@ func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo)
|
||||
func (s *OpenAIOAuthService) Stop() {
|
||||
s.sessionStore.Stop()
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthService) resolveProxyURL(ctx context.Context, proxyID *int64) (string, error) {
|
||||
if proxyID == nil {
|
||||
return "", nil
|
||||
}
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
|
||||
if err != nil {
|
||||
return "", infraerrors.Newf(http.StatusBadRequest, "OPENAI_OAUTH_PROXY_NOT_FOUND", "proxy not found: %v", err)
|
||||
}
|
||||
if proxy == nil {
|
||||
return "", nil
|
||||
}
|
||||
return proxy.URL(), nil
|
||||
}
|
||||
|
||||
func newOpenAIOAuthHTTPClient(proxyURL string) *http.Client {
|
||||
transport := &http.Transport{}
|
||||
if strings.TrimSpace(proxyURL) != "" {
|
||||
if parsed, err := url.Parse(proxyURL); err == nil && parsed.Host != "" {
|
||||
transport.Proxy = http.ProxyURL(parsed)
|
||||
}
|
||||
}
|
||||
return &http.Client{
|
||||
Timeout: 120 * time.Second,
|
||||
Transport: transport,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,69 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type openaiOAuthClientNoopStub struct{}
|
||||
|
||||
func (s *openaiOAuthClientNoopStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *openaiOAuthClientNoopStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *openaiOAuthClientNoopStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func TestOpenAIOAuthService_ExchangeSoraSessionToken_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, http.MethodGet, r.Method)
|
||||
require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=st-token")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
origin := openAISoraSessionAuthURL
|
||||
openAISoraSessionAuthURL = server.URL
|
||||
defer func() { openAISoraSessionAuthURL = origin }()
|
||||
|
||||
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
|
||||
defer svc.Stop()
|
||||
|
||||
info, err := svc.ExchangeSoraSessionToken(context.Background(), "st-token", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, info)
|
||||
require.Equal(t, "at-token", info.AccessToken)
|
||||
require.Equal(t, "demo@example.com", info.Email)
|
||||
require.Greater(t, info.ExpiresAt, int64(0))
|
||||
}
|
||||
|
||||
func TestOpenAIOAuthService_ExchangeSoraSessionToken_MissingAccessToken(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"expires":"2099-01-01T00:00:00Z"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
origin := openAISoraSessionAuthURL
|
||||
openAISoraSessionAuthURL = server.URL
|
||||
defer func() { openAISoraSessionAuthURL = origin }()
|
||||
|
||||
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
|
||||
defer svc.Stop()
|
||||
|
||||
_, err := svc.ExchangeSoraSessionToken(context.Background(), "st-token", nil)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "missing access token")
|
||||
}
|
||||
102
backend/internal/service/openai_oauth_service_state_test.go
Normal file
102
backend/internal/service/openai_oauth_service_state_test.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type openaiOAuthClientStateStub struct {
|
||||
exchangeCalled int32
|
||||
}
|
||||
|
||||
func (s *openaiOAuthClientStateStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
|
||||
atomic.AddInt32(&s.exchangeCalled, 1)
|
||||
return &openai.TokenResponse{
|
||||
AccessToken: "at",
|
||||
RefreshToken: "rt",
|
||||
ExpiresIn: 3600,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *openaiOAuthClientStateStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *openaiOAuthClientStateStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
|
||||
return s.RefreshToken(ctx, refreshToken, proxyURL)
|
||||
}
|
||||
|
||||
func TestOpenAIOAuthService_ExchangeCode_StateRequired(t *testing.T) {
|
||||
client := &openaiOAuthClientStateStub{}
|
||||
svc := NewOpenAIOAuthService(nil, client)
|
||||
defer svc.Stop()
|
||||
|
||||
svc.sessionStore.Set("sid", &openai.OAuthSession{
|
||||
State: "expected-state",
|
||||
CodeVerifier: "verifier",
|
||||
RedirectURI: openai.DefaultRedirectURI,
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
|
||||
_, err := svc.ExchangeCode(context.Background(), &OpenAIExchangeCodeInput{
|
||||
SessionID: "sid",
|
||||
Code: "auth-code",
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "oauth state is required")
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&client.exchangeCalled))
|
||||
}
|
||||
|
||||
func TestOpenAIOAuthService_ExchangeCode_StateMismatch(t *testing.T) {
|
||||
client := &openaiOAuthClientStateStub{}
|
||||
svc := NewOpenAIOAuthService(nil, client)
|
||||
defer svc.Stop()
|
||||
|
||||
svc.sessionStore.Set("sid", &openai.OAuthSession{
|
||||
State: "expected-state",
|
||||
CodeVerifier: "verifier",
|
||||
RedirectURI: openai.DefaultRedirectURI,
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
|
||||
_, err := svc.ExchangeCode(context.Background(), &OpenAIExchangeCodeInput{
|
||||
SessionID: "sid",
|
||||
Code: "auth-code",
|
||||
State: "wrong-state",
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid oauth state")
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&client.exchangeCalled))
|
||||
}
|
||||
|
||||
func TestOpenAIOAuthService_ExchangeCode_StateMatch(t *testing.T) {
|
||||
client := &openaiOAuthClientStateStub{}
|
||||
svc := NewOpenAIOAuthService(nil, client)
|
||||
defer svc.Stop()
|
||||
|
||||
svc.sessionStore.Set("sid", &openai.OAuthSession{
|
||||
State: "expected-state",
|
||||
CodeVerifier: "verifier",
|
||||
RedirectURI: openai.DefaultRedirectURI,
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
|
||||
info, err := svc.ExchangeCode(context.Background(), &OpenAIExchangeCodeInput{
|
||||
SessionID: "sid",
|
||||
Code: "auth-code",
|
||||
State: "expected-state",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, info)
|
||||
require.Equal(t, "at", info.AccessToken)
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&client.exchangeCalled))
|
||||
|
||||
_, ok := svc.sessionStore.Get("sid")
|
||||
require.False(t, ok)
|
||||
}
|
||||
@@ -157,7 +157,11 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
}
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
|
||||
if p.openAIOAuthService == nil {
|
||||
if account.Platform == PlatformSora {
|
||||
slog.Debug("openai_token_refresh_skipped_for_sora", "account_id", account.ID)
|
||||
// Sora 账号不走 OpenAI OAuth 刷新,交由 Sora 客户端的 ST/RT 恢复链路处理。
|
||||
refreshFailed = true
|
||||
} else if p.openAIOAuthService == nil {
|
||||
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
|
||||
p.metrics.refreshFailure.Add(1)
|
||||
refreshFailed = true // 无法刷新,标记失败
|
||||
@@ -206,7 +210,11 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
|
||||
// 仅在 expires_at 已过期/接近过期时才执行无锁刷新
|
||||
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
|
||||
if p.openAIOAuthService == nil {
|
||||
if account.Platform == PlatformSora {
|
||||
slog.Debug("openai_token_refresh_skipped_for_sora_degraded", "account_id", account.ID)
|
||||
// Sora 账号不走 OpenAI OAuth 刷新,交由 Sora 客户端的 ST/RT 恢复链路处理。
|
||||
refreshFailed = true
|
||||
} else if p.openAIOAuthService == nil {
|
||||
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
|
||||
p.metrics.refreshFailure.Add(1)
|
||||
refreshFailed = true
|
||||
|
||||
@@ -17,12 +17,15 @@ import (
|
||||
"net/textproto"
|
||||
"net/url"
|
||||
"path"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
openaioauth "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
|
||||
"github.com/google/uuid"
|
||||
"github.com/tidwall/gjson"
|
||||
"golang.org/x/crypto/sha3"
|
||||
@@ -34,6 +37,11 @@ const (
|
||||
soraDefaultUserAgent = "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)"
|
||||
)
|
||||
|
||||
var (
|
||||
soraSessionAuthURL = "https://sora.chatgpt.com/api/auth/session"
|
||||
soraOAuthTokenURL = "https://auth.openai.com/oauth/token"
|
||||
)
|
||||
|
||||
const (
|
||||
soraPowMaxIteration = 500000
|
||||
)
|
||||
@@ -96,6 +104,7 @@ type SoraClient interface {
|
||||
UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error)
|
||||
CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error)
|
||||
CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error)
|
||||
EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error)
|
||||
GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error)
|
||||
GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error)
|
||||
}
|
||||
@@ -157,26 +166,94 @@ func (e *SoraUpstreamError) Error() string {
|
||||
|
||||
// SoraDirectClient 直连 Sora 实现
|
||||
type SoraDirectClient struct {
|
||||
cfg *config.Config
|
||||
httpUpstream HTTPUpstream
|
||||
tokenProvider *OpenAITokenProvider
|
||||
cfg *config.Config
|
||||
httpUpstream HTTPUpstream
|
||||
tokenProvider *OpenAITokenProvider
|
||||
accountRepo AccountRepository
|
||||
soraAccountRepo SoraAccountRepository
|
||||
baseURL string
|
||||
}
|
||||
|
||||
// NewSoraDirectClient 创建 Sora 直连客户端
|
||||
func NewSoraDirectClient(cfg *config.Config, httpUpstream HTTPUpstream, tokenProvider *OpenAITokenProvider) *SoraDirectClient {
|
||||
baseURL := ""
|
||||
if cfg != nil {
|
||||
rawBaseURL := strings.TrimRight(strings.TrimSpace(cfg.Sora.Client.BaseURL), "/")
|
||||
baseURL = normalizeSoraBaseURL(rawBaseURL)
|
||||
if rawBaseURL != "" && baseURL != rawBaseURL {
|
||||
log.Printf("[SoraClient] normalized base_url from %s to %s", sanitizeSoraLogURL(rawBaseURL), sanitizeSoraLogURL(baseURL))
|
||||
}
|
||||
}
|
||||
return &SoraDirectClient{
|
||||
cfg: cfg,
|
||||
httpUpstream: httpUpstream,
|
||||
tokenProvider: tokenProvider,
|
||||
baseURL: baseURL,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) SetAccountRepositories(accountRepo AccountRepository, soraAccountRepo SoraAccountRepository) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.accountRepo = accountRepo
|
||||
c.soraAccountRepo = soraAccountRepo
|
||||
}
|
||||
|
||||
// Enabled 判断是否启用 Sora 直连
|
||||
func (c *SoraDirectClient) Enabled() bool {
|
||||
if c == nil || c.cfg == nil {
|
||||
if c == nil {
|
||||
return false
|
||||
}
|
||||
return strings.TrimSpace(c.cfg.Sora.Client.BaseURL) != ""
|
||||
if strings.TrimSpace(c.baseURL) != "" {
|
||||
return true
|
||||
}
|
||||
if c.cfg == nil {
|
||||
return false
|
||||
}
|
||||
return strings.TrimSpace(normalizeSoraBaseURL(c.cfg.Sora.Client.BaseURL)) != ""
|
||||
}
|
||||
|
||||
// PreflightCheck 在创建任务前执行账号能力预检。
|
||||
// 当前仅对视频模型执行 /nf/check 预检,用于提前识别额度耗尽或能力缺失。
|
||||
func (c *SoraDirectClient) PreflightCheck(ctx context.Context, account *Account, requestedModel string, modelCfg SoraModelConfig) error {
|
||||
if modelCfg.Type != "video" {
|
||||
return nil
|
||||
}
|
||||
token, err := c.getAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
headers := c.buildBaseHeaders(token, c.defaultUserAgent())
|
||||
headers.Set("Accept", "application/json")
|
||||
body, _, err := c.doRequest(ctx, account, http.MethodGet, c.buildURL("/nf/check"), headers, nil, false)
|
||||
if err != nil {
|
||||
var upstreamErr *SoraUpstreamError
|
||||
if errors.As(err, &upstreamErr) && upstreamErr.StatusCode == http.StatusNotFound {
|
||||
return &SoraUpstreamError{
|
||||
StatusCode: http.StatusForbidden,
|
||||
Message: "当前账号未开通 Sora2 能力或无可用配额",
|
||||
Headers: upstreamErr.Headers,
|
||||
Body: upstreamErr.Body,
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
rateLimitReached := gjson.GetBytes(body, "rate_limit_and_credit_balance.rate_limit_reached").Bool()
|
||||
remaining := gjson.GetBytes(body, "rate_limit_and_credit_balance.estimated_num_videos_remaining")
|
||||
if rateLimitReached || (remaining.Exists() && remaining.Int() <= 0) {
|
||||
msg := "当前账号 Sora2 可用配额不足"
|
||||
if requestedModel != "" {
|
||||
msg = fmt.Sprintf("当前账号 %s 可用配额不足", requestedModel)
|
||||
}
|
||||
return &SoraUpstreamError{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Message: msg,
|
||||
Headers: http.Header{},
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) {
|
||||
@@ -347,6 +424,45 @@ func (c *SoraDirectClient) CreateVideoTask(ctx context.Context, account *Account
|
||||
return taskID, nil
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) {
|
||||
token, err := c.getAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if strings.TrimSpace(expansionLevel) == "" {
|
||||
expansionLevel = "medium"
|
||||
}
|
||||
if durationS <= 0 {
|
||||
durationS = 10
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"prompt": prompt,
|
||||
"expansion_level": expansionLevel,
|
||||
"duration_s": durationS,
|
||||
}
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
headers := c.buildBaseHeaders(token, c.defaultUserAgent())
|
||||
headers.Set("Content-Type", "application/json")
|
||||
headers.Set("Accept", "application/json")
|
||||
headers.Set("Origin", "https://sora.chatgpt.com")
|
||||
headers.Set("Referer", "https://sora.chatgpt.com/")
|
||||
|
||||
respBody, _, err := c.doRequest(ctx, account, http.MethodPost, c.buildURL("/editor/enhance_prompt"), headers, bytes.NewReader(body), false)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
enhancedPrompt := strings.TrimSpace(gjson.GetBytes(respBody, "enhanced_prompt").String())
|
||||
if enhancedPrompt == "" {
|
||||
return "", errors.New("enhance_prompt response missing enhanced_prompt")
|
||||
}
|
||||
return enhancedPrompt, nil
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) {
|
||||
status, found, err := c.fetchRecentImageTask(ctx, account, taskID, c.recentTaskLimit())
|
||||
if err != nil {
|
||||
@@ -512,9 +628,10 @@ func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, t
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) buildURL(endpoint string) string {
|
||||
base := ""
|
||||
if c != nil && c.cfg != nil {
|
||||
base = strings.TrimRight(strings.TrimSpace(c.cfg.Sora.Client.BaseURL), "/")
|
||||
base := strings.TrimRight(strings.TrimSpace(c.baseURL), "/")
|
||||
if base == "" && c != nil && c.cfg != nil {
|
||||
base = normalizeSoraBaseURL(c.cfg.Sora.Client.BaseURL)
|
||||
c.baseURL = base
|
||||
}
|
||||
if base == "" {
|
||||
return endpoint
|
||||
@@ -540,14 +657,257 @@ func (c *SoraDirectClient) getAccessToken(ctx context.Context, account *Account)
|
||||
if account == nil {
|
||||
return "", errors.New("account is nil")
|
||||
}
|
||||
if c.tokenProvider != nil {
|
||||
return c.tokenProvider.GetAccessToken(ctx, account)
|
||||
|
||||
allowProvider := c.allowOpenAITokenProvider(account)
|
||||
var providerErr error
|
||||
if allowProvider && c.tokenProvider != nil {
|
||||
token, err := c.tokenProvider.GetAccessToken(ctx, account)
|
||||
if err == nil && strings.TrimSpace(token) != "" {
|
||||
c.logTokenSource(account, "openai_token_provider")
|
||||
return token, nil
|
||||
}
|
||||
providerErr = err
|
||||
if err != nil && c.debugEnabled() {
|
||||
c.debugLogf(
|
||||
"token_provider_failed account_id=%d platform=%s err=%s",
|
||||
account.ID,
|
||||
account.Platform,
|
||||
logredact.RedactText(err.Error()),
|
||||
)
|
||||
}
|
||||
}
|
||||
token := strings.TrimSpace(account.GetCredential("access_token"))
|
||||
if token == "" {
|
||||
return "", errors.New("access_token not found")
|
||||
if token != "" {
|
||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||
if expiresAt != nil && time.Until(*expiresAt) <= 2*time.Minute {
|
||||
refreshed, refreshErr := c.recoverAccessToken(ctx, account, "access_token_expiring")
|
||||
if refreshErr == nil && strings.TrimSpace(refreshed) != "" {
|
||||
c.logTokenSource(account, "refresh_token_recovered")
|
||||
return refreshed, nil
|
||||
}
|
||||
if refreshErr != nil && c.debugEnabled() {
|
||||
c.debugLogf("token_refresh_before_use_failed account_id=%d err=%s", account.ID, logredact.RedactText(refreshErr.Error()))
|
||||
}
|
||||
}
|
||||
c.logTokenSource(account, "account_credentials")
|
||||
return token, nil
|
||||
}
|
||||
return token, nil
|
||||
|
||||
recovered, recoverErr := c.recoverAccessToken(ctx, account, "access_token_missing")
|
||||
if recoverErr == nil && strings.TrimSpace(recovered) != "" {
|
||||
c.logTokenSource(account, "session_or_refresh_recovered")
|
||||
return recovered, nil
|
||||
}
|
||||
if recoverErr != nil && c.debugEnabled() {
|
||||
c.debugLogf("token_recover_failed account_id=%d platform=%s err=%s", account.ID, account.Platform, logredact.RedactText(recoverErr.Error()))
|
||||
}
|
||||
if providerErr != nil {
|
||||
return "", providerErr
|
||||
}
|
||||
if c.tokenProvider != nil && !allowProvider {
|
||||
c.logTokenSource(account, "account_credentials(provider_disabled)")
|
||||
}
|
||||
return "", errors.New("access_token not found")
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) recoverAccessToken(ctx context.Context, account *Account, reason string) (string, error) {
|
||||
if account == nil {
|
||||
return "", errors.New("account is nil")
|
||||
}
|
||||
|
||||
if sessionToken := strings.TrimSpace(account.GetCredential("session_token")); sessionToken != "" {
|
||||
accessToken, expiresAt, err := c.exchangeSessionToken(ctx, account, sessionToken)
|
||||
if err == nil && strings.TrimSpace(accessToken) != "" {
|
||||
c.applyRecoveredToken(ctx, account, accessToken, "", expiresAt, sessionToken)
|
||||
c.logTokenRecover(account, "session_token", reason, true, nil)
|
||||
return accessToken, nil
|
||||
}
|
||||
c.logTokenRecover(account, "session_token", reason, false, err)
|
||||
}
|
||||
|
||||
refreshToken := strings.TrimSpace(account.GetCredential("refresh_token"))
|
||||
if refreshToken == "" {
|
||||
return "", errors.New("session_token/refresh_token not found")
|
||||
}
|
||||
accessToken, newRefreshToken, expiresAt, err := c.exchangeRefreshToken(ctx, account, refreshToken)
|
||||
if err != nil {
|
||||
c.logTokenRecover(account, "refresh_token", reason, false, err)
|
||||
return "", err
|
||||
}
|
||||
if strings.TrimSpace(accessToken) == "" {
|
||||
return "", errors.New("refreshed access_token is empty")
|
||||
}
|
||||
c.applyRecoveredToken(ctx, account, accessToken, newRefreshToken, expiresAt, "")
|
||||
c.logTokenRecover(account, "refresh_token", reason, true, nil)
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) exchangeSessionToken(ctx context.Context, account *Account, sessionToken string) (string, string, error) {
|
||||
headers := http.Header{}
|
||||
headers.Set("Cookie", "__Secure-next-auth.session-token="+sessionToken)
|
||||
headers.Set("Accept", "application/json")
|
||||
headers.Set("Origin", "https://sora.chatgpt.com")
|
||||
headers.Set("Referer", "https://sora.chatgpt.com/")
|
||||
headers.Set("User-Agent", c.defaultUserAgent())
|
||||
body, _, err := c.doRequest(ctx, account, http.MethodGet, soraSessionAuthURL, headers, nil, false)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
accessToken := strings.TrimSpace(gjson.GetBytes(body, "accessToken").String())
|
||||
if accessToken == "" {
|
||||
return "", "", errors.New("session exchange missing accessToken")
|
||||
}
|
||||
expiresAt := strings.TrimSpace(gjson.GetBytes(body, "expires").String())
|
||||
return accessToken, expiresAt, nil
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) exchangeRefreshToken(ctx context.Context, account *Account, refreshToken string) (string, string, string, error) {
|
||||
clientIDs := []string{
|
||||
strings.TrimSpace(account.GetCredential("client_id")),
|
||||
openaioauth.SoraClientID,
|
||||
openaioauth.ClientID,
|
||||
}
|
||||
tried := make(map[string]struct{}, len(clientIDs))
|
||||
var lastErr error
|
||||
|
||||
for _, clientID := range clientIDs {
|
||||
if clientID == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := tried[clientID]; ok {
|
||||
continue
|
||||
}
|
||||
tried[clientID] = struct{}{}
|
||||
|
||||
payload := map[string]any{
|
||||
"client_id": clientID,
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refreshToken,
|
||||
"redirect_uri": "com.openai.chat://auth0.openai.com/ios/com.openai.chat/callback",
|
||||
}
|
||||
bodyBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
headers := http.Header{}
|
||||
headers.Set("Accept", "application/json")
|
||||
headers.Set("Content-Type", "application/json")
|
||||
headers.Set("User-Agent", c.defaultUserAgent())
|
||||
|
||||
respBody, _, err := c.doRequest(ctx, account, http.MethodPost, soraOAuthTokenURL, headers, bytes.NewReader(bodyBytes), false)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
if c.debugEnabled() {
|
||||
c.debugLogf("refresh_token_exchange_failed account_id=%d client_id=%s err=%s", account.ID, clientID, logredact.RedactText(err.Error()))
|
||||
}
|
||||
continue
|
||||
}
|
||||
accessToken := strings.TrimSpace(gjson.GetBytes(respBody, "access_token").String())
|
||||
if accessToken == "" {
|
||||
lastErr = errors.New("oauth refresh response missing access_token")
|
||||
continue
|
||||
}
|
||||
newRefreshToken := strings.TrimSpace(gjson.GetBytes(respBody, "refresh_token").String())
|
||||
expiresIn := gjson.GetBytes(respBody, "expires_in").Int()
|
||||
expiresAt := ""
|
||||
if expiresIn > 0 {
|
||||
expiresAt = time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339)
|
||||
}
|
||||
return accessToken, newRefreshToken, expiresAt, nil
|
||||
}
|
||||
|
||||
if lastErr != nil {
|
||||
return "", "", "", lastErr
|
||||
}
|
||||
return "", "", "", errors.New("no available client_id for refresh_token exchange")
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) applyRecoveredToken(ctx context.Context, account *Account, accessToken, refreshToken, expiresAt, sessionToken string) {
|
||||
if account == nil {
|
||||
return
|
||||
}
|
||||
if account.Credentials == nil {
|
||||
account.Credentials = make(map[string]any)
|
||||
}
|
||||
if strings.TrimSpace(accessToken) != "" {
|
||||
account.Credentials["access_token"] = accessToken
|
||||
}
|
||||
if strings.TrimSpace(refreshToken) != "" {
|
||||
account.Credentials["refresh_token"] = refreshToken
|
||||
}
|
||||
if strings.TrimSpace(expiresAt) != "" {
|
||||
account.Credentials["expires_at"] = expiresAt
|
||||
}
|
||||
if strings.TrimSpace(sessionToken) != "" {
|
||||
account.Credentials["session_token"] = sessionToken
|
||||
}
|
||||
|
||||
if c.accountRepo != nil {
|
||||
if err := c.accountRepo.Update(ctx, account); err != nil {
|
||||
if c.debugEnabled() {
|
||||
c.debugLogf("persist_recovered_token_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error()))
|
||||
}
|
||||
}
|
||||
}
|
||||
c.updateSoraAccountExtension(ctx, account, accessToken, refreshToken, sessionToken)
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) updateSoraAccountExtension(ctx context.Context, account *Account, accessToken, refreshToken, sessionToken string) {
|
||||
if c == nil || c.soraAccountRepo == nil || account == nil || account.ID <= 0 {
|
||||
return
|
||||
}
|
||||
updates := make(map[string]any)
|
||||
if strings.TrimSpace(accessToken) != "" && strings.TrimSpace(refreshToken) != "" {
|
||||
updates["access_token"] = accessToken
|
||||
updates["refresh_token"] = refreshToken
|
||||
}
|
||||
if strings.TrimSpace(sessionToken) != "" {
|
||||
updates["session_token"] = sessionToken
|
||||
}
|
||||
if len(updates) == 0 {
|
||||
return
|
||||
}
|
||||
if err := c.soraAccountRepo.Upsert(ctx, account.ID, updates); err != nil && c.debugEnabled() {
|
||||
c.debugLogf("persist_sora_extension_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error()))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) logTokenRecover(account *Account, source, reason string, success bool, err error) {
|
||||
if !c.debugEnabled() || account == nil {
|
||||
return
|
||||
}
|
||||
if success {
|
||||
c.debugLogf("token_recover_success account_id=%d platform=%s source=%s reason=%s", account.ID, account.Platform, source, reason)
|
||||
return
|
||||
}
|
||||
if err == nil {
|
||||
c.debugLogf("token_recover_failed account_id=%d platform=%s source=%s reason=%s", account.ID, account.Platform, source, reason)
|
||||
return
|
||||
}
|
||||
c.debugLogf("token_recover_failed account_id=%d platform=%s source=%s reason=%s err=%s", account.ID, account.Platform, source, reason, logredact.RedactText(err.Error()))
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) allowOpenAITokenProvider(account *Account) bool {
|
||||
if c == nil || c.tokenProvider == nil {
|
||||
return false
|
||||
}
|
||||
if account != nil && account.Platform == PlatformSora {
|
||||
return c.cfg != nil && c.cfg.Sora.Client.UseOpenAITokenProvider
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) logTokenSource(account *Account, source string) {
|
||||
if !c.debugEnabled() || account == nil {
|
||||
return
|
||||
}
|
||||
c.debugLogf(
|
||||
"token_selected account_id=%d platform=%s account_type=%s source=%s",
|
||||
account.ID,
|
||||
account.Platform,
|
||||
account.Type,
|
||||
source,
|
||||
)
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) buildBaseHeaders(token, userAgent string) http.Header {
|
||||
@@ -600,7 +960,24 @@ func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, meth
|
||||
}
|
||||
|
||||
attempts := maxRetries + 1
|
||||
authRecovered := false
|
||||
authRecoverExtraAttemptGranted := false
|
||||
var lastErr error
|
||||
for attempt := 1; attempt <= attempts; attempt++ {
|
||||
if c.debugEnabled() {
|
||||
c.debugLogf(
|
||||
"request_start method=%s url=%s attempt=%d/%d timeout_s=%d body_bytes=%d proxy_bound=%t headers=%s",
|
||||
method,
|
||||
sanitizeSoraLogURL(urlStr),
|
||||
attempt,
|
||||
attempts,
|
||||
timeout,
|
||||
len(bodyBytes),
|
||||
account != nil && account.ProxyID != nil && account.Proxy != nil,
|
||||
formatSoraHeaders(headers),
|
||||
)
|
||||
}
|
||||
|
||||
var reader io.Reader
|
||||
if bodyBytes != nil {
|
||||
reader = bytes.NewReader(bodyBytes)
|
||||
@@ -618,7 +995,21 @@ func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, meth
|
||||
}
|
||||
resp, err := c.doHTTP(req, proxyURL, account)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
if c.debugEnabled() {
|
||||
c.debugLogf(
|
||||
"request_transport_error method=%s url=%s attempt=%d/%d err=%s",
|
||||
method,
|
||||
sanitizeSoraLogURL(urlStr),
|
||||
attempt,
|
||||
attempts,
|
||||
logredact.RedactText(err.Error()),
|
||||
)
|
||||
}
|
||||
if attempt < attempts && allowRetry {
|
||||
if c.debugEnabled() {
|
||||
c.debugLogf("request_retry_scheduled method=%s url=%s reason=transport_error next_attempt=%d/%d", method, sanitizeSoraLogURL(urlStr), attempt+1, attempts)
|
||||
}
|
||||
c.sleepRetry(attempt)
|
||||
continue
|
||||
}
|
||||
@@ -632,12 +1023,53 @@ func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, meth
|
||||
}
|
||||
|
||||
if c.cfg != nil && c.cfg.Sora.Client.Debug {
|
||||
log.Printf("[SoraClient] %s %s status=%d cost=%s", method, sanitizeSoraLogURL(urlStr), resp.StatusCode, time.Since(start))
|
||||
c.debugLogf(
|
||||
"response_received method=%s url=%s attempt=%d/%d status=%d cost=%s resp_bytes=%d resp_headers=%s",
|
||||
method,
|
||||
sanitizeSoraLogURL(urlStr),
|
||||
attempt,
|
||||
attempts,
|
||||
resp.StatusCode,
|
||||
time.Since(start),
|
||||
len(respBody),
|
||||
formatSoraHeaders(resp.Header),
|
||||
)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
|
||||
upstreamErr := c.buildUpstreamError(resp.StatusCode, resp.Header, respBody)
|
||||
if !authRecovered && shouldAttemptSoraTokenRecover(resp.StatusCode, urlStr) && account != nil {
|
||||
if recovered, recoverErr := c.recoverAccessToken(ctx, account, fmt.Sprintf("upstream_status_%d", resp.StatusCode)); recoverErr == nil && strings.TrimSpace(recovered) != "" {
|
||||
headers.Set("Authorization", "Bearer "+recovered)
|
||||
authRecovered = true
|
||||
if attempt == attempts && !authRecoverExtraAttemptGranted {
|
||||
attempts++
|
||||
authRecoverExtraAttemptGranted = true
|
||||
}
|
||||
if c.debugEnabled() {
|
||||
c.debugLogf("request_retry_with_recovered_token method=%s url=%s status=%d", method, sanitizeSoraLogURL(urlStr), resp.StatusCode)
|
||||
}
|
||||
continue
|
||||
} else if recoverErr != nil && c.debugEnabled() {
|
||||
c.debugLogf("request_recover_token_failed method=%s url=%s status=%d err=%s", method, sanitizeSoraLogURL(urlStr), resp.StatusCode, logredact.RedactText(recoverErr.Error()))
|
||||
}
|
||||
}
|
||||
if c.debugEnabled() {
|
||||
c.debugLogf(
|
||||
"response_non_success method=%s url=%s attempt=%d/%d status=%d body=%s",
|
||||
method,
|
||||
sanitizeSoraLogURL(urlStr),
|
||||
attempt,
|
||||
attempts,
|
||||
resp.StatusCode,
|
||||
summarizeSoraResponseBody(respBody, 512),
|
||||
)
|
||||
}
|
||||
upstreamErr := c.buildUpstreamError(resp.StatusCode, resp.Header, respBody, urlStr)
|
||||
lastErr = upstreamErr
|
||||
if allowRetry && attempt < attempts && (resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode >= 500) {
|
||||
if c.debugEnabled() {
|
||||
c.debugLogf("request_retry_scheduled method=%s url=%s reason=status_%d next_attempt=%d/%d", method, sanitizeSoraLogURL(urlStr), resp.StatusCode, attempt+1, attempts)
|
||||
}
|
||||
c.sleepRetry(attempt)
|
||||
continue
|
||||
}
|
||||
@@ -645,9 +1077,34 @@ func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, meth
|
||||
}
|
||||
return respBody, resp.Header, nil
|
||||
}
|
||||
if lastErr != nil {
|
||||
return nil, nil, lastErr
|
||||
}
|
||||
return nil, nil, errors.New("upstream retries exhausted")
|
||||
}
|
||||
|
||||
func shouldAttemptSoraTokenRecover(statusCode int, rawURL string) bool {
|
||||
switch statusCode {
|
||||
case http.StatusUnauthorized, http.StatusForbidden:
|
||||
parsed, err := url.Parse(strings.TrimSpace(rawURL))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
host := strings.ToLower(parsed.Hostname())
|
||||
if host != "sora.chatgpt.com" && host != "chatgpt.com" {
|
||||
return false
|
||||
}
|
||||
// 避免在 ST->AT 转换接口上递归触发 token 恢复导致死循环。
|
||||
path := strings.ToLower(strings.TrimSpace(parsed.Path))
|
||||
if path == "/api/auth/session" {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) doHTTP(req *http.Request, proxyURL string, account *Account) (*http.Response, error) {
|
||||
enableTLS := c != nil && c.cfg != nil && c.cfg.Gateway.TLSFingerprint.Enabled && !c.cfg.Sora.Client.DisableTLSFingerprint
|
||||
if c.httpUpstream != nil {
|
||||
@@ -670,9 +1127,14 @@ func (c *SoraDirectClient) sleepRetry(attempt int) {
|
||||
time.Sleep(backoff)
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) buildUpstreamError(status int, headers http.Header, body []byte) error {
|
||||
func (c *SoraDirectClient) buildUpstreamError(status int, headers http.Header, body []byte, requestURL string) error {
|
||||
msg := strings.TrimSpace(extractUpstreamErrorMessage(body))
|
||||
msg = sanitizeUpstreamErrorMessage(msg)
|
||||
if status == http.StatusNotFound && strings.Contains(strings.ToLower(msg), "not found") {
|
||||
if hint := soraBaseURLNotFoundHint(requestURL); hint != "" {
|
||||
msg = strings.TrimSpace(msg + " " + hint)
|
||||
}
|
||||
}
|
||||
if msg == "" {
|
||||
msg = truncateForLog(body, 256)
|
||||
}
|
||||
@@ -684,6 +1146,45 @@ func (c *SoraDirectClient) buildUpstreamError(status int, headers http.Header, b
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeSoraBaseURL(raw string) string {
|
||||
trimmed := strings.TrimRight(strings.TrimSpace(raw), "/")
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
parsed, err := url.Parse(trimmed)
|
||||
if err != nil || parsed.Scheme == "" || parsed.Host == "" {
|
||||
return trimmed
|
||||
}
|
||||
host := strings.ToLower(parsed.Hostname())
|
||||
if host != "sora.chatgpt.com" && host != "chatgpt.com" {
|
||||
return trimmed
|
||||
}
|
||||
pathVal := strings.TrimRight(strings.TrimSpace(parsed.Path), "/")
|
||||
switch pathVal {
|
||||
case "", "/":
|
||||
parsed.Path = "/backend"
|
||||
case "/backend-api":
|
||||
parsed.Path = "/backend"
|
||||
}
|
||||
return strings.TrimRight(parsed.String(), "/")
|
||||
}
|
||||
|
||||
func soraBaseURLNotFoundHint(requestURL string) string {
|
||||
parsed, err := url.Parse(strings.TrimSpace(requestURL))
|
||||
if err != nil || parsed.Host == "" {
|
||||
return ""
|
||||
}
|
||||
host := strings.ToLower(parsed.Hostname())
|
||||
if host != "sora.chatgpt.com" && host != "chatgpt.com" {
|
||||
return ""
|
||||
}
|
||||
pathVal := strings.TrimSpace(parsed.Path)
|
||||
if strings.HasPrefix(pathVal, "/backend/") || pathVal == "/backend" {
|
||||
return ""
|
||||
}
|
||||
return "(请检查 sora.client.base_url,建议配置为 https://sora.chatgpt.com/backend)"
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) generateSentinelToken(ctx context.Context, account *Account, accessToken string) (string, error) {
|
||||
reqID := uuid.NewString()
|
||||
userAgent := soraRandChoice(soraDesktopUserAgents)
|
||||
@@ -901,3 +1402,70 @@ func sanitizeSoraLogURL(raw string) string {
|
||||
parsed.RawQuery = q.Encode()
|
||||
return parsed.String()
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) debugEnabled() bool {
|
||||
return c != nil && c.cfg != nil && c.cfg.Sora.Client.Debug
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) debugLogf(format string, args ...any) {
|
||||
if !c.debugEnabled() {
|
||||
return
|
||||
}
|
||||
log.Printf("[SoraClient] "+format, args...)
|
||||
}
|
||||
|
||||
func formatSoraHeaders(headers http.Header) string {
|
||||
if len(headers) == 0 {
|
||||
return "{}"
|
||||
}
|
||||
keys := make([]string, 0, len(headers))
|
||||
for key := range headers {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
out := make(map[string]string, len(keys))
|
||||
for _, key := range keys {
|
||||
values := headers.Values(key)
|
||||
if len(values) == 0 {
|
||||
continue
|
||||
}
|
||||
val := strings.Join(values, ",")
|
||||
if isSensitiveHeader(key) {
|
||||
out[key] = "***"
|
||||
continue
|
||||
}
|
||||
out[key] = truncateForLog([]byte(logredact.RedactText(val)), 160)
|
||||
}
|
||||
encoded, err := json.Marshal(out)
|
||||
if err != nil {
|
||||
return "{}"
|
||||
}
|
||||
return string(encoded)
|
||||
}
|
||||
|
||||
func isSensitiveHeader(key string) bool {
|
||||
k := strings.ToLower(strings.TrimSpace(key))
|
||||
switch k {
|
||||
case "authorization", "openai-sentinel-token", "cookie", "set-cookie", "x-api-key":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func summarizeSoraResponseBody(body []byte, maxLen int) string {
|
||||
if len(body) == 0 {
|
||||
return ""
|
||||
}
|
||||
var text string
|
||||
if json.Valid(body) {
|
||||
text = logredact.RedactJSON(body)
|
||||
} else {
|
||||
text = logredact.RedactText(string(body))
|
||||
}
|
||||
text = strings.TrimSpace(text)
|
||||
if maxLen <= 0 || len(text) <= maxLen {
|
||||
return text
|
||||
}
|
||||
return text[:maxLen] + "...(truncated)"
|
||||
}
|
||||
|
||||
@@ -4,9 +4,13 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -85,3 +89,273 @@ func TestSoraDirectClient_GetImageTaskFallbackLimit(t *testing.T) {
|
||||
require.Equal(t, "completed", status.Status)
|
||||
require.Equal(t, []string{"https://example.com/a.png"}, status.URLs)
|
||||
}
|
||||
|
||||
func TestNormalizeSoraBaseURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
raw string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
raw: "",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "append_backend_for_sora_host",
|
||||
raw: "https://sora.chatgpt.com",
|
||||
want: "https://sora.chatgpt.com/backend",
|
||||
},
|
||||
{
|
||||
name: "convert_backend_api_to_backend",
|
||||
raw: "https://sora.chatgpt.com/backend-api",
|
||||
want: "https://sora.chatgpt.com/backend",
|
||||
},
|
||||
{
|
||||
name: "keep_backend",
|
||||
raw: "https://sora.chatgpt.com/backend",
|
||||
want: "https://sora.chatgpt.com/backend",
|
||||
},
|
||||
{
|
||||
name: "keep_custom_host",
|
||||
raw: "https://example.com/custom-path",
|
||||
want: "https://example.com/custom-path",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := normalizeSoraBaseURL(tt.raw)
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSoraDirectClient_BuildURL_UsesNormalizedBaseURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
cfg := &config.Config{
|
||||
Sora: config.SoraConfig{
|
||||
Client: config.SoraClientConfig{
|
||||
BaseURL: "https://sora.chatgpt.com",
|
||||
},
|
||||
},
|
||||
}
|
||||
client := NewSoraDirectClient(cfg, nil, nil)
|
||||
require.Equal(t, "https://sora.chatgpt.com/backend/video_gen", client.buildURL("/video_gen"))
|
||||
}
|
||||
|
||||
func TestSoraDirectClient_BuildUpstreamError_NotFoundHint(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := NewSoraDirectClient(&config.Config{}, nil, nil)
|
||||
|
||||
err := client.buildUpstreamError(http.StatusNotFound, http.Header{}, []byte(`{"error":{"message":"Not found"}}`), "https://sora.chatgpt.com/video_gen")
|
||||
var upstreamErr *SoraUpstreamError
|
||||
require.ErrorAs(t, err, &upstreamErr)
|
||||
require.Contains(t, upstreamErr.Message, "请检查 sora.client.base_url")
|
||||
|
||||
errNoHint := client.buildUpstreamError(http.StatusNotFound, http.Header{}, []byte(`{"error":{"message":"Not found"}}`), "https://sora.chatgpt.com/backend/video_gen")
|
||||
require.ErrorAs(t, errNoHint, &upstreamErr)
|
||||
require.NotContains(t, upstreamErr.Message, "请检查 sora.client.base_url")
|
||||
}
|
||||
|
||||
func TestFormatSoraHeaders_RedactsSensitive(t *testing.T) {
|
||||
t.Parallel()
|
||||
headers := http.Header{}
|
||||
headers.Set("Authorization", "Bearer secret-token")
|
||||
headers.Set("openai-sentinel-token", "sentinel-secret")
|
||||
headers.Set("X-Test", "ok")
|
||||
|
||||
out := formatSoraHeaders(headers)
|
||||
require.Contains(t, out, `"Authorization":"***"`)
|
||||
require.Contains(t, out, `Sentinel-Token":"***"`)
|
||||
require.Contains(t, out, `"X-Test":"ok"`)
|
||||
require.NotContains(t, out, "secret-token")
|
||||
require.NotContains(t, out, "sentinel-secret")
|
||||
}
|
||||
|
||||
func TestSummarizeSoraResponseBody_RedactsJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"error":{"message":"bad"},"access_token":"abc123"}`)
|
||||
out := summarizeSoraResponseBody(body, 512)
|
||||
require.Contains(t, out, `"access_token":"***"`)
|
||||
require.NotContains(t, out, "abc123")
|
||||
}
|
||||
|
||||
func TestSummarizeSoraResponseBody_Truncates(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(strings.Repeat("x", 100))
|
||||
out := summarizeSoraResponseBody(body, 10)
|
||||
require.Contains(t, out, "(truncated)")
|
||||
}
|
||||
|
||||
func TestSoraDirectClient_GetAccessToken_SoraDefaultUseCredentials(t *testing.T) {
|
||||
t.Parallel()
|
||||
cache := newOpenAITokenCacheStub()
|
||||
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||
cfg := &config.Config{
|
||||
Sora: config.SoraConfig{
|
||||
Client: config.SoraClientConfig{
|
||||
BaseURL: "https://sora.chatgpt.com/backend",
|
||||
},
|
||||
},
|
||||
}
|
||||
client := NewSoraDirectClient(cfg, nil, provider)
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Platform: PlatformSora,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "sora-credential-token",
|
||||
},
|
||||
}
|
||||
|
||||
token, err := client.getAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "sora-credential-token", token)
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&cache.getCalled))
|
||||
}
|
||||
|
||||
func TestSoraDirectClient_GetAccessToken_SoraCanEnableProvider(t *testing.T) {
|
||||
t.Parallel()
|
||||
cache := newOpenAITokenCacheStub()
|
||||
account := &Account{
|
||||
ID: 2,
|
||||
Platform: PlatformSora,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "sora-credential-token",
|
||||
},
|
||||
}
|
||||
cache.tokens[OpenAITokenCacheKey(account)] = "provider-token"
|
||||
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||
cfg := &config.Config{
|
||||
Sora: config.SoraConfig{
|
||||
Client: config.SoraClientConfig{
|
||||
BaseURL: "https://sora.chatgpt.com/backend",
|
||||
UseOpenAITokenProvider: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
client := NewSoraDirectClient(cfg, nil, provider)
|
||||
|
||||
token, err := client.getAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "provider-token", token)
|
||||
require.Greater(t, atomic.LoadInt32(&cache.getCalled), int32(0))
|
||||
}
|
||||
|
||||
func TestSoraDirectClient_GetAccessToken_FromSessionToken(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, http.MethodGet, r.Method)
|
||||
require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=session-token")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"accessToken": "session-access-token",
|
||||
"expires": "2099-01-01T00:00:00Z",
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
origin := soraSessionAuthURL
|
||||
soraSessionAuthURL = server.URL
|
||||
defer func() { soraSessionAuthURL = origin }()
|
||||
|
||||
client := NewSoraDirectClient(&config.Config{}, nil, nil)
|
||||
account := &Account{
|
||||
ID: 10,
|
||||
Platform: PlatformSora,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"session_token": "session-token",
|
||||
},
|
||||
}
|
||||
|
||||
token, err := client.getAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "session-access-token", token)
|
||||
require.Equal(t, "session-access-token", account.GetCredential("access_token"))
|
||||
}
|
||||
|
||||
func TestSoraDirectClient_GetAccessToken_FromRefreshToken(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, http.MethodPost, r.Method)
|
||||
require.Equal(t, "/oauth/token", r.URL.Path)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"access_token": "refresh-access-token",
|
||||
"refresh_token": "refresh-token-new",
|
||||
"expires_in": 3600,
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
origin := soraOAuthTokenURL
|
||||
soraOAuthTokenURL = server.URL + "/oauth/token"
|
||||
defer func() { soraOAuthTokenURL = origin }()
|
||||
|
||||
client := NewSoraDirectClient(&config.Config{}, nil, nil)
|
||||
account := &Account{
|
||||
ID: 11,
|
||||
Platform: PlatformSora,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"refresh_token": "refresh-token-old",
|
||||
},
|
||||
}
|
||||
|
||||
token, err := client.getAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "refresh-access-token", token)
|
||||
require.Equal(t, "refresh-token-new", account.GetCredential("refresh_token"))
|
||||
require.NotNil(t, account.GetCredentialAsTime("expires_at"))
|
||||
}
|
||||
|
||||
func TestSoraDirectClient_PreflightCheck_VideoQuotaExceeded(t *testing.T) {
|
||||
t.Parallel()
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, http.MethodGet, r.Method)
|
||||
require.Equal(t, "/nf/check", r.URL.Path)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"rate_limit_and_credit_balance": map[string]any{
|
||||
"estimated_num_videos_remaining": 0,
|
||||
"rate_limit_reached": true,
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cfg := &config.Config{
|
||||
Sora: config.SoraConfig{
|
||||
Client: config.SoraClientConfig{
|
||||
BaseURL: server.URL,
|
||||
},
|
||||
},
|
||||
}
|
||||
client := NewSoraDirectClient(cfg, nil, nil)
|
||||
account := &Account{
|
||||
ID: 12,
|
||||
Platform: PlatformSora,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "ok",
|
||||
"expires_at": time.Now().Add(2 * time.Hour).Format(time.RFC3339),
|
||||
},
|
||||
}
|
||||
err := client.PreflightCheck(context.Background(), account, "sora2-landscape-10s", SoraModelConfig{Type: "video"})
|
||||
require.Error(t, err)
|
||||
var upstreamErr *SoraUpstreamError
|
||||
require.ErrorAs(t, err, &upstreamErr)
|
||||
require.Equal(t, http.StatusTooManyRequests, upstreamErr.StatusCode)
|
||||
}
|
||||
|
||||
func TestShouldAttemptSoraTokenRecover(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require.True(t, shouldAttemptSoraTokenRecover(http.StatusUnauthorized, "https://sora.chatgpt.com/backend/video_gen"))
|
||||
require.True(t, shouldAttemptSoraTokenRecover(http.StatusForbidden, "https://chatgpt.com/backend/video_gen"))
|
||||
require.False(t, shouldAttemptSoraTokenRecover(http.StatusUnauthorized, "https://sora.chatgpt.com/api/auth/session"))
|
||||
require.False(t, shouldAttemptSoraTokenRecover(http.StatusUnauthorized, "https://auth.openai.com/oauth/token"))
|
||||
require.False(t, shouldAttemptSoraTokenRecover(http.StatusTooManyRequests, "https://sora.chatgpt.com/backend/video_gen"))
|
||||
}
|
||||
|
||||
@@ -61,6 +61,10 @@ type SoraGatewayService struct {
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
type soraPreflightChecker interface {
|
||||
PreflightCheck(ctx context.Context, account *Account, requestedModel string, modelCfg SoraModelConfig) error
|
||||
}
|
||||
|
||||
func NewSoraGatewayService(
|
||||
soraClient SoraClient,
|
||||
mediaStorage *SoraMediaStorage,
|
||||
@@ -112,11 +116,6 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
|
||||
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Unsupported Sora model", clientStream)
|
||||
return nil, fmt.Errorf("unsupported model: %s", reqModel)
|
||||
}
|
||||
if modelCfg.Type == "prompt_enhance" {
|
||||
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Prompt-enhance 模型暂未支持", clientStream)
|
||||
return nil, fmt.Errorf("prompt-enhance not supported")
|
||||
}
|
||||
|
||||
prompt, imageInput, videoInput, remixTargetID := extractSoraInput(reqBody)
|
||||
if strings.TrimSpace(prompt) == "" {
|
||||
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream)
|
||||
@@ -131,6 +130,41 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
|
||||
if cancel != nil {
|
||||
defer cancel()
|
||||
}
|
||||
if checker, ok := s.soraClient.(soraPreflightChecker); ok {
|
||||
if err := checker.PreflightCheck(reqCtx, account, reqModel, modelCfg); err != nil {
|
||||
return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream)
|
||||
}
|
||||
}
|
||||
|
||||
if modelCfg.Type == "prompt_enhance" {
|
||||
enhancedPrompt, err := s.soraClient.EnhancePrompt(reqCtx, account, prompt, modelCfg.ExpansionLevel, modelCfg.DurationS)
|
||||
if err != nil {
|
||||
return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream)
|
||||
}
|
||||
content := strings.TrimSpace(enhancedPrompt)
|
||||
if content == "" {
|
||||
content = prompt
|
||||
}
|
||||
var firstTokenMs *int
|
||||
if clientStream {
|
||||
ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime)
|
||||
if streamErr != nil {
|
||||
return nil, streamErr
|
||||
}
|
||||
firstTokenMs = ms
|
||||
} else if c != nil {
|
||||
c.JSON(http.StatusOK, buildSoraNonStreamResponse(content, reqModel))
|
||||
}
|
||||
return &ForwardResult{
|
||||
RequestID: "",
|
||||
Model: reqModel,
|
||||
Stream: clientStream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
Usage: ClaudeUsage{},
|
||||
MediaType: "prompt",
|
||||
}, nil
|
||||
}
|
||||
|
||||
var imageData []byte
|
||||
imageFilename := ""
|
||||
@@ -267,7 +301,7 @@ func (s *SoraGatewayService) withSoraTimeout(ctx context.Context, stream bool) (
|
||||
|
||||
func (s *SoraGatewayService) shouldFailoverUpstreamError(statusCode int) bool {
|
||||
switch statusCode {
|
||||
case 401, 402, 403, 429, 529:
|
||||
case 401, 402, 403, 404, 429, 529:
|
||||
return true
|
||||
default:
|
||||
return statusCode >= 500
|
||||
@@ -460,7 +494,7 @@ func (s *SoraGatewayService) handleSoraRequestError(ctx context.Context, account
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, upstreamErr.StatusCode, upstreamErr.Headers, upstreamErr.Body)
|
||||
}
|
||||
if s.shouldFailoverUpstreamError(upstreamErr.StatusCode) {
|
||||
return &UpstreamFailoverError{StatusCode: upstreamErr.StatusCode}
|
||||
return &UpstreamFailoverError{StatusCode: upstreamErr.StatusCode, ResponseBody: upstreamErr.Body}
|
||||
}
|
||||
msg := upstreamErr.Message
|
||||
if override := soraProErrorMessage(model, msg); override != "" {
|
||||
|
||||
@@ -18,6 +18,8 @@ type stubSoraClientForPoll struct {
|
||||
videoStatus *SoraVideoTaskStatus
|
||||
imageCalls int
|
||||
videoCalls int
|
||||
enhanced string
|
||||
enhanceErr error
|
||||
}
|
||||
|
||||
func (s *stubSoraClientForPoll) Enabled() bool { return true }
|
||||
@@ -30,6 +32,12 @@ func (s *stubSoraClientForPoll) CreateImageTask(ctx context.Context, account *Ac
|
||||
func (s *stubSoraClientForPoll) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) {
|
||||
return "task-video", nil
|
||||
}
|
||||
func (s *stubSoraClientForPoll) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) {
|
||||
if s.enhanced != "" {
|
||||
return s.enhanced, s.enhanceErr
|
||||
}
|
||||
return "enhanced prompt", s.enhanceErr
|
||||
}
|
||||
func (s *stubSoraClientForPoll) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) {
|
||||
s.imageCalls++
|
||||
return s.imageStatus, nil
|
||||
@@ -62,6 +70,33 @@ func TestSoraGatewayService_PollImageTaskCompleted(t *testing.T) {
|
||||
require.Equal(t, 1, client.imageCalls)
|
||||
}
|
||||
|
||||
func TestSoraGatewayService_ForwardPromptEnhance(t *testing.T) {
|
||||
client := &stubSoraClientForPoll{
|
||||
enhanced: "cinematic prompt",
|
||||
}
|
||||
cfg := &config.Config{
|
||||
Sora: config.SoraConfig{
|
||||
Client: config.SoraClientConfig{
|
||||
PollIntervalSeconds: 1,
|
||||
MaxPollAttempts: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := NewSoraGatewayService(client, nil, nil, cfg)
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Platform: PlatformSora,
|
||||
Status: StatusActive,
|
||||
}
|
||||
body := []byte(`{"model":"prompt-enhance-short-10s","messages":[{"role":"user","content":"cat running"}],"stream":false}`)
|
||||
|
||||
result, err := svc.Forward(context.Background(), nil, account, body, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, "prompt", result.MediaType)
|
||||
require.Equal(t, "prompt-enhance-short-10s", result.Model)
|
||||
}
|
||||
|
||||
func TestSoraGatewayService_PollVideoTaskFailed(t *testing.T) {
|
||||
client := &stubSoraClientForPoll{
|
||||
videoStatus: &SoraVideoTaskStatus{
|
||||
@@ -178,6 +213,7 @@ func TestSoraProErrorMessage(t *testing.T) {
|
||||
func TestShouldFailoverUpstreamError(t *testing.T) {
|
||||
svc := NewSoraGatewayService(nil, nil, nil, &config.Config{})
|
||||
require.True(t, svc.shouldFailoverUpstreamError(401))
|
||||
require.True(t, svc.shouldFailoverUpstreamError(404))
|
||||
require.True(t, svc.shouldFailoverUpstreamError(429))
|
||||
require.True(t, svc.shouldFailoverUpstreamError(500))
|
||||
require.True(t, svc.shouldFailoverUpstreamError(502))
|
||||
|
||||
@@ -17,6 +17,9 @@ type SoraModelConfig struct {
|
||||
Model string
|
||||
Size string
|
||||
RequirePro bool
|
||||
// Prompt-enhance 专用参数
|
||||
ExpansionLevel string
|
||||
DurationS int
|
||||
}
|
||||
|
||||
var soraModelConfigs = map[string]SoraModelConfig{
|
||||
@@ -160,31 +163,49 @@ var soraModelConfigs = map[string]SoraModelConfig{
|
||||
RequirePro: true,
|
||||
},
|
||||
"prompt-enhance-short-10s": {
|
||||
Type: "prompt_enhance",
|
||||
Type: "prompt_enhance",
|
||||
ExpansionLevel: "short",
|
||||
DurationS: 10,
|
||||
},
|
||||
"prompt-enhance-short-15s": {
|
||||
Type: "prompt_enhance",
|
||||
Type: "prompt_enhance",
|
||||
ExpansionLevel: "short",
|
||||
DurationS: 15,
|
||||
},
|
||||
"prompt-enhance-short-20s": {
|
||||
Type: "prompt_enhance",
|
||||
Type: "prompt_enhance",
|
||||
ExpansionLevel: "short",
|
||||
DurationS: 20,
|
||||
},
|
||||
"prompt-enhance-medium-10s": {
|
||||
Type: "prompt_enhance",
|
||||
Type: "prompt_enhance",
|
||||
ExpansionLevel: "medium",
|
||||
DurationS: 10,
|
||||
},
|
||||
"prompt-enhance-medium-15s": {
|
||||
Type: "prompt_enhance",
|
||||
Type: "prompt_enhance",
|
||||
ExpansionLevel: "medium",
|
||||
DurationS: 15,
|
||||
},
|
||||
"prompt-enhance-medium-20s": {
|
||||
Type: "prompt_enhance",
|
||||
Type: "prompt_enhance",
|
||||
ExpansionLevel: "medium",
|
||||
DurationS: 20,
|
||||
},
|
||||
"prompt-enhance-long-10s": {
|
||||
Type: "prompt_enhance",
|
||||
Type: "prompt_enhance",
|
||||
ExpansionLevel: "long",
|
||||
DurationS: 10,
|
||||
},
|
||||
"prompt-enhance-long-15s": {
|
||||
Type: "prompt_enhance",
|
||||
Type: "prompt_enhance",
|
||||
ExpansionLevel: "long",
|
||||
DurationS: 15,
|
||||
},
|
||||
"prompt-enhance-long-20s": {
|
||||
Type: "prompt_enhance",
|
||||
Type: "prompt_enhance",
|
||||
ExpansionLevel: "long",
|
||||
DurationS: 20,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -43,10 +43,13 @@ func NewTokenRefreshService(
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
openAIRefresher := NewOpenAITokenRefresher(openaiOAuthService, accountRepo)
|
||||
openAIRefresher.SetSyncLinkedSoraAccounts(cfg.TokenRefresh.SyncLinkedSoraAccounts)
|
||||
|
||||
// 注册平台特定的刷新器
|
||||
s.refreshers = []TokenRefresher{
|
||||
NewClaudeTokenRefresher(oauthService),
|
||||
NewOpenAITokenRefresher(openaiOAuthService, accountRepo),
|
||||
openAIRefresher,
|
||||
NewGeminiTokenRefresher(geminiOAuthService),
|
||||
NewAntigravityTokenRefresher(antigravityOAuthService),
|
||||
}
|
||||
|
||||
@@ -86,6 +86,7 @@ type OpenAITokenRefresher struct {
|
||||
openaiOAuthService *OpenAIOAuthService
|
||||
accountRepo AccountRepository
|
||||
soraAccountRepo SoraAccountRepository // Sora 扩展表仓储,用于双表同步
|
||||
syncLinkedSora bool
|
||||
}
|
||||
|
||||
// NewOpenAITokenRefresher 创建 OpenAI token刷新器
|
||||
@@ -103,11 +104,15 @@ func (r *OpenAITokenRefresher) SetSoraAccountRepo(repo SoraAccountRepository) {
|
||||
r.soraAccountRepo = repo
|
||||
}
|
||||
|
||||
// SetSyncLinkedSoraAccounts 控制是否同步覆盖关联的 Sora 账号 token。
|
||||
func (r *OpenAITokenRefresher) SetSyncLinkedSoraAccounts(enabled bool) {
|
||||
r.syncLinkedSora = enabled
|
||||
}
|
||||
|
||||
// CanRefresh 检查是否能处理此账号
|
||||
// 只处理 openai 平台的 oauth 类型账号
|
||||
// 只处理 openai 平台的 oauth 类型账号(不直接刷新 sora 平台账号)
|
||||
func (r *OpenAITokenRefresher) CanRefresh(account *Account) bool {
|
||||
return (account.Platform == PlatformOpenAI || account.Platform == PlatformSora) &&
|
||||
account.Type == AccountTypeOAuth
|
||||
return account.Platform == PlatformOpenAI && account.Type == AccountTypeOAuth
|
||||
}
|
||||
|
||||
// NeedsRefresh 检查token是否需要刷新
|
||||
@@ -141,7 +146,7 @@ func (r *OpenAITokenRefresher) Refresh(ctx context.Context, account *Account) (m
|
||||
}
|
||||
|
||||
// 异步同步关联的 Sora 账号(不阻塞主流程)
|
||||
if r.accountRepo != nil {
|
||||
if r.accountRepo != nil && r.syncLinkedSora {
|
||||
go r.syncLinkedSoraAccounts(context.Background(), account.ID, newCredentials)
|
||||
}
|
||||
|
||||
|
||||
@@ -226,3 +226,43 @@ func TestClaudeTokenRefresher_CanRefresh(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAITokenRefresher_CanRefresh(t *testing.T) {
|
||||
refresher := &OpenAITokenRefresher{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
platform string
|
||||
accType string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "openai oauth - can refresh",
|
||||
platform: PlatformOpenAI,
|
||||
accType: AccountTypeOAuth,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "sora oauth - cannot refresh directly",
|
||||
platform: PlatformSora,
|
||||
accType: AccountTypeOAuth,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "openai apikey - cannot refresh",
|
||||
platform: PlatformOpenAI,
|
||||
accType: AccountTypeAPIKey,
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: tt.platform,
|
||||
Type: tt.accType,
|
||||
}
|
||||
require.Equal(t, tt.want, refresher.CanRefresh(account))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -206,6 +206,18 @@ func ProvideSoraMediaStorage(cfg *config.Config) *SoraMediaStorage {
|
||||
return NewSoraMediaStorage(cfg)
|
||||
}
|
||||
|
||||
func ProvideSoraDirectClient(
|
||||
cfg *config.Config,
|
||||
httpUpstream HTTPUpstream,
|
||||
tokenProvider *OpenAITokenProvider,
|
||||
accountRepo AccountRepository,
|
||||
soraAccountRepo SoraAccountRepository,
|
||||
) *SoraDirectClient {
|
||||
client := NewSoraDirectClient(cfg, httpUpstream, tokenProvider)
|
||||
client.SetAccountRepositories(accountRepo, soraAccountRepo)
|
||||
return client
|
||||
}
|
||||
|
||||
// ProvideSoraMediaCleanupService 创建并启动 Sora 媒体清理服务
|
||||
func ProvideSoraMediaCleanupService(storage *SoraMediaStorage, cfg *config.Config) *SoraMediaCleanupService {
|
||||
svc := NewSoraMediaCleanupService(storage, cfg)
|
||||
@@ -255,7 +267,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewGatewayService,
|
||||
ProvideSoraMediaStorage,
|
||||
ProvideSoraMediaCleanupService,
|
||||
NewSoraDirectClient,
|
||||
ProvideSoraDirectClient,
|
||||
wire.Bind(new(SoraClient), new(*SoraDirectClient)),
|
||||
NewSoraGatewayService,
|
||||
NewOpenAIGatewayService,
|
||||
|
||||
Reference in New Issue
Block a user