feat(sync): full code sync from release

This commit is contained in:
yangjianbo
2026-02-28 15:01:20 +08:00
parent bfc7b339f7
commit bb664d9bbf
338 changed files with 54513 additions and 2011 deletions

View File

@@ -15,6 +15,7 @@ import (
"github.com/DouDOU-start/go-sora2api/sora"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
openaioauth "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
"github.com/tidwall/gjson"
@@ -75,6 +76,17 @@ func (c *SoraSDKClient) PreflightCheck(ctx context.Context, account *Account, re
}
balance, err := sdkClient.GetCreditBalance(ctx, token)
if err != nil {
accountID := int64(0)
if account != nil {
accountID = account.ID
}
logger.LegacyPrintf(
"service.sora_sdk",
"[PreflightCheckRawError] account_id=%d model=%s op=get_credit_balance raw_err=%s",
accountID,
requestedModel,
logredact.RedactText(err.Error()),
)
return &SoraUpstreamError{
StatusCode: http.StatusForbidden,
Message: "当前账号未开通 Sora2 能力或无可用配额",
@@ -170,9 +182,23 @@ func (c *SoraSDKClient) CreateVideoTask(ctx context.Context, account *Account, r
if size == "" {
size = "small"
}
videoCount := req.VideoCount
if videoCount <= 0 {
videoCount = 1
}
if videoCount > 3 {
videoCount = 3
}
// Remix 模式
if strings.TrimSpace(req.RemixTargetID) != "" {
if videoCount > 1 {
accountID := int64(0)
if account != nil {
accountID = account.ID
}
c.debugLogf("video_count_ignored_for_remix account_id=%d count=%d", accountID, videoCount)
}
styleID := "" // SDK ExtractStyle 可从 prompt 中提取
taskID, err := sdkClient.RemixVideo(ctx, token, sentinel, req.RemixTargetID, req.Prompt, orientation, nFrames, styleID)
if err != nil {
@@ -182,13 +208,60 @@ func (c *SoraSDKClient) CreateVideoTask(ctx context.Context, account *Account, r
}
// 普通视频(文生视频或图生视频)
taskID, err := sdkClient.CreateVideoTaskWithOptions(ctx, token, sentinel, req.Prompt, orientation, nFrames, model, size, req.MediaID, "")
var taskID string
if videoCount <= 1 {
taskID, err = sdkClient.CreateVideoTaskWithOptions(ctx, token, sentinel, req.Prompt, orientation, nFrames, model, size, req.MediaID, "")
} else {
taskID, err = c.createVideoTaskWithVariants(ctx, account, token, sentinel, req.Prompt, orientation, nFrames, model, size, req.MediaID, videoCount)
}
if err != nil {
return "", c.wrapSDKError(err, account)
}
return taskID, nil
}
func (c *SoraSDKClient) createVideoTaskWithVariants(
ctx context.Context,
account *Account,
accessToken string,
sentinelToken string,
prompt string,
orientation string,
nFrames int,
model string,
size string,
mediaID string,
videoCount int,
) (string, error) {
inpaintItems := make([]any, 0, 1)
if strings.TrimSpace(mediaID) != "" {
inpaintItems = append(inpaintItems, map[string]any{
"kind": "upload",
"upload_id": mediaID,
})
}
payload := map[string]any{
"kind": "video",
"prompt": prompt,
"orientation": orientation,
"size": size,
"n_frames": nFrames,
"n_variants": videoCount,
"model": model,
"inpaint_items": inpaintItems,
"style_id": nil,
}
raw, err := c.doSoraBackendJSON(ctx, account, http.MethodPost, "/nf/create", accessToken, sentinelToken, payload)
if err != nil {
return "", err
}
taskID := strings.TrimSpace(gjson.GetBytes(raw, "id").String())
if taskID == "" {
return "", errors.New("create video task response missing id")
}
return taskID, nil
}
func (c *SoraSDKClient) CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
@@ -512,7 +585,7 @@ func (c *SoraSDKClient) GetVideoTask(ctx context.Context, account *Account, task
}
// 任务不在 pending 中,查询 drafts 获取下载链接
downloadURL, err := sdkClient.GetDownloadURL(ctx, token, taskID)
downloadURLs, err := c.getVideoTaskDownloadURLs(ctx, account, token, taskID)
if err != nil {
errMsg := err.Error()
if strings.Contains(errMsg, "内容违规") || strings.Contains(errMsg, "Content violates") {
@@ -528,13 +601,147 @@ func (c *SoraSDKClient) GetVideoTask(ctx context.Context, account *Account, task
Status: "processing",
}, nil
}
if len(downloadURLs) == 0 {
return &SoraVideoTaskStatus{
ID: taskID,
Status: "processing",
}, nil
}
return &SoraVideoTaskStatus{
ID: taskID,
Status: "completed",
URLs: []string{downloadURL},
URLs: downloadURLs,
}, nil
}
func (c *SoraSDKClient) getVideoTaskDownloadURLs(ctx context.Context, account *Account, accessToken, taskID string) ([]string, error) {
raw, err := c.doSoraBackendJSON(ctx, account, http.MethodGet, "/project_y/profile/drafts?limit=30", accessToken, "", nil)
if err != nil {
return nil, err
}
items := gjson.GetBytes(raw, "items")
if !items.Exists() || !items.IsArray() {
return nil, fmt.Errorf("drafts response missing items for task %s", taskID)
}
urlSet := make(map[string]struct{}, 4)
urls := make([]string, 0, 4)
items.ForEach(func(_, item gjson.Result) bool {
if strings.TrimSpace(item.Get("task_id").String()) != taskID {
return true
}
kind := strings.TrimSpace(item.Get("kind").String())
reason := strings.TrimSpace(item.Get("reason_str").String())
markdownReason := strings.TrimSpace(item.Get("markdown_reason_str").String())
if kind == "sora_content_violation" || reason != "" || markdownReason != "" {
if reason == "" {
reason = markdownReason
}
if reason == "" {
reason = "内容违规"
}
err = fmt.Errorf("内容违规: %s", reason)
return false
}
url := strings.TrimSpace(item.Get("downloadable_url").String())
if url == "" {
url = strings.TrimSpace(item.Get("url").String())
}
if url == "" {
return true
}
if _, exists := urlSet[url]; exists {
return true
}
urlSet[url] = struct{}{}
urls = append(urls, url)
return true
})
if err != nil {
return nil, err
}
if len(urls) > 0 {
return urls, nil
}
// 兼容旧 SDK 的兜底逻辑
sdkClient, sdkErr := c.getSDKClient(account)
if sdkErr != nil {
return nil, sdkErr
}
downloadURL, sdkErr := sdkClient.GetDownloadURL(ctx, accessToken, taskID)
if sdkErr != nil {
return nil, sdkErr
}
if strings.TrimSpace(downloadURL) == "" {
return nil, nil
}
return []string{downloadURL}, nil
}
func (c *SoraSDKClient) doSoraBackendJSON(
ctx context.Context,
account *Account,
method string,
path string,
accessToken string,
sentinelToken string,
payload map[string]any,
) ([]byte, error) {
endpoint := "https://sora.chatgpt.com/backend" + path
var body io.Reader
if payload != nil {
raw, err := json.Marshal(payload)
if err != nil {
return nil, err
}
body = bytes.NewReader(raw)
}
req, err := http.NewRequestWithContext(ctx, method, endpoint, body)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Accept", "application/json, text/plain, */*")
req.Header.Set("Origin", "https://sora.chatgpt.com")
req.Header.Set("Referer", "https://sora.chatgpt.com/")
req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
if payload != nil {
req.Header.Set("Content-Type", "application/json")
}
if strings.TrimSpace(sentinelToken) != "" {
req.Header.Set("openai-sentinel-token", sentinelToken)
}
proxyURL := c.resolveProxyURL(account)
accountID := int64(0)
accountConcurrency := 0
if account != nil {
accountID = account.ID
accountConcurrency = account.Concurrency
}
var resp *http.Response
if c.httpUpstream != nil {
resp, err = c.httpUpstream.Do(req, proxyURL, accountID, accountConcurrency)
} else {
resp, err = http.DefaultClient.Do(req)
}
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
raw, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20))
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, truncateForLog(raw, 256))
}
return raw, nil
}
// --- 内部方法 ---
// getSDKClient 获取或创建指定代理的 SDK 客户端实例
@@ -791,6 +998,17 @@ func (c *SoraSDKClient) wrapSDKError(err error, account *Account) error {
} else if strings.Contains(msg, "HTTP 404") {
statusCode = http.StatusNotFound
}
accountID := int64(0)
if account != nil {
accountID = account.ID
}
logger.LegacyPrintf(
"service.sora_sdk",
"[WrapSDKError] account_id=%d mapped_status=%d raw_err=%s",
accountID,
statusCode,
logredact.RedactText(msg),
)
return &SoraUpstreamError{
StatusCode: statusCode,
Message: msg,