feat(sync): full code sync from release
This commit is contained in:
@@ -15,6 +15,7 @@ import (
|
||||
|
||||
"github.com/DouDOU-start/go-sora2api/sora"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
openaioauth "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -75,6 +76,17 @@ func (c *SoraSDKClient) PreflightCheck(ctx context.Context, account *Account, re
|
||||
}
|
||||
balance, err := sdkClient.GetCreditBalance(ctx, token)
|
||||
if err != nil {
|
||||
accountID := int64(0)
|
||||
if account != nil {
|
||||
accountID = account.ID
|
||||
}
|
||||
logger.LegacyPrintf(
|
||||
"service.sora_sdk",
|
||||
"[PreflightCheckRawError] account_id=%d model=%s op=get_credit_balance raw_err=%s",
|
||||
accountID,
|
||||
requestedModel,
|
||||
logredact.RedactText(err.Error()),
|
||||
)
|
||||
return &SoraUpstreamError{
|
||||
StatusCode: http.StatusForbidden,
|
||||
Message: "当前账号未开通 Sora2 能力或无可用配额",
|
||||
@@ -170,9 +182,23 @@ func (c *SoraSDKClient) CreateVideoTask(ctx context.Context, account *Account, r
|
||||
if size == "" {
|
||||
size = "small"
|
||||
}
|
||||
videoCount := req.VideoCount
|
||||
if videoCount <= 0 {
|
||||
videoCount = 1
|
||||
}
|
||||
if videoCount > 3 {
|
||||
videoCount = 3
|
||||
}
|
||||
|
||||
// Remix 模式
|
||||
if strings.TrimSpace(req.RemixTargetID) != "" {
|
||||
if videoCount > 1 {
|
||||
accountID := int64(0)
|
||||
if account != nil {
|
||||
accountID = account.ID
|
||||
}
|
||||
c.debugLogf("video_count_ignored_for_remix account_id=%d count=%d", accountID, videoCount)
|
||||
}
|
||||
styleID := "" // SDK ExtractStyle 可从 prompt 中提取
|
||||
taskID, err := sdkClient.RemixVideo(ctx, token, sentinel, req.RemixTargetID, req.Prompt, orientation, nFrames, styleID)
|
||||
if err != nil {
|
||||
@@ -182,13 +208,60 @@ func (c *SoraSDKClient) CreateVideoTask(ctx context.Context, account *Account, r
|
||||
}
|
||||
|
||||
// 普通视频(文生视频或图生视频)
|
||||
taskID, err := sdkClient.CreateVideoTaskWithOptions(ctx, token, sentinel, req.Prompt, orientation, nFrames, model, size, req.MediaID, "")
|
||||
var taskID string
|
||||
if videoCount <= 1 {
|
||||
taskID, err = sdkClient.CreateVideoTaskWithOptions(ctx, token, sentinel, req.Prompt, orientation, nFrames, model, size, req.MediaID, "")
|
||||
} else {
|
||||
taskID, err = c.createVideoTaskWithVariants(ctx, account, token, sentinel, req.Prompt, orientation, nFrames, model, size, req.MediaID, videoCount)
|
||||
}
|
||||
if err != nil {
|
||||
return "", c.wrapSDKError(err, account)
|
||||
}
|
||||
return taskID, nil
|
||||
}
|
||||
|
||||
func (c *SoraSDKClient) createVideoTaskWithVariants(
|
||||
ctx context.Context,
|
||||
account *Account,
|
||||
accessToken string,
|
||||
sentinelToken string,
|
||||
prompt string,
|
||||
orientation string,
|
||||
nFrames int,
|
||||
model string,
|
||||
size string,
|
||||
mediaID string,
|
||||
videoCount int,
|
||||
) (string, error) {
|
||||
inpaintItems := make([]any, 0, 1)
|
||||
if strings.TrimSpace(mediaID) != "" {
|
||||
inpaintItems = append(inpaintItems, map[string]any{
|
||||
"kind": "upload",
|
||||
"upload_id": mediaID,
|
||||
})
|
||||
}
|
||||
payload := map[string]any{
|
||||
"kind": "video",
|
||||
"prompt": prompt,
|
||||
"orientation": orientation,
|
||||
"size": size,
|
||||
"n_frames": nFrames,
|
||||
"n_variants": videoCount,
|
||||
"model": model,
|
||||
"inpaint_items": inpaintItems,
|
||||
"style_id": nil,
|
||||
}
|
||||
raw, err := c.doSoraBackendJSON(ctx, account, http.MethodPost, "/nf/create", accessToken, sentinelToken, payload)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
taskID := strings.TrimSpace(gjson.GetBytes(raw, "id").String())
|
||||
if taskID == "" {
|
||||
return "", errors.New("create video task response missing id")
|
||||
}
|
||||
return taskID, nil
|
||||
}
|
||||
|
||||
func (c *SoraSDKClient) CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error) {
|
||||
token, err := c.getAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
@@ -512,7 +585,7 @@ func (c *SoraSDKClient) GetVideoTask(ctx context.Context, account *Account, task
|
||||
}
|
||||
|
||||
// 任务不在 pending 中,查询 drafts 获取下载链接
|
||||
downloadURL, err := sdkClient.GetDownloadURL(ctx, token, taskID)
|
||||
downloadURLs, err := c.getVideoTaskDownloadURLs(ctx, account, token, taskID)
|
||||
if err != nil {
|
||||
errMsg := err.Error()
|
||||
if strings.Contains(errMsg, "内容违规") || strings.Contains(errMsg, "Content violates") {
|
||||
@@ -528,13 +601,147 @@ func (c *SoraSDKClient) GetVideoTask(ctx context.Context, account *Account, task
|
||||
Status: "processing",
|
||||
}, nil
|
||||
}
|
||||
if len(downloadURLs) == 0 {
|
||||
return &SoraVideoTaskStatus{
|
||||
ID: taskID,
|
||||
Status: "processing",
|
||||
}, nil
|
||||
}
|
||||
return &SoraVideoTaskStatus{
|
||||
ID: taskID,
|
||||
Status: "completed",
|
||||
URLs: []string{downloadURL},
|
||||
URLs: downloadURLs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *SoraSDKClient) getVideoTaskDownloadURLs(ctx context.Context, account *Account, accessToken, taskID string) ([]string, error) {
|
||||
raw, err := c.doSoraBackendJSON(ctx, account, http.MethodGet, "/project_y/profile/drafts?limit=30", accessToken, "", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items := gjson.GetBytes(raw, "items")
|
||||
if !items.Exists() || !items.IsArray() {
|
||||
return nil, fmt.Errorf("drafts response missing items for task %s", taskID)
|
||||
}
|
||||
urlSet := make(map[string]struct{}, 4)
|
||||
urls := make([]string, 0, 4)
|
||||
items.ForEach(func(_, item gjson.Result) bool {
|
||||
if strings.TrimSpace(item.Get("task_id").String()) != taskID {
|
||||
return true
|
||||
}
|
||||
kind := strings.TrimSpace(item.Get("kind").String())
|
||||
reason := strings.TrimSpace(item.Get("reason_str").String())
|
||||
markdownReason := strings.TrimSpace(item.Get("markdown_reason_str").String())
|
||||
if kind == "sora_content_violation" || reason != "" || markdownReason != "" {
|
||||
if reason == "" {
|
||||
reason = markdownReason
|
||||
}
|
||||
if reason == "" {
|
||||
reason = "内容违规"
|
||||
}
|
||||
err = fmt.Errorf("内容违规: %s", reason)
|
||||
return false
|
||||
}
|
||||
url := strings.TrimSpace(item.Get("downloadable_url").String())
|
||||
if url == "" {
|
||||
url = strings.TrimSpace(item.Get("url").String())
|
||||
}
|
||||
if url == "" {
|
||||
return true
|
||||
}
|
||||
if _, exists := urlSet[url]; exists {
|
||||
return true
|
||||
}
|
||||
urlSet[url] = struct{}{}
|
||||
urls = append(urls, url)
|
||||
return true
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(urls) > 0 {
|
||||
return urls, nil
|
||||
}
|
||||
|
||||
// 兼容旧 SDK 的兜底逻辑
|
||||
sdkClient, sdkErr := c.getSDKClient(account)
|
||||
if sdkErr != nil {
|
||||
return nil, sdkErr
|
||||
}
|
||||
downloadURL, sdkErr := sdkClient.GetDownloadURL(ctx, accessToken, taskID)
|
||||
if sdkErr != nil {
|
||||
return nil, sdkErr
|
||||
}
|
||||
if strings.TrimSpace(downloadURL) == "" {
|
||||
return nil, nil
|
||||
}
|
||||
return []string{downloadURL}, nil
|
||||
}
|
||||
|
||||
func (c *SoraSDKClient) doSoraBackendJSON(
|
||||
ctx context.Context,
|
||||
account *Account,
|
||||
method string,
|
||||
path string,
|
||||
accessToken string,
|
||||
sentinelToken string,
|
||||
payload map[string]any,
|
||||
) ([]byte, error) {
|
||||
endpoint := "https://sora.chatgpt.com/backend" + path
|
||||
var body io.Reader
|
||||
if payload != nil {
|
||||
raw, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
body = bytes.NewReader(raw)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, endpoint, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Accept", "application/json, text/plain, */*")
|
||||
req.Header.Set("Origin", "https://sora.chatgpt.com")
|
||||
req.Header.Set("Referer", "https://sora.chatgpt.com/")
|
||||
req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
|
||||
if payload != nil {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
if strings.TrimSpace(sentinelToken) != "" {
|
||||
req.Header.Set("openai-sentinel-token", sentinelToken)
|
||||
}
|
||||
|
||||
proxyURL := c.resolveProxyURL(account)
|
||||
accountID := int64(0)
|
||||
accountConcurrency := 0
|
||||
if account != nil {
|
||||
accountID = account.ID
|
||||
accountConcurrency = account.Concurrency
|
||||
}
|
||||
|
||||
var resp *http.Response
|
||||
if c.httpUpstream != nil {
|
||||
resp, err = c.httpUpstream.Do(req, proxyURL, accountID, accountConcurrency)
|
||||
} else {
|
||||
resp, err = http.DefaultClient.Do(req)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
raw, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
|
||||
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, truncateForLog(raw, 256))
|
||||
}
|
||||
return raw, nil
|
||||
}
|
||||
|
||||
// --- 内部方法 ---
|
||||
|
||||
// getSDKClient 获取或创建指定代理的 SDK 客户端实例
|
||||
@@ -791,6 +998,17 @@ func (c *SoraSDKClient) wrapSDKError(err error, account *Account) error {
|
||||
} else if strings.Contains(msg, "HTTP 404") {
|
||||
statusCode = http.StatusNotFound
|
||||
}
|
||||
accountID := int64(0)
|
||||
if account != nil {
|
||||
accountID = account.ID
|
||||
}
|
||||
logger.LegacyPrintf(
|
||||
"service.sora_sdk",
|
||||
"[WrapSDKError] account_id=%d mapped_status=%d raw_err=%s",
|
||||
accountID,
|
||||
statusCode,
|
||||
logredact.RedactText(msg),
|
||||
)
|
||||
return &SoraUpstreamError{
|
||||
StatusCode: statusCode,
|
||||
Message: msg,
|
||||
|
||||
Reference in New Issue
Block a user