refactor: 使用 go-sora2api SDK 替代自建 Sora 客户端
使用 go-sora2api v1.1.0 SDK 替代原有 ~2000 行自建 HTTP/PoW/TLS 指纹代码, SDK 提供高并发性能优化(实例级 rand、PoW 缓冲区复用、context.Context 支持)。 - 新增 SoraSDKClient 适配器实现 SoraClient 接口 - 精简 sora_client.go 为仅保留接口和类型定义 - 更新 Wire 绑定使用 SoraSDKClient - 删除 SoraDirectClient、sora_curl_cffi_sidecar、sora_request_guard 等旧代码 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -1,515 +0,0 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// ---------- 辅助解析函数(复制生产代码中的 gjson 解析逻辑,用于单元测试) ----------
|
||||
|
||||
// testParseUploadOrCreateTaskID 模拟 UploadImage / CreateImageTask / CreateVideoTask 中
|
||||
// 用 gjson.GetBytes(respBody, "id") 提取 id 的逻辑。
|
||||
func testParseUploadOrCreateTaskID(respBody []byte) (string, error) {
|
||||
id := strings.TrimSpace(gjson.GetBytes(respBody, "id").String())
|
||||
if id == "" {
|
||||
return "", assert.AnError // 占位错误,表示 "missing id"
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// testParseFetchRecentImageTask 模拟 fetchRecentImageTask 中的 gjson.ForEach 解析逻辑。
|
||||
func testParseFetchRecentImageTask(respBody []byte, taskID string) (*SoraImageTaskStatus, bool) {
|
||||
var found *SoraImageTaskStatus
|
||||
gjson.GetBytes(respBody, "task_responses").ForEach(func(_, item gjson.Result) bool {
|
||||
if item.Get("id").String() != taskID {
|
||||
return true // continue
|
||||
}
|
||||
status := strings.TrimSpace(item.Get("status").String())
|
||||
progress := item.Get("progress_pct").Float()
|
||||
var urls []string
|
||||
item.Get("generations").ForEach(func(_, gen gjson.Result) bool {
|
||||
if u := strings.TrimSpace(gen.Get("url").String()); u != "" {
|
||||
urls = append(urls, u)
|
||||
}
|
||||
return true
|
||||
})
|
||||
found = &SoraImageTaskStatus{
|
||||
ID: taskID,
|
||||
Status: status,
|
||||
ProgressPct: progress,
|
||||
URLs: urls,
|
||||
}
|
||||
return false // break
|
||||
})
|
||||
if found != nil {
|
||||
return found, true
|
||||
}
|
||||
return &SoraImageTaskStatus{ID: taskID, Status: "processing"}, false
|
||||
}
|
||||
|
||||
// testParseGetVideoTaskPending 模拟 GetVideoTask 中解析 pending 列表的逻辑。
|
||||
func testParseGetVideoTaskPending(respBody []byte, taskID string) (*SoraVideoTaskStatus, bool) {
|
||||
pendingResult := gjson.ParseBytes(respBody)
|
||||
if !pendingResult.IsArray() {
|
||||
return nil, false
|
||||
}
|
||||
var pendingFound *SoraVideoTaskStatus
|
||||
pendingResult.ForEach(func(_, task gjson.Result) bool {
|
||||
if task.Get("id").String() != taskID {
|
||||
return true
|
||||
}
|
||||
progress := 0
|
||||
if v := task.Get("progress_pct"); v.Exists() {
|
||||
progress = int(v.Float() * 100)
|
||||
}
|
||||
status := strings.TrimSpace(task.Get("status").String())
|
||||
pendingFound = &SoraVideoTaskStatus{
|
||||
ID: taskID,
|
||||
Status: status,
|
||||
ProgressPct: progress,
|
||||
}
|
||||
return false
|
||||
})
|
||||
if pendingFound != nil {
|
||||
return pendingFound, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// testParseGetVideoTaskDrafts 模拟 GetVideoTask 中解析 drafts 列表的逻辑。
|
||||
func testParseGetVideoTaskDrafts(respBody []byte, taskID string) (*SoraVideoTaskStatus, bool) {
|
||||
var draftFound *SoraVideoTaskStatus
|
||||
gjson.GetBytes(respBody, "items").ForEach(func(_, draft gjson.Result) bool {
|
||||
if draft.Get("task_id").String() != taskID {
|
||||
return true
|
||||
}
|
||||
kind := strings.TrimSpace(draft.Get("kind").String())
|
||||
reason := strings.TrimSpace(draft.Get("reason_str").String())
|
||||
if reason == "" {
|
||||
reason = strings.TrimSpace(draft.Get("markdown_reason_str").String())
|
||||
}
|
||||
urlStr := strings.TrimSpace(draft.Get("downloadable_url").String())
|
||||
if urlStr == "" {
|
||||
urlStr = strings.TrimSpace(draft.Get("url").String())
|
||||
}
|
||||
|
||||
if kind == "sora_content_violation" || reason != "" || urlStr == "" {
|
||||
msg := reason
|
||||
if msg == "" {
|
||||
msg = "Content violates guardrails"
|
||||
}
|
||||
draftFound = &SoraVideoTaskStatus{
|
||||
ID: taskID,
|
||||
Status: "failed",
|
||||
ErrorMsg: msg,
|
||||
}
|
||||
} else {
|
||||
draftFound = &SoraVideoTaskStatus{
|
||||
ID: taskID,
|
||||
Status: "completed",
|
||||
URLs: []string{urlStr},
|
||||
}
|
||||
}
|
||||
return false
|
||||
})
|
||||
if draftFound != nil {
|
||||
return draftFound, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// ===================== Test 1: TestSoraParseUploadResponse =====================
|
||||
|
||||
func TestSoraParseUploadResponse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
wantID string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "正常 id",
|
||||
body: `{"id":"file-abc123","status":"uploaded"}`,
|
||||
wantID: "file-abc123",
|
||||
},
|
||||
{
|
||||
name: "空 id",
|
||||
body: `{"id":"","status":"uploaded"}`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "无 id 字段",
|
||||
body: `{"status":"uploaded"}`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "id 全为空白",
|
||||
body: `{"id":" ","status":"uploaded"}`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "id 前后有空白",
|
||||
body: `{"id":" file-trimmed ","status":"uploaded"}`,
|
||||
wantID: "file-trimmed",
|
||||
},
|
||||
{
|
||||
name: "空 JSON 对象",
|
||||
body: `{}`,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
id, err := testParseUploadOrCreateTaskID([]byte(tt.body))
|
||||
if tt.wantErr {
|
||||
require.Error(t, err, "应返回错误")
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.wantID, id)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ===================== Test 2: TestSoraParseCreateTaskResponse =====================
|
||||
|
||||
func TestSoraParseCreateTaskResponse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
wantID string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "正常任务 id",
|
||||
body: `{"id":"task-123"}`,
|
||||
wantID: "task-123",
|
||||
},
|
||||
{
|
||||
name: "缺失 id",
|
||||
body: `{"status":"created"}`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "空 id",
|
||||
body: `{"id":" "}`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "id 为数字(gjson 转字符串)",
|
||||
body: `{"id":123}`,
|
||||
wantID: "123",
|
||||
},
|
||||
{
|
||||
name: "id 含特殊字符",
|
||||
body: `{"id":"task-abc-def-456-ghi"}`,
|
||||
wantID: "task-abc-def-456-ghi",
|
||||
},
|
||||
{
|
||||
name: "额外字段不影响解析",
|
||||
body: `{"id":"task-999","type":"image_gen","extra":"data"}`,
|
||||
wantID: "task-999",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
id, err := testParseUploadOrCreateTaskID([]byte(tt.body))
|
||||
if tt.wantErr {
|
||||
require.Error(t, err, "应返回错误")
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.wantID, id)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ===================== Test 3: TestSoraParseFetchRecentImageTask =====================
|
||||
|
||||
func TestSoraParseFetchRecentImageTask(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
taskID string
|
||||
wantFound bool
|
||||
wantStatus string
|
||||
wantProgress float64
|
||||
wantURLs []string
|
||||
}{
|
||||
{
|
||||
name: "匹配已完成任务",
|
||||
body: `{"task_responses":[{"id":"task-1","status":"completed","progress_pct":1.0,"generations":[{"url":"https://example.com/img.png"}]}]}`,
|
||||
taskID: "task-1",
|
||||
wantFound: true,
|
||||
wantStatus: "completed",
|
||||
wantProgress: 1.0,
|
||||
wantURLs: []string{"https://example.com/img.png"},
|
||||
},
|
||||
{
|
||||
name: "匹配处理中任务",
|
||||
body: `{"task_responses":[{"id":"task-2","status":"processing","progress_pct":0.5,"generations":[]}]}`,
|
||||
taskID: "task-2",
|
||||
wantFound: true,
|
||||
wantStatus: "processing",
|
||||
wantProgress: 0.5,
|
||||
wantURLs: nil,
|
||||
},
|
||||
{
|
||||
name: "无匹配任务",
|
||||
body: `{"task_responses":[{"id":"other","status":"completed"}]}`,
|
||||
taskID: "task-1",
|
||||
wantFound: false,
|
||||
wantStatus: "processing",
|
||||
},
|
||||
{
|
||||
name: "空 task_responses",
|
||||
body: `{"task_responses":[]}`,
|
||||
taskID: "task-1",
|
||||
wantFound: false,
|
||||
wantStatus: "processing",
|
||||
},
|
||||
{
|
||||
name: "缺少 task_responses 字段",
|
||||
body: `{"other":"data"}`,
|
||||
taskID: "task-1",
|
||||
wantFound: false,
|
||||
wantStatus: "processing",
|
||||
},
|
||||
{
|
||||
name: "多个任务中精准匹配",
|
||||
body: `{"task_responses":[{"id":"task-a","status":"completed","progress_pct":1.0,"generations":[{"url":"https://a.com/1.png"}]},{"id":"task-b","status":"processing","progress_pct":0.3,"generations":[]},{"id":"task-c","status":"failed","progress_pct":0}]}`,
|
||||
taskID: "task-b",
|
||||
wantFound: true,
|
||||
wantStatus: "processing",
|
||||
wantProgress: 0.3,
|
||||
wantURLs: nil,
|
||||
},
|
||||
{
|
||||
name: "多个 generations",
|
||||
body: `{"task_responses":[{"id":"task-m","status":"completed","progress_pct":1.0,"generations":[{"url":"https://a.com/1.png"},{"url":"https://a.com/2.png"},{"url":""}]}]}`,
|
||||
taskID: "task-m",
|
||||
wantFound: true,
|
||||
wantStatus: "completed",
|
||||
wantProgress: 1.0,
|
||||
wantURLs: []string{"https://a.com/1.png", "https://a.com/2.png"},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
status, found := testParseFetchRecentImageTask([]byte(tt.body), tt.taskID)
|
||||
require.Equal(t, tt.wantFound, found, "found 不匹配")
|
||||
require.NotNil(t, status)
|
||||
require.Equal(t, tt.taskID, status.ID)
|
||||
require.Equal(t, tt.wantStatus, status.Status)
|
||||
if tt.wantFound {
|
||||
require.InDelta(t, tt.wantProgress, status.ProgressPct, 0.001, "进度不匹配")
|
||||
require.Equal(t, tt.wantURLs, status.URLs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ===================== Test 4: TestSoraParseGetVideoTaskPending =====================
|
||||
|
||||
func TestSoraParseGetVideoTaskPending(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
taskID string
|
||||
wantFound bool
|
||||
wantStatus string
|
||||
wantProgress int
|
||||
}{
|
||||
{
|
||||
name: "匹配 pending 任务",
|
||||
body: `[{"id":"task-1","status":"processing","progress_pct":0.5}]`,
|
||||
taskID: "task-1",
|
||||
wantFound: true,
|
||||
wantStatus: "processing",
|
||||
wantProgress: 50,
|
||||
},
|
||||
{
|
||||
name: "进度为 0",
|
||||
body: `[{"id":"task-2","status":"queued","progress_pct":0}]`,
|
||||
taskID: "task-2",
|
||||
wantFound: true,
|
||||
wantStatus: "queued",
|
||||
wantProgress: 0,
|
||||
},
|
||||
{
|
||||
name: "进度为 1(100%)",
|
||||
body: `[{"id":"task-3","status":"completing","progress_pct":1.0}]`,
|
||||
taskID: "task-3",
|
||||
wantFound: true,
|
||||
wantStatus: "completing",
|
||||
wantProgress: 100,
|
||||
},
|
||||
{
|
||||
name: "空数组",
|
||||
body: `[]`,
|
||||
taskID: "task-1",
|
||||
wantFound: false,
|
||||
},
|
||||
{
|
||||
name: "无匹配 id",
|
||||
body: `[{"id":"task-other","status":"processing","progress_pct":0.3}]`,
|
||||
taskID: "task-1",
|
||||
wantFound: false,
|
||||
},
|
||||
{
|
||||
name: "多个任务精准匹配",
|
||||
body: `[{"id":"task-a","status":"processing","progress_pct":0.2},{"id":"task-b","status":"queued","progress_pct":0},{"id":"task-c","status":"processing","progress_pct":0.8}]`,
|
||||
taskID: "task-c",
|
||||
wantFound: true,
|
||||
wantStatus: "processing",
|
||||
wantProgress: 80,
|
||||
},
|
||||
{
|
||||
name: "非数组 JSON",
|
||||
body: `{"id":"task-1","status":"processing"}`,
|
||||
taskID: "task-1",
|
||||
wantFound: false,
|
||||
},
|
||||
{
|
||||
name: "无 progress_pct 字段",
|
||||
body: `[{"id":"task-4","status":"pending"}]`,
|
||||
taskID: "task-4",
|
||||
wantFound: true,
|
||||
wantStatus: "pending",
|
||||
wantProgress: 0,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
status, found := testParseGetVideoTaskPending([]byte(tt.body), tt.taskID)
|
||||
require.Equal(t, tt.wantFound, found, "found 不匹配")
|
||||
if tt.wantFound {
|
||||
require.NotNil(t, status)
|
||||
require.Equal(t, tt.taskID, status.ID)
|
||||
require.Equal(t, tt.wantStatus, status.Status)
|
||||
require.Equal(t, tt.wantProgress, status.ProgressPct)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ===================== Test 5: TestSoraParseGetVideoTaskDrafts =====================
|
||||
|
||||
func TestSoraParseGetVideoTaskDrafts(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
taskID string
|
||||
wantFound bool
|
||||
wantStatus string
|
||||
wantURLs []string
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "正常完成的视频",
|
||||
body: `{"items":[{"task_id":"task-1","kind":"video","downloadable_url":"https://example.com/video.mp4"}]}`,
|
||||
taskID: "task-1",
|
||||
wantFound: true,
|
||||
wantStatus: "completed",
|
||||
wantURLs: []string{"https://example.com/video.mp4"},
|
||||
},
|
||||
{
|
||||
name: "使用 url 字段回退",
|
||||
body: `{"items":[{"task_id":"task-2","kind":"video","url":"https://example.com/fallback.mp4"}]}`,
|
||||
taskID: "task-2",
|
||||
wantFound: true,
|
||||
wantStatus: "completed",
|
||||
wantURLs: []string{"https://example.com/fallback.mp4"},
|
||||
},
|
||||
{
|
||||
name: "内容违规",
|
||||
body: `{"items":[{"task_id":"task-3","kind":"sora_content_violation","reason_str":"Content policy violation"}]}`,
|
||||
taskID: "task-3",
|
||||
wantFound: true,
|
||||
wantStatus: "failed",
|
||||
wantErr: "Content policy violation",
|
||||
},
|
||||
{
|
||||
name: "内容违规 - markdown_reason_str 回退",
|
||||
body: `{"items":[{"task_id":"task-4","kind":"sora_content_violation","markdown_reason_str":"Markdown reason"}]}`,
|
||||
taskID: "task-4",
|
||||
wantFound: true,
|
||||
wantStatus: "failed",
|
||||
wantErr: "Markdown reason",
|
||||
},
|
||||
{
|
||||
name: "内容违规 - 无 reason 使用默认消息",
|
||||
body: `{"items":[{"task_id":"task-5","kind":"sora_content_violation"}]}`,
|
||||
taskID: "task-5",
|
||||
wantFound: true,
|
||||
wantStatus: "failed",
|
||||
wantErr: "Content violates guardrails",
|
||||
},
|
||||
{
|
||||
name: "有 reason_str 但非 violation kind(仍判定失败)",
|
||||
body: `{"items":[{"task_id":"task-6","kind":"video","reason_str":"Some error occurred"}]}`,
|
||||
taskID: "task-6",
|
||||
wantFound: true,
|
||||
wantStatus: "failed",
|
||||
wantErr: "Some error occurred",
|
||||
},
|
||||
{
|
||||
name: "空 URL 判定为失败",
|
||||
body: `{"items":[{"task_id":"task-7","kind":"video","downloadable_url":"","url":""}]}`,
|
||||
taskID: "task-7",
|
||||
wantFound: true,
|
||||
wantStatus: "failed",
|
||||
wantErr: "Content violates guardrails",
|
||||
},
|
||||
{
|
||||
name: "无匹配 task_id",
|
||||
body: `{"items":[{"task_id":"task-other","kind":"video","downloadable_url":"https://example.com/video.mp4"}]}`,
|
||||
taskID: "task-1",
|
||||
wantFound: false,
|
||||
},
|
||||
{
|
||||
name: "空 items",
|
||||
body: `{"items":[]}`,
|
||||
taskID: "task-1",
|
||||
wantFound: false,
|
||||
},
|
||||
{
|
||||
name: "缺少 items 字段",
|
||||
body: `{"other":"data"}`,
|
||||
taskID: "task-1",
|
||||
wantFound: false,
|
||||
},
|
||||
{
|
||||
name: "多个 items 精准匹配",
|
||||
body: `{"items":[{"task_id":"task-a","kind":"video","downloadable_url":"https://a.com/a.mp4"},{"task_id":"task-b","kind":"sora_content_violation","reason_str":"Bad content"},{"task_id":"task-c","kind":"video","downloadable_url":"https://c.com/c.mp4"}]}`,
|
||||
taskID: "task-b",
|
||||
wantFound: true,
|
||||
wantStatus: "failed",
|
||||
wantErr: "Bad content",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
status, found := testParseGetVideoTaskDrafts([]byte(tt.body), tt.taskID)
|
||||
require.Equal(t, tt.wantFound, found, "found 不匹配")
|
||||
if !tt.wantFound {
|
||||
return
|
||||
}
|
||||
require.NotNil(t, status)
|
||||
require.Equal(t, tt.taskID, status.ID)
|
||||
require.Equal(t, tt.wantStatus, status.Status)
|
||||
if tt.wantErr != "" {
|
||||
require.Equal(t, tt.wantErr, status.ErrorMsg)
|
||||
}
|
||||
if tt.wantURLs != nil {
|
||||
require.Equal(t, tt.wantURLs, status.URLs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,260 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
|
||||
)
|
||||
|
||||
const soraCurlCFFISidecarDefaultTimeoutSeconds = 60
|
||||
|
||||
type soraCurlCFFISidecarRequest struct {
|
||||
Method string `json:"method"`
|
||||
URL string `json:"url"`
|
||||
Headers map[string][]string `json:"headers,omitempty"`
|
||||
BodyBase64 string `json:"body_base64,omitempty"`
|
||||
ProxyURL string `json:"proxy_url,omitempty"`
|
||||
SessionKey string `json:"session_key,omitempty"`
|
||||
Impersonate string `json:"impersonate,omitempty"`
|
||||
TimeoutSeconds int `json:"timeout_seconds,omitempty"`
|
||||
}
|
||||
|
||||
type soraCurlCFFISidecarResponse struct {
|
||||
StatusCode int `json:"status_code"`
|
||||
Status int `json:"status"`
|
||||
Headers map[string]any `json:"headers"`
|
||||
BodyBase64 string `json:"body_base64"`
|
||||
Body string `json:"body"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) doHTTPViaCurlCFFISidecar(req *http.Request, proxyURL string, account *Account) (*http.Response, error) {
|
||||
if req == nil || req.URL == nil {
|
||||
return nil, errors.New("request url is nil")
|
||||
}
|
||||
if c == nil || c.cfg == nil {
|
||||
return nil, errors.New("sora curl_cffi sidecar config is nil")
|
||||
}
|
||||
if !c.cfg.Sora.Client.CurlCFFISidecar.Enabled {
|
||||
return nil, errors.New("sora curl_cffi sidecar is disabled")
|
||||
}
|
||||
endpoint := c.curlCFFISidecarEndpoint()
|
||||
if endpoint == "" {
|
||||
return nil, errors.New("sora curl_cffi sidecar base_url is empty")
|
||||
}
|
||||
|
||||
bodyBytes, err := readAndRestoreRequestBody(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sora curl_cffi sidecar read request body failed: %w", err)
|
||||
}
|
||||
|
||||
headers := make(map[string][]string, len(req.Header)+1)
|
||||
for key, vals := range req.Header {
|
||||
copied := make([]string, len(vals))
|
||||
copy(copied, vals)
|
||||
headers[key] = copied
|
||||
}
|
||||
if strings.TrimSpace(req.Host) != "" {
|
||||
if _, ok := headers["Host"]; !ok {
|
||||
headers["Host"] = []string{req.Host}
|
||||
}
|
||||
}
|
||||
|
||||
payload := soraCurlCFFISidecarRequest{
|
||||
Method: req.Method,
|
||||
URL: req.URL.String(),
|
||||
Headers: headers,
|
||||
ProxyURL: strings.TrimSpace(proxyURL),
|
||||
SessionKey: c.sidecarSessionKey(account, proxyURL),
|
||||
Impersonate: c.curlCFFIImpersonate(),
|
||||
TimeoutSeconds: c.curlCFFISidecarTimeoutSeconds(),
|
||||
}
|
||||
if len(bodyBytes) > 0 {
|
||||
payload.BodyBase64 = base64.StdEncoding.EncodeToString(bodyBytes)
|
||||
}
|
||||
|
||||
encoded, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sora curl_cffi sidecar marshal request failed: %w", err)
|
||||
}
|
||||
|
||||
sidecarReq, err := http.NewRequestWithContext(req.Context(), http.MethodPost, endpoint, bytes.NewReader(encoded))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sora curl_cffi sidecar build request failed: %w", err)
|
||||
}
|
||||
sidecarReq.Header.Set("Content-Type", "application/json")
|
||||
sidecarReq.Header.Set("Accept", "application/json")
|
||||
|
||||
httpClient := &http.Client{Timeout: time.Duration(payload.TimeoutSeconds) * time.Second}
|
||||
sidecarResp, err := httpClient.Do(sidecarReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sora curl_cffi sidecar request failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = sidecarResp.Body.Close()
|
||||
}()
|
||||
|
||||
sidecarRespBody, err := io.ReadAll(io.LimitReader(sidecarResp.Body, 8<<20))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sora curl_cffi sidecar read response failed: %w", err)
|
||||
}
|
||||
if sidecarResp.StatusCode != http.StatusOK {
|
||||
redacted := truncateForLog([]byte(logredact.RedactText(string(sidecarRespBody))), 512)
|
||||
return nil, fmt.Errorf("sora curl_cffi sidecar http status=%d body=%s", sidecarResp.StatusCode, redacted)
|
||||
}
|
||||
|
||||
var payloadResp soraCurlCFFISidecarResponse
|
||||
if err := json.Unmarshal(sidecarRespBody, &payloadResp); err != nil {
|
||||
return nil, fmt.Errorf("sora curl_cffi sidecar parse response failed: %w", err)
|
||||
}
|
||||
if msg := strings.TrimSpace(payloadResp.Error); msg != "" {
|
||||
return nil, fmt.Errorf("sora curl_cffi sidecar upstream error: %s", msg)
|
||||
}
|
||||
statusCode := payloadResp.StatusCode
|
||||
if statusCode <= 0 {
|
||||
statusCode = payloadResp.Status
|
||||
}
|
||||
if statusCode <= 0 {
|
||||
return nil, errors.New("sora curl_cffi sidecar response missing status code")
|
||||
}
|
||||
|
||||
responseBody := []byte(payloadResp.Body)
|
||||
if strings.TrimSpace(payloadResp.BodyBase64) != "" {
|
||||
decoded, err := base64.StdEncoding.DecodeString(payloadResp.BodyBase64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sora curl_cffi sidecar decode body failed: %w", err)
|
||||
}
|
||||
responseBody = decoded
|
||||
}
|
||||
|
||||
respHeaders := make(http.Header)
|
||||
for key, rawVal := range payloadResp.Headers {
|
||||
for _, v := range convertSidecarHeaderValue(rawVal) {
|
||||
respHeaders.Add(key, v)
|
||||
}
|
||||
}
|
||||
|
||||
return &http.Response{
|
||||
StatusCode: statusCode,
|
||||
Header: respHeaders,
|
||||
Body: io.NopCloser(bytes.NewReader(responseBody)),
|
||||
ContentLength: int64(len(responseBody)),
|
||||
Request: req,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func readAndRestoreRequestBody(req *http.Request) ([]byte, error) {
|
||||
if req == nil || req.Body == nil {
|
||||
return nil, nil
|
||||
}
|
||||
bodyBytes, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_ = req.Body.Close()
|
||||
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
req.ContentLength = int64(len(bodyBytes))
|
||||
return bodyBytes, nil
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) curlCFFISidecarEndpoint() string {
|
||||
if c == nil || c.cfg == nil {
|
||||
return ""
|
||||
}
|
||||
raw := strings.TrimSpace(c.cfg.Sora.Client.CurlCFFISidecar.BaseURL)
|
||||
if raw == "" {
|
||||
return ""
|
||||
}
|
||||
parsed, err := url.Parse(raw)
|
||||
if err != nil || strings.TrimSpace(parsed.Scheme) == "" || strings.TrimSpace(parsed.Host) == "" {
|
||||
return raw
|
||||
}
|
||||
if path := strings.TrimSpace(parsed.Path); path == "" || path == "/" {
|
||||
parsed.Path = "/request"
|
||||
}
|
||||
return parsed.String()
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) curlCFFISidecarTimeoutSeconds() int {
|
||||
if c == nil || c.cfg == nil {
|
||||
return soraCurlCFFISidecarDefaultTimeoutSeconds
|
||||
}
|
||||
timeoutSeconds := c.cfg.Sora.Client.CurlCFFISidecar.TimeoutSeconds
|
||||
if timeoutSeconds <= 0 {
|
||||
return soraCurlCFFISidecarDefaultTimeoutSeconds
|
||||
}
|
||||
return timeoutSeconds
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) curlCFFIImpersonate() string {
|
||||
if c == nil || c.cfg == nil {
|
||||
return "chrome131"
|
||||
}
|
||||
impersonate := strings.TrimSpace(c.cfg.Sora.Client.CurlCFFISidecar.Impersonate)
|
||||
if impersonate == "" {
|
||||
return "chrome131"
|
||||
}
|
||||
return impersonate
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) sidecarSessionReuseEnabled() bool {
|
||||
if c == nil || c.cfg == nil {
|
||||
return true
|
||||
}
|
||||
return c.cfg.Sora.Client.CurlCFFISidecar.SessionReuseEnabled
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) sidecarSessionTTLSeconds() int {
|
||||
if c == nil || c.cfg == nil {
|
||||
return 3600
|
||||
}
|
||||
ttl := c.cfg.Sora.Client.CurlCFFISidecar.SessionTTLSeconds
|
||||
if ttl < 0 {
|
||||
return 3600
|
||||
}
|
||||
return ttl
|
||||
}
|
||||
|
||||
func convertSidecarHeaderValue(raw any) []string {
|
||||
switch val := raw.(type) {
|
||||
case nil:
|
||||
return nil
|
||||
case string:
|
||||
if strings.TrimSpace(val) == "" {
|
||||
return nil
|
||||
}
|
||||
return []string{val}
|
||||
case []any:
|
||||
out := make([]string, 0, len(val))
|
||||
for _, item := range val {
|
||||
s := strings.TrimSpace(fmt.Sprint(item))
|
||||
if s != "" {
|
||||
out = append(out, s)
|
||||
}
|
||||
}
|
||||
return out
|
||||
case []string:
|
||||
out := make([]string, 0, len(val))
|
||||
for _, item := range val {
|
||||
if strings.TrimSpace(item) != "" {
|
||||
out = append(out, item)
|
||||
}
|
||||
}
|
||||
return out
|
||||
default:
|
||||
s := strings.TrimSpace(fmt.Sprint(val))
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
return []string{s}
|
||||
}
|
||||
}
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"io"
|
||||
"log"
|
||||
"math"
|
||||
"math/rand"
|
||||
"mime"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -669,7 +670,7 @@ func processSoraCharacterUsername(usernameHint string) string {
|
||||
if usernameHint == "" {
|
||||
usernameHint = "character"
|
||||
}
|
||||
return fmt.Sprintf("%s%d", usernameHint, soraRandInt(900)+100)
|
||||
return fmt.Sprintf("%s%d", usernameHint, rand.Intn(900)+100)
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) resolveWatermarkFreeURL(ctx context.Context, account *Account, generationID string, opts soraWatermarkOptions) (string, string, error) {
|
||||
|
||||
@@ -181,7 +181,7 @@ func (s *SoraMediaStorage) downloadAndStore(ctx context.Context, mediaType, rawU
|
||||
return relative, nil
|
||||
}
|
||||
if s.debug {
|
||||
log.Printf("[SoraStorage] 下载失败(%d/%d): %s err=%v", attempt, retries, sanitizeSoraLogURL(rawURL), err)
|
||||
log.Printf("[SoraStorage] 下载失败(%d/%d): %s err=%v", attempt, retries, sanitizeMediaLogURL(rawURL), err)
|
||||
}
|
||||
if attempt < retries {
|
||||
time.Sleep(time.Duration(attempt*attempt) * time.Second)
|
||||
@@ -252,7 +252,7 @@ func (s *SoraMediaStorage) downloadOnce(ctx context.Context, root, mediaType, ra
|
||||
|
||||
relative := path.Join("/", mediaType, datePath, filename)
|
||||
if s.debug {
|
||||
log.Printf("[SoraStorage] 已落地 %s -> %s", sanitizeSoraLogURL(rawURL), relative)
|
||||
log.Printf("[SoraStorage] 已落地 %s -> %s", sanitizeMediaLogURL(rawURL), relative)
|
||||
}
|
||||
return relative, nil
|
||||
}
|
||||
@@ -305,3 +305,19 @@ func removePartialDownload(root *os.Root, filePath string) {
|
||||
}
|
||||
_ = root.Remove(filePath)
|
||||
}
|
||||
|
||||
// sanitizeMediaLogURL 脱敏 URL 用于日志记录(去除 query 参数中可能的 token 信息)
|
||||
func sanitizeMediaLogURL(rawURL string) string {
|
||||
parsed, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
if len(rawURL) > 80 {
|
||||
return rawURL[:80] + "..."
|
||||
}
|
||||
return rawURL
|
||||
}
|
||||
safe := parsed.Scheme + "://" + parsed.Host + parsed.Path
|
||||
if len(safe) > 120 {
|
||||
return safe[:120] + "..."
|
||||
}
|
||||
return safe
|
||||
}
|
||||
|
||||
@@ -1,266 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/soraerror"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type soraChallengeCooldownEntry struct {
|
||||
Until time.Time
|
||||
StatusCode int
|
||||
CFRay string
|
||||
ConsecutiveChallenges int
|
||||
LastChallengeAt time.Time
|
||||
}
|
||||
|
||||
type soraSidecarSessionEntry struct {
|
||||
SessionKey string
|
||||
ExpiresAt time.Time
|
||||
LastUsedAt time.Time
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) cloudflareChallengeCooldownSeconds() int {
|
||||
if c == nil || c.cfg == nil {
|
||||
return 900
|
||||
}
|
||||
cooldown := c.cfg.Sora.Client.CloudflareChallengeCooldownSeconds
|
||||
if cooldown <= 0 {
|
||||
return 0
|
||||
}
|
||||
return cooldown
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) checkCloudflareChallengeCooldown(account *Account, proxyURL string) error {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
if account == nil || account.ID <= 0 {
|
||||
return nil
|
||||
}
|
||||
cooldownSeconds := c.cloudflareChallengeCooldownSeconds()
|
||||
if cooldownSeconds <= 0 {
|
||||
return nil
|
||||
}
|
||||
key := soraAccountProxyKey(account, proxyURL)
|
||||
now := time.Now()
|
||||
|
||||
c.challengeCooldownMu.RLock()
|
||||
entry, ok := c.challengeCooldowns[key]
|
||||
c.challengeCooldownMu.RUnlock()
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if !entry.Until.After(now) {
|
||||
c.challengeCooldownMu.Lock()
|
||||
delete(c.challengeCooldowns, key)
|
||||
c.challengeCooldownMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
remaining := int(math.Ceil(entry.Until.Sub(now).Seconds()))
|
||||
if remaining < 1 {
|
||||
remaining = 1
|
||||
}
|
||||
message := fmt.Sprintf("Sora request cooling down due to recent Cloudflare challenge. Retry in %d seconds.", remaining)
|
||||
if entry.ConsecutiveChallenges > 1 {
|
||||
message = fmt.Sprintf("%s (streak=%d)", message, entry.ConsecutiveChallenges)
|
||||
}
|
||||
if entry.CFRay != "" {
|
||||
message = fmt.Sprintf("%s (last cf-ray: %s)", message, entry.CFRay)
|
||||
}
|
||||
return &SoraUpstreamError{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Message: message,
|
||||
Headers: make(http.Header),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) recordCloudflareChallengeCooldown(account *Account, proxyURL string, statusCode int, headers http.Header, body []byte) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
if account == nil || account.ID <= 0 {
|
||||
return
|
||||
}
|
||||
cooldownSeconds := c.cloudflareChallengeCooldownSeconds()
|
||||
if cooldownSeconds <= 0 {
|
||||
return
|
||||
}
|
||||
key := soraAccountProxyKey(account, proxyURL)
|
||||
now := time.Now()
|
||||
cfRay := soraerror.ExtractCloudflareRayID(headers, body)
|
||||
|
||||
c.challengeCooldownMu.Lock()
|
||||
c.cleanupExpiredChallengeCooldownsLocked(now)
|
||||
|
||||
streak := 1
|
||||
existing, ok := c.challengeCooldowns[key]
|
||||
if ok && now.Sub(existing.LastChallengeAt) <= 30*time.Minute {
|
||||
streak = existing.ConsecutiveChallenges + 1
|
||||
}
|
||||
effectiveCooldown := soraComputeChallengeCooldownSeconds(cooldownSeconds, streak)
|
||||
until := now.Add(time.Duration(effectiveCooldown) * time.Second)
|
||||
if ok && existing.Until.After(until) {
|
||||
until = existing.Until
|
||||
if existing.ConsecutiveChallenges > streak {
|
||||
streak = existing.ConsecutiveChallenges
|
||||
}
|
||||
if cfRay == "" {
|
||||
cfRay = existing.CFRay
|
||||
}
|
||||
}
|
||||
c.challengeCooldowns[key] = soraChallengeCooldownEntry{
|
||||
Until: until,
|
||||
StatusCode: statusCode,
|
||||
CFRay: cfRay,
|
||||
ConsecutiveChallenges: streak,
|
||||
LastChallengeAt: now,
|
||||
}
|
||||
c.challengeCooldownMu.Unlock()
|
||||
|
||||
if c.debugEnabled() {
|
||||
remain := int(math.Ceil(until.Sub(now).Seconds()))
|
||||
if remain < 0 {
|
||||
remain = 0
|
||||
}
|
||||
c.debugLogf("cloudflare_challenge_cooldown_set key=%s status=%d remain_s=%d streak=%d cf_ray=%s", key, statusCode, remain, streak, cfRay)
|
||||
}
|
||||
}
|
||||
|
||||
func soraComputeChallengeCooldownSeconds(baseSeconds, streak int) int {
|
||||
if baseSeconds <= 0 {
|
||||
return 0
|
||||
}
|
||||
if streak < 1 {
|
||||
streak = 1
|
||||
}
|
||||
multiplier := streak
|
||||
if multiplier > 4 {
|
||||
multiplier = 4
|
||||
}
|
||||
cooldown := baseSeconds * multiplier
|
||||
if cooldown > 3600 {
|
||||
cooldown = 3600
|
||||
}
|
||||
return cooldown
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) clearCloudflareChallengeCooldown(account *Account, proxyURL string) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
if account == nil || account.ID <= 0 {
|
||||
return
|
||||
}
|
||||
key := soraAccountProxyKey(account, proxyURL)
|
||||
c.challengeCooldownMu.Lock()
|
||||
_, existed := c.challengeCooldowns[key]
|
||||
if existed {
|
||||
delete(c.challengeCooldowns, key)
|
||||
}
|
||||
c.challengeCooldownMu.Unlock()
|
||||
if existed && c.debugEnabled() {
|
||||
c.debugLogf("cloudflare_challenge_cooldown_cleared key=%s", key)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) sidecarSessionKey(account *Account, proxyURL string) string {
|
||||
if c == nil || !c.sidecarSessionReuseEnabled() {
|
||||
return ""
|
||||
}
|
||||
if account == nil || account.ID <= 0 {
|
||||
return ""
|
||||
}
|
||||
key := soraAccountProxyKey(account, proxyURL)
|
||||
now := time.Now()
|
||||
ttlSeconds := c.sidecarSessionTTLSeconds()
|
||||
|
||||
c.sidecarSessionMu.Lock()
|
||||
defer c.sidecarSessionMu.Unlock()
|
||||
c.cleanupExpiredSidecarSessionsLocked(now)
|
||||
if existing, exists := c.sidecarSessions[key]; exists {
|
||||
existing.LastUsedAt = now
|
||||
c.sidecarSessions[key] = existing
|
||||
return existing.SessionKey
|
||||
}
|
||||
|
||||
expiresAt := now.Add(time.Duration(ttlSeconds) * time.Second)
|
||||
if ttlSeconds <= 0 {
|
||||
expiresAt = now.Add(365 * 24 * time.Hour)
|
||||
}
|
||||
newEntry := soraSidecarSessionEntry{
|
||||
SessionKey: "sora-" + uuid.NewString(),
|
||||
ExpiresAt: expiresAt,
|
||||
LastUsedAt: now,
|
||||
}
|
||||
c.sidecarSessions[key] = newEntry
|
||||
|
||||
if c.debugEnabled() {
|
||||
c.debugLogf("sidecar_session_created key=%s ttl_s=%d", key, ttlSeconds)
|
||||
}
|
||||
return newEntry.SessionKey
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) cleanupExpiredChallengeCooldownsLocked(now time.Time) {
|
||||
if c == nil || len(c.challengeCooldowns) == 0 {
|
||||
return
|
||||
}
|
||||
for key, entry := range c.challengeCooldowns {
|
||||
if !entry.Until.After(now) {
|
||||
delete(c.challengeCooldowns, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) cleanupExpiredSidecarSessionsLocked(now time.Time) {
|
||||
if c == nil || len(c.sidecarSessions) == 0 {
|
||||
return
|
||||
}
|
||||
for key, entry := range c.sidecarSessions {
|
||||
if !entry.ExpiresAt.After(now) {
|
||||
delete(c.sidecarSessions, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func soraAccountProxyKey(account *Account, proxyURL string) string {
|
||||
accountID := int64(0)
|
||||
if account != nil {
|
||||
accountID = account.ID
|
||||
}
|
||||
return fmt.Sprintf("account:%d|proxy:%s", accountID, normalizeSoraProxyKey(proxyURL))
|
||||
}
|
||||
|
||||
func normalizeSoraProxyKey(proxyURL string) string {
|
||||
raw := strings.TrimSpace(proxyURL)
|
||||
if raw == "" {
|
||||
return "direct"
|
||||
}
|
||||
parsed, err := url.Parse(raw)
|
||||
if err != nil {
|
||||
return strings.ToLower(raw)
|
||||
}
|
||||
scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme))
|
||||
host := strings.ToLower(strings.TrimSpace(parsed.Hostname()))
|
||||
port := strings.TrimSpace(parsed.Port())
|
||||
if host == "" {
|
||||
return strings.ToLower(raw)
|
||||
}
|
||||
if (scheme == "http" && port == "80") || (scheme == "https" && port == "443") {
|
||||
port = ""
|
||||
}
|
||||
if port != "" {
|
||||
host = host + ":" + port
|
||||
}
|
||||
if scheme == "" {
|
||||
scheme = "proxy"
|
||||
}
|
||||
return scheme + "://" + host
|
||||
}
|
||||
803
backend/internal/service/sora_sdk_client.go
Normal file
803
backend/internal/service/sora_sdk_client.go
Normal file
@@ -0,0 +1,803 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/DouDOU-start/go-sora2api/sora"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
openaioauth "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// SoraSDKClient 基于 go-sora2api SDK 的 Sora 客户端实现。
|
||||
// 它实现了 SoraClient 接口,用 SDK 替代原有的自建 HTTP/PoW/TLS 指纹逻辑。
|
||||
type SoraSDKClient struct {
|
||||
cfg *config.Config
|
||||
httpUpstream HTTPUpstream
|
||||
tokenProvider *OpenAITokenProvider
|
||||
accountRepo AccountRepository
|
||||
soraAccountRepo SoraAccountRepository
|
||||
|
||||
// 每个 proxyURL 对应一个 SDK 客户端实例
|
||||
sdkClients sync.Map // key: proxyURL (string), value: *sora.Client
|
||||
}
|
||||
|
||||
// NewSoraSDKClient 创建基于 SDK 的 Sora 客户端
|
||||
func NewSoraSDKClient(cfg *config.Config, httpUpstream HTTPUpstream, tokenProvider *OpenAITokenProvider) *SoraSDKClient {
|
||||
return &SoraSDKClient{
|
||||
cfg: cfg,
|
||||
httpUpstream: httpUpstream,
|
||||
tokenProvider: tokenProvider,
|
||||
}
|
||||
}
|
||||
|
||||
// SetAccountRepositories 设置账号和 Sora 扩展仓库(用于 token 持久化)
|
||||
func (c *SoraSDKClient) SetAccountRepositories(accountRepo AccountRepository, soraAccountRepo SoraAccountRepository) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.accountRepo = accountRepo
|
||||
c.soraAccountRepo = soraAccountRepo
|
||||
}
|
||||
|
||||
// Enabled 判断是否启用 Sora
|
||||
func (c *SoraSDKClient) Enabled() bool {
|
||||
if c == nil || c.cfg == nil {
|
||||
return false
|
||||
}
|
||||
return strings.TrimSpace(c.cfg.Sora.Client.BaseURL) != ""
|
||||
}
|
||||
|
||||
// PreflightCheck 在创建任务前执行账号能力预检。
|
||||
// 当前仅对视频模型执行预检,用于提前识别额度耗尽或能力缺失。
|
||||
func (c *SoraSDKClient) PreflightCheck(ctx context.Context, account *Account, requestedModel string, modelCfg SoraModelConfig) error {
|
||||
if modelCfg.Type != "video" {
|
||||
return nil
|
||||
}
|
||||
token, err := c.getAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sdkClient, err := c.getSDKClient(account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
balance, err := sdkClient.GetCreditBalance(ctx, token)
|
||||
if err != nil {
|
||||
return &SoraUpstreamError{
|
||||
StatusCode: http.StatusForbidden,
|
||||
Message: "当前账号未开通 Sora2 能力或无可用配额",
|
||||
}
|
||||
}
|
||||
if balance.RateLimitReached || balance.RemainingCount <= 0 {
|
||||
msg := "当前账号 Sora2 可用配额不足"
|
||||
if requestedModel != "" {
|
||||
msg = fmt.Sprintf("当前账号 %s 可用配额不足", requestedModel)
|
||||
}
|
||||
return &SoraUpstreamError{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Message: msg,
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *SoraSDKClient) UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) {
|
||||
if len(data) == 0 {
|
||||
return "", errors.New("empty image data")
|
||||
}
|
||||
token, err := c.getAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sdkClient, err := c.getSDKClient(account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if filename == "" {
|
||||
filename = "image.png"
|
||||
}
|
||||
mediaID, err := sdkClient.UploadImage(ctx, token, data, filename)
|
||||
if err != nil {
|
||||
return "", c.wrapSDKError(err, account)
|
||||
}
|
||||
return mediaID, nil
|
||||
}
|
||||
|
||||
func (c *SoraSDKClient) CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error) {
|
||||
token, err := c.getAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sdkClient, err := c.getSDKClient(account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sentinel, err := sdkClient.GenerateSentinelToken(ctx, token)
|
||||
if err != nil {
|
||||
return "", c.wrapSDKError(err, account)
|
||||
}
|
||||
var taskID string
|
||||
if strings.TrimSpace(req.MediaID) != "" {
|
||||
taskID, err = sdkClient.CreateImageTaskWithImage(ctx, token, sentinel, req.Prompt, req.Width, req.Height, req.MediaID)
|
||||
} else {
|
||||
taskID, err = sdkClient.CreateImageTask(ctx, token, sentinel, req.Prompt, req.Width, req.Height)
|
||||
}
|
||||
if err != nil {
|
||||
return "", c.wrapSDKError(err, account)
|
||||
}
|
||||
return taskID, nil
|
||||
}
|
||||
|
||||
func (c *SoraSDKClient) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) {
|
||||
token, err := c.getAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sdkClient, err := c.getSDKClient(account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sentinel, err := sdkClient.GenerateSentinelToken(ctx, token)
|
||||
if err != nil {
|
||||
return "", c.wrapSDKError(err, account)
|
||||
}
|
||||
|
||||
orientation := req.Orientation
|
||||
if orientation == "" {
|
||||
orientation = "landscape"
|
||||
}
|
||||
nFrames := req.Frames
|
||||
if nFrames <= 0 {
|
||||
nFrames = 450
|
||||
}
|
||||
model := req.Model
|
||||
if model == "" {
|
||||
model = "sy_8"
|
||||
}
|
||||
size := req.Size
|
||||
if size == "" {
|
||||
size = "small"
|
||||
}
|
||||
|
||||
// Remix 模式
|
||||
if strings.TrimSpace(req.RemixTargetID) != "" {
|
||||
styleID := "" // SDK ExtractStyle 可从 prompt 中提取
|
||||
taskID, err := sdkClient.RemixVideo(ctx, token, sentinel, req.RemixTargetID, req.Prompt, orientation, nFrames, styleID)
|
||||
if err != nil {
|
||||
return "", c.wrapSDKError(err, account)
|
||||
}
|
||||
return taskID, nil
|
||||
}
|
||||
|
||||
// 普通视频(文生视频或图生视频)
|
||||
taskID, err := sdkClient.CreateVideoTaskWithOptions(ctx, token, sentinel, req.Prompt, orientation, nFrames, model, size, req.MediaID, "")
|
||||
if err != nil {
|
||||
return "", c.wrapSDKError(err, account)
|
||||
}
|
||||
return taskID, nil
|
||||
}
|
||||
|
||||
func (c *SoraSDKClient) CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error) {
|
||||
token, err := c.getAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sdkClient, err := c.getSDKClient(account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sentinel, err := sdkClient.GenerateSentinelToken(ctx, token)
|
||||
if err != nil {
|
||||
return "", c.wrapSDKError(err, account)
|
||||
}
|
||||
|
||||
orientation := req.Orientation
|
||||
if orientation == "" {
|
||||
orientation = "landscape"
|
||||
}
|
||||
nFrames := req.Frames
|
||||
if nFrames <= 0 {
|
||||
nFrames = 450
|
||||
}
|
||||
|
||||
taskID, err := sdkClient.CreateStoryboardTask(ctx, token, sentinel, req.Prompt, orientation, nFrames, req.MediaID, "")
|
||||
if err != nil {
|
||||
return "", c.wrapSDKError(err, account)
|
||||
}
|
||||
return taskID, nil
|
||||
}
|
||||
|
||||
func (c *SoraSDKClient) UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error) {
|
||||
if len(data) == 0 {
|
||||
return "", errors.New("empty video data")
|
||||
}
|
||||
token, err := c.getAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sdkClient, err := c.getSDKClient(account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
cameoID, err := sdkClient.UploadCharacterVideo(ctx, token, data)
|
||||
if err != nil {
|
||||
return "", c.wrapSDKError(err, account)
|
||||
}
|
||||
return cameoID, nil
|
||||
}
|
||||
|
||||
func (c *SoraSDKClient) GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) {
|
||||
token, err := c.getAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sdkClient, err := c.getSDKClient(account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
status, err := sdkClient.GetCameoStatus(ctx, token, cameoID)
|
||||
if err != nil {
|
||||
return nil, c.wrapSDKError(err, account)
|
||||
}
|
||||
return &SoraCameoStatus{
|
||||
Status: status.Status,
|
||||
DisplayNameHint: status.DisplayNameHint,
|
||||
UsernameHint: status.UsernameHint,
|
||||
ProfileAssetURL: status.ProfileAssetURL,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *SoraSDKClient) DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error) {
|
||||
sdkClient, err := c.getSDKClient(account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
data, err := sdkClient.DownloadCharacterImage(ctx, imageURL)
|
||||
if err != nil {
|
||||
return nil, c.wrapSDKError(err, account)
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (c *SoraSDKClient) UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error) {
|
||||
if len(data) == 0 {
|
||||
return "", errors.New("empty character image")
|
||||
}
|
||||
token, err := c.getAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sdkClient, err := c.getSDKClient(account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
assetPointer, err := sdkClient.UploadCharacterImage(ctx, token, data)
|
||||
if err != nil {
|
||||
return "", c.wrapSDKError(err, account)
|
||||
}
|
||||
return assetPointer, nil
|
||||
}
|
||||
|
||||
func (c *SoraSDKClient) FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error) {
|
||||
token, err := c.getAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sdkClient, err := c.getSDKClient(account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
characterID, err := sdkClient.FinalizeCharacter(ctx, token, req.CameoID, req.Username, req.DisplayName, req.ProfileAssetPointer)
|
||||
if err != nil {
|
||||
return "", c.wrapSDKError(err, account)
|
||||
}
|
||||
return characterID, nil
|
||||
}
|
||||
|
||||
func (c *SoraSDKClient) SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error {
|
||||
token, err := c.getAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sdkClient, err := c.getSDKClient(account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := sdkClient.SetCharacterPublic(ctx, token, cameoID); err != nil {
|
||||
return c.wrapSDKError(err, account)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *SoraSDKClient) DeleteCharacter(ctx context.Context, account *Account, characterID string) error {
|
||||
token, err := c.getAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sdkClient, err := c.getSDKClient(account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := sdkClient.DeleteCharacter(ctx, token, characterID); err != nil {
|
||||
return c.wrapSDKError(err, account)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *SoraSDKClient) PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error) {
|
||||
token, err := c.getAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sdkClient, err := c.getSDKClient(account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sentinel, err := sdkClient.GenerateSentinelToken(ctx, token)
|
||||
if err != nil {
|
||||
return "", c.wrapSDKError(err, account)
|
||||
}
|
||||
postID, err := sdkClient.PublishVideo(ctx, token, sentinel, generationID)
|
||||
if err != nil {
|
||||
return "", c.wrapSDKError(err, account)
|
||||
}
|
||||
return postID, nil
|
||||
}
|
||||
|
||||
func (c *SoraSDKClient) DeletePost(ctx context.Context, account *Account, postID string) error {
|
||||
token, err := c.getAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sdkClient, err := c.getSDKClient(account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := sdkClient.DeletePost(ctx, token, postID); err != nil {
|
||||
return c.wrapSDKError(err, account)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetWatermarkFreeURLCustom 使用自定义第三方解析服务获取去水印链接。
|
||||
// SDK 不涉及此功能,保留自建实现。
|
||||
func (c *SoraSDKClient) GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error) {
|
||||
parseURL = strings.TrimRight(strings.TrimSpace(parseURL), "/")
|
||||
if parseURL == "" {
|
||||
return "", errors.New("custom parse url is required")
|
||||
}
|
||||
if strings.TrimSpace(parseToken) == "" {
|
||||
return "", errors.New("custom parse token is required")
|
||||
}
|
||||
shareURL := "https://sora.chatgpt.com/p/" + strings.TrimSpace(postID)
|
||||
payload := map[string]any{
|
||||
"url": shareURL,
|
||||
"token": strings.TrimSpace(parseToken),
|
||||
}
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, parseURL+"/get-sora-link", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
proxyURL := c.resolveProxyURL(account)
|
||||
accountID := int64(0)
|
||||
accountConcurrency := 0
|
||||
if account != nil {
|
||||
accountID = account.ID
|
||||
accountConcurrency = account.Concurrency
|
||||
}
|
||||
var resp *http.Response
|
||||
if c.httpUpstream != nil {
|
||||
resp, err = c.httpUpstream.Do(req, proxyURL, accountID, accountConcurrency)
|
||||
} else {
|
||||
resp, err = http.DefaultClient.Do(req)
|
||||
}
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
raw, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("custom parse failed: %d %s", resp.StatusCode, truncateForLog(raw, 256))
|
||||
}
|
||||
downloadLink := strings.TrimSpace(gjson.GetBytes(raw, "download_link").String())
|
||||
if downloadLink == "" {
|
||||
return "", errors.New("custom parse response missing download_link")
|
||||
}
|
||||
return downloadLink, nil
|
||||
}
|
||||
|
||||
func (c *SoraSDKClient) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) {
|
||||
token, err := c.getAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sdkClient, err := c.getSDKClient(account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if strings.TrimSpace(expansionLevel) == "" {
|
||||
expansionLevel = "medium"
|
||||
}
|
||||
if durationS <= 0 {
|
||||
durationS = 10
|
||||
}
|
||||
enhanced, err := sdkClient.EnhancePrompt(ctx, token, prompt, expansionLevel, durationS)
|
||||
if err != nil {
|
||||
return "", c.wrapSDKError(err, account)
|
||||
}
|
||||
return enhanced, nil
|
||||
}
|
||||
|
||||
func (c *SoraSDKClient) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) {
|
||||
token, err := c.getAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sdkClient, err := c.getSDKClient(account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := sdkClient.QueryImageTaskOnce(ctx, token, taskID, time.Now().Add(-10*time.Second))
|
||||
if result.Err != nil {
|
||||
return &SoraImageTaskStatus{
|
||||
ID: taskID,
|
||||
Status: "failed",
|
||||
ErrorMsg: result.Err.Error(),
|
||||
}, nil
|
||||
}
|
||||
if result.Done && result.ImageURL != "" {
|
||||
return &SoraImageTaskStatus{
|
||||
ID: taskID,
|
||||
Status: "succeeded",
|
||||
URLs: []string{result.ImageURL},
|
||||
}, nil
|
||||
}
|
||||
status := result.Progress.Status
|
||||
if status == "" {
|
||||
status = "processing"
|
||||
}
|
||||
return &SoraImageTaskStatus{
|
||||
ID: taskID,
|
||||
Status: status,
|
||||
ProgressPct: float64(result.Progress.Percent) / 100.0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *SoraSDKClient) GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error) {
|
||||
token, err := c.getAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sdkClient, err := c.getSDKClient(account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 先查询 pending 列表
|
||||
result := sdkClient.QueryVideoTaskOnce(ctx, token, taskID, time.Now().Add(-10*time.Second), 0)
|
||||
if result.Err != nil {
|
||||
return &SoraVideoTaskStatus{
|
||||
ID: taskID,
|
||||
Status: "failed",
|
||||
ErrorMsg: result.Err.Error(),
|
||||
}, nil
|
||||
}
|
||||
if !result.Done {
|
||||
return &SoraVideoTaskStatus{
|
||||
ID: taskID,
|
||||
Status: result.Progress.Status,
|
||||
ProgressPct: result.Progress.Percent,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 任务不在 pending 中,查询 drafts 获取下载链接
|
||||
downloadURL, err := sdkClient.GetDownloadURL(ctx, token, taskID)
|
||||
if err != nil {
|
||||
errMsg := err.Error()
|
||||
if strings.Contains(errMsg, "内容违规") || strings.Contains(errMsg, "Content violates") {
|
||||
return &SoraVideoTaskStatus{
|
||||
ID: taskID,
|
||||
Status: "failed",
|
||||
ErrorMsg: errMsg,
|
||||
}, nil
|
||||
}
|
||||
// 可能还在处理中
|
||||
return &SoraVideoTaskStatus{
|
||||
ID: taskID,
|
||||
Status: "processing",
|
||||
}, nil
|
||||
}
|
||||
return &SoraVideoTaskStatus{
|
||||
ID: taskID,
|
||||
Status: "completed",
|
||||
URLs: []string{downloadURL},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// --- 内部方法 ---
|
||||
|
||||
// getSDKClient 获取或创建指定代理的 SDK 客户端实例
|
||||
func (c *SoraSDKClient) getSDKClient(account *Account) (*sora.Client, error) {
|
||||
proxyURL := c.resolveProxyURL(account)
|
||||
if v, ok := c.sdkClients.Load(proxyURL); ok {
|
||||
return v.(*sora.Client), nil
|
||||
}
|
||||
client, err := sora.New(proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建 Sora SDK 客户端失败: %w", err)
|
||||
}
|
||||
actual, _ := c.sdkClients.LoadOrStore(proxyURL, client)
|
||||
return actual.(*sora.Client), nil
|
||||
}
|
||||
|
||||
func (c *SoraSDKClient) resolveProxyURL(account *Account) string {
|
||||
if account == nil || account.ProxyID == nil || account.Proxy == nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(account.Proxy.URL())
|
||||
}
|
||||
|
||||
// getAccessToken 获取账号的 access_token,支持多种 token 来源和自动刷新。
|
||||
// 此方法保留了原 SoraDirectClient 的 token 管理逻辑。
|
||||
func (c *SoraSDKClient) getAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||
if account == nil {
|
||||
return "", errors.New("account is nil")
|
||||
}
|
||||
|
||||
// 优先尝试 OpenAI Token Provider
|
||||
allowProvider := c.allowOpenAITokenProvider(account)
|
||||
var providerErr error
|
||||
if allowProvider && c.tokenProvider != nil {
|
||||
token, err := c.tokenProvider.GetAccessToken(ctx, account)
|
||||
if err == nil && strings.TrimSpace(token) != "" {
|
||||
c.debugLogf("token_selected account_id=%d source=openai_token_provider", account.ID)
|
||||
return token, nil
|
||||
}
|
||||
providerErr = err
|
||||
if err != nil && c.debugEnabled() {
|
||||
c.debugLogf("token_provider_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error()))
|
||||
}
|
||||
}
|
||||
|
||||
// 尝试直接使用 credentials 中的 access_token
|
||||
token := strings.TrimSpace(account.GetCredential("access_token"))
|
||||
if token != "" {
|
||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||
if expiresAt != nil && time.Until(*expiresAt) <= 2*time.Minute {
|
||||
refreshed, refreshErr := c.recoverAccessToken(ctx, account, "access_token_expiring")
|
||||
if refreshErr == nil && strings.TrimSpace(refreshed) != "" {
|
||||
return refreshed, nil
|
||||
}
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// 尝试通过 session_token 或 refresh_token 恢复
|
||||
recovered, recoverErr := c.recoverAccessToken(ctx, account, "access_token_missing")
|
||||
if recoverErr == nil && strings.TrimSpace(recovered) != "" {
|
||||
return recovered, nil
|
||||
}
|
||||
if providerErr != nil {
|
||||
return "", providerErr
|
||||
}
|
||||
return "", errors.New("access_token not found")
|
||||
}
|
||||
|
||||
// recoverAccessToken 通过 session_token 或 refresh_token 恢复 access_token
|
||||
func (c *SoraSDKClient) recoverAccessToken(ctx context.Context, account *Account, reason string) (string, error) {
|
||||
if account == nil {
|
||||
return "", errors.New("account is nil")
|
||||
}
|
||||
|
||||
// 先尝试 session_token
|
||||
if sessionToken := strings.TrimSpace(account.GetCredential("session_token")); sessionToken != "" {
|
||||
accessToken, expiresAt, err := c.exchangeSessionToken(ctx, account, sessionToken)
|
||||
if err == nil && strings.TrimSpace(accessToken) != "" {
|
||||
c.applyRecoveredToken(ctx, account, accessToken, "", expiresAt, sessionToken)
|
||||
return accessToken, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 再尝试 refresh_token
|
||||
refreshToken := strings.TrimSpace(account.GetCredential("refresh_token"))
|
||||
if refreshToken == "" {
|
||||
return "", errors.New("session_token/refresh_token not found")
|
||||
}
|
||||
|
||||
sdkClient, err := c.getSDKClient(account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 尝试多个 client_id
|
||||
clientIDs := []string{
|
||||
strings.TrimSpace(account.GetCredential("client_id")),
|
||||
openaioauth.SoraClientID,
|
||||
openaioauth.ClientID,
|
||||
}
|
||||
tried := make(map[string]struct{}, len(clientIDs))
|
||||
var lastErr error
|
||||
|
||||
for _, clientID := range clientIDs {
|
||||
if clientID == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := tried[clientID]; ok {
|
||||
continue
|
||||
}
|
||||
tried[clientID] = struct{}{}
|
||||
|
||||
newAccess, newRefresh, refreshErr := sdkClient.RefreshAccessToken(ctx, refreshToken, clientID)
|
||||
if refreshErr != nil {
|
||||
lastErr = refreshErr
|
||||
continue
|
||||
}
|
||||
if strings.TrimSpace(newAccess) == "" {
|
||||
lastErr = errors.New("refreshed access_token is empty")
|
||||
continue
|
||||
}
|
||||
c.applyRecoveredToken(ctx, account, newAccess, newRefresh, "", "")
|
||||
return newAccess, nil
|
||||
}
|
||||
|
||||
if lastErr != nil {
|
||||
return "", lastErr
|
||||
}
|
||||
return "", errors.New("no available client_id for refresh_token exchange")
|
||||
}
|
||||
|
||||
// exchangeSessionToken 通过 session_token 换取 access_token
|
||||
func (c *SoraSDKClient) exchangeSessionToken(ctx context.Context, account *Account, sessionToken string) (string, string, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://sora.chatgpt.com/api/auth/session", nil)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
req.Header.Set("Cookie", "__Secure-next-auth.session-token="+sessionToken)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Origin", "https://sora.chatgpt.com")
|
||||
req.Header.Set("Referer", "https://sora.chatgpt.com/")
|
||||
req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
|
||||
|
||||
proxyURL := c.resolveProxyURL(account)
|
||||
accountID := int64(0)
|
||||
accountConcurrency := 0
|
||||
if account != nil {
|
||||
accountID = account.ID
|
||||
accountConcurrency = account.Concurrency
|
||||
}
|
||||
|
||||
var resp *http.Response
|
||||
if c.httpUpstream != nil {
|
||||
resp, err = c.httpUpstream.Do(req, proxyURL, accountID, accountConcurrency)
|
||||
} else {
|
||||
resp, err = http.DefaultClient.Do(req)
|
||||
}
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", "", fmt.Errorf("session exchange failed: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
accessToken := strings.TrimSpace(gjson.GetBytes(body, "accessToken").String())
|
||||
if accessToken == "" {
|
||||
return "", "", errors.New("session exchange missing accessToken")
|
||||
}
|
||||
expiresAt := strings.TrimSpace(gjson.GetBytes(body, "expires").String())
|
||||
return accessToken, expiresAt, nil
|
||||
}
|
||||
|
||||
// applyRecoveredToken 将恢复的 token 写入账号内存和数据库
|
||||
func (c *SoraSDKClient) applyRecoveredToken(ctx context.Context, account *Account, accessToken, refreshToken, expiresAt, sessionToken string) {
|
||||
if account == nil {
|
||||
return
|
||||
}
|
||||
if account.Credentials == nil {
|
||||
account.Credentials = make(map[string]any)
|
||||
}
|
||||
if strings.TrimSpace(accessToken) != "" {
|
||||
account.Credentials["access_token"] = accessToken
|
||||
}
|
||||
if strings.TrimSpace(refreshToken) != "" {
|
||||
account.Credentials["refresh_token"] = refreshToken
|
||||
}
|
||||
if strings.TrimSpace(expiresAt) != "" {
|
||||
account.Credentials["expires_at"] = expiresAt
|
||||
}
|
||||
if strings.TrimSpace(sessionToken) != "" {
|
||||
account.Credentials["session_token"] = sessionToken
|
||||
}
|
||||
|
||||
if c.accountRepo != nil {
|
||||
if err := c.accountRepo.Update(ctx, account); err != nil && c.debugEnabled() {
|
||||
c.debugLogf("persist_recovered_token_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error()))
|
||||
}
|
||||
}
|
||||
c.updateSoraAccountExtension(ctx, account, accessToken, refreshToken, sessionToken)
|
||||
}
|
||||
|
||||
func (c *SoraSDKClient) updateSoraAccountExtension(ctx context.Context, account *Account, accessToken, refreshToken, sessionToken string) {
|
||||
if c == nil || c.soraAccountRepo == nil || account == nil || account.ID <= 0 {
|
||||
return
|
||||
}
|
||||
updates := make(map[string]any)
|
||||
if strings.TrimSpace(accessToken) != "" && strings.TrimSpace(refreshToken) != "" {
|
||||
updates["access_token"] = accessToken
|
||||
updates["refresh_token"] = refreshToken
|
||||
}
|
||||
if strings.TrimSpace(sessionToken) != "" {
|
||||
updates["session_token"] = sessionToken
|
||||
}
|
||||
if len(updates) == 0 {
|
||||
return
|
||||
}
|
||||
if err := c.soraAccountRepo.Upsert(ctx, account.ID, updates); err != nil && c.debugEnabled() {
|
||||
c.debugLogf("persist_sora_extension_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error()))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *SoraSDKClient) allowOpenAITokenProvider(account *Account) bool {
|
||||
if c == nil || c.tokenProvider == nil {
|
||||
return false
|
||||
}
|
||||
if account != nil && account.Platform == PlatformSora {
|
||||
return c.cfg != nil && c.cfg.Sora.Client.UseOpenAITokenProvider
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// wrapSDKError 将 SDK 错误包装为 SoraUpstreamError
|
||||
func (c *SoraSDKClient) wrapSDKError(err error, account *Account) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
msg := err.Error()
|
||||
statusCode := http.StatusBadGateway
|
||||
if strings.Contains(msg, "HTTP 401") || strings.Contains(msg, "HTTP 403") {
|
||||
statusCode = http.StatusUnauthorized
|
||||
} else if strings.Contains(msg, "HTTP 429") {
|
||||
statusCode = http.StatusTooManyRequests
|
||||
} else if strings.Contains(msg, "HTTP 404") {
|
||||
statusCode = http.StatusNotFound
|
||||
}
|
||||
return &SoraUpstreamError{
|
||||
StatusCode: statusCode,
|
||||
Message: msg,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *SoraSDKClient) debugEnabled() bool {
|
||||
return c != nil && c.cfg != nil && c.cfg.Sora.Client.Debug
|
||||
}
|
||||
|
||||
func (c *SoraSDKClient) debugLogf(format string, args ...any) {
|
||||
if c.debugEnabled() {
|
||||
log.Printf("[SoraSDK] "+format, args...)
|
||||
}
|
||||
}
|
||||
@@ -206,14 +206,14 @@ func ProvideSoraMediaStorage(cfg *config.Config) *SoraMediaStorage {
|
||||
return NewSoraMediaStorage(cfg)
|
||||
}
|
||||
|
||||
func ProvideSoraDirectClient(
|
||||
func ProvideSoraSDKClient(
|
||||
cfg *config.Config,
|
||||
httpUpstream HTTPUpstream,
|
||||
tokenProvider *OpenAITokenProvider,
|
||||
accountRepo AccountRepository,
|
||||
soraAccountRepo SoraAccountRepository,
|
||||
) *SoraDirectClient {
|
||||
client := NewSoraDirectClient(cfg, httpUpstream, tokenProvider)
|
||||
) *SoraSDKClient {
|
||||
client := NewSoraSDKClient(cfg, httpUpstream, tokenProvider)
|
||||
client.SetAccountRepositories(accountRepo, soraAccountRepo)
|
||||
return client
|
||||
}
|
||||
@@ -306,8 +306,8 @@ var ProviderSet = wire.NewSet(
|
||||
NewGatewayService,
|
||||
ProvideSoraMediaStorage,
|
||||
ProvideSoraMediaCleanupService,
|
||||
ProvideSoraDirectClient,
|
||||
wire.Bind(new(SoraClient), new(*SoraDirectClient)),
|
||||
ProvideSoraSDKClient,
|
||||
wire.Bind(new(SoraClient), new(*SoraSDKClient)),
|
||||
NewSoraGatewayService,
|
||||
NewOpenAIGatewayService,
|
||||
NewOAuthService,
|
||||
|
||||
Reference in New Issue
Block a user