From 3c619a8da58a11bc698c17af74431ff445e427aa Mon Sep 17 00:00:00 2001 From: huangenjun <1021217094@qq.com> Date: Wed, 25 Feb 2026 10:15:38 +0800 Subject: [PATCH 1/4] =?UTF-8?q?refactor:=20=E4=BD=BF=E7=94=A8=20go-sora2ap?= =?UTF-8?q?i=20SDK=20=E6=9B=BF=E4=BB=A3=E8=87=AA=E5=BB=BA=20Sora=20?= =?UTF-8?q?=E5=AE=A2=E6=88=B7=E7=AB=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 使用 go-sora2api v1.1.0 SDK 替代原有 ~2000 行自建 HTTP/PoW/TLS 指纹代码, SDK 提供高并发性能优化(实例级 rand、PoW 缓冲区复用、context.Context 支持)。 - 新增 SoraSDKClient 适配器实现 SoraClient 接口 - 精简 sora_client.go 为仅保留接口和类型定义 - 更新 Wire 绑定使用 SoraSDKClient - 删除 SoraDirectClient、sora_curl_cffi_sidecar、sora_request_guard 等旧代码 Co-Authored-By: Claude Opus 4.6 --- backend/cmd/server/wire_gen.go | 4 +- backend/go.mod | 19 +- backend/go.sum | 48 +- backend/internal/service/sora_client.go | 2007 ----------------- .../service/sora_client_gjson_test.go | 515 ----- backend/internal/service/sora_client_test.go | 1075 --------- .../service/sora_curl_cffi_sidecar.go | 260 --- .../internal/service/sora_gateway_service.go | 3 +- .../internal/service/sora_media_storage.go | 20 +- .../internal/service/sora_request_guard.go | 266 --- backend/internal/service/sora_sdk_client.go | 803 +++++++ backend/internal/service/wire.go | 10 +- 12 files changed, 880 insertions(+), 4150 deletions(-) delete mode 100644 backend/internal/service/sora_client_gjson_test.go delete mode 100644 backend/internal/service/sora_client_test.go delete mode 100644 backend/internal/service/sora_curl_cffi_sidecar.go delete mode 100644 backend/internal/service/sora_request_guard.go create mode 100644 backend/internal/service/sora_sdk_client.go diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 7a277112..287f8176 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -187,9 +187,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig) gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig) openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig) - soraDirectClient := service.ProvideSoraDirectClient(configConfig, httpUpstream, openAITokenProvider, accountRepository, soraAccountRepository) + soraSDKClient := service.ProvideSoraSDKClient(configConfig, httpUpstream, openAITokenProvider, accountRepository, soraAccountRepository) soraMediaStorage := service.ProvideSoraMediaStorage(configConfig) - soraGatewayService := service.NewSoraGatewayService(soraDirectClient, soraMediaStorage, rateLimitService, configConfig) + soraGatewayService := service.NewSoraGatewayService(soraSDKClient, soraMediaStorage, rateLimitService, configConfig) soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, usageRecordWorkerPool, configConfig) handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) totpHandler := handler.NewTotpHandler(totpService) diff --git a/backend/go.mod b/backend/go.mod index ec3cf509..0adddadf 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -5,6 +5,7 @@ go 1.25.7 require ( entgo.io/ent v0.14.5 github.com/DATA-DOG/go-sqlmock v1.5.2 + github.com/DouDOU-start/go-sora2api v1.1.0 github.com/alitto/pond/v2 v2.6.2 github.com/cespare/xxhash/v2 v2.3.0 github.com/dgraph-io/ristretto v0.2.0 @@ -29,10 +30,10 @@ require ( github.com/tidwall/sjson v1.2.5 github.com/zeromicro/go-zero v1.9.4 go.uber.org/zap v1.24.0 - golang.org/x/crypto v0.47.0 + golang.org/x/crypto v0.48.0 golang.org/x/net v0.49.0 golang.org/x/sync v0.19.0 - golang.org/x/term v0.39.0 + golang.org/x/term v0.40.0 gopkg.in/natefinch/lumberjack.v2 v2.2.1 gopkg.in/yaml.v3 v3.0.1 modernc.org/sqlite v1.44.3 @@ -46,7 +47,14 @@ require ( github.com/agext/levenshtein v1.2.3 // indirect github.com/andybalholm/brotli v1.2.0 // indirect github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect + github.com/bdandy/go-errors v1.2.2 // indirect + github.com/bdandy/go-socks4 v1.2.3 // indirect github.com/bmatcuk/doublestar v1.3.4 // indirect + github.com/bogdanfinn/fhttp v0.6.8 // indirect + github.com/bogdanfinn/quic-go-utls v1.0.9-utls // indirect + github.com/bogdanfinn/tls-client v1.14.0 // indirect + github.com/bogdanfinn/utls v1.7.7-barnius // indirect + github.com/bogdanfinn/websocket v1.5.5-barnius // indirect github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect github.com/bytedance/sonic v1.9.1 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect @@ -123,6 +131,7 @@ require ( github.com/spf13/cast v1.6.0 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/subosito/gotenv v1.6.0 // indirect + github.com/tam7t/hpkp v0.0.0-20160821193359-2b70b4024ed5 // indirect github.com/testcontainers/testcontainers-go v0.40.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect @@ -144,9 +153,9 @@ require ( go.uber.org/multierr v1.9.0 // indirect golang.org/x/arch v0.3.0 // indirect golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect - golang.org/x/mod v0.31.0 // indirect - golang.org/x/sys v0.40.0 // indirect - golang.org/x/text v0.33.0 // indirect + golang.org/x/mod v0.32.0 // indirect + golang.org/x/sys v0.41.0 // indirect + golang.org/x/text v0.34.0 // indirect google.golang.org/grpc v1.75.1 // indirect google.golang.org/protobuf v1.36.10 // indirect gopkg.in/ini.v1 v1.67.0 // indirect diff --git a/backend/go.sum b/backend/go.sum index d1728e48..efe6c145 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -10,6 +10,8 @@ github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOEl github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= +github.com/DouDOU-start/go-sora2api v1.1.0 h1:PxWiukK77StiHxEngOFwT1rKUn9oTAJJTl07wQUXwiU= +github.com/DouDOU-start/go-sora2api v1.1.0/go.mod h1:dcwpethoKfAsMWskDD9iGgc/3yox2tkthPLSMVGnhkE= github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/agext/levenshtein v1.2.3 h1:YB2fHEn0UJagG8T1rrWknE3ZQzWM06O8AMAatNn7lmo= @@ -20,10 +22,24 @@ github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwTo github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew1u1fNQOlOtuGxQY= github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4= +github.com/bdandy/go-errors v1.2.2 h1:WdFv/oukjTJCLa79UfkGmwX7ZxONAihKu4V0mLIs11Q= +github.com/bdandy/go-errors v1.2.2/go.mod h1:NkYHl4Fey9oRRdbB1CoC6e84tuqQHiqrOcZpqFEkBxM= +github.com/bdandy/go-socks4 v1.2.3 h1:Q6Y2heY1GRjCtHbmlKfnwrKVU/k81LS8mRGLRlmDlic= +github.com/bdandy/go-socks4 v1.2.3/go.mod h1:98kiVFgpdogR8aIGLWLvjDVZ8XcKPsSI/ypGrO+bqHI= github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/bmatcuk/doublestar v1.3.4 h1:gPypJ5xD31uhX6Tf54sDPUOBXTqKH4c9aPY66CyQrS0= github.com/bmatcuk/doublestar v1.3.4/go.mod h1:wiQtGV+rzVYxB7WIlirSN++5HPtPlXEo9MEoZQC/PmE= +github.com/bogdanfinn/fhttp v0.6.8 h1:LiQyHOY3i0QoxxNB7nq27/nGNNbtPj0fuBPozhR7Ws4= +github.com/bogdanfinn/fhttp v0.6.8/go.mod h1:A+EKDzMx2hb4IUbMx4TlkoHnaJEiLl8r/1Ss1Y+5e5M= +github.com/bogdanfinn/quic-go-utls v1.0.9-utls h1:tV6eDEiRbRCcepALSzxR94JUVD3N3ACIiRLgyc2Ep8s= +github.com/bogdanfinn/quic-go-utls v1.0.9-utls/go.mod h1:aHph9B9H9yPOt5xnhWKSOum27DJAqpiHzwX+gjvaXcg= +github.com/bogdanfinn/tls-client v1.14.0 h1:vyk7Cn4BIvLAGVuMfb0tP22OqogfO1lYamquQNEZU1A= +github.com/bogdanfinn/tls-client v1.14.0/go.mod h1:LsU6mXVn8MOFDwTkyRfI7V1BZM1p0wf2ZfZsICW/1fM= +github.com/bogdanfinn/utls v1.7.7-barnius h1:OuJ497cc7F3yKNVHRsYPQdGggmk5x6+V5ZlrCR7fOLU= +github.com/bogdanfinn/utls v1.7.7-barnius/go.mod h1:aAK1VZQlpKZClF1WEQeq6kyclbkPq4hz6xTbB5xSlmg= +github.com/bogdanfinn/websocket v1.5.5-barnius h1:bY+qnxpai1qe7Jmjx+Sds/cmOSpuuLoR8x61rWltjOI= +github.com/bogdanfinn/websocket v1.5.5-barnius/go.mod h1:gvvEw6pTKHb7yOiFvIfAFTStQWyrm25BMVCTj5wRSsI= github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI= github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= @@ -279,6 +295,8 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +github.com/tam7t/hpkp v0.0.0-20160821193359-2b70b4024ed5 h1:YqAladjX7xpA6BM04leXMWAEjS0mTZ5kUU9KRBriQJc= +github.com/tam7t/hpkp v0.0.0-20160821193359-2b70b4024ed5/go.mod h1:2JjD2zLQYH5HO74y5+aE3remJQvl6q4Sn6aWA2wD1Ng= github.com/testcontainers/testcontainers-go v0.40.0 h1:pSdJYLOVgLE8YdUY2FHQ1Fxu+aMnb6JfVz1mxk7OeMU= github.com/testcontainers/testcontainers-go v0.40.0/go.mod h1:FSXV5KQtX2HAMlm7U3APNyLkkap35zNLxukw9oBi/MY= github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0 h1:s2bIayFXlbDFexo96y+htn7FzuhpXLYJNnIuglNKqOk= @@ -345,18 +363,21 @@ go.uber.org/zap v1.24.0/go.mod h1:2kMP+WWQ8aoFoedH3T2sq6iJ2yDWpHbP0f6MQbS9Gkg= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= -golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= -golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A= +golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= +golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY= golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70= -golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI= -golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg= +golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c= +golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU= +golang.org/x/net v0.0.0-20211104170005-ce137452f963/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -364,16 +385,19 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= -golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/term v0.39.0 h1:RclSuaJf32jOqZz74CkPA9qFuVTX7vhLlpfj/IGWlqY= -golang.org/x/term v0.39.0/go.mod h1:yxzUCTP/U+FzoxfdKmLaA0RV1WgE0VY7hXBwKtY/4ww= -golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= -golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= +golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= +golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg= +golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= +golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= -golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA= -golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc= +golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 h1:wpZ8pe2x1Q3f2KyT5f8oP/fa9rHAKgFPr/HZdNuS+PQ= google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4 h1:8XJ4pajGwOlasW+L13MnEGA8W4115jJySQtVfS2/IBU= diff --git a/backend/internal/service/sora_client.go b/backend/internal/service/sora_client.go index 7cecfa03..4680538c 100644 --- a/backend/internal/service/sora_client.go +++ b/backend/internal/service/sora_client.go @@ -1,116 +1,11 @@ package service import ( - "bytes" "context" - "encoding/base64" - "encoding/hex" - "encoding/json" - "errors" "fmt" - "hash/fnv" - "io" - "log" - "math/rand" - "mime" - "mime/multipart" "net/http" - "net/textproto" - "net/url" - "path" - "sort" - "strconv" - "strings" - "sync" - "time" - - "github.com/Wei-Shaw/sub2api/internal/config" - openaioauth "github.com/Wei-Shaw/sub2api/internal/pkg/openai" - "github.com/Wei-Shaw/sub2api/internal/util/logredact" - "github.com/Wei-Shaw/sub2api/internal/util/soraerror" - "github.com/google/uuid" - "github.com/tidwall/gjson" - "golang.org/x/crypto/sha3" ) -const ( - soraChatGPTBaseURL = "https://chatgpt.com" - soraSentinelFlow = "sora_2_create_task" - soraDefaultUserAgent = "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)" -) - -var ( - soraSessionAuthURL = "https://sora.chatgpt.com/api/auth/session" - soraOAuthTokenURL = "https://auth.openai.com/oauth/token" -) - -const ( - soraPowMaxIteration = 500000 -) - -var soraPowCores = []int{8, 16, 24, 32} - -var soraPowScripts = []string{ - "https://cdn.oaistatic.com/_next/static/cXh69klOLzS0Gy2joLDRS/_ssgManifest.js?dpl=453ebaec0d44c2decab71692e1bfe39be35a24b3", -} - -var soraPowDPL = []string{ - "prod-f501fe933b3edf57aea882da888e1a544df99840", -} - -var soraPowNavigatorKeys = []string{ - "registerProtocolHandler−function registerProtocolHandler() { [native code] }", - "storage−[object StorageManager]", - "locks−[object LockManager]", - "appCodeName−Mozilla", - "permissions−[object Permissions]", - "webdriver−false", - "vendor−Google Inc.", - "mediaDevices−[object MediaDevices]", - "cookieEnabled−true", - "product−Gecko", - "productSub−20030107", - "hardwareConcurrency−32", - "onLine−true", -} - -var soraPowDocumentKeys = []string{ - "_reactListeningo743lnnpvdg", - "location", -} - -var soraPowWindowKeys = []string{ - "0", "window", "self", "document", "name", "location", - "navigator", "screen", "innerWidth", "innerHeight", - "localStorage", "sessionStorage", "crypto", "performance", - "fetch", "setTimeout", "setInterval", "console", -} - -var soraDesktopUserAgents = []string{ - "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36", - "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36", - "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/129.0.0.0 Safari/537.36", - "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36", - "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36", - "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36", - "Mozilla/5.0 (Windows NT 11.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36", -} - -var soraMobileUserAgents = []string{ - "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)", - "Sora/1.2026.007 (Android 14; SM-G998B; build 2600700)", - "Sora/1.2026.007 (Android 15; Pixel 8 Pro; build 2600700)", - "Sora/1.2026.007 (Android 14; Pixel 7; build 2600700)", - "Sora/1.2026.007 (Android 15; 2211133C; build 2600700)", - "Sora/1.2026.007 (Android 14; SM-S918B; build 2600700)", - "Sora/1.2026.007 (Android 15; OnePlus 12; build 2600700)", -} - -var soraRand = rand.New(rand.NewSource(time.Now().UnixNano())) -var soraRandMu sync.Mutex -var soraPerfStart = time.Now() -var soraPowTokenGenerator = soraGetPowToken - // SoraClient 定义直连 Sora 的任务操作接口。 type SoraClient interface { Enabled() bool @@ -219,1905 +114,3 @@ func (e *SoraUpstreamError) Error() string { } return fmt.Sprintf("sora upstream error: %d", e.StatusCode) } - -// SoraDirectClient 直连 Sora 实现 -type SoraDirectClient struct { - cfg *config.Config - httpUpstream HTTPUpstream - tokenProvider *OpenAITokenProvider - accountRepo AccountRepository - soraAccountRepo SoraAccountRepository - baseURL string - challengeCooldownMu sync.RWMutex - challengeCooldowns map[string]soraChallengeCooldownEntry - sidecarSessionMu sync.RWMutex - sidecarSessions map[string]soraSidecarSessionEntry -} - -type soraRequestTraceContextKey struct{} - -type soraRequestTrace struct { - ID string - ProxyKey string - UAHash string -} - -// NewSoraDirectClient 创建 Sora 直连客户端 -func NewSoraDirectClient(cfg *config.Config, httpUpstream HTTPUpstream, tokenProvider *OpenAITokenProvider) *SoraDirectClient { - baseURL := "" - if cfg != nil { - rawBaseURL := strings.TrimRight(strings.TrimSpace(cfg.Sora.Client.BaseURL), "/") - baseURL = normalizeSoraBaseURL(rawBaseURL) - if rawBaseURL != "" && baseURL != rawBaseURL { - log.Printf("[SoraClient] normalized base_url from %s to %s", sanitizeSoraLogURL(rawBaseURL), sanitizeSoraLogURL(baseURL)) - } - } - return &SoraDirectClient{ - cfg: cfg, - httpUpstream: httpUpstream, - tokenProvider: tokenProvider, - baseURL: baseURL, - challengeCooldowns: make(map[string]soraChallengeCooldownEntry), - sidecarSessions: make(map[string]soraSidecarSessionEntry), - } -} - -func (c *SoraDirectClient) SetAccountRepositories(accountRepo AccountRepository, soraAccountRepo SoraAccountRepository) { - if c == nil { - return - } - c.accountRepo = accountRepo - c.soraAccountRepo = soraAccountRepo -} - -// Enabled 判断是否启用 Sora 直连 -func (c *SoraDirectClient) Enabled() bool { - if c == nil { - return false - } - if strings.TrimSpace(c.baseURL) != "" { - return true - } - if c.cfg == nil { - return false - } - return strings.TrimSpace(normalizeSoraBaseURL(c.cfg.Sora.Client.BaseURL)) != "" -} - -// PreflightCheck 在创建任务前执行账号能力预检。 -// 当前仅对视频模型执行 /nf/check 预检,用于提前识别额度耗尽或能力缺失。 -func (c *SoraDirectClient) PreflightCheck(ctx context.Context, account *Account, requestedModel string, modelCfg SoraModelConfig) error { - if modelCfg.Type != "video" { - return nil - } - token, err := c.getAccessToken(ctx, account) - if err != nil { - return err - } - userAgent := c.taskUserAgent() - proxyURL := c.resolveProxyURL(account) - headers := c.buildBaseHeaders(token, userAgent) - headers.Set("Accept", "application/json") - body, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodGet, c.buildURL("/nf/check"), headers, nil, false) - if err != nil { - var upstreamErr *SoraUpstreamError - if errors.As(err, &upstreamErr) && upstreamErr.StatusCode == http.StatusNotFound { - return &SoraUpstreamError{ - StatusCode: http.StatusForbidden, - Message: "当前账号未开通 Sora2 能力或无可用配额", - Headers: upstreamErr.Headers, - Body: upstreamErr.Body, - } - } - return err - } - - rateLimitReached := gjson.GetBytes(body, "rate_limit_and_credit_balance.rate_limit_reached").Bool() - remaining := gjson.GetBytes(body, "rate_limit_and_credit_balance.estimated_num_videos_remaining") - if rateLimitReached || (remaining.Exists() && remaining.Int() <= 0) { - msg := "当前账号 Sora2 可用配额不足" - if requestedModel != "" { - msg = fmt.Sprintf("当前账号 %s 可用配额不足", requestedModel) - } - return &SoraUpstreamError{ - StatusCode: http.StatusTooManyRequests, - Message: msg, - Headers: http.Header{}, - } - } - return nil -} - -func (c *SoraDirectClient) UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) { - if len(data) == 0 { - return "", errors.New("empty image data") - } - token, err := c.getAccessToken(ctx, account) - if err != nil { - return "", err - } - userAgent := c.taskUserAgent() - proxyURL := c.resolveProxyURL(account) - if filename == "" { - filename = "image.png" - } - var body bytes.Buffer - writer := multipart.NewWriter(&body) - contentType := mime.TypeByExtension(path.Ext(filename)) - if contentType == "" { - contentType = "application/octet-stream" - } - partHeader := make(textproto.MIMEHeader) - partHeader.Set("Content-Disposition", fmt.Sprintf(`form-data; name="file"; filename="%s"`, filename)) - partHeader.Set("Content-Type", contentType) - part, err := writer.CreatePart(partHeader) - if err != nil { - return "", err - } - if _, err := part.Write(data); err != nil { - return "", err - } - if err := writer.WriteField("file_name", filename); err != nil { - return "", err - } - if err := writer.Close(); err != nil { - return "", err - } - - headers := c.buildBaseHeaders(token, userAgent) - headers.Set("Content-Type", writer.FormDataContentType()) - - respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/uploads"), headers, &body, false) - if err != nil { - return "", err - } - id := strings.TrimSpace(gjson.GetBytes(respBody, "id").String()) - if id == "" { - return "", errors.New("upload response missing id") - } - return id, nil -} - -func (c *SoraDirectClient) CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error) { - token, err := c.getAccessToken(ctx, account) - if err != nil { - return "", err - } - userAgent := c.taskUserAgent() - proxyURL := c.resolveProxyURL(account) - ctx = c.withRequestTrace(ctx, account, proxyURL, userAgent) - operation := "simple_compose" - inpaintItems := []map[string]any{} - if strings.TrimSpace(req.MediaID) != "" { - operation = "remix" - inpaintItems = append(inpaintItems, map[string]any{ - "type": "image", - "frame_index": 0, - "upload_media_id": req.MediaID, - }) - } - payload := map[string]any{ - "type": "image_gen", - "operation": operation, - "prompt": req.Prompt, - "width": req.Width, - "height": req.Height, - "n_variants": 1, - "n_frames": 1, - "inpaint_items": inpaintItems, - } - headers := c.buildBaseHeaders(token, userAgent) - headers.Set("Content-Type", "application/json") - headers.Set("Origin", "https://sora.chatgpt.com") - headers.Set("Referer", "https://sora.chatgpt.com/") - - body, err := json.Marshal(payload) - if err != nil { - return "", err - } - sentinel, err := c.generateSentinelToken(ctx, account, token, userAgent, proxyURL) - if err != nil { - return "", err - } - headers.Set("openai-sentinel-token", sentinel) - - respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/video_gen"), headers, bytes.NewReader(body), true) - if err != nil { - return "", err - } - taskID := strings.TrimSpace(gjson.GetBytes(respBody, "id").String()) - if taskID == "" { - return "", errors.New("image task response missing id") - } - return taskID, nil -} - -func (c *SoraDirectClient) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) { - token, err := c.getAccessToken(ctx, account) - if err != nil { - return "", err - } - userAgent := c.taskUserAgent() - proxyURL := c.resolveProxyURL(account) - ctx = c.withRequestTrace(ctx, account, proxyURL, userAgent) - orientation := req.Orientation - if orientation == "" { - orientation = "landscape" - } - nFrames := req.Frames - if nFrames <= 0 { - nFrames = 450 - } - model := req.Model - if model == "" { - model = "sy_8" - } - size := req.Size - if size == "" { - size = "small" - } - - inpaintItems := []map[string]any{} - if strings.TrimSpace(req.MediaID) != "" { - inpaintItems = append(inpaintItems, map[string]any{ - "kind": "upload", - "upload_id": req.MediaID, - }) - } - payload := map[string]any{ - "kind": "video", - "prompt": req.Prompt, - "orientation": orientation, - "size": size, - "n_frames": nFrames, - "model": model, - "inpaint_items": inpaintItems, - } - if strings.TrimSpace(req.RemixTargetID) != "" { - payload["remix_target_id"] = req.RemixTargetID - payload["cameo_ids"] = []string{} - payload["cameo_replacements"] = map[string]any{} - } else if len(req.CameoIDs) > 0 { - payload["cameo_ids"] = req.CameoIDs - payload["cameo_replacements"] = map[string]any{} - } - - headers := c.buildBaseHeaders(token, userAgent) - headers.Set("Content-Type", "application/json") - headers.Set("Origin", "https://sora.chatgpt.com") - headers.Set("Referer", "https://sora.chatgpt.com/") - body, err := json.Marshal(payload) - if err != nil { - return "", err - } - sentinel, err := c.generateSentinelToken(ctx, account, token, userAgent, proxyURL) - if err != nil { - return "", err - } - headers.Set("openai-sentinel-token", sentinel) - - respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/nf/create"), headers, bytes.NewReader(body), true) - if err != nil { - return "", err - } - taskID := strings.TrimSpace(gjson.GetBytes(respBody, "id").String()) - if taskID == "" { - return "", errors.New("video task response missing id") - } - return taskID, nil -} - -func (c *SoraDirectClient) CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error) { - token, err := c.getAccessToken(ctx, account) - if err != nil { - return "", err - } - userAgent := c.taskUserAgent() - proxyURL := c.resolveProxyURL(account) - ctx = c.withRequestTrace(ctx, account, proxyURL, userAgent) - orientation := req.Orientation - if orientation == "" { - orientation = "landscape" - } - nFrames := req.Frames - if nFrames <= 0 { - nFrames = 450 - } - model := req.Model - if model == "" { - model = "sy_8" - } - size := req.Size - if size == "" { - size = "small" - } - - inpaintItems := []map[string]any{} - if strings.TrimSpace(req.MediaID) != "" { - inpaintItems = append(inpaintItems, map[string]any{ - "kind": "upload", - "upload_id": req.MediaID, - }) - } - payload := map[string]any{ - "kind": "video", - "prompt": req.Prompt, - "title": "Draft your video", - "orientation": orientation, - "size": size, - "n_frames": nFrames, - "storyboard_id": nil, - "inpaint_items": inpaintItems, - "remix_target_id": nil, - "model": model, - "metadata": nil, - "style_id": nil, - "cameo_ids": nil, - "cameo_replacements": nil, - "audio_caption": nil, - "audio_transcript": nil, - "video_caption": nil, - } - - headers := c.buildBaseHeaders(token, userAgent) - headers.Set("Content-Type", "application/json") - headers.Set("Origin", "https://sora.chatgpt.com") - headers.Set("Referer", "https://sora.chatgpt.com/") - body, err := json.Marshal(payload) - if err != nil { - return "", err - } - sentinel, err := c.generateSentinelToken(ctx, account, token, userAgent, proxyURL) - if err != nil { - return "", err - } - headers.Set("openai-sentinel-token", sentinel) - - respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/nf/create/storyboard"), headers, bytes.NewReader(body), true) - if err != nil { - return "", err - } - taskID := strings.TrimSpace(gjson.GetBytes(respBody, "id").String()) - if taskID == "" { - return "", errors.New("storyboard task response missing id") - } - return taskID, nil -} - -func (c *SoraDirectClient) UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error) { - if len(data) == 0 { - return "", errors.New("empty video data") - } - token, err := c.getAccessToken(ctx, account) - if err != nil { - return "", err - } - userAgent := c.taskUserAgent() - proxyURL := c.resolveProxyURL(account) - - var body bytes.Buffer - writer := multipart.NewWriter(&body) - partHeader := make(textproto.MIMEHeader) - partHeader.Set("Content-Disposition", `form-data; name="file"; filename="video.mp4"`) - partHeader.Set("Content-Type", "video/mp4") - part, err := writer.CreatePart(partHeader) - if err != nil { - return "", err - } - if _, err := part.Write(data); err != nil { - return "", err - } - if err := writer.WriteField("timestamps", "0,3"); err != nil { - return "", err - } - if err := writer.Close(); err != nil { - return "", err - } - - headers := c.buildBaseHeaders(token, userAgent) - headers.Set("Content-Type", writer.FormDataContentType()) - respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/characters/upload"), headers, &body, false) - if err != nil { - return "", err - } - cameoID := strings.TrimSpace(gjson.GetBytes(respBody, "id").String()) - if cameoID == "" { - return "", errors.New("character upload response missing id") - } - return cameoID, nil -} - -func (c *SoraDirectClient) GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) { - token, err := c.getAccessToken(ctx, account) - if err != nil { - return nil, err - } - userAgent := c.taskUserAgent() - proxyURL := c.resolveProxyURL(account) - headers := c.buildBaseHeaders(token, userAgent) - respBody, _, err := c.doRequestWithProxy( - ctx, - account, - proxyURL, - http.MethodGet, - c.buildURL("/project_y/cameos/in_progress/"+strings.TrimSpace(cameoID)), - headers, - nil, - false, - ) - if err != nil { - return nil, err - } - return &SoraCameoStatus{ - Status: strings.TrimSpace(gjson.GetBytes(respBody, "status").String()), - StatusMessage: strings.TrimSpace(gjson.GetBytes(respBody, "status_message").String()), - DisplayNameHint: strings.TrimSpace(gjson.GetBytes(respBody, "display_name_hint").String()), - UsernameHint: strings.TrimSpace(gjson.GetBytes(respBody, "username_hint").String()), - ProfileAssetURL: strings.TrimSpace(gjson.GetBytes(respBody, "profile_asset_url").String()), - InstructionSetHint: gjson.GetBytes(respBody, "instruction_set_hint").Value(), - InstructionSet: gjson.GetBytes(respBody, "instruction_set").Value(), - }, nil -} - -func (c *SoraDirectClient) DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error) { - token, err := c.getAccessToken(ctx, account) - if err != nil { - return nil, err - } - userAgent := c.taskUserAgent() - proxyURL := c.resolveProxyURL(account) - headers := c.buildBaseHeaders(token, userAgent) - headers.Set("Accept", "image/*,*/*;q=0.8") - - respBody, _, err := c.doRequestWithProxy( - ctx, - account, - proxyURL, - http.MethodGet, - strings.TrimSpace(imageURL), - headers, - nil, - false, - ) - if err != nil { - return nil, err - } - return respBody, nil -} - -func (c *SoraDirectClient) UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error) { - if len(data) == 0 { - return "", errors.New("empty character image") - } - token, err := c.getAccessToken(ctx, account) - if err != nil { - return "", err - } - userAgent := c.taskUserAgent() - proxyURL := c.resolveProxyURL(account) - - var body bytes.Buffer - writer := multipart.NewWriter(&body) - partHeader := make(textproto.MIMEHeader) - partHeader.Set("Content-Disposition", `form-data; name="file"; filename="profile.webp"`) - partHeader.Set("Content-Type", "image/webp") - part, err := writer.CreatePart(partHeader) - if err != nil { - return "", err - } - if _, err := part.Write(data); err != nil { - return "", err - } - if err := writer.WriteField("use_case", "profile"); err != nil { - return "", err - } - if err := writer.Close(); err != nil { - return "", err - } - - headers := c.buildBaseHeaders(token, userAgent) - headers.Set("Content-Type", writer.FormDataContentType()) - respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/project_y/file/upload"), headers, &body, false) - if err != nil { - return "", err - } - assetPointer := strings.TrimSpace(gjson.GetBytes(respBody, "asset_pointer").String()) - if assetPointer == "" { - return "", errors.New("character image upload response missing asset_pointer") - } - return assetPointer, nil -} - -func (c *SoraDirectClient) FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error) { - token, err := c.getAccessToken(ctx, account) - if err != nil { - return "", err - } - userAgent := c.taskUserAgent() - proxyURL := c.resolveProxyURL(account) - ctx = c.withRequestTrace(ctx, account, proxyURL, userAgent) - payload := map[string]any{ - "cameo_id": req.CameoID, - "username": req.Username, - "display_name": req.DisplayName, - "profile_asset_pointer": req.ProfileAssetPointer, - "instruction_set": nil, - "safety_instruction_set": nil, - } - body, err := json.Marshal(payload) - if err != nil { - return "", err - } - headers := c.buildBaseHeaders(token, userAgent) - headers.Set("Content-Type", "application/json") - headers.Set("Origin", "https://sora.chatgpt.com") - headers.Set("Referer", "https://sora.chatgpt.com/") - respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/characters/finalize"), headers, bytes.NewReader(body), false) - if err != nil { - return "", err - } - characterID := strings.TrimSpace(gjson.GetBytes(respBody, "character.character_id").String()) - if characterID == "" { - return "", errors.New("character finalize response missing character_id") - } - return characterID, nil -} - -func (c *SoraDirectClient) SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error { - token, err := c.getAccessToken(ctx, account) - if err != nil { - return err - } - userAgent := c.taskUserAgent() - proxyURL := c.resolveProxyURL(account) - payload := map[string]any{"visibility": "public"} - body, err := json.Marshal(payload) - if err != nil { - return err - } - headers := c.buildBaseHeaders(token, userAgent) - headers.Set("Content-Type", "application/json") - headers.Set("Origin", "https://sora.chatgpt.com") - headers.Set("Referer", "https://sora.chatgpt.com/") - _, _, err = c.doRequestWithProxy( - ctx, - account, - proxyURL, - http.MethodPost, - c.buildURL("/project_y/cameos/by_id/"+strings.TrimSpace(cameoID)+"/update_v2"), - headers, - bytes.NewReader(body), - false, - ) - return err -} - -func (c *SoraDirectClient) DeleteCharacter(ctx context.Context, account *Account, characterID string) error { - token, err := c.getAccessToken(ctx, account) - if err != nil { - return err - } - userAgent := c.taskUserAgent() - proxyURL := c.resolveProxyURL(account) - headers := c.buildBaseHeaders(token, userAgent) - _, _, err = c.doRequestWithProxy( - ctx, - account, - proxyURL, - http.MethodDelete, - c.buildURL("/project_y/characters/"+strings.TrimSpace(characterID)), - headers, - nil, - false, - ) - return err -} - -func (c *SoraDirectClient) PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error) { - token, err := c.getAccessToken(ctx, account) - if err != nil { - return "", err - } - userAgent := c.taskUserAgent() - proxyURL := c.resolveProxyURL(account) - ctx = c.withRequestTrace(ctx, account, proxyURL, userAgent) - payload := map[string]any{ - "attachments_to_create": []map[string]any{ - { - "generation_id": generationID, - "kind": "sora", - }, - }, - "post_text": "", - } - body, err := json.Marshal(payload) - if err != nil { - return "", err - } - headers := c.buildBaseHeaders(token, userAgent) - headers.Set("Content-Type", "application/json") - headers.Set("Origin", "https://sora.chatgpt.com") - headers.Set("Referer", "https://sora.chatgpt.com/") - sentinel, err := c.generateSentinelToken(ctx, account, token, userAgent, proxyURL) - if err != nil { - return "", err - } - headers.Set("openai-sentinel-token", sentinel) - respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/project_y/post"), headers, bytes.NewReader(body), true) - if err != nil { - return "", err - } - postID := strings.TrimSpace(gjson.GetBytes(respBody, "post.id").String()) - if postID == "" { - return "", errors.New("watermark-free publish response missing post.id") - } - return postID, nil -} - -func (c *SoraDirectClient) DeletePost(ctx context.Context, account *Account, postID string) error { - token, err := c.getAccessToken(ctx, account) - if err != nil { - return err - } - userAgent := c.taskUserAgent() - proxyURL := c.resolveProxyURL(account) - headers := c.buildBaseHeaders(token, userAgent) - _, _, err = c.doRequestWithProxy( - ctx, - account, - proxyURL, - http.MethodDelete, - c.buildURL("/project_y/post/"+strings.TrimSpace(postID)), - headers, - nil, - false, - ) - return err -} - -func (c *SoraDirectClient) GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error) { - parseURL = strings.TrimRight(strings.TrimSpace(parseURL), "/") - if parseURL == "" { - return "", errors.New("custom parse url is required") - } - if strings.TrimSpace(parseToken) == "" { - return "", errors.New("custom parse token is required") - } - shareURL := "https://sora.chatgpt.com/p/" + strings.TrimSpace(postID) - payload := map[string]any{ - "url": shareURL, - "token": strings.TrimSpace(parseToken), - } - body, err := json.Marshal(payload) - if err != nil { - return "", err - } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, parseURL+"/get-sora-link", bytes.NewReader(body)) - if err != nil { - return "", err - } - req.Header.Set("Content-Type", "application/json") - - proxyURL := c.resolveProxyURL(account) - accountID := int64(0) - accountConcurrency := 0 - if account != nil { - accountID = account.ID - accountConcurrency = account.Concurrency - } - var resp *http.Response - if c.httpUpstream != nil { - resp, err = c.httpUpstream.Do(req, proxyURL, accountID, accountConcurrency) - } else { - resp, err = http.DefaultClient.Do(req) - } - if err != nil { - return "", err - } - defer func() { _ = resp.Body.Close() }() - raw, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20)) - if err != nil { - return "", err - } - if resp.StatusCode != http.StatusOK { - return "", fmt.Errorf("custom parse failed: %d %s", resp.StatusCode, truncateForLog(raw, 256)) - } - downloadLink := strings.TrimSpace(gjson.GetBytes(raw, "download_link").String()) - if downloadLink == "" { - return "", errors.New("custom parse response missing download_link") - } - return downloadLink, nil -} - -func (c *SoraDirectClient) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) { - token, err := c.getAccessToken(ctx, account) - if err != nil { - return "", err - } - userAgent := c.taskUserAgent() - proxyURL := c.resolveProxyURL(account) - if strings.TrimSpace(expansionLevel) == "" { - expansionLevel = "medium" - } - if durationS <= 0 { - durationS = 10 - } - - payload := map[string]any{ - "prompt": prompt, - "expansion_level": expansionLevel, - "duration_s": durationS, - } - body, err := json.Marshal(payload) - if err != nil { - return "", err - } - - headers := c.buildBaseHeaders(token, userAgent) - headers.Set("Content-Type", "application/json") - headers.Set("Accept", "application/json") - headers.Set("Origin", "https://sora.chatgpt.com") - headers.Set("Referer", "https://sora.chatgpt.com/") - - respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/editor/enhance_prompt"), headers, bytes.NewReader(body), false) - if err != nil { - return "", err - } - enhancedPrompt := strings.TrimSpace(gjson.GetBytes(respBody, "enhanced_prompt").String()) - if enhancedPrompt == "" { - return "", errors.New("enhance_prompt response missing enhanced_prompt") - } - return enhancedPrompt, nil -} - -func (c *SoraDirectClient) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) { - status, found, err := c.fetchRecentImageTask(ctx, account, taskID, c.recentTaskLimit()) - if err != nil { - return nil, err - } - if found { - return status, nil - } - maxLimit := c.recentTaskLimitMax() - if maxLimit > 0 && maxLimit != c.recentTaskLimit() { - status, found, err = c.fetchRecentImageTask(ctx, account, taskID, maxLimit) - if err != nil { - return nil, err - } - if found { - return status, nil - } - } - return &SoraImageTaskStatus{ID: taskID, Status: "processing"}, nil -} - -func (c *SoraDirectClient) fetchRecentImageTask(ctx context.Context, account *Account, taskID string, limit int) (*SoraImageTaskStatus, bool, error) { - token, err := c.getAccessToken(ctx, account) - if err != nil { - return nil, false, err - } - userAgent := c.taskUserAgent() - proxyURL := c.resolveProxyURL(account) - headers := c.buildBaseHeaders(token, userAgent) - if limit <= 0 { - limit = 20 - } - endpoint := fmt.Sprintf("/v2/recent_tasks?limit=%d", limit) - respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodGet, c.buildURL(endpoint), headers, nil, false) - if err != nil { - return nil, false, err - } - var found *SoraImageTaskStatus - gjson.GetBytes(respBody, "task_responses").ForEach(func(_, item gjson.Result) bool { - if item.Get("id").String() != taskID { - return true // continue - } - status := strings.TrimSpace(item.Get("status").String()) - progress := item.Get("progress_pct").Float() - var urls []string - item.Get("generations").ForEach(func(_, gen gjson.Result) bool { - if u := strings.TrimSpace(gen.Get("url").String()); u != "" { - urls = append(urls, u) - } - return true - }) - found = &SoraImageTaskStatus{ - ID: taskID, - Status: status, - ProgressPct: progress, - URLs: urls, - } - return false // break - }) - if found != nil { - return found, true, nil - } - return &SoraImageTaskStatus{ID: taskID, Status: "processing"}, false, nil -} - -func (c *SoraDirectClient) recentTaskLimit() int { - if c == nil || c.cfg == nil { - return 20 - } - if c.cfg.Sora.Client.RecentTaskLimit > 0 { - return c.cfg.Sora.Client.RecentTaskLimit - } - return 20 -} - -func (c *SoraDirectClient) recentTaskLimitMax() int { - if c == nil || c.cfg == nil { - return 0 - } - if c.cfg.Sora.Client.RecentTaskLimitMax > 0 { - return c.cfg.Sora.Client.RecentTaskLimitMax - } - return 0 -} - -func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error) { - token, err := c.getAccessToken(ctx, account) - if err != nil { - return nil, err - } - userAgent := c.taskUserAgent() - proxyURL := c.resolveProxyURL(account) - headers := c.buildBaseHeaders(token, userAgent) - - respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodGet, c.buildURL("/nf/pending/v2"), headers, nil, false) - if err != nil { - return nil, err - } - // 搜索 pending 列表(JSON 数组) - pendingResult := gjson.ParseBytes(respBody) - if pendingResult.IsArray() { - var pendingFound *SoraVideoTaskStatus - pendingResult.ForEach(func(_, task gjson.Result) bool { - if task.Get("id").String() != taskID { - return true - } - progress := 0 - if v := task.Get("progress_pct"); v.Exists() { - progress = int(v.Float() * 100) - } - status := strings.TrimSpace(task.Get("status").String()) - pendingFound = &SoraVideoTaskStatus{ - ID: taskID, - Status: status, - ProgressPct: progress, - } - return false - }) - if pendingFound != nil { - return pendingFound, nil - } - } - - respBody, _, err = c.doRequestWithProxy(ctx, account, proxyURL, http.MethodGet, c.buildURL("/project_y/profile/drafts?limit=15"), headers, nil, false) - if err != nil { - return nil, err - } - var draftFound *SoraVideoTaskStatus - gjson.GetBytes(respBody, "items").ForEach(func(_, draft gjson.Result) bool { - if draft.Get("task_id").String() != taskID { - return true - } - generationID := strings.TrimSpace(draft.Get("id").String()) - kind := strings.TrimSpace(draft.Get("kind").String()) - reason := strings.TrimSpace(draft.Get("reason_str").String()) - if reason == "" { - reason = strings.TrimSpace(draft.Get("markdown_reason_str").String()) - } - urlStr := strings.TrimSpace(draft.Get("downloadable_url").String()) - if urlStr == "" { - urlStr = strings.TrimSpace(draft.Get("url").String()) - } - - if kind == "sora_content_violation" || reason != "" || urlStr == "" { - msg := reason - if msg == "" { - msg = "Content violates guardrails" - } - draftFound = &SoraVideoTaskStatus{ - ID: taskID, - Status: "failed", - GenerationID: generationID, - ErrorMsg: msg, - } - } else { - draftFound = &SoraVideoTaskStatus{ - ID: taskID, - Status: "completed", - GenerationID: generationID, - URLs: []string{urlStr}, - } - } - return false - }) - if draftFound != nil { - return draftFound, nil - } - - return &SoraVideoTaskStatus{ID: taskID, Status: "processing"}, nil -} - -func (c *SoraDirectClient) buildURL(endpoint string) string { - base := strings.TrimRight(strings.TrimSpace(c.baseURL), "/") - if base == "" && c != nil && c.cfg != nil { - base = normalizeSoraBaseURL(c.cfg.Sora.Client.BaseURL) - c.baseURL = base - } - if base == "" { - return endpoint - } - if strings.HasPrefix(endpoint, "/") { - return base + endpoint - } - return base + "/" + endpoint -} - -func (c *SoraDirectClient) defaultUserAgent() string { - if c == nil || c.cfg == nil { - return soraDefaultUserAgent - } - ua := strings.TrimSpace(c.cfg.Sora.Client.UserAgent) - if ua == "" { - return soraDefaultUserAgent - } - return ua -} - -func (c *SoraDirectClient) taskUserAgent() string { - if c != nil && c.cfg != nil { - if ua := strings.TrimSpace(c.cfg.Sora.Client.UserAgent); ua != "" { - return ua - } - } - if len(soraMobileUserAgents) > 0 { - return soraMobileUserAgents[soraRandInt(len(soraMobileUserAgents))] - } - if len(soraDesktopUserAgents) > 0 { - return soraDesktopUserAgents[soraRandInt(len(soraDesktopUserAgents))] - } - return soraDefaultUserAgent -} - -func (c *SoraDirectClient) resolveProxyURL(account *Account) string { - if account == nil || account.ProxyID == nil || account.Proxy == nil { - return "" - } - return strings.TrimSpace(account.Proxy.URL()) -} - -func (c *SoraDirectClient) getAccessToken(ctx context.Context, account *Account) (string, error) { - if account == nil { - return "", errors.New("account is nil") - } - - allowProvider := c.allowOpenAITokenProvider(account) - var providerErr error - if allowProvider && c.tokenProvider != nil { - token, err := c.tokenProvider.GetAccessToken(ctx, account) - if err == nil && strings.TrimSpace(token) != "" { - c.logTokenSource(account, "openai_token_provider") - return token, nil - } - providerErr = err - if err != nil && c.debugEnabled() { - c.debugLogf( - "token_provider_failed account_id=%d platform=%s err=%s", - account.ID, - account.Platform, - logredact.RedactText(err.Error()), - ) - } - } - token := strings.TrimSpace(account.GetCredential("access_token")) - if token != "" { - expiresAt := account.GetCredentialAsTime("expires_at") - if expiresAt != nil && time.Until(*expiresAt) <= 2*time.Minute { - refreshed, refreshErr := c.recoverAccessToken(ctx, account, "access_token_expiring") - if refreshErr == nil && strings.TrimSpace(refreshed) != "" { - c.logTokenSource(account, "refresh_token_recovered") - return refreshed, nil - } - if refreshErr != nil && c.debugEnabled() { - c.debugLogf("token_refresh_before_use_failed account_id=%d err=%s", account.ID, logredact.RedactText(refreshErr.Error())) - } - } - c.logTokenSource(account, "account_credentials") - return token, nil - } - - recovered, recoverErr := c.recoverAccessToken(ctx, account, "access_token_missing") - if recoverErr == nil && strings.TrimSpace(recovered) != "" { - c.logTokenSource(account, "session_or_refresh_recovered") - return recovered, nil - } - if recoverErr != nil && c.debugEnabled() { - c.debugLogf("token_recover_failed account_id=%d platform=%s err=%s", account.ID, account.Platform, logredact.RedactText(recoverErr.Error())) - } - if providerErr != nil { - return "", providerErr - } - if c.tokenProvider != nil && !allowProvider { - c.logTokenSource(account, "account_credentials(provider_disabled)") - } - return "", errors.New("access_token not found") -} - -func (c *SoraDirectClient) recoverAccessToken(ctx context.Context, account *Account, reason string) (string, error) { - if account == nil { - return "", errors.New("account is nil") - } - - if sessionToken := strings.TrimSpace(account.GetCredential("session_token")); sessionToken != "" { - accessToken, expiresAt, err := c.exchangeSessionToken(ctx, account, sessionToken) - if err == nil && strings.TrimSpace(accessToken) != "" { - c.applyRecoveredToken(ctx, account, accessToken, "", expiresAt, sessionToken) - c.logTokenRecover(account, "session_token", reason, true, nil) - return accessToken, nil - } - c.logTokenRecover(account, "session_token", reason, false, err) - } - - refreshToken := strings.TrimSpace(account.GetCredential("refresh_token")) - if refreshToken == "" { - return "", errors.New("session_token/refresh_token not found") - } - accessToken, newRefreshToken, expiresAt, err := c.exchangeRefreshToken(ctx, account, refreshToken) - if err != nil { - c.logTokenRecover(account, "refresh_token", reason, false, err) - return "", err - } - if strings.TrimSpace(accessToken) == "" { - return "", errors.New("refreshed access_token is empty") - } - c.applyRecoveredToken(ctx, account, accessToken, newRefreshToken, expiresAt, "") - c.logTokenRecover(account, "refresh_token", reason, true, nil) - return accessToken, nil -} - -func (c *SoraDirectClient) exchangeSessionToken(ctx context.Context, account *Account, sessionToken string) (string, string, error) { - headers := http.Header{} - headers.Set("Cookie", "__Secure-next-auth.session-token="+sessionToken) - headers.Set("Accept", "application/json") - headers.Set("Origin", "https://sora.chatgpt.com") - headers.Set("Referer", "https://sora.chatgpt.com/") - headers.Set("User-Agent", c.defaultUserAgent()) - body, _, err := c.doRequest(ctx, account, http.MethodGet, soraSessionAuthURL, headers, nil, false) - if err != nil { - return "", "", err - } - accessToken := strings.TrimSpace(gjson.GetBytes(body, "accessToken").String()) - if accessToken == "" { - return "", "", errors.New("session exchange missing accessToken") - } - expiresAt := strings.TrimSpace(gjson.GetBytes(body, "expires").String()) - return accessToken, expiresAt, nil -} - -func (c *SoraDirectClient) exchangeRefreshToken(ctx context.Context, account *Account, refreshToken string) (string, string, string, error) { - clientIDs := []string{ - strings.TrimSpace(account.GetCredential("client_id")), - openaioauth.SoraClientID, - openaioauth.ClientID, - } - tried := make(map[string]struct{}, len(clientIDs)) - var lastErr error - - for _, clientID := range clientIDs { - if clientID == "" { - continue - } - if _, ok := tried[clientID]; ok { - continue - } - tried[clientID] = struct{}{} - - formData := url.Values{} - formData.Set("client_id", clientID) - formData.Set("grant_type", "refresh_token") - formData.Set("refresh_token", refreshToken) - formData.Set("redirect_uri", "com.openai.chat://auth0.openai.com/ios/com.openai.chat/callback") - headers := http.Header{} - headers.Set("Accept", "application/json") - headers.Set("Content-Type", "application/x-www-form-urlencoded") - headers.Set("User-Agent", c.defaultUserAgent()) - - respBody, _, err := c.doRequest(ctx, account, http.MethodPost, soraOAuthTokenURL, headers, strings.NewReader(formData.Encode()), false) - if err != nil { - lastErr = err - if c.debugEnabled() { - c.debugLogf("refresh_token_exchange_failed account_id=%d client_id=%s err=%s", account.ID, clientID, logredact.RedactText(err.Error())) - } - continue - } - accessToken := strings.TrimSpace(gjson.GetBytes(respBody, "access_token").String()) - if accessToken == "" { - lastErr = errors.New("oauth refresh response missing access_token") - continue - } - newRefreshToken := strings.TrimSpace(gjson.GetBytes(respBody, "refresh_token").String()) - expiresIn := gjson.GetBytes(respBody, "expires_in").Int() - expiresAt := "" - if expiresIn > 0 { - expiresAt = time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339) - } - return accessToken, newRefreshToken, expiresAt, nil - } - - if lastErr != nil { - return "", "", "", lastErr - } - return "", "", "", errors.New("no available client_id for refresh_token exchange") -} - -func (c *SoraDirectClient) applyRecoveredToken(ctx context.Context, account *Account, accessToken, refreshToken, expiresAt, sessionToken string) { - if account == nil { - return - } - if account.Credentials == nil { - account.Credentials = make(map[string]any) - } - if strings.TrimSpace(accessToken) != "" { - account.Credentials["access_token"] = accessToken - } - if strings.TrimSpace(refreshToken) != "" { - account.Credentials["refresh_token"] = refreshToken - } - if strings.TrimSpace(expiresAt) != "" { - account.Credentials["expires_at"] = expiresAt - } - if strings.TrimSpace(sessionToken) != "" { - account.Credentials["session_token"] = sessionToken - } - - if c.accountRepo != nil { - if err := c.accountRepo.Update(ctx, account); err != nil { - if c.debugEnabled() { - c.debugLogf("persist_recovered_token_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error())) - } - } - } - c.updateSoraAccountExtension(ctx, account, accessToken, refreshToken, sessionToken) -} - -func (c *SoraDirectClient) updateSoraAccountExtension(ctx context.Context, account *Account, accessToken, refreshToken, sessionToken string) { - if c == nil || c.soraAccountRepo == nil || account == nil || account.ID <= 0 { - return - } - updates := make(map[string]any) - if strings.TrimSpace(accessToken) != "" && strings.TrimSpace(refreshToken) != "" { - updates["access_token"] = accessToken - updates["refresh_token"] = refreshToken - } - if strings.TrimSpace(sessionToken) != "" { - updates["session_token"] = sessionToken - } - if len(updates) == 0 { - return - } - if err := c.soraAccountRepo.Upsert(ctx, account.ID, updates); err != nil && c.debugEnabled() { - c.debugLogf("persist_sora_extension_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error())) - } -} - -func (c *SoraDirectClient) logTokenRecover(account *Account, source, reason string, success bool, err error) { - if !c.debugEnabled() || account == nil { - return - } - if success { - c.debugLogf("token_recover_success account_id=%d platform=%s source=%s reason=%s", account.ID, account.Platform, source, reason) - return - } - if err == nil { - c.debugLogf("token_recover_failed account_id=%d platform=%s source=%s reason=%s", account.ID, account.Platform, source, reason) - return - } - c.debugLogf("token_recover_failed account_id=%d platform=%s source=%s reason=%s err=%s", account.ID, account.Platform, source, reason, logredact.RedactText(err.Error())) -} - -func (c *SoraDirectClient) allowOpenAITokenProvider(account *Account) bool { - if c == nil || c.tokenProvider == nil { - return false - } - if account != nil && account.Platform == PlatformSora { - return c.cfg != nil && c.cfg.Sora.Client.UseOpenAITokenProvider - } - return true -} - -func (c *SoraDirectClient) logTokenSource(account *Account, source string) { - if !c.debugEnabled() || account == nil { - return - } - c.debugLogf( - "token_selected account_id=%d platform=%s account_type=%s source=%s", - account.ID, - account.Platform, - account.Type, - source, - ) -} - -func (c *SoraDirectClient) buildBaseHeaders(token, userAgent string) http.Header { - headers := http.Header{} - if token != "" { - headers.Set("Authorization", "Bearer "+token) - } - if userAgent != "" { - headers.Set("User-Agent", userAgent) - } - if c != nil && c.cfg != nil { - for key, value := range c.cfg.Sora.Client.Headers { - if strings.EqualFold(key, "authorization") || strings.EqualFold(key, "openai-sentinel-token") { - continue - } - headers.Set(key, value) - } - } - return headers -} - -func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, method, urlStr string, headers http.Header, body io.Reader, allowRetry bool) ([]byte, http.Header, error) { - return c.doRequestWithProxy(ctx, account, c.resolveProxyURL(account), method, urlStr, headers, body, allowRetry) -} - -func (c *SoraDirectClient) doRequestWithProxy( - ctx context.Context, - account *Account, - proxyURL string, - method, - urlStr string, - headers http.Header, - body io.Reader, - allowRetry bool, -) ([]byte, http.Header, error) { - if strings.TrimSpace(urlStr) == "" { - return nil, nil, errors.New("empty upstream url") - } - proxyURL = strings.TrimSpace(proxyURL) - if proxyURL == "" { - proxyURL = c.resolveProxyURL(account) - } - if cooldownErr := c.checkCloudflareChallengeCooldown(account, proxyURL); cooldownErr != nil { - return nil, nil, cooldownErr - } - traceID, traceProxyKey, traceUAHash := c.requestTraceFields(ctx, proxyURL, headers.Get("User-Agent")) - timeout := 0 - if c != nil && c.cfg != nil { - timeout = c.cfg.Sora.Client.TimeoutSeconds - } - if timeout > 0 { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, time.Duration(timeout)*time.Second) - defer cancel() - } - maxRetries := 0 - if allowRetry && c != nil && c.cfg != nil { - maxRetries = c.cfg.Sora.Client.MaxRetries - } - if maxRetries < 0 { - maxRetries = 0 - } - - var bodyBytes []byte - if body != nil { - b, err := io.ReadAll(body) - if err != nil { - return nil, nil, err - } - bodyBytes = b - } - - attempts := maxRetries + 1 - authRecovered := false - authRecoverExtraAttemptGranted := false - challengeRetried := false - sawCFChallenge := false - var lastErr error - for attempt := 1; attempt <= attempts; attempt++ { - if c.debugEnabled() { - c.debugLogf( - "request_start trace_id=%s method=%s url=%s attempt=%d/%d timeout_s=%d body_bytes=%d proxy_bound=%t proxy_key=%s ua_hash=%s headers=%s", - traceID, - method, - sanitizeSoraLogURL(urlStr), - attempt, - attempts, - timeout, - len(bodyBytes), - proxyURL != "", - traceProxyKey, - traceUAHash, - formatSoraHeaders(headers), - ) - } - - var reader io.Reader - if bodyBytes != nil { - reader = bytes.NewReader(bodyBytes) - } - req, err := http.NewRequestWithContext(ctx, method, urlStr, reader) - if err != nil { - return nil, nil, err - } - req.Header = headers.Clone() - start := time.Now() - - resp, err := c.doHTTP(req, proxyURL, account) - if err != nil { - lastErr = err - if c.debugEnabled() { - c.debugLogf( - "request_transport_error trace_id=%s method=%s url=%s attempt=%d/%d err=%s", - traceID, - method, - sanitizeSoraLogURL(urlStr), - attempt, - attempts, - logredact.RedactText(err.Error()), - ) - } - if attempt < attempts && allowRetry { - if c.debugEnabled() { - c.debugLogf("request_retry_scheduled trace_id=%s method=%s url=%s reason=transport_error next_attempt=%d/%d", traceID, method, sanitizeSoraLogURL(urlStr), attempt+1, attempts) - } - c.sleepRetry(attempt) - continue - } - return nil, nil, err - } - - respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - _ = resp.Body.Close() - if readErr != nil { - return nil, resp.Header, readErr - } - - if c.cfg != nil && c.cfg.Sora.Client.Debug { - c.debugLogf( - "response_received trace_id=%s method=%s url=%s attempt=%d/%d status=%d cost=%s resp_bytes=%d resp_headers=%s", - traceID, - method, - sanitizeSoraLogURL(urlStr), - attempt, - attempts, - resp.StatusCode, - time.Since(start), - len(respBody), - formatSoraHeaders(resp.Header), - ) - } - - if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { - isCFChallenge := soraerror.IsCloudflareChallengeResponse(resp.StatusCode, resp.Header, respBody) - if isCFChallenge { - sawCFChallenge = true - c.recordCloudflareChallengeCooldown(account, proxyURL, resp.StatusCode, resp.Header, respBody) - if allowRetry && attempt < attempts && !challengeRetried { - challengeRetried = true - if c.debugEnabled() { - c.debugLogf("request_retry_scheduled trace_id=%s method=%s url=%s reason=cloudflare_challenge status=%d next_attempt=%d/%d", traceID, method, sanitizeSoraLogURL(urlStr), resp.StatusCode, attempt+1, attempts) - } - c.sleepRetry(attempt) - continue - } - } - if !isCFChallenge && !authRecovered && shouldAttemptSoraTokenRecover(resp.StatusCode, urlStr) && account != nil { - if recovered, recoverErr := c.recoverAccessToken(ctx, account, fmt.Sprintf("upstream_status_%d", resp.StatusCode)); recoverErr == nil && strings.TrimSpace(recovered) != "" { - headers.Set("Authorization", "Bearer "+recovered) - authRecovered = true - if attempt == attempts && !authRecoverExtraAttemptGranted { - attempts++ - authRecoverExtraAttemptGranted = true - } - if c.debugEnabled() { - c.debugLogf("request_retry_with_recovered_token trace_id=%s method=%s url=%s status=%d", traceID, method, sanitizeSoraLogURL(urlStr), resp.StatusCode) - } - continue - } else if recoverErr != nil && c.debugEnabled() { - c.debugLogf("request_recover_token_failed trace_id=%s method=%s url=%s status=%d err=%s", traceID, method, sanitizeSoraLogURL(urlStr), resp.StatusCode, logredact.RedactText(recoverErr.Error())) - } - } - if c.debugEnabled() { - c.debugLogf( - "response_non_success trace_id=%s method=%s url=%s attempt=%d/%d status=%d body=%s", - traceID, - method, - sanitizeSoraLogURL(urlStr), - attempt, - attempts, - resp.StatusCode, - summarizeSoraResponseBody(respBody, 512), - ) - } - upstreamErr := c.buildUpstreamError(resp.StatusCode, resp.Header, respBody, urlStr) - lastErr = upstreamErr - if isCFChallenge { - return nil, resp.Header, upstreamErr - } - if allowRetry && attempt < attempts && (resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode >= 500) { - if c.debugEnabled() { - c.debugLogf("request_retry_scheduled trace_id=%s method=%s url=%s reason=status_%d next_attempt=%d/%d", traceID, method, sanitizeSoraLogURL(urlStr), resp.StatusCode, attempt+1, attempts) - } - c.sleepRetry(attempt) - continue - } - return nil, resp.Header, upstreamErr - } - if sawCFChallenge { - c.clearCloudflareChallengeCooldown(account, proxyURL) - } - return respBody, resp.Header, nil - } - if lastErr != nil { - return nil, nil, lastErr - } - return nil, nil, errors.New("upstream retries exhausted") -} - -func shouldAttemptSoraTokenRecover(statusCode int, rawURL string) bool { - switch statusCode { - case http.StatusUnauthorized, http.StatusForbidden: - parsed, err := url.Parse(strings.TrimSpace(rawURL)) - if err != nil { - return false - } - host := strings.ToLower(parsed.Hostname()) - if host != "sora.chatgpt.com" && host != "chatgpt.com" { - return false - } - // 避免在 ST->AT 转换接口上递归触发 token 恢复导致死循环。 - path := strings.ToLower(strings.TrimSpace(parsed.Path)) - if path == "/api/auth/session" { - return false - } - return true - default: - return false - } -} - -func (c *SoraDirectClient) doHTTP(req *http.Request, proxyURL string, account *Account) (*http.Response, error) { - if c != nil && c.cfg != nil && c.cfg.Sora.Client.CurlCFFISidecar.Enabled { - resp, err := c.doHTTPViaCurlCFFISidecar(req, proxyURL, account) - if err != nil { - return nil, err - } - return resp, nil - } - - enableTLS := c == nil || c.cfg == nil || !c.cfg.Sora.Client.DisableTLSFingerprint - if c.httpUpstream != nil { - accountID := int64(0) - accountConcurrency := 0 - if account != nil { - accountID = account.ID - accountConcurrency = account.Concurrency - } - return c.httpUpstream.DoWithTLS(req, proxyURL, accountID, accountConcurrency, enableTLS) - } - return http.DefaultClient.Do(req) -} - -func (c *SoraDirectClient) sleepRetry(attempt int) { - backoff := time.Duration(attempt*attempt) * time.Second - if backoff > 10*time.Second { - backoff = 10 * time.Second - } - time.Sleep(backoff) -} - -func (c *SoraDirectClient) buildUpstreamError(status int, headers http.Header, body []byte, requestURL string) error { - msg := strings.TrimSpace(extractUpstreamErrorMessage(body)) - msg = sanitizeUpstreamErrorMessage(msg) - if status == http.StatusNotFound && strings.Contains(strings.ToLower(msg), "not found") { - if hint := soraBaseURLNotFoundHint(requestURL); hint != "" { - msg = strings.TrimSpace(msg + " " + hint) - } - } - if msg == "" { - msg = truncateForLog(body, 256) - } - return &SoraUpstreamError{ - StatusCode: status, - Message: msg, - Headers: headers, - Body: body, - } -} - -func normalizeSoraBaseURL(raw string) string { - trimmed := strings.TrimRight(strings.TrimSpace(raw), "/") - if trimmed == "" { - return "" - } - parsed, err := url.Parse(trimmed) - if err != nil || parsed.Scheme == "" || parsed.Host == "" { - return trimmed - } - host := strings.ToLower(parsed.Hostname()) - if host != "sora.chatgpt.com" && host != "chatgpt.com" { - return trimmed - } - pathVal := strings.TrimRight(strings.TrimSpace(parsed.Path), "/") - switch pathVal { - case "", "/": - parsed.Path = "/backend" - case "/backend-api": - parsed.Path = "/backend" - } - return strings.TrimRight(parsed.String(), "/") -} - -func soraBaseURLNotFoundHint(requestURL string) string { - parsed, err := url.Parse(strings.TrimSpace(requestURL)) - if err != nil || parsed.Host == "" { - return "" - } - host := strings.ToLower(parsed.Hostname()) - if host != "sora.chatgpt.com" && host != "chatgpt.com" { - return "" - } - pathVal := strings.TrimSpace(parsed.Path) - if strings.HasPrefix(pathVal, "/backend/") || pathVal == "/backend" { - return "" - } - return "(请检查 sora.client.base_url,建议配置为 https://sora.chatgpt.com/backend)" -} - -func (c *SoraDirectClient) generateSentinelToken(ctx context.Context, account *Account, accessToken, userAgent, proxyURL string) (string, error) { - reqID := uuid.NewString() - userAgent = strings.TrimSpace(userAgent) - if userAgent == "" { - userAgent = c.taskUserAgent() - } - powToken := soraPowTokenGenerator(userAgent) - payload := map[string]any{ - "p": powToken, - "flow": soraSentinelFlow, - "id": reqID, - } - body, err := json.Marshal(payload) - if err != nil { - return "", err - } - headers := http.Header{} - headers.Set("Accept", "application/json, text/plain, */*") - headers.Set("Content-Type", "application/json") - headers.Set("Origin", "https://sora.chatgpt.com") - headers.Set("Referer", "https://sora.chatgpt.com/") - headers.Set("User-Agent", userAgent) - if accessToken != "" { - headers.Set("Authorization", "Bearer "+accessToken) - } - - urlStr := soraChatGPTBaseURL + "/backend-api/sentinel/req" - respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, urlStr, headers, bytes.NewReader(body), true) - if err != nil { - return "", err - } - var resp map[string]any - if err := json.Unmarshal(respBody, &resp); err != nil { - return "", err - } - - sentinel := soraBuildSentinelToken(soraSentinelFlow, reqID, powToken, resp, userAgent) - if sentinel == "" { - return "", errors.New("failed to build sentinel token") - } - return sentinel, nil -} - -func soraGetPowToken(userAgent string) string { - configList := soraBuildPowConfig(userAgent) - seed := strconv.FormatFloat(soraRandFloat(), 'f', -1, 64) - difficulty := "0fffff" - solution, _ := soraSolvePow(seed, difficulty, configList) - return "gAAAAAC" + solution -} - -func soraRandFloat() float64 { - soraRandMu.Lock() - defer soraRandMu.Unlock() - return soraRand.Float64() -} - -func soraRandInt(max int) int { - if max <= 1 { - return 0 - } - soraRandMu.Lock() - defer soraRandMu.Unlock() - return soraRand.Intn(max) -} - -func soraBuildPowConfig(userAgent string) []any { - userAgent = strings.TrimSpace(userAgent) - if userAgent == "" && len(soraDesktopUserAgents) > 0 { - userAgent = soraDesktopUserAgents[0] - } - screenVal := soraStableChoiceInt([]int{ - 1920 + 1080, - 2560 + 1440, - 1920 + 1200, - 2560 + 1600, - }, userAgent+"|screen") - perfMs := float64(time.Since(soraPerfStart).Milliseconds()) - wallMs := float64(time.Now().UnixNano()) / 1e6 - diff := wallMs - perfMs - return []any{ - screenVal, - soraPowParseTime(), - 4294705152, - 0, - userAgent, - soraStableChoice(soraPowScripts, userAgent+"|script"), - soraStableChoice(soraPowDPL, userAgent+"|dpl"), - "en-US", - "en-US,es-US,en,es", - 0, - soraStableChoice(soraPowNavigatorKeys, userAgent+"|navigator"), - soraStableChoice(soraPowDocumentKeys, userAgent+"|document"), - soraStableChoice(soraPowWindowKeys, userAgent+"|window"), - perfMs, - uuid.NewString(), - "", - soraStableChoiceInt(soraPowCores, userAgent+"|cores"), - diff, - } -} - -func soraStableChoice(items []string, seed string) string { - if len(items) == 0 { - return "" - } - idx := soraStableIndex(seed, len(items)) - return items[idx] -} - -func soraStableChoiceInt(items []int, seed string) int { - if len(items) == 0 { - return 0 - } - idx := soraStableIndex(seed, len(items)) - return items[idx] -} - -func soraStableIndex(seed string, size int) int { - if size <= 0 { - return 0 - } - h := fnv.New32a() - _, _ = h.Write([]byte(seed)) - return int(h.Sum32() % uint32(size)) -} - -func soraPowParseTime() string { - loc := time.FixedZone("EST", -5*3600) - return time.Now().In(loc).Format("Mon Jan 02 2006 15:04:05 GMT-0700 (Eastern Standard Time)") -} - -func soraSolvePow(seed, difficulty string, configList []any) (string, bool) { - diffLen := len(difficulty) / 2 - target, err := hexDecodeString(difficulty) - if err != nil { - return "", false - } - seedBytes := []byte(seed) - - part1 := mustMarshalJSON(configList[:3]) - part2 := mustMarshalJSON(configList[4:9]) - part3 := mustMarshalJSON(configList[10:]) - - staticPart1 := append(part1[:len(part1)-1], ',') - staticPart2 := append([]byte(","), append(part2[1:len(part2)-1], ',')...) - staticPart3 := append([]byte(","), part3[1:]...) - - for i := 0; i < soraPowMaxIteration; i++ { - dynamicI := []byte(strconv.Itoa(i)) - dynamicJ := []byte(strconv.Itoa(i >> 1)) - finalJSON := make([]byte, 0, len(staticPart1)+len(dynamicI)+len(staticPart2)+len(dynamicJ)+len(staticPart3)) - finalJSON = append(finalJSON, staticPart1...) - finalJSON = append(finalJSON, dynamicI...) - finalJSON = append(finalJSON, staticPart2...) - finalJSON = append(finalJSON, dynamicJ...) - finalJSON = append(finalJSON, staticPart3...) - - b64 := base64.StdEncoding.EncodeToString(finalJSON) - hash := sha3.Sum512(append(seedBytes, []byte(b64)...)) - if bytes.Compare(hash[:diffLen], target[:diffLen]) <= 0 { - return b64, true - } - } - - errorToken := "wQ8Lk5FbGpA2NcR9dShT6gYjU7VxZ4D" + base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("\"%s\"", seed))) - return errorToken, false -} - -func soraBuildSentinelToken(flow, reqID, powToken string, resp map[string]any, userAgent string) string { - finalPow := powToken - proof, _ := resp["proofofwork"].(map[string]any) - if required, _ := proof["required"].(bool); required { - seed, _ := proof["seed"].(string) - difficulty, _ := proof["difficulty"].(string) - if seed != "" && difficulty != "" { - configList := soraBuildPowConfig(userAgent) - solution, _ := soraSolvePow(seed, difficulty, configList) - finalPow = "gAAAAAB" + solution - } - } - if !strings.HasSuffix(finalPow, "~S") { - finalPow += "~S" - } - turnstile, _ := resp["turnstile"].(map[string]any) - tokenPayload := map[string]any{ - "p": finalPow, - "t": safeMapString(turnstile, "dx"), - "c": safeString(resp["token"]), - "id": reqID, - "flow": flow, - } - encoded, _ := json.Marshal(tokenPayload) - return string(encoded) -} - -func safeMapString(m map[string]any, key string) string { - if m == nil { - return "" - } - if v, ok := m[key]; ok { - return safeString(v) - } - return "" -} - -func safeString(v any) string { - switch val := v.(type) { - case string: - return val - default: - return fmt.Sprintf("%v", val) - } -} - -func mustMarshalJSON(v any) []byte { - b, _ := json.Marshal(v) - return b -} - -func hexDecodeString(s string) ([]byte, error) { - dst := make([]byte, len(s)/2) - _, err := hex.Decode(dst, []byte(s)) - return dst, err -} - -func (c *SoraDirectClient) withRequestTrace(ctx context.Context, account *Account, proxyURL, userAgent string) context.Context { - if ctx == nil { - ctx = context.Background() - } - if existing, ok := ctx.Value(soraRequestTraceContextKey{}).(*soraRequestTrace); ok && existing != nil && existing.ID != "" { - return ctx - } - accountID := int64(0) - if account != nil { - accountID = account.ID - } - seed := fmt.Sprintf("%d|%s|%s|%d", accountID, normalizeSoraProxyKey(proxyURL), strings.TrimSpace(userAgent), time.Now().UnixNano()) - trace := &soraRequestTrace{ - ID: "sora-" + soraHashForLog(seed), - ProxyKey: normalizeSoraProxyKey(proxyURL), - UAHash: soraHashForLog(strings.TrimSpace(userAgent)), - } - return context.WithValue(ctx, soraRequestTraceContextKey{}, trace) -} - -func (c *SoraDirectClient) requestTraceFields(ctx context.Context, proxyURL, userAgent string) (string, string, string) { - proxyKey := normalizeSoraProxyKey(proxyURL) - uaHash := soraHashForLog(strings.TrimSpace(userAgent)) - traceID := "" - if ctx != nil { - if trace, ok := ctx.Value(soraRequestTraceContextKey{}).(*soraRequestTrace); ok && trace != nil { - if strings.TrimSpace(trace.ID) != "" { - traceID = strings.TrimSpace(trace.ID) - } - if strings.TrimSpace(trace.ProxyKey) != "" { - proxyKey = strings.TrimSpace(trace.ProxyKey) - } - if strings.TrimSpace(trace.UAHash) != "" { - uaHash = strings.TrimSpace(trace.UAHash) - } - } - } - if traceID == "" { - traceID = "sora-" + soraHashForLog(fmt.Sprintf("%s|%d", proxyKey, time.Now().UnixNano())) - } - return traceID, proxyKey, uaHash -} - -func soraHashForLog(raw string) string { - h := fnv.New32a() - _, _ = h.Write([]byte(raw)) - return fmt.Sprintf("%08x", h.Sum32()) -} - -func sanitizeSoraLogURL(raw string) string { - parsed, err := url.Parse(raw) - if err != nil { - return raw - } - q := parsed.Query() - q.Del("sig") - q.Del("expires") - parsed.RawQuery = q.Encode() - return parsed.String() -} - -func (c *SoraDirectClient) debugEnabled() bool { - return c != nil && c.cfg != nil && c.cfg.Sora.Client.Debug -} - -func (c *SoraDirectClient) debugLogf(format string, args ...any) { - if !c.debugEnabled() { - return - } - log.Printf("[SoraClient] "+format, args...) -} - -func formatSoraHeaders(headers http.Header) string { - if len(headers) == 0 { - return "{}" - } - keys := make([]string, 0, len(headers)) - for key := range headers { - keys = append(keys, key) - } - sort.Strings(keys) - out := make(map[string]string, len(keys)) - for _, key := range keys { - values := headers.Values(key) - if len(values) == 0 { - continue - } - val := strings.Join(values, ",") - if isSensitiveHeader(key) { - out[key] = "***" - continue - } - out[key] = truncateForLog([]byte(logredact.RedactText(val)), 160) - } - encoded, err := json.Marshal(out) - if err != nil { - return "{}" - } - return string(encoded) -} - -func isSensitiveHeader(key string) bool { - k := strings.ToLower(strings.TrimSpace(key)) - switch k { - case "authorization", "openai-sentinel-token", "cookie", "set-cookie", "x-api-key": - return true - default: - return false - } -} - -func summarizeSoraResponseBody(body []byte, maxLen int) string { - if len(body) == 0 { - return "" - } - var text string - if json.Valid(body) { - text = logredact.RedactJSON(body) - } else { - text = logredact.RedactText(string(body)) - } - text = strings.TrimSpace(text) - if maxLen <= 0 || len(text) <= maxLen { - return text - } - return text[:maxLen] + "...(truncated)" -} diff --git a/backend/internal/service/sora_client_gjson_test.go b/backend/internal/service/sora_client_gjson_test.go deleted file mode 100644 index d38cfa57..00000000 --- a/backend/internal/service/sora_client_gjson_test.go +++ /dev/null @@ -1,515 +0,0 @@ -//go:build unit - -package service - -import ( - "strings" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/tidwall/gjson" -) - -// ---------- 辅助解析函数(复制生产代码中的 gjson 解析逻辑,用于单元测试) ---------- - -// testParseUploadOrCreateTaskID 模拟 UploadImage / CreateImageTask / CreateVideoTask 中 -// 用 gjson.GetBytes(respBody, "id") 提取 id 的逻辑。 -func testParseUploadOrCreateTaskID(respBody []byte) (string, error) { - id := strings.TrimSpace(gjson.GetBytes(respBody, "id").String()) - if id == "" { - return "", assert.AnError // 占位错误,表示 "missing id" - } - return id, nil -} - -// testParseFetchRecentImageTask 模拟 fetchRecentImageTask 中的 gjson.ForEach 解析逻辑。 -func testParseFetchRecentImageTask(respBody []byte, taskID string) (*SoraImageTaskStatus, bool) { - var found *SoraImageTaskStatus - gjson.GetBytes(respBody, "task_responses").ForEach(func(_, item gjson.Result) bool { - if item.Get("id").String() != taskID { - return true // continue - } - status := strings.TrimSpace(item.Get("status").String()) - progress := item.Get("progress_pct").Float() - var urls []string - item.Get("generations").ForEach(func(_, gen gjson.Result) bool { - if u := strings.TrimSpace(gen.Get("url").String()); u != "" { - urls = append(urls, u) - } - return true - }) - found = &SoraImageTaskStatus{ - ID: taskID, - Status: status, - ProgressPct: progress, - URLs: urls, - } - return false // break - }) - if found != nil { - return found, true - } - return &SoraImageTaskStatus{ID: taskID, Status: "processing"}, false -} - -// testParseGetVideoTaskPending 模拟 GetVideoTask 中解析 pending 列表的逻辑。 -func testParseGetVideoTaskPending(respBody []byte, taskID string) (*SoraVideoTaskStatus, bool) { - pendingResult := gjson.ParseBytes(respBody) - if !pendingResult.IsArray() { - return nil, false - } - var pendingFound *SoraVideoTaskStatus - pendingResult.ForEach(func(_, task gjson.Result) bool { - if task.Get("id").String() != taskID { - return true - } - progress := 0 - if v := task.Get("progress_pct"); v.Exists() { - progress = int(v.Float() * 100) - } - status := strings.TrimSpace(task.Get("status").String()) - pendingFound = &SoraVideoTaskStatus{ - ID: taskID, - Status: status, - ProgressPct: progress, - } - return false - }) - if pendingFound != nil { - return pendingFound, true - } - return nil, false -} - -// testParseGetVideoTaskDrafts 模拟 GetVideoTask 中解析 drafts 列表的逻辑。 -func testParseGetVideoTaskDrafts(respBody []byte, taskID string) (*SoraVideoTaskStatus, bool) { - var draftFound *SoraVideoTaskStatus - gjson.GetBytes(respBody, "items").ForEach(func(_, draft gjson.Result) bool { - if draft.Get("task_id").String() != taskID { - return true - } - kind := strings.TrimSpace(draft.Get("kind").String()) - reason := strings.TrimSpace(draft.Get("reason_str").String()) - if reason == "" { - reason = strings.TrimSpace(draft.Get("markdown_reason_str").String()) - } - urlStr := strings.TrimSpace(draft.Get("downloadable_url").String()) - if urlStr == "" { - urlStr = strings.TrimSpace(draft.Get("url").String()) - } - - if kind == "sora_content_violation" || reason != "" || urlStr == "" { - msg := reason - if msg == "" { - msg = "Content violates guardrails" - } - draftFound = &SoraVideoTaskStatus{ - ID: taskID, - Status: "failed", - ErrorMsg: msg, - } - } else { - draftFound = &SoraVideoTaskStatus{ - ID: taskID, - Status: "completed", - URLs: []string{urlStr}, - } - } - return false - }) - if draftFound != nil { - return draftFound, true - } - return nil, false -} - -// ===================== Test 1: TestSoraParseUploadResponse ===================== - -func TestSoraParseUploadResponse(t *testing.T) { - tests := []struct { - name string - body string - wantID string - wantErr bool - }{ - { - name: "正常 id", - body: `{"id":"file-abc123","status":"uploaded"}`, - wantID: "file-abc123", - }, - { - name: "空 id", - body: `{"id":"","status":"uploaded"}`, - wantErr: true, - }, - { - name: "无 id 字段", - body: `{"status":"uploaded"}`, - wantErr: true, - }, - { - name: "id 全为空白", - body: `{"id":" ","status":"uploaded"}`, - wantErr: true, - }, - { - name: "id 前后有空白", - body: `{"id":" file-trimmed ","status":"uploaded"}`, - wantID: "file-trimmed", - }, - { - name: "空 JSON 对象", - body: `{}`, - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - id, err := testParseUploadOrCreateTaskID([]byte(tt.body)) - if tt.wantErr { - require.Error(t, err, "应返回错误") - return - } - require.NoError(t, err) - require.Equal(t, tt.wantID, id) - }) - } -} - -// ===================== Test 2: TestSoraParseCreateTaskResponse ===================== - -func TestSoraParseCreateTaskResponse(t *testing.T) { - tests := []struct { - name string - body string - wantID string - wantErr bool - }{ - { - name: "正常任务 id", - body: `{"id":"task-123"}`, - wantID: "task-123", - }, - { - name: "缺失 id", - body: `{"status":"created"}`, - wantErr: true, - }, - { - name: "空 id", - body: `{"id":" "}`, - wantErr: true, - }, - { - name: "id 为数字(gjson 转字符串)", - body: `{"id":123}`, - wantID: "123", - }, - { - name: "id 含特殊字符", - body: `{"id":"task-abc-def-456-ghi"}`, - wantID: "task-abc-def-456-ghi", - }, - { - name: "额外字段不影响解析", - body: `{"id":"task-999","type":"image_gen","extra":"data"}`, - wantID: "task-999", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - id, err := testParseUploadOrCreateTaskID([]byte(tt.body)) - if tt.wantErr { - require.Error(t, err, "应返回错误") - return - } - require.NoError(t, err) - require.Equal(t, tt.wantID, id) - }) - } -} - -// ===================== Test 3: TestSoraParseFetchRecentImageTask ===================== - -func TestSoraParseFetchRecentImageTask(t *testing.T) { - tests := []struct { - name string - body string - taskID string - wantFound bool - wantStatus string - wantProgress float64 - wantURLs []string - }{ - { - name: "匹配已完成任务", - body: `{"task_responses":[{"id":"task-1","status":"completed","progress_pct":1.0,"generations":[{"url":"https://example.com/img.png"}]}]}`, - taskID: "task-1", - wantFound: true, - wantStatus: "completed", - wantProgress: 1.0, - wantURLs: []string{"https://example.com/img.png"}, - }, - { - name: "匹配处理中任务", - body: `{"task_responses":[{"id":"task-2","status":"processing","progress_pct":0.5,"generations":[]}]}`, - taskID: "task-2", - wantFound: true, - wantStatus: "processing", - wantProgress: 0.5, - wantURLs: nil, - }, - { - name: "无匹配任务", - body: `{"task_responses":[{"id":"other","status":"completed"}]}`, - taskID: "task-1", - wantFound: false, - wantStatus: "processing", - }, - { - name: "空 task_responses", - body: `{"task_responses":[]}`, - taskID: "task-1", - wantFound: false, - wantStatus: "processing", - }, - { - name: "缺少 task_responses 字段", - body: `{"other":"data"}`, - taskID: "task-1", - wantFound: false, - wantStatus: "processing", - }, - { - name: "多个任务中精准匹配", - body: `{"task_responses":[{"id":"task-a","status":"completed","progress_pct":1.0,"generations":[{"url":"https://a.com/1.png"}]},{"id":"task-b","status":"processing","progress_pct":0.3,"generations":[]},{"id":"task-c","status":"failed","progress_pct":0}]}`, - taskID: "task-b", - wantFound: true, - wantStatus: "processing", - wantProgress: 0.3, - wantURLs: nil, - }, - { - name: "多个 generations", - body: `{"task_responses":[{"id":"task-m","status":"completed","progress_pct":1.0,"generations":[{"url":"https://a.com/1.png"},{"url":"https://a.com/2.png"},{"url":""}]}]}`, - taskID: "task-m", - wantFound: true, - wantStatus: "completed", - wantProgress: 1.0, - wantURLs: []string{"https://a.com/1.png", "https://a.com/2.png"}, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - status, found := testParseFetchRecentImageTask([]byte(tt.body), tt.taskID) - require.Equal(t, tt.wantFound, found, "found 不匹配") - require.NotNil(t, status) - require.Equal(t, tt.taskID, status.ID) - require.Equal(t, tt.wantStatus, status.Status) - if tt.wantFound { - require.InDelta(t, tt.wantProgress, status.ProgressPct, 0.001, "进度不匹配") - require.Equal(t, tt.wantURLs, status.URLs) - } - }) - } -} - -// ===================== Test 4: TestSoraParseGetVideoTaskPending ===================== - -func TestSoraParseGetVideoTaskPending(t *testing.T) { - tests := []struct { - name string - body string - taskID string - wantFound bool - wantStatus string - wantProgress int - }{ - { - name: "匹配 pending 任务", - body: `[{"id":"task-1","status":"processing","progress_pct":0.5}]`, - taskID: "task-1", - wantFound: true, - wantStatus: "processing", - wantProgress: 50, - }, - { - name: "进度为 0", - body: `[{"id":"task-2","status":"queued","progress_pct":0}]`, - taskID: "task-2", - wantFound: true, - wantStatus: "queued", - wantProgress: 0, - }, - { - name: "进度为 1(100%)", - body: `[{"id":"task-3","status":"completing","progress_pct":1.0}]`, - taskID: "task-3", - wantFound: true, - wantStatus: "completing", - wantProgress: 100, - }, - { - name: "空数组", - body: `[]`, - taskID: "task-1", - wantFound: false, - }, - { - name: "无匹配 id", - body: `[{"id":"task-other","status":"processing","progress_pct":0.3}]`, - taskID: "task-1", - wantFound: false, - }, - { - name: "多个任务精准匹配", - body: `[{"id":"task-a","status":"processing","progress_pct":0.2},{"id":"task-b","status":"queued","progress_pct":0},{"id":"task-c","status":"processing","progress_pct":0.8}]`, - taskID: "task-c", - wantFound: true, - wantStatus: "processing", - wantProgress: 80, - }, - { - name: "非数组 JSON", - body: `{"id":"task-1","status":"processing"}`, - taskID: "task-1", - wantFound: false, - }, - { - name: "无 progress_pct 字段", - body: `[{"id":"task-4","status":"pending"}]`, - taskID: "task-4", - wantFound: true, - wantStatus: "pending", - wantProgress: 0, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - status, found := testParseGetVideoTaskPending([]byte(tt.body), tt.taskID) - require.Equal(t, tt.wantFound, found, "found 不匹配") - if tt.wantFound { - require.NotNil(t, status) - require.Equal(t, tt.taskID, status.ID) - require.Equal(t, tt.wantStatus, status.Status) - require.Equal(t, tt.wantProgress, status.ProgressPct) - } - }) - } -} - -// ===================== Test 5: TestSoraParseGetVideoTaskDrafts ===================== - -func TestSoraParseGetVideoTaskDrafts(t *testing.T) { - tests := []struct { - name string - body string - taskID string - wantFound bool - wantStatus string - wantURLs []string - wantErr string - }{ - { - name: "正常完成的视频", - body: `{"items":[{"task_id":"task-1","kind":"video","downloadable_url":"https://example.com/video.mp4"}]}`, - taskID: "task-1", - wantFound: true, - wantStatus: "completed", - wantURLs: []string{"https://example.com/video.mp4"}, - }, - { - name: "使用 url 字段回退", - body: `{"items":[{"task_id":"task-2","kind":"video","url":"https://example.com/fallback.mp4"}]}`, - taskID: "task-2", - wantFound: true, - wantStatus: "completed", - wantURLs: []string{"https://example.com/fallback.mp4"}, - }, - { - name: "内容违规", - body: `{"items":[{"task_id":"task-3","kind":"sora_content_violation","reason_str":"Content policy violation"}]}`, - taskID: "task-3", - wantFound: true, - wantStatus: "failed", - wantErr: "Content policy violation", - }, - { - name: "内容违规 - markdown_reason_str 回退", - body: `{"items":[{"task_id":"task-4","kind":"sora_content_violation","markdown_reason_str":"Markdown reason"}]}`, - taskID: "task-4", - wantFound: true, - wantStatus: "failed", - wantErr: "Markdown reason", - }, - { - name: "内容违规 - 无 reason 使用默认消息", - body: `{"items":[{"task_id":"task-5","kind":"sora_content_violation"}]}`, - taskID: "task-5", - wantFound: true, - wantStatus: "failed", - wantErr: "Content violates guardrails", - }, - { - name: "有 reason_str 但非 violation kind(仍判定失败)", - body: `{"items":[{"task_id":"task-6","kind":"video","reason_str":"Some error occurred"}]}`, - taskID: "task-6", - wantFound: true, - wantStatus: "failed", - wantErr: "Some error occurred", - }, - { - name: "空 URL 判定为失败", - body: `{"items":[{"task_id":"task-7","kind":"video","downloadable_url":"","url":""}]}`, - taskID: "task-7", - wantFound: true, - wantStatus: "failed", - wantErr: "Content violates guardrails", - }, - { - name: "无匹配 task_id", - body: `{"items":[{"task_id":"task-other","kind":"video","downloadable_url":"https://example.com/video.mp4"}]}`, - taskID: "task-1", - wantFound: false, - }, - { - name: "空 items", - body: `{"items":[]}`, - taskID: "task-1", - wantFound: false, - }, - { - name: "缺少 items 字段", - body: `{"other":"data"}`, - taskID: "task-1", - wantFound: false, - }, - { - name: "多个 items 精准匹配", - body: `{"items":[{"task_id":"task-a","kind":"video","downloadable_url":"https://a.com/a.mp4"},{"task_id":"task-b","kind":"sora_content_violation","reason_str":"Bad content"},{"task_id":"task-c","kind":"video","downloadable_url":"https://c.com/c.mp4"}]}`, - taskID: "task-b", - wantFound: true, - wantStatus: "failed", - wantErr: "Bad content", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - status, found := testParseGetVideoTaskDrafts([]byte(tt.body), tt.taskID) - require.Equal(t, tt.wantFound, found, "found 不匹配") - if !tt.wantFound { - return - } - require.NotNil(t, status) - require.Equal(t, tt.taskID, status.ID) - require.Equal(t, tt.wantStatus, status.Status) - if tt.wantErr != "" { - require.Equal(t, tt.wantErr, status.ErrorMsg) - } - if tt.wantURLs != nil { - require.Equal(t, tt.wantURLs, status.URLs) - } - }) - } -} diff --git a/backend/internal/service/sora_client_test.go b/backend/internal/service/sora_client_test.go deleted file mode 100644 index cffe8a35..00000000 --- a/backend/internal/service/sora_client_test.go +++ /dev/null @@ -1,1075 +0,0 @@ -//go:build unit - -package service - -import ( - "context" - "encoding/base64" - "encoding/json" - "errors" - "io" - "net/http" - "net/http/httptest" - "strings" - "sync/atomic" - "testing" - "time" - - "github.com/Wei-Shaw/sub2api/internal/config" - "github.com/stretchr/testify/require" -) - -func TestSoraDirectClient_DoRequestSuccess(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(`{"ok":true}`)) - })) - defer server.Close() - - cfg := &config.Config{ - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{BaseURL: server.URL}, - }, - } - client := NewSoraDirectClient(cfg, nil, nil) - - body, _, err := client.doRequest(context.Background(), &Account{ID: 1}, http.MethodGet, server.URL, http.Header{}, nil, false) - require.NoError(t, err) - require.Contains(t, string(body), "ok") -} - -func TestSoraDirectClient_BuildBaseHeaders(t *testing.T) { - cfg := &config.Config{ - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{ - Headers: map[string]string{ - "X-Test": "yes", - "Authorization": "should-ignore", - "openai-sentinel-token": "skip", - }, - }, - }, - } - client := NewSoraDirectClient(cfg, nil, nil) - - headers := client.buildBaseHeaders("token-123", "UA") - require.Equal(t, "Bearer token-123", headers.Get("Authorization")) - require.Equal(t, "UA", headers.Get("User-Agent")) - require.Equal(t, "yes", headers.Get("X-Test")) - require.Empty(t, headers.Get("openai-sentinel-token")) -} - -func TestSoraDirectClient_GetImageTaskFallbackLimit(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - limit := r.URL.Query().Get("limit") - w.Header().Set("Content-Type", "application/json") - switch limit { - case "1": - _, _ = w.Write([]byte(`{"task_responses":[]}`)) - case "2": - _, _ = w.Write([]byte(`{"task_responses":[{"id":"task-1","status":"completed","progress_pct":1,"generations":[{"url":"https://example.com/a.png"}]}]}`)) - default: - _, _ = w.Write([]byte(`{"task_responses":[]}`)) - } - })) - defer server.Close() - - cfg := &config.Config{ - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{ - BaseURL: server.URL, - RecentTaskLimit: 1, - RecentTaskLimitMax: 2, - }, - }, - } - client := NewSoraDirectClient(cfg, nil, nil) - account := &Account{Credentials: map[string]any{"access_token": "token"}} - - status, err := client.GetImageTask(context.Background(), account, "task-1") - require.NoError(t, err) - require.Equal(t, "completed", status.Status) - require.Equal(t, []string{"https://example.com/a.png"}, status.URLs) -} - -func TestNormalizeSoraBaseURL(t *testing.T) { - t.Parallel() - tests := []struct { - name string - raw string - want string - }{ - { - name: "empty", - raw: "", - want: "", - }, - { - name: "append_backend_for_sora_host", - raw: "https://sora.chatgpt.com", - want: "https://sora.chatgpt.com/backend", - }, - { - name: "convert_backend_api_to_backend", - raw: "https://sora.chatgpt.com/backend-api", - want: "https://sora.chatgpt.com/backend", - }, - { - name: "keep_backend", - raw: "https://sora.chatgpt.com/backend", - want: "https://sora.chatgpt.com/backend", - }, - { - name: "keep_custom_host", - raw: "https://example.com/custom-path", - want: "https://example.com/custom-path", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := normalizeSoraBaseURL(tt.raw) - require.Equal(t, tt.want, got) - }) - } -} - -func TestSoraDirectClient_BuildURL_UsesNormalizedBaseURL(t *testing.T) { - t.Parallel() - cfg := &config.Config{ - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{ - BaseURL: "https://sora.chatgpt.com", - }, - }, - } - client := NewSoraDirectClient(cfg, nil, nil) - require.Equal(t, "https://sora.chatgpt.com/backend/video_gen", client.buildURL("/video_gen")) -} - -func TestSoraDirectClient_BuildUpstreamError_NotFoundHint(t *testing.T) { - t.Parallel() - client := NewSoraDirectClient(&config.Config{}, nil, nil) - - err := client.buildUpstreamError(http.StatusNotFound, http.Header{}, []byte(`{"error":{"message":"Not found"}}`), "https://sora.chatgpt.com/video_gen") - var upstreamErr *SoraUpstreamError - require.ErrorAs(t, err, &upstreamErr) - require.Contains(t, upstreamErr.Message, "请检查 sora.client.base_url") - - errNoHint := client.buildUpstreamError(http.StatusNotFound, http.Header{}, []byte(`{"error":{"message":"Not found"}}`), "https://sora.chatgpt.com/backend/video_gen") - require.ErrorAs(t, errNoHint, &upstreamErr) - require.NotContains(t, upstreamErr.Message, "请检查 sora.client.base_url") -} - -func TestFormatSoraHeaders_RedactsSensitive(t *testing.T) { - t.Parallel() - headers := http.Header{} - headers.Set("Authorization", "Bearer secret-token") - headers.Set("openai-sentinel-token", "sentinel-secret") - headers.Set("X-Test", "ok") - - out := formatSoraHeaders(headers) - require.Contains(t, out, `"Authorization":"***"`) - require.Contains(t, out, `Sentinel-Token":"***"`) - require.Contains(t, out, `"X-Test":"ok"`) - require.NotContains(t, out, "secret-token") - require.NotContains(t, out, "sentinel-secret") -} - -func TestSummarizeSoraResponseBody_RedactsJSON(t *testing.T) { - t.Parallel() - body := []byte(`{"error":{"message":"bad"},"access_token":"abc123"}`) - out := summarizeSoraResponseBody(body, 512) - require.Contains(t, out, `"access_token":"***"`) - require.NotContains(t, out, "abc123") -} - -func TestSummarizeSoraResponseBody_Truncates(t *testing.T) { - t.Parallel() - body := []byte(strings.Repeat("x", 100)) - out := summarizeSoraResponseBody(body, 10) - require.Contains(t, out, "(truncated)") -} - -func TestSoraDirectClient_GetAccessToken_SoraDefaultUseCredentials(t *testing.T) { - t.Parallel() - cache := newOpenAITokenCacheStub() - provider := NewOpenAITokenProvider(nil, cache, nil) - cfg := &config.Config{ - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{ - BaseURL: "https://sora.chatgpt.com/backend", - }, - }, - } - client := NewSoraDirectClient(cfg, nil, provider) - account := &Account{ - ID: 1, - Platform: PlatformSora, - Type: AccountTypeOAuth, - Credentials: map[string]any{ - "access_token": "sora-credential-token", - }, - } - - token, err := client.getAccessToken(context.Background(), account) - require.NoError(t, err) - require.Equal(t, "sora-credential-token", token) - require.Equal(t, int32(0), atomic.LoadInt32(&cache.getCalled)) -} - -func TestSoraDirectClient_GetAccessToken_SoraCanEnableProvider(t *testing.T) { - t.Parallel() - cache := newOpenAITokenCacheStub() - account := &Account{ - ID: 2, - Platform: PlatformSora, - Type: AccountTypeOAuth, - Credentials: map[string]any{ - "access_token": "sora-credential-token", - }, - } - cache.tokens[OpenAITokenCacheKey(account)] = "provider-token" - provider := NewOpenAITokenProvider(nil, cache, nil) - cfg := &config.Config{ - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{ - BaseURL: "https://sora.chatgpt.com/backend", - UseOpenAITokenProvider: true, - }, - }, - } - client := NewSoraDirectClient(cfg, nil, provider) - - token, err := client.getAccessToken(context.Background(), account) - require.NoError(t, err) - require.Equal(t, "provider-token", token) - require.Greater(t, atomic.LoadInt32(&cache.getCalled), int32(0)) -} - -func TestSoraDirectClient_GetAccessToken_FromSessionToken(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, http.MethodGet, r.Method) - require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=session-token") - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(map[string]any{ - "accessToken": "session-access-token", - "expires": "2099-01-01T00:00:00Z", - }) - })) - defer server.Close() - - origin := soraSessionAuthURL - soraSessionAuthURL = server.URL - defer func() { soraSessionAuthURL = origin }() - - client := NewSoraDirectClient(&config.Config{}, nil, nil) - account := &Account{ - ID: 10, - Platform: PlatformSora, - Type: AccountTypeOAuth, - Credentials: map[string]any{ - "session_token": "session-token", - }, - } - - token, err := client.getAccessToken(context.Background(), account) - require.NoError(t, err) - require.Equal(t, "session-access-token", token) - require.Equal(t, "session-access-token", account.GetCredential("access_token")) -} - -func TestSoraDirectClient_GetAccessToken_FromRefreshToken(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, http.MethodPost, r.Method) - require.Equal(t, "/oauth/token", r.URL.Path) - require.Equal(t, "application/x-www-form-urlencoded", r.Header.Get("Content-Type")) - require.NoError(t, r.ParseForm()) - require.Equal(t, "refresh_token", r.FormValue("grant_type")) - require.Equal(t, "refresh-token-old", r.FormValue("refresh_token")) - require.NotEmpty(t, r.FormValue("client_id")) - require.Equal(t, "com.openai.chat://auth0.openai.com/ios/com.openai.chat/callback", r.FormValue("redirect_uri")) - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(map[string]any{ - "access_token": "refresh-access-token", - "refresh_token": "refresh-token-new", - "expires_in": 3600, - }) - })) - defer server.Close() - - origin := soraOAuthTokenURL - soraOAuthTokenURL = server.URL + "/oauth/token" - defer func() { soraOAuthTokenURL = origin }() - - client := NewSoraDirectClient(&config.Config{}, nil, nil) - account := &Account{ - ID: 11, - Platform: PlatformSora, - Type: AccountTypeOAuth, - Credentials: map[string]any{ - "refresh_token": "refresh-token-old", - }, - } - - token, err := client.getAccessToken(context.Background(), account) - require.NoError(t, err) - require.Equal(t, "refresh-access-token", token) - require.Equal(t, "refresh-token-new", account.GetCredential("refresh_token")) - require.NotNil(t, account.GetCredentialAsTime("expires_at")) -} - -func TestSoraDirectClient_PreflightCheck_VideoQuotaExceeded(t *testing.T) { - t.Parallel() - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, http.MethodGet, r.Method) - require.Equal(t, "/nf/check", r.URL.Path) - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(map[string]any{ - "rate_limit_and_credit_balance": map[string]any{ - "estimated_num_videos_remaining": 0, - "rate_limit_reached": true, - }, - }) - })) - defer server.Close() - - cfg := &config.Config{ - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{ - BaseURL: server.URL, - }, - }, - } - client := NewSoraDirectClient(cfg, nil, nil) - account := &Account{ - ID: 12, - Platform: PlatformSora, - Type: AccountTypeOAuth, - Credentials: map[string]any{ - "access_token": "ok", - "expires_at": time.Now().Add(2 * time.Hour).Format(time.RFC3339), - }, - } - err := client.PreflightCheck(context.Background(), account, "sora2-landscape-10s", SoraModelConfig{Type: "video"}) - require.Error(t, err) - var upstreamErr *SoraUpstreamError - require.ErrorAs(t, err, &upstreamErr) - require.Equal(t, http.StatusTooManyRequests, upstreamErr.StatusCode) -} - -func TestShouldAttemptSoraTokenRecover(t *testing.T) { - t.Parallel() - - require.True(t, shouldAttemptSoraTokenRecover(http.StatusUnauthorized, "https://sora.chatgpt.com/backend/video_gen")) - require.True(t, shouldAttemptSoraTokenRecover(http.StatusForbidden, "https://chatgpt.com/backend/video_gen")) - require.False(t, shouldAttemptSoraTokenRecover(http.StatusUnauthorized, "https://sora.chatgpt.com/api/auth/session")) - require.False(t, shouldAttemptSoraTokenRecover(http.StatusUnauthorized, "https://auth.openai.com/oauth/token")) - require.False(t, shouldAttemptSoraTokenRecover(http.StatusTooManyRequests, "https://sora.chatgpt.com/backend/video_gen")) -} - -type soraClientRequestCall struct { - Path string - UserAgent string - ProxyURL string -} - -type soraClientRecordingUpstream struct { - calls []soraClientRequestCall -} - -func (u *soraClientRecordingUpstream) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) { - return nil, errors.New("unexpected Do call") -} - -func (u *soraClientRecordingUpstream) DoWithTLS(req *http.Request, proxyURL string, _ int64, _ int, _ bool) (*http.Response, error) { - u.calls = append(u.calls, soraClientRequestCall{ - Path: req.URL.Path, - UserAgent: req.Header.Get("User-Agent"), - ProxyURL: proxyURL, - }) - switch req.URL.Path { - case "/backend-api/sentinel/req": - return newSoraClientMockResponse(http.StatusOK, `{"token":"sentinel-token","turnstile":{"dx":"ok"}}`), nil - case "/backend/nf/create": - return newSoraClientMockResponse(http.StatusOK, `{"id":"task-123"}`), nil - case "/backend/nf/create/storyboard": - return newSoraClientMockResponse(http.StatusOK, `{"id":"storyboard-123"}`), nil - case "/backend/uploads": - return newSoraClientMockResponse(http.StatusOK, `{"id":"upload-123"}`), nil - case "/backend/nf/check": - return newSoraClientMockResponse(http.StatusOK, `{"rate_limit_and_credit_balance":{"estimated_num_videos_remaining":1,"rate_limit_reached":false}}`), nil - case "/backend/characters/upload": - return newSoraClientMockResponse(http.StatusOK, `{"id":"cameo-123"}`), nil - case "/backend/project_y/cameos/in_progress/cameo-123": - return newSoraClientMockResponse(http.StatusOK, `{"status":"finalized","status_message":"Completed","username_hint":"foo.bar","display_name_hint":"Bar","profile_asset_url":"https://example.com/avatar.webp"}`), nil - case "/backend/project_y/file/upload": - return newSoraClientMockResponse(http.StatusOK, `{"asset_pointer":"asset-123"}`), nil - case "/backend/characters/finalize": - return newSoraClientMockResponse(http.StatusOK, `{"character":{"character_id":"character-123"}}`), nil - case "/backend/project_y/post": - return newSoraClientMockResponse(http.StatusOK, `{"post":{"id":"s_post"}}`), nil - default: - return newSoraClientMockResponse(http.StatusOK, `{"ok":true}`), nil - } -} - -func newSoraClientMockResponse(statusCode int, body string) *http.Response { - return &http.Response{ - StatusCode: statusCode, - Header: make(http.Header), - Body: io.NopCloser(strings.NewReader(body)), - } -} - -func TestSoraDirectClient_TaskUserAgent_DefaultMobileFallback(t *testing.T) { - client := NewSoraDirectClient(&config.Config{}, nil, nil) - ua := client.taskUserAgent() - require.NotEmpty(t, ua) - allowed := append([]string{}, soraMobileUserAgents...) - allowed = append(allowed, soraDesktopUserAgents...) - require.Contains(t, allowed, ua) -} - -func TestSoraDirectClient_CreateVideoTask_UsesSameUserAgentAndProxyForSentinelAndCreate(t *testing.T) { - originPowTokenGenerator := soraPowTokenGenerator - soraPowTokenGenerator = func(_ string) string { return "gAAAAACmock" } - defer func() { - soraPowTokenGenerator = originPowTokenGenerator - }() - - upstream := &soraClientRecordingUpstream{} - cfg := &config.Config{ - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{ - BaseURL: "https://sora.chatgpt.com/backend", - }, - }, - } - client := NewSoraDirectClient(cfg, upstream, nil) - proxyID := int64(9) - account := &Account{ - ID: 21, - Platform: PlatformSora, - Type: AccountTypeOAuth, - Concurrency: 1, - ProxyID: &proxyID, - Proxy: &Proxy{ - Protocol: "http", - Host: "127.0.0.1", - Port: 8080, - }, - Credentials: map[string]any{ - "access_token": "access-token", - "expires_at": time.Now().Add(30 * time.Minute).Format(time.RFC3339), - }, - } - - taskID, err := client.CreateVideoTask(context.Background(), account, SoraVideoRequest{Prompt: "test"}) - require.NoError(t, err) - require.Equal(t, "task-123", taskID) - require.Len(t, upstream.calls, 2) - - sentinelCall := upstream.calls[0] - createCall := upstream.calls[1] - require.Equal(t, "/backend-api/sentinel/req", sentinelCall.Path) - require.Equal(t, "/backend/nf/create", createCall.Path) - require.Equal(t, "http://127.0.0.1:8080", sentinelCall.ProxyURL) - require.Equal(t, sentinelCall.ProxyURL, createCall.ProxyURL) - require.NotEmpty(t, sentinelCall.UserAgent) - require.Equal(t, sentinelCall.UserAgent, createCall.UserAgent) -} - -func TestSoraDirectClient_UploadImage_UsesTaskUserAgentAndProxy(t *testing.T) { - upstream := &soraClientRecordingUpstream{} - cfg := &config.Config{ - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{ - BaseURL: "https://sora.chatgpt.com/backend", - }, - }, - } - client := NewSoraDirectClient(cfg, upstream, nil) - proxyID := int64(3) - account := &Account{ - ID: 31, - ProxyID: &proxyID, - Proxy: &Proxy{ - Protocol: "http", - Host: "127.0.0.1", - Port: 8080, - }, - Credentials: map[string]any{ - "access_token": "access-token", - "expires_at": time.Now().Add(30 * time.Minute).Format(time.RFC3339), - }, - } - - uploadID, err := client.UploadImage(context.Background(), account, []byte("mock-image"), "a.png") - require.NoError(t, err) - require.Equal(t, "upload-123", uploadID) - require.Len(t, upstream.calls, 1) - require.Equal(t, "/backend/uploads", upstream.calls[0].Path) - require.Equal(t, "http://127.0.0.1:8080", upstream.calls[0].ProxyURL) - require.NotEmpty(t, upstream.calls[0].UserAgent) -} - -func TestSoraDirectClient_PreflightCheck_UsesTaskUserAgentAndProxy(t *testing.T) { - upstream := &soraClientRecordingUpstream{} - cfg := &config.Config{ - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{ - BaseURL: "https://sora.chatgpt.com/backend", - }, - }, - } - client := NewSoraDirectClient(cfg, upstream, nil) - proxyID := int64(7) - account := &Account{ - ID: 41, - ProxyID: &proxyID, - Proxy: &Proxy{ - Protocol: "http", - Host: "127.0.0.1", - Port: 8080, - }, - Credentials: map[string]any{ - "access_token": "access-token", - "expires_at": time.Now().Add(30 * time.Minute).Format(time.RFC3339), - }, - } - - err := client.PreflightCheck(context.Background(), account, "sora2", SoraModelConfig{Type: "video"}) - require.NoError(t, err) - require.Len(t, upstream.calls, 1) - require.Equal(t, "/backend/nf/check", upstream.calls[0].Path) - require.Equal(t, "http://127.0.0.1:8080", upstream.calls[0].ProxyURL) - require.NotEmpty(t, upstream.calls[0].UserAgent) -} - -func TestSoraDirectClient_CreateStoryboardTask(t *testing.T) { - originPowTokenGenerator := soraPowTokenGenerator - soraPowTokenGenerator = func(_ string) string { return "gAAAAACmock" } - defer func() { soraPowTokenGenerator = originPowTokenGenerator }() - - upstream := &soraClientRecordingUpstream{} - cfg := &config.Config{ - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{ - BaseURL: "https://sora.chatgpt.com/backend", - }, - }, - } - client := NewSoraDirectClient(cfg, upstream, nil) - account := &Account{ - ID: 51, - Credentials: map[string]any{ - "access_token": "access-token", - "expires_at": time.Now().Add(30 * time.Minute).Format(time.RFC3339), - }, - } - - taskID, err := client.CreateStoryboardTask(context.Background(), account, SoraStoryboardRequest{ - Prompt: "Shot 1:\nduration: 5sec\nScene: cat", - }) - require.NoError(t, err) - require.Equal(t, "storyboard-123", taskID) - require.Len(t, upstream.calls, 2) - require.Equal(t, "/backend-api/sentinel/req", upstream.calls[0].Path) - require.Equal(t, "/backend/nf/create/storyboard", upstream.calls[1].Path) -} - -func TestSoraDirectClient_GetVideoTask_ReturnsGenerationID(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - switch r.URL.Path { - case "/nf/pending/v2": - _, _ = w.Write([]byte(`[]`)) - case "/project_y/profile/drafts": - _, _ = w.Write([]byte(`{"items":[{"id":"gen_1","task_id":"task-1","kind":"video","downloadable_url":"https://example.com/v.mp4"}]}`)) - default: - http.NotFound(w, r) - } - })) - defer server.Close() - - cfg := &config.Config{ - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{ - BaseURL: server.URL, - }, - }, - } - client := NewSoraDirectClient(cfg, nil, nil) - account := &Account{Credentials: map[string]any{"access_token": "token"}} - - status, err := client.GetVideoTask(context.Background(), account, "task-1") - require.NoError(t, err) - require.Equal(t, "completed", status.Status) - require.Equal(t, "gen_1", status.GenerationID) - require.Equal(t, []string{"https://example.com/v.mp4"}, status.URLs) -} - -func TestSoraDirectClient_PostVideoForWatermarkFree(t *testing.T) { - originPowTokenGenerator := soraPowTokenGenerator - soraPowTokenGenerator = func(_ string) string { return "gAAAAACmock" } - defer func() { soraPowTokenGenerator = originPowTokenGenerator }() - - upstream := &soraClientRecordingUpstream{} - cfg := &config.Config{ - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{ - BaseURL: "https://sora.chatgpt.com/backend", - }, - }, - } - client := NewSoraDirectClient(cfg, upstream, nil) - account := &Account{ - ID: 52, - Credentials: map[string]any{ - "access_token": "access-token", - "expires_at": time.Now().Add(30 * time.Minute).Format(time.RFC3339), - }, - } - - postID, err := client.PostVideoForWatermarkFree(context.Background(), account, "gen_1") - require.NoError(t, err) - require.Equal(t, "s_post", postID) - require.Len(t, upstream.calls, 2) - require.Equal(t, "/backend-api/sentinel/req", upstream.calls[0].Path) - require.Equal(t, "/backend/project_y/post", upstream.calls[1].Path) -} - -type soraClientFallbackUpstream struct { - doWithTLSCalls int32 - respBody string - respStatusCode int - err error -} - -func (u *soraClientFallbackUpstream) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) { - return nil, errors.New("unexpected Do call") -} - -func (u *soraClientFallbackUpstream) DoWithTLS(_ *http.Request, _ string, _ int64, _ int, _ bool) (*http.Response, error) { - atomic.AddInt32(&u.doWithTLSCalls, 1) - if u.err != nil { - return nil, u.err - } - statusCode := u.respStatusCode - if statusCode <= 0 { - statusCode = http.StatusOK - } - body := u.respBody - if body == "" { - body = `{"ok":true}` - } - return newSoraClientMockResponse(statusCode, body), nil -} - -func TestSoraDirectClient_DoHTTP_UsesCurlCFFISidecarWhenEnabled(t *testing.T) { - var captured soraCurlCFFISidecarRequest - sidecar := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, http.MethodPost, r.Method) - require.Equal(t, "/request", r.URL.Path) - raw, err := io.ReadAll(r.Body) - require.NoError(t, err) - require.NoError(t, json.Unmarshal(raw, &captured)) - _ = json.NewEncoder(w).Encode(map[string]any{ - "status_code": http.StatusOK, - "headers": map[string]any{ - "Content-Type": "application/json", - "X-Sidecar": []string{"yes"}, - }, - "body_base64": base64.StdEncoding.EncodeToString([]byte(`{"ok":true}`)), - }) - })) - defer sidecar.Close() - - upstream := &soraClientFallbackUpstream{} - cfg := &config.Config{ - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{ - BaseURL: "https://sora.chatgpt.com/backend", - CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{ - Enabled: true, - BaseURL: sidecar.URL, - Impersonate: "chrome131", - TimeoutSeconds: 15, - SessionReuseEnabled: true, - }, - }, - }, - } - client := NewSoraDirectClient(cfg, upstream, nil) - req, err := http.NewRequest(http.MethodPost, "https://sora.chatgpt.com/backend/me", strings.NewReader("hello-sidecar")) - require.NoError(t, err) - req.Header.Set("User-Agent", "test-ua") - - resp, err := client.doHTTP(req, "http://127.0.0.1:18080", &Account{ID: 1}) - require.NoError(t, err) - defer resp.Body.Close() - body, err := io.ReadAll(resp.Body) - require.NoError(t, err) - - require.JSONEq(t, `{"ok":true}`, string(body)) - require.Equal(t, int32(0), atomic.LoadInt32(&upstream.doWithTLSCalls)) - require.Equal(t, "http://127.0.0.1:18080", captured.ProxyURL) - require.NotEmpty(t, captured.SessionKey) - require.Equal(t, "chrome131", captured.Impersonate) - require.Equal(t, "https://sora.chatgpt.com/backend/me", captured.URL) - decodedReqBody, err := base64.StdEncoding.DecodeString(captured.BodyBase64) - require.NoError(t, err) - require.Equal(t, "hello-sidecar", string(decodedReqBody)) -} - -func TestSoraDirectClient_DoHTTP_CurlCFFISidecarFailureReturnsError(t *testing.T) { - sidecar := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusBadGateway) - _, _ = w.Write([]byte(`{"error":"boom"}`)) - })) - defer sidecar.Close() - - upstream := &soraClientFallbackUpstream{respBody: `{"fallback":true}`} - cfg := &config.Config{ - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{ - BaseURL: "https://sora.chatgpt.com/backend", - CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{ - Enabled: true, - BaseURL: sidecar.URL, - }, - }, - }, - } - client := NewSoraDirectClient(cfg, upstream, nil) - req, err := http.NewRequest(http.MethodGet, "https://sora.chatgpt.com/backend/me", nil) - require.NoError(t, err) - - _, err = client.doHTTP(req, "", &Account{ID: 2}) - require.Error(t, err) - require.Contains(t, err.Error(), "sora curl_cffi sidecar") - require.Equal(t, int32(0), atomic.LoadInt32(&upstream.doWithTLSCalls)) -} - -func TestSoraDirectClient_DoHTTP_CurlCFFISidecarDisabledUsesLegacyStack(t *testing.T) { - upstream := &soraClientFallbackUpstream{respBody: `{"legacy":true}`} - cfg := &config.Config{ - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{ - BaseURL: "https://sora.chatgpt.com/backend", - CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{ - Enabled: false, - BaseURL: "http://127.0.0.1:18080", - }, - }, - }, - } - client := NewSoraDirectClient(cfg, upstream, nil) - req, err := http.NewRequest(http.MethodGet, "https://sora.chatgpt.com/backend/me", nil) - require.NoError(t, err) - - resp, err := client.doHTTP(req, "", &Account{ID: 3}) - require.NoError(t, err) - defer resp.Body.Close() - body, err := io.ReadAll(resp.Body) - require.NoError(t, err) - require.JSONEq(t, `{"legacy":true}`, string(body)) - require.Equal(t, int32(1), atomic.LoadInt32(&upstream.doWithTLSCalls)) -} - -func TestConvertSidecarHeaderValue_NilAndSlice(t *testing.T) { - require.Nil(t, convertSidecarHeaderValue(nil)) - require.Equal(t, []string{"a", "b"}, convertSidecarHeaderValue([]any{"a", " ", "b"})) -} - -func TestSoraDirectClient_DoHTTP_SidecarSessionKeyStableForSameAccountProxy(t *testing.T) { - var captured []soraCurlCFFISidecarRequest - sidecar := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - raw, err := io.ReadAll(r.Body) - require.NoError(t, err) - var reqPayload soraCurlCFFISidecarRequest - require.NoError(t, json.Unmarshal(raw, &reqPayload)) - captured = append(captured, reqPayload) - _ = json.NewEncoder(w).Encode(map[string]any{ - "status_code": http.StatusOK, - "headers": map[string]any{ - "Content-Type": "application/json", - }, - "body": `{"ok":true}`, - }) - })) - defer sidecar.Close() - - cfg := &config.Config{ - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{ - BaseURL: "https://sora.chatgpt.com/backend", - CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{ - Enabled: true, - BaseURL: sidecar.URL, - SessionReuseEnabled: true, - SessionTTLSeconds: 3600, - }, - }, - }, - } - client := NewSoraDirectClient(cfg, nil, nil) - account := &Account{ID: 1001} - - req1, err := http.NewRequest(http.MethodGet, "https://sora.chatgpt.com/backend/me", nil) - require.NoError(t, err) - _, err = client.doHTTP(req1, "http://127.0.0.1:18080", account) - require.NoError(t, err) - - req2, err := http.NewRequest(http.MethodGet, "https://sora.chatgpt.com/backend/me", nil) - require.NoError(t, err) - _, err = client.doHTTP(req2, "http://127.0.0.1:18080", account) - require.NoError(t, err) - - require.Len(t, captured, 2) - require.NotEmpty(t, captured[0].SessionKey) - require.Equal(t, captured[0].SessionKey, captured[1].SessionKey) -} - -func TestSoraDirectClient_DoRequestWithProxy_CloudflareChallengeSetsCooldownAfterSingleRetry(t *testing.T) { - var sidecarCalls int32 - sidecar := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - atomic.AddInt32(&sidecarCalls, 1) - _ = json.NewEncoder(w).Encode(map[string]any{ - "status_code": http.StatusForbidden, - "headers": map[string]any{ - "cf-ray": "9d05d73dec4d8c8e-GRU", - "content-type": "text/html", - }, - "body": `Just a moment...`, - }) - })) - defer sidecar.Close() - - cfg := &config.Config{ - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{ - BaseURL: "https://sora.chatgpt.com/backend", - MaxRetries: 3, - CloudflareChallengeCooldownSeconds: 60, - CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{ - Enabled: true, - BaseURL: sidecar.URL, - Impersonate: "chrome131", - }, - }, - }, - } - client := NewSoraDirectClient(cfg, nil, nil) - headers := http.Header{} - - _, _, err := client.doRequestWithProxy( - context.Background(), - &Account{ID: 99}, - "http://127.0.0.1:18080", - http.MethodGet, - "https://sora.chatgpt.com/backend/me", - headers, - nil, - true, - ) - require.Error(t, err) - var upstreamErr *SoraUpstreamError - require.ErrorAs(t, err, &upstreamErr) - require.Equal(t, http.StatusForbidden, upstreamErr.StatusCode) - require.Equal(t, int32(2), atomic.LoadInt32(&sidecarCalls), "challenge should trigger exactly one same-proxy retry") - - _, _, err = client.doRequestWithProxy( - context.Background(), - &Account{ID: 99}, - "http://127.0.0.1:18080", - http.MethodGet, - "https://sora.chatgpt.com/backend/me", - headers, - nil, - true, - ) - require.Error(t, err) - require.ErrorAs(t, err, &upstreamErr) - require.Equal(t, http.StatusTooManyRequests, upstreamErr.StatusCode) - require.Contains(t, upstreamErr.Message, "cooling down") - require.Contains(t, upstreamErr.Message, "cf-ray") - require.Equal(t, int32(2), atomic.LoadInt32(&sidecarCalls), "cooldown should block outbound request") -} - -func TestSoraDirectClient_DoRequestWithProxy_CloudflareRetrySuccessClearsCooldown(t *testing.T) { - var sidecarCalls int32 - sidecar := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - call := atomic.AddInt32(&sidecarCalls, 1) - if call == 1 { - _ = json.NewEncoder(w).Encode(map[string]any{ - "status_code": http.StatusForbidden, - "headers": map[string]any{ - "cf-ray": "9d05d73dec4d8c8e-GRU", - "content-type": "text/html", - }, - "body": `Just a moment...`, - }) - return - } - _ = json.NewEncoder(w).Encode(map[string]any{ - "status_code": http.StatusOK, - "headers": map[string]any{ - "content-type": "application/json", - }, - "body": `{"ok":true}`, - }) - })) - defer sidecar.Close() - - cfg := &config.Config{ - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{ - BaseURL: "https://sora.chatgpt.com/backend", - MaxRetries: 3, - CloudflareChallengeCooldownSeconds: 60, - CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{ - Enabled: true, - BaseURL: sidecar.URL, - Impersonate: "chrome131", - }, - }, - }, - } - client := NewSoraDirectClient(cfg, nil, nil) - headers := http.Header{} - account := &Account{ID: 109} - proxyURL := "http://127.0.0.1:18080" - - body, _, err := client.doRequestWithProxy( - context.Background(), - account, - proxyURL, - http.MethodGet, - "https://sora.chatgpt.com/backend/me", - headers, - nil, - true, - ) - require.NoError(t, err) - require.Contains(t, string(body), `"ok":true`) - require.Equal(t, int32(2), atomic.LoadInt32(&sidecarCalls)) - - _, _, err = client.doRequestWithProxy( - context.Background(), - account, - proxyURL, - http.MethodGet, - "https://sora.chatgpt.com/backend/me", - headers, - nil, - true, - ) - require.NoError(t, err) - require.Equal(t, int32(3), atomic.LoadInt32(&sidecarCalls), "cooldown should be cleared after retry succeeds") -} - -func TestSoraComputeChallengeCooldownSeconds(t *testing.T) { - require.Equal(t, 0, soraComputeChallengeCooldownSeconds(0, 3)) - require.Equal(t, 10, soraComputeChallengeCooldownSeconds(10, 1)) - require.Equal(t, 20, soraComputeChallengeCooldownSeconds(10, 2)) - require.Equal(t, 40, soraComputeChallengeCooldownSeconds(10, 4)) - require.Equal(t, 40, soraComputeChallengeCooldownSeconds(10, 9), "streak should cap at x4") - require.Equal(t, 3600, soraComputeChallengeCooldownSeconds(1200, 9), "cooldown should cap at 3600s") -} - -func TestSoraDirectClient_RecordCloudflareChallengeCooldown_EscalatesStreak(t *testing.T) { - cfg := &config.Config{ - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{ - CloudflareChallengeCooldownSeconds: 10, - }, - }, - } - client := NewSoraDirectClient(cfg, nil, nil) - account := &Account{ID: 201} - proxyURL := "http://127.0.0.1:18080" - - client.recordCloudflareChallengeCooldown(account, proxyURL, http.StatusForbidden, http.Header{"Cf-Ray": []string{"9d05d73dec4d8c8e-GRU"}}, nil) - client.recordCloudflareChallengeCooldown(account, proxyURL, http.StatusForbidden, http.Header{"Cf-Ray": []string{"9d05d73dec4d8c8f-GRU"}}, nil) - - key := soraAccountProxyKey(account, proxyURL) - entry, ok := client.challengeCooldowns[key] - require.True(t, ok) - require.Equal(t, 2, entry.ConsecutiveChallenges) - require.Equal(t, "9d05d73dec4d8c8f-GRU", entry.CFRay) - remain := int(entry.Until.Sub(entry.LastChallengeAt).Seconds()) - require.GreaterOrEqual(t, remain, 19) -} - -func TestSoraDirectClient_SidecarSessionKey_SkipsWhenAccountMissing(t *testing.T) { - cfg := &config.Config{ - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{ - CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{ - Enabled: true, - SessionReuseEnabled: true, - SessionTTLSeconds: 3600, - }, - }, - }, - } - client := NewSoraDirectClient(cfg, nil, nil) - require.Equal(t, "", client.sidecarSessionKey(nil, "http://127.0.0.1:18080")) - require.Empty(t, client.sidecarSessions) -} - -func TestSoraDirectClient_SidecarSessionKey_PrunesExpiredAndRecreates(t *testing.T) { - cfg := &config.Config{ - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{ - CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{ - Enabled: true, - SessionReuseEnabled: true, - SessionTTLSeconds: 3600, - }, - }, - }, - } - client := NewSoraDirectClient(cfg, nil, nil) - account := &Account{ID: 123} - key := soraAccountProxyKey(account, "http://127.0.0.1:18080") - client.sidecarSessions[key] = soraSidecarSessionEntry{ - SessionKey: "sora-expired", - ExpiresAt: time.Now().Add(-time.Minute), - LastUsedAt: time.Now().Add(-2 * time.Minute), - } - - sessionKey := client.sidecarSessionKey(account, "http://127.0.0.1:18080") - require.NotEmpty(t, sessionKey) - require.NotEqual(t, "sora-expired", sessionKey) - require.Len(t, client.sidecarSessions, 1) -} - -func TestSoraDirectClient_SidecarSessionKey_TTLZeroKeepsLongLivedSession(t *testing.T) { - cfg := &config.Config{ - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{ - CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{ - Enabled: true, - SessionReuseEnabled: true, - SessionTTLSeconds: 0, - }, - }, - }, - } - client := NewSoraDirectClient(cfg, nil, nil) - account := &Account{ID: 456} - - first := client.sidecarSessionKey(account, "http://127.0.0.1:18080") - second := client.sidecarSessionKey(account, "http://127.0.0.1:18080") - require.NotEmpty(t, first) - require.Equal(t, first, second) - - key := soraAccountProxyKey(account, "http://127.0.0.1:18080") - entry, ok := client.sidecarSessions[key] - require.True(t, ok) - require.True(t, entry.ExpiresAt.After(time.Now().Add(300*24*time.Hour))) -} diff --git a/backend/internal/service/sora_curl_cffi_sidecar.go b/backend/internal/service/sora_curl_cffi_sidecar.go deleted file mode 100644 index 40f5c017..00000000 --- a/backend/internal/service/sora_curl_cffi_sidecar.go +++ /dev/null @@ -1,260 +0,0 @@ -package service - -import ( - "bytes" - "encoding/base64" - "encoding/json" - "errors" - "fmt" - "io" - "net/http" - "net/url" - "strings" - "time" - - "github.com/Wei-Shaw/sub2api/internal/util/logredact" -) - -const soraCurlCFFISidecarDefaultTimeoutSeconds = 60 - -type soraCurlCFFISidecarRequest struct { - Method string `json:"method"` - URL string `json:"url"` - Headers map[string][]string `json:"headers,omitempty"` - BodyBase64 string `json:"body_base64,omitempty"` - ProxyURL string `json:"proxy_url,omitempty"` - SessionKey string `json:"session_key,omitempty"` - Impersonate string `json:"impersonate,omitempty"` - TimeoutSeconds int `json:"timeout_seconds,omitempty"` -} - -type soraCurlCFFISidecarResponse struct { - StatusCode int `json:"status_code"` - Status int `json:"status"` - Headers map[string]any `json:"headers"` - BodyBase64 string `json:"body_base64"` - Body string `json:"body"` - Error string `json:"error"` -} - -func (c *SoraDirectClient) doHTTPViaCurlCFFISidecar(req *http.Request, proxyURL string, account *Account) (*http.Response, error) { - if req == nil || req.URL == nil { - return nil, errors.New("request url is nil") - } - if c == nil || c.cfg == nil { - return nil, errors.New("sora curl_cffi sidecar config is nil") - } - if !c.cfg.Sora.Client.CurlCFFISidecar.Enabled { - return nil, errors.New("sora curl_cffi sidecar is disabled") - } - endpoint := c.curlCFFISidecarEndpoint() - if endpoint == "" { - return nil, errors.New("sora curl_cffi sidecar base_url is empty") - } - - bodyBytes, err := readAndRestoreRequestBody(req) - if err != nil { - return nil, fmt.Errorf("sora curl_cffi sidecar read request body failed: %w", err) - } - - headers := make(map[string][]string, len(req.Header)+1) - for key, vals := range req.Header { - copied := make([]string, len(vals)) - copy(copied, vals) - headers[key] = copied - } - if strings.TrimSpace(req.Host) != "" { - if _, ok := headers["Host"]; !ok { - headers["Host"] = []string{req.Host} - } - } - - payload := soraCurlCFFISidecarRequest{ - Method: req.Method, - URL: req.URL.String(), - Headers: headers, - ProxyURL: strings.TrimSpace(proxyURL), - SessionKey: c.sidecarSessionKey(account, proxyURL), - Impersonate: c.curlCFFIImpersonate(), - TimeoutSeconds: c.curlCFFISidecarTimeoutSeconds(), - } - if len(bodyBytes) > 0 { - payload.BodyBase64 = base64.StdEncoding.EncodeToString(bodyBytes) - } - - encoded, err := json.Marshal(payload) - if err != nil { - return nil, fmt.Errorf("sora curl_cffi sidecar marshal request failed: %w", err) - } - - sidecarReq, err := http.NewRequestWithContext(req.Context(), http.MethodPost, endpoint, bytes.NewReader(encoded)) - if err != nil { - return nil, fmt.Errorf("sora curl_cffi sidecar build request failed: %w", err) - } - sidecarReq.Header.Set("Content-Type", "application/json") - sidecarReq.Header.Set("Accept", "application/json") - - httpClient := &http.Client{Timeout: time.Duration(payload.TimeoutSeconds) * time.Second} - sidecarResp, err := httpClient.Do(sidecarReq) - if err != nil { - return nil, fmt.Errorf("sora curl_cffi sidecar request failed: %w", err) - } - defer func() { - _ = sidecarResp.Body.Close() - }() - - sidecarRespBody, err := io.ReadAll(io.LimitReader(sidecarResp.Body, 8<<20)) - if err != nil { - return nil, fmt.Errorf("sora curl_cffi sidecar read response failed: %w", err) - } - if sidecarResp.StatusCode != http.StatusOK { - redacted := truncateForLog([]byte(logredact.RedactText(string(sidecarRespBody))), 512) - return nil, fmt.Errorf("sora curl_cffi sidecar http status=%d body=%s", sidecarResp.StatusCode, redacted) - } - - var payloadResp soraCurlCFFISidecarResponse - if err := json.Unmarshal(sidecarRespBody, &payloadResp); err != nil { - return nil, fmt.Errorf("sora curl_cffi sidecar parse response failed: %w", err) - } - if msg := strings.TrimSpace(payloadResp.Error); msg != "" { - return nil, fmt.Errorf("sora curl_cffi sidecar upstream error: %s", msg) - } - statusCode := payloadResp.StatusCode - if statusCode <= 0 { - statusCode = payloadResp.Status - } - if statusCode <= 0 { - return nil, errors.New("sora curl_cffi sidecar response missing status code") - } - - responseBody := []byte(payloadResp.Body) - if strings.TrimSpace(payloadResp.BodyBase64) != "" { - decoded, err := base64.StdEncoding.DecodeString(payloadResp.BodyBase64) - if err != nil { - return nil, fmt.Errorf("sora curl_cffi sidecar decode body failed: %w", err) - } - responseBody = decoded - } - - respHeaders := make(http.Header) - for key, rawVal := range payloadResp.Headers { - for _, v := range convertSidecarHeaderValue(rawVal) { - respHeaders.Add(key, v) - } - } - - return &http.Response{ - StatusCode: statusCode, - Header: respHeaders, - Body: io.NopCloser(bytes.NewReader(responseBody)), - ContentLength: int64(len(responseBody)), - Request: req, - }, nil -} - -func readAndRestoreRequestBody(req *http.Request) ([]byte, error) { - if req == nil || req.Body == nil { - return nil, nil - } - bodyBytes, err := io.ReadAll(req.Body) - if err != nil { - return nil, err - } - _ = req.Body.Close() - req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - req.ContentLength = int64(len(bodyBytes)) - return bodyBytes, nil -} - -func (c *SoraDirectClient) curlCFFISidecarEndpoint() string { - if c == nil || c.cfg == nil { - return "" - } - raw := strings.TrimSpace(c.cfg.Sora.Client.CurlCFFISidecar.BaseURL) - if raw == "" { - return "" - } - parsed, err := url.Parse(raw) - if err != nil || strings.TrimSpace(parsed.Scheme) == "" || strings.TrimSpace(parsed.Host) == "" { - return raw - } - if path := strings.TrimSpace(parsed.Path); path == "" || path == "/" { - parsed.Path = "/request" - } - return parsed.String() -} - -func (c *SoraDirectClient) curlCFFISidecarTimeoutSeconds() int { - if c == nil || c.cfg == nil { - return soraCurlCFFISidecarDefaultTimeoutSeconds - } - timeoutSeconds := c.cfg.Sora.Client.CurlCFFISidecar.TimeoutSeconds - if timeoutSeconds <= 0 { - return soraCurlCFFISidecarDefaultTimeoutSeconds - } - return timeoutSeconds -} - -func (c *SoraDirectClient) curlCFFIImpersonate() string { - if c == nil || c.cfg == nil { - return "chrome131" - } - impersonate := strings.TrimSpace(c.cfg.Sora.Client.CurlCFFISidecar.Impersonate) - if impersonate == "" { - return "chrome131" - } - return impersonate -} - -func (c *SoraDirectClient) sidecarSessionReuseEnabled() bool { - if c == nil || c.cfg == nil { - return true - } - return c.cfg.Sora.Client.CurlCFFISidecar.SessionReuseEnabled -} - -func (c *SoraDirectClient) sidecarSessionTTLSeconds() int { - if c == nil || c.cfg == nil { - return 3600 - } - ttl := c.cfg.Sora.Client.CurlCFFISidecar.SessionTTLSeconds - if ttl < 0 { - return 3600 - } - return ttl -} - -func convertSidecarHeaderValue(raw any) []string { - switch val := raw.(type) { - case nil: - return nil - case string: - if strings.TrimSpace(val) == "" { - return nil - } - return []string{val} - case []any: - out := make([]string, 0, len(val)) - for _, item := range val { - s := strings.TrimSpace(fmt.Sprint(item)) - if s != "" { - out = append(out, s) - } - } - return out - case []string: - out := make([]string, 0, len(val)) - for _, item := range val { - if strings.TrimSpace(item) != "" { - out = append(out, item) - } - } - return out - default: - s := strings.TrimSpace(fmt.Sprint(val)) - if s == "" { - return nil - } - return []string{s} - } -} diff --git a/backend/internal/service/sora_gateway_service.go b/backend/internal/service/sora_gateway_service.go index ac29ae0d..04c4a9a9 100644 --- a/backend/internal/service/sora_gateway_service.go +++ b/backend/internal/service/sora_gateway_service.go @@ -9,6 +9,7 @@ import ( "io" "log" "math" + "math/rand" "mime" "net" "net/http" @@ -669,7 +670,7 @@ func processSoraCharacterUsername(usernameHint string) string { if usernameHint == "" { usernameHint = "character" } - return fmt.Sprintf("%s%d", usernameHint, soraRandInt(900)+100) + return fmt.Sprintf("%s%d", usernameHint, rand.Intn(900)+100) } func (s *SoraGatewayService) resolveWatermarkFreeURL(ctx context.Context, account *Account, generationID string, opts soraWatermarkOptions) (string, string, error) { diff --git a/backend/internal/service/sora_media_storage.go b/backend/internal/service/sora_media_storage.go index 8b83cb76..eb363c4f 100644 --- a/backend/internal/service/sora_media_storage.go +++ b/backend/internal/service/sora_media_storage.go @@ -181,7 +181,7 @@ func (s *SoraMediaStorage) downloadAndStore(ctx context.Context, mediaType, rawU return relative, nil } if s.debug { - log.Printf("[SoraStorage] 下载失败(%d/%d): %s err=%v", attempt, retries, sanitizeSoraLogURL(rawURL), err) + log.Printf("[SoraStorage] 下载失败(%d/%d): %s err=%v", attempt, retries, sanitizeMediaLogURL(rawURL), err) } if attempt < retries { time.Sleep(time.Duration(attempt*attempt) * time.Second) @@ -252,7 +252,7 @@ func (s *SoraMediaStorage) downloadOnce(ctx context.Context, root, mediaType, ra relative := path.Join("/", mediaType, datePath, filename) if s.debug { - log.Printf("[SoraStorage] 已落地 %s -> %s", sanitizeSoraLogURL(rawURL), relative) + log.Printf("[SoraStorage] 已落地 %s -> %s", sanitizeMediaLogURL(rawURL), relative) } return relative, nil } @@ -305,3 +305,19 @@ func removePartialDownload(root *os.Root, filePath string) { } _ = root.Remove(filePath) } + +// sanitizeMediaLogURL 脱敏 URL 用于日志记录(去除 query 参数中可能的 token 信息) +func sanitizeMediaLogURL(rawURL string) string { + parsed, err := url.Parse(rawURL) + if err != nil { + if len(rawURL) > 80 { + return rawURL[:80] + "..." + } + return rawURL + } + safe := parsed.Scheme + "://" + parsed.Host + parsed.Path + if len(safe) > 120 { + return safe[:120] + "..." + } + return safe +} diff --git a/backend/internal/service/sora_request_guard.go b/backend/internal/service/sora_request_guard.go deleted file mode 100644 index a118fe82..00000000 --- a/backend/internal/service/sora_request_guard.go +++ /dev/null @@ -1,266 +0,0 @@ -package service - -import ( - "fmt" - "math" - "net/http" - "net/url" - "strings" - "time" - - "github.com/Wei-Shaw/sub2api/internal/util/soraerror" - "github.com/google/uuid" -) - -type soraChallengeCooldownEntry struct { - Until time.Time - StatusCode int - CFRay string - ConsecutiveChallenges int - LastChallengeAt time.Time -} - -type soraSidecarSessionEntry struct { - SessionKey string - ExpiresAt time.Time - LastUsedAt time.Time -} - -func (c *SoraDirectClient) cloudflareChallengeCooldownSeconds() int { - if c == nil || c.cfg == nil { - return 900 - } - cooldown := c.cfg.Sora.Client.CloudflareChallengeCooldownSeconds - if cooldown <= 0 { - return 0 - } - return cooldown -} - -func (c *SoraDirectClient) checkCloudflareChallengeCooldown(account *Account, proxyURL string) error { - if c == nil { - return nil - } - if account == nil || account.ID <= 0 { - return nil - } - cooldownSeconds := c.cloudflareChallengeCooldownSeconds() - if cooldownSeconds <= 0 { - return nil - } - key := soraAccountProxyKey(account, proxyURL) - now := time.Now() - - c.challengeCooldownMu.RLock() - entry, ok := c.challengeCooldowns[key] - c.challengeCooldownMu.RUnlock() - if !ok { - return nil - } - if !entry.Until.After(now) { - c.challengeCooldownMu.Lock() - delete(c.challengeCooldowns, key) - c.challengeCooldownMu.Unlock() - return nil - } - - remaining := int(math.Ceil(entry.Until.Sub(now).Seconds())) - if remaining < 1 { - remaining = 1 - } - message := fmt.Sprintf("Sora request cooling down due to recent Cloudflare challenge. Retry in %d seconds.", remaining) - if entry.ConsecutiveChallenges > 1 { - message = fmt.Sprintf("%s (streak=%d)", message, entry.ConsecutiveChallenges) - } - if entry.CFRay != "" { - message = fmt.Sprintf("%s (last cf-ray: %s)", message, entry.CFRay) - } - return &SoraUpstreamError{ - StatusCode: http.StatusTooManyRequests, - Message: message, - Headers: make(http.Header), - } -} - -func (c *SoraDirectClient) recordCloudflareChallengeCooldown(account *Account, proxyURL string, statusCode int, headers http.Header, body []byte) { - if c == nil { - return - } - if account == nil || account.ID <= 0 { - return - } - cooldownSeconds := c.cloudflareChallengeCooldownSeconds() - if cooldownSeconds <= 0 { - return - } - key := soraAccountProxyKey(account, proxyURL) - now := time.Now() - cfRay := soraerror.ExtractCloudflareRayID(headers, body) - - c.challengeCooldownMu.Lock() - c.cleanupExpiredChallengeCooldownsLocked(now) - - streak := 1 - existing, ok := c.challengeCooldowns[key] - if ok && now.Sub(existing.LastChallengeAt) <= 30*time.Minute { - streak = existing.ConsecutiveChallenges + 1 - } - effectiveCooldown := soraComputeChallengeCooldownSeconds(cooldownSeconds, streak) - until := now.Add(time.Duration(effectiveCooldown) * time.Second) - if ok && existing.Until.After(until) { - until = existing.Until - if existing.ConsecutiveChallenges > streak { - streak = existing.ConsecutiveChallenges - } - if cfRay == "" { - cfRay = existing.CFRay - } - } - c.challengeCooldowns[key] = soraChallengeCooldownEntry{ - Until: until, - StatusCode: statusCode, - CFRay: cfRay, - ConsecutiveChallenges: streak, - LastChallengeAt: now, - } - c.challengeCooldownMu.Unlock() - - if c.debugEnabled() { - remain := int(math.Ceil(until.Sub(now).Seconds())) - if remain < 0 { - remain = 0 - } - c.debugLogf("cloudflare_challenge_cooldown_set key=%s status=%d remain_s=%d streak=%d cf_ray=%s", key, statusCode, remain, streak, cfRay) - } -} - -func soraComputeChallengeCooldownSeconds(baseSeconds, streak int) int { - if baseSeconds <= 0 { - return 0 - } - if streak < 1 { - streak = 1 - } - multiplier := streak - if multiplier > 4 { - multiplier = 4 - } - cooldown := baseSeconds * multiplier - if cooldown > 3600 { - cooldown = 3600 - } - return cooldown -} - -func (c *SoraDirectClient) clearCloudflareChallengeCooldown(account *Account, proxyURL string) { - if c == nil { - return - } - if account == nil || account.ID <= 0 { - return - } - key := soraAccountProxyKey(account, proxyURL) - c.challengeCooldownMu.Lock() - _, existed := c.challengeCooldowns[key] - if existed { - delete(c.challengeCooldowns, key) - } - c.challengeCooldownMu.Unlock() - if existed && c.debugEnabled() { - c.debugLogf("cloudflare_challenge_cooldown_cleared key=%s", key) - } -} - -func (c *SoraDirectClient) sidecarSessionKey(account *Account, proxyURL string) string { - if c == nil || !c.sidecarSessionReuseEnabled() { - return "" - } - if account == nil || account.ID <= 0 { - return "" - } - key := soraAccountProxyKey(account, proxyURL) - now := time.Now() - ttlSeconds := c.sidecarSessionTTLSeconds() - - c.sidecarSessionMu.Lock() - defer c.sidecarSessionMu.Unlock() - c.cleanupExpiredSidecarSessionsLocked(now) - if existing, exists := c.sidecarSessions[key]; exists { - existing.LastUsedAt = now - c.sidecarSessions[key] = existing - return existing.SessionKey - } - - expiresAt := now.Add(time.Duration(ttlSeconds) * time.Second) - if ttlSeconds <= 0 { - expiresAt = now.Add(365 * 24 * time.Hour) - } - newEntry := soraSidecarSessionEntry{ - SessionKey: "sora-" + uuid.NewString(), - ExpiresAt: expiresAt, - LastUsedAt: now, - } - c.sidecarSessions[key] = newEntry - - if c.debugEnabled() { - c.debugLogf("sidecar_session_created key=%s ttl_s=%d", key, ttlSeconds) - } - return newEntry.SessionKey -} - -func (c *SoraDirectClient) cleanupExpiredChallengeCooldownsLocked(now time.Time) { - if c == nil || len(c.challengeCooldowns) == 0 { - return - } - for key, entry := range c.challengeCooldowns { - if !entry.Until.After(now) { - delete(c.challengeCooldowns, key) - } - } -} - -func (c *SoraDirectClient) cleanupExpiredSidecarSessionsLocked(now time.Time) { - if c == nil || len(c.sidecarSessions) == 0 { - return - } - for key, entry := range c.sidecarSessions { - if !entry.ExpiresAt.After(now) { - delete(c.sidecarSessions, key) - } - } -} - -func soraAccountProxyKey(account *Account, proxyURL string) string { - accountID := int64(0) - if account != nil { - accountID = account.ID - } - return fmt.Sprintf("account:%d|proxy:%s", accountID, normalizeSoraProxyKey(proxyURL)) -} - -func normalizeSoraProxyKey(proxyURL string) string { - raw := strings.TrimSpace(proxyURL) - if raw == "" { - return "direct" - } - parsed, err := url.Parse(raw) - if err != nil { - return strings.ToLower(raw) - } - scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme)) - host := strings.ToLower(strings.TrimSpace(parsed.Hostname())) - port := strings.TrimSpace(parsed.Port()) - if host == "" { - return strings.ToLower(raw) - } - if (scheme == "http" && port == "80") || (scheme == "https" && port == "443") { - port = "" - } - if port != "" { - host = host + ":" + port - } - if scheme == "" { - scheme = "proxy" - } - return scheme + "://" + host -} diff --git a/backend/internal/service/sora_sdk_client.go b/backend/internal/service/sora_sdk_client.go new file mode 100644 index 00000000..0f452ed8 --- /dev/null +++ b/backend/internal/service/sora_sdk_client.go @@ -0,0 +1,803 @@ +package service + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "net/http" + "strings" + "sync" + "time" + + "github.com/DouDOU-start/go-sora2api/sora" + "github.com/Wei-Shaw/sub2api/internal/config" + openaioauth "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/Wei-Shaw/sub2api/internal/util/logredact" + "github.com/tidwall/gjson" +) + +// SoraSDKClient 基于 go-sora2api SDK 的 Sora 客户端实现。 +// 它实现了 SoraClient 接口,用 SDK 替代原有的自建 HTTP/PoW/TLS 指纹逻辑。 +type SoraSDKClient struct { + cfg *config.Config + httpUpstream HTTPUpstream + tokenProvider *OpenAITokenProvider + accountRepo AccountRepository + soraAccountRepo SoraAccountRepository + + // 每个 proxyURL 对应一个 SDK 客户端实例 + sdkClients sync.Map // key: proxyURL (string), value: *sora.Client +} + +// NewSoraSDKClient 创建基于 SDK 的 Sora 客户端 +func NewSoraSDKClient(cfg *config.Config, httpUpstream HTTPUpstream, tokenProvider *OpenAITokenProvider) *SoraSDKClient { + return &SoraSDKClient{ + cfg: cfg, + httpUpstream: httpUpstream, + tokenProvider: tokenProvider, + } +} + +// SetAccountRepositories 设置账号和 Sora 扩展仓库(用于 token 持久化) +func (c *SoraSDKClient) SetAccountRepositories(accountRepo AccountRepository, soraAccountRepo SoraAccountRepository) { + if c == nil { + return + } + c.accountRepo = accountRepo + c.soraAccountRepo = soraAccountRepo +} + +// Enabled 判断是否启用 Sora +func (c *SoraSDKClient) Enabled() bool { + if c == nil || c.cfg == nil { + return false + } + return strings.TrimSpace(c.cfg.Sora.Client.BaseURL) != "" +} + +// PreflightCheck 在创建任务前执行账号能力预检。 +// 当前仅对视频模型执行预检,用于提前识别额度耗尽或能力缺失。 +func (c *SoraSDKClient) PreflightCheck(ctx context.Context, account *Account, requestedModel string, modelCfg SoraModelConfig) error { + if modelCfg.Type != "video" { + return nil + } + token, err := c.getAccessToken(ctx, account) + if err != nil { + return err + } + sdkClient, err := c.getSDKClient(account) + if err != nil { + return err + } + balance, err := sdkClient.GetCreditBalance(ctx, token) + if err != nil { + return &SoraUpstreamError{ + StatusCode: http.StatusForbidden, + Message: "当前账号未开通 Sora2 能力或无可用配额", + } + } + if balance.RateLimitReached || balance.RemainingCount <= 0 { + msg := "当前账号 Sora2 可用配额不足" + if requestedModel != "" { + msg = fmt.Sprintf("当前账号 %s 可用配额不足", requestedModel) + } + return &SoraUpstreamError{ + StatusCode: http.StatusTooManyRequests, + Message: msg, + } + } + return nil +} + +func (c *SoraSDKClient) UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) { + if len(data) == 0 { + return "", errors.New("empty image data") + } + token, err := c.getAccessToken(ctx, account) + if err != nil { + return "", err + } + sdkClient, err := c.getSDKClient(account) + if err != nil { + return "", err + } + if filename == "" { + filename = "image.png" + } + mediaID, err := sdkClient.UploadImage(ctx, token, data, filename) + if err != nil { + return "", c.wrapSDKError(err, account) + } + return mediaID, nil +} + +func (c *SoraSDKClient) CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error) { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return "", err + } + sdkClient, err := c.getSDKClient(account) + if err != nil { + return "", err + } + sentinel, err := sdkClient.GenerateSentinelToken(ctx, token) + if err != nil { + return "", c.wrapSDKError(err, account) + } + var taskID string + if strings.TrimSpace(req.MediaID) != "" { + taskID, err = sdkClient.CreateImageTaskWithImage(ctx, token, sentinel, req.Prompt, req.Width, req.Height, req.MediaID) + } else { + taskID, err = sdkClient.CreateImageTask(ctx, token, sentinel, req.Prompt, req.Width, req.Height) + } + if err != nil { + return "", c.wrapSDKError(err, account) + } + return taskID, nil +} + +func (c *SoraSDKClient) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return "", err + } + sdkClient, err := c.getSDKClient(account) + if err != nil { + return "", err + } + sentinel, err := sdkClient.GenerateSentinelToken(ctx, token) + if err != nil { + return "", c.wrapSDKError(err, account) + } + + orientation := req.Orientation + if orientation == "" { + orientation = "landscape" + } + nFrames := req.Frames + if nFrames <= 0 { + nFrames = 450 + } + model := req.Model + if model == "" { + model = "sy_8" + } + size := req.Size + if size == "" { + size = "small" + } + + // Remix 模式 + if strings.TrimSpace(req.RemixTargetID) != "" { + styleID := "" // SDK ExtractStyle 可从 prompt 中提取 + taskID, err := sdkClient.RemixVideo(ctx, token, sentinel, req.RemixTargetID, req.Prompt, orientation, nFrames, styleID) + if err != nil { + return "", c.wrapSDKError(err, account) + } + return taskID, nil + } + + // 普通视频(文生视频或图生视频) + taskID, err := sdkClient.CreateVideoTaskWithOptions(ctx, token, sentinel, req.Prompt, orientation, nFrames, model, size, req.MediaID, "") + if err != nil { + return "", c.wrapSDKError(err, account) + } + return taskID, nil +} + +func (c *SoraSDKClient) CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error) { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return "", err + } + sdkClient, err := c.getSDKClient(account) + if err != nil { + return "", err + } + sentinel, err := sdkClient.GenerateSentinelToken(ctx, token) + if err != nil { + return "", c.wrapSDKError(err, account) + } + + orientation := req.Orientation + if orientation == "" { + orientation = "landscape" + } + nFrames := req.Frames + if nFrames <= 0 { + nFrames = 450 + } + + taskID, err := sdkClient.CreateStoryboardTask(ctx, token, sentinel, req.Prompt, orientation, nFrames, req.MediaID, "") + if err != nil { + return "", c.wrapSDKError(err, account) + } + return taskID, nil +} + +func (c *SoraSDKClient) UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error) { + if len(data) == 0 { + return "", errors.New("empty video data") + } + token, err := c.getAccessToken(ctx, account) + if err != nil { + return "", err + } + sdkClient, err := c.getSDKClient(account) + if err != nil { + return "", err + } + cameoID, err := sdkClient.UploadCharacterVideo(ctx, token, data) + if err != nil { + return "", c.wrapSDKError(err, account) + } + return cameoID, nil +} + +func (c *SoraSDKClient) GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return nil, err + } + sdkClient, err := c.getSDKClient(account) + if err != nil { + return nil, err + } + status, err := sdkClient.GetCameoStatus(ctx, token, cameoID) + if err != nil { + return nil, c.wrapSDKError(err, account) + } + return &SoraCameoStatus{ + Status: status.Status, + DisplayNameHint: status.DisplayNameHint, + UsernameHint: status.UsernameHint, + ProfileAssetURL: status.ProfileAssetURL, + }, nil +} + +func (c *SoraSDKClient) DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error) { + sdkClient, err := c.getSDKClient(account) + if err != nil { + return nil, err + } + data, err := sdkClient.DownloadCharacterImage(ctx, imageURL) + if err != nil { + return nil, c.wrapSDKError(err, account) + } + return data, nil +} + +func (c *SoraSDKClient) UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error) { + if len(data) == 0 { + return "", errors.New("empty character image") + } + token, err := c.getAccessToken(ctx, account) + if err != nil { + return "", err + } + sdkClient, err := c.getSDKClient(account) + if err != nil { + return "", err + } + assetPointer, err := sdkClient.UploadCharacterImage(ctx, token, data) + if err != nil { + return "", c.wrapSDKError(err, account) + } + return assetPointer, nil +} + +func (c *SoraSDKClient) FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error) { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return "", err + } + sdkClient, err := c.getSDKClient(account) + if err != nil { + return "", err + } + characterID, err := sdkClient.FinalizeCharacter(ctx, token, req.CameoID, req.Username, req.DisplayName, req.ProfileAssetPointer) + if err != nil { + return "", c.wrapSDKError(err, account) + } + return characterID, nil +} + +func (c *SoraSDKClient) SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return err + } + sdkClient, err := c.getSDKClient(account) + if err != nil { + return err + } + if err := sdkClient.SetCharacterPublic(ctx, token, cameoID); err != nil { + return c.wrapSDKError(err, account) + } + return nil +} + +func (c *SoraSDKClient) DeleteCharacter(ctx context.Context, account *Account, characterID string) error { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return err + } + sdkClient, err := c.getSDKClient(account) + if err != nil { + return err + } + if err := sdkClient.DeleteCharacter(ctx, token, characterID); err != nil { + return c.wrapSDKError(err, account) + } + return nil +} + +func (c *SoraSDKClient) PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error) { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return "", err + } + sdkClient, err := c.getSDKClient(account) + if err != nil { + return "", err + } + sentinel, err := sdkClient.GenerateSentinelToken(ctx, token) + if err != nil { + return "", c.wrapSDKError(err, account) + } + postID, err := sdkClient.PublishVideo(ctx, token, sentinel, generationID) + if err != nil { + return "", c.wrapSDKError(err, account) + } + return postID, nil +} + +func (c *SoraSDKClient) DeletePost(ctx context.Context, account *Account, postID string) error { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return err + } + sdkClient, err := c.getSDKClient(account) + if err != nil { + return err + } + if err := sdkClient.DeletePost(ctx, token, postID); err != nil { + return c.wrapSDKError(err, account) + } + return nil +} + +// GetWatermarkFreeURLCustom 使用自定义第三方解析服务获取去水印链接。 +// SDK 不涉及此功能,保留自建实现。 +func (c *SoraSDKClient) GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error) { + parseURL = strings.TrimRight(strings.TrimSpace(parseURL), "/") + if parseURL == "" { + return "", errors.New("custom parse url is required") + } + if strings.TrimSpace(parseToken) == "" { + return "", errors.New("custom parse token is required") + } + shareURL := "https://sora.chatgpt.com/p/" + strings.TrimSpace(postID) + payload := map[string]any{ + "url": shareURL, + "token": strings.TrimSpace(parseToken), + } + body, err := json.Marshal(payload) + if err != nil { + return "", err + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, parseURL+"/get-sora-link", bytes.NewReader(body)) + if err != nil { + return "", err + } + req.Header.Set("Content-Type", "application/json") + + proxyURL := c.resolveProxyURL(account) + accountID := int64(0) + accountConcurrency := 0 + if account != nil { + accountID = account.ID + accountConcurrency = account.Concurrency + } + var resp *http.Response + if c.httpUpstream != nil { + resp, err = c.httpUpstream.Do(req, proxyURL, accountID, accountConcurrency) + } else { + resp, err = http.DefaultClient.Do(req) + } + if err != nil { + return "", err + } + defer func() { _ = resp.Body.Close() }() + raw, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20)) + if err != nil { + return "", err + } + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("custom parse failed: %d %s", resp.StatusCode, truncateForLog(raw, 256)) + } + downloadLink := strings.TrimSpace(gjson.GetBytes(raw, "download_link").String()) + if downloadLink == "" { + return "", errors.New("custom parse response missing download_link") + } + return downloadLink, nil +} + +func (c *SoraSDKClient) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return "", err + } + sdkClient, err := c.getSDKClient(account) + if err != nil { + return "", err + } + if strings.TrimSpace(expansionLevel) == "" { + expansionLevel = "medium" + } + if durationS <= 0 { + durationS = 10 + } + enhanced, err := sdkClient.EnhancePrompt(ctx, token, prompt, expansionLevel, durationS) + if err != nil { + return "", c.wrapSDKError(err, account) + } + return enhanced, nil +} + +func (c *SoraSDKClient) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return nil, err + } + sdkClient, err := c.getSDKClient(account) + if err != nil { + return nil, err + } + result := sdkClient.QueryImageTaskOnce(ctx, token, taskID, time.Now().Add(-10*time.Second)) + if result.Err != nil { + return &SoraImageTaskStatus{ + ID: taskID, + Status: "failed", + ErrorMsg: result.Err.Error(), + }, nil + } + if result.Done && result.ImageURL != "" { + return &SoraImageTaskStatus{ + ID: taskID, + Status: "succeeded", + URLs: []string{result.ImageURL}, + }, nil + } + status := result.Progress.Status + if status == "" { + status = "processing" + } + return &SoraImageTaskStatus{ + ID: taskID, + Status: status, + ProgressPct: float64(result.Progress.Percent) / 100.0, + }, nil +} + +func (c *SoraSDKClient) GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error) { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return nil, err + } + sdkClient, err := c.getSDKClient(account) + if err != nil { + return nil, err + } + + // 先查询 pending 列表 + result := sdkClient.QueryVideoTaskOnce(ctx, token, taskID, time.Now().Add(-10*time.Second), 0) + if result.Err != nil { + return &SoraVideoTaskStatus{ + ID: taskID, + Status: "failed", + ErrorMsg: result.Err.Error(), + }, nil + } + if !result.Done { + return &SoraVideoTaskStatus{ + ID: taskID, + Status: result.Progress.Status, + ProgressPct: result.Progress.Percent, + }, nil + } + + // 任务不在 pending 中,查询 drafts 获取下载链接 + downloadURL, err := sdkClient.GetDownloadURL(ctx, token, taskID) + if err != nil { + errMsg := err.Error() + if strings.Contains(errMsg, "内容违规") || strings.Contains(errMsg, "Content violates") { + return &SoraVideoTaskStatus{ + ID: taskID, + Status: "failed", + ErrorMsg: errMsg, + }, nil + } + // 可能还在处理中 + return &SoraVideoTaskStatus{ + ID: taskID, + Status: "processing", + }, nil + } + return &SoraVideoTaskStatus{ + ID: taskID, + Status: "completed", + URLs: []string{downloadURL}, + }, nil +} + +// --- 内部方法 --- + +// getSDKClient 获取或创建指定代理的 SDK 客户端实例 +func (c *SoraSDKClient) getSDKClient(account *Account) (*sora.Client, error) { + proxyURL := c.resolveProxyURL(account) + if v, ok := c.sdkClients.Load(proxyURL); ok { + return v.(*sora.Client), nil + } + client, err := sora.New(proxyURL) + if err != nil { + return nil, fmt.Errorf("创建 Sora SDK 客户端失败: %w", err) + } + actual, _ := c.sdkClients.LoadOrStore(proxyURL, client) + return actual.(*sora.Client), nil +} + +func (c *SoraSDKClient) resolveProxyURL(account *Account) string { + if account == nil || account.ProxyID == nil || account.Proxy == nil { + return "" + } + return strings.TrimSpace(account.Proxy.URL()) +} + +// getAccessToken 获取账号的 access_token,支持多种 token 来源和自动刷新。 +// 此方法保留了原 SoraDirectClient 的 token 管理逻辑。 +func (c *SoraSDKClient) getAccessToken(ctx context.Context, account *Account) (string, error) { + if account == nil { + return "", errors.New("account is nil") + } + + // 优先尝试 OpenAI Token Provider + allowProvider := c.allowOpenAITokenProvider(account) + var providerErr error + if allowProvider && c.tokenProvider != nil { + token, err := c.tokenProvider.GetAccessToken(ctx, account) + if err == nil && strings.TrimSpace(token) != "" { + c.debugLogf("token_selected account_id=%d source=openai_token_provider", account.ID) + return token, nil + } + providerErr = err + if err != nil && c.debugEnabled() { + c.debugLogf("token_provider_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error())) + } + } + + // 尝试直接使用 credentials 中的 access_token + token := strings.TrimSpace(account.GetCredential("access_token")) + if token != "" { + expiresAt := account.GetCredentialAsTime("expires_at") + if expiresAt != nil && time.Until(*expiresAt) <= 2*time.Minute { + refreshed, refreshErr := c.recoverAccessToken(ctx, account, "access_token_expiring") + if refreshErr == nil && strings.TrimSpace(refreshed) != "" { + return refreshed, nil + } + } + return token, nil + } + + // 尝试通过 session_token 或 refresh_token 恢复 + recovered, recoverErr := c.recoverAccessToken(ctx, account, "access_token_missing") + if recoverErr == nil && strings.TrimSpace(recovered) != "" { + return recovered, nil + } + if providerErr != nil { + return "", providerErr + } + return "", errors.New("access_token not found") +} + +// recoverAccessToken 通过 session_token 或 refresh_token 恢复 access_token +func (c *SoraSDKClient) recoverAccessToken(ctx context.Context, account *Account, reason string) (string, error) { + if account == nil { + return "", errors.New("account is nil") + } + + // 先尝试 session_token + if sessionToken := strings.TrimSpace(account.GetCredential("session_token")); sessionToken != "" { + accessToken, expiresAt, err := c.exchangeSessionToken(ctx, account, sessionToken) + if err == nil && strings.TrimSpace(accessToken) != "" { + c.applyRecoveredToken(ctx, account, accessToken, "", expiresAt, sessionToken) + return accessToken, nil + } + } + + // 再尝试 refresh_token + refreshToken := strings.TrimSpace(account.GetCredential("refresh_token")) + if refreshToken == "" { + return "", errors.New("session_token/refresh_token not found") + } + + sdkClient, err := c.getSDKClient(account) + if err != nil { + return "", err + } + + // 尝试多个 client_id + clientIDs := []string{ + strings.TrimSpace(account.GetCredential("client_id")), + openaioauth.SoraClientID, + openaioauth.ClientID, + } + tried := make(map[string]struct{}, len(clientIDs)) + var lastErr error + + for _, clientID := range clientIDs { + if clientID == "" { + continue + } + if _, ok := tried[clientID]; ok { + continue + } + tried[clientID] = struct{}{} + + newAccess, newRefresh, refreshErr := sdkClient.RefreshAccessToken(ctx, refreshToken, clientID) + if refreshErr != nil { + lastErr = refreshErr + continue + } + if strings.TrimSpace(newAccess) == "" { + lastErr = errors.New("refreshed access_token is empty") + continue + } + c.applyRecoveredToken(ctx, account, newAccess, newRefresh, "", "") + return newAccess, nil + } + + if lastErr != nil { + return "", lastErr + } + return "", errors.New("no available client_id for refresh_token exchange") +} + +// exchangeSessionToken 通过 session_token 换取 access_token +func (c *SoraSDKClient) exchangeSessionToken(ctx context.Context, account *Account, sessionToken string) (string, string, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://sora.chatgpt.com/api/auth/session", nil) + if err != nil { + return "", "", err + } + req.Header.Set("Cookie", "__Secure-next-auth.session-token="+sessionToken) + req.Header.Set("Accept", "application/json") + req.Header.Set("Origin", "https://sora.chatgpt.com") + req.Header.Set("Referer", "https://sora.chatgpt.com/") + req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") + + proxyURL := c.resolveProxyURL(account) + accountID := int64(0) + accountConcurrency := 0 + if account != nil { + accountID = account.ID + accountConcurrency = account.Concurrency + } + + var resp *http.Response + if c.httpUpstream != nil { + resp, err = c.httpUpstream.Do(req, proxyURL, accountID, accountConcurrency) + } else { + resp, err = http.DefaultClient.Do(req) + } + if err != nil { + return "", "", err + } + defer func() { _ = resp.Body.Close() }() + body, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + if err != nil { + return "", "", err + } + if resp.StatusCode != http.StatusOK { + return "", "", fmt.Errorf("session exchange failed: %d", resp.StatusCode) + } + + accessToken := strings.TrimSpace(gjson.GetBytes(body, "accessToken").String()) + if accessToken == "" { + return "", "", errors.New("session exchange missing accessToken") + } + expiresAt := strings.TrimSpace(gjson.GetBytes(body, "expires").String()) + return accessToken, expiresAt, nil +} + +// applyRecoveredToken 将恢复的 token 写入账号内存和数据库 +func (c *SoraSDKClient) applyRecoveredToken(ctx context.Context, account *Account, accessToken, refreshToken, expiresAt, sessionToken string) { + if account == nil { + return + } + if account.Credentials == nil { + account.Credentials = make(map[string]any) + } + if strings.TrimSpace(accessToken) != "" { + account.Credentials["access_token"] = accessToken + } + if strings.TrimSpace(refreshToken) != "" { + account.Credentials["refresh_token"] = refreshToken + } + if strings.TrimSpace(expiresAt) != "" { + account.Credentials["expires_at"] = expiresAt + } + if strings.TrimSpace(sessionToken) != "" { + account.Credentials["session_token"] = sessionToken + } + + if c.accountRepo != nil { + if err := c.accountRepo.Update(ctx, account); err != nil && c.debugEnabled() { + c.debugLogf("persist_recovered_token_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error())) + } + } + c.updateSoraAccountExtension(ctx, account, accessToken, refreshToken, sessionToken) +} + +func (c *SoraSDKClient) updateSoraAccountExtension(ctx context.Context, account *Account, accessToken, refreshToken, sessionToken string) { + if c == nil || c.soraAccountRepo == nil || account == nil || account.ID <= 0 { + return + } + updates := make(map[string]any) + if strings.TrimSpace(accessToken) != "" && strings.TrimSpace(refreshToken) != "" { + updates["access_token"] = accessToken + updates["refresh_token"] = refreshToken + } + if strings.TrimSpace(sessionToken) != "" { + updates["session_token"] = sessionToken + } + if len(updates) == 0 { + return + } + if err := c.soraAccountRepo.Upsert(ctx, account.ID, updates); err != nil && c.debugEnabled() { + c.debugLogf("persist_sora_extension_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error())) + } +} + +func (c *SoraSDKClient) allowOpenAITokenProvider(account *Account) bool { + if c == nil || c.tokenProvider == nil { + return false + } + if account != nil && account.Platform == PlatformSora { + return c.cfg != nil && c.cfg.Sora.Client.UseOpenAITokenProvider + } + return true +} + +// wrapSDKError 将 SDK 错误包装为 SoraUpstreamError +func (c *SoraSDKClient) wrapSDKError(err error, account *Account) error { + if err == nil { + return nil + } + msg := err.Error() + statusCode := http.StatusBadGateway + if strings.Contains(msg, "HTTP 401") || strings.Contains(msg, "HTTP 403") { + statusCode = http.StatusUnauthorized + } else if strings.Contains(msg, "HTTP 429") { + statusCode = http.StatusTooManyRequests + } else if strings.Contains(msg, "HTTP 404") { + statusCode = http.StatusNotFound + } + return &SoraUpstreamError{ + StatusCode: statusCode, + Message: msg, + } +} + +func (c *SoraSDKClient) debugEnabled() bool { + return c != nil && c.cfg != nil && c.cfg.Sora.Client.Debug +} + +func (c *SoraSDKClient) debugLogf(format string, args ...any) { + if c.debugEnabled() { + log.Printf("[SoraSDK] "+format, args...) + } +} diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index bd241566..f04acc00 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -206,14 +206,14 @@ func ProvideSoraMediaStorage(cfg *config.Config) *SoraMediaStorage { return NewSoraMediaStorage(cfg) } -func ProvideSoraDirectClient( +func ProvideSoraSDKClient( cfg *config.Config, httpUpstream HTTPUpstream, tokenProvider *OpenAITokenProvider, accountRepo AccountRepository, soraAccountRepo SoraAccountRepository, -) *SoraDirectClient { - client := NewSoraDirectClient(cfg, httpUpstream, tokenProvider) +) *SoraSDKClient { + client := NewSoraSDKClient(cfg, httpUpstream, tokenProvider) client.SetAccountRepositories(accountRepo, soraAccountRepo) return client } @@ -306,8 +306,8 @@ var ProviderSet = wire.NewSet( NewGatewayService, ProvideSoraMediaStorage, ProvideSoraMediaCleanupService, - ProvideSoraDirectClient, - wire.Bind(new(SoraClient), new(*SoraDirectClient)), + ProvideSoraSDKClient, + wire.Bind(new(SoraClient), new(*SoraSDKClient)), NewSoraGatewayService, NewOpenAIGatewayService, NewOAuthService, From 65d4ca25634f9a5e374afcf4431483cad490d7cd Mon Sep 17 00:00:00 2001 From: huangenjun <1021217094@qq.com> Date: Wed, 25 Feb 2026 11:32:56 +0800 Subject: [PATCH 2/4] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E6=B5=81=E5=BC=8F?= =?UTF-8?q?=E5=93=8D=E5=BA=94=E4=B8=AD=20URL=20=E7=9A=84=20&=20=E8=A2=AB?= =?UTF-8?q?=E8=BD=AC=E4=B9=89=E4=B8=BA=20\u0026=20=E7=9A=84=E9=97=AE?= =?UTF-8?q?=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 新增 jsonMarshalRaw 使用 SetEscapeHTML(false) 替代 json.Marshal, 避免 HTML 字符转义导致客户端无法直接使用返回的 URL。 Co-Authored-By: Claude Opus 4.6 --- .../internal/service/sora_gateway_service.go | 22 +++++++++++++++++-- .../service/sora_gateway_streaming_legacy.go | 4 ++-- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/backend/internal/service/sora_gateway_service.go b/backend/internal/service/sora_gateway_service.go index 04c4a9a9..b8241eef 100644 --- a/backend/internal/service/sora_gateway_service.go +++ b/backend/internal/service/sora_gateway_service.go @@ -1,6 +1,7 @@ package service import ( + "bytes" "context" "encoding/base64" "encoding/json" @@ -830,7 +831,7 @@ func (s *SoraGatewayService) writeSoraStream(c *gin.Context, model, content stri }, }, } - encoded, _ := json.Marshal(chunk) + encoded, _ := jsonMarshalRaw(chunk) if _, err := fmt.Fprintf(writer, "data: %s\n\n", encoded); err != nil { return nil, err } @@ -851,7 +852,7 @@ func (s *SoraGatewayService) writeSoraStream(c *gin.Context, model, content stri }, }, } - finalEncoded, _ := json.Marshal(finalChunk) + finalEncoded, _ := jsonMarshalRaw(finalChunk) if _, err := fmt.Fprintf(writer, "data: %s\n\n", finalEncoded); err != nil { return &ms, err } @@ -1052,6 +1053,23 @@ func (s *SoraGatewayService) normalizeSoraMediaURLs(urls []string) []string { return output } +// jsonMarshalRaw 序列化 JSON,不转义 &、<、> 等 HTML 字符, +// 避免 URL 中的 & 被转义为 \u0026 导致客户端无法直接使用。 +func jsonMarshalRaw(v any) ([]byte, error) { + var buf bytes.Buffer + enc := json.NewEncoder(&buf) + enc.SetEscapeHTML(false) + if err := enc.Encode(v); err != nil { + return nil, err + } + // Encode 会追加换行符,去掉它 + b := buf.Bytes() + if len(b) > 0 && b[len(b)-1] == '\n' { + b = b[:len(b)-1] + } + return b, nil +} + func buildSoraContent(mediaType string, urls []string) string { switch mediaType { case "image": diff --git a/backend/internal/service/sora_gateway_streaming_legacy.go b/backend/internal/service/sora_gateway_streaming_legacy.go index 8a38f181..d399ba1c 100644 --- a/backend/internal/service/sora_gateway_streaming_legacy.go +++ b/backend/internal/service/sora_gateway_streaming_legacy.go @@ -316,7 +316,7 @@ func (s *SoraGatewayService) processSoraSSEData(data string, originalModel strin } } - updatedData, err := json.Marshal(payload) + updatedData, err := jsonMarshalRaw(payload) if err != nil { return "data: " + data, contentDelta, nil } @@ -484,7 +484,7 @@ func (s *SoraGatewayService) flushSoraRewriteBuffer(buffer string, originalModel if originalModel != "" { payload["model"] = originalModel } - updatedData, err := json.Marshal(payload) + updatedData, err := jsonMarshalRaw(payload) if err != nil { return "", "", err } From 26060e702f210d4c19115ea4559229927c67923f Mon Sep 17 00:00:00 2001 From: huangenjun <1021217094@qq.com> Date: Wed, 25 Feb 2026 11:33:07 +0800 Subject: [PATCH 3/4] =?UTF-8?q?feat:=20Sora=20=E5=B9=B3=E5=8F=B0=E6=94=AF?= =?UTF-8?q?=E6=8C=81=E6=89=8B=E5=8A=A8=E5=AF=BC=E5=85=A5=20Access=20Token?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 新增 Access Token 输入方式,支持批量粘贴(每行一个)直接创建账号, 无需经过 OAuth 授权流程。 Co-Authored-By: Claude Opus 4.6 --- .../components/account/CreateAccountModal.vue | 79 +++++++++++++++++ .../account/OAuthAuthorizationFlow.vue | 87 ++++++++++++++++++- frontend/src/composables/useAccountOAuth.ts | 2 +- 3 files changed, 166 insertions(+), 2 deletions(-) diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index 25100c82..72d74318 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -1816,12 +1816,14 @@ :show-cookie-option="form.platform === 'anthropic'" :show-refresh-token-option="form.platform === 'openai' || form.platform === 'sora' || form.platform === 'antigravity'" :show-session-token-option="form.platform === 'sora'" + :show-access-token-option="form.platform === 'sora'" :platform="form.platform" :show-project-id="geminiOAuthType === 'code_assist'" @generate-url="handleGenerateUrl" @cookie-auth="handleCookieAuth" @validate-refresh-token="handleValidateRefreshToken" @validate-session-token="handleValidateSessionToken" + @import-access-token="handleImportAccessToken" /> @@ -3188,6 +3190,83 @@ const handleValidateSessionToken = (sessionToken: string) => { } } +// Sora 手动 AT 批量导入 +const handleImportAccessToken = async (accessTokenInput: string) => { + const oauthClient = activeOpenAIOAuth.value + if (!accessTokenInput.trim()) return + + const accessTokens = accessTokenInput + .split('\n') + .map((at) => at.trim()) + .filter((at) => at) + + if (accessTokens.length === 0) { + oauthClient.error.value = 'Please enter at least one Access Token' + return + } + + oauthClient.loading.value = true + oauthClient.error.value = '' + + let successCount = 0 + let failedCount = 0 + const errors: string[] = [] + + try { + for (let i = 0; i < accessTokens.length; i++) { + try { + const credentials: Record = { + access_token: accessTokens[i], + } + const soraExtra = buildSoraExtra() + + const accountName = accessTokens.length > 1 ? `${form.name} #${i + 1}` : form.name + await adminAPI.accounts.create({ + name: accountName, + notes: form.notes, + platform: 'sora', + type: 'oauth', + credentials, + extra: soraExtra, + proxy_id: form.proxy_id, + concurrency: form.concurrency, + priority: form.priority, + rate_multiplier: form.rate_multiplier, + group_ids: form.group_ids, + expires_at: form.expires_at, + auto_pause_on_expired: autoPauseOnExpired.value + }) + successCount++ + } catch (error: any) { + failedCount++ + const errMsg = error.response?.data?.detail || error.message || 'Unknown error' + errors.push(`#${i + 1}: ${errMsg}`) + } + } + + if (successCount > 0 && failedCount === 0) { + appStore.showSuccess( + accessTokens.length > 1 + ? t('admin.accounts.oauth.batchSuccess', { count: successCount }) + : t('admin.accounts.accountCreated') + ) + emit('created') + handleClose() + } else if (successCount > 0 && failedCount > 0) { + appStore.showWarning( + t('admin.accounts.oauth.batchPartialSuccess', { success: successCount, failed: failedCount }) + ) + oauthClient.error.value = errors.join('\n') + emit('created') + } else { + oauthClient.error.value = errors.join('\n') + appStore.showError(t('admin.accounts.oauth.batchFailed')) + } + } finally { + oauthClient.loading.value = false + } +} + const formatDateTimeLocal = formatDateTimeLocalInput const parseDateTimeLocal = parseDateTimeLocalInput diff --git a/frontend/src/components/account/OAuthAuthorizationFlow.vue b/frontend/src/components/account/OAuthAuthorizationFlow.vue index 8e00d25b..94d417dc 100644 --- a/frontend/src/components/account/OAuthAuthorizationFlow.vue +++ b/frontend/src/components/account/OAuthAuthorizationFlow.vue @@ -59,6 +59,17 @@ t(getOAuthKey('sessionTokenAuth')) }} + @@ -227,6 +238,63 @@ + +
+
+

+ {{ t('admin.accounts.oauth.openai.accessTokenDesc', '直接粘贴 Access Token 创建账号,无需 OAuth 授权流程。支持批量导入(每行一个)。') }} +

+ +
+ + +

+ {{ t('admin.accounts.oauth.batchCreateAccounts', { count: parsedAccessTokenCount }) }} +

+
+ +
+

+ {{ error }} +

+
+ + +
+
+
(), { showCookieOption: true, showRefreshTokenOption: false, showSessionTokenOption: false, + showAccessTokenOption: false, platform: 'anthropic', showProjectId: true }) @@ -644,6 +714,7 @@ const emit = defineEmits<{ 'cookie-auth': [sessionKey: string] 'validate-refresh-token': [refreshToken: string] 'validate-session-token': [sessionToken: string] + 'import-access-token': [accessToken: string] 'update:inputMethod': [method: AuthInputMethod] }>() @@ -683,12 +754,13 @@ const authCodeInput = ref('') const sessionKeyInput = ref('') const refreshTokenInput = ref('') const sessionTokenInput = ref('') +const accessTokenInput = ref('') const showHelpDialog = ref(false) const oauthState = ref('') const projectId = ref('') // Computed: show method selection when either cookie or refresh token option is enabled -const showMethodSelection = computed(() => props.showCookieOption || props.showRefreshTokenOption || props.showSessionTokenOption) +const showMethodSelection = computed(() => props.showCookieOption || props.showRefreshTokenOption || props.showSessionTokenOption || props.showAccessTokenOption) // Clipboard const { copied, copyToClipboard } = useClipboard() @@ -716,6 +788,13 @@ const parsedSessionTokenCount = computed(() => { .filter((st) => st).length }) +const parsedAccessTokenCount = computed(() => { + return accessTokenInput.value + .split('\n') + .map((at) => at.trim()) + .filter((at) => at).length +}) + // Watchers watch(inputMethod, (newVal) => { emit('update:inputMethod', newVal) @@ -789,6 +868,12 @@ const handleValidateSessionToken = () => { } } +const handleImportAccessToken = () => { + if (accessTokenInput.value.trim()) { + emit('import-access-token', accessTokenInput.value.trim()) + } +} + // Expose methods and state defineExpose({ authCode: authCodeInput, diff --git a/frontend/src/composables/useAccountOAuth.ts b/frontend/src/composables/useAccountOAuth.ts index 6f53404c..b6f33186 100644 --- a/frontend/src/composables/useAccountOAuth.ts +++ b/frontend/src/composables/useAccountOAuth.ts @@ -3,7 +3,7 @@ import { useAppStore } from '@/stores/app' import { adminAPI } from '@/api/admin' export type AddMethod = 'oauth' | 'setup-token' -export type AuthInputMethod = 'manual' | 'cookie' | 'refresh_token' | 'session_token' +export type AuthInputMethod = 'manual' | 'cookie' | 'refresh_token' | 'session_token' | 'access_token' export interface OAuthState { authUrl: string From 935ea66681dd0690d6847089df6e878778114373 Mon Sep 17 00:00:00 2001 From: huangenjun <1021217094@qq.com> Date: Wed, 25 Feb 2026 11:43:08 +0800 Subject: [PATCH 4/4] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=20sora=5Fsdk=5Fcl?= =?UTF-8?q?ient=20=E7=B1=BB=E5=9E=8B=E6=96=AD=E8=A8=80=E6=9C=AA=E6=A3=80?= =?UTF-8?q?=E6=9F=A5=E7=9A=84=20errcheck=20lint=20=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 使用安全的 comma-ok 模式替代裸类型断言,避免 golangci-lint errcheck 报错。 Co-Authored-By: Claude Opus 4.6 --- backend/internal/service/sora_sdk_client.go | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/backend/internal/service/sora_sdk_client.go b/backend/internal/service/sora_sdk_client.go index 0f452ed8..604c2749 100644 --- a/backend/internal/service/sora_sdk_client.go +++ b/backend/internal/service/sora_sdk_client.go @@ -541,14 +541,19 @@ func (c *SoraSDKClient) GetVideoTask(ctx context.Context, account *Account, task func (c *SoraSDKClient) getSDKClient(account *Account) (*sora.Client, error) { proxyURL := c.resolveProxyURL(account) if v, ok := c.sdkClients.Load(proxyURL); ok { - return v.(*sora.Client), nil + if cli, ok2 := v.(*sora.Client); ok2 { + return cli, nil + } } client, err := sora.New(proxyURL) if err != nil { return nil, fmt.Errorf("创建 Sora SDK 客户端失败: %w", err) } actual, _ := c.sdkClients.LoadOrStore(proxyURL, client) - return actual.(*sora.Client), nil + if cli, ok := actual.(*sora.Client); ok { + return cli, nil + } + return client, nil } func (c *SoraSDKClient) resolveProxyURL(account *Account) string {