feat(sora): 对齐 Sora OAuth 流程并隔离网关请求路径
- 新增并接通 Sora 专用 OAuth 接口与 ST/RT 换取能力 - 完成前端 Sora 授权、RT/ST 手动导入与账号创建流程 - 强化 Sora token 恢复、转发日志与网关路由隔离行为 - 补充后端服务层与路由层相关测试覆盖 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -4,9 +4,13 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -85,3 +89,273 @@ func TestSoraDirectClient_GetImageTaskFallbackLimit(t *testing.T) {
|
||||
require.Equal(t, "completed", status.Status)
|
||||
require.Equal(t, []string{"https://example.com/a.png"}, status.URLs)
|
||||
}
|
||||
|
||||
func TestNormalizeSoraBaseURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
raw string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
raw: "",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "append_backend_for_sora_host",
|
||||
raw: "https://sora.chatgpt.com",
|
||||
want: "https://sora.chatgpt.com/backend",
|
||||
},
|
||||
{
|
||||
name: "convert_backend_api_to_backend",
|
||||
raw: "https://sora.chatgpt.com/backend-api",
|
||||
want: "https://sora.chatgpt.com/backend",
|
||||
},
|
||||
{
|
||||
name: "keep_backend",
|
||||
raw: "https://sora.chatgpt.com/backend",
|
||||
want: "https://sora.chatgpt.com/backend",
|
||||
},
|
||||
{
|
||||
name: "keep_custom_host",
|
||||
raw: "https://example.com/custom-path",
|
||||
want: "https://example.com/custom-path",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := normalizeSoraBaseURL(tt.raw)
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSoraDirectClient_BuildURL_UsesNormalizedBaseURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
cfg := &config.Config{
|
||||
Sora: config.SoraConfig{
|
||||
Client: config.SoraClientConfig{
|
||||
BaseURL: "https://sora.chatgpt.com",
|
||||
},
|
||||
},
|
||||
}
|
||||
client := NewSoraDirectClient(cfg, nil, nil)
|
||||
require.Equal(t, "https://sora.chatgpt.com/backend/video_gen", client.buildURL("/video_gen"))
|
||||
}
|
||||
|
||||
func TestSoraDirectClient_BuildUpstreamError_NotFoundHint(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := NewSoraDirectClient(&config.Config{}, nil, nil)
|
||||
|
||||
err := client.buildUpstreamError(http.StatusNotFound, http.Header{}, []byte(`{"error":{"message":"Not found"}}`), "https://sora.chatgpt.com/video_gen")
|
||||
var upstreamErr *SoraUpstreamError
|
||||
require.ErrorAs(t, err, &upstreamErr)
|
||||
require.Contains(t, upstreamErr.Message, "请检查 sora.client.base_url")
|
||||
|
||||
errNoHint := client.buildUpstreamError(http.StatusNotFound, http.Header{}, []byte(`{"error":{"message":"Not found"}}`), "https://sora.chatgpt.com/backend/video_gen")
|
||||
require.ErrorAs(t, errNoHint, &upstreamErr)
|
||||
require.NotContains(t, upstreamErr.Message, "请检查 sora.client.base_url")
|
||||
}
|
||||
|
||||
func TestFormatSoraHeaders_RedactsSensitive(t *testing.T) {
|
||||
t.Parallel()
|
||||
headers := http.Header{}
|
||||
headers.Set("Authorization", "Bearer secret-token")
|
||||
headers.Set("openai-sentinel-token", "sentinel-secret")
|
||||
headers.Set("X-Test", "ok")
|
||||
|
||||
out := formatSoraHeaders(headers)
|
||||
require.Contains(t, out, `"Authorization":"***"`)
|
||||
require.Contains(t, out, `Sentinel-Token":"***"`)
|
||||
require.Contains(t, out, `"X-Test":"ok"`)
|
||||
require.NotContains(t, out, "secret-token")
|
||||
require.NotContains(t, out, "sentinel-secret")
|
||||
}
|
||||
|
||||
func TestSummarizeSoraResponseBody_RedactsJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"error":{"message":"bad"},"access_token":"abc123"}`)
|
||||
out := summarizeSoraResponseBody(body, 512)
|
||||
require.Contains(t, out, `"access_token":"***"`)
|
||||
require.NotContains(t, out, "abc123")
|
||||
}
|
||||
|
||||
func TestSummarizeSoraResponseBody_Truncates(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(strings.Repeat("x", 100))
|
||||
out := summarizeSoraResponseBody(body, 10)
|
||||
require.Contains(t, out, "(truncated)")
|
||||
}
|
||||
|
||||
func TestSoraDirectClient_GetAccessToken_SoraDefaultUseCredentials(t *testing.T) {
|
||||
t.Parallel()
|
||||
cache := newOpenAITokenCacheStub()
|
||||
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||
cfg := &config.Config{
|
||||
Sora: config.SoraConfig{
|
||||
Client: config.SoraClientConfig{
|
||||
BaseURL: "https://sora.chatgpt.com/backend",
|
||||
},
|
||||
},
|
||||
}
|
||||
client := NewSoraDirectClient(cfg, nil, provider)
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Platform: PlatformSora,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "sora-credential-token",
|
||||
},
|
||||
}
|
||||
|
||||
token, err := client.getAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "sora-credential-token", token)
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&cache.getCalled))
|
||||
}
|
||||
|
||||
func TestSoraDirectClient_GetAccessToken_SoraCanEnableProvider(t *testing.T) {
|
||||
t.Parallel()
|
||||
cache := newOpenAITokenCacheStub()
|
||||
account := &Account{
|
||||
ID: 2,
|
||||
Platform: PlatformSora,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "sora-credential-token",
|
||||
},
|
||||
}
|
||||
cache.tokens[OpenAITokenCacheKey(account)] = "provider-token"
|
||||
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||
cfg := &config.Config{
|
||||
Sora: config.SoraConfig{
|
||||
Client: config.SoraClientConfig{
|
||||
BaseURL: "https://sora.chatgpt.com/backend",
|
||||
UseOpenAITokenProvider: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
client := NewSoraDirectClient(cfg, nil, provider)
|
||||
|
||||
token, err := client.getAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "provider-token", token)
|
||||
require.Greater(t, atomic.LoadInt32(&cache.getCalled), int32(0))
|
||||
}
|
||||
|
||||
func TestSoraDirectClient_GetAccessToken_FromSessionToken(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, http.MethodGet, r.Method)
|
||||
require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=session-token")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"accessToken": "session-access-token",
|
||||
"expires": "2099-01-01T00:00:00Z",
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
origin := soraSessionAuthURL
|
||||
soraSessionAuthURL = server.URL
|
||||
defer func() { soraSessionAuthURL = origin }()
|
||||
|
||||
client := NewSoraDirectClient(&config.Config{}, nil, nil)
|
||||
account := &Account{
|
||||
ID: 10,
|
||||
Platform: PlatformSora,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"session_token": "session-token",
|
||||
},
|
||||
}
|
||||
|
||||
token, err := client.getAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "session-access-token", token)
|
||||
require.Equal(t, "session-access-token", account.GetCredential("access_token"))
|
||||
}
|
||||
|
||||
func TestSoraDirectClient_GetAccessToken_FromRefreshToken(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, http.MethodPost, r.Method)
|
||||
require.Equal(t, "/oauth/token", r.URL.Path)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"access_token": "refresh-access-token",
|
||||
"refresh_token": "refresh-token-new",
|
||||
"expires_in": 3600,
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
origin := soraOAuthTokenURL
|
||||
soraOAuthTokenURL = server.URL + "/oauth/token"
|
||||
defer func() { soraOAuthTokenURL = origin }()
|
||||
|
||||
client := NewSoraDirectClient(&config.Config{}, nil, nil)
|
||||
account := &Account{
|
||||
ID: 11,
|
||||
Platform: PlatformSora,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"refresh_token": "refresh-token-old",
|
||||
},
|
||||
}
|
||||
|
||||
token, err := client.getAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "refresh-access-token", token)
|
||||
require.Equal(t, "refresh-token-new", account.GetCredential("refresh_token"))
|
||||
require.NotNil(t, account.GetCredentialAsTime("expires_at"))
|
||||
}
|
||||
|
||||
func TestSoraDirectClient_PreflightCheck_VideoQuotaExceeded(t *testing.T) {
|
||||
t.Parallel()
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, http.MethodGet, r.Method)
|
||||
require.Equal(t, "/nf/check", r.URL.Path)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"rate_limit_and_credit_balance": map[string]any{
|
||||
"estimated_num_videos_remaining": 0,
|
||||
"rate_limit_reached": true,
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cfg := &config.Config{
|
||||
Sora: config.SoraConfig{
|
||||
Client: config.SoraClientConfig{
|
||||
BaseURL: server.URL,
|
||||
},
|
||||
},
|
||||
}
|
||||
client := NewSoraDirectClient(cfg, nil, nil)
|
||||
account := &Account{
|
||||
ID: 12,
|
||||
Platform: PlatformSora,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "ok",
|
||||
"expires_at": time.Now().Add(2 * time.Hour).Format(time.RFC3339),
|
||||
},
|
||||
}
|
||||
err := client.PreflightCheck(context.Background(), account, "sora2-landscape-10s", SoraModelConfig{Type: "video"})
|
||||
require.Error(t, err)
|
||||
var upstreamErr *SoraUpstreamError
|
||||
require.ErrorAs(t, err, &upstreamErr)
|
||||
require.Equal(t, http.StatusTooManyRequests, upstreamErr.StatusCode)
|
||||
}
|
||||
|
||||
func TestShouldAttemptSoraTokenRecover(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require.True(t, shouldAttemptSoraTokenRecover(http.StatusUnauthorized, "https://sora.chatgpt.com/backend/video_gen"))
|
||||
require.True(t, shouldAttemptSoraTokenRecover(http.StatusForbidden, "https://chatgpt.com/backend/video_gen"))
|
||||
require.False(t, shouldAttemptSoraTokenRecover(http.StatusUnauthorized, "https://sora.chatgpt.com/api/auth/session"))
|
||||
require.False(t, shouldAttemptSoraTokenRecover(http.StatusUnauthorized, "https://auth.openai.com/oauth/token"))
|
||||
require.False(t, shouldAttemptSoraTokenRecover(http.StatusTooManyRequests, "https://sora.chatgpt.com/backend/video_gen"))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user