Files
sub2api/backend/internal/service/sora_client_test.go
yangjianbo 900cce20a1 feat(sora): 对齐 Sora OAuth 流程并隔离网关请求路径
- 新增并接通 Sora 专用 OAuth 接口与 ST/RT 换取能力
- 完成前端 Sora 授权、RT/ST 手动导入与账号创建流程
- 强化 Sora token 恢复、转发日志与网关路由隔离行为
- 补充后端服务层与路由层相关测试覆盖

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 08:02:56 +08:00

362 lines
11 KiB
Go

//go:build unit
package service
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func TestSoraDirectClient_DoRequestSuccess(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"ok":true}`))
}))
defer server.Close()
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{BaseURL: server.URL},
},
}
client := NewSoraDirectClient(cfg, nil, nil)
body, _, err := client.doRequest(context.Background(), &Account{ID: 1}, http.MethodGet, server.URL, http.Header{}, nil, false)
require.NoError(t, err)
require.Contains(t, string(body), "ok")
}
func TestSoraDirectClient_BuildBaseHeaders(t *testing.T) {
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
Headers: map[string]string{
"X-Test": "yes",
"Authorization": "should-ignore",
"openai-sentinel-token": "skip",
},
},
},
}
client := NewSoraDirectClient(cfg, nil, nil)
headers := client.buildBaseHeaders("token-123", "UA")
require.Equal(t, "Bearer token-123", headers.Get("Authorization"))
require.Equal(t, "UA", headers.Get("User-Agent"))
require.Equal(t, "yes", headers.Get("X-Test"))
require.Empty(t, headers.Get("openai-sentinel-token"))
}
func TestSoraDirectClient_GetImageTaskFallbackLimit(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
limit := r.URL.Query().Get("limit")
w.Header().Set("Content-Type", "application/json")
switch limit {
case "1":
_, _ = w.Write([]byte(`{"task_responses":[]}`))
case "2":
_, _ = w.Write([]byte(`{"task_responses":[{"id":"task-1","status":"completed","progress_pct":1,"generations":[{"url":"https://example.com/a.png"}]}]}`))
default:
_, _ = w.Write([]byte(`{"task_responses":[]}`))
}
}))
defer server.Close()
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: server.URL,
RecentTaskLimit: 1,
RecentTaskLimitMax: 2,
},
},
}
client := NewSoraDirectClient(cfg, nil, nil)
account := &Account{Credentials: map[string]any{"access_token": "token"}}
status, err := client.GetImageTask(context.Background(), account, "task-1")
require.NoError(t, err)
require.Equal(t, "completed", status.Status)
require.Equal(t, []string{"https://example.com/a.png"}, status.URLs)
}
func TestNormalizeSoraBaseURL(t *testing.T) {
t.Parallel()
tests := []struct {
name string
raw string
want string
}{
{
name: "empty",
raw: "",
want: "",
},
{
name: "append_backend_for_sora_host",
raw: "https://sora.chatgpt.com",
want: "https://sora.chatgpt.com/backend",
},
{
name: "convert_backend_api_to_backend",
raw: "https://sora.chatgpt.com/backend-api",
want: "https://sora.chatgpt.com/backend",
},
{
name: "keep_backend",
raw: "https://sora.chatgpt.com/backend",
want: "https://sora.chatgpt.com/backend",
},
{
name: "keep_custom_host",
raw: "https://example.com/custom-path",
want: "https://example.com/custom-path",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := normalizeSoraBaseURL(tt.raw)
require.Equal(t, tt.want, got)
})
}
}
func TestSoraDirectClient_BuildURL_UsesNormalizedBaseURL(t *testing.T) {
t.Parallel()
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: "https://sora.chatgpt.com",
},
},
}
client := NewSoraDirectClient(cfg, nil, nil)
require.Equal(t, "https://sora.chatgpt.com/backend/video_gen", client.buildURL("/video_gen"))
}
func TestSoraDirectClient_BuildUpstreamError_NotFoundHint(t *testing.T) {
t.Parallel()
client := NewSoraDirectClient(&config.Config{}, nil, nil)
err := client.buildUpstreamError(http.StatusNotFound, http.Header{}, []byte(`{"error":{"message":"Not found"}}`), "https://sora.chatgpt.com/video_gen")
var upstreamErr *SoraUpstreamError
require.ErrorAs(t, err, &upstreamErr)
require.Contains(t, upstreamErr.Message, "请检查 sora.client.base_url")
errNoHint := client.buildUpstreamError(http.StatusNotFound, http.Header{}, []byte(`{"error":{"message":"Not found"}}`), "https://sora.chatgpt.com/backend/video_gen")
require.ErrorAs(t, errNoHint, &upstreamErr)
require.NotContains(t, upstreamErr.Message, "请检查 sora.client.base_url")
}
func TestFormatSoraHeaders_RedactsSensitive(t *testing.T) {
t.Parallel()
headers := http.Header{}
headers.Set("Authorization", "Bearer secret-token")
headers.Set("openai-sentinel-token", "sentinel-secret")
headers.Set("X-Test", "ok")
out := formatSoraHeaders(headers)
require.Contains(t, out, `"Authorization":"***"`)
require.Contains(t, out, `Sentinel-Token":"***"`)
require.Contains(t, out, `"X-Test":"ok"`)
require.NotContains(t, out, "secret-token")
require.NotContains(t, out, "sentinel-secret")
}
func TestSummarizeSoraResponseBody_RedactsJSON(t *testing.T) {
t.Parallel()
body := []byte(`{"error":{"message":"bad"},"access_token":"abc123"}`)
out := summarizeSoraResponseBody(body, 512)
require.Contains(t, out, `"access_token":"***"`)
require.NotContains(t, out, "abc123")
}
func TestSummarizeSoraResponseBody_Truncates(t *testing.T) {
t.Parallel()
body := []byte(strings.Repeat("x", 100))
out := summarizeSoraResponseBody(body, 10)
require.Contains(t, out, "(truncated)")
}
func TestSoraDirectClient_GetAccessToken_SoraDefaultUseCredentials(t *testing.T) {
t.Parallel()
cache := newOpenAITokenCacheStub()
provider := NewOpenAITokenProvider(nil, cache, nil)
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: "https://sora.chatgpt.com/backend",
},
},
}
client := NewSoraDirectClient(cfg, nil, provider)
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "sora-credential-token",
},
}
token, err := client.getAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "sora-credential-token", token)
require.Equal(t, int32(0), atomic.LoadInt32(&cache.getCalled))
}
func TestSoraDirectClient_GetAccessToken_SoraCanEnableProvider(t *testing.T) {
t.Parallel()
cache := newOpenAITokenCacheStub()
account := &Account{
ID: 2,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "sora-credential-token",
},
}
cache.tokens[OpenAITokenCacheKey(account)] = "provider-token"
provider := NewOpenAITokenProvider(nil, cache, nil)
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: "https://sora.chatgpt.com/backend",
UseOpenAITokenProvider: true,
},
},
}
client := NewSoraDirectClient(cfg, nil, provider)
token, err := client.getAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "provider-token", token)
require.Greater(t, atomic.LoadInt32(&cache.getCalled), int32(0))
}
func TestSoraDirectClient_GetAccessToken_FromSessionToken(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=session-token")
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"accessToken": "session-access-token",
"expires": "2099-01-01T00:00:00Z",
})
}))
defer server.Close()
origin := soraSessionAuthURL
soraSessionAuthURL = server.URL
defer func() { soraSessionAuthURL = origin }()
client := NewSoraDirectClient(&config.Config{}, nil, nil)
account := &Account{
ID: 10,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"session_token": "session-token",
},
}
token, err := client.getAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "session-access-token", token)
require.Equal(t, "session-access-token", account.GetCredential("access_token"))
}
func TestSoraDirectClient_GetAccessToken_FromRefreshToken(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodPost, r.Method)
require.Equal(t, "/oauth/token", r.URL.Path)
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"access_token": "refresh-access-token",
"refresh_token": "refresh-token-new",
"expires_in": 3600,
})
}))
defer server.Close()
origin := soraOAuthTokenURL
soraOAuthTokenURL = server.URL + "/oauth/token"
defer func() { soraOAuthTokenURL = origin }()
client := NewSoraDirectClient(&config.Config{}, nil, nil)
account := &Account{
ID: 11,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"refresh_token": "refresh-token-old",
},
}
token, err := client.getAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "refresh-access-token", token)
require.Equal(t, "refresh-token-new", account.GetCredential("refresh_token"))
require.NotNil(t, account.GetCredentialAsTime("expires_at"))
}
func TestSoraDirectClient_PreflightCheck_VideoQuotaExceeded(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
require.Equal(t, "/nf/check", r.URL.Path)
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"rate_limit_and_credit_balance": map[string]any{
"estimated_num_videos_remaining": 0,
"rate_limit_reached": true,
},
})
}))
defer server.Close()
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: server.URL,
},
},
}
client := NewSoraDirectClient(cfg, nil, nil)
account := &Account{
ID: 12,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "ok",
"expires_at": time.Now().Add(2 * time.Hour).Format(time.RFC3339),
},
}
err := client.PreflightCheck(context.Background(), account, "sora2-landscape-10s", SoraModelConfig{Type: "video"})
require.Error(t, err)
var upstreamErr *SoraUpstreamError
require.ErrorAs(t, err, &upstreamErr)
require.Equal(t, http.StatusTooManyRequests, upstreamErr.StatusCode)
}
func TestShouldAttemptSoraTokenRecover(t *testing.T) {
t.Parallel()
require.True(t, shouldAttemptSoraTokenRecover(http.StatusUnauthorized, "https://sora.chatgpt.com/backend/video_gen"))
require.True(t, shouldAttemptSoraTokenRecover(http.StatusForbidden, "https://chatgpt.com/backend/video_gen"))
require.False(t, shouldAttemptSoraTokenRecover(http.StatusUnauthorized, "https://sora.chatgpt.com/api/auth/session"))
require.False(t, shouldAttemptSoraTokenRecover(http.StatusUnauthorized, "https://auth.openai.com/oauth/token"))
require.False(t, shouldAttemptSoraTokenRecover(http.StatusTooManyRequests, "https://sora.chatgpt.com/backend/video_gen"))
}