From fe1d46a8eaddfcc85a4237d1140c38255cb623d8 Mon Sep 17 00:00:00 2001 From: kyx236 Date: Thu, 12 Feb 2026 03:47:06 +0800 Subject: [PATCH 01/28] feat(admin): Add group filtering for account listings - Add groupID parameter to ListAccounts and ListWithFilters methods - Implement account filtering by group ID in repository query - Add group query parameter parsing in account handler - Update all ListAccounts/ListWithFilters call sites with groupID parameter - Add group filter UI component to AccountTableFilters - Add i18n translations for group filter label in English and Chinese - Update API contract and test stubs to reflect new signature - Enable filtering accounts by their assigned groups in admin panel --- backend/internal/handler/admin/account_data.go | 2 +- backend/internal/handler/admin/account_handler.go | 9 +++++++-- .../internal/handler/admin/admin_service_stub_test.go | 2 +- backend/internal/repository/account_repo.go | 7 +++++-- .../internal/repository/account_repo_integration_test.go | 4 ++-- backend/internal/server/api_contract_test.go | 2 +- backend/internal/service/account_service.go | 2 +- backend/internal/service/account_service_delete_test.go | 2 +- backend/internal/service/admin_service.go | 6 +++--- backend/internal/service/admin_service_search_test.go | 4 ++-- backend/internal/service/gateway_multiplatform_test.go | 2 +- backend/internal/service/gemini_multiplatform_test.go | 2 +- backend/internal/service/ops_concurrency.go | 2 +- frontend/src/api/admin/accounts.ts | 1 + frontend/src/components/account/AccountGroupsCell.vue | 2 +- .../src/components/admin/account/AccountTableFilters.vue | 7 ++++++- frontend/src/i18n/locales/en.ts | 3 ++- frontend/src/i18n/locales/zh.ts | 3 ++- frontend/src/views/admin/AccountsView.vue | 3 ++- 19 files changed, 41 insertions(+), 24 deletions(-) diff --git a/backend/internal/handler/admin/account_data.go b/backend/internal/handler/admin/account_data.go index b5d1dd0a..34397696 100644 --- a/backend/internal/handler/admin/account_data.go +++ b/backend/internal/handler/admin/account_data.go @@ -341,7 +341,7 @@ func (h *AccountHandler) listAccountsFiltered(ctx context.Context, platform, acc pageSize := dataPageCap var out []service.Account for { - items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search) + items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search, 0) if err != nil { return nil, err } diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index 85400c6f..0fae04ac 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -156,7 +156,12 @@ func (h *AccountHandler) List(c *gin.Context) { search = search[:100] } - accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search) + var groupID int64 + if groupIDStr := c.Query("group"); groupIDStr != "" { + groupID, _ = strconv.ParseInt(groupIDStr, 10, 64) + } + + accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID) if err != nil { response.ErrorFrom(c, err) return @@ -1429,7 +1434,7 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) { accounts := make([]*service.Account, 0) if len(req.AccountIDs) == 0 { - allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "") + allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "", 0) if err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go index cbbfe942..d44c99ea 100644 --- a/backend/internal/handler/admin/admin_service_stub_test.go +++ b/backend/internal/handler/admin/admin_service_stub_test.go @@ -166,7 +166,7 @@ func (s *stubAdminService) GetGroupAPIKeys(ctx context.Context, groupID int64, p return s.apiKeys, int64(len(s.apiKeys)), nil } -func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]service.Account, int64, error) { +func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]service.Account, int64, error) { return s.accounts, int64(len(s.accounts)), nil } diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index d28ae042..e3e70213 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -435,10 +435,10 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error { } func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) { - return r.ListWithFilters(ctx, params, "", "", "", "") + return r.ListWithFilters(ctx, params, "", "", "", "", 0) } -func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) { +func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) { q := r.client.Account.Query() if platform != "" { @@ -458,6 +458,9 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati if search != "" { q = q.Where(dbaccount.NameContainsFold(search)) } + if groupID > 0 { + q = q.Where(dbaccount.HasAccountGroupsWith(dbaccountgroup.GroupIDEQ(groupID))) + } total, err := q.Count(ctx) if err != nil { diff --git a/backend/internal/repository/account_repo_integration_test.go b/backend/internal/repository/account_repo_integration_test.go index a054b6d6..4f9d0152 100644 --- a/backend/internal/repository/account_repo_integration_test.go +++ b/backend/internal/repository/account_repo_integration_test.go @@ -238,7 +238,7 @@ func (s *AccountRepoSuite) TestListWithFilters() { tt.setup(client) - accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search) + accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search, 0) s.Require().NoError(err) s.Require().Len(accounts, tt.wantCount) if tt.validate != nil { @@ -305,7 +305,7 @@ func (s *AccountRepoSuite) TestPreload_And_VirtualFields() { s.Require().Len(got.Groups, 1, "expected Groups to be populated") s.Require().Equal(group.ID, got.Groups[0].ID) - accounts, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", "acc") + accounts, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", "acc", 0) s.Require().NoError(err, "ListWithFilters") s.Require().Equal(int64(1), page.Total) s.Require().Len(accounts, 1) diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index fa6806ae..e6bb4f53 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -936,7 +936,7 @@ func (s *stubAccountRepo) List(ctx context.Context, params pagination.Pagination return nil, nil, errors.New("not implemented") } -func (s *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) { +func (s *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) { return nil, nil, errors.New("not implemented") } diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go index 6c0cca31..f192fba4 100644 --- a/backend/internal/service/account_service.go +++ b/backend/internal/service/account_service.go @@ -32,7 +32,7 @@ type AccountRepository interface { Delete(ctx context.Context, id int64) error List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) - ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) + ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) ListActive(ctx context.Context) ([]Account, error) ListByPlatform(ctx context.Context, platform string) ([]Account, error) diff --git a/backend/internal/service/account_service_delete_test.go b/backend/internal/service/account_service_delete_test.go index 25bd0576..a420d46b 100644 --- a/backend/internal/service/account_service_delete_test.go +++ b/backend/internal/service/account_service_delete_test.go @@ -75,7 +75,7 @@ func (s *accountRepoStub) List(ctx context.Context, params pagination.Pagination panic("unexpected List call") } -func (s *accountRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) { +func (s *accountRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) { panic("unexpected ListWithFilters call") } diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 06354e1e..1f6e91e5 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -39,7 +39,7 @@ type AdminService interface { UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error // Account management - ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error) + ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error) GetAccount(ctx context.Context, id int64) (*Account, error) GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) @@ -1021,9 +1021,9 @@ func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates [] } // Account management implementations -func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error) { +func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error) { params := pagination.PaginationParams{Page: page, PageSize: pageSize} - accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search) + accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search, groupID) if err != nil { return nil, 0, err } diff --git a/backend/internal/service/admin_service_search_test.go b/backend/internal/service/admin_service_search_test.go index d661b710..ff58fd01 100644 --- a/backend/internal/service/admin_service_search_test.go +++ b/backend/internal/service/admin_service_search_test.go @@ -24,7 +24,7 @@ type accountRepoStubForAdminList struct { listWithFiltersErr error } -func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) { +func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) { s.listWithFiltersCalls++ s.listWithFiltersParams = params s.listWithFiltersPlatform = platform @@ -168,7 +168,7 @@ func TestAdminService_ListAccounts_WithSearch(t *testing.T) { } svc := &adminServiceImpl{accountRepo: repo} - accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc") + accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc", 0) require.NoError(t, err) require.Equal(t, int64(10), total) require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts) diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index b4b93ace..09fda60e 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -87,7 +87,7 @@ func (m *mockAccountRepoForPlatform) Delete(ctx context.Context, id int64) error func (m *mockAccountRepoForPlatform) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) { return nil, nil, nil } -func (m *mockAccountRepoForPlatform) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) { +func (m *mockAccountRepoForPlatform) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) { return nil, nil, nil } func (m *mockAccountRepoForPlatform) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) { diff --git a/backend/internal/service/gemini_multiplatform_test.go b/backend/internal/service/gemini_multiplatform_test.go index 080352ba..6b1fcecc 100644 --- a/backend/internal/service/gemini_multiplatform_test.go +++ b/backend/internal/service/gemini_multiplatform_test.go @@ -74,7 +74,7 @@ func (m *mockAccountRepoForGemini) Delete(ctx context.Context, id int64) error func (m *mockAccountRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) { return nil, nil, nil } -func (m *mockAccountRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) { +func (m *mockAccountRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) { return nil, nil, nil } func (m *mockAccountRepoForGemini) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) { diff --git a/backend/internal/service/ops_concurrency.go b/backend/internal/service/ops_concurrency.go index f6541d08..92b37e73 100644 --- a/backend/internal/service/ops_concurrency.go +++ b/backend/internal/service/ops_concurrency.go @@ -24,7 +24,7 @@ func (s *OpsService) listAllAccountsForOps(ctx context.Context, platformFilter s accounts, pageInfo, err := s.accountRepo.ListWithFilters(ctx, pagination.PaginationParams{ Page: page, PageSize: opsAccountsPageSize, - }, platformFilter, "", "", "") + }, platformFilter, "", "", "", 0) if err != nil { return nil, err } diff --git a/frontend/src/api/admin/accounts.ts b/frontend/src/api/admin/accounts.ts index 4cb1a6f2..e1299595 100644 --- a/frontend/src/api/admin/accounts.ts +++ b/frontend/src/api/admin/accounts.ts @@ -32,6 +32,7 @@ export async function list( platform?: string type?: string status?: string + group?: string search?: string }, options?: { diff --git a/frontend/src/components/account/AccountGroupsCell.vue b/frontend/src/components/account/AccountGroupsCell.vue index 512383a5..37771275 100644 --- a/frontend/src/components/account/AccountGroupsCell.vue +++ b/frontend/src/components/account/AccountGroupsCell.vue @@ -41,7 +41,7 @@ >
- {{ t('admin.accounts.allGroups', { count: groups.length }) }} + {{ t('admin.accounts.groupCountTotal', { count: groups.length }) }}
+
@@ -2146,6 +2186,8 @@ const maxSessions = ref(null) const sessionIdleTimeout = ref(null) const tlsFingerprintEnabled = ref(false) const sessionIdMaskingEnabled = ref(false) +const cacheTTLOverrideEnabled = ref(false) +const cacheTTLOverrideTarget = ref('5m') // Gemini tier selection (used as fallback when auto-detection is unavailable/fails) const geminiTierGoogleOne = ref<'google_one_free' | 'google_ai_pro' | 'google_ai_ultra'>('google_one_free') @@ -2597,6 +2639,8 @@ const resetForm = () => { sessionIdleTimeout.value = null tlsFingerprintEnabled.value = false sessionIdMaskingEnabled.value = false + cacheTTLOverrideEnabled.value = false + cacheTTLOverrideTarget.value = '5m' antigravityAccountType.value = 'oauth' upstreamBaseUrl.value = '' upstreamApiKey.value = '' @@ -3174,6 +3218,12 @@ const handleAnthropicExchange = async (authCode: string) => { extra.session_id_masking_enabled = true } + // Add cache TTL override settings + if (cacheTTLOverrideEnabled.value) { + extra.cache_ttl_override_enabled = true + extra.cache_ttl_override_target = cacheTTLOverrideTarget.value + } + const credentials = { ...tokenInfo, ...(interceptWarmupRequests.value ? { intercept_warmup_requests: true } : {}) @@ -3267,6 +3317,12 @@ const handleCookieAuth = async (sessionKey: string) => { extra.session_id_masking_enabled = true } + // Add cache TTL override settings + if (cacheTTLOverrideEnabled.value) { + extra.cache_ttl_override_enabled = true + extra.cache_ttl_override_target = cacheTTLOverrideTarget.value + } + const accountName = keys.length > 1 ? `${form.name} #${i + 1}` : form.name // Merge interceptWarmupRequests into credentials diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index 60575f56..ed243276 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -904,6 +904,46 @@
+ + +
+
+
+ +

+ {{ t('admin.accounts.quotaControl.cacheTTLOverride.hint') }} +

+
+ +
+
+ + +

+ {{ t('admin.accounts.quotaControl.cacheTTLOverride.targetHint') }} +

+
+
@@ -1102,6 +1142,8 @@ const maxSessions = ref(null) const sessionIdleTimeout = ref(null) const tlsFingerprintEnabled = ref(false) const sessionIdMaskingEnabled = ref(false) +const cacheTTLOverrideEnabled = ref(false) +const cacheTTLOverrideTarget = ref('5m') // Computed: current preset mappings based on platform const presetMappings = computed(() => getPresetMappingsByPlatform(props.account?.platform || 'anthropic')) @@ -1489,6 +1531,8 @@ function loadQuotaControlSettings(account: Account) { sessionIdleTimeout.value = null tlsFingerprintEnabled.value = false sessionIdMaskingEnabled.value = false + cacheTTLOverrideEnabled.value = false + cacheTTLOverrideTarget.value = '5m' // Only applies to Anthropic OAuth/SetupToken accounts if (account.platform !== 'anthropic' || (account.type !== 'oauth' && account.type !== 'setup-token')) { @@ -1517,6 +1561,12 @@ function loadQuotaControlSettings(account: Account) { if (account.session_id_masking_enabled === true) { sessionIdMaskingEnabled.value = true } + + // Load cache TTL override setting + if (account.cache_ttl_override_enabled === true) { + cacheTTLOverrideEnabled.value = true + cacheTTLOverrideTarget.value = account.cache_ttl_override_target || '5m' + } } function formatTempUnschedKeywords(value: unknown) { @@ -1723,6 +1773,15 @@ const handleSubmit = async () => { delete newExtra.session_id_masking_enabled } + // Cache TTL override setting + if (cacheTTLOverrideEnabled.value) { + newExtra.cache_ttl_override_enabled = true + newExtra.cache_ttl_override_target = cacheTTLOverrideTarget.value + } else { + delete newExtra.cache_ttl_override_enabled + delete newExtra.cache_ttl_override_target + } + updatePayload.extra = newExtra } diff --git a/frontend/src/components/admin/usage/UsageTable.vue b/frontend/src/components/admin/usage/UsageTable.vue index 2c39807e..a6420f1c 100644 --- a/frontend/src/components/admin/usage/UsageTable.vue +++ b/frontend/src/components/admin/usage/UsageTable.vue @@ -71,6 +71,7 @@ {{ formatCacheTokens(row.cache_creation_tokens) }} 1h + R
@@ -182,6 +183,13 @@ {{ tokenTooltipData.cache_creation_tokens.toLocaleString() }} +
+ + {{ t('usage.cacheTtlOverriddenLabel') }} + R-{{ tokenTooltipData.cache_creation_1h_tokens > 0 ? '5m' : '1H' }} + + {{ tokenTooltipData.cache_creation_1h_tokens > 0 ? t('usage.cacheTtlOverridden1h') : t('usage.cacheTtlOverridden5m') }} +
{{ t('admin.usage.cacheReadTokens') }} {{ tokenTooltipData.cache_read_tokens.toLocaleString() }} diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 5e00c3b2..90515478 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -576,6 +576,10 @@ export default { description: 'View and analyze your API usage history', costDetails: 'Cost Breakdown', tokenDetails: 'Token Breakdown', + cacheTtlOverriddenHint: 'Cache TTL Override enabled', + cacheTtlOverriddenLabel: 'TTL Override', + cacheTtlOverridden5m: 'Billed as 5m', + cacheTtlOverridden1h: 'Billed as 1h', totalRequests: 'Total Requests', totalTokens: 'Total Tokens', totalCost: 'Total Cost', @@ -1595,6 +1599,12 @@ export default { sessionIdMasking: { label: 'Session ID Masking', hint: 'When enabled, fixes the session ID in metadata.user_id for 15 minutes, making upstream think requests come from the same session' + }, + cacheTTLOverride: { + label: 'Cache TTL Override', + hint: 'Force all cache creation tokens to be billed as the selected TTL tier (5m or 1h)', + target: 'Target TTL', + targetHint: 'Select the TTL tier for billing' } }, expired: 'Expired', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index a78de03e..78bdfaa2 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -582,6 +582,10 @@ export default { description: '查看和分析您的 API 使用历史', costDetails: '成本明细', tokenDetails: 'Token 明细', + cacheTtlOverriddenHint: '缓存 TTL Override 已启用', + cacheTtlOverriddenLabel: 'TTL 替换', + cacheTtlOverridden5m: '按 5m 计费', + cacheTtlOverridden1h: '按 1h 计费', totalRequests: '总请求数', totalTokens: '总 Token', totalCost: '总消费', @@ -1741,6 +1745,12 @@ export default { sessionIdMasking: { label: '会话 ID 伪装', hint: '启用后将在 15 分钟内固定 metadata.user_id 中的 session ID,使上游认为请求来自同一会话' + }, + cacheTTLOverride: { + label: '缓存 TTL 强制替换', + hint: '将所有缓存创建 token 强制按指定的 TTL 类型(5分钟或1小时)计费', + target: '目标 TTL', + targetHint: '选择计费使用的 TTL 类型' } }, expired: '已过期', diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 189d8af1..bed331b3 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -614,6 +614,10 @@ export interface Account { // 启用后将在15分钟内固定 metadata.user_id 中的 session ID session_id_masking_enabled?: boolean | null + // 缓存 TTL 强制替换(仅 Anthropic OAuth/SetupToken 账号有效) + cache_ttl_override_enabled?: boolean | null + cache_ttl_override_target?: string | null + // 运行时状态(仅当启用对应限制时返回) current_window_cost?: number | null // 当前窗口费用 active_sessions?: number | null // 当前活跃会话数 @@ -827,6 +831,9 @@ export interface UsageLog { // User-Agent user_agent: string | null + // Cache TTL Override + cache_ttl_overridden: boolean + created_at: string user?: User diff --git a/frontend/src/views/user/UsageView.vue b/frontend/src/views/user/UsageView.vue index eb827d7a..73ea0897 100644 --- a/frontend/src/views/user/UsageView.vue +++ b/frontend/src/views/user/UsageView.vue @@ -234,6 +234,7 @@ formatCacheTokens(row.cache_creation_tokens) }} 1h + R
@@ -375,6 +376,13 @@ {{ tokenTooltipData.cache_creation_tokens.toLocaleString() }} +
+ + {{ t('usage.cacheTtlOverriddenLabel') }} + R-{{ tokenTooltipData.cache_creation_1h_tokens > 0 ? '5m' : '1H' }} + + {{ tokenTooltipData.cache_creation_1h_tokens > 0 ? t('usage.cacheTtlOverridden1h') : t('usage.cacheTtlOverridden5m') }} +
{{ t('admin.usage.cacheReadTokens') }} {{ tokenTooltipData.cache_read_tokens.toLocaleString() }} From b41fa5e15f14332784213cdd33aaaf09c8f633cc Mon Sep 17 00:00:00 2001 From: shaw Date: Wed, 18 Feb 2026 17:06:37 +0800 Subject: [PATCH 14/28] =?UTF-8?q?feat:=20=E5=89=8D=E7=AB=AF=E6=96=B0?= =?UTF-8?q?=E5=A2=9Esonnet4.6=E5=BF=AB=E6=8D=B7=E6=98=A0=E5=B0=84=E6=8C=89?= =?UTF-8?q?=E9=92=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/internal/pkg/claude/constants.go | 6 ++++++ frontend/src/components/account/BulkEditAccountModal.vue | 8 ++++++++ frontend/src/composables/useModelWhitelist.ts | 2 ++ 3 files changed, 16 insertions(+) diff --git a/backend/internal/pkg/claude/constants.go b/backend/internal/pkg/claude/constants.go index eecee11e..9775bf7b 100644 --- a/backend/internal/pkg/claude/constants.go +++ b/backend/internal/pkg/claude/constants.go @@ -77,6 +77,12 @@ var DefaultModels = []Model{ DisplayName: "Claude Opus 4.6", CreatedAt: "2026-02-06T00:00:00Z", }, + { + ID: "claude-sonnet-4-6", + Type: "model", + DisplayName: "Claude Sonnet 4.6", + CreatedAt: "2026-02-18T00:00:00Z", + }, { ID: "claude-sonnet-4-5-20250929", Type: "model", diff --git a/frontend/src/components/account/BulkEditAccountModal.vue b/frontend/src/components/account/BulkEditAccountModal.vue index 912eabb3..67de5697 100644 --- a/frontend/src/components/account/BulkEditAccountModal.vue +++ b/frontend/src/components/account/BulkEditAccountModal.vue @@ -708,6 +708,7 @@ const groupIds = ref([]) // All models list (combined Anthropic + OpenAI) const allModels = [ { value: 'claude-opus-4-6', label: 'Claude Opus 4.6' }, + { value: 'claude-sonnet-4-6', label: 'Claude Sonnet 4.6' }, { value: 'claude-opus-4-5-20251101', label: 'Claude Opus 4.5' }, { value: 'claude-sonnet-4-20250514', label: 'Claude Sonnet 4' }, { value: 'claude-sonnet-4-5-20250929', label: 'Claude Sonnet 4.5' }, @@ -754,6 +755,13 @@ const presetMappings = [ color: 'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400' }, + { + label: 'Sonnet 4.6', + from: 'claude-sonnet-4-6', + to: 'claude-sonnet-4-6', + color: + 'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400' + }, { label: 'Opus->Sonnet', from: 'claude-opus-4-5-20251101', diff --git a/frontend/src/composables/useModelWhitelist.ts b/frontend/src/composables/useModelWhitelist.ts index 0ef80431..98c668f0 100644 --- a/frontend/src/composables/useModelWhitelist.ts +++ b/frontend/src/composables/useModelWhitelist.ts @@ -39,6 +39,7 @@ export const claudeModels = [ 'claude-sonnet-4-5-20250929', 'claude-haiku-4-5-20251001', 'claude-opus-4-5-20251101', 'claude-opus-4-6', + 'claude-sonnet-4-6', 'claude-2.1', 'claude-2.0', 'claude-instant-1.2' ] @@ -233,6 +234,7 @@ export const allModels = allModelsList.map(m => ({ value: m, label: m })) const anthropicPresetMappings = [ { label: 'Sonnet 4', from: 'claude-sonnet-4-20250514', to: 'claude-sonnet-4-20250514', color: 'bg-blue-100 text-blue-700 hover:bg-blue-200 dark:bg-blue-900/30 dark:text-blue-400' }, { label: 'Sonnet 4.5', from: 'claude-sonnet-4-5-20250929', to: 'claude-sonnet-4-5-20250929', color: 'bg-indigo-100 text-indigo-700 hover:bg-indigo-200 dark:bg-indigo-900/30 dark:text-indigo-400' }, + { label: 'Sonnet 4.6', from: 'claude-sonnet-4-6', to: 'claude-sonnet-4-6', color: 'bg-indigo-100 text-indigo-700 hover:bg-indigo-200 dark:bg-indigo-900/30 dark:text-indigo-400' }, { label: 'Opus 4.5', from: 'claude-opus-4-5-20251101', to: 'claude-opus-4-5-20251101', color: 'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400' }, { label: 'Opus 4.6', from: 'claude-opus-4-6', to: 'claude-opus-4-6', color: 'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400' }, { label: 'Haiku 3.5', from: 'claude-3-5-haiku-20241022', to: 'claude-3-5-haiku-20241022', color: 'bg-green-100 text-green-700 hover:bg-green-200 dark:bg-green-900/30 dark:text-green-400' }, From 074bd0dfdace3f523125e08609bc7c9ee3f6134d Mon Sep 17 00:00:00 2001 From: shaw Date: Wed, 18 Feb 2026 18:41:30 +0800 Subject: [PATCH 15/28] =?UTF-8?q?fix:=20=E4=B8=B4=E6=97=B6=E7=A7=BB?= =?UTF-8?q?=E9=99=A4context-1m-2025-08-07=E4=BB=A5=E7=A1=AE=E4=BF=9D?= =?UTF-8?q?=E9=81=BF=E5=85=8Dsonnet1m=E8=A7=A6=E5=8F=91429?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/internal/pkg/claude/constants.go | 1 + backend/internal/service/gateway_beta_test.go | 69 +++++++++++++++++++ backend/internal/service/gateway_service.go | 26 +++++-- 3 files changed, 92 insertions(+), 4 deletions(-) diff --git a/backend/internal/pkg/claude/constants.go b/backend/internal/pkg/claude/constants.go index 9775bf7b..423ad925 100644 --- a/backend/internal/pkg/claude/constants.go +++ b/backend/internal/pkg/claude/constants.go @@ -10,6 +10,7 @@ const ( BetaInterleavedThinking = "interleaved-thinking-2025-05-14" BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14" BetaTokenCounting = "token-counting-2024-11-01" + BetaContext1M = "context-1m-2025-08-07" ) // DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header diff --git a/backend/internal/service/gateway_beta_test.go b/backend/internal/service/gateway_beta_test.go index dd58c183..d7108c8d 100644 --- a/backend/internal/service/gateway_beta_test.go +++ b/backend/internal/service/gateway_beta_test.go @@ -21,3 +21,72 @@ func TestMergeAnthropicBeta_EmptyIncoming(t *testing.T) { ) require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14", got) } + +func TestStripBetaToken(t *testing.T) { + tests := []struct { + name string + header string + token string + want string + }{ + { + name: "token in middle", + header: "oauth-2025-04-20,context-1m-2025-08-07,interleaved-thinking-2025-05-14", + token: "context-1m-2025-08-07", + want: "oauth-2025-04-20,interleaved-thinking-2025-05-14", + }, + { + name: "token at start", + header: "context-1m-2025-08-07,oauth-2025-04-20,interleaved-thinking-2025-05-14", + token: "context-1m-2025-08-07", + want: "oauth-2025-04-20,interleaved-thinking-2025-05-14", + }, + { + name: "token at end", + header: "oauth-2025-04-20,interleaved-thinking-2025-05-14,context-1m-2025-08-07", + token: "context-1m-2025-08-07", + want: "oauth-2025-04-20,interleaved-thinking-2025-05-14", + }, + { + name: "token not present", + header: "oauth-2025-04-20,interleaved-thinking-2025-05-14", + token: "context-1m-2025-08-07", + want: "oauth-2025-04-20,interleaved-thinking-2025-05-14", + }, + { + name: "empty header", + header: "", + token: "context-1m-2025-08-07", + want: "", + }, + { + name: "with spaces", + header: "oauth-2025-04-20, context-1m-2025-08-07 , interleaved-thinking-2025-05-14", + token: "context-1m-2025-08-07", + want: "oauth-2025-04-20,interleaved-thinking-2025-05-14", + }, + { + name: "only token", + header: "context-1m-2025-08-07", + token: "context-1m-2025-08-07", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := stripBetaToken(tt.header, tt.token) + require.Equal(t, tt.want, got) + }) + } +} + +func TestMergeAnthropicBetaDropping_Context1M(t *testing.T) { + required := []string{"oauth-2025-04-20", "interleaved-thinking-2025-05-14"} + incoming := "context-1m-2025-08-07,foo-beta,oauth-2025-04-20" + drop := map[string]struct{}{"context-1m-2025-08-07": {}} + + got := mergeAnthropicBetaDropping(required, incoming, drop) + require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14,foo-beta", got) + require.NotContains(t, got, "context-1m-2025-08-07") +} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index ea803215..4d1dbad0 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -3553,12 +3553,12 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex // messages requests typically use only oauth + interleaved-thinking. // Also drop claude-code beta if a downstream client added it. requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking} - drop := map[string]struct{}{claude.BetaClaudeCode: {}} + drop := map[string]struct{}{claude.BetaClaudeCode: {}, claude.BetaContext1M: {}} req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, drop)) } else { // Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta clientBetaHeader := req.Header.Get("anthropic-beta") - req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, clientBetaHeader)) + req.Header.Set("anthropic-beta", stripBetaToken(s.getBetaHeader(modelID, clientBetaHeader), claude.BetaContext1M)) } } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" { // API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭) @@ -3712,6 +3712,23 @@ func mergeAnthropicBetaDropping(required []string, incoming string, drop map[str return strings.Join(out, ",") } +// stripBetaToken removes a single beta token from a comma-separated header value. +// It short-circuits when the token is not present to avoid unnecessary allocations. +func stripBetaToken(header, token string) string { + if !strings.Contains(header, token) { + return header + } + out := make([]string, 0, 8) + for _, p := range strings.Split(header, ",") { + p = strings.TrimSpace(p) + if p == "" || p == token { + continue + } + out = append(out, p) + } + return strings.Join(out, ",") +} + // applyClaudeCodeMimicHeaders forces "Claude Code-like" request headers. // This mirrors opencode-anthropic-auth behavior: do not trust downstream // headers when using Claude Code-scoped OAuth credentials. @@ -5236,7 +5253,8 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con incomingBeta := req.Header.Get("anthropic-beta") requiredBetas := []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking, claude.BetaTokenCounting} - req.Header.Set("anthropic-beta", mergeAnthropicBeta(requiredBetas, incomingBeta)) + drop := map[string]struct{}{claude.BetaContext1M: {}} + req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, drop)) } else { clientBetaHeader := req.Header.Get("anthropic-beta") if clientBetaHeader == "" { @@ -5246,7 +5264,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con if !strings.Contains(beta, claude.BetaTokenCounting) { beta = beta + "," + claude.BetaTokenCounting } - req.Header.Set("anthropic-beta", beta) + req.Header.Set("anthropic-beta", stripBetaToken(beta, claude.BetaContext1M)) } } } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" { From 36bb327024525f8dfb7139665f30f680f2ab3e4f Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Wed, 18 Feb 2026 20:52:35 +0800 Subject: [PATCH 16/28] =?UTF-8?q?fix:=20=E6=9B=B4=E6=96=B0=20ListWithFilte?= =?UTF-8?q?rs=20=E6=96=B9=E6=B3=95=E4=BB=A5=E6=94=AF=E6=8C=81=20groupID=20?= =?UTF-8?q?=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/internal/handler/sora_gateway_handler_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/internal/handler/sora_gateway_handler_test.go b/backend/internal/handler/sora_gateway_handler_test.go index 3cae5cdd..04a58e49 100644 --- a/backend/internal/handler/sora_gateway_handler_test.go +++ b/backend/internal/handler/sora_gateway_handler_test.go @@ -88,7 +88,7 @@ func (r *stubAccountRepo) Delete(ctx context.Context, id int64) error func (r *stubAccountRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) { return nil, nil, nil } -func (r *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) { +func (r *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) { return nil, nil, nil } func (r *stubAccountRepo) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) { From 900cce20a1d5453aecfa3a2d59bcd662fa3ec9f7 Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Thu, 19 Feb 2026 08:02:56 +0800 Subject: [PATCH 17/28] =?UTF-8?q?feat(sora):=20=E5=AF=B9=E9=BD=90=20Sora?= =?UTF-8?q?=20OAuth=20=E6=B5=81=E7=A8=8B=E5=B9=B6=E9=9A=94=E7=A6=BB?= =?UTF-8?q?=E7=BD=91=E5=85=B3=E8=AF=B7=E6=B1=82=E8=B7=AF=E5=BE=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增并接通 Sora 专用 OAuth 接口与 ST/RT 换取能力 - 完成前端 Sora 授权、RT/ST 手动导入与账号创建流程 - 强化 Sora token 恢复、转发日志与网关路由隔离行为 - 补充后端服务层与路由层相关测试覆盖 Co-Authored-By: Claude Opus 4.6 --- backend/internal/config/config.go | 27 +- .../internal/handler/admin/account_handler.go | 6 + .../handler/admin/openai_oauth_handler.go | 79 ++- .../internal/handler/sora_gateway_handler.go | 66 +- .../handler/sora_gateway_handler_test.go | 3 + backend/internal/pkg/openai/oauth.go | 2 + .../repository/openai_oauth_service.go | 40 +- .../repository/openai_oauth_service_test.go | 54 ++ backend/internal/server/routes/admin.go | 15 + backend/internal/server/routes/gateway.go | 21 +- .../internal/service/account_test_service.go | 129 +++- .../service/account_test_service_sora_test.go | 193 ++++++ backend/internal/service/oauth_service.go | 1 + .../internal/service/openai_oauth_service.go | 131 +++- .../openai_oauth_service_sora_session_test.go | 69 ++ .../openai_oauth_service_state_test.go | 102 +++ .../internal/service/openai_token_provider.go | 12 +- backend/internal/service/sora_client.go | 600 +++++++++++++++++- backend/internal/service/sora_client_test.go | 274 ++++++++ .../internal/service/sora_gateway_service.go | 48 +- .../service/sora_gateway_service_test.go | 36 ++ backend/internal/service/sora_models.go | 39 +- .../internal/service/token_refresh_service.go | 5 +- backend/internal/service/token_refresher.go | 13 +- .../internal/service/token_refresher_test.go | 40 ++ backend/internal/service/wire.go | 14 +- backend/internal/web/embed_on.go | 2 + backend/internal/web/embed_test.go | 2 + deploy/config.example.yaml | 11 + frontend/src/api/admin/accounts.ts | 30 +- .../components/account/CreateAccountModal.vue | 454 +++++++++---- .../account/OAuthAuthorizationFlow.vue | 131 +++- .../components/account/ReAuthAccountModal.vue | 53 +- .../admin/account/AccountTestModal.vue | 5 + .../admin/account/ReAuthAccountModal.vue | 53 +- frontend/src/composables/useAccountOAuth.ts | 2 +- frontend/src/composables/useOpenAIOAuth.ts | 68 +- frontend/src/i18n/locales/en.ts | 7 +- frontend/src/i18n/locales/zh.ts | 7 +- 39 files changed, 2561 insertions(+), 283 deletions(-) create mode 100644 backend/internal/service/account_test_service_sora_test.go create mode 100644 backend/internal/service/openai_oauth_service_sora_session_test.go create mode 100644 backend/internal/service/openai_oauth_service_state_test.go diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index b9f31ba9..8efcb550 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -162,6 +162,8 @@ type TokenRefreshConfig struct { MaxRetries int `mapstructure:"max_retries"` // 重试退避基础时间(秒) RetryBackoffSeconds int `mapstructure:"retry_backoff_seconds"` + // 是否允许 OpenAI 刷新器同步覆盖关联的 Sora 账号 token(默认关闭) + SyncLinkedSoraAccounts bool `mapstructure:"sync_linked_sora_accounts"` } type PricingConfig struct { @@ -269,17 +271,18 @@ type SoraConfig struct { // SoraClientConfig 直连 Sora 客户端配置 type SoraClientConfig struct { - BaseURL string `mapstructure:"base_url"` - TimeoutSeconds int `mapstructure:"timeout_seconds"` - MaxRetries int `mapstructure:"max_retries"` - PollIntervalSeconds int `mapstructure:"poll_interval_seconds"` - MaxPollAttempts int `mapstructure:"max_poll_attempts"` - RecentTaskLimit int `mapstructure:"recent_task_limit"` - RecentTaskLimitMax int `mapstructure:"recent_task_limit_max"` - Debug bool `mapstructure:"debug"` - Headers map[string]string `mapstructure:"headers"` - UserAgent string `mapstructure:"user_agent"` - DisableTLSFingerprint bool `mapstructure:"disable_tls_fingerprint"` + BaseURL string `mapstructure:"base_url"` + TimeoutSeconds int `mapstructure:"timeout_seconds"` + MaxRetries int `mapstructure:"max_retries"` + PollIntervalSeconds int `mapstructure:"poll_interval_seconds"` + MaxPollAttempts int `mapstructure:"max_poll_attempts"` + RecentTaskLimit int `mapstructure:"recent_task_limit"` + RecentTaskLimitMax int `mapstructure:"recent_task_limit_max"` + Debug bool `mapstructure:"debug"` + UseOpenAITokenProvider bool `mapstructure:"use_openai_token_provider"` + Headers map[string]string `mapstructure:"headers"` + UserAgent string `mapstructure:"user_agent"` + DisableTLSFingerprint bool `mapstructure:"disable_tls_fingerprint"` } // SoraStorageConfig 媒体存储配置 @@ -1116,6 +1119,7 @@ func setDefaults() { viper.SetDefault("sora.client.recent_task_limit", 50) viper.SetDefault("sora.client.recent_task_limit_max", 200) viper.SetDefault("sora.client.debug", false) + viper.SetDefault("sora.client.use_openai_token_provider", false) viper.SetDefault("sora.client.headers", map[string]string{}) viper.SetDefault("sora.client.user_agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") viper.SetDefault("sora.client.disable_tls_fingerprint", false) @@ -1137,6 +1141,7 @@ func setDefaults() { viper.SetDefault("token_refresh.refresh_before_expiry_hours", 0.5) // 提前30分钟刷新(适配Google 1小时token) viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次 viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒 + viper.SetDefault("token_refresh.sync_linked_sora_accounts", false) // 默认不跨平台覆盖 Sora token // Gemini OAuth - configure via environment variables or config file // GEMINI_OAUTH_CLIENT_ID and GEMINI_OAUTH_CLIENT_SECRET diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index 7533c70e..79f90b8e 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -1333,6 +1333,12 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) { return } + // Handle Sora accounts + if account.Platform == service.PlatformSora { + response.Success(c, service.DefaultSoraModels(nil)) + return + } + // Handle Claude/Anthropic accounts // For OAuth and Setup-Token accounts: return default models if account.IsOAuth() { diff --git a/backend/internal/handler/admin/openai_oauth_handler.go b/backend/internal/handler/admin/openai_oauth_handler.go index ed86fea9..cf43f89e 100644 --- a/backend/internal/handler/admin/openai_oauth_handler.go +++ b/backend/internal/handler/admin/openai_oauth_handler.go @@ -2,6 +2,7 @@ package admin import ( "strconv" + "strings" "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/response" @@ -16,6 +17,13 @@ type OpenAIOAuthHandler struct { adminService service.AdminService } +func oauthPlatformFromPath(c *gin.Context) string { + if strings.Contains(c.FullPath(), "/admin/sora/") { + return service.PlatformSora + } + return service.PlatformOpenAI +} + // NewOpenAIOAuthHandler creates a new OpenAI OAuth handler func NewOpenAIOAuthHandler(openaiOAuthService *service.OpenAIOAuthService, adminService service.AdminService) *OpenAIOAuthHandler { return &OpenAIOAuthHandler{ @@ -52,6 +60,7 @@ func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) { type OpenAIExchangeCodeRequest struct { SessionID string `json:"session_id" binding:"required"` Code string `json:"code" binding:"required"` + State string `json:"state" binding:"required"` RedirectURI string `json:"redirect_uri"` ProxyID *int64 `json:"proxy_id"` } @@ -68,6 +77,7 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) { tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{ SessionID: req.SessionID, Code: req.Code, + State: req.State, RedirectURI: req.RedirectURI, ProxyID: req.ProxyID, }) @@ -81,18 +91,29 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) { // OpenAIRefreshTokenRequest represents the request for refreshing OpenAI token type OpenAIRefreshTokenRequest struct { - RefreshToken string `json:"refresh_token" binding:"required"` + RefreshToken string `json:"refresh_token"` + RT string `json:"rt"` + ClientID string `json:"client_id"` ProxyID *int64 `json:"proxy_id"` } // RefreshToken refreshes an OpenAI OAuth token // POST /api/v1/admin/openai/refresh-token +// POST /api/v1/admin/sora/rt2at func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) { var req OpenAIRefreshTokenRequest if err := c.ShouldBindJSON(&req); err != nil { response.BadRequest(c, "Invalid request: "+err.Error()) return } + refreshToken := strings.TrimSpace(req.RefreshToken) + if refreshToken == "" { + refreshToken = strings.TrimSpace(req.RT) + } + if refreshToken == "" { + response.BadRequest(c, "refresh_token is required") + return + } var proxyURL string if req.ProxyID != nil { @@ -102,7 +123,7 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) { } } - tokenInfo, err := h.openaiOAuthService.RefreshToken(c.Request.Context(), req.RefreshToken, proxyURL) + tokenInfo, err := h.openaiOAuthService.RefreshTokenWithClientID(c.Request.Context(), refreshToken, proxyURL, strings.TrimSpace(req.ClientID)) if err != nil { response.ErrorFrom(c, err) return @@ -111,8 +132,39 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) { response.Success(c, tokenInfo) } -// RefreshAccountToken refreshes token for a specific OpenAI account +// ExchangeSoraSessionToken exchanges Sora session token to access token +// POST /api/v1/admin/sora/st2at +func (h *OpenAIOAuthHandler) ExchangeSoraSessionToken(c *gin.Context) { + var req struct { + SessionToken string `json:"session_token"` + ST string `json:"st"` + ProxyID *int64 `json:"proxy_id"` + } + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + sessionToken := strings.TrimSpace(req.SessionToken) + if sessionToken == "" { + sessionToken = strings.TrimSpace(req.ST) + } + if sessionToken == "" { + response.BadRequest(c, "session_token is required") + return + } + + tokenInfo, err := h.openaiOAuthService.ExchangeSoraSessionToken(c.Request.Context(), sessionToken, req.ProxyID) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, tokenInfo) +} + +// RefreshAccountToken refreshes token for a specific OpenAI/Sora account // POST /api/v1/admin/openai/accounts/:id/refresh +// POST /api/v1/admin/sora/accounts/:id/refresh func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) { accountID, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil { @@ -127,9 +179,9 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) { return } - // Ensure account is OpenAI platform - if !account.IsOpenAI() { - response.BadRequest(c, "Account is not an OpenAI account") + platform := oauthPlatformFromPath(c) + if account.Platform != platform { + response.BadRequest(c, "Account platform does not match OAuth endpoint") return } @@ -167,12 +219,14 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) { response.Success(c, dto.AccountFromService(updatedAccount)) } -// CreateAccountFromOAuth creates a new OpenAI OAuth account from token info +// CreateAccountFromOAuth creates a new OpenAI/Sora OAuth account from token info // POST /api/v1/admin/openai/create-from-oauth +// POST /api/v1/admin/sora/create-from-oauth func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) { var req struct { SessionID string `json:"session_id" binding:"required"` Code string `json:"code" binding:"required"` + State string `json:"state" binding:"required"` RedirectURI string `json:"redirect_uri"` ProxyID *int64 `json:"proxy_id"` Name string `json:"name"` @@ -189,6 +243,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) { tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{ SessionID: req.SessionID, Code: req.Code, + State: req.State, RedirectURI: req.RedirectURI, ProxyID: req.ProxyID, }) @@ -200,19 +255,25 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) { // Build credentials from token info credentials := h.openaiOAuthService.BuildAccountCredentials(tokenInfo) + platform := oauthPlatformFromPath(c) + // Use email as default name if not provided name := req.Name if name == "" && tokenInfo.Email != "" { name = tokenInfo.Email } if name == "" { - name = "OpenAI OAuth Account" + if platform == service.PlatformSora { + name = "Sora OAuth Account" + } else { + name = "OpenAI OAuth Account" + } } // Create account account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{ Name: name, - Platform: "openai", + Platform: platform, Type: "oauth", Credentials: credentials, ProxyID: req.ProxyID, diff --git a/backend/internal/handler/sora_gateway_handler.go b/backend/internal/handler/sora_gateway_handler.go index 80932899..9c9f53b1 100644 --- a/backend/internal/handler/sora_gateway_handler.go +++ b/backend/internal/handler/sora_gateway_handler.go @@ -212,6 +212,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) { switchCount := 0 failedAccountIDs := make(map[int64]struct{}) lastFailoverStatus := 0 + var lastFailoverBody []byte for { selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "") @@ -224,7 +225,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) return } - h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) + h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverBody, streamStarted) return } account := selection.Account @@ -287,14 +288,19 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) { failedAccountIDs[account.ID] = struct{}{} if switchCount >= maxAccountSwitches { lastFailoverStatus = failoverErr.StatusCode - h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) + lastFailoverBody = failoverErr.ResponseBody + h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverBody, streamStarted) return } lastFailoverStatus = failoverErr.StatusCode + lastFailoverBody = failoverErr.ResponseBody switchCount++ + upstreamErrCode, upstreamErrMsg := extractUpstreamErrorCodeAndMessage(lastFailoverBody) reqLog.Warn("sora.upstream_failover_switching", zap.Int64("account_id", account.ID), zap.Int("upstream_status", failoverErr.StatusCode), + zap.String("upstream_error_code", upstreamErrCode), + zap.String("upstream_error_message", upstreamErrMsg), zap.Int("switch_count", switchCount), zap.Int("max_switches", maxAccountSwitches), ) @@ -360,17 +366,32 @@ func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, s fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted) } -func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) { - status, errType, errMsg := h.mapUpstreamError(statusCode) +func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, responseBody []byte, streamStarted bool) { + status, errType, errMsg := h.mapUpstreamError(statusCode, responseBody) h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted) } -func (h *SoraGatewayHandler) mapUpstreamError(statusCode int) (int, string, string) { +func (h *SoraGatewayHandler) mapUpstreamError(statusCode int, responseBody []byte) (int, string, string) { + upstreamCode, upstreamMessage := extractUpstreamErrorCodeAndMessage(responseBody) + if upstreamMessage != "" { + switch statusCode { + case 401, 403, 404, 500, 502, 503, 504: + return http.StatusBadGateway, "upstream_error", upstreamMessage + case 429: + return http.StatusTooManyRequests, "rate_limit_error", upstreamMessage + } + } + switch statusCode { case 401: return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator" case 403: return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator" + case 404: + if strings.EqualFold(upstreamCode, "unsupported_country_code") { + return http.StatusBadGateway, "upstream_error", "Upstream region capability unavailable for this account, please contact administrator" + } + return http.StatusBadGateway, "upstream_error", "Upstream capability unavailable for this account, please contact administrator" case 429: return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later" case 529: @@ -382,6 +403,41 @@ func (h *SoraGatewayHandler) mapUpstreamError(statusCode int) (int, string, stri } } +func extractUpstreamErrorCodeAndMessage(body []byte) (string, string) { + trimmed := strings.TrimSpace(string(body)) + if trimmed == "" { + return "", "" + } + if !gjson.Valid(trimmed) { + return "", truncateSoraErrorMessage(trimmed, 256) + } + code := strings.TrimSpace(gjson.Get(trimmed, "error.code").String()) + if code == "" { + code = strings.TrimSpace(gjson.Get(trimmed, "code").String()) + } + message := strings.TrimSpace(gjson.Get(trimmed, "error.message").String()) + if message == "" { + message = strings.TrimSpace(gjson.Get(trimmed, "message").String()) + } + if message == "" { + message = strings.TrimSpace(gjson.Get(trimmed, "error.detail").String()) + } + if message == "" { + message = strings.TrimSpace(gjson.Get(trimmed, "detail").String()) + } + return code, truncateSoraErrorMessage(message, 512) +} + +func truncateSoraErrorMessage(s string, maxLen int) string { + if maxLen <= 0 { + return "" + } + if len(s) <= maxLen { + return s + } + return s[:maxLen] + "...(truncated)" +} + func (h *SoraGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) { if streamStarted { flusher, ok := c.Writer.(http.Flusher) diff --git a/backend/internal/handler/sora_gateway_handler_test.go b/backend/internal/handler/sora_gateway_handler_test.go index 04a58e49..39e2eed6 100644 --- a/backend/internal/handler/sora_gateway_handler_test.go +++ b/backend/internal/handler/sora_gateway_handler_test.go @@ -43,6 +43,9 @@ func (s *stubSoraClient) CreateImageTask(ctx context.Context, account *service.A func (s *stubSoraClient) CreateVideoTask(ctx context.Context, account *service.Account, req service.SoraVideoRequest) (string, error) { return "task-video", nil } +func (s *stubSoraClient) EnhancePrompt(ctx context.Context, account *service.Account, prompt, expansionLevel string, durationS int) (string, error) { + return "enhanced prompt", nil +} func (s *stubSoraClient) GetImageTask(ctx context.Context, account *service.Account, taskID string) (*service.SoraImageTaskStatus, error) { return &service.SoraImageTaskStatus{ID: taskID, Status: "completed", URLs: s.imageURLs}, nil } diff --git a/backend/internal/pkg/openai/oauth.go b/backend/internal/pkg/openai/oauth.go index bb120b57..e3b931be 100644 --- a/backend/internal/pkg/openai/oauth.go +++ b/backend/internal/pkg/openai/oauth.go @@ -17,6 +17,8 @@ import ( const ( // OAuth Client ID for OpenAI (Codex CLI official) ClientID = "app_EMoamEEZ73f0CkXaXp7hrann" + // OAuth Client ID for Sora mobile flow (aligned with sora2api) + SoraClientID = "app_LlGpXReQgckcGGUo2JrYvtJK" // OAuth endpoints AuthorizeURL = "https://auth.openai.com/oauth/authorize" diff --git a/backend/internal/repository/openai_oauth_service.go b/backend/internal/repository/openai_oauth_service.go index 394d3a1a..088e7d7f 100644 --- a/backend/internal/repository/openai_oauth_service.go +++ b/backend/internal/repository/openai_oauth_service.go @@ -4,6 +4,7 @@ import ( "context" "net/http" "net/url" + "strings" "time" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" @@ -56,12 +57,49 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie } func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) { + return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, "") +} + +func (s *openaiOAuthService) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) { + if strings.TrimSpace(clientID) != "" { + return s.refreshTokenWithClientID(ctx, refreshToken, proxyURL, strings.TrimSpace(clientID)) + } + + clientIDs := []string{ + openai.ClientID, + openai.SoraClientID, + } + seen := make(map[string]struct{}, len(clientIDs)) + var lastErr error + for _, clientID := range clientIDs { + clientID = strings.TrimSpace(clientID) + if clientID == "" { + continue + } + if _, ok := seen[clientID]; ok { + continue + } + seen[clientID] = struct{}{} + + tokenResp, err := s.refreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID) + if err == nil { + return tokenResp, nil + } + lastErr = err + } + if lastErr != nil { + return nil, lastErr + } + return nil, infraerrors.New(http.StatusBadGateway, "OPENAI_OAUTH_TOKEN_REFRESH_FAILED", "token refresh failed") +} + +func (s *openaiOAuthService) refreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL, clientID string) (*openai.TokenResponse, error) { client := createOpenAIReqClient(proxyURL) formData := url.Values{} formData.Set("grant_type", "refresh_token") formData.Set("refresh_token", refreshToken) - formData.Set("client_id", openai.ClientID) + formData.Set("client_id", clientID) formData.Set("scope", openai.RefreshScopes) var tokenResp openai.TokenResponse diff --git a/backend/internal/repository/openai_oauth_service_test.go b/backend/internal/repository/openai_oauth_service_test.go index f9df08c8..5938272a 100644 --- a/backend/internal/repository/openai_oauth_service_test.go +++ b/backend/internal/repository/openai_oauth_service_test.go @@ -136,6 +136,60 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FormFields() { require.Equal(s.T(), "rt2", resp.RefreshToken) } +func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FallbackToSoraClientID() { + var seenClientIDs []string + s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + clientID := r.PostForm.Get("client_id") + seenClientIDs = append(seenClientIDs, clientID) + if clientID == openai.ClientID { + w.WriteHeader(http.StatusBadRequest) + _, _ = io.WriteString(w, "invalid_grant") + return + } + if clientID == openai.SoraClientID { + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{"access_token":"at-sora","refresh_token":"rt-sora","token_type":"bearer","expires_in":3600}`) + return + } + w.WriteHeader(http.StatusBadRequest) + })) + + resp, err := s.svc.RefreshToken(s.ctx, "rt", "") + require.NoError(s.T(), err, "RefreshToken") + require.Equal(s.T(), "at-sora", resp.AccessToken) + require.Equal(s.T(), "rt-sora", resp.RefreshToken) + require.Equal(s.T(), []string{openai.ClientID, openai.SoraClientID}, seenClientIDs) +} + +func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseProvidedClientID() { + const customClientID = "custom-client-id" + var seenClientIDs []string + s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + clientID := r.PostForm.Get("client_id") + seenClientIDs = append(seenClientIDs, clientID) + if clientID != customClientID { + w.WriteHeader(http.StatusBadRequest) + return + } + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{"access_token":"at-custom","refresh_token":"rt-custom","token_type":"bearer","expires_in":3600}`) + })) + + resp, err := s.svc.RefreshTokenWithClientID(s.ctx, "rt", "", customClientID) + require.NoError(s.T(), err, "RefreshTokenWithClientID") + require.Equal(s.T(), "at-custom", resp.AccessToken) + require.Equal(s.T(), "rt-custom", resp.RefreshToken) + require.Equal(s.T(), []string{customClientID}, seenClientIDs) +} + func (s *OpenAIOAuthServiceSuite) TestNonSuccessStatus_IncludesBody() { s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusBadRequest) diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 57d54a54..7341f85b 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -34,6 +34,8 @@ func RegisterAdminRoutes( // OpenAI OAuth registerOpenAIOAuthRoutes(admin, h) + // Sora OAuth(实现复用 OpenAI OAuth 服务,入口独立) + registerSoraOAuthRoutes(admin, h) // Gemini OAuth registerGeminiOAuthRoutes(admin, h) @@ -276,6 +278,19 @@ func registerOpenAIOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) { } } +func registerSoraOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) { + sora := admin.Group("/sora") + { + sora.POST("/generate-auth-url", h.Admin.OpenAIOAuth.GenerateAuthURL) + sora.POST("/exchange-code", h.Admin.OpenAIOAuth.ExchangeCode) + sora.POST("/refresh-token", h.Admin.OpenAIOAuth.RefreshToken) + sora.POST("/st2at", h.Admin.OpenAIOAuth.ExchangeSoraSessionToken) + sora.POST("/rt2at", h.Admin.OpenAIOAuth.RefreshToken) + sora.POST("/accounts/:id/refresh", h.Admin.OpenAIOAuth.RefreshAccountToken) + sora.POST("/create-from-oauth", h.Admin.OpenAIOAuth.CreateAccountFromOAuth) + } +} + func registerGeminiOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) { gemini := admin.Group("/gemini") { diff --git a/backend/internal/server/routes/gateway.go b/backend/internal/server/routes/gateway.go index 32f34e0c..69881e70 100644 --- a/backend/internal/server/routes/gateway.go +++ b/backend/internal/server/routes/gateway.go @@ -1,6 +1,8 @@ package routes import ( + "net/http" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/handler" "github.com/Wei-Shaw/sub2api/internal/server/middleware" @@ -41,16 +43,15 @@ func RegisterGatewayRoutes( gateway.GET("/usage", h.Gateway.Usage) // OpenAI Responses API gateway.POST("/responses", h.OpenAIGateway.Responses) - } - - // Sora Chat Completions - soraGateway := r.Group("/v1") - soraGateway.Use(soraBodyLimit) - soraGateway.Use(clientRequestID) - soraGateway.Use(opsErrorLogger) - soraGateway.Use(gin.HandlerFunc(apiKeyAuth)) - { - soraGateway.POST("/chat/completions", h.SoraGateway.ChatCompletions) + // 明确阻止旧入口误用到 Sora,避免客户端把 OpenAI Chat Completions 当作 Sora 入口 + gateway.POST("/chat/completions", func(c *gin.Context) { + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "type": "invalid_request_error", + "message": "For Sora, use /sora/v1/chat/completions. OpenAI should use /v1/responses.", + }, + }) + }) } // Gemini 原生 API 兼容层(Gemini SDK/CLI 直连) diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index 093f7d4d..67c9ef0c 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -27,11 +27,13 @@ import ( // sseDataPrefix matches SSE data lines with optional whitespace after colon. // Some upstream APIs return non-standard "data:" without space (should be "data: "). var sseDataPrefix = regexp.MustCompile(`^data:\s*`) +var cloudflareRayPattern = regexp.MustCompile(`(?i)cRay:\s*'([a-z0-9-]+)'`) const ( testClaudeAPIURL = "https://api.anthropic.com/v1/messages" chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses" soraMeAPIURL = "https://sora.chatgpt.com/backend/me" // Sora 用户信息接口,用于测试连接 + soraBillingAPIURL = "https://sora.chatgpt.com/backend/billing/subscriptions" ) // TestEvent represents a SSE event for account testing @@ -502,8 +504,9 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account * if account.ProxyID != nil && account.Proxy != nil { proxyURL = account.Proxy.URL() } + enableSoraTLSFingerprint := s.shouldEnableSoraTLSFingerprint() - resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) + resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, enableSoraTLSFingerprint) if err != nil { return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) } @@ -512,7 +515,10 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account * body, _ := io.ReadAll(resp.Body) if resp.StatusCode != http.StatusOK { - return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, string(body))) + if isCloudflareChallengeResponse(resp.StatusCode, body) { + return s.sendErrorAndEnd(c, formatCloudflareChallengeMessage("Sora request blocked by Cloudflare challenge (HTTP 403). Please switch to a clean proxy/network and retry.", resp.Header, body)) + } + return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, truncateSoraErrorBody(body, 512))) } // 解析 /me 响应,提取用户信息 @@ -531,10 +537,129 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account * s.sendEvent(c, TestEvent{Type: "content", Text: info}) } + // 追加轻量能力检查:订阅信息查询(失败仅告警,不中断连接测试) + subReq, err := http.NewRequestWithContext(ctx, "GET", soraBillingAPIURL, nil) + if err == nil { + subReq.Header.Set("Authorization", "Bearer "+authToken) + subReq.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") + subReq.Header.Set("Accept", "application/json") + + subResp, subErr := s.httpUpstream.DoWithTLS(subReq, proxyURL, account.ID, account.Concurrency, enableSoraTLSFingerprint) + if subErr != nil { + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check skipped: %s", subErr.Error())}) + } else { + subBody, _ := io.ReadAll(subResp.Body) + _ = subResp.Body.Close() + if subResp.StatusCode == http.StatusOK { + if summary := parseSoraSubscriptionSummary(subBody); summary != "" { + s.sendEvent(c, TestEvent{Type: "content", Text: summary}) + } else { + s.sendEvent(c, TestEvent{Type: "content", Text: "Subscription check OK"}) + } + } else { + if isCloudflareChallengeResponse(subResp.StatusCode, subBody) { + s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage("Subscription check blocked by Cloudflare challenge (HTTP 403)", subResp.Header, subBody)}) + } else { + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check returned %d", subResp.StatusCode)}) + } + } + } + } + s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) return nil } +func parseSoraSubscriptionSummary(body []byte) string { + var subResp struct { + Data []struct { + Plan struct { + ID string `json:"id"` + Title string `json:"title"` + } `json:"plan"` + EndTS string `json:"end_ts"` + } `json:"data"` + } + if err := json.Unmarshal(body, &subResp); err != nil { + return "" + } + if len(subResp.Data) == 0 { + return "" + } + + first := subResp.Data[0] + parts := make([]string, 0, 3) + if first.Plan.Title != "" { + parts = append(parts, first.Plan.Title) + } + if first.Plan.ID != "" { + parts = append(parts, first.Plan.ID) + } + if first.EndTS != "" { + parts = append(parts, "end="+first.EndTS) + } + if len(parts) == 0 { + return "" + } + return "Subscription: " + strings.Join(parts, " | ") +} + +func (s *AccountTestService) shouldEnableSoraTLSFingerprint() bool { + if s == nil || s.cfg == nil { + return false + } + return s.cfg.Gateway.TLSFingerprint.Enabled && !s.cfg.Sora.Client.DisableTLSFingerprint +} + +func isCloudflareChallengeResponse(statusCode int, body []byte) bool { + if statusCode != http.StatusForbidden { + return false + } + preview := strings.ToLower(truncateSoraErrorBody(body, 4096)) + return strings.Contains(preview, "window._cf_chl_opt") || + strings.Contains(preview, "just a moment") || + strings.Contains(preview, "enable javascript and cookies to continue") +} + +func formatCloudflareChallengeMessage(base string, headers http.Header, body []byte) string { + rayID := extractCloudflareRayID(headers, body) + if rayID == "" { + return base + } + return fmt.Sprintf("%s (cf-ray: %s)", base, rayID) +} + +func extractCloudflareRayID(headers http.Header, body []byte) string { + if headers != nil { + rayID := strings.TrimSpace(headers.Get("cf-ray")) + if rayID != "" { + return rayID + } + rayID = strings.TrimSpace(headers.Get("Cf-Ray")) + if rayID != "" { + return rayID + } + } + + preview := truncateSoraErrorBody(body, 8192) + matches := cloudflareRayPattern.FindStringSubmatch(preview) + if len(matches) >= 2 { + return strings.TrimSpace(matches[1]) + } + return "" +} + +func truncateSoraErrorBody(body []byte, max int) string { + if max <= 0 { + max = 512 + } + raw := strings.TrimSpace(string(body)) + if len(raw) <= max { + return raw + } + return raw[:max] + "...(truncated)" +} + // testAntigravityAccountConnection tests an Antigravity account's connection // 支持 Claude 和 Gemini 两种协议,使用非流式请求 func (s *AccountTestService) testAntigravityAccountConnection(c *gin.Context, account *Account, modelID string) error { diff --git a/backend/internal/service/account_test_service_sora_test.go b/backend/internal/service/account_test_service_sora_test.go new file mode 100644 index 00000000..fbbc8ff1 --- /dev/null +++ b/backend/internal/service/account_test_service_sora_test.go @@ -0,0 +1,193 @@ +package service + +import ( + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type queuedHTTPUpstream struct { + responses []*http.Response + requests []*http.Request + tlsFlags []bool +} + +func (u *queuedHTTPUpstream) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) { + return nil, fmt.Errorf("unexpected Do call") +} + +func (u *queuedHTTPUpstream) DoWithTLS(req *http.Request, _ string, _ int64, _ int, enableTLSFingerprint bool) (*http.Response, error) { + u.requests = append(u.requests, req) + u.tlsFlags = append(u.tlsFlags, enableTLSFingerprint) + if len(u.responses) == 0 { + return nil, fmt.Errorf("no mocked response") + } + resp := u.responses[0] + u.responses = u.responses[1:] + return resp, nil +} + +func newJSONResponse(status int, body string) *http.Response { + return &http.Response{ + StatusCode: status, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(body)), + } +} + +func newJSONResponseWithHeader(status int, body, key, value string) *http.Response { + resp := newJSONResponse(status, body) + resp.Header.Set(key, value) + return resp +} + +func newSoraTestContext() (*gin.Context, *httptest.ResponseRecorder) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil) + return c, rec +} + +func TestAccountTestService_testSoraAccountConnection_WithSubscription(t *testing.T) { + upstream := &queuedHTTPUpstream{ + responses: []*http.Response{ + newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`), + newJSONResponse(http.StatusOK, `{"data":[{"plan":{"id":"chatgpt_plus","title":"ChatGPT Plus"},"end_ts":"2026-12-31T00:00:00Z"}]}`), + }, + } + svc := &AccountTestService{ + httpUpstream: upstream, + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + TLSFingerprint: config.TLSFingerprintConfig{ + Enabled: true, + }, + }, + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + DisableTLSFingerprint: false, + }, + }, + }, + } + account := &Account{ + ID: 1, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "test_token", + }, + } + + c, rec := newSoraTestContext() + err := svc.testSoraAccountConnection(c, account) + + require.NoError(t, err) + require.Len(t, upstream.requests, 2) + require.Equal(t, soraMeAPIURL, upstream.requests[0].URL.String()) + require.Equal(t, soraBillingAPIURL, upstream.requests[1].URL.String()) + require.Equal(t, "Bearer test_token", upstream.requests[0].Header.Get("Authorization")) + require.Equal(t, "Bearer test_token", upstream.requests[1].Header.Get("Authorization")) + require.Equal(t, []bool{true, true}, upstream.tlsFlags) + + body := rec.Body.String() + require.Contains(t, body, `"type":"test_start"`) + require.Contains(t, body, "Sora connection OK - Email: demo@example.com") + require.Contains(t, body, "Subscription: ChatGPT Plus | chatgpt_plus | end=2026-12-31T00:00:00Z") + require.Contains(t, body, `"type":"test_complete","success":true`) +} + +func TestAccountTestService_testSoraAccountConnection_SubscriptionFailedStillSuccess(t *testing.T) { + upstream := &queuedHTTPUpstream{ + responses: []*http.Response{ + newJSONResponse(http.StatusOK, `{"name":"demo-user"}`), + newJSONResponse(http.StatusForbidden, `{"error":{"message":"forbidden"}}`), + }, + } + svc := &AccountTestService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "test_token", + }, + } + + c, rec := newSoraTestContext() + err := svc.testSoraAccountConnection(c, account) + + require.NoError(t, err) + require.Len(t, upstream.requests, 2) + body := rec.Body.String() + require.Contains(t, body, "Sora connection OK - User: demo-user") + require.Contains(t, body, "Subscription check returned 403") + require.Contains(t, body, `"type":"test_complete","success":true`) +} + +func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge(t *testing.T) { + upstream := &queuedHTTPUpstream{ + responses: []*http.Response{ + newJSONResponseWithHeader(http.StatusForbidden, `Just a moment...`, "cf-ray", "9cff2d62d83bb98d"), + }, + } + svc := &AccountTestService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "test_token", + }, + } + + c, rec := newSoraTestContext() + err := svc.testSoraAccountConnection(c, account) + + require.Error(t, err) + require.Contains(t, err.Error(), "Cloudflare challenge") + require.Contains(t, err.Error(), "cf-ray: 9cff2d62d83bb98d") + body := rec.Body.String() + require.Contains(t, body, `"type":"error"`) + require.Contains(t, body, "Cloudflare challenge") + require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d") +} + +func TestAccountTestService_testSoraAccountConnection_SubscriptionCloudflareChallengeWithRay(t *testing.T) { + upstream := &queuedHTTPUpstream{ + responses: []*http.Response{ + newJSONResponse(http.StatusOK, `{"name":"demo-user"}`), + newJSONResponse(http.StatusForbidden, `Just a moment...`), + }, + } + svc := &AccountTestService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "test_token", + }, + } + + c, rec := newSoraTestContext() + err := svc.testSoraAccountConnection(c, account) + + require.NoError(t, err) + body := rec.Body.String() + require.Contains(t, body, "Subscription check blocked by Cloudflare challenge (HTTP 403)") + require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d") + require.Contains(t, body, `"type":"test_complete","success":true`) +} diff --git a/backend/internal/service/oauth_service.go b/backend/internal/service/oauth_service.go index e247e654..6f6261d8 100644 --- a/backend/internal/service/oauth_service.go +++ b/backend/internal/service/oauth_service.go @@ -14,6 +14,7 @@ import ( type OpenAIOAuthClient interface { ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) + RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) } // ClaudeOAuthClient handles HTTP requests for Claude OAuth flows diff --git a/backend/internal/service/openai_oauth_service.go b/backend/internal/service/openai_oauth_service.go index ca7470b9..087ad4ec 100644 --- a/backend/internal/service/openai_oauth_service.go +++ b/backend/internal/service/openai_oauth_service.go @@ -2,13 +2,20 @@ package service import ( "context" + "crypto/subtle" + "encoding/json" + "io" "net/http" + "net/url" + "strings" "time" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" ) +var openAISoraSessionAuthURL = "https://sora.chatgpt.com/api/auth/session" + // OpenAIOAuthService handles OpenAI OAuth authentication flows type OpenAIOAuthService struct { sessionStore *openai.SessionStore @@ -92,6 +99,7 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64 type OpenAIExchangeCodeInput struct { SessionID string Code string + State string RedirectURI string ProxyID *int64 } @@ -116,6 +124,12 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch if !ok { return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_SESSION_NOT_FOUND", "session not found or expired") } + if input.State == "" { + return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_STATE_REQUIRED", "oauth state is required") + } + if subtle.ConstantTimeCompare([]byte(input.State), []byte(session.State)) != 1 { + return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_STATE", "invalid oauth state") + } // Get proxy URL: prefer input.ProxyID, fallback to session.ProxyURL proxyURL := session.ProxyURL @@ -173,7 +187,12 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch // RefreshToken refreshes an OpenAI OAuth token func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken string, proxyURL string) (*OpenAITokenInfo, error) { - tokenResp, err := s.oauthClient.RefreshToken(ctx, refreshToken, proxyURL) + return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, "") +} + +// RefreshTokenWithClientID refreshes an OpenAI/Sora OAuth token with optional client_id. +func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refreshToken string, proxyURL string, clientID string) (*OpenAITokenInfo, error) { + tokenResp, err := s.oauthClient.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID) if err != nil { return nil, err } @@ -205,13 +224,83 @@ func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken stri return tokenInfo, nil } -// RefreshAccountToken refreshes token for an OpenAI account -func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) { - if !account.IsOpenAI() { - return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI account") +// ExchangeSoraSessionToken exchanges Sora session_token to access_token. +func (s *OpenAIOAuthService) ExchangeSoraSessionToken(ctx context.Context, sessionToken string, proxyID *int64) (*OpenAITokenInfo, error) { + if strings.TrimSpace(sessionToken) == "" { + return nil, infraerrors.New(http.StatusBadRequest, "SORA_SESSION_TOKEN_REQUIRED", "session_token is required") } - refreshToken := account.GetOpenAIRefreshToken() + proxyURL, err := s.resolveProxyURL(ctx, proxyID) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, openAISoraSessionAuthURL, nil) + if err != nil { + return nil, infraerrors.Newf(http.StatusInternalServerError, "SORA_SESSION_REQUEST_BUILD_FAILED", "failed to build request: %v", err) + } + req.Header.Set("Cookie", "__Secure-next-auth.session-token="+strings.TrimSpace(sessionToken)) + req.Header.Set("Accept", "application/json") + req.Header.Set("Origin", "https://sora.chatgpt.com") + req.Header.Set("Referer", "https://sora.chatgpt.com/") + req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") + + client := newOpenAIOAuthHTTPClient(proxyURL) + resp, err := client.Do(req) + if err != nil { + return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_REQUEST_FAILED", "request failed: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + if resp.StatusCode != http.StatusOK { + return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_EXCHANGE_FAILED", "status %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) + } + + var sessionResp struct { + AccessToken string `json:"accessToken"` + Expires string `json:"expires"` + User struct { + Email string `json:"email"` + Name string `json:"name"` + } `json:"user"` + } + if err := json.Unmarshal(body, &sessionResp); err != nil { + return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_PARSE_FAILED", "failed to parse response: %v", err) + } + if strings.TrimSpace(sessionResp.AccessToken) == "" { + return nil, infraerrors.New(http.StatusBadGateway, "SORA_SESSION_ACCESS_TOKEN_MISSING", "session exchange response missing access token") + } + + expiresAt := time.Now().Add(time.Hour).Unix() + if strings.TrimSpace(sessionResp.Expires) != "" { + if parsed, parseErr := time.Parse(time.RFC3339, sessionResp.Expires); parseErr == nil { + expiresAt = parsed.Unix() + } + } + expiresIn := expiresAt - time.Now().Unix() + if expiresIn < 0 { + expiresIn = 0 + } + + return &OpenAITokenInfo{ + AccessToken: strings.TrimSpace(sessionResp.AccessToken), + ExpiresIn: expiresIn, + ExpiresAt: expiresAt, + Email: strings.TrimSpace(sessionResp.User.Email), + }, nil +} + +// RefreshAccountToken refreshes token for an OpenAI/Sora OAuth account +func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) { + if account.Platform != PlatformOpenAI && account.Platform != PlatformSora { + return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI/Sora account") + } + if account.Type != AccountTypeOAuth { + return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT_TYPE", "account is not an OAuth account") + } + + refreshToken := account.GetCredential("refresh_token") if refreshToken == "" { return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_NO_REFRESH_TOKEN", "no refresh token available") } @@ -224,7 +313,8 @@ func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *A } } - return s.RefreshToken(ctx, refreshToken, proxyURL) + clientID := account.GetCredential("client_id") + return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID) } // BuildAccountCredentials builds credentials map from token info @@ -260,3 +350,30 @@ func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo) func (s *OpenAIOAuthService) Stop() { s.sessionStore.Stop() } + +func (s *OpenAIOAuthService) resolveProxyURL(ctx context.Context, proxyID *int64) (string, error) { + if proxyID == nil { + return "", nil + } + proxy, err := s.proxyRepo.GetByID(ctx, *proxyID) + if err != nil { + return "", infraerrors.Newf(http.StatusBadRequest, "OPENAI_OAUTH_PROXY_NOT_FOUND", "proxy not found: %v", err) + } + if proxy == nil { + return "", nil + } + return proxy.URL(), nil +} + +func newOpenAIOAuthHTTPClient(proxyURL string) *http.Client { + transport := &http.Transport{} + if strings.TrimSpace(proxyURL) != "" { + if parsed, err := url.Parse(proxyURL); err == nil && parsed.Host != "" { + transport.Proxy = http.ProxyURL(parsed) + } + } + return &http.Client{ + Timeout: 120 * time.Second, + Transport: transport, + } +} diff --git a/backend/internal/service/openai_oauth_service_sora_session_test.go b/backend/internal/service/openai_oauth_service_sora_session_test.go new file mode 100644 index 00000000..fb76f6c1 --- /dev/null +++ b/backend/internal/service/openai_oauth_service_sora_session_test.go @@ -0,0 +1,69 @@ +package service + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/stretchr/testify/require" +) + +type openaiOAuthClientNoopStub struct{} + +func (s *openaiOAuthClientNoopStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) { + return nil, errors.New("not implemented") +} + +func (s *openaiOAuthClientNoopStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) { + return nil, errors.New("not implemented") +} + +func (s *openaiOAuthClientNoopStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) { + return nil, errors.New("not implemented") +} + +func TestOpenAIOAuthService_ExchangeSoraSessionToken_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodGet, r.Method) + require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=st-token") + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`)) + })) + defer server.Close() + + origin := openAISoraSessionAuthURL + openAISoraSessionAuthURL = server.URL + defer func() { openAISoraSessionAuthURL = origin }() + + svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{}) + defer svc.Stop() + + info, err := svc.ExchangeSoraSessionToken(context.Background(), "st-token", nil) + require.NoError(t, err) + require.NotNil(t, info) + require.Equal(t, "at-token", info.AccessToken) + require.Equal(t, "demo@example.com", info.Email) + require.Greater(t, info.ExpiresAt, int64(0)) +} + +func TestOpenAIOAuthService_ExchangeSoraSessionToken_MissingAccessToken(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"expires":"2099-01-01T00:00:00Z"}`)) + })) + defer server.Close() + + origin := openAISoraSessionAuthURL + openAISoraSessionAuthURL = server.URL + defer func() { openAISoraSessionAuthURL = origin }() + + svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{}) + defer svc.Stop() + + _, err := svc.ExchangeSoraSessionToken(context.Background(), "st-token", nil) + require.Error(t, err) + require.Contains(t, err.Error(), "missing access token") +} diff --git a/backend/internal/service/openai_oauth_service_state_test.go b/backend/internal/service/openai_oauth_service_state_test.go new file mode 100644 index 00000000..0a2a195f --- /dev/null +++ b/backend/internal/service/openai_oauth_service_state_test.go @@ -0,0 +1,102 @@ +package service + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/stretchr/testify/require" +) + +type openaiOAuthClientStateStub struct { + exchangeCalled int32 +} + +func (s *openaiOAuthClientStateStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) { + atomic.AddInt32(&s.exchangeCalled, 1) + return &openai.TokenResponse{ + AccessToken: "at", + RefreshToken: "rt", + ExpiresIn: 3600, + }, nil +} + +func (s *openaiOAuthClientStateStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) { + return nil, errors.New("not implemented") +} + +func (s *openaiOAuthClientStateStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) { + return s.RefreshToken(ctx, refreshToken, proxyURL) +} + +func TestOpenAIOAuthService_ExchangeCode_StateRequired(t *testing.T) { + client := &openaiOAuthClientStateStub{} + svc := NewOpenAIOAuthService(nil, client) + defer svc.Stop() + + svc.sessionStore.Set("sid", &openai.OAuthSession{ + State: "expected-state", + CodeVerifier: "verifier", + RedirectURI: openai.DefaultRedirectURI, + CreatedAt: time.Now(), + }) + + _, err := svc.ExchangeCode(context.Background(), &OpenAIExchangeCodeInput{ + SessionID: "sid", + Code: "auth-code", + }) + require.Error(t, err) + require.Contains(t, err.Error(), "oauth state is required") + require.Equal(t, int32(0), atomic.LoadInt32(&client.exchangeCalled)) +} + +func TestOpenAIOAuthService_ExchangeCode_StateMismatch(t *testing.T) { + client := &openaiOAuthClientStateStub{} + svc := NewOpenAIOAuthService(nil, client) + defer svc.Stop() + + svc.sessionStore.Set("sid", &openai.OAuthSession{ + State: "expected-state", + CodeVerifier: "verifier", + RedirectURI: openai.DefaultRedirectURI, + CreatedAt: time.Now(), + }) + + _, err := svc.ExchangeCode(context.Background(), &OpenAIExchangeCodeInput{ + SessionID: "sid", + Code: "auth-code", + State: "wrong-state", + }) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid oauth state") + require.Equal(t, int32(0), atomic.LoadInt32(&client.exchangeCalled)) +} + +func TestOpenAIOAuthService_ExchangeCode_StateMatch(t *testing.T) { + client := &openaiOAuthClientStateStub{} + svc := NewOpenAIOAuthService(nil, client) + defer svc.Stop() + + svc.sessionStore.Set("sid", &openai.OAuthSession{ + State: "expected-state", + CodeVerifier: "verifier", + RedirectURI: openai.DefaultRedirectURI, + CreatedAt: time.Now(), + }) + + info, err := svc.ExchangeCode(context.Background(), &OpenAIExchangeCodeInput{ + SessionID: "sid", + Code: "auth-code", + State: "expected-state", + }) + require.NoError(t, err) + require.NotNil(t, info) + require.Equal(t, "at", info.AccessToken) + require.Equal(t, int32(1), atomic.LoadInt32(&client.exchangeCalled)) + + _, ok := svc.sessionStore.Get("sid") + require.False(t, ok) +} diff --git a/backend/internal/service/openai_token_provider.go b/backend/internal/service/openai_token_provider.go index 3842f0a4..a8a6b96c 100644 --- a/backend/internal/service/openai_token_provider.go +++ b/backend/internal/service/openai_token_provider.go @@ -157,7 +157,11 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou } expiresAt = account.GetCredentialAsTime("expires_at") if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew { - if p.openAIOAuthService == nil { + if account.Platform == PlatformSora { + slog.Debug("openai_token_refresh_skipped_for_sora", "account_id", account.ID) + // Sora 账号不走 OpenAI OAuth 刷新,交由 Sora 客户端的 ST/RT 恢复链路处理。 + refreshFailed = true + } else if p.openAIOAuthService == nil { slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID) p.metrics.refreshFailure.Add(1) refreshFailed = true // 无法刷新,标记失败 @@ -206,7 +210,11 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou // 仅在 expires_at 已过期/接近过期时才执行无锁刷新 if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew { - if p.openAIOAuthService == nil { + if account.Platform == PlatformSora { + slog.Debug("openai_token_refresh_skipped_for_sora_degraded", "account_id", account.ID) + // Sora 账号不走 OpenAI OAuth 刷新,交由 Sora 客户端的 ST/RT 恢复链路处理。 + refreshFailed = true + } else if p.openAIOAuthService == nil { slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID) p.metrics.refreshFailure.Add(1) refreshFailed = true diff --git a/backend/internal/service/sora_client.go b/backend/internal/service/sora_client.go index de097d5e..38be7a04 100644 --- a/backend/internal/service/sora_client.go +++ b/backend/internal/service/sora_client.go @@ -17,12 +17,15 @@ import ( "net/textproto" "net/url" "path" + "sort" "strconv" "strings" "sync" "time" "github.com/Wei-Shaw/sub2api/internal/config" + openaioauth "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/Wei-Shaw/sub2api/internal/util/logredact" "github.com/google/uuid" "github.com/tidwall/gjson" "golang.org/x/crypto/sha3" @@ -34,6 +37,11 @@ const ( soraDefaultUserAgent = "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)" ) +var ( + soraSessionAuthURL = "https://sora.chatgpt.com/api/auth/session" + soraOAuthTokenURL = "https://auth.openai.com/oauth/token" +) + const ( soraPowMaxIteration = 500000 ) @@ -96,6 +104,7 @@ type SoraClient interface { UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) + EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error) } @@ -157,26 +166,94 @@ func (e *SoraUpstreamError) Error() string { // SoraDirectClient 直连 Sora 实现 type SoraDirectClient struct { - cfg *config.Config - httpUpstream HTTPUpstream - tokenProvider *OpenAITokenProvider + cfg *config.Config + httpUpstream HTTPUpstream + tokenProvider *OpenAITokenProvider + accountRepo AccountRepository + soraAccountRepo SoraAccountRepository + baseURL string } // NewSoraDirectClient 创建 Sora 直连客户端 func NewSoraDirectClient(cfg *config.Config, httpUpstream HTTPUpstream, tokenProvider *OpenAITokenProvider) *SoraDirectClient { + baseURL := "" + if cfg != nil { + rawBaseURL := strings.TrimRight(strings.TrimSpace(cfg.Sora.Client.BaseURL), "/") + baseURL = normalizeSoraBaseURL(rawBaseURL) + if rawBaseURL != "" && baseURL != rawBaseURL { + log.Printf("[SoraClient] normalized base_url from %s to %s", sanitizeSoraLogURL(rawBaseURL), sanitizeSoraLogURL(baseURL)) + } + } return &SoraDirectClient{ cfg: cfg, httpUpstream: httpUpstream, tokenProvider: tokenProvider, + baseURL: baseURL, } } +func (c *SoraDirectClient) SetAccountRepositories(accountRepo AccountRepository, soraAccountRepo SoraAccountRepository) { + if c == nil { + return + } + c.accountRepo = accountRepo + c.soraAccountRepo = soraAccountRepo +} + // Enabled 判断是否启用 Sora 直连 func (c *SoraDirectClient) Enabled() bool { - if c == nil || c.cfg == nil { + if c == nil { return false } - return strings.TrimSpace(c.cfg.Sora.Client.BaseURL) != "" + if strings.TrimSpace(c.baseURL) != "" { + return true + } + if c.cfg == nil { + return false + } + return strings.TrimSpace(normalizeSoraBaseURL(c.cfg.Sora.Client.BaseURL)) != "" +} + +// PreflightCheck 在创建任务前执行账号能力预检。 +// 当前仅对视频模型执行 /nf/check 预检,用于提前识别额度耗尽或能力缺失。 +func (c *SoraDirectClient) PreflightCheck(ctx context.Context, account *Account, requestedModel string, modelCfg SoraModelConfig) error { + if modelCfg.Type != "video" { + return nil + } + token, err := c.getAccessToken(ctx, account) + if err != nil { + return err + } + headers := c.buildBaseHeaders(token, c.defaultUserAgent()) + headers.Set("Accept", "application/json") + body, _, err := c.doRequest(ctx, account, http.MethodGet, c.buildURL("/nf/check"), headers, nil, false) + if err != nil { + var upstreamErr *SoraUpstreamError + if errors.As(err, &upstreamErr) && upstreamErr.StatusCode == http.StatusNotFound { + return &SoraUpstreamError{ + StatusCode: http.StatusForbidden, + Message: "当前账号未开通 Sora2 能力或无可用配额", + Headers: upstreamErr.Headers, + Body: upstreamErr.Body, + } + } + return err + } + + rateLimitReached := gjson.GetBytes(body, "rate_limit_and_credit_balance.rate_limit_reached").Bool() + remaining := gjson.GetBytes(body, "rate_limit_and_credit_balance.estimated_num_videos_remaining") + if rateLimitReached || (remaining.Exists() && remaining.Int() <= 0) { + msg := "当前账号 Sora2 可用配额不足" + if requestedModel != "" { + msg = fmt.Sprintf("当前账号 %s 可用配额不足", requestedModel) + } + return &SoraUpstreamError{ + StatusCode: http.StatusTooManyRequests, + Message: msg, + Headers: http.Header{}, + } + } + return nil } func (c *SoraDirectClient) UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) { @@ -347,6 +424,45 @@ func (c *SoraDirectClient) CreateVideoTask(ctx context.Context, account *Account return taskID, nil } +func (c *SoraDirectClient) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return "", err + } + if strings.TrimSpace(expansionLevel) == "" { + expansionLevel = "medium" + } + if durationS <= 0 { + durationS = 10 + } + + payload := map[string]any{ + "prompt": prompt, + "expansion_level": expansionLevel, + "duration_s": durationS, + } + body, err := json.Marshal(payload) + if err != nil { + return "", err + } + + headers := c.buildBaseHeaders(token, c.defaultUserAgent()) + headers.Set("Content-Type", "application/json") + headers.Set("Accept", "application/json") + headers.Set("Origin", "https://sora.chatgpt.com") + headers.Set("Referer", "https://sora.chatgpt.com/") + + respBody, _, err := c.doRequest(ctx, account, http.MethodPost, c.buildURL("/editor/enhance_prompt"), headers, bytes.NewReader(body), false) + if err != nil { + return "", err + } + enhancedPrompt := strings.TrimSpace(gjson.GetBytes(respBody, "enhanced_prompt").String()) + if enhancedPrompt == "" { + return "", errors.New("enhance_prompt response missing enhanced_prompt") + } + return enhancedPrompt, nil +} + func (c *SoraDirectClient) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) { status, found, err := c.fetchRecentImageTask(ctx, account, taskID, c.recentTaskLimit()) if err != nil { @@ -512,9 +628,10 @@ func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, t } func (c *SoraDirectClient) buildURL(endpoint string) string { - base := "" - if c != nil && c.cfg != nil { - base = strings.TrimRight(strings.TrimSpace(c.cfg.Sora.Client.BaseURL), "/") + base := strings.TrimRight(strings.TrimSpace(c.baseURL), "/") + if base == "" && c != nil && c.cfg != nil { + base = normalizeSoraBaseURL(c.cfg.Sora.Client.BaseURL) + c.baseURL = base } if base == "" { return endpoint @@ -540,14 +657,257 @@ func (c *SoraDirectClient) getAccessToken(ctx context.Context, account *Account) if account == nil { return "", errors.New("account is nil") } - if c.tokenProvider != nil { - return c.tokenProvider.GetAccessToken(ctx, account) + + allowProvider := c.allowOpenAITokenProvider(account) + var providerErr error + if allowProvider && c.tokenProvider != nil { + token, err := c.tokenProvider.GetAccessToken(ctx, account) + if err == nil && strings.TrimSpace(token) != "" { + c.logTokenSource(account, "openai_token_provider") + return token, nil + } + providerErr = err + if err != nil && c.debugEnabled() { + c.debugLogf( + "token_provider_failed account_id=%d platform=%s err=%s", + account.ID, + account.Platform, + logredact.RedactText(err.Error()), + ) + } } token := strings.TrimSpace(account.GetCredential("access_token")) - if token == "" { - return "", errors.New("access_token not found") + if token != "" { + expiresAt := account.GetCredentialAsTime("expires_at") + if expiresAt != nil && time.Until(*expiresAt) <= 2*time.Minute { + refreshed, refreshErr := c.recoverAccessToken(ctx, account, "access_token_expiring") + if refreshErr == nil && strings.TrimSpace(refreshed) != "" { + c.logTokenSource(account, "refresh_token_recovered") + return refreshed, nil + } + if refreshErr != nil && c.debugEnabled() { + c.debugLogf("token_refresh_before_use_failed account_id=%d err=%s", account.ID, logredact.RedactText(refreshErr.Error())) + } + } + c.logTokenSource(account, "account_credentials") + return token, nil } - return token, nil + + recovered, recoverErr := c.recoverAccessToken(ctx, account, "access_token_missing") + if recoverErr == nil && strings.TrimSpace(recovered) != "" { + c.logTokenSource(account, "session_or_refresh_recovered") + return recovered, nil + } + if recoverErr != nil && c.debugEnabled() { + c.debugLogf("token_recover_failed account_id=%d platform=%s err=%s", account.ID, account.Platform, logredact.RedactText(recoverErr.Error())) + } + if providerErr != nil { + return "", providerErr + } + if c.tokenProvider != nil && !allowProvider { + c.logTokenSource(account, "account_credentials(provider_disabled)") + } + return "", errors.New("access_token not found") +} + +func (c *SoraDirectClient) recoverAccessToken(ctx context.Context, account *Account, reason string) (string, error) { + if account == nil { + return "", errors.New("account is nil") + } + + if sessionToken := strings.TrimSpace(account.GetCredential("session_token")); sessionToken != "" { + accessToken, expiresAt, err := c.exchangeSessionToken(ctx, account, sessionToken) + if err == nil && strings.TrimSpace(accessToken) != "" { + c.applyRecoveredToken(ctx, account, accessToken, "", expiresAt, sessionToken) + c.logTokenRecover(account, "session_token", reason, true, nil) + return accessToken, nil + } + c.logTokenRecover(account, "session_token", reason, false, err) + } + + refreshToken := strings.TrimSpace(account.GetCredential("refresh_token")) + if refreshToken == "" { + return "", errors.New("session_token/refresh_token not found") + } + accessToken, newRefreshToken, expiresAt, err := c.exchangeRefreshToken(ctx, account, refreshToken) + if err != nil { + c.logTokenRecover(account, "refresh_token", reason, false, err) + return "", err + } + if strings.TrimSpace(accessToken) == "" { + return "", errors.New("refreshed access_token is empty") + } + c.applyRecoveredToken(ctx, account, accessToken, newRefreshToken, expiresAt, "") + c.logTokenRecover(account, "refresh_token", reason, true, nil) + return accessToken, nil +} + +func (c *SoraDirectClient) exchangeSessionToken(ctx context.Context, account *Account, sessionToken string) (string, string, error) { + headers := http.Header{} + headers.Set("Cookie", "__Secure-next-auth.session-token="+sessionToken) + headers.Set("Accept", "application/json") + headers.Set("Origin", "https://sora.chatgpt.com") + headers.Set("Referer", "https://sora.chatgpt.com/") + headers.Set("User-Agent", c.defaultUserAgent()) + body, _, err := c.doRequest(ctx, account, http.MethodGet, soraSessionAuthURL, headers, nil, false) + if err != nil { + return "", "", err + } + accessToken := strings.TrimSpace(gjson.GetBytes(body, "accessToken").String()) + if accessToken == "" { + return "", "", errors.New("session exchange missing accessToken") + } + expiresAt := strings.TrimSpace(gjson.GetBytes(body, "expires").String()) + return accessToken, expiresAt, nil +} + +func (c *SoraDirectClient) exchangeRefreshToken(ctx context.Context, account *Account, refreshToken string) (string, string, string, error) { + clientIDs := []string{ + strings.TrimSpace(account.GetCredential("client_id")), + openaioauth.SoraClientID, + openaioauth.ClientID, + } + tried := make(map[string]struct{}, len(clientIDs)) + var lastErr error + + for _, clientID := range clientIDs { + if clientID == "" { + continue + } + if _, ok := tried[clientID]; ok { + continue + } + tried[clientID] = struct{}{} + + payload := map[string]any{ + "client_id": clientID, + "grant_type": "refresh_token", + "refresh_token": refreshToken, + "redirect_uri": "com.openai.chat://auth0.openai.com/ios/com.openai.chat/callback", + } + bodyBytes, err := json.Marshal(payload) + if err != nil { + return "", "", "", err + } + headers := http.Header{} + headers.Set("Accept", "application/json") + headers.Set("Content-Type", "application/json") + headers.Set("User-Agent", c.defaultUserAgent()) + + respBody, _, err := c.doRequest(ctx, account, http.MethodPost, soraOAuthTokenURL, headers, bytes.NewReader(bodyBytes), false) + if err != nil { + lastErr = err + if c.debugEnabled() { + c.debugLogf("refresh_token_exchange_failed account_id=%d client_id=%s err=%s", account.ID, clientID, logredact.RedactText(err.Error())) + } + continue + } + accessToken := strings.TrimSpace(gjson.GetBytes(respBody, "access_token").String()) + if accessToken == "" { + lastErr = errors.New("oauth refresh response missing access_token") + continue + } + newRefreshToken := strings.TrimSpace(gjson.GetBytes(respBody, "refresh_token").String()) + expiresIn := gjson.GetBytes(respBody, "expires_in").Int() + expiresAt := "" + if expiresIn > 0 { + expiresAt = time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339) + } + return accessToken, newRefreshToken, expiresAt, nil + } + + if lastErr != nil { + return "", "", "", lastErr + } + return "", "", "", errors.New("no available client_id for refresh_token exchange") +} + +func (c *SoraDirectClient) applyRecoveredToken(ctx context.Context, account *Account, accessToken, refreshToken, expiresAt, sessionToken string) { + if account == nil { + return + } + if account.Credentials == nil { + account.Credentials = make(map[string]any) + } + if strings.TrimSpace(accessToken) != "" { + account.Credentials["access_token"] = accessToken + } + if strings.TrimSpace(refreshToken) != "" { + account.Credentials["refresh_token"] = refreshToken + } + if strings.TrimSpace(expiresAt) != "" { + account.Credentials["expires_at"] = expiresAt + } + if strings.TrimSpace(sessionToken) != "" { + account.Credentials["session_token"] = sessionToken + } + + if c.accountRepo != nil { + if err := c.accountRepo.Update(ctx, account); err != nil { + if c.debugEnabled() { + c.debugLogf("persist_recovered_token_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error())) + } + } + } + c.updateSoraAccountExtension(ctx, account, accessToken, refreshToken, sessionToken) +} + +func (c *SoraDirectClient) updateSoraAccountExtension(ctx context.Context, account *Account, accessToken, refreshToken, sessionToken string) { + if c == nil || c.soraAccountRepo == nil || account == nil || account.ID <= 0 { + return + } + updates := make(map[string]any) + if strings.TrimSpace(accessToken) != "" && strings.TrimSpace(refreshToken) != "" { + updates["access_token"] = accessToken + updates["refresh_token"] = refreshToken + } + if strings.TrimSpace(sessionToken) != "" { + updates["session_token"] = sessionToken + } + if len(updates) == 0 { + return + } + if err := c.soraAccountRepo.Upsert(ctx, account.ID, updates); err != nil && c.debugEnabled() { + c.debugLogf("persist_sora_extension_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error())) + } +} + +func (c *SoraDirectClient) logTokenRecover(account *Account, source, reason string, success bool, err error) { + if !c.debugEnabled() || account == nil { + return + } + if success { + c.debugLogf("token_recover_success account_id=%d platform=%s source=%s reason=%s", account.ID, account.Platform, source, reason) + return + } + if err == nil { + c.debugLogf("token_recover_failed account_id=%d platform=%s source=%s reason=%s", account.ID, account.Platform, source, reason) + return + } + c.debugLogf("token_recover_failed account_id=%d platform=%s source=%s reason=%s err=%s", account.ID, account.Platform, source, reason, logredact.RedactText(err.Error())) +} + +func (c *SoraDirectClient) allowOpenAITokenProvider(account *Account) bool { + if c == nil || c.tokenProvider == nil { + return false + } + if account != nil && account.Platform == PlatformSora { + return c.cfg != nil && c.cfg.Sora.Client.UseOpenAITokenProvider + } + return true +} + +func (c *SoraDirectClient) logTokenSource(account *Account, source string) { + if !c.debugEnabled() || account == nil { + return + } + c.debugLogf( + "token_selected account_id=%d platform=%s account_type=%s source=%s", + account.ID, + account.Platform, + account.Type, + source, + ) } func (c *SoraDirectClient) buildBaseHeaders(token, userAgent string) http.Header { @@ -600,7 +960,24 @@ func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, meth } attempts := maxRetries + 1 + authRecovered := false + authRecoverExtraAttemptGranted := false + var lastErr error for attempt := 1; attempt <= attempts; attempt++ { + if c.debugEnabled() { + c.debugLogf( + "request_start method=%s url=%s attempt=%d/%d timeout_s=%d body_bytes=%d proxy_bound=%t headers=%s", + method, + sanitizeSoraLogURL(urlStr), + attempt, + attempts, + timeout, + len(bodyBytes), + account != nil && account.ProxyID != nil && account.Proxy != nil, + formatSoraHeaders(headers), + ) + } + var reader io.Reader if bodyBytes != nil { reader = bytes.NewReader(bodyBytes) @@ -618,7 +995,21 @@ func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, meth } resp, err := c.doHTTP(req, proxyURL, account) if err != nil { + lastErr = err + if c.debugEnabled() { + c.debugLogf( + "request_transport_error method=%s url=%s attempt=%d/%d err=%s", + method, + sanitizeSoraLogURL(urlStr), + attempt, + attempts, + logredact.RedactText(err.Error()), + ) + } if attempt < attempts && allowRetry { + if c.debugEnabled() { + c.debugLogf("request_retry_scheduled method=%s url=%s reason=transport_error next_attempt=%d/%d", method, sanitizeSoraLogURL(urlStr), attempt+1, attempts) + } c.sleepRetry(attempt) continue } @@ -632,12 +1023,53 @@ func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, meth } if c.cfg != nil && c.cfg.Sora.Client.Debug { - log.Printf("[SoraClient] %s %s status=%d cost=%s", method, sanitizeSoraLogURL(urlStr), resp.StatusCode, time.Since(start)) + c.debugLogf( + "response_received method=%s url=%s attempt=%d/%d status=%d cost=%s resp_bytes=%d resp_headers=%s", + method, + sanitizeSoraLogURL(urlStr), + attempt, + attempts, + resp.StatusCode, + time.Since(start), + len(respBody), + formatSoraHeaders(resp.Header), + ) } if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { - upstreamErr := c.buildUpstreamError(resp.StatusCode, resp.Header, respBody) + if !authRecovered && shouldAttemptSoraTokenRecover(resp.StatusCode, urlStr) && account != nil { + if recovered, recoverErr := c.recoverAccessToken(ctx, account, fmt.Sprintf("upstream_status_%d", resp.StatusCode)); recoverErr == nil && strings.TrimSpace(recovered) != "" { + headers.Set("Authorization", "Bearer "+recovered) + authRecovered = true + if attempt == attempts && !authRecoverExtraAttemptGranted { + attempts++ + authRecoverExtraAttemptGranted = true + } + if c.debugEnabled() { + c.debugLogf("request_retry_with_recovered_token method=%s url=%s status=%d", method, sanitizeSoraLogURL(urlStr), resp.StatusCode) + } + continue + } else if recoverErr != nil && c.debugEnabled() { + c.debugLogf("request_recover_token_failed method=%s url=%s status=%d err=%s", method, sanitizeSoraLogURL(urlStr), resp.StatusCode, logredact.RedactText(recoverErr.Error())) + } + } + if c.debugEnabled() { + c.debugLogf( + "response_non_success method=%s url=%s attempt=%d/%d status=%d body=%s", + method, + sanitizeSoraLogURL(urlStr), + attempt, + attempts, + resp.StatusCode, + summarizeSoraResponseBody(respBody, 512), + ) + } + upstreamErr := c.buildUpstreamError(resp.StatusCode, resp.Header, respBody, urlStr) + lastErr = upstreamErr if allowRetry && attempt < attempts && (resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode >= 500) { + if c.debugEnabled() { + c.debugLogf("request_retry_scheduled method=%s url=%s reason=status_%d next_attempt=%d/%d", method, sanitizeSoraLogURL(urlStr), resp.StatusCode, attempt+1, attempts) + } c.sleepRetry(attempt) continue } @@ -645,9 +1077,34 @@ func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, meth } return respBody, resp.Header, nil } + if lastErr != nil { + return nil, nil, lastErr + } return nil, nil, errors.New("upstream retries exhausted") } +func shouldAttemptSoraTokenRecover(statusCode int, rawURL string) bool { + switch statusCode { + case http.StatusUnauthorized, http.StatusForbidden: + parsed, err := url.Parse(strings.TrimSpace(rawURL)) + if err != nil { + return false + } + host := strings.ToLower(parsed.Hostname()) + if host != "sora.chatgpt.com" && host != "chatgpt.com" { + return false + } + // 避免在 ST->AT 转换接口上递归触发 token 恢复导致死循环。 + path := strings.ToLower(strings.TrimSpace(parsed.Path)) + if path == "/api/auth/session" { + return false + } + return true + default: + return false + } +} + func (c *SoraDirectClient) doHTTP(req *http.Request, proxyURL string, account *Account) (*http.Response, error) { enableTLS := c != nil && c.cfg != nil && c.cfg.Gateway.TLSFingerprint.Enabled && !c.cfg.Sora.Client.DisableTLSFingerprint if c.httpUpstream != nil { @@ -670,9 +1127,14 @@ func (c *SoraDirectClient) sleepRetry(attempt int) { time.Sleep(backoff) } -func (c *SoraDirectClient) buildUpstreamError(status int, headers http.Header, body []byte) error { +func (c *SoraDirectClient) buildUpstreamError(status int, headers http.Header, body []byte, requestURL string) error { msg := strings.TrimSpace(extractUpstreamErrorMessage(body)) msg = sanitizeUpstreamErrorMessage(msg) + if status == http.StatusNotFound && strings.Contains(strings.ToLower(msg), "not found") { + if hint := soraBaseURLNotFoundHint(requestURL); hint != "" { + msg = strings.TrimSpace(msg + " " + hint) + } + } if msg == "" { msg = truncateForLog(body, 256) } @@ -684,6 +1146,45 @@ func (c *SoraDirectClient) buildUpstreamError(status int, headers http.Header, b } } +func normalizeSoraBaseURL(raw string) string { + trimmed := strings.TrimRight(strings.TrimSpace(raw), "/") + if trimmed == "" { + return "" + } + parsed, err := url.Parse(trimmed) + if err != nil || parsed.Scheme == "" || parsed.Host == "" { + return trimmed + } + host := strings.ToLower(parsed.Hostname()) + if host != "sora.chatgpt.com" && host != "chatgpt.com" { + return trimmed + } + pathVal := strings.TrimRight(strings.TrimSpace(parsed.Path), "/") + switch pathVal { + case "", "/": + parsed.Path = "/backend" + case "/backend-api": + parsed.Path = "/backend" + } + return strings.TrimRight(parsed.String(), "/") +} + +func soraBaseURLNotFoundHint(requestURL string) string { + parsed, err := url.Parse(strings.TrimSpace(requestURL)) + if err != nil || parsed.Host == "" { + return "" + } + host := strings.ToLower(parsed.Hostname()) + if host != "sora.chatgpt.com" && host != "chatgpt.com" { + return "" + } + pathVal := strings.TrimSpace(parsed.Path) + if strings.HasPrefix(pathVal, "/backend/") || pathVal == "/backend" { + return "" + } + return "(请检查 sora.client.base_url,建议配置为 https://sora.chatgpt.com/backend)" +} + func (c *SoraDirectClient) generateSentinelToken(ctx context.Context, account *Account, accessToken string) (string, error) { reqID := uuid.NewString() userAgent := soraRandChoice(soraDesktopUserAgents) @@ -901,3 +1402,70 @@ func sanitizeSoraLogURL(raw string) string { parsed.RawQuery = q.Encode() return parsed.String() } + +func (c *SoraDirectClient) debugEnabled() bool { + return c != nil && c.cfg != nil && c.cfg.Sora.Client.Debug +} + +func (c *SoraDirectClient) debugLogf(format string, args ...any) { + if !c.debugEnabled() { + return + } + log.Printf("[SoraClient] "+format, args...) +} + +func formatSoraHeaders(headers http.Header) string { + if len(headers) == 0 { + return "{}" + } + keys := make([]string, 0, len(headers)) + for key := range headers { + keys = append(keys, key) + } + sort.Strings(keys) + out := make(map[string]string, len(keys)) + for _, key := range keys { + values := headers.Values(key) + if len(values) == 0 { + continue + } + val := strings.Join(values, ",") + if isSensitiveHeader(key) { + out[key] = "***" + continue + } + out[key] = truncateForLog([]byte(logredact.RedactText(val)), 160) + } + encoded, err := json.Marshal(out) + if err != nil { + return "{}" + } + return string(encoded) +} + +func isSensitiveHeader(key string) bool { + k := strings.ToLower(strings.TrimSpace(key)) + switch k { + case "authorization", "openai-sentinel-token", "cookie", "set-cookie", "x-api-key": + return true + default: + return false + } +} + +func summarizeSoraResponseBody(body []byte, maxLen int) string { + if len(body) == 0 { + return "" + } + var text string + if json.Valid(body) { + text = logredact.RedactJSON(body) + } else { + text = logredact.RedactText(string(body)) + } + text = strings.TrimSpace(text) + if maxLen <= 0 || len(text) <= maxLen { + return text + } + return text[:maxLen] + "...(truncated)" +} diff --git a/backend/internal/service/sora_client_test.go b/backend/internal/service/sora_client_test.go index a6bf71cd..3e88c9f9 100644 --- a/backend/internal/service/sora_client_test.go +++ b/backend/internal/service/sora_client_test.go @@ -4,9 +4,13 @@ package service import ( "context" + "encoding/json" "net/http" "net/http/httptest" + "strings" + "sync/atomic" "testing" + "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/stretchr/testify/require" @@ -85,3 +89,273 @@ func TestSoraDirectClient_GetImageTaskFallbackLimit(t *testing.T) { require.Equal(t, "completed", status.Status) require.Equal(t, []string{"https://example.com/a.png"}, status.URLs) } + +func TestNormalizeSoraBaseURL(t *testing.T) { + t.Parallel() + tests := []struct { + name string + raw string + want string + }{ + { + name: "empty", + raw: "", + want: "", + }, + { + name: "append_backend_for_sora_host", + raw: "https://sora.chatgpt.com", + want: "https://sora.chatgpt.com/backend", + }, + { + name: "convert_backend_api_to_backend", + raw: "https://sora.chatgpt.com/backend-api", + want: "https://sora.chatgpt.com/backend", + }, + { + name: "keep_backend", + raw: "https://sora.chatgpt.com/backend", + want: "https://sora.chatgpt.com/backend", + }, + { + name: "keep_custom_host", + raw: "https://example.com/custom-path", + want: "https://example.com/custom-path", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := normalizeSoraBaseURL(tt.raw) + require.Equal(t, tt.want, got) + }) + } +} + +func TestSoraDirectClient_BuildURL_UsesNormalizedBaseURL(t *testing.T) { + t.Parallel() + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + BaseURL: "https://sora.chatgpt.com", + }, + }, + } + client := NewSoraDirectClient(cfg, nil, nil) + require.Equal(t, "https://sora.chatgpt.com/backend/video_gen", client.buildURL("/video_gen")) +} + +func TestSoraDirectClient_BuildUpstreamError_NotFoundHint(t *testing.T) { + t.Parallel() + client := NewSoraDirectClient(&config.Config{}, nil, nil) + + err := client.buildUpstreamError(http.StatusNotFound, http.Header{}, []byte(`{"error":{"message":"Not found"}}`), "https://sora.chatgpt.com/video_gen") + var upstreamErr *SoraUpstreamError + require.ErrorAs(t, err, &upstreamErr) + require.Contains(t, upstreamErr.Message, "请检查 sora.client.base_url") + + errNoHint := client.buildUpstreamError(http.StatusNotFound, http.Header{}, []byte(`{"error":{"message":"Not found"}}`), "https://sora.chatgpt.com/backend/video_gen") + require.ErrorAs(t, errNoHint, &upstreamErr) + require.NotContains(t, upstreamErr.Message, "请检查 sora.client.base_url") +} + +func TestFormatSoraHeaders_RedactsSensitive(t *testing.T) { + t.Parallel() + headers := http.Header{} + headers.Set("Authorization", "Bearer secret-token") + headers.Set("openai-sentinel-token", "sentinel-secret") + headers.Set("X-Test", "ok") + + out := formatSoraHeaders(headers) + require.Contains(t, out, `"Authorization":"***"`) + require.Contains(t, out, `Sentinel-Token":"***"`) + require.Contains(t, out, `"X-Test":"ok"`) + require.NotContains(t, out, "secret-token") + require.NotContains(t, out, "sentinel-secret") +} + +func TestSummarizeSoraResponseBody_RedactsJSON(t *testing.T) { + t.Parallel() + body := []byte(`{"error":{"message":"bad"},"access_token":"abc123"}`) + out := summarizeSoraResponseBody(body, 512) + require.Contains(t, out, `"access_token":"***"`) + require.NotContains(t, out, "abc123") +} + +func TestSummarizeSoraResponseBody_Truncates(t *testing.T) { + t.Parallel() + body := []byte(strings.Repeat("x", 100)) + out := summarizeSoraResponseBody(body, 10) + require.Contains(t, out, "(truncated)") +} + +func TestSoraDirectClient_GetAccessToken_SoraDefaultUseCredentials(t *testing.T) { + t.Parallel() + cache := newOpenAITokenCacheStub() + provider := NewOpenAITokenProvider(nil, cache, nil) + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + BaseURL: "https://sora.chatgpt.com/backend", + }, + }, + } + client := NewSoraDirectClient(cfg, nil, provider) + account := &Account{ + ID: 1, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "sora-credential-token", + }, + } + + token, err := client.getAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "sora-credential-token", token) + require.Equal(t, int32(0), atomic.LoadInt32(&cache.getCalled)) +} + +func TestSoraDirectClient_GetAccessToken_SoraCanEnableProvider(t *testing.T) { + t.Parallel() + cache := newOpenAITokenCacheStub() + account := &Account{ + ID: 2, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "sora-credential-token", + }, + } + cache.tokens[OpenAITokenCacheKey(account)] = "provider-token" + provider := NewOpenAITokenProvider(nil, cache, nil) + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + BaseURL: "https://sora.chatgpt.com/backend", + UseOpenAITokenProvider: true, + }, + }, + } + client := NewSoraDirectClient(cfg, nil, provider) + + token, err := client.getAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "provider-token", token) + require.Greater(t, atomic.LoadInt32(&cache.getCalled), int32(0)) +} + +func TestSoraDirectClient_GetAccessToken_FromSessionToken(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodGet, r.Method) + require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=session-token") + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "accessToken": "session-access-token", + "expires": "2099-01-01T00:00:00Z", + }) + })) + defer server.Close() + + origin := soraSessionAuthURL + soraSessionAuthURL = server.URL + defer func() { soraSessionAuthURL = origin }() + + client := NewSoraDirectClient(&config.Config{}, nil, nil) + account := &Account{ + ID: 10, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "session_token": "session-token", + }, + } + + token, err := client.getAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "session-access-token", token) + require.Equal(t, "session-access-token", account.GetCredential("access_token")) +} + +func TestSoraDirectClient_GetAccessToken_FromRefreshToken(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodPost, r.Method) + require.Equal(t, "/oauth/token", r.URL.Path) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "refresh-access-token", + "refresh_token": "refresh-token-new", + "expires_in": 3600, + }) + })) + defer server.Close() + + origin := soraOAuthTokenURL + soraOAuthTokenURL = server.URL + "/oauth/token" + defer func() { soraOAuthTokenURL = origin }() + + client := NewSoraDirectClient(&config.Config{}, nil, nil) + account := &Account{ + ID: 11, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "refresh_token": "refresh-token-old", + }, + } + + token, err := client.getAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "refresh-access-token", token) + require.Equal(t, "refresh-token-new", account.GetCredential("refresh_token")) + require.NotNil(t, account.GetCredentialAsTime("expires_at")) +} + +func TestSoraDirectClient_PreflightCheck_VideoQuotaExceeded(t *testing.T) { + t.Parallel() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodGet, r.Method) + require.Equal(t, "/nf/check", r.URL.Path) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "rate_limit_and_credit_balance": map[string]any{ + "estimated_num_videos_remaining": 0, + "rate_limit_reached": true, + }, + }) + })) + defer server.Close() + + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + BaseURL: server.URL, + }, + }, + } + client := NewSoraDirectClient(cfg, nil, nil) + account := &Account{ + ID: 12, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "ok", + "expires_at": time.Now().Add(2 * time.Hour).Format(time.RFC3339), + }, + } + err := client.PreflightCheck(context.Background(), account, "sora2-landscape-10s", SoraModelConfig{Type: "video"}) + require.Error(t, err) + var upstreamErr *SoraUpstreamError + require.ErrorAs(t, err, &upstreamErr) + require.Equal(t, http.StatusTooManyRequests, upstreamErr.StatusCode) +} + +func TestShouldAttemptSoraTokenRecover(t *testing.T) { + t.Parallel() + + require.True(t, shouldAttemptSoraTokenRecover(http.StatusUnauthorized, "https://sora.chatgpt.com/backend/video_gen")) + require.True(t, shouldAttemptSoraTokenRecover(http.StatusForbidden, "https://chatgpt.com/backend/video_gen")) + require.False(t, shouldAttemptSoraTokenRecover(http.StatusUnauthorized, "https://sora.chatgpt.com/api/auth/session")) + require.False(t, shouldAttemptSoraTokenRecover(http.StatusUnauthorized, "https://auth.openai.com/oauth/token")) + require.False(t, shouldAttemptSoraTokenRecover(http.StatusTooManyRequests, "https://sora.chatgpt.com/backend/video_gen")) +} diff --git a/backend/internal/service/sora_gateway_service.go b/backend/internal/service/sora_gateway_service.go index d7ff297c..8ae89f92 100644 --- a/backend/internal/service/sora_gateway_service.go +++ b/backend/internal/service/sora_gateway_service.go @@ -61,6 +61,10 @@ type SoraGatewayService struct { cfg *config.Config } +type soraPreflightChecker interface { + PreflightCheck(ctx context.Context, account *Account, requestedModel string, modelCfg SoraModelConfig) error +} + func NewSoraGatewayService( soraClient SoraClient, mediaStorage *SoraMediaStorage, @@ -112,11 +116,6 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Unsupported Sora model", clientStream) return nil, fmt.Errorf("unsupported model: %s", reqModel) } - if modelCfg.Type == "prompt_enhance" { - s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Prompt-enhance 模型暂未支持", clientStream) - return nil, fmt.Errorf("prompt-enhance not supported") - } - prompt, imageInput, videoInput, remixTargetID := extractSoraInput(reqBody) if strings.TrimSpace(prompt) == "" { s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream) @@ -131,6 +130,41 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun if cancel != nil { defer cancel() } + if checker, ok := s.soraClient.(soraPreflightChecker); ok { + if err := checker.PreflightCheck(reqCtx, account, reqModel, modelCfg); err != nil { + return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream) + } + } + + if modelCfg.Type == "prompt_enhance" { + enhancedPrompt, err := s.soraClient.EnhancePrompt(reqCtx, account, prompt, modelCfg.ExpansionLevel, modelCfg.DurationS) + if err != nil { + return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream) + } + content := strings.TrimSpace(enhancedPrompt) + if content == "" { + content = prompt + } + var firstTokenMs *int + if clientStream { + ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime) + if streamErr != nil { + return nil, streamErr + } + firstTokenMs = ms + } else if c != nil { + c.JSON(http.StatusOK, buildSoraNonStreamResponse(content, reqModel)) + } + return &ForwardResult{ + RequestID: "", + Model: reqModel, + Stream: clientStream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + Usage: ClaudeUsage{}, + MediaType: "prompt", + }, nil + } var imageData []byte imageFilename := "" @@ -267,7 +301,7 @@ func (s *SoraGatewayService) withSoraTimeout(ctx context.Context, stream bool) ( func (s *SoraGatewayService) shouldFailoverUpstreamError(statusCode int) bool { switch statusCode { - case 401, 402, 403, 429, 529: + case 401, 402, 403, 404, 429, 529: return true default: return statusCode >= 500 @@ -460,7 +494,7 @@ func (s *SoraGatewayService) handleSoraRequestError(ctx context.Context, account s.rateLimitService.HandleUpstreamError(ctx, account, upstreamErr.StatusCode, upstreamErr.Headers, upstreamErr.Body) } if s.shouldFailoverUpstreamError(upstreamErr.StatusCode) { - return &UpstreamFailoverError{StatusCode: upstreamErr.StatusCode} + return &UpstreamFailoverError{StatusCode: upstreamErr.StatusCode, ResponseBody: upstreamErr.Body} } msg := upstreamErr.Message if override := soraProErrorMessage(model, msg); override != "" { diff --git a/backend/internal/service/sora_gateway_service_test.go b/backend/internal/service/sora_gateway_service_test.go index d6bf9eae..f706d052 100644 --- a/backend/internal/service/sora_gateway_service_test.go +++ b/backend/internal/service/sora_gateway_service_test.go @@ -18,6 +18,8 @@ type stubSoraClientForPoll struct { videoStatus *SoraVideoTaskStatus imageCalls int videoCalls int + enhanced string + enhanceErr error } func (s *stubSoraClientForPoll) Enabled() bool { return true } @@ -30,6 +32,12 @@ func (s *stubSoraClientForPoll) CreateImageTask(ctx context.Context, account *Ac func (s *stubSoraClientForPoll) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) { return "task-video", nil } +func (s *stubSoraClientForPoll) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) { + if s.enhanced != "" { + return s.enhanced, s.enhanceErr + } + return "enhanced prompt", s.enhanceErr +} func (s *stubSoraClientForPoll) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) { s.imageCalls++ return s.imageStatus, nil @@ -62,6 +70,33 @@ func TestSoraGatewayService_PollImageTaskCompleted(t *testing.T) { require.Equal(t, 1, client.imageCalls) } +func TestSoraGatewayService_ForwardPromptEnhance(t *testing.T) { + client := &stubSoraClientForPoll{ + enhanced: "cinematic prompt", + } + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + PollIntervalSeconds: 1, + MaxPollAttempts: 1, + }, + }, + } + svc := NewSoraGatewayService(client, nil, nil, cfg) + account := &Account{ + ID: 1, + Platform: PlatformSora, + Status: StatusActive, + } + body := []byte(`{"model":"prompt-enhance-short-10s","messages":[{"role":"user","content":"cat running"}],"stream":false}`) + + result, err := svc.Forward(context.Background(), nil, account, body, false) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "prompt", result.MediaType) + require.Equal(t, "prompt-enhance-short-10s", result.Model) +} + func TestSoraGatewayService_PollVideoTaskFailed(t *testing.T) { client := &stubSoraClientForPoll{ videoStatus: &SoraVideoTaskStatus{ @@ -178,6 +213,7 @@ func TestSoraProErrorMessage(t *testing.T) { func TestShouldFailoverUpstreamError(t *testing.T) { svc := NewSoraGatewayService(nil, nil, nil, &config.Config{}) require.True(t, svc.shouldFailoverUpstreamError(401)) + require.True(t, svc.shouldFailoverUpstreamError(404)) require.True(t, svc.shouldFailoverUpstreamError(429)) require.True(t, svc.shouldFailoverUpstreamError(500)) require.True(t, svc.shouldFailoverUpstreamError(502)) diff --git a/backend/internal/service/sora_models.go b/backend/internal/service/sora_models.go index ab095e46..80b20a4b 100644 --- a/backend/internal/service/sora_models.go +++ b/backend/internal/service/sora_models.go @@ -17,6 +17,9 @@ type SoraModelConfig struct { Model string Size string RequirePro bool + // Prompt-enhance 专用参数 + ExpansionLevel string + DurationS int } var soraModelConfigs = map[string]SoraModelConfig{ @@ -160,31 +163,49 @@ var soraModelConfigs = map[string]SoraModelConfig{ RequirePro: true, }, "prompt-enhance-short-10s": { - Type: "prompt_enhance", + Type: "prompt_enhance", + ExpansionLevel: "short", + DurationS: 10, }, "prompt-enhance-short-15s": { - Type: "prompt_enhance", + Type: "prompt_enhance", + ExpansionLevel: "short", + DurationS: 15, }, "prompt-enhance-short-20s": { - Type: "prompt_enhance", + Type: "prompt_enhance", + ExpansionLevel: "short", + DurationS: 20, }, "prompt-enhance-medium-10s": { - Type: "prompt_enhance", + Type: "prompt_enhance", + ExpansionLevel: "medium", + DurationS: 10, }, "prompt-enhance-medium-15s": { - Type: "prompt_enhance", + Type: "prompt_enhance", + ExpansionLevel: "medium", + DurationS: 15, }, "prompt-enhance-medium-20s": { - Type: "prompt_enhance", + Type: "prompt_enhance", + ExpansionLevel: "medium", + DurationS: 20, }, "prompt-enhance-long-10s": { - Type: "prompt_enhance", + Type: "prompt_enhance", + ExpansionLevel: "long", + DurationS: 10, }, "prompt-enhance-long-15s": { - Type: "prompt_enhance", + Type: "prompt_enhance", + ExpansionLevel: "long", + DurationS: 15, }, "prompt-enhance-long-20s": { - Type: "prompt_enhance", + Type: "prompt_enhance", + ExpansionLevel: "long", + DurationS: 20, }, } diff --git a/backend/internal/service/token_refresh_service.go b/backend/internal/service/token_refresh_service.go index 9de1c164..a37e0d0a 100644 --- a/backend/internal/service/token_refresh_service.go +++ b/backend/internal/service/token_refresh_service.go @@ -43,10 +43,13 @@ func NewTokenRefreshService( stopCh: make(chan struct{}), } + openAIRefresher := NewOpenAITokenRefresher(openaiOAuthService, accountRepo) + openAIRefresher.SetSyncLinkedSoraAccounts(cfg.TokenRefresh.SyncLinkedSoraAccounts) + // 注册平台特定的刷新器 s.refreshers = []TokenRefresher{ NewClaudeTokenRefresher(oauthService), - NewOpenAITokenRefresher(openaiOAuthService, accountRepo), + openAIRefresher, NewGeminiTokenRefresher(geminiOAuthService), NewAntigravityTokenRefresher(antigravityOAuthService), } diff --git a/backend/internal/service/token_refresher.go b/backend/internal/service/token_refresher.go index 46033f75..0dd3cf45 100644 --- a/backend/internal/service/token_refresher.go +++ b/backend/internal/service/token_refresher.go @@ -86,6 +86,7 @@ type OpenAITokenRefresher struct { openaiOAuthService *OpenAIOAuthService accountRepo AccountRepository soraAccountRepo SoraAccountRepository // Sora 扩展表仓储,用于双表同步 + syncLinkedSora bool } // NewOpenAITokenRefresher 创建 OpenAI token刷新器 @@ -103,11 +104,15 @@ func (r *OpenAITokenRefresher) SetSoraAccountRepo(repo SoraAccountRepository) { r.soraAccountRepo = repo } +// SetSyncLinkedSoraAccounts 控制是否同步覆盖关联的 Sora 账号 token。 +func (r *OpenAITokenRefresher) SetSyncLinkedSoraAccounts(enabled bool) { + r.syncLinkedSora = enabled +} + // CanRefresh 检查是否能处理此账号 -// 只处理 openai 平台的 oauth 类型账号 +// 只处理 openai 平台的 oauth 类型账号(不直接刷新 sora 平台账号) func (r *OpenAITokenRefresher) CanRefresh(account *Account) bool { - return (account.Platform == PlatformOpenAI || account.Platform == PlatformSora) && - account.Type == AccountTypeOAuth + return account.Platform == PlatformOpenAI && account.Type == AccountTypeOAuth } // NeedsRefresh 检查token是否需要刷新 @@ -141,7 +146,7 @@ func (r *OpenAITokenRefresher) Refresh(ctx context.Context, account *Account) (m } // 异步同步关联的 Sora 账号(不阻塞主流程) - if r.accountRepo != nil { + if r.accountRepo != nil && r.syncLinkedSora { go r.syncLinkedSoraAccounts(context.Background(), account.ID, newCredentials) } diff --git a/backend/internal/service/token_refresher_test.go b/backend/internal/service/token_refresher_test.go index c7505037..264d7912 100644 --- a/backend/internal/service/token_refresher_test.go +++ b/backend/internal/service/token_refresher_test.go @@ -226,3 +226,43 @@ func TestClaudeTokenRefresher_CanRefresh(t *testing.T) { }) } } + +func TestOpenAITokenRefresher_CanRefresh(t *testing.T) { + refresher := &OpenAITokenRefresher{} + + tests := []struct { + name string + platform string + accType string + want bool + }{ + { + name: "openai oauth - can refresh", + platform: PlatformOpenAI, + accType: AccountTypeOAuth, + want: true, + }, + { + name: "sora oauth - cannot refresh directly", + platform: PlatformSora, + accType: AccountTypeOAuth, + want: false, + }, + { + name: "openai apikey - cannot refresh", + platform: PlatformOpenAI, + accType: AccountTypeAPIKey, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + account := &Account{ + Platform: tt.platform, + Type: tt.accType, + } + require.Equal(t, tt.want, refresher.CanRefresh(account)) + }) + } +} diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 5d712f75..652f9e00 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -206,6 +206,18 @@ func ProvideSoraMediaStorage(cfg *config.Config) *SoraMediaStorage { return NewSoraMediaStorage(cfg) } +func ProvideSoraDirectClient( + cfg *config.Config, + httpUpstream HTTPUpstream, + tokenProvider *OpenAITokenProvider, + accountRepo AccountRepository, + soraAccountRepo SoraAccountRepository, +) *SoraDirectClient { + client := NewSoraDirectClient(cfg, httpUpstream, tokenProvider) + client.SetAccountRepositories(accountRepo, soraAccountRepo) + return client +} + // ProvideSoraMediaCleanupService 创建并启动 Sora 媒体清理服务 func ProvideSoraMediaCleanupService(storage *SoraMediaStorage, cfg *config.Config) *SoraMediaCleanupService { svc := NewSoraMediaCleanupService(storage, cfg) @@ -255,7 +267,7 @@ var ProviderSet = wire.NewSet( NewGatewayService, ProvideSoraMediaStorage, ProvideSoraMediaCleanupService, - NewSoraDirectClient, + ProvideSoraDirectClient, wire.Bind(new(SoraClient), new(*SoraDirectClient)), NewSoraGatewayService, NewOpenAIGatewayService, diff --git a/backend/internal/web/embed_on.go b/backend/internal/web/embed_on.go index 7f37d59c..f7ba5c9e 100644 --- a/backend/internal/web/embed_on.go +++ b/backend/internal/web/embed_on.go @@ -86,6 +86,7 @@ func (s *FrontendServer) Middleware() gin.HandlerFunc { if strings.HasPrefix(path, "/api/") || strings.HasPrefix(path, "/v1/") || strings.HasPrefix(path, "/v1beta/") || + strings.HasPrefix(path, "/sora/") || strings.HasPrefix(path, "/antigravity/") || strings.HasPrefix(path, "/setup/") || path == "/health" || @@ -209,6 +210,7 @@ func ServeEmbeddedFrontend() gin.HandlerFunc { if strings.HasPrefix(path, "/api/") || strings.HasPrefix(path, "/v1/") || strings.HasPrefix(path, "/v1beta/") || + strings.HasPrefix(path, "/sora/") || strings.HasPrefix(path, "/antigravity/") || strings.HasPrefix(path, "/setup/") || path == "/health" || diff --git a/backend/internal/web/embed_test.go b/backend/internal/web/embed_test.go index 50f5a323..e2cbcf15 100644 --- a/backend/internal/web/embed_test.go +++ b/backend/internal/web/embed_test.go @@ -362,6 +362,7 @@ func TestFrontendServer_Middleware(t *testing.T) { "/api/v1/users", "/v1/models", "/v1beta/chat", + "/sora/v1/models", "/antigravity/test", "/setup/init", "/health", @@ -537,6 +538,7 @@ func TestServeEmbeddedFrontend(t *testing.T) { "/api/users", "/v1/models", "/v1beta/chat", + "/sora/v1/models", "/antigravity/test", "/setup/init", "/health", diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 9fd2d391..0ff1ec02 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -388,7 +388,11 @@ sora: recent_task_limit_max: 200 # Enable debug logs for Sora upstream requests # 启用 Sora 直连调试日志 + # 调试日志会输出上游请求尝试、重试、响应摘要;Authorization/openai-sentinel-token 等敏感头会自动脱敏 debug: false + # Allow Sora client to fetch token via OpenAI token provider + # 是否允许 Sora 客户端通过 OpenAI token provider 取 token(默认 false,避免误走 OpenAI 刷新链路) + use_openai_token_provider: false # Optional custom headers (key-value) # 额外请求头(键值对) headers: {} @@ -431,6 +435,13 @@ sora: # Cron 调度表达式 schedule: "0 3 * * *" +# Token refresh behavior +# token 刷新行为控制 +token_refresh: + # Whether OpenAI refresh flow is allowed to sync linked Sora accounts + # 是否允许 OpenAI 刷新流程同步覆盖 linked_openai_account_id 关联的 Sora 账号 token + sync_linked_sora_accounts: false + # ============================================================================= # API Key Auth Cache Configuration # API Key 认证缓存配置 diff --git a/frontend/src/api/admin/accounts.ts b/frontend/src/api/admin/accounts.ts index 36bec4e7..e1f502ec 100644 --- a/frontend/src/api/admin/accounts.ts +++ b/frontend/src/api/admin/accounts.ts @@ -220,7 +220,7 @@ export async function generateAuthUrl( */ export async function exchangeCode( endpoint: string, - exchangeData: { session_id: string; code: string; proxy_id?: number } + exchangeData: { session_id: string; code: string; state?: string; proxy_id?: number } ): Promise> { const { data } = await apiClient.post>(endpoint, exchangeData) return data @@ -442,7 +442,8 @@ export async function getAntigravityDefaultModelMapping(): Promise> { const payload: { refresh_token: string; proxy_id?: number } = { refresh_token: refreshToken @@ -450,7 +451,29 @@ export async function refreshOpenAIToken( if (proxyId) { payload.proxy_id = proxyId } - const { data } = await apiClient.post>('/admin/openai/refresh-token', payload) + const { data } = await apiClient.post>(endpoint, payload) + return data +} + +/** + * Validate Sora session token and exchange to access token + * @param sessionToken - Sora session token + * @param proxyId - Optional proxy ID + * @param endpoint - API endpoint path + * @returns Token information including access_token + */ +export async function validateSoraSessionToken( + sessionToken: string, + proxyId?: number | null, + endpoint: string = '/admin/sora/st2at' +): Promise> { + const payload: { session_token: string; proxy_id?: number } = { + session_token: sessionToken + } + if (proxyId) { + payload.proxy_id = proxyId + } + const { data } = await apiClient.post>(endpoint, payload) return data } @@ -475,6 +498,7 @@ export const accountsAPI = { generateAuthUrl, exchangeCode, refreshOpenAIToken, + validateSoraSessionToken, batchCreate, batchUpdateCredentials, bulkUpdate, diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index 85785d6a..8024dfb6 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -109,6 +109,28 @@ OpenAI +
+ +
+ +
+ +
+
+
@@ -1747,32 +1801,6 @@
- -
- -
-
@@ -2148,6 +2178,7 @@ interface OAuthFlowExposed { projectId: string sessionKey: string refreshToken: string + sessionToken: string inputMethod: AuthInputMethod reset: () => void } @@ -2156,7 +2187,7 @@ const { t } = useI18n() const authStore = useAuthStore() const oauthStepTitle = computed(() => { - if (form.platform === 'openai') return t('admin.accounts.oauth.openai.title') + if (form.platform === 'openai' || form.platform === 'sora') return t('admin.accounts.oauth.openai.title') if (form.platform === 'gemini') return t('admin.accounts.oauth.gemini.title') if (form.platform === 'antigravity') return t('admin.accounts.oauth.antigravity.title') return t('admin.accounts.oauth.title') @@ -2164,13 +2195,13 @@ const oauthStepTitle = computed(() => { // Platform-specific hints for API Key type const baseUrlHint = computed(() => { - if (form.platform === 'openai') return t('admin.accounts.openai.baseUrlHint') + if (form.platform === 'openai' || form.platform === 'sora') return t('admin.accounts.openai.baseUrlHint') if (form.platform === 'gemini') return t('admin.accounts.gemini.baseUrlHint') return t('admin.accounts.baseUrlHint') }) const apiKeyHint = computed(() => { - if (form.platform === 'openai') return t('admin.accounts.openai.apiKeyHint') + if (form.platform === 'openai' || form.platform === 'sora') return t('admin.accounts.openai.apiKeyHint') if (form.platform === 'gemini') return t('admin.accounts.gemini.apiKeyHint') return t('admin.accounts.apiKeyHint') }) @@ -2191,34 +2222,36 @@ const appStore = useAppStore() // OAuth composables const oauth = useAccountOAuth() // For Anthropic OAuth -const openaiOAuth = useOpenAIOAuth() // For OpenAI OAuth +const openaiOAuth = useOpenAIOAuth({ platform: 'openai' }) // For OpenAI OAuth +const soraOAuth = useOpenAIOAuth({ platform: 'sora' }) // For Sora OAuth const geminiOAuth = useGeminiOAuth() // For Gemini OAuth const antigravityOAuth = useAntigravityOAuth() // For Antigravity OAuth +const activeOpenAIOAuth = computed(() => (form.platform === 'sora' ? soraOAuth : openaiOAuth)) // Computed: current OAuth state for template binding const currentAuthUrl = computed(() => { - if (form.platform === 'openai') return openaiOAuth.authUrl.value + if (form.platform === 'openai' || form.platform === 'sora') return activeOpenAIOAuth.value.authUrl.value if (form.platform === 'gemini') return geminiOAuth.authUrl.value if (form.platform === 'antigravity') return antigravityOAuth.authUrl.value return oauth.authUrl.value }) const currentSessionId = computed(() => { - if (form.platform === 'openai') return openaiOAuth.sessionId.value + if (form.platform === 'openai' || form.platform === 'sora') return activeOpenAIOAuth.value.sessionId.value if (form.platform === 'gemini') return geminiOAuth.sessionId.value if (form.platform === 'antigravity') return antigravityOAuth.sessionId.value return oauth.sessionId.value }) const currentOAuthLoading = computed(() => { - if (form.platform === 'openai') return openaiOAuth.loading.value + if (form.platform === 'openai' || form.platform === 'sora') return activeOpenAIOAuth.value.loading.value if (form.platform === 'gemini') return geminiOAuth.loading.value if (form.platform === 'antigravity') return antigravityOAuth.loading.value return oauth.loading.value }) const currentOAuthError = computed(() => { - if (form.platform === 'openai') return openaiOAuth.error.value + if (form.platform === 'openai' || form.platform === 'sora') return activeOpenAIOAuth.value.error.value if (form.platform === 'gemini') return geminiOAuth.error.value if (form.platform === 'antigravity') return antigravityOAuth.error.value return oauth.error.value @@ -2257,7 +2290,6 @@ const interceptWarmupRequests = ref(false) const autoPauseOnExpired = ref(true) const openaiPassthroughEnabled = ref(false) const codexCLIOnlyEnabled = ref(false) -const enableSoraOnOpenAIOAuth = ref(false) // OpenAI OAuth 时同时启用 Sora const mixedScheduling = ref(false) // For antigravity accounts: enable mixed scheduling const antigravityAccountType = ref<'oauth' | 'upstream'>('oauth') // For antigravity: oauth or upstream const upstreamBaseUrl = ref('') // For upstream type: base URL @@ -2398,8 +2430,8 @@ const expiresAtInput = computed({ const canExchangeCode = computed(() => { const authCode = oauthFlowRef.value?.authCode || '' - if (form.platform === 'openai') { - return authCode.trim() && openaiOAuth.sessionId.value && !openaiOAuth.loading.value + if (form.platform === 'openai' || form.platform === 'sora') { + return authCode.trim() && activeOpenAIOAuth.value.sessionId.value && !activeOpenAIOAuth.value.loading.value } if (form.platform === 'gemini') { return authCode.trim() && geminiOAuth.sessionId.value && !geminiOAuth.loading.value @@ -2459,7 +2491,7 @@ watch( (newPlatform) => { // Reset base URL based on platform apiKeyBaseUrl.value = - newPlatform === 'openai' + (newPlatform === 'openai' || newPlatform === 'sora') ? 'https://api.openai.com' : newPlatform === 'gemini' ? 'https://generativelanguage.googleapis.com' @@ -2485,6 +2517,11 @@ watch( if (newPlatform !== 'anthropic') { interceptWarmupRequests.value = false } + if (newPlatform === 'sora') { + accountCategory.value = 'oauth-based' + addMethod.value = 'oauth' + form.type = 'oauth' + } if (newPlatform !== 'openai') { openaiPassthroughEnabled.value = false codexCLIOnlyEnabled.value = false @@ -2492,6 +2529,7 @@ watch( // Reset OAuth states oauth.resetState() openaiOAuth.resetState() + soraOAuth.resetState() geminiOAuth.resetState() antigravityOAuth.resetState() } @@ -2753,7 +2791,6 @@ const resetForm = () => { autoPauseOnExpired.value = true openaiPassthroughEnabled.value = false codexCLIOnlyEnabled.value = false - enableSoraOnOpenAIOAuth.value = false // Reset quota control state windowCostEnabled.value = false windowCostLimit.value = null @@ -2776,6 +2813,7 @@ const resetForm = () => { geminiTierAIStudio.value = 'aistudio_free' oauth.resetState() openaiOAuth.resetState() + soraOAuth.resetState() geminiOAuth.resetState() antigravityOAuth.resetState() oauthFlowRef.value?.reset() @@ -2807,6 +2845,23 @@ const buildOpenAIExtra = (base?: Record): Record 0 ? extra : undefined } +const buildSoraExtra = ( + base?: Record, + linkedOpenAIAccountId?: string | number +): Record | undefined => { + const extra: Record = { ...(base || {}) } + if (linkedOpenAIAccountId !== undefined && linkedOpenAIAccountId !== null) { + const id = String(linkedOpenAIAccountId).trim() + if (id) { + extra.linked_openai_account_id = id + } + } + delete extra.openai_passthrough + delete extra.openai_oauth_passthrough + delete extra.codex_cli_only + return Object.keys(extra).length > 0 ? extra : undefined +} + // Helper function to create account with mixed channel warning handling const doCreateAccount = async (payload: any) => { submitting.value = true @@ -2922,7 +2977,7 @@ const handleSubmit = async () => { // Determine default base URL based on platform const defaultBaseUrl = - form.platform === 'openai' + (form.platform === 'openai' || form.platform === 'sora') ? 'https://api.openai.com' : form.platform === 'gemini' ? 'https://generativelanguage.googleapis.com' @@ -2974,14 +3029,15 @@ const goBackToBasicInfo = () => { step.value = 1 oauth.resetState() openaiOAuth.resetState() + soraOAuth.resetState() geminiOAuth.resetState() antigravityOAuth.resetState() oauthFlowRef.value?.reset() } const handleGenerateUrl = async () => { - if (form.platform === 'openai') { - await openaiOAuth.generateAuthUrl(form.proxy_id) + if (form.platform === 'openai' || form.platform === 'sora') { + await activeOpenAIOAuth.value.generateAuthUrl(form.proxy_id) } else if (form.platform === 'gemini') { await geminiOAuth.generateAuthUrl( form.proxy_id, @@ -2997,13 +3053,19 @@ const handleGenerateUrl = async () => { } const handleValidateRefreshToken = (rt: string) => { - if (form.platform === 'openai') { + if (form.platform === 'openai' || form.platform === 'sora') { handleOpenAIValidateRT(rt) } else if (form.platform === 'antigravity') { handleAntigravityValidateRT(rt) } } +const handleValidateSessionToken = (sessionToken: string) => { + if (form.platform === 'sora') { + handleSoraValidateST(sessionToken) + } +} + const formatDateTimeLocal = formatDateTimeLocalInput const parseDateTimeLocal = parseDateTimeLocalInput @@ -3039,100 +3101,101 @@ const createAccountAndFinish = async ( // OpenAI OAuth 授权码兑换 const handleOpenAIExchange = async (authCode: string) => { - if (!authCode.trim() || !openaiOAuth.sessionId.value) return + const oauthClient = activeOpenAIOAuth.value + if (!authCode.trim() || !oauthClient.sessionId.value) return - openaiOAuth.loading.value = true - openaiOAuth.error.value = '' + oauthClient.loading.value = true + oauthClient.error.value = '' try { - const tokenInfo = await openaiOAuth.exchangeAuthCode( + const stateToUse = (oauthFlowRef.value?.oauthState || oauthClient.oauthState.value || '').trim() + if (!stateToUse) { + oauthClient.error.value = t('admin.accounts.oauth.authFailed') + appStore.showError(oauthClient.error.value) + return + } + + const tokenInfo = await oauthClient.exchangeAuthCode( authCode.trim(), - openaiOAuth.sessionId.value, + oauthClient.sessionId.value, + stateToUse, form.proxy_id ) if (!tokenInfo) return - const credentials = openaiOAuth.buildCredentials(tokenInfo) - const oauthExtra = openaiOAuth.buildExtraInfo(tokenInfo) as Record | undefined + const credentials = oauthClient.buildCredentials(tokenInfo) + const oauthExtra = oauthClient.buildExtraInfo(tokenInfo) as Record | undefined const extra = buildOpenAIExtra(oauthExtra) + const shouldCreateOpenAI = form.platform === 'openai' + const shouldCreateSora = form.platform === 'sora' // 应用临时不可调度配置 if (!applyTempUnschedConfig(credentials)) { return } - // 1. 创建 OpenAI 账号 - const openaiAccount = await adminAPI.accounts.create({ - name: form.name, - notes: form.notes, - platform: 'openai', - type: 'oauth', - credentials, - extra, - proxy_id: form.proxy_id, - concurrency: form.concurrency, - priority: form.priority, - rate_multiplier: form.rate_multiplier, - group_ids: form.group_ids, - expires_at: form.expires_at, - auto_pause_on_expired: autoPauseOnExpired.value - }) + let openaiAccountId: string | number | undefined - appStore.showSuccess(t('admin.accounts.accountCreated')) + if (shouldCreateOpenAI) { + const openaiAccount = await adminAPI.accounts.create({ + name: form.name, + notes: form.notes, + platform: 'openai', + type: 'oauth', + credentials, + extra, + proxy_id: form.proxy_id, + concurrency: form.concurrency, + priority: form.priority, + rate_multiplier: form.rate_multiplier, + group_ids: form.group_ids, + expires_at: form.expires_at, + auto_pause_on_expired: autoPauseOnExpired.value + }) + openaiAccountId = openaiAccount.id + appStore.showSuccess(t('admin.accounts.accountCreated')) + } - // 2. 如果启用了 Sora,同时创建 Sora 账号 - if (enableSoraOnOpenAIOAuth.value) { - try { - // Sora 使用相同的 OAuth credentials - const soraCredentials = { - access_token: credentials.access_token, - refresh_token: credentials.refresh_token, - expires_at: credentials.expires_at - } - - // 建立关联关系 - const soraExtra: Record = { - ...(extra || {}), - linked_openai_account_id: String(openaiAccount.id) - } - delete soraExtra.openai_passthrough - delete soraExtra.openai_oauth_passthrough - - await adminAPI.accounts.create({ - name: `${form.name} (Sora)`, - notes: form.notes, - platform: 'sora', - type: 'oauth', - credentials: soraCredentials, - extra: soraExtra, - proxy_id: form.proxy_id, - concurrency: form.concurrency, - priority: form.priority, - rate_multiplier: form.rate_multiplier, - group_ids: form.group_ids, - expires_at: form.expires_at, - auto_pause_on_expired: autoPauseOnExpired.value - }) - - appStore.showSuccess(t('admin.accounts.soraAccountCreated')) - } catch (error: any) { - console.error('创建 Sora 账号失败:', error) - appStore.showWarning(t('admin.accounts.soraAccountFailed')) + if (shouldCreateSora) { + const soraCredentials = { + access_token: credentials.access_token, + refresh_token: credentials.refresh_token, + expires_at: credentials.expires_at } + + const soraName = shouldCreateOpenAI ? `${form.name} (Sora)` : form.name + const soraExtra = buildSoraExtra(shouldCreateOpenAI ? extra : oauthExtra, openaiAccountId) + await adminAPI.accounts.create({ + name: soraName, + notes: form.notes, + platform: 'sora', + type: 'oauth', + credentials: soraCredentials, + extra: soraExtra, + proxy_id: form.proxy_id, + concurrency: form.concurrency, + priority: form.priority, + rate_multiplier: form.rate_multiplier, + group_ids: form.group_ids, + expires_at: form.expires_at, + auto_pause_on_expired: autoPauseOnExpired.value + }) + appStore.showSuccess(t('admin.accounts.accountCreated')) } emit('created') handleClose() } catch (error: any) { - openaiOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed') - appStore.showError(openaiOAuth.error.value) + oauthClient.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed') + appStore.showError(oauthClient.error.value) } finally { - openaiOAuth.loading.value = false + oauthClient.loading.value = false } } // OpenAI 手动 RT 批量验证和创建 const handleOpenAIValidateRT = async (refreshTokenInput: string) => { + const oauthClient = activeOpenAIOAuth.value if (!refreshTokenInput.trim()) return // Parse multiple refresh tokens (one per line) @@ -3142,53 +3205,86 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => { .filter((rt) => rt) if (refreshTokens.length === 0) { - openaiOAuth.error.value = t('admin.accounts.oauth.openai.pleaseEnterRefreshToken') + oauthClient.error.value = t('admin.accounts.oauth.openai.pleaseEnterRefreshToken') return } - openaiOAuth.loading.value = true - openaiOAuth.error.value = '' + oauthClient.loading.value = true + oauthClient.error.value = '' let successCount = 0 let failedCount = 0 const errors: string[] = [] + const shouldCreateOpenAI = form.platform === 'openai' + const shouldCreateSora = form.platform === 'sora' try { for (let i = 0; i < refreshTokens.length; i++) { try { - const tokenInfo = await openaiOAuth.validateRefreshToken( + const tokenInfo = await oauthClient.validateRefreshToken( refreshTokens[i], form.proxy_id ) if (!tokenInfo) { failedCount++ - errors.push(`#${i + 1}: ${openaiOAuth.error.value || 'Validation failed'}`) - openaiOAuth.error.value = '' + errors.push(`#${i + 1}: ${oauthClient.error.value || 'Validation failed'}`) + oauthClient.error.value = '' continue } - const credentials = openaiOAuth.buildCredentials(tokenInfo) - const oauthExtra = openaiOAuth.buildExtraInfo(tokenInfo) as Record | undefined + const credentials = oauthClient.buildCredentials(tokenInfo) + const oauthExtra = oauthClient.buildExtraInfo(tokenInfo) as Record | undefined const extra = buildOpenAIExtra(oauthExtra) // Generate account name with index for batch const accountName = refreshTokens.length > 1 ? `${form.name} #${i + 1}` : form.name - await adminAPI.accounts.create({ - name: accountName, - notes: form.notes, - platform: 'openai', - type: 'oauth', - credentials, - extra, - proxy_id: form.proxy_id, - concurrency: form.concurrency, - priority: form.priority, - rate_multiplier: form.rate_multiplier, - group_ids: form.group_ids, - expires_at: form.expires_at, - auto_pause_on_expired: autoPauseOnExpired.value - }) + let openaiAccountId: string | number | undefined + + if (shouldCreateOpenAI) { + const openaiAccount = await adminAPI.accounts.create({ + name: accountName, + notes: form.notes, + platform: 'openai', + type: 'oauth', + credentials, + extra, + proxy_id: form.proxy_id, + concurrency: form.concurrency, + priority: form.priority, + rate_multiplier: form.rate_multiplier, + group_ids: form.group_ids, + expires_at: form.expires_at, + auto_pause_on_expired: autoPauseOnExpired.value + }) + openaiAccountId = openaiAccount.id + } + + if (shouldCreateSora) { + const soraCredentials = { + access_token: credentials.access_token, + refresh_token: credentials.refresh_token, + expires_at: credentials.expires_at + } + const soraName = shouldCreateOpenAI ? `${accountName} (Sora)` : accountName + const soraExtra = buildSoraExtra(shouldCreateOpenAI ? extra : oauthExtra, openaiAccountId) + await adminAPI.accounts.create({ + name: soraName, + notes: form.notes, + platform: 'sora', + type: 'oauth', + credentials: soraCredentials, + extra: soraExtra, + proxy_id: form.proxy_id, + concurrency: form.concurrency, + priority: form.priority, + rate_multiplier: form.rate_multiplier, + group_ids: form.group_ids, + expires_at: form.expires_at, + auto_pause_on_expired: autoPauseOnExpired.value + }) + } + successCount++ } catch (error: any) { failedCount++ @@ -3210,14 +3306,99 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => { appStore.showWarning( t('admin.accounts.oauth.batchPartialSuccess', { success: successCount, failed: failedCount }) ) - openaiOAuth.error.value = errors.join('\n') + oauthClient.error.value = errors.join('\n') emit('created') } else { - openaiOAuth.error.value = errors.join('\n') + oauthClient.error.value = errors.join('\n') appStore.showError(t('admin.accounts.oauth.batchFailed')) } } finally { - openaiOAuth.loading.value = false + oauthClient.loading.value = false + } +} + +// Sora 手动 ST 批量验证和创建 +const handleSoraValidateST = async (sessionTokenInput: string) => { + const oauthClient = activeOpenAIOAuth.value + if (!sessionTokenInput.trim()) return + + const sessionTokens = sessionTokenInput + .split('\n') + .map((st) => st.trim()) + .filter((st) => st) + + if (sessionTokens.length === 0) { + oauthClient.error.value = t('admin.accounts.oauth.openai.pleaseEnterSessionToken') + return + } + + oauthClient.loading.value = true + oauthClient.error.value = '' + + let successCount = 0 + let failedCount = 0 + const errors: string[] = [] + + try { + for (let i = 0; i < sessionTokens.length; i++) { + try { + const tokenInfo = await oauthClient.validateSessionToken(sessionTokens[i], form.proxy_id) + if (!tokenInfo) { + failedCount++ + errors.push(`#${i + 1}: ${oauthClient.error.value || 'Validation failed'}`) + oauthClient.error.value = '' + continue + } + + const credentials = oauthClient.buildCredentials(tokenInfo) + credentials.session_token = sessionTokens[i] + const oauthExtra = oauthClient.buildExtraInfo(tokenInfo) as Record | undefined + const soraExtra = buildSoraExtra(oauthExtra) + + const accountName = sessionTokens.length > 1 ? `${form.name} #${i + 1}` : form.name + await adminAPI.accounts.create({ + name: accountName, + notes: form.notes, + platform: 'sora', + type: 'oauth', + credentials, + extra: soraExtra, + proxy_id: form.proxy_id, + concurrency: form.concurrency, + priority: form.priority, + rate_multiplier: form.rate_multiplier, + group_ids: form.group_ids, + expires_at: form.expires_at, + auto_pause_on_expired: autoPauseOnExpired.value + }) + successCount++ + } catch (error: any) { + failedCount++ + const errMsg = error.response?.data?.detail || error.message || 'Unknown error' + errors.push(`#${i + 1}: ${errMsg}`) + } + } + + if (successCount > 0 && failedCount === 0) { + appStore.showSuccess( + sessionTokens.length > 1 + ? t('admin.accounts.oauth.batchSuccess', { count: successCount }) + : t('admin.accounts.accountCreated') + ) + emit('created') + handleClose() + } else if (successCount > 0 && failedCount > 0) { + appStore.showWarning( + t('admin.accounts.oauth.batchPartialSuccess', { success: successCount, failed: failedCount }) + ) + oauthClient.error.value = errors.join('\n') + emit('created') + } else { + oauthClient.error.value = errors.join('\n') + appStore.showError(t('admin.accounts.oauth.batchFailed')) + } + } finally { + oauthClient.loading.value = false } } @@ -3462,6 +3643,7 @@ const handleExchangeCode = async () => { switch (form.platform) { case 'openai': + case 'sora': return handleOpenAIExchange(authCode) case 'gemini': return handleGeminiExchange(authCode) diff --git a/frontend/src/components/account/OAuthAuthorizationFlow.vue b/frontend/src/components/account/OAuthAuthorizationFlow.vue index 9c4b7e4b..8e00d25b 100644 --- a/frontend/src/components/account/OAuthAuthorizationFlow.vue +++ b/frontend/src/components/account/OAuthAuthorizationFlow.vue @@ -48,6 +48,17 @@ t(getOAuthKey('refreshTokenAuth')) }} +
@@ -135,6 +146,87 @@ + +
+
+

+ {{ t(getOAuthKey('sessionTokenDesc')) }} +

+ +
+ + +

+ {{ t('admin.accounts.oauth.batchCreateAccounts', { count: parsedSessionTokenCount }) }} +

+
+ +
+

+ {{ error }} +

+
+ + +
+
+
(), { authUrl: '', @@ -540,6 +633,7 @@ const props = withDefaults(defineProps(), { methodLabel: 'Authorization Method', showCookieOption: true, showRefreshTokenOption: false, + showSessionTokenOption: false, platform: 'anthropic', showProjectId: true }) @@ -549,6 +643,7 @@ const emit = defineEmits<{ 'exchange-code': [code: string] 'cookie-auth': [sessionKey: string] 'validate-refresh-token': [refreshToken: string] + 'validate-session-token': [sessionToken: string] 'update:inputMethod': [method: AuthInputMethod] }>() @@ -587,12 +682,13 @@ const inputMethod = ref(props.showCookieOption ? 'manual' : 'ma const authCodeInput = ref('') const sessionKeyInput = ref('') const refreshTokenInput = ref('') +const sessionTokenInput = ref('') const showHelpDialog = ref(false) const oauthState = ref('') const projectId = ref('') // Computed: show method selection when either cookie or refresh token option is enabled -const showMethodSelection = computed(() => props.showCookieOption || props.showRefreshTokenOption) +const showMethodSelection = computed(() => props.showCookieOption || props.showRefreshTokenOption || props.showSessionTokenOption) // Clipboard const { copied, copyToClipboard } = useClipboard() @@ -613,6 +709,13 @@ const parsedRefreshTokenCount = computed(() => { .filter((rt) => rt).length }) +const parsedSessionTokenCount = computed(() => { + return sessionTokenInput.value + .split('\n') + .map((st) => st.trim()) + .filter((st) => st).length +}) + // Watchers watch(inputMethod, (newVal) => { emit('update:inputMethod', newVal) @@ -631,7 +734,7 @@ watch(authCodeInput, (newVal) => { const url = new URL(trimmed) const code = url.searchParams.get('code') const stateParam = url.searchParams.get('state') - if ((props.platform === 'gemini' || props.platform === 'antigravity') && stateParam) { + if ((props.platform === 'openai' || props.platform === 'sora' || props.platform === 'gemini' || props.platform === 'antigravity') && stateParam) { oauthState.value = stateParam } if (code && code !== trimmed) { @@ -642,7 +745,7 @@ watch(authCodeInput, (newVal) => { // If URL parsing fails, try regex extraction const match = trimmed.match(/[?&]code=([^&]+)/) const stateMatch = trimmed.match(/[?&]state=([^&]+)/) - if ((props.platform === 'gemini' || props.platform === 'antigravity') && stateMatch && stateMatch[1]) { + if ((props.platform === 'openai' || props.platform === 'sora' || props.platform === 'gemini' || props.platform === 'antigravity') && stateMatch && stateMatch[1]) { oauthState.value = stateMatch[1] } if (match && match[1] && match[1] !== trimmed) { @@ -680,6 +783,12 @@ const handleValidateRefreshToken = () => { } } +const handleValidateSessionToken = () => { + if (sessionTokenInput.value.trim()) { + emit('validate-session-token', sessionTokenInput.value.trim()) + } +} + // Expose methods and state defineExpose({ authCode: authCodeInput, @@ -687,6 +796,7 @@ defineExpose({ projectId, sessionKey: sessionKeyInput, refreshToken: refreshTokenInput, + sessionToken: sessionTokenInput, inputMethod, reset: () => { authCodeInput.value = '' @@ -694,6 +804,7 @@ defineExpose({ projectId.value = '' sessionKeyInput.value = '' refreshTokenInput.value = '' + sessionTokenInput.value = '' inputMethod.value = 'manual' showHelpDialog.value = false } diff --git a/frontend/src/components/account/ReAuthAccountModal.vue b/frontend/src/components/account/ReAuthAccountModal.vue index b2734b4f..aab0fe7d 100644 --- a/frontend/src/components/account/ReAuthAccountModal.vue +++ b/frontend/src/components/account/ReAuthAccountModal.vue @@ -14,7 +14,7 @@
('code_as // Computed - check platform const isOpenAI = computed(() => props.account?.platform === 'openai') +const isSora = computed(() => props.account?.platform === 'sora') +const isOpenAILike = computed(() => isOpenAI.value || isSora.value) const isGemini = computed(() => props.account?.platform === 'gemini') const isAnthropic = computed(() => props.account?.platform === 'anthropic') const isAntigravity = computed(() => props.account?.platform === 'antigravity') +const activeOpenAIOAuth = computed(() => (isSora.value ? soraOAuth : openaiOAuth)) // Computed - current OAuth state based on platform const currentAuthUrl = computed(() => { - if (isOpenAI.value) return openaiOAuth.authUrl.value + if (isOpenAILike.value) return activeOpenAIOAuth.value.authUrl.value if (isGemini.value) return geminiOAuth.authUrl.value if (isAntigravity.value) return antigravityOAuth.authUrl.value return claudeOAuth.authUrl.value }) const currentSessionId = computed(() => { - if (isOpenAI.value) return openaiOAuth.sessionId.value + if (isOpenAILike.value) return activeOpenAIOAuth.value.sessionId.value if (isGemini.value) return geminiOAuth.sessionId.value if (isAntigravity.value) return antigravityOAuth.sessionId.value return claudeOAuth.sessionId.value }) const currentLoading = computed(() => { - if (isOpenAI.value) return openaiOAuth.loading.value + if (isOpenAILike.value) return activeOpenAIOAuth.value.loading.value if (isGemini.value) return geminiOAuth.loading.value if (isAntigravity.value) return antigravityOAuth.loading.value return claudeOAuth.loading.value }) const currentError = computed(() => { - if (isOpenAI.value) return openaiOAuth.error.value + if (isOpenAILike.value) return activeOpenAIOAuth.value.error.value if (isGemini.value) return geminiOAuth.error.value if (isAntigravity.value) return antigravityOAuth.error.value return claudeOAuth.error.value @@ -269,8 +275,8 @@ const currentError = computed(() => { // Computed const isManualInputMethod = computed(() => { - // OpenAI/Gemini/Antigravity always use manual input (no cookie auth option) - return isOpenAI.value || isGemini.value || isAntigravity.value || oauthFlowRef.value?.inputMethod === 'manual' + // OpenAI/Sora/Gemini/Antigravity always use manual input (no cookie auth option) + return isOpenAILike.value || isGemini.value || isAntigravity.value || oauthFlowRef.value?.inputMethod === 'manual' }) const canExchangeCode = computed(() => { @@ -313,6 +319,7 @@ const resetState = () => { geminiOAuthType.value = 'code_assist' claudeOAuth.resetState() openaiOAuth.resetState() + soraOAuth.resetState() geminiOAuth.resetState() antigravityOAuth.resetState() oauthFlowRef.value?.reset() @@ -325,8 +332,8 @@ const handleClose = () => { const handleGenerateUrl = async () => { if (!props.account) return - if (isOpenAI.value) { - await openaiOAuth.generateAuthUrl(props.account.proxy_id) + if (isOpenAILike.value) { + await activeOpenAIOAuth.value.generateAuthUrl(props.account.proxy_id) } else if (isGemini.value) { const creds = (props.account.credentials || {}) as Record const tierId = typeof creds.tier_id === 'string' ? creds.tier_id : undefined @@ -345,21 +352,29 @@ const handleExchangeCode = async () => { const authCode = oauthFlowRef.value?.authCode || '' if (!authCode.trim()) return - if (isOpenAI.value) { + if (isOpenAILike.value) { // OpenAI OAuth flow - const sessionId = openaiOAuth.sessionId.value + const oauthClient = activeOpenAIOAuth.value + const sessionId = oauthClient.sessionId.value if (!sessionId) return + const stateToUse = (oauthFlowRef.value?.oauthState || oauthClient.oauthState.value || '').trim() + if (!stateToUse) { + oauthClient.error.value = t('admin.accounts.oauth.authFailed') + appStore.showError(oauthClient.error.value) + return + } - const tokenInfo = await openaiOAuth.exchangeAuthCode( + const tokenInfo = await oauthClient.exchangeAuthCode( authCode.trim(), sessionId, + stateToUse, props.account.proxy_id ) if (!tokenInfo) return // Build credentials and extra info - const credentials = openaiOAuth.buildCredentials(tokenInfo) - const extra = openaiOAuth.buildExtraInfo(tokenInfo) + const credentials = oauthClient.buildCredentials(tokenInfo) + const extra = oauthClient.buildExtraInfo(tokenInfo) try { // Update account with new credentials @@ -376,8 +391,8 @@ const handleExchangeCode = async () => { emit('reauthorized') handleClose() } catch (error: any) { - openaiOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed') - appStore.showError(openaiOAuth.error.value) + oauthClient.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed') + appStore.showError(oauthClient.error.value) } } else if (isGemini.value) { const sessionId = geminiOAuth.sessionId.value @@ -490,7 +505,7 @@ const handleExchangeCode = async () => { } const handleCookieAuth = async (sessionKey: string) => { - if (!props.account || isOpenAI.value) return + if (!props.account || isOpenAILike.value) return claudeOAuth.loading.value = true claudeOAuth.error.value = '' diff --git a/frontend/src/components/admin/account/AccountTestModal.vue b/frontend/src/components/admin/account/AccountTestModal.vue index feb09654..38196781 100644 --- a/frontend/src/components/admin/account/AccountTestModal.vue +++ b/frontend/src/components/admin/account/AccountTestModal.vue @@ -238,6 +238,11 @@ const loadAvailableModels = async () => { availableModels.value.find((m) => m.id === 'gemini-3-flash-preview') || availableModels.value.find((m) => m.id === 'gemini-3-pro-preview') selectedModelId.value = preferred?.id || availableModels.value[0].id + } else if (props.account.platform === 'sora') { + const preferred = + availableModels.value.find((m) => m.id === 'gpt-image') || + availableModels.value.find((m) => !m.id.startsWith('prompt-enhance')) + selectedModelId.value = preferred?.id || availableModels.value[0].id } else { // Try to select Sonnet as default, otherwise use first model const sonnetModel = availableModels.value.find((m) => m.id.includes('sonnet')) diff --git a/frontend/src/components/admin/account/ReAuthAccountModal.vue b/frontend/src/components/admin/account/ReAuthAccountModal.vue index eeb3f288..c269eea4 100644 --- a/frontend/src/components/admin/account/ReAuthAccountModal.vue +++ b/frontend/src/components/admin/account/ReAuthAccountModal.vue @@ -14,7 +14,7 @@
('code_as // Computed - check platform const isOpenAI = computed(() => props.account?.platform === 'openai') +const isSora = computed(() => props.account?.platform === 'sora') +const isOpenAILike = computed(() => isOpenAI.value || isSora.value) const isGemini = computed(() => props.account?.platform === 'gemini') const isAnthropic = computed(() => props.account?.platform === 'anthropic') const isAntigravity = computed(() => props.account?.platform === 'antigravity') +const activeOpenAIOAuth = computed(() => (isSora.value ? soraOAuth : openaiOAuth)) // Computed - current OAuth state based on platform const currentAuthUrl = computed(() => { - if (isOpenAI.value) return openaiOAuth.authUrl.value + if (isOpenAILike.value) return activeOpenAIOAuth.value.authUrl.value if (isGemini.value) return geminiOAuth.authUrl.value if (isAntigravity.value) return antigravityOAuth.authUrl.value return claudeOAuth.authUrl.value }) const currentSessionId = computed(() => { - if (isOpenAI.value) return openaiOAuth.sessionId.value + if (isOpenAILike.value) return activeOpenAIOAuth.value.sessionId.value if (isGemini.value) return geminiOAuth.sessionId.value if (isAntigravity.value) return antigravityOAuth.sessionId.value return claudeOAuth.sessionId.value }) const currentLoading = computed(() => { - if (isOpenAI.value) return openaiOAuth.loading.value + if (isOpenAILike.value) return activeOpenAIOAuth.value.loading.value if (isGemini.value) return geminiOAuth.loading.value if (isAntigravity.value) return antigravityOAuth.loading.value return claudeOAuth.loading.value }) const currentError = computed(() => { - if (isOpenAI.value) return openaiOAuth.error.value + if (isOpenAILike.value) return activeOpenAIOAuth.value.error.value if (isGemini.value) return geminiOAuth.error.value if (isAntigravity.value) return antigravityOAuth.error.value return claudeOAuth.error.value @@ -269,8 +275,8 @@ const currentError = computed(() => { // Computed const isManualInputMethod = computed(() => { - // OpenAI/Gemini/Antigravity always use manual input (no cookie auth option) - return isOpenAI.value || isGemini.value || isAntigravity.value || oauthFlowRef.value?.inputMethod === 'manual' + // OpenAI/Sora/Gemini/Antigravity always use manual input (no cookie auth option) + return isOpenAILike.value || isGemini.value || isAntigravity.value || oauthFlowRef.value?.inputMethod === 'manual' }) const canExchangeCode = computed(() => { @@ -313,6 +319,7 @@ const resetState = () => { geminiOAuthType.value = 'code_assist' claudeOAuth.resetState() openaiOAuth.resetState() + soraOAuth.resetState() geminiOAuth.resetState() antigravityOAuth.resetState() oauthFlowRef.value?.reset() @@ -325,8 +332,8 @@ const handleClose = () => { const handleGenerateUrl = async () => { if (!props.account) return - if (isOpenAI.value) { - await openaiOAuth.generateAuthUrl(props.account.proxy_id) + if (isOpenAILike.value) { + await activeOpenAIOAuth.value.generateAuthUrl(props.account.proxy_id) } else if (isGemini.value) { const creds = (props.account.credentials || {}) as Record const tierId = typeof creds.tier_id === 'string' ? creds.tier_id : undefined @@ -345,21 +352,29 @@ const handleExchangeCode = async () => { const authCode = oauthFlowRef.value?.authCode || '' if (!authCode.trim()) return - if (isOpenAI.value) { + if (isOpenAILike.value) { // OpenAI OAuth flow - const sessionId = openaiOAuth.sessionId.value + const oauthClient = activeOpenAIOAuth.value + const sessionId = oauthClient.sessionId.value if (!sessionId) return + const stateToUse = (oauthFlowRef.value?.oauthState || oauthClient.oauthState.value || '').trim() + if (!stateToUse) { + oauthClient.error.value = t('admin.accounts.oauth.authFailed') + appStore.showError(oauthClient.error.value) + return + } - const tokenInfo = await openaiOAuth.exchangeAuthCode( + const tokenInfo = await oauthClient.exchangeAuthCode( authCode.trim(), sessionId, + stateToUse, props.account.proxy_id ) if (!tokenInfo) return // Build credentials and extra info - const credentials = openaiOAuth.buildCredentials(tokenInfo) - const extra = openaiOAuth.buildExtraInfo(tokenInfo) + const credentials = oauthClient.buildCredentials(tokenInfo) + const extra = oauthClient.buildExtraInfo(tokenInfo) try { // Update account with new credentials @@ -376,8 +391,8 @@ const handleExchangeCode = async () => { emit('reauthorized', updatedAccount) handleClose() } catch (error: any) { - openaiOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed') - appStore.showError(openaiOAuth.error.value) + oauthClient.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed') + appStore.showError(oauthClient.error.value) } } else if (isGemini.value) { const sessionId = geminiOAuth.sessionId.value @@ -490,7 +505,7 @@ const handleExchangeCode = async () => { } const handleCookieAuth = async (sessionKey: string) => { - if (!props.account || isOpenAI.value) return + if (!props.account || isOpenAILike.value) return claudeOAuth.loading.value = true claudeOAuth.error.value = '' diff --git a/frontend/src/composables/useAccountOAuth.ts b/frontend/src/composables/useAccountOAuth.ts index ca200cb3..6f53404c 100644 --- a/frontend/src/composables/useAccountOAuth.ts +++ b/frontend/src/composables/useAccountOAuth.ts @@ -3,7 +3,7 @@ import { useAppStore } from '@/stores/app' import { adminAPI } from '@/api/admin' export type AddMethod = 'oauth' | 'setup-token' -export type AuthInputMethod = 'manual' | 'cookie' | 'refresh_token' +export type AuthInputMethod = 'manual' | 'cookie' | 'refresh_token' | 'session_token' export interface OAuthState { authUrl: string diff --git a/frontend/src/composables/useOpenAIOAuth.ts b/frontend/src/composables/useOpenAIOAuth.ts index 82a77031..32045cbe 100644 --- a/frontend/src/composables/useOpenAIOAuth.ts +++ b/frontend/src/composables/useOpenAIOAuth.ts @@ -19,12 +19,21 @@ export interface OpenAITokenInfo { [key: string]: unknown } -export function useOpenAIOAuth() { +export type OpenAIOAuthPlatform = 'openai' | 'sora' + +interface UseOpenAIOAuthOptions { + platform?: OpenAIOAuthPlatform +} + +export function useOpenAIOAuth(options?: UseOpenAIOAuthOptions) { const appStore = useAppStore() + const oauthPlatform = options?.platform ?? 'openai' + const endpointPrefix = oauthPlatform === 'sora' ? '/admin/sora' : '/admin/openai' // State const authUrl = ref('') const sessionId = ref('') + const oauthState = ref('') const loading = ref(false) const error = ref('') @@ -32,6 +41,7 @@ export function useOpenAIOAuth() { const resetState = () => { authUrl.value = '' sessionId.value = '' + oauthState.value = '' loading.value = false error.value = '' } @@ -44,6 +54,7 @@ export function useOpenAIOAuth() { loading.value = true authUrl.value = '' sessionId.value = '' + oauthState.value = '' error.value = '' try { @@ -56,11 +67,17 @@ export function useOpenAIOAuth() { } const response = await adminAPI.accounts.generateAuthUrl( - '/admin/openai/generate-auth-url', + `${endpointPrefix}/generate-auth-url`, payload ) authUrl.value = response.auth_url sessionId.value = response.session_id + try { + const parsed = new URL(response.auth_url) + oauthState.value = parsed.searchParams.get('state') || '' + } catch { + oauthState.value = '' + } return true } catch (err: any) { error.value = err.response?.data?.detail || 'Failed to generate OpenAI auth URL' @@ -75,10 +92,11 @@ export function useOpenAIOAuth() { const exchangeAuthCode = async ( code: string, currentSessionId: string, + state: string, proxyId?: number | null ): Promise => { - if (!code.trim() || !currentSessionId) { - error.value = 'Missing auth code or session ID' + if (!code.trim() || !currentSessionId || !state.trim()) { + error.value = 'Missing auth code, session ID, or state' return null } @@ -86,15 +104,16 @@ export function useOpenAIOAuth() { error.value = '' try { - const payload: { session_id: string; code: string; proxy_id?: number } = { + const payload: { session_id: string; code: string; state: string; proxy_id?: number } = { session_id: currentSessionId, - code: code.trim() + code: code.trim(), + state: state.trim() } if (proxyId) { payload.proxy_id = proxyId } - const tokenInfo = await adminAPI.accounts.exchangeCode('/admin/openai/exchange-code', payload) + const tokenInfo = await adminAPI.accounts.exchangeCode(`${endpointPrefix}/exchange-code`, payload) return tokenInfo as OpenAITokenInfo } catch (err: any) { error.value = err.response?.data?.detail || 'Failed to exchange OpenAI auth code' @@ -120,7 +139,11 @@ export function useOpenAIOAuth() { try { // Use dedicated refresh-token endpoint - const tokenInfo = await adminAPI.accounts.refreshOpenAIToken(refreshToken.trim(), proxyId) + const tokenInfo = await adminAPI.accounts.refreshOpenAIToken( + refreshToken.trim(), + proxyId, + `${endpointPrefix}/refresh-token` + ) return tokenInfo as OpenAITokenInfo } catch (err: any) { error.value = err.response?.data?.detail || 'Failed to validate refresh token' @@ -131,6 +154,33 @@ export function useOpenAIOAuth() { } } + // Validate Sora session token and get access token + const validateSessionToken = async ( + sessionToken: string, + proxyId?: number | null + ): Promise => { + if (!sessionToken.trim()) { + error.value = 'Missing session token' + return null + } + loading.value = true + error.value = '' + try { + const tokenInfo = await adminAPI.accounts.validateSoraSessionToken( + sessionToken.trim(), + proxyId, + `${endpointPrefix}/st2at` + ) + return tokenInfo as OpenAITokenInfo + } catch (err: any) { + error.value = err.response?.data?.detail || 'Failed to validate session token' + appStore.showError(error.value) + return null + } finally { + loading.value = false + } + } + // Build credentials for OpenAI OAuth account const buildCredentials = (tokenInfo: OpenAITokenInfo): Record => { const creds: Record = { @@ -172,6 +222,7 @@ export function useOpenAIOAuth() { // State authUrl, sessionId, + oauthState, loading, error, // Methods @@ -179,6 +230,7 @@ export function useOpenAIOAuth() { generateAuthUrl, exchangeAuthCode, validateRefreshToken, + validateSessionToken, buildCredentials, buildExtraInfo } diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 293af1da..0dd87f8a 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -1740,9 +1740,13 @@ export default { refreshTokenAuth: 'Manual RT Input', refreshTokenDesc: 'Enter your existing OpenAI Refresh Token(s). Supports batch input (one per line). The system will automatically validate and create accounts.', refreshTokenPlaceholder: 'Paste your OpenAI Refresh Token...\nSupports multiple, one per line', + sessionTokenAuth: 'Manual ST Input', + sessionTokenDesc: 'Enter your existing Sora Session Token(s). Supports batch input (one per line). The system will automatically validate and create accounts.', + sessionTokenPlaceholder: 'Paste your Sora Session Token...\nSupports multiple, one per line', validating: 'Validating...', validateAndCreate: 'Validate & Create Account', - pleaseEnterRefreshToken: 'Please enter Refresh Token' + pleaseEnterRefreshToken: 'Please enter Refresh Token', + pleaseEnterSessionToken: 'Please enter Session Token' }, // Gemini specific gemini: { @@ -1963,6 +1967,7 @@ export default { reAuthorizeAccount: 'Re-Authorize Account', claudeCodeAccount: 'Claude Code Account', openaiAccount: 'OpenAI Account', + soraAccount: 'Sora Account', geminiAccount: 'Gemini Account', antigravityAccount: 'Antigravity Account', inputMethod: 'Input Method', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 08f1aeef..f28045e3 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -1879,9 +1879,13 @@ export default { refreshTokenAuth: '手动输入 RT', refreshTokenDesc: '输入您已有的 OpenAI Refresh Token,支持批量输入(每行一个),系统将自动验证并创建账号。', refreshTokenPlaceholder: '粘贴您的 OpenAI Refresh Token...\n支持多个,每行一个', + sessionTokenAuth: '手动输入 ST', + sessionTokenDesc: '输入您已有的 Sora Session Token,支持批量输入(每行一个),系统将自动验证并创建账号。', + sessionTokenPlaceholder: '粘贴您的 Sora Session Token...\n支持多个,每行一个', validating: '验证中...', validateAndCreate: '验证并创建账号', - pleaseEnterRefreshToken: '请输入 Refresh Token' + pleaseEnterRefreshToken: '请输入 Refresh Token', + pleaseEnterSessionToken: '请输入 Session Token' }, // Gemini specific gemini: { @@ -2097,6 +2101,7 @@ export default { reAuthorizeAccount: '重新授权账号', claudeCodeAccount: 'Claude Code 账号', openaiAccount: 'OpenAI 账号', + soraAccount: 'Sora 账号', geminiAccount: 'Gemini 账号', antigravityAccount: 'Antigravity 账号', inputMethod: '输入方式', From 5d2219d299b98e113ac5ca4e8ca5b90e0593c2c2 Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Thu, 19 Feb 2026 08:23:00 +0800 Subject: [PATCH 18/28] =?UTF-8?q?fix(sora):=20=E4=BF=AE=E5=A4=8D=E4=BB=A4?= =?UTF-8?q?=E7=89=8C=E5=88=B7=E6=96=B0=E8=AF=B7=E6=B1=82=E6=A0=BC=E5=BC=8F?= =?UTF-8?q?=E4=B8=8E=E6=B5=81=E5=BC=8F=E9=94=99=E8=AF=AF=E8=BD=AC=E4=B9=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 将 refresh_token 恢复请求改为表单编码并匹配 OAuth 约定 - 流式错误改为 JSON 序列化,避免消息含引号或换行导致 SSE 非法 - 补充 Sora token 恢复与 failover 流式错误透传回归测试 Co-Authored-By: Claude Opus 4.6 --- .../internal/handler/sora_gateway_handler.go | 14 +++- .../handler/sora_gateway_handler_test.go | 81 +++++++++++++++++++ backend/internal/service/sora_client.go | 19 ++--- backend/internal/service/sora_client_test.go | 6 ++ 4 files changed, 107 insertions(+), 13 deletions(-) diff --git a/backend/internal/handler/sora_gateway_handler.go b/backend/internal/handler/sora_gateway_handler.go index 9c9f53b1..3a5ddcb0 100644 --- a/backend/internal/handler/sora_gateway_handler.go +++ b/backend/internal/handler/sora_gateway_handler.go @@ -4,6 +4,7 @@ import ( "context" "crypto/sha256" "encoding/hex" + "encoding/json" "errors" "fmt" "io" @@ -442,7 +443,18 @@ func (h *SoraGatewayHandler) handleStreamingAwareError(c *gin.Context, status in if streamStarted { flusher, ok := c.Writer.(http.Flusher) if ok { - errorEvent := fmt.Sprintf(`event: error`+"\n"+`data: {"error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message) + errorData := map[string]any{ + "error": map[string]string{ + "type": errType, + "message": message, + }, + } + jsonBytes, err := json.Marshal(errorData) + if err != nil { + _ = c.Error(err) + return + } + errorEvent := fmt.Sprintf("event: error\ndata: %s\n\n", string(jsonBytes)) if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil { _ = c.Error(err) } diff --git a/backend/internal/handler/sora_gateway_handler_test.go b/backend/internal/handler/sora_gateway_handler_test.go index 39e2eed6..d80b959c 100644 --- a/backend/internal/handler/sora_gateway_handler_test.go +++ b/backend/internal/handler/sora_gateway_handler_test.go @@ -498,3 +498,84 @@ func TestGenerateOpenAISessionHash_WithBody(t *testing.T) { require.NotEmpty(t, hash3) require.NotEqual(t, hash, hash3) // 不同来源应产生不同 hash } + +func TestSoraHandleStreamingAwareError_JSONEscaping(t *testing.T) { + tests := []struct { + name string + errType string + message string + }{ + { + name: "包含双引号", + errType: "upstream_error", + message: `upstream returned "invalid" payload`, + }, + { + name: "包含换行和制表符", + errType: "rate_limit_error", + message: "line1\nline2\ttab", + }, + { + name: "包含反斜杠", + errType: "upstream_error", + message: `path C:\Users\test\file.txt not found`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + + h := &SoraGatewayHandler{} + h.handleStreamingAwareError(c, http.StatusBadGateway, tt.errType, tt.message, true) + + body := w.Body.String() + require.True(t, strings.HasPrefix(body, "event: error\n"), "应以 SSE error 事件开头") + require.True(t, strings.HasSuffix(body, "\n\n"), "应以 SSE 结束分隔符结尾") + + lines := strings.Split(strings.TrimSuffix(body, "\n\n"), "\n") + require.Len(t, lines, 2, "SSE 错误事件应包含 event 行和 data 行") + require.Equal(t, "event: error", lines[0]) + require.True(t, strings.HasPrefix(lines[1], "data: "), "第二行应为 data 前缀") + + jsonStr := strings.TrimPrefix(lines[1], "data: ") + var parsed map[string]any + require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed), "data 行必须是合法 JSON") + + errorObj, ok := parsed["error"].(map[string]any) + require.True(t, ok, "JSON 中应包含 error 对象") + require.Equal(t, tt.errType, errorObj["type"]) + require.Equal(t, tt.message, errorObj["message"]) + }) + } +} + +func TestSoraHandleFailoverExhausted_StreamPassesUpstreamMessage(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + + h := &SoraGatewayHandler{} + resp := []byte(`{"error":{"message":"invalid \"prompt\"\nline2","code":"bad_request"}}`) + h.handleFailoverExhausted(c, http.StatusBadGateway, resp, true) + + body := w.Body.String() + require.True(t, strings.HasPrefix(body, "event: error\n")) + require.True(t, strings.HasSuffix(body, "\n\n")) + + lines := strings.Split(strings.TrimSuffix(body, "\n\n"), "\n") + require.Len(t, lines, 2) + jsonStr := strings.TrimPrefix(lines[1], "data: ") + + var parsed map[string]any + require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed)) + + errorObj, ok := parsed["error"].(map[string]any) + require.True(t, ok) + require.Equal(t, "upstream_error", errorObj["type"]) + require.Equal(t, "invalid \"prompt\"\nline2", errorObj["message"]) +} diff --git a/backend/internal/service/sora_client.go b/backend/internal/service/sora_client.go index 38be7a04..38c1b3cc 100644 --- a/backend/internal/service/sora_client.go +++ b/backend/internal/service/sora_client.go @@ -779,22 +779,17 @@ func (c *SoraDirectClient) exchangeRefreshToken(ctx context.Context, account *Ac } tried[clientID] = struct{}{} - payload := map[string]any{ - "client_id": clientID, - "grant_type": "refresh_token", - "refresh_token": refreshToken, - "redirect_uri": "com.openai.chat://auth0.openai.com/ios/com.openai.chat/callback", - } - bodyBytes, err := json.Marshal(payload) - if err != nil { - return "", "", "", err - } + formData := url.Values{} + formData.Set("client_id", clientID) + formData.Set("grant_type", "refresh_token") + formData.Set("refresh_token", refreshToken) + formData.Set("redirect_uri", "com.openai.chat://auth0.openai.com/ios/com.openai.chat/callback") headers := http.Header{} headers.Set("Accept", "application/json") - headers.Set("Content-Type", "application/json") + headers.Set("Content-Type", "application/x-www-form-urlencoded") headers.Set("User-Agent", c.defaultUserAgent()) - respBody, _, err := c.doRequest(ctx, account, http.MethodPost, soraOAuthTokenURL, headers, bytes.NewReader(bodyBytes), false) + respBody, _, err := c.doRequest(ctx, account, http.MethodPost, soraOAuthTokenURL, headers, strings.NewReader(formData.Encode()), false) if err != nil { lastErr = err if c.debugEnabled() { diff --git a/backend/internal/service/sora_client_test.go b/backend/internal/service/sora_client_test.go index 3e88c9f9..e566f06b 100644 --- a/backend/internal/service/sora_client_test.go +++ b/backend/internal/service/sora_client_test.go @@ -281,6 +281,12 @@ func TestSoraDirectClient_GetAccessToken_FromRefreshToken(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { require.Equal(t, http.MethodPost, r.Method) require.Equal(t, "/oauth/token", r.URL.Path) + require.Equal(t, "application/x-www-form-urlencoded", r.Header.Get("Content-Type")) + require.NoError(t, r.ParseForm()) + require.Equal(t, "refresh_token", r.FormValue("grant_type")) + require.Equal(t, "refresh-token-old", r.FormValue("refresh_token")) + require.NotEmpty(t, r.FormValue("client_id")) + require.Equal(t, "com.openai.chat://auth0.openai.com/ios/com.openai.chat/callback", r.FormValue("redirect_uri")) w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(map[string]any{ "access_token": "refresh-access-token", From be09188bdaff441cd8cfc680f573cf22448abe31 Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Thu, 19 Feb 2026 08:29:51 +0800 Subject: [PATCH 19/28] =?UTF-8?q?feat(account-test):=20=E5=A2=9E=E5=BC=BA?= =?UTF-8?q?=20Sora=20=E8=B4=A6=E5=8F=B7=E6=B5=8B=E8=AF=95=E8=83=BD?= =?UTF-8?q?=E5=8A=9B=E6=8E=A2=E6=B5=8B=E4=B8=8E=E5=BC=B9=E7=AA=97=E4=BA=A4?= =?UTF-8?q?=E4=BA=92?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 后端新增 Sora2 邀请码与剩余额度探测,并补充对应结果解析 - Sora 测试流程补齐请求头与 Cloudflare 场景提示,完善单测覆盖 - 前端测试弹窗对 Sora 账号改为免选模型流程,并新增中英文提示文案 Co-Authored-By: Claude Opus 4.6 --- .../internal/service/account_test_service.go | 175 ++++++++++++++++++ .../service/account_test_service_sora_test.go | 17 +- .../components/account/AccountTestModal.vue | 36 +++- .../admin/account/AccountTestModal.vue | 41 ++-- frontend/src/i18n/locales/en.ts | 4 + frontend/src/i18n/locales/zh.ts | 4 + 6 files changed, 251 insertions(+), 26 deletions(-) diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index 67c9ef0c..e6c1cf4c 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -34,6 +34,9 @@ const ( chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses" soraMeAPIURL = "https://sora.chatgpt.com/backend/me" // Sora 用户信息接口,用于测试连接 soraBillingAPIURL = "https://sora.chatgpt.com/backend/billing/subscriptions" + soraInviteMineURL = "https://sora.chatgpt.com/backend/project_y/invite/mine" + soraBootstrapURL = "https://sora.chatgpt.com/backend/m/bootstrap" + soraRemainingURL = "https://sora.chatgpt.com/backend/nf/check" ) // TestEvent represents a SSE event for account testing @@ -498,6 +501,9 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account * req.Header.Set("Authorization", "Bearer "+authToken) req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") req.Header.Set("Accept", "application/json") + req.Header.Set("Accept-Language", "en-US,en;q=0.9") + req.Header.Set("Origin", "https://sora.chatgpt.com") + req.Header.Set("Referer", "https://sora.chatgpt.com/") // Get proxy URL proxyURL := "" @@ -543,6 +549,9 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account * subReq.Header.Set("Authorization", "Bearer "+authToken) subReq.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") subReq.Header.Set("Accept", "application/json") + subReq.Header.Set("Accept-Language", "en-US,en;q=0.9") + subReq.Header.Set("Origin", "https://sora.chatgpt.com") + subReq.Header.Set("Referer", "https://sora.chatgpt.com/") subResp, subErr := s.httpUpstream.DoWithTLS(subReq, proxyURL, account.ID, account.Concurrency, enableSoraTLSFingerprint) if subErr != nil { @@ -566,10 +575,134 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account * } } + // 追加 Sora2 能力探测(对齐 sora2api 的测试思路):邀请码 + 剩余额度。 + s.testSora2Capabilities(c, ctx, account, authToken, proxyURL, enableSoraTLSFingerprint) + s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) return nil } +func (s *AccountTestService) testSora2Capabilities( + c *gin.Context, + ctx context.Context, + account *Account, + authToken string, + proxyURL string, + enableTLSFingerprint bool, +) { + inviteStatus, inviteHeader, inviteBody, err := s.fetchSoraTestEndpoint( + ctx, + account, + authToken, + soraInviteMineURL, + proxyURL, + enableTLSFingerprint, + ) + if err != nil { + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite check skipped: %s", err.Error())}) + return + } + + if inviteStatus == http.StatusUnauthorized { + bootstrapStatus, _, _, bootstrapErr := s.fetchSoraTestEndpoint( + ctx, + account, + authToken, + soraBootstrapURL, + proxyURL, + enableTLSFingerprint, + ) + if bootstrapErr == nil && bootstrapStatus == http.StatusOK { + s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 bootstrap OK, retry invite check"}) + inviteStatus, inviteHeader, inviteBody, err = s.fetchSoraTestEndpoint( + ctx, + account, + authToken, + soraInviteMineURL, + proxyURL, + enableTLSFingerprint, + ) + if err != nil { + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite retry failed: %s", err.Error())}) + return + } + } + } + + if inviteStatus != http.StatusOK { + if isCloudflareChallengeResponse(inviteStatus, inviteBody) { + s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage("Sora2 invite check blocked by Cloudflare challenge (HTTP 403)", inviteHeader, inviteBody)}) + return + } + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite check returned %d", inviteStatus)}) + return + } + + if summary := parseSoraInviteSummary(inviteBody); summary != "" { + s.sendEvent(c, TestEvent{Type: "content", Text: summary}) + } else { + s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 invite check OK"}) + } + + remainingStatus, remainingHeader, remainingBody, remainingErr := s.fetchSoraTestEndpoint( + ctx, + account, + authToken, + soraRemainingURL, + proxyURL, + enableTLSFingerprint, + ) + if remainingErr != nil { + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 remaining check skipped: %s", remainingErr.Error())}) + return + } + if remainingStatus != http.StatusOK { + if isCloudflareChallengeResponse(remainingStatus, remainingBody) { + s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage("Sora2 remaining check blocked by Cloudflare challenge (HTTP 403)", remainingHeader, remainingBody)}) + return + } + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 remaining check returned %d", remainingStatus)}) + return + } + if summary := parseSoraRemainingSummary(remainingBody); summary != "" { + s.sendEvent(c, TestEvent{Type: "content", Text: summary}) + } else { + s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 remaining check OK"}) + } +} + +func (s *AccountTestService) fetchSoraTestEndpoint( + ctx context.Context, + account *Account, + authToken string, + url string, + proxyURL string, + enableTLSFingerprint bool, +) (int, http.Header, []byte, error) { + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return 0, nil, nil, err + } + req.Header.Set("Authorization", "Bearer "+authToken) + req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") + req.Header.Set("Accept", "application/json") + req.Header.Set("Accept-Language", "en-US,en;q=0.9") + req.Header.Set("Origin", "https://sora.chatgpt.com") + req.Header.Set("Referer", "https://sora.chatgpt.com/") + + resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, enableTLSFingerprint) + if err != nil { + return 0, nil, nil, err + } + defer func() { _ = resp.Body.Close() }() + + body, readErr := io.ReadAll(resp.Body) + if readErr != nil { + return resp.StatusCode, resp.Header, nil, readErr + } + return resp.StatusCode, resp.Header, body, nil +} + func parseSoraSubscriptionSummary(body []byte) string { var subResp struct { Data []struct { @@ -604,6 +737,48 @@ func parseSoraSubscriptionSummary(body []byte) string { return "Subscription: " + strings.Join(parts, " | ") } +func parseSoraInviteSummary(body []byte) string { + var inviteResp struct { + InviteCode string `json:"invite_code"` + RedeemedCount int64 `json:"redeemed_count"` + TotalCount int64 `json:"total_count"` + } + if err := json.Unmarshal(body, &inviteResp); err != nil { + return "" + } + + parts := []string{"Sora2: supported"} + if inviteResp.InviteCode != "" { + parts = append(parts, "invite="+inviteResp.InviteCode) + } + if inviteResp.TotalCount > 0 { + parts = append(parts, fmt.Sprintf("used=%d/%d", inviteResp.RedeemedCount, inviteResp.TotalCount)) + } + return strings.Join(parts, " | ") +} + +func parseSoraRemainingSummary(body []byte) string { + var remainingResp struct { + RateLimitAndCreditBalance struct { + EstimatedNumVideosRemaining int64 `json:"estimated_num_videos_remaining"` + RateLimitReached bool `json:"rate_limit_reached"` + AccessResetsInSeconds int64 `json:"access_resets_in_seconds"` + } `json:"rate_limit_and_credit_balance"` + } + if err := json.Unmarshal(body, &remainingResp); err != nil { + return "" + } + info := remainingResp.RateLimitAndCreditBalance + parts := []string{fmt.Sprintf("Sora2 remaining: %d", info.EstimatedNumVideosRemaining)} + if info.RateLimitReached { + parts = append(parts, "rate_limited=true") + } + if info.AccessResetsInSeconds > 0 { + parts = append(parts, fmt.Sprintf("reset_in=%ds", info.AccessResetsInSeconds)) + } + return strings.Join(parts, " | ") +} + func (s *AccountTestService) shouldEnableSoraTLSFingerprint() bool { if s == nil || s.cfg == nil { return false diff --git a/backend/internal/service/account_test_service_sora_test.go b/backend/internal/service/account_test_service_sora_test.go index fbbc8ff1..0c09bf18 100644 --- a/backend/internal/service/account_test_service_sora_test.go +++ b/backend/internal/service/account_test_service_sora_test.go @@ -61,6 +61,8 @@ func TestAccountTestService_testSoraAccountConnection_WithSubscription(t *testin responses: []*http.Response{ newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`), newJSONResponse(http.StatusOK, `{"data":[{"plan":{"id":"chatgpt_plus","title":"ChatGPT Plus"},"end_ts":"2026-12-31T00:00:00Z"}]}`), + newJSONResponse(http.StatusOK, `{"invite_code":"inv_abc","redeemed_count":3,"total_count":50}`), + newJSONResponse(http.StatusOK, `{"rate_limit_and_credit_balance":{"estimated_num_videos_remaining":27,"rate_limit_reached":false,"access_resets_in_seconds":46833}}`), }, } svc := &AccountTestService{ @@ -92,17 +94,21 @@ func TestAccountTestService_testSoraAccountConnection_WithSubscription(t *testin err := svc.testSoraAccountConnection(c, account) require.NoError(t, err) - require.Len(t, upstream.requests, 2) + require.Len(t, upstream.requests, 4) require.Equal(t, soraMeAPIURL, upstream.requests[0].URL.String()) require.Equal(t, soraBillingAPIURL, upstream.requests[1].URL.String()) + require.Equal(t, soraInviteMineURL, upstream.requests[2].URL.String()) + require.Equal(t, soraRemainingURL, upstream.requests[3].URL.String()) require.Equal(t, "Bearer test_token", upstream.requests[0].Header.Get("Authorization")) require.Equal(t, "Bearer test_token", upstream.requests[1].Header.Get("Authorization")) - require.Equal(t, []bool{true, true}, upstream.tlsFlags) + require.Equal(t, []bool{true, true, true, true}, upstream.tlsFlags) body := rec.Body.String() require.Contains(t, body, `"type":"test_start"`) require.Contains(t, body, "Sora connection OK - Email: demo@example.com") require.Contains(t, body, "Subscription: ChatGPT Plus | chatgpt_plus | end=2026-12-31T00:00:00Z") + require.Contains(t, body, "Sora2: supported | invite=inv_abc | used=3/50") + require.Contains(t, body, "Sora2 remaining: 27 | reset_in=46833s") require.Contains(t, body, `"type":"test_complete","success":true`) } @@ -111,6 +117,8 @@ func TestAccountTestService_testSoraAccountConnection_SubscriptionFailedStillSuc responses: []*http.Response{ newJSONResponse(http.StatusOK, `{"name":"demo-user"}`), newJSONResponse(http.StatusForbidden, `{"error":{"message":"forbidden"}}`), + newJSONResponse(http.StatusUnauthorized, `{"error":{"message":"Unauthorized"}}`), + newJSONResponse(http.StatusForbidden, `{"error":{"message":"forbidden"}}`), }, } svc := &AccountTestService{httpUpstream: upstream} @@ -128,10 +136,11 @@ func TestAccountTestService_testSoraAccountConnection_SubscriptionFailedStillSuc err := svc.testSoraAccountConnection(c, account) require.NoError(t, err) - require.Len(t, upstream.requests, 2) + require.Len(t, upstream.requests, 4) body := rec.Body.String() require.Contains(t, body, "Sora connection OK - User: demo-user") require.Contains(t, body, "Subscription check returned 403") + require.Contains(t, body, "Sora2 invite check returned 401") require.Contains(t, body, `"type":"test_complete","success":true`) } @@ -169,6 +178,7 @@ func TestAccountTestService_testSoraAccountConnection_SubscriptionCloudflareChal responses: []*http.Response{ newJSONResponse(http.StatusOK, `{"name":"demo-user"}`), newJSONResponse(http.StatusForbidden, `Just a moment...`), + newJSONResponse(http.StatusForbidden, `Just a moment...`), }, } svc := &AccountTestService{httpUpstream: upstream} @@ -188,6 +198,7 @@ func TestAccountTestService_testSoraAccountConnection_SubscriptionCloudflareChal require.NoError(t, err) body := rec.Body.String() require.Contains(t, body, "Subscription check blocked by Cloudflare challenge (HTTP 403)") + require.Contains(t, body, "Sora2 invite check blocked by Cloudflare challenge (HTTP 403)") require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d") require.Contains(t, body, `"type":"test_complete","success":true`) } diff --git a/frontend/src/components/account/AccountTestModal.vue b/frontend/src/components/account/AccountTestModal.vue index dfa1503e..792a8f45 100644 --- a/frontend/src/components/account/AccountTestModal.vue +++ b/frontend/src/components/account/AccountTestModal.vue @@ -41,7 +41,7 @@
-
+
@@ -54,6 +54,12 @@ :placeholder="loadingModels ? t('common.loading') + '...' : t('admin.accounts.selectTestModel')" />
+
+ {{ t('admin.accounts.soraTestHint') }} +
@@ -135,12 +141,12 @@
- {{ t('admin.accounts.testModel') }} + {{ isSoraAccount ? t('admin.accounts.soraTestTarget') : t('admin.accounts.testModel') }}
- {{ t('admin.accounts.testPrompt') }} + {{ isSoraAccount ? t('admin.accounts.soraTestMode') : t('admin.accounts.testPrompt') }}
@@ -156,10 +162,10 @@
-
+
@@ -54,6 +54,12 @@ :placeholder="loadingModels ? t('common.loading') + '...' : t('admin.accounts.selectTestModel')" />
+
+ {{ t('admin.accounts.soraTestHint') }} +
@@ -114,12 +120,12 @@
- {{ t('admin.accounts.testModel') }} + {{ isSoraAccount ? t('admin.accounts.soraTestTarget') : t('admin.accounts.testModel') }}
- {{ t('admin.accounts.testPrompt') }} + {{ isSoraAccount ? t('admin.accounts.soraTestMode') : t('admin.accounts.testPrompt') }}
@@ -135,10 +141,10 @@ + +