feat(Sora): 直连生成并移除sora2api依赖

实现直连 Sora 客户端、媒体落地与清理策略\n更新网关与前端配置以支持 Sora 平台\n补齐单元测试与契约测试,新增 curl 测试脚本\n\n测试: go test ./... -tags=unit
This commit is contained in:
yangjianbo
2026-02-01 21:37:10 +08:00
parent 78d0ca3775
commit 399dd78b2a
39 changed files with 3120 additions and 1189 deletions

View File

@@ -491,7 +491,7 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
return s.sendErrorAndEnd(c, "Failed to create request")
}
// 使用 Sora 客户端标准请求头(参考 sora2api
// 使用 Sora 客户端标准请求头
req.Header.Set("Authorization", "Bearer "+authToken)
req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
req.Header.Set("Accept", "application/json")

View File

@@ -283,7 +283,6 @@ type adminServiceImpl struct {
groupRepo GroupRepository
accountRepo AccountRepository
soraAccountRepo SoraAccountRepository // Sora 账号扩展表仓储
soraSyncService *Sora2APISyncService // Sora2API 同步服务
proxyRepo ProxyRepository
apiKeyRepo APIKeyRepository
redeemCodeRepo RedeemCodeRepository
@@ -299,7 +298,6 @@ func NewAdminService(
groupRepo GroupRepository,
accountRepo AccountRepository,
soraAccountRepo SoraAccountRepository,
soraSyncService *Sora2APISyncService,
proxyRepo ProxyRepository,
apiKeyRepo APIKeyRepository,
redeemCodeRepo RedeemCodeRepository,
@@ -313,7 +311,6 @@ func NewAdminService(
groupRepo: groupRepo,
accountRepo: accountRepo,
soraAccountRepo: soraAccountRepo,
soraSyncService: soraSyncService,
proxyRepo: proxyRepo,
apiKeyRepo: apiKeyRepo,
redeemCodeRepo: redeemCodeRepo,
@@ -917,9 +914,6 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
}
}
// 同步到 sora2api异步不阻塞创建
s.syncSoraAccountAsync(account)
return account, nil
}
@@ -1014,7 +1008,6 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
if err != nil {
return nil, err
}
s.syncSoraAccountAsync(updated)
return updated, nil
}
@@ -1032,17 +1025,15 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
}
needMixedChannelCheck := input.GroupIDs != nil && !input.SkipMixedChannelCheck
needSoraSync := s != nil && s.soraSyncService != nil
// 预加载账号平台信息(混合渠道检查或 Sora 同步需要)。
platformByID := map[int64]string{}
if needMixedChannelCheck || needSoraSync {
if needMixedChannelCheck {
accounts, err := s.accountRepo.GetByIDs(ctx, input.AccountIDs)
if err != nil {
if needMixedChannelCheck {
return nil, err
}
log.Printf("[AdminService] 预加载账号平台信息失败,将逐个降级同步: err=%v", err)
} else {
for _, account := range accounts {
if account != nil {
@@ -1134,45 +1125,15 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
result.Success++
result.SuccessIDs = append(result.SuccessIDs, accountID)
result.Results = append(result.Results, entry)
// 批量更新后同步 sora2api
if needSoraSync {
platform := platformByID[accountID]
if platform == "" {
updated, err := s.accountRepo.GetByID(ctx, accountID)
if err != nil {
log.Printf("[AdminService] 批量更新后获取账号失败,无法同步 sora2api: account_id=%d err=%v", accountID, err)
continue
}
if updated.Platform == PlatformSora {
s.syncSoraAccountAsync(updated)
}
continue
}
if platform == PlatformSora {
updated, err := s.accountRepo.GetByID(ctx, accountID)
if err != nil {
log.Printf("[AdminService] 批量更新后获取账号失败,无法同步 sora2api: account_id=%d err=%v", accountID, err)
continue
}
s.syncSoraAccountAsync(updated)
}
}
}
return result, nil
}
func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
return err
}
if err := s.accountRepo.Delete(ctx, id); err != nil {
return err
}
s.deleteSoraAccountAsync(account)
return nil
}
@@ -1210,44 +1171,9 @@ func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64,
if err != nil {
return nil, err
}
s.syncSoraAccountAsync(updated)
return updated, nil
}
func (s *adminServiceImpl) syncSoraAccountAsync(account *Account) {
if s == nil || s.soraSyncService == nil || account == nil {
return
}
if account.Platform != PlatformSora {
return
}
syncAccount := *account
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := s.soraSyncService.SyncAccount(ctx, &syncAccount); err != nil {
log.Printf("[AdminService] 同步 sora2api 失败: account_id=%d err=%v", syncAccount.ID, err)
}
}()
}
func (s *adminServiceImpl) deleteSoraAccountAsync(account *Account) {
if s == nil || s.soraSyncService == nil || account == nil {
return
}
if account.Platform != PlatformSora {
return
}
syncAccount := *account
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := s.soraSyncService.DeleteAccount(ctx, &syncAccount); err != nil {
log.Printf("[AdminService] 删除 sora2api token 失败: account_id=%d err=%v", syncAccount.ID, err)
}
}()
}
// Proxy management implementations
func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}

View File

@@ -105,31 +105,3 @@ func TestAdminService_BulkUpdateAccounts_PartialFailureIDs(t *testing.T) {
require.ElementsMatch(t, []int64{2}, result.FailedIDs)
require.Len(t, result.Results, 3)
}
// TestAdminService_BulkUpdateAccounts_SoraSyncWithoutGroupIDs 验证无分组更新时仍会触发 Sora 同步。
func TestAdminService_BulkUpdateAccounts_SoraSyncWithoutGroupIDs(t *testing.T) {
repo := &accountRepoStubForBulkUpdate{
getByIDsAccounts: []*Account{
{ID: 1, Platform: PlatformSora},
},
getByIDAccounts: map[int64]*Account{
1: {ID: 1, Platform: PlatformSora},
},
}
svc := &adminServiceImpl{
accountRepo: repo,
soraSyncService: &Sora2APISyncService{},
}
schedulable := true
input := &BulkUpdateAccountsInput{
AccountIDs: []int64{1},
Schedulable: &schedulable,
}
result, err := svc.BulkUpdateAccounts(context.Background(), input)
require.NoError(t, err)
require.Equal(t, 1, result.Success)
require.True(t, repo.getByIDsCalled)
require.ElementsMatch(t, []int64{1}, repo.getByIDCalled)
}

View File

@@ -1,351 +0,0 @@
package service
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"log"
"net/http"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
)
// Sora2APIModel represents a model entry returned by sora2api.
type Sora2APIModel struct {
ID string `json:"id"`
Object string `json:"object"`
OwnedBy string `json:"owned_by,omitempty"`
Description string `json:"description,omitempty"`
}
// Sora2APIModelList represents /v1/models response.
type Sora2APIModelList struct {
Object string `json:"object"`
Data []Sora2APIModel `json:"data"`
}
// Sora2APIImportTokenItem mirrors sora2api ImportTokenItem.
type Sora2APIImportTokenItem struct {
Email string `json:"email"`
AccessToken string `json:"access_token,omitempty"`
SessionToken string `json:"session_token,omitempty"`
RefreshToken string `json:"refresh_token,omitempty"`
ClientID string `json:"client_id,omitempty"`
ProxyURL string `json:"proxy_url,omitempty"`
Remark string `json:"remark,omitempty"`
IsActive bool `json:"is_active"`
ImageEnabled bool `json:"image_enabled"`
VideoEnabled bool `json:"video_enabled"`
ImageConcurrency int `json:"image_concurrency"`
VideoConcurrency int `json:"video_concurrency"`
}
// Sora2APIToken represents minimal fields for admin list.
type Sora2APIToken struct {
ID int64 `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
Remark string `json:"remark"`
}
// Sora2APIService provides access to sora2api endpoints.
type Sora2APIService struct {
cfg *config.Config
baseURL string
apiKey string
adminUsername string
adminPassword string
adminTokenTTL time.Duration
tokenImportMode string
client *http.Client
adminClient *http.Client
adminToken string
adminTokenAt time.Time
adminMu sync.Mutex
modelCache []Sora2APIModel
modelMu sync.RWMutex
}
func NewSora2APIService(cfg *config.Config) *Sora2APIService {
if cfg == nil {
return &Sora2APIService{}
}
adminTTL := time.Duration(cfg.Sora2API.AdminTokenTTLSeconds) * time.Second
if adminTTL <= 0 {
adminTTL = 15 * time.Minute
}
adminTimeout := time.Duration(cfg.Sora2API.AdminTimeoutSeconds) * time.Second
if adminTimeout <= 0 {
adminTimeout = 10 * time.Second
}
return &Sora2APIService{
cfg: cfg,
baseURL: strings.TrimRight(strings.TrimSpace(cfg.Sora2API.BaseURL), "/"),
apiKey: strings.TrimSpace(cfg.Sora2API.APIKey),
adminUsername: strings.TrimSpace(cfg.Sora2API.AdminUsername),
adminPassword: strings.TrimSpace(cfg.Sora2API.AdminPassword),
adminTokenTTL: adminTTL,
tokenImportMode: strings.ToLower(strings.TrimSpace(cfg.Sora2API.TokenImportMode)),
client: &http.Client{},
adminClient: &http.Client{Timeout: adminTimeout},
}
}
func (s *Sora2APIService) Enabled() bool {
return s != nil && s.baseURL != "" && s.apiKey != ""
}
func (s *Sora2APIService) AdminEnabled() bool {
return s != nil && s.baseURL != "" && s.adminUsername != "" && s.adminPassword != ""
}
func (s *Sora2APIService) buildURL(path string) string {
if s.baseURL == "" {
return path
}
if strings.HasPrefix(path, "/") {
return s.baseURL + path
}
return s.baseURL + "/" + path
}
// BuildURL 返回完整的 sora2api URL用于代理媒体
func (s *Sora2APIService) BuildURL(path string) string {
return s.buildURL(path)
}
func (s *Sora2APIService) NewAPIRequest(ctx context.Context, method string, path string, body []byte) (*http.Request, error) {
if !s.Enabled() {
return nil, errors.New("sora2api not configured")
}
req, err := http.NewRequestWithContext(ctx, method, s.buildURL(path), bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+s.apiKey)
req.Header.Set("Content-Type", "application/json")
return req, nil
}
func (s *Sora2APIService) ListModels(ctx context.Context) ([]Sora2APIModel, error) {
if !s.Enabled() {
return nil, errors.New("sora2api not configured")
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, s.buildURL("/v1/models"), nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+s.apiKey)
resp, err := s.client.Do(req)
if err != nil {
return s.cachedModelsOnError(err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return s.cachedModelsOnError(fmt.Errorf("sora2api models status: %d", resp.StatusCode))
}
var payload Sora2APIModelList
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
return s.cachedModelsOnError(err)
}
models := payload.Data
if s.cfg != nil && s.cfg.Gateway.SoraModelFilters.HidePromptEnhance {
filtered := make([]Sora2APIModel, 0, len(models))
for _, m := range models {
if strings.HasPrefix(strings.ToLower(m.ID), "prompt-enhance") {
continue
}
filtered = append(filtered, m)
}
models = filtered
}
s.modelMu.Lock()
s.modelCache = models
s.modelMu.Unlock()
return models, nil
}
func (s *Sora2APIService) cachedModelsOnError(err error) ([]Sora2APIModel, error) {
s.modelMu.RLock()
cached := append([]Sora2APIModel(nil), s.modelCache...)
s.modelMu.RUnlock()
if len(cached) > 0 {
log.Printf("[Sora2API] 模型列表拉取失败,回退缓存: %v", err)
return cached, nil
}
return nil, err
}
func (s *Sora2APIService) ImportTokens(ctx context.Context, items []Sora2APIImportTokenItem) error {
if !s.AdminEnabled() {
return errors.New("sora2api admin not configured")
}
mode := s.tokenImportMode
if mode == "" {
mode = "at"
}
payload := map[string]any{
"tokens": items,
"mode": mode,
}
_, err := s.doAdminRequest(ctx, http.MethodPost, "/api/tokens/import", payload, nil)
return err
}
func (s *Sora2APIService) ListTokens(ctx context.Context) ([]Sora2APIToken, error) {
if !s.AdminEnabled() {
return nil, errors.New("sora2api admin not configured")
}
var tokens []Sora2APIToken
_, err := s.doAdminRequest(ctx, http.MethodGet, "/api/tokens", nil, &tokens)
return tokens, err
}
func (s *Sora2APIService) DisableToken(ctx context.Context, tokenID int64) error {
if !s.AdminEnabled() {
return errors.New("sora2api admin not configured")
}
path := fmt.Sprintf("/api/tokens/%d/disable", tokenID)
_, err := s.doAdminRequest(ctx, http.MethodPost, path, nil, nil)
return err
}
func (s *Sora2APIService) DeleteToken(ctx context.Context, tokenID int64) error {
if !s.AdminEnabled() {
return errors.New("sora2api admin not configured")
}
path := fmt.Sprintf("/api/tokens/%d", tokenID)
_, err := s.doAdminRequest(ctx, http.MethodDelete, path, nil, nil)
return err
}
func (s *Sora2APIService) doAdminRequest(ctx context.Context, method string, path string, body any, out any) (*http.Response, error) {
if !s.AdminEnabled() {
return nil, errors.New("sora2api admin not configured")
}
token, err := s.getAdminToken(ctx)
if err != nil {
return nil, err
}
resp, err := s.doAdminRequestWithToken(ctx, method, path, token, body, out)
if err == nil && resp != nil && resp.StatusCode != http.StatusUnauthorized {
return resp, nil
}
if resp != nil && resp.StatusCode == http.StatusUnauthorized {
s.invalidateAdminToken()
token, err = s.getAdminToken(ctx)
if err != nil {
return resp, err
}
return s.doAdminRequestWithToken(ctx, method, path, token, body, out)
}
return resp, err
}
func (s *Sora2APIService) doAdminRequestWithToken(ctx context.Context, method string, path string, token string, body any, out any) (*http.Response, error) {
var reader *bytes.Reader
if body != nil {
buf, err := json.Marshal(body)
if err != nil {
return nil, err
}
reader = bytes.NewReader(buf)
} else {
reader = bytes.NewReader(nil)
}
req, err := http.NewRequestWithContext(ctx, method, s.buildURL(path), reader)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+token)
if body != nil {
req.Header.Set("Content-Type", "application/json")
}
resp, err := s.adminClient.Do(req)
if err != nil {
return resp, err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return resp, fmt.Errorf("sora2api admin status: %d", resp.StatusCode)
}
if out != nil {
if err := json.NewDecoder(resp.Body).Decode(out); err != nil {
return resp, err
}
}
return resp, nil
}
func (s *Sora2APIService) getAdminToken(ctx context.Context) (string, error) {
s.adminMu.Lock()
defer s.adminMu.Unlock()
if s.adminToken != "" && time.Since(s.adminTokenAt) < s.adminTokenTTL {
return s.adminToken, nil
}
if !s.AdminEnabled() {
return "", errors.New("sora2api admin not configured")
}
payload := map[string]string{
"username": s.adminUsername,
"password": s.adminPassword,
}
buf, err := json.Marshal(payload)
if err != nil {
return "", err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, s.buildURL("/api/login"), bytes.NewReader(buf))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
resp, err := s.adminClient.Do(req)
if err != nil {
return "", err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("sora2api login failed: %d", resp.StatusCode)
}
var result struct {
Success bool `json:"success"`
Token string `json:"token"`
Message string `json:"message"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", err
}
if !result.Success || result.Token == "" {
if result.Message == "" {
result.Message = "sora2api login failed"
}
return "", errors.New(result.Message)
}
s.adminToken = result.Token
s.adminTokenAt = time.Now()
return result.Token, nil
}
func (s *Sora2APIService) invalidateAdminToken() {
s.adminMu.Lock()
defer s.adminMu.Unlock()
s.adminToken = ""
s.adminTokenAt = time.Time{}
}

View File

@@ -1,255 +0,0 @@
package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"log"
"net/http"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
)
// Sora2APISyncService 用于同步 Sora 账号到 sora2api token 池
type Sora2APISyncService struct {
sora2api *Sora2APIService
accountRepo AccountRepository
httpClient *http.Client
}
func NewSora2APISyncService(sora2api *Sora2APIService, accountRepo AccountRepository) *Sora2APISyncService {
return &Sora2APISyncService{
sora2api: sora2api,
accountRepo: accountRepo,
httpClient: &http.Client{Timeout: 10 * time.Second},
}
}
func (s *Sora2APISyncService) Enabled() bool {
return s != nil && s.sora2api != nil && s.sora2api.AdminEnabled()
}
// SyncAccount 将 Sora 账号同步到 sora2api导入或更新
func (s *Sora2APISyncService) SyncAccount(ctx context.Context, account *Account) error {
if !s.Enabled() {
return nil
}
if account == nil || account.Platform != PlatformSora {
return nil
}
accessToken := strings.TrimSpace(account.GetCredential("access_token"))
if accessToken == "" {
return errors.New("sora 账号缺少 access_token")
}
email, updated := s.resolveAccountEmail(ctx, account)
if email == "" {
return errors.New("无法解析 Sora 账号邮箱")
}
if updated && s.accountRepo != nil {
if err := s.accountRepo.Update(ctx, account); err != nil {
log.Printf("[SoraSync] 更新账号邮箱失败: account_id=%d err=%v", account.ID, err)
}
}
item := Sora2APIImportTokenItem{
Email: email,
AccessToken: accessToken,
SessionToken: strings.TrimSpace(account.GetCredential("session_token")),
RefreshToken: strings.TrimSpace(account.GetCredential("refresh_token")),
ClientID: strings.TrimSpace(account.GetCredential("client_id")),
Remark: account.Name,
IsActive: account.IsActive() && account.Schedulable,
ImageEnabled: true,
VideoEnabled: true,
ImageConcurrency: normalizeSoraConcurrency(account.Concurrency),
VideoConcurrency: normalizeSoraConcurrency(account.Concurrency),
}
if err := s.sora2api.ImportTokens(ctx, []Sora2APIImportTokenItem{item}); err != nil {
return err
}
return nil
}
// DisableAccount 禁用 sora2api 中的 token
func (s *Sora2APISyncService) DisableAccount(ctx context.Context, account *Account) error {
if !s.Enabled() {
return nil
}
if account == nil || account.Platform != PlatformSora {
return nil
}
tokenID, err := s.resolveTokenID(ctx, account)
if err != nil {
return err
}
return s.sora2api.DisableToken(ctx, tokenID)
}
// DeleteAccount 删除 sora2api 中的 token
func (s *Sora2APISyncService) DeleteAccount(ctx context.Context, account *Account) error {
if !s.Enabled() {
return nil
}
if account == nil || account.Platform != PlatformSora {
return nil
}
tokenID, err := s.resolveTokenID(ctx, account)
if err != nil {
return err
}
return s.sora2api.DeleteToken(ctx, tokenID)
}
func normalizeSoraConcurrency(value int) int {
if value <= 0 {
return -1
}
return value
}
func (s *Sora2APISyncService) resolveAccountEmail(ctx context.Context, account *Account) (string, bool) {
if account == nil {
return "", false
}
if email := strings.TrimSpace(account.GetCredential("email")); email != "" {
return email, false
}
if email := strings.TrimSpace(account.GetExtraString("email")); email != "" {
if account.Credentials == nil {
account.Credentials = map[string]any{}
}
account.Credentials["email"] = email
return email, true
}
if email := strings.TrimSpace(account.GetExtraString("sora_email")); email != "" {
if account.Credentials == nil {
account.Credentials = map[string]any{}
}
account.Credentials["email"] = email
return email, true
}
accessToken := strings.TrimSpace(account.GetCredential("access_token"))
if accessToken != "" {
if email := extractEmailFromAccessToken(accessToken); email != "" {
if account.Credentials == nil {
account.Credentials = map[string]any{}
}
account.Credentials["email"] = email
return email, true
}
if email := s.fetchEmailFromSora(ctx, accessToken); email != "" {
if account.Credentials == nil {
account.Credentials = map[string]any{}
}
account.Credentials["email"] = email
return email, true
}
}
return "", false
}
func (s *Sora2APISyncService) resolveTokenID(ctx context.Context, account *Account) (int64, error) {
if account == nil {
return 0, errors.New("account is nil")
}
if account.Extra != nil {
if v, ok := account.Extra["sora2api_token_id"]; ok {
if id, ok := v.(float64); ok && id > 0 {
return int64(id), nil
}
if id, ok := v.(int64); ok && id > 0 {
return id, nil
}
if id, ok := v.(int); ok && id > 0 {
return int64(id), nil
}
}
}
email := strings.TrimSpace(account.GetCredential("email"))
if email == "" {
email, _ = s.resolveAccountEmail(ctx, account)
}
if email == "" {
return 0, errors.New("sora2api token email missing")
}
tokenID, err := s.findTokenIDByEmail(ctx, email)
if err != nil {
return 0, err
}
return tokenID, nil
}
func (s *Sora2APISyncService) findTokenIDByEmail(ctx context.Context, email string) (int64, error) {
if !s.Enabled() {
return 0, errors.New("sora2api admin not configured")
}
tokens, err := s.sora2api.ListTokens(ctx)
if err != nil {
return 0, err
}
for _, token := range tokens {
if strings.EqualFold(strings.TrimSpace(token.Email), strings.TrimSpace(email)) {
return token.ID, nil
}
}
return 0, fmt.Errorf("sora2api token not found for email: %s", email)
}
func extractEmailFromAccessToken(accessToken string) string {
parser := jwt.NewParser(jwt.WithoutClaimsValidation())
claims := jwt.MapClaims{}
_, _, err := parser.ParseUnverified(accessToken, claims)
if err != nil {
return ""
}
if email, ok := claims["email"].(string); ok && strings.TrimSpace(email) != "" {
return email
}
if profile, ok := claims["https://api.openai.com/profile"].(map[string]any); ok {
if email, ok := profile["email"].(string); ok && strings.TrimSpace(email) != "" {
return email
}
}
return ""
}
func (s *Sora2APISyncService) fetchEmailFromSora(ctx context.Context, accessToken string) string {
if s.httpClient == nil {
return ""
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, soraMeAPIURL, nil)
if err != nil {
return ""
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
req.Header.Set("Accept", "application/json")
resp, err := s.httpClient.Do(req)
if err != nil {
return ""
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return ""
}
var payload map[string]any
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
return ""
}
if email, ok := payload["email"].(string); ok && strings.TrimSpace(email) != "" {
return email
}
return ""
}

View File

@@ -0,0 +1,884 @@
package service
import (
"bytes"
"context"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"math/rand"
"mime"
"mime/multipart"
"net/http"
"net/textproto"
"net/url"
"path"
"strconv"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/google/uuid"
"golang.org/x/crypto/sha3"
)
const (
soraChatGPTBaseURL = "https://chatgpt.com"
soraSentinelFlow = "sora_2_create_task"
soraDefaultUserAgent = "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)"
)
const (
soraPowMaxIteration = 500000
)
var soraPowCores = []int{8, 16, 24, 32}
var soraPowScripts = []string{
"https://cdn.oaistatic.com/_next/static/cXh69klOLzS0Gy2joLDRS/_ssgManifest.js?dpl=453ebaec0d44c2decab71692e1bfe39be35a24b3",
}
var soraPowDPL = []string{
"prod-f501fe933b3edf57aea882da888e1a544df99840",
}
var soraPowNavigatorKeys = []string{
"registerProtocolHandlerfunction registerProtocolHandler() { [native code] }",
"storage[object StorageManager]",
"locks[object LockManager]",
"appCodeNameMozilla",
"permissions[object Permissions]",
"webdriverfalse",
"vendorGoogle Inc.",
"mediaDevices[object MediaDevices]",
"cookieEnabledtrue",
"productGecko",
"productSub20030107",
"hardwareConcurrency32",
"onLinetrue",
}
var soraPowDocumentKeys = []string{
"_reactListeningo743lnnpvdg",
"location",
}
var soraPowWindowKeys = []string{
"0", "window", "self", "document", "name", "location",
"navigator", "screen", "innerWidth", "innerHeight",
"localStorage", "sessionStorage", "crypto", "performance",
"fetch", "setTimeout", "setInterval", "console",
}
var soraDesktopUserAgents = []string{
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36",
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36",
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/129.0.0.0 Safari/537.36",
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36",
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36",
"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36",
"Mozilla/5.0 (Windows NT 11.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36",
}
var soraRand = rand.New(rand.NewSource(time.Now().UnixNano()))
var soraRandMu sync.Mutex
var soraPerfStart = time.Now()
// SoraClient 定义直连 Sora 的任务操作接口。
type SoraClient interface {
Enabled() bool
UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error)
CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error)
CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error)
GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error)
GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error)
}
// SoraImageRequest 图片生成请求参数
type SoraImageRequest struct {
Prompt string
Width int
Height int
MediaID string
}
// SoraVideoRequest 视频生成请求参数
type SoraVideoRequest struct {
Prompt string
Orientation string
Frames int
Model string
Size string
MediaID string
RemixTargetID string
}
// SoraImageTaskStatus 图片任务状态
type SoraImageTaskStatus struct {
ID string
Status string
ProgressPct float64
URLs []string
ErrorMsg string
}
// SoraVideoTaskStatus 视频任务状态
type SoraVideoTaskStatus struct {
ID string
Status string
ProgressPct int
URLs []string
ErrorMsg string
}
// SoraUpstreamError 上游错误
type SoraUpstreamError struct {
StatusCode int
Message string
Headers http.Header
Body []byte
}
func (e *SoraUpstreamError) Error() string {
if e == nil {
return "sora upstream error"
}
if e.Message != "" {
return fmt.Sprintf("sora upstream error: %d %s", e.StatusCode, e.Message)
}
return fmt.Sprintf("sora upstream error: %d", e.StatusCode)
}
// SoraDirectClient 直连 Sora 实现
type SoraDirectClient struct {
cfg *config.Config
httpUpstream HTTPUpstream
tokenProvider *OpenAITokenProvider
}
// NewSoraDirectClient 创建 Sora 直连客户端
func NewSoraDirectClient(cfg *config.Config, httpUpstream HTTPUpstream, tokenProvider *OpenAITokenProvider) *SoraDirectClient {
return &SoraDirectClient{
cfg: cfg,
httpUpstream: httpUpstream,
tokenProvider: tokenProvider,
}
}
// Enabled 判断是否启用 Sora 直连
func (c *SoraDirectClient) Enabled() bool {
if c == nil || c.cfg == nil {
return false
}
return strings.TrimSpace(c.cfg.Sora.Client.BaseURL) != ""
}
func (c *SoraDirectClient) UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) {
if len(data) == 0 {
return "", errors.New("empty image data")
}
token, err := c.getAccessToken(ctx, account)
if err != nil {
return "", err
}
if filename == "" {
filename = "image.png"
}
var body bytes.Buffer
writer := multipart.NewWriter(&body)
contentType := mime.TypeByExtension(path.Ext(filename))
if contentType == "" {
contentType = "application/octet-stream"
}
partHeader := make(textproto.MIMEHeader)
partHeader.Set("Content-Disposition", fmt.Sprintf(`form-data; name="file"; filename="%s"`, filename))
partHeader.Set("Content-Type", contentType)
part, err := writer.CreatePart(partHeader)
if err != nil {
return "", err
}
if _, err := part.Write(data); err != nil {
return "", err
}
if err := writer.WriteField("file_name", filename); err != nil {
return "", err
}
if err := writer.Close(); err != nil {
return "", err
}
headers := c.buildBaseHeaders(token, c.defaultUserAgent())
headers.Set("Content-Type", writer.FormDataContentType())
respBody, _, err := c.doRequest(ctx, account, http.MethodPost, c.buildURL("/uploads"), headers, &body, false)
if err != nil {
return "", err
}
var payload map[string]any
if err := json.Unmarshal(respBody, &payload); err != nil {
return "", fmt.Errorf("parse upload response: %w", err)
}
id, _ := payload["id"].(string)
if strings.TrimSpace(id) == "" {
return "", errors.New("upload response missing id")
}
return id, nil
}
func (c *SoraDirectClient) CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return "", err
}
operation := "simple_compose"
inpaintItems := []map[string]any{}
if strings.TrimSpace(req.MediaID) != "" {
operation = "remix"
inpaintItems = append(inpaintItems, map[string]any{
"type": "image",
"frame_index": 0,
"upload_media_id": req.MediaID,
})
}
payload := map[string]any{
"type": "image_gen",
"operation": operation,
"prompt": req.Prompt,
"width": req.Width,
"height": req.Height,
"n_variants": 1,
"n_frames": 1,
"inpaint_items": inpaintItems,
}
headers := c.buildBaseHeaders(token, c.defaultUserAgent())
headers.Set("Content-Type", "application/json")
headers.Set("Origin", "https://sora.chatgpt.com")
headers.Set("Referer", "https://sora.chatgpt.com/")
body, err := json.Marshal(payload)
if err != nil {
return "", err
}
sentinel, err := c.generateSentinelToken(ctx, account, token)
if err != nil {
return "", err
}
headers.Set("openai-sentinel-token", sentinel)
respBody, _, err := c.doRequest(ctx, account, http.MethodPost, c.buildURL("/video_gen"), headers, bytes.NewReader(body), true)
if err != nil {
return "", err
}
var resp map[string]any
if err := json.Unmarshal(respBody, &resp); err != nil {
return "", err
}
taskID, _ := resp["id"].(string)
if strings.TrimSpace(taskID) == "" {
return "", errors.New("image task response missing id")
}
return taskID, nil
}
func (c *SoraDirectClient) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return "", err
}
orientation := req.Orientation
if orientation == "" {
orientation = "landscape"
}
nFrames := req.Frames
if nFrames <= 0 {
nFrames = 450
}
model := req.Model
if model == "" {
model = "sy_8"
}
size := req.Size
if size == "" {
size = "small"
}
inpaintItems := []map[string]any{}
if strings.TrimSpace(req.MediaID) != "" {
inpaintItems = append(inpaintItems, map[string]any{
"kind": "upload",
"upload_id": req.MediaID,
})
}
payload := map[string]any{
"kind": "video",
"prompt": req.Prompt,
"orientation": orientation,
"size": size,
"n_frames": nFrames,
"model": model,
"inpaint_items": inpaintItems,
}
if strings.TrimSpace(req.RemixTargetID) != "" {
payload["remix_target_id"] = req.RemixTargetID
payload["cameo_ids"] = []string{}
payload["cameo_replacements"] = map[string]any{}
}
headers := c.buildBaseHeaders(token, c.defaultUserAgent())
headers.Set("Content-Type", "application/json")
headers.Set("Origin", "https://sora.chatgpt.com")
headers.Set("Referer", "https://sora.chatgpt.com/")
body, err := json.Marshal(payload)
if err != nil {
return "", err
}
sentinel, err := c.generateSentinelToken(ctx, account, token)
if err != nil {
return "", err
}
headers.Set("openai-sentinel-token", sentinel)
respBody, _, err := c.doRequest(ctx, account, http.MethodPost, c.buildURL("/nf/create"), headers, bytes.NewReader(body), true)
if err != nil {
return "", err
}
var resp map[string]any
if err := json.Unmarshal(respBody, &resp); err != nil {
return "", err
}
taskID, _ := resp["id"].(string)
if strings.TrimSpace(taskID) == "" {
return "", errors.New("video task response missing id")
}
return taskID, nil
}
func (c *SoraDirectClient) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return nil, err
}
headers := c.buildBaseHeaders(token, c.defaultUserAgent())
respBody, _, err := c.doRequest(ctx, account, http.MethodGet, c.buildURL("/v2/recent_tasks?limit=20"), headers, nil, false)
if err != nil {
return nil, err
}
var resp map[string]any
if err := json.Unmarshal(respBody, &resp); err != nil {
return nil, err
}
taskResponses, _ := resp["task_responses"].([]any)
for _, item := range taskResponses {
taskResp, ok := item.(map[string]any)
if !ok {
continue
}
if id, _ := taskResp["id"].(string); id == taskID {
status := strings.TrimSpace(fmt.Sprintf("%v", taskResp["status"]))
progress := 0.0
if v, ok := taskResp["progress_pct"].(float64); ok {
progress = v
}
urls := []string{}
if generations, ok := taskResp["generations"].([]any); ok {
for _, genItem := range generations {
gen, ok := genItem.(map[string]any)
if !ok {
continue
}
if urlStr, ok := gen["url"].(string); ok && strings.TrimSpace(urlStr) != "" {
urls = append(urls, urlStr)
}
}
}
return &SoraImageTaskStatus{
ID: taskID,
Status: status,
ProgressPct: progress,
URLs: urls,
}, nil
}
}
return &SoraImageTaskStatus{ID: taskID, Status: "processing"}, nil
}
func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return nil, err
}
headers := c.buildBaseHeaders(token, c.defaultUserAgent())
respBody, _, err := c.doRequest(ctx, account, http.MethodGet, c.buildURL("/nf/pending/v2"), headers, nil, false)
if err != nil {
return nil, err
}
var pending any
if err := json.Unmarshal(respBody, &pending); err == nil {
if list, ok := pending.([]any); ok {
for _, item := range list {
task, ok := item.(map[string]any)
if !ok {
continue
}
if id, _ := task["id"].(string); id == taskID {
progress := 0
if v, ok := task["progress_pct"].(float64); ok {
progress = int(v * 100)
}
status := strings.TrimSpace(fmt.Sprintf("%v", task["status"]))
return &SoraVideoTaskStatus{
ID: taskID,
Status: status,
ProgressPct: progress,
}, nil
}
}
}
}
respBody, _, err = c.doRequest(ctx, account, http.MethodGet, c.buildURL("/project_y/profile/drafts?limit=15"), headers, nil, false)
if err != nil {
return nil, err
}
var draftsResp map[string]any
if err := json.Unmarshal(respBody, &draftsResp); err != nil {
return nil, err
}
items, _ := draftsResp["items"].([]any)
for _, item := range items {
draft, ok := item.(map[string]any)
if !ok {
continue
}
if id, _ := draft["task_id"].(string); id == taskID {
kind := strings.TrimSpace(fmt.Sprintf("%v", draft["kind"]))
reason := strings.TrimSpace(fmt.Sprintf("%v", draft["reason_str"]))
if reason == "" {
reason = strings.TrimSpace(fmt.Sprintf("%v", draft["markdown_reason_str"]))
}
urlStr := strings.TrimSpace(fmt.Sprintf("%v", draft["downloadable_url"]))
if urlStr == "" {
urlStr = strings.TrimSpace(fmt.Sprintf("%v", draft["url"]))
}
if kind == "sora_content_violation" || reason != "" || urlStr == "" {
msg := reason
if msg == "" {
msg = "Content violates guardrails"
}
return &SoraVideoTaskStatus{
ID: taskID,
Status: "failed",
ErrorMsg: msg,
}, nil
}
return &SoraVideoTaskStatus{
ID: taskID,
Status: "completed",
URLs: []string{urlStr},
}, nil
}
}
return &SoraVideoTaskStatus{ID: taskID, Status: "processing"}, nil
}
func (c *SoraDirectClient) buildURL(endpoint string) string {
base := ""
if c != nil && c.cfg != nil {
base = strings.TrimRight(strings.TrimSpace(c.cfg.Sora.Client.BaseURL), "/")
}
if base == "" {
return endpoint
}
if strings.HasPrefix(endpoint, "/") {
return base + endpoint
}
return base + "/" + endpoint
}
func (c *SoraDirectClient) defaultUserAgent() string {
if c == nil || c.cfg == nil {
return soraDefaultUserAgent
}
ua := strings.TrimSpace(c.cfg.Sora.Client.UserAgent)
if ua == "" {
return soraDefaultUserAgent
}
return ua
}
func (c *SoraDirectClient) getAccessToken(ctx context.Context, account *Account) (string, error) {
if account == nil {
return "", errors.New("account is nil")
}
if c.tokenProvider != nil {
return c.tokenProvider.GetAccessToken(ctx, account)
}
token := strings.TrimSpace(account.GetCredential("access_token"))
if token == "" {
return "", errors.New("access_token not found")
}
return token, nil
}
func (c *SoraDirectClient) buildBaseHeaders(token, userAgent string) http.Header {
headers := http.Header{}
if token != "" {
headers.Set("Authorization", "Bearer "+token)
}
if userAgent != "" {
headers.Set("User-Agent", userAgent)
}
if c != nil && c.cfg != nil {
for key, value := range c.cfg.Sora.Client.Headers {
if strings.EqualFold(key, "authorization") || strings.EqualFold(key, "openai-sentinel-token") {
continue
}
headers.Set(key, value)
}
}
return headers
}
func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, method, urlStr string, headers http.Header, body io.Reader, allowRetry bool) ([]byte, http.Header, error) {
if strings.TrimSpace(urlStr) == "" {
return nil, nil, errors.New("empty upstream url")
}
timeout := 0
if c != nil && c.cfg != nil {
timeout = c.cfg.Sora.Client.TimeoutSeconds
}
if timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, time.Duration(timeout)*time.Second)
defer cancel()
}
maxRetries := 0
if allowRetry && c != nil && c.cfg != nil {
maxRetries = c.cfg.Sora.Client.MaxRetries
}
if maxRetries < 0 {
maxRetries = 0
}
var bodyBytes []byte
if body != nil {
b, err := io.ReadAll(body)
if err != nil {
return nil, nil, err
}
bodyBytes = b
}
attempts := maxRetries + 1
for attempt := 1; attempt <= attempts; attempt++ {
var reader io.Reader
if bodyBytes != nil {
reader = bytes.NewReader(bodyBytes)
}
req, err := http.NewRequestWithContext(ctx, method, urlStr, reader)
if err != nil {
return nil, nil, err
}
req.Header = headers.Clone()
start := time.Now()
proxyURL := ""
if account != nil && account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
resp, err := c.doHTTP(req, proxyURL, account)
if err != nil {
if attempt < attempts && allowRetry {
c.sleepRetry(attempt)
continue
}
return nil, nil, err
}
respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
if readErr != nil {
return nil, resp.Header, readErr
}
if c.cfg != nil && c.cfg.Sora.Client.Debug {
log.Printf("[SoraClient] %s %s status=%d cost=%s", method, sanitizeSoraLogURL(urlStr), resp.StatusCode, time.Since(start))
}
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
upstreamErr := c.buildUpstreamError(resp.StatusCode, resp.Header, respBody)
if allowRetry && attempt < attempts && (resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode >= 500) {
c.sleepRetry(attempt)
continue
}
return nil, resp.Header, upstreamErr
}
return respBody, resp.Header, nil
}
return nil, nil, errors.New("upstream retries exhausted")
}
func (c *SoraDirectClient) doHTTP(req *http.Request, proxyURL string, account *Account) (*http.Response, error) {
enableTLS := false
if c != nil && c.cfg != nil && c.cfg.Gateway.TLSFingerprint.Enabled && !c.cfg.Sora.Client.DisableTLSFingerprint {
enableTLS = true
}
if c.httpUpstream != nil {
accountID := int64(0)
accountConcurrency := 0
if account != nil {
accountID = account.ID
accountConcurrency = account.Concurrency
}
return c.httpUpstream.DoWithTLS(req, proxyURL, accountID, accountConcurrency, enableTLS)
}
return http.DefaultClient.Do(req)
}
func (c *SoraDirectClient) sleepRetry(attempt int) {
backoff := time.Duration(attempt*attempt) * time.Second
if backoff > 10*time.Second {
backoff = 10 * time.Second
}
time.Sleep(backoff)
}
func (c *SoraDirectClient) buildUpstreamError(status int, headers http.Header, body []byte) error {
msg := strings.TrimSpace(extractUpstreamErrorMessage(body))
msg = sanitizeUpstreamErrorMessage(msg)
if msg == "" {
msg = truncateForLog(body, 256)
}
return &SoraUpstreamError{
StatusCode: status,
Message: msg,
Headers: headers,
Body: body,
}
}
func (c *SoraDirectClient) generateSentinelToken(ctx context.Context, account *Account, accessToken string) (string, error) {
reqID := uuid.NewString()
userAgent := soraRandChoice(soraDesktopUserAgents)
powToken := soraGetPowToken(userAgent)
payload := map[string]any{
"p": powToken,
"flow": soraSentinelFlow,
"id": reqID,
}
body, err := json.Marshal(payload)
if err != nil {
return "", err
}
headers := http.Header{}
headers.Set("Accept", "application/json, text/plain, */*")
headers.Set("Content-Type", "application/json")
headers.Set("Origin", "https://sora.chatgpt.com")
headers.Set("Referer", "https://sora.chatgpt.com/")
headers.Set("User-Agent", userAgent)
if accessToken != "" {
headers.Set("Authorization", "Bearer "+accessToken)
}
urlStr := soraChatGPTBaseURL + "/backend-api/sentinel/req"
respBody, _, err := c.doRequest(ctx, account, http.MethodPost, urlStr, headers, bytes.NewReader(body), true)
if err != nil {
return "", err
}
var resp map[string]any
if err := json.Unmarshal(respBody, &resp); err != nil {
return "", err
}
sentinel := soraBuildSentinelToken(soraSentinelFlow, reqID, powToken, resp, userAgent)
if sentinel == "" {
return "", errors.New("failed to build sentinel token")
}
return sentinel, nil
}
func soraRandChoice(items []string) string {
if len(items) == 0 {
return ""
}
soraRandMu.Lock()
idx := soraRand.Intn(len(items))
soraRandMu.Unlock()
return items[idx]
}
func soraGetPowToken(userAgent string) string {
configList := soraBuildPowConfig(userAgent)
seed := strconv.FormatFloat(soraRandFloat(), 'f', -1, 64)
difficulty := "0fffff"
solution, _ := soraSolvePow(seed, difficulty, configList)
return "gAAAAAC" + solution
}
func soraRandFloat() float64 {
soraRandMu.Lock()
defer soraRandMu.Unlock()
return soraRand.Float64()
}
func soraBuildPowConfig(userAgent string) []any {
screen := soraRandChoice([]string{
strconv.Itoa(1920 + 1080),
strconv.Itoa(2560 + 1440),
strconv.Itoa(1920 + 1200),
strconv.Itoa(2560 + 1600),
})
screenVal, _ := strconv.Atoi(screen)
perfMs := float64(time.Since(soraPerfStart).Milliseconds())
wallMs := float64(time.Now().UnixNano()) / 1e6
diff := wallMs - perfMs
return []any{
screenVal,
soraPowParseTime(),
4294705152,
0,
userAgent,
soraRandChoice(soraPowScripts),
soraRandChoice(soraPowDPL),
"en-US",
"en-US,es-US,en,es",
0,
soraRandChoice(soraPowNavigatorKeys),
soraRandChoice(soraPowDocumentKeys),
soraRandChoice(soraPowWindowKeys),
perfMs,
uuid.NewString(),
"",
soraRandChoiceInt(soraPowCores),
diff,
}
}
func soraRandChoiceInt(items []int) int {
if len(items) == 0 {
return 0
}
soraRandMu.Lock()
idx := soraRand.Intn(len(items))
soraRandMu.Unlock()
return items[idx]
}
func soraPowParseTime() string {
loc := time.FixedZone("EST", -5*3600)
return time.Now().In(loc).Format("Mon Jan 02 2006 15:04:05 GMT-0700 (Eastern Standard Time)")
}
func soraSolvePow(seed, difficulty string, configList []any) (string, bool) {
diffLen := len(difficulty) / 2
target, err := hexDecodeString(difficulty)
if err != nil {
return "", false
}
seedBytes := []byte(seed)
part1 := mustMarshalJSON(configList[:3])
part2 := mustMarshalJSON(configList[4:9])
part3 := mustMarshalJSON(configList[10:])
staticPart1 := append(part1[:len(part1)-1], ',')
staticPart2 := append([]byte(","), append(part2[1:len(part2)-1], ',')...)
staticPart3 := append([]byte(","), part3[1:]...)
for i := 0; i < soraPowMaxIteration; i++ {
dynamicI := []byte(strconv.Itoa(i))
dynamicJ := []byte(strconv.Itoa(i >> 1))
finalJSON := make([]byte, 0, len(staticPart1)+len(dynamicI)+len(staticPart2)+len(dynamicJ)+len(staticPart3))
finalJSON = append(finalJSON, staticPart1...)
finalJSON = append(finalJSON, dynamicI...)
finalJSON = append(finalJSON, staticPart2...)
finalJSON = append(finalJSON, dynamicJ...)
finalJSON = append(finalJSON, staticPart3...)
b64 := base64.StdEncoding.EncodeToString(finalJSON)
hash := sha3.Sum512(append(seedBytes, []byte(b64)...))
if bytes.Compare(hash[:diffLen], target[:diffLen]) <= 0 {
return b64, true
}
}
errorToken := "wQ8Lk5FbGpA2NcR9dShT6gYjU7VxZ4D" + base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("\"%s\"", seed)))
return errorToken, false
}
func soraBuildSentinelToken(flow, reqID, powToken string, resp map[string]any, userAgent string) string {
finalPow := powToken
proof, _ := resp["proofofwork"].(map[string]any)
if required, _ := proof["required"].(bool); required {
seed, _ := proof["seed"].(string)
difficulty, _ := proof["difficulty"].(string)
if seed != "" && difficulty != "" {
configList := soraBuildPowConfig(userAgent)
solution, _ := soraSolvePow(seed, difficulty, configList)
finalPow = "gAAAAAB" + solution
}
}
if !strings.HasSuffix(finalPow, "~S") {
finalPow += "~S"
}
turnstile, _ := resp["turnstile"].(map[string]any)
tokenPayload := map[string]any{
"p": finalPow,
"t": safeMapString(turnstile, "dx"),
"c": safeString(resp["token"]),
"id": reqID,
"flow": flow,
}
encoded, _ := json.Marshal(tokenPayload)
return string(encoded)
}
func safeMapString(m map[string]any, key string) string {
if m == nil {
return ""
}
if v, ok := m[key]; ok {
return safeString(v)
}
return ""
}
func safeString(v any) string {
switch val := v.(type) {
case string:
return val
default:
return fmt.Sprintf("%v", val)
}
}
func mustMarshalJSON(v any) []byte {
b, _ := json.Marshal(v)
return b
}
func hexDecodeString(s string) ([]byte, error) {
dst := make([]byte, len(s)/2)
_, err := hex.Decode(dst, []byte(s))
return dst, err
}
func sanitizeSoraLogURL(raw string) string {
parsed, err := url.Parse(raw)
if err != nil {
return raw
}
q := parsed.Query()
q.Del("sig")
q.Del("expires")
parsed.RawQuery = q.Encode()
return parsed.String()
}

View File

@@ -0,0 +1,54 @@
//go:build unit
package service
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func TestSoraDirectClient_DoRequestSuccess(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"ok":true}`))
}))
defer server.Close()
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{BaseURL: server.URL},
},
}
client := NewSoraDirectClient(cfg, nil, nil)
body, _, err := client.doRequest(context.Background(), &Account{ID: 1}, http.MethodGet, server.URL, http.Header{}, nil, false)
require.NoError(t, err)
require.Contains(t, string(body), "ok")
}
func TestSoraDirectClient_BuildBaseHeaders(t *testing.T) {
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
Headers: map[string]string{
"X-Test": "yes",
"Authorization": "should-ignore",
"openai-sentinel-token": "skip",
},
},
},
}
client := NewSoraDirectClient(cfg, nil, nil)
headers := client.buildBaseHeaders("token-123", "UA")
require.Equal(t, "Bearer token-123", headers.Get("Authorization"))
require.Equal(t, "UA", headers.Get("User-Agent"))
require.Equal(t, "yes", headers.Get("X-Test"))
require.Empty(t, headers.Get("openai-sentinel-token"))
}

View File

@@ -4,10 +4,12 @@ import (
"bufio"
"bytes"
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"mime"
"net/http"
"net/url"
"regexp"
@@ -39,23 +41,23 @@ type soraStreamingResult struct {
firstTokenMs *int
}
// SoraGatewayService handles forwarding requests to sora2api.
// SoraGatewayService handles forwarding requests to Sora upstream.
type SoraGatewayService struct {
sora2api *Sora2APIService
httpUpstream HTTPUpstream
soraClient SoraClient
mediaStorage *SoraMediaStorage
rateLimitService *RateLimitService
cfg *config.Config
}
func NewSoraGatewayService(
sora2api *Sora2APIService,
httpUpstream HTTPUpstream,
soraClient SoraClient,
mediaStorage *SoraMediaStorage,
rateLimitService *RateLimitService,
cfg *config.Config,
) *SoraGatewayService {
return &SoraGatewayService{
sora2api: sora2api,
httpUpstream: httpUpstream,
soraClient: soraClient,
mediaStorage: mediaStorage,
rateLimitService: rateLimitService,
cfg: cfg,
}
@@ -64,31 +66,53 @@ func NewSoraGatewayService(
func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte, clientStream bool) (*ForwardResult, error) {
startTime := time.Now()
if s.sora2api == nil || !s.sora2api.Enabled() {
if s.soraClient == nil || !s.soraClient.Enabled() {
if c != nil {
c.JSON(http.StatusServiceUnavailable, gin.H{
"error": gin.H{
"type": "api_error",
"message": "sora2api 未配置",
"message": "Sora 上游未配置",
},
})
}
return nil, errors.New("sora2api not configured")
return nil, errors.New("sora upstream not configured")
}
var reqBody map[string]any
if err := json.Unmarshal(body, &reqBody); err != nil {
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body", clientStream)
return nil, fmt.Errorf("parse request: %w", err)
}
reqModel, _ := reqBody["model"].(string)
reqStream, _ := reqBody["stream"].(bool)
if strings.TrimSpace(reqModel) == "" {
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "model is required", clientStream)
return nil, errors.New("model is required")
}
mappedModel := account.GetMappedModel(reqModel)
if mappedModel != reqModel && mappedModel != "" {
reqBody["model"] = mappedModel
if updated, err := json.Marshal(reqBody); err == nil {
body = updated
}
if mappedModel != "" && mappedModel != reqModel {
reqModel = mappedModel
}
modelCfg, ok := GetSoraModelConfig(reqModel)
if !ok {
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Unsupported Sora model", clientStream)
return nil, fmt.Errorf("unsupported model: %s", reqModel)
}
if modelCfg.Type == "prompt_enhance" {
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Prompt-enhance 模型暂未支持", clientStream)
return nil, fmt.Errorf("prompt-enhance not supported")
}
prompt, imageInput, videoInput, remixTargetID := extractSoraInput(reqBody)
if strings.TrimSpace(prompt) == "" {
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream)
return nil, errors.New("prompt is required")
}
if strings.TrimSpace(videoInput) != "" {
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Video input is not supported yet", clientStream)
return nil, errors.New("video input not supported")
}
reqCtx, cancel := s.withSoraTimeout(ctx, reqStream)
@@ -96,81 +120,122 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
defer cancel()
}
upstreamReq, err := s.sora2api.NewAPIRequest(reqCtx, http.MethodPost, "/v1/chat/completions", body)
if err != nil {
return nil, err
}
if c != nil {
if ua := strings.TrimSpace(c.GetHeader("User-Agent")); ua != "" {
upstreamReq.Header.Set("User-Agent", ua)
var imageData []byte
imageFilename := ""
if strings.TrimSpace(imageInput) != "" {
decoded, filename, err := decodeSoraImageInput(reqCtx, imageInput)
if err != nil {
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", err.Error(), clientStream)
return nil, err
}
}
if reqStream {
upstreamReq.Header.Set("Accept", "text/event-stream")
imageData = decoded
imageFilename = filename
}
if c != nil {
c.Set(OpsUpstreamRequestBodyKey, string(body))
mediaID := ""
if len(imageData) > 0 {
uploadID, err := s.soraClient.UploadImage(reqCtx, account, imageData, imageFilename)
if err != nil {
return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream)
}
mediaID = uploadID
}
proxyURL := ""
if account != nil && account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
taskID := ""
var err error
switch modelCfg.Type {
case "image":
taskID, err = s.soraClient.CreateImageTask(reqCtx, account, SoraImageRequest{
Prompt: prompt,
Width: modelCfg.Width,
Height: modelCfg.Height,
MediaID: mediaID,
})
case "video":
taskID, err = s.soraClient.CreateVideoTask(reqCtx, account, SoraVideoRequest{
Prompt: prompt,
Orientation: modelCfg.Orientation,
Frames: modelCfg.Frames,
Model: modelCfg.Model,
Size: modelCfg.Size,
MediaID: mediaID,
RemixTargetID: remixTargetID,
})
default:
err = fmt.Errorf("unsupported model type: %s", modelCfg.Type)
}
if err != nil {
return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream)
}
var resp *http.Response
if s.httpUpstream != nil {
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if clientStream && c != nil {
s.prepareSoraStream(c, taskID)
}
var mediaURLs []string
mediaType := modelCfg.Type
imageCount := 0
imageSize := ""
if modelCfg.Type == "image" {
urls, pollErr := s.pollImageTask(reqCtx, c, account, taskID, clientStream)
if pollErr != nil {
return nil, s.handleSoraRequestError(ctx, account, pollErr, reqModel, c, clientStream)
}
mediaURLs = urls
imageCount = len(urls)
imageSize = soraImageSizeFromModel(reqModel)
} else if modelCfg.Type == "video" {
urls, pollErr := s.pollVideoTask(reqCtx, c, account, taskID, clientStream)
if pollErr != nil {
return nil, s.handleSoraRequestError(ctx, account, pollErr, reqModel, c, clientStream)
}
mediaURLs = urls
} else {
resp, err = http.DefaultClient.Do(upstreamReq)
mediaType = "prompt"
}
if err != nil {
s.setUpstreamRequestError(c, account, err)
return nil, fmt.Errorf("upstream request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode >= 400 {
if s.shouldFailoverUpstreamError(resp.StatusCode) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
Message: upstreamMsg,
})
s.handleFailoverSideEffects(ctx, resp, account)
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
finalURLs := mediaURLs
if len(mediaURLs) > 0 && s.mediaStorage != nil && s.mediaStorage.Enabled() {
stored, storeErr := s.mediaStorage.StoreFromURLs(reqCtx, mediaType, mediaURLs)
if storeErr != nil {
return nil, s.handleSoraRequestError(ctx, account, storeErr, reqModel, c, clientStream)
}
return s.handleErrorResponse(ctx, resp, c, account, reqModel)
finalURLs = s.normalizeSoraMediaURLs(stored)
} else {
finalURLs = s.normalizeSoraMediaURLs(mediaURLs)
}
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, reqModel, clientStream)
if err != nil {
return nil, err
content := buildSoraContent(mediaType, finalURLs)
var firstTokenMs *int
if clientStream {
ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime)
if streamErr != nil {
return nil, streamErr
}
firstTokenMs = ms
} else if c != nil {
response := buildSoraNonStreamResponse(content, reqModel)
if len(finalURLs) > 0 {
response["media_url"] = finalURLs[0]
if len(finalURLs) > 1 {
response["media_urls"] = finalURLs
}
}
c.JSON(http.StatusOK, response)
}
result := &ForwardResult{
RequestID: resp.Header.Get("x-request-id"),
return &ForwardResult{
RequestID: taskID,
Model: reqModel,
Stream: clientStream,
Duration: time.Since(startTime),
FirstTokenMs: streamResult.firstTokenMs,
FirstTokenMs: firstTokenMs,
Usage: ClaudeUsage{},
MediaType: streamResult.mediaType,
MediaURL: firstMediaURL(streamResult.mediaURLs),
ImageCount: streamResult.imageCount,
ImageSize: streamResult.imageSize,
}
return result, nil
MediaType: mediaType,
MediaURL: firstMediaURL(finalURLs),
ImageCount: imageCount,
ImageSize: imageSize,
}, nil
}
func (s *SoraGatewayService) withSoraTimeout(ctx context.Context, stream bool) (context.Context, context.CancelFunc) {
@@ -780,3 +845,414 @@ func (s *SoraGatewayService) buildSoraMediaURL(path string, rawQuery string) str
}
return prefix + path + "?" + encoded
}
func (s *SoraGatewayService) prepareSoraStream(c *gin.Context, requestID string) {
if c == nil {
return
}
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
if strings.TrimSpace(requestID) != "" {
c.Header("x-request-id", requestID)
}
}
func (s *SoraGatewayService) writeSoraStream(c *gin.Context, model, content string, startTime time.Time) (*int, error) {
if c == nil {
return nil, nil
}
writer := c.Writer
flusher, _ := writer.(http.Flusher)
chunk := map[string]any{
"id": fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano()),
"object": "chat.completion.chunk",
"created": time.Now().Unix(),
"model": model,
"choices": []any{
map[string]any{
"index": 0,
"delta": map[string]any{
"content": content,
},
},
},
}
encoded, _ := json.Marshal(chunk)
if _, err := fmt.Fprintf(writer, "data: %s\n\n", encoded); err != nil {
return nil, err
}
if flusher != nil {
flusher.Flush()
}
ms := int(time.Since(startTime).Milliseconds())
finalChunk := map[string]any{
"id": chunk["id"],
"object": "chat.completion.chunk",
"created": time.Now().Unix(),
"model": model,
"choices": []any{
map[string]any{
"index": 0,
"delta": map[string]any{},
"finish_reason": "stop",
},
},
}
finalEncoded, _ := json.Marshal(finalChunk)
if _, err := fmt.Fprintf(writer, "data: %s\n\n", finalEncoded); err != nil {
return &ms, err
}
if _, err := fmt.Fprint(writer, "data: [DONE]\n\n"); err != nil {
return &ms, err
}
if flusher != nil {
flusher.Flush()
}
return &ms, nil
}
func (s *SoraGatewayService) writeSoraError(c *gin.Context, status int, errType, message string, stream bool) {
if c == nil {
return
}
if stream {
flusher, _ := c.Writer.(http.Flusher)
errorEvent := fmt.Sprintf(`event: error`+"\n"+`data: {"error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message)
_, _ = fmt.Fprint(c.Writer, errorEvent)
_, _ = fmt.Fprint(c.Writer, "data: [DONE]\n\n")
if flusher != nil {
flusher.Flush()
}
return
}
c.JSON(status, gin.H{
"error": gin.H{
"type": errType,
"message": message,
},
})
}
func (s *SoraGatewayService) handleSoraRequestError(ctx context.Context, account *Account, err error, model string, c *gin.Context, stream bool) error {
if err == nil {
return nil
}
var upstreamErr *SoraUpstreamError
if errors.As(err, &upstreamErr) {
if s.rateLimitService != nil && account != nil {
s.rateLimitService.HandleUpstreamError(ctx, account, upstreamErr.StatusCode, upstreamErr.Headers, upstreamErr.Body)
}
if s.shouldFailoverUpstreamError(upstreamErr.StatusCode) {
return &UpstreamFailoverError{StatusCode: upstreamErr.StatusCode}
}
msg := upstreamErr.Message
if override := soraProErrorMessage(model, msg); override != "" {
msg = override
}
s.writeSoraError(c, upstreamErr.StatusCode, "upstream_error", msg, stream)
return err
}
if errors.Is(err, context.DeadlineExceeded) {
s.writeSoraError(c, http.StatusGatewayTimeout, "timeout_error", "Sora generation timeout", stream)
return err
}
s.writeSoraError(c, http.StatusBadGateway, "api_error", err.Error(), stream)
return err
}
func (s *SoraGatewayService) pollImageTask(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) ([]string, error) {
interval := s.pollInterval()
maxAttempts := s.pollMaxAttempts()
lastPing := time.Now()
for attempt := 0; attempt < maxAttempts; attempt++ {
status, err := s.soraClient.GetImageTask(ctx, account, taskID)
if err != nil {
return nil, err
}
switch strings.ToLower(status.Status) {
case "succeeded", "completed":
return status.URLs, nil
case "failed":
if status.ErrorMsg != "" {
return nil, errors.New(status.ErrorMsg)
}
return nil, errors.New("Sora image generation failed")
}
if stream {
s.maybeSendPing(c, &lastPing)
}
if err := sleepWithContext(ctx, interval); err != nil {
return nil, err
}
}
return nil, errors.New("Sora image generation timeout")
}
func (s *SoraGatewayService) pollVideoTask(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) ([]string, error) {
interval := s.pollInterval()
maxAttempts := s.pollMaxAttempts()
lastPing := time.Now()
for attempt := 0; attempt < maxAttempts; attempt++ {
status, err := s.soraClient.GetVideoTask(ctx, account, taskID)
if err != nil {
return nil, err
}
switch strings.ToLower(status.Status) {
case "completed", "succeeded":
return status.URLs, nil
case "failed":
if status.ErrorMsg != "" {
return nil, errors.New(status.ErrorMsg)
}
return nil, errors.New("Sora video generation failed")
}
if stream {
s.maybeSendPing(c, &lastPing)
}
if err := sleepWithContext(ctx, interval); err != nil {
return nil, err
}
}
return nil, errors.New("Sora video generation timeout")
}
func (s *SoraGatewayService) pollInterval() time.Duration {
if s == nil || s.cfg == nil {
return 2 * time.Second
}
interval := s.cfg.Sora.Client.PollIntervalSeconds
if interval <= 0 {
interval = 2
}
return time.Duration(interval) * time.Second
}
func (s *SoraGatewayService) pollMaxAttempts() int {
if s == nil || s.cfg == nil {
return 600
}
maxAttempts := s.cfg.Sora.Client.MaxPollAttempts
if maxAttempts <= 0 {
maxAttempts = 600
}
return maxAttempts
}
func (s *SoraGatewayService) maybeSendPing(c *gin.Context, lastPing *time.Time) {
if c == nil {
return
}
interval := 10 * time.Second
if s != nil && s.cfg != nil && s.cfg.Concurrency.PingInterval > 0 {
interval = time.Duration(s.cfg.Concurrency.PingInterval) * time.Second
}
if time.Since(*lastPing) < interval {
return
}
if _, err := fmt.Fprint(c.Writer, ":\n\n"); err == nil {
if flusher, ok := c.Writer.(http.Flusher); ok {
flusher.Flush()
}
*lastPing = time.Now()
}
}
func (s *SoraGatewayService) normalizeSoraMediaURLs(urls []string) []string {
if len(urls) == 0 {
return urls
}
output := make([]string, 0, len(urls))
for _, raw := range urls {
raw = strings.TrimSpace(raw)
if raw == "" {
continue
}
if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") {
output = append(output, raw)
continue
}
pathVal := raw
if !strings.HasPrefix(pathVal, "/") {
pathVal = "/" + pathVal
}
output = append(output, s.buildSoraMediaURL(pathVal, ""))
}
return output
}
func buildSoraContent(mediaType string, urls []string) string {
switch mediaType {
case "image":
parts := make([]string, 0, len(urls))
for _, u := range urls {
parts = append(parts, fmt.Sprintf("![image](%s)", u))
}
return strings.Join(parts, "\n")
case "video":
if len(urls) == 0 {
return ""
}
return fmt.Sprintf("```html\n<video src='%s' controls></video>\n```", urls[0])
default:
return ""
}
}
func extractSoraInput(body map[string]any) (prompt, imageInput, videoInput, remixTargetID string) {
if body == nil {
return "", "", "", ""
}
if v, ok := body["remix_target_id"].(string); ok {
remixTargetID = v
}
if v, ok := body["image"].(string); ok {
imageInput = v
}
if v, ok := body["video"].(string); ok {
videoInput = v
}
if v, ok := body["prompt"].(string); ok && strings.TrimSpace(v) != "" {
prompt = v
}
if messages, ok := body["messages"].([]any); ok {
builder := strings.Builder{}
for _, raw := range messages {
msg, ok := raw.(map[string]any)
if !ok {
continue
}
role, _ := msg["role"].(string)
if role != "" && role != "user" {
continue
}
content := msg["content"]
text, img, vid := parseSoraMessageContent(content)
if text != "" {
if builder.Len() > 0 {
builder.WriteString("\n")
}
builder.WriteString(text)
}
if imageInput == "" && img != "" {
imageInput = img
}
if videoInput == "" && vid != "" {
videoInput = vid
}
}
if prompt == "" {
prompt = builder.String()
}
}
return prompt, imageInput, videoInput, remixTargetID
}
func parseSoraMessageContent(content any) (text, imageInput, videoInput string) {
switch val := content.(type) {
case string:
return val, "", ""
case []any:
builder := strings.Builder{}
for _, item := range val {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
t, _ := itemMap["type"].(string)
switch t {
case "text":
if txt, ok := itemMap["text"].(string); ok && strings.TrimSpace(txt) != "" {
if builder.Len() > 0 {
builder.WriteString("\n")
}
builder.WriteString(txt)
}
case "image_url":
if imageInput == "" {
if urlVal, ok := itemMap["image_url"].(map[string]any); ok {
imageInput = fmt.Sprintf("%v", urlVal["url"])
} else if urlStr, ok := itemMap["image_url"].(string); ok {
imageInput = urlStr
}
}
case "video_url":
if videoInput == "" {
if urlVal, ok := itemMap["video_url"].(map[string]any); ok {
videoInput = fmt.Sprintf("%v", urlVal["url"])
} else if urlStr, ok := itemMap["video_url"].(string); ok {
videoInput = urlStr
}
}
}
}
return builder.String(), imageInput, videoInput
default:
return "", "", ""
}
}
func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, error) {
raw := strings.TrimSpace(input)
if raw == "" {
return nil, "", errors.New("empty image input")
}
if strings.HasPrefix(raw, "data:") {
parts := strings.SplitN(raw, ",", 2)
if len(parts) != 2 {
return nil, "", errors.New("invalid data url")
}
meta := parts[0]
payload := parts[1]
decoded, err := base64.StdEncoding.DecodeString(payload)
if err != nil {
return nil, "", err
}
ext := ""
if strings.HasPrefix(meta, "data:") {
metaParts := strings.SplitN(meta[5:], ";", 2)
if len(metaParts) > 0 {
if exts, err := mime.ExtensionsByType(metaParts[0]); err == nil && len(exts) > 0 {
ext = exts[0]
}
}
}
filename := "image" + ext
return decoded, filename, nil
}
if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") {
return downloadSoraImageInput(ctx, raw)
}
decoded, err := base64.StdEncoding.DecodeString(raw)
if err != nil {
return nil, "", errors.New("invalid base64 image")
}
return decoded, "image.png", nil
}
func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil)
if err != nil {
return nil, "", err
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, "", err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return nil, "", fmt.Errorf("download image failed: %d", resp.StatusCode)
}
data, err := io.ReadAll(io.LimitReader(resp.Body, 20<<20))
if err != nil {
return nil, "", err
}
ext := fileExtFromURL(rawURL)
if ext == "" {
ext = fileExtFromContentType(resp.Header.Get("Content-Type"))
}
filename := "image" + ext
return data, filename, nil
}

View File

@@ -0,0 +1,99 @@
//go:build unit
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type stubSoraClientForPoll struct {
imageStatus *SoraImageTaskStatus
videoStatus *SoraVideoTaskStatus
imageCalls int
videoCalls int
}
func (s *stubSoraClientForPoll) Enabled() bool { return true }
func (s *stubSoraClientForPoll) UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) {
return "", nil
}
func (s *stubSoraClientForPoll) CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error) {
return "task-image", nil
}
func (s *stubSoraClientForPoll) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) {
return "task-video", nil
}
func (s *stubSoraClientForPoll) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) {
s.imageCalls++
return s.imageStatus, nil
}
func (s *stubSoraClientForPoll) GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error) {
s.videoCalls++
return s.videoStatus, nil
}
func TestSoraGatewayService_PollImageTaskCompleted(t *testing.T) {
client := &stubSoraClientForPoll{
imageStatus: &SoraImageTaskStatus{
Status: "completed",
URLs: []string{"https://example.com/a.png"},
},
}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
PollIntervalSeconds: 1,
MaxPollAttempts: 1,
},
},
}
service := NewSoraGatewayService(client, nil, nil, cfg)
urls, err := service.pollImageTask(context.Background(), nil, &Account{ID: 1}, "task", false)
require.NoError(t, err)
require.Equal(t, []string{"https://example.com/a.png"}, urls)
require.Equal(t, 1, client.imageCalls)
}
func TestSoraGatewayService_PollVideoTaskFailed(t *testing.T) {
client := &stubSoraClientForPoll{
videoStatus: &SoraVideoTaskStatus{
Status: "failed",
ErrorMsg: "reject",
},
}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
PollIntervalSeconds: 1,
MaxPollAttempts: 1,
},
},
}
service := NewSoraGatewayService(client, nil, nil, cfg)
urls, err := service.pollVideoTask(context.Background(), nil, &Account{ID: 1}, "task", false)
require.Error(t, err)
require.Empty(t, urls)
require.Contains(t, err.Error(), "reject")
require.Equal(t, 1, client.videoCalls)
}
func TestSoraGatewayService_BuildSoraMediaURLSigned(t *testing.T) {
cfg := &config.Config{
Gateway: config.GatewayConfig{
SoraMediaSigningKey: "test-key",
SoraMediaSignedURLTTLSeconds: 600,
},
}
service := NewSoraGatewayService(nil, nil, nil, cfg)
url := service.buildSoraMediaURL("/image/2025/01/01/a.png", "")
require.Contains(t, url, "/sora/media-signed")
require.Contains(t, url, "expires=")
require.Contains(t, url, "sig=")
}

View File

@@ -0,0 +1,117 @@
package service
import (
"log"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/robfig/cron/v3"
)
var soraCleanupCronParser = cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow)
// SoraMediaCleanupService 定期清理本地媒体文件
type SoraMediaCleanupService struct {
storage *SoraMediaStorage
cfg *config.Config
cron *cron.Cron
startOnce sync.Once
stopOnce sync.Once
}
func NewSoraMediaCleanupService(storage *SoraMediaStorage, cfg *config.Config) *SoraMediaCleanupService {
return &SoraMediaCleanupService{
storage: storage,
cfg: cfg,
}
}
func (s *SoraMediaCleanupService) Start() {
if s == nil || s.cfg == nil {
return
}
if !s.cfg.Sora.Storage.Cleanup.Enabled {
log.Printf("[SoraCleanup] not started (disabled)")
return
}
if s.storage == nil || !s.storage.Enabled() {
log.Printf("[SoraCleanup] not started (storage disabled)")
return
}
s.startOnce.Do(func() {
schedule := strings.TrimSpace(s.cfg.Sora.Storage.Cleanup.Schedule)
if schedule == "" {
log.Printf("[SoraCleanup] not started (empty schedule)")
return
}
loc := time.Local
if strings.TrimSpace(s.cfg.Timezone) != "" {
if parsed, err := time.LoadLocation(strings.TrimSpace(s.cfg.Timezone)); err == nil && parsed != nil {
loc = parsed
}
}
c := cron.New(cron.WithParser(soraCleanupCronParser), cron.WithLocation(loc))
if _, err := c.AddFunc(schedule, func() { s.runCleanup() }); err != nil {
log.Printf("[SoraCleanup] not started (invalid schedule=%q): %v", schedule, err)
return
}
s.cron = c
s.cron.Start()
log.Printf("[SoraCleanup] started (schedule=%q tz=%s)", schedule, loc.String())
})
}
func (s *SoraMediaCleanupService) Stop() {
if s == nil {
return
}
s.stopOnce.Do(func() {
if s.cron != nil {
ctx := s.cron.Stop()
select {
case <-ctx.Done():
case <-time.After(3 * time.Second):
log.Printf("[SoraCleanup] cron stop timed out")
}
}
})
}
func (s *SoraMediaCleanupService) runCleanup() {
retention := s.cfg.Sora.Storage.Cleanup.RetentionDays
if retention <= 0 {
log.Printf("[SoraCleanup] skipped (retention_days=%d)", retention)
return
}
cutoff := time.Now().AddDate(0, 0, -retention)
deleted := 0
roots := []string{s.storage.ImageRoot(), s.storage.VideoRoot()}
for _, root := range roots {
if root == "" {
continue
}
_ = filepath.Walk(root, func(p string, info os.FileInfo, err error) error {
if err != nil {
return nil
}
if info.IsDir() {
return nil
}
if info.ModTime().Before(cutoff) {
if rmErr := os.Remove(p); rmErr == nil {
deleted++
}
}
return nil
})
}
log.Printf("[SoraCleanup] cleanup finished, deleted=%d", deleted)
}

View File

@@ -0,0 +1,46 @@
//go:build unit
package service
import (
"os"
"path/filepath"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func TestSoraMediaCleanupService_RunCleanup(t *testing.T) {
tmpDir := t.TempDir()
cfg := &config.Config{
Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{
Type: "local",
LocalPath: tmpDir,
Cleanup: config.SoraStorageCleanupConfig{
Enabled: true,
RetentionDays: 1,
},
},
},
}
storage := NewSoraMediaStorage(cfg)
require.NoError(t, storage.EnsureLocalDirs())
oldImage := filepath.Join(storage.ImageRoot(), "old.png")
newVideo := filepath.Join(storage.VideoRoot(), "new.mp4")
require.NoError(t, os.WriteFile(oldImage, []byte("old"), 0o644))
require.NoError(t, os.WriteFile(newVideo, []byte("new"), 0o644))
oldTime := time.Now().Add(-48 * time.Hour)
require.NoError(t, os.Chtimes(oldImage, oldTime, oldTime))
cleanup := NewSoraMediaCleanupService(storage, cfg)
cleanup.runCleanup()
require.NoFileExists(t, oldImage)
require.FileExists(t, newVideo)
}

View File

@@ -0,0 +1,256 @@
package service
import (
"context"
"errors"
"fmt"
"io"
"log"
"mime"
"net/http"
"net/url"
"os"
"path"
"path/filepath"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/google/uuid"
)
const (
soraStorageDefaultRoot = "/app/data/sora"
)
// SoraMediaStorage 负责下载并落地 Sora 媒体
type SoraMediaStorage struct {
cfg *config.Config
root string
imageRoot string
videoRoot string
maxConcurrent int
fallbackToUpstream bool
debug bool
sem chan struct{}
ready bool
}
func NewSoraMediaStorage(cfg *config.Config) *SoraMediaStorage {
storage := &SoraMediaStorage{cfg: cfg}
storage.refreshConfig()
if storage.Enabled() {
if err := storage.EnsureLocalDirs(); err != nil {
log.Printf("[SoraStorage] 初始化失败: %v", err)
}
}
return storage
}
func (s *SoraMediaStorage) Enabled() bool {
if s == nil || s.cfg == nil {
return false
}
return strings.ToLower(strings.TrimSpace(s.cfg.Sora.Storage.Type)) == "local"
}
func (s *SoraMediaStorage) Root() string {
if s == nil {
return ""
}
return s.root
}
func (s *SoraMediaStorage) ImageRoot() string {
if s == nil {
return ""
}
return s.imageRoot
}
func (s *SoraMediaStorage) VideoRoot() string {
if s == nil {
return ""
}
return s.videoRoot
}
func (s *SoraMediaStorage) refreshConfig() {
if s == nil || s.cfg == nil {
return
}
root := strings.TrimSpace(s.cfg.Sora.Storage.LocalPath)
if root == "" {
root = soraStorageDefaultRoot
}
s.root = root
s.imageRoot = filepath.Join(root, "image")
s.videoRoot = filepath.Join(root, "video")
maxConcurrent := s.cfg.Sora.Storage.MaxConcurrentDownloads
if maxConcurrent <= 0 {
maxConcurrent = 4
}
s.maxConcurrent = maxConcurrent
s.fallbackToUpstream = s.cfg.Sora.Storage.FallbackToUpstream
s.debug = s.cfg.Sora.Storage.Debug
s.sem = make(chan struct{}, maxConcurrent)
}
// EnsureLocalDirs 创建并校验本地目录
func (s *SoraMediaStorage) EnsureLocalDirs() error {
if s == nil || !s.Enabled() {
return nil
}
if err := os.MkdirAll(s.imageRoot, 0o755); err != nil {
return fmt.Errorf("create image dir: %w", err)
}
if err := os.MkdirAll(s.videoRoot, 0o755); err != nil {
return fmt.Errorf("create video dir: %w", err)
}
s.ready = true
return nil
}
// StoreFromURLs 下载并存储媒体,返回相对路径或回退 URL
func (s *SoraMediaStorage) StoreFromURLs(ctx context.Context, mediaType string, urls []string) ([]string, error) {
if len(urls) == 0 {
return nil, nil
}
if s == nil || !s.Enabled() {
return urls, nil
}
if !s.ready {
if err := s.EnsureLocalDirs(); err != nil {
return nil, err
}
}
results := make([]string, 0, len(urls))
for _, raw := range urls {
relative, err := s.downloadAndStore(ctx, mediaType, raw)
if err != nil {
if s.fallbackToUpstream {
results = append(results, raw)
continue
}
return nil, err
}
results = append(results, relative)
}
return results, nil
}
func (s *SoraMediaStorage) downloadAndStore(ctx context.Context, mediaType, rawURL string) (string, error) {
if strings.TrimSpace(rawURL) == "" {
return "", errors.New("empty url")
}
root := s.imageRoot
if mediaType == "video" {
root = s.videoRoot
}
if root == "" {
return "", errors.New("storage root not configured")
}
retries := 3
for attempt := 1; attempt <= retries; attempt++ {
release, err := s.acquire(ctx)
if err != nil {
return "", err
}
relative, err := s.downloadOnce(ctx, root, mediaType, rawURL)
release()
if err == nil {
return relative, nil
}
if s.debug {
log.Printf("[SoraStorage] 下载失败(%d/%d): %s err=%v", attempt, retries, sanitizeSoraLogURL(rawURL), err)
}
if attempt < retries {
time.Sleep(time.Duration(attempt*attempt) * time.Second)
continue
}
return "", err
}
return "", errors.New("download retries exhausted")
}
func (s *SoraMediaStorage) downloadOnce(ctx context.Context, root, mediaType, rawURL string) (string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil)
if err != nil {
return "", err
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
return "", fmt.Errorf("download failed: %d %s", resp.StatusCode, string(body))
}
ext := fileExtFromURL(rawURL)
if ext == "" {
ext = fileExtFromContentType(resp.Header.Get("Content-Type"))
}
if ext == "" {
ext = ".bin"
}
datePath := time.Now().Format("2006/01/02")
destDir := filepath.Join(root, filepath.FromSlash(datePath))
if err := os.MkdirAll(destDir, 0o755); err != nil {
return "", err
}
filename := uuid.NewString() + ext
destPath := filepath.Join(destDir, filename)
out, err := os.Create(destPath)
if err != nil {
return "", err
}
defer func() { _ = out.Close() }()
if _, err := io.Copy(out, resp.Body); err != nil {
_ = os.Remove(destPath)
return "", err
}
relative := path.Join("/", mediaType, datePath, filename)
if s.debug {
log.Printf("[SoraStorage] 已落地 %s -> %s", sanitizeSoraLogURL(rawURL), relative)
}
return relative, nil
}
func (s *SoraMediaStorage) acquire(ctx context.Context) (func(), error) {
if s.sem == nil {
return func() {}, nil
}
select {
case s.sem <- struct{}{}:
return func() { <-s.sem }, nil
case <-ctx.Done():
return nil, ctx.Err()
}
}
func fileExtFromURL(raw string) string {
parsed, err := url.Parse(raw)
if err != nil {
return ""
}
ext := path.Ext(parsed.Path)
return strings.ToLower(ext)
}
func fileExtFromContentType(ct string) string {
if ct == "" {
return ""
}
if exts, err := mime.ExtensionsByType(ct); err == nil && len(exts) > 0 {
return strings.ToLower(exts[0])
}
return ""
}

View File

@@ -0,0 +1,69 @@
//go:build unit
package service
import (
"context"
"net/http"
"net/http/httptest"
"path/filepath"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func TestSoraMediaStorage_StoreFromURLs(t *testing.T) {
tmpDir := t.TempDir()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "image/png")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("data"))
}))
defer server.Close()
cfg := &config.Config{
Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{
Type: "local",
LocalPath: tmpDir,
MaxConcurrentDownloads: 1,
},
},
}
storage := NewSoraMediaStorage(cfg)
urls, err := storage.StoreFromURLs(context.Background(), "image", []string{server.URL + "/img.png"})
require.NoError(t, err)
require.Len(t, urls, 1)
require.True(t, strings.HasPrefix(urls[0], "/image/"))
require.True(t, strings.HasSuffix(urls[0], ".png"))
localPath := filepath.Join(tmpDir, filepath.FromSlash(strings.TrimPrefix(urls[0], "/")))
require.FileExists(t, localPath)
}
func TestSoraMediaStorage_FallbackToUpstream(t *testing.T) {
tmpDir := t.TempDir()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
defer server.Close()
cfg := &config.Config{
Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{
Type: "local",
LocalPath: tmpDir,
FallbackToUpstream: true,
},
},
}
storage := NewSoraMediaStorage(cfg)
url := server.URL + "/broken.png"
urls, err := storage.StoreFromURLs(context.Background(), "image", []string{url})
require.NoError(t, err)
require.Equal(t, []string{url}, urls)
}

View File

@@ -0,0 +1,252 @@
package service
import (
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
)
// SoraModelConfig Sora 模型配置
type SoraModelConfig struct {
Type string
Width int
Height int
Orientation string
Frames int
Model string
Size string
RequirePro bool
}
var soraModelConfigs = map[string]SoraModelConfig{
"gpt-image": {
Type: "image",
Width: 360,
Height: 360,
},
"gpt-image-landscape": {
Type: "image",
Width: 540,
Height: 360,
},
"gpt-image-portrait": {
Type: "image",
Width: 360,
Height: 540,
},
"sora2-landscape-10s": {
Type: "video",
Orientation: "landscape",
Frames: 300,
Model: "sy_8",
Size: "small",
},
"sora2-portrait-10s": {
Type: "video",
Orientation: "portrait",
Frames: 300,
Model: "sy_8",
Size: "small",
},
"sora2-landscape-15s": {
Type: "video",
Orientation: "landscape",
Frames: 450,
Model: "sy_8",
Size: "small",
},
"sora2-portrait-15s": {
Type: "video",
Orientation: "portrait",
Frames: 450,
Model: "sy_8",
Size: "small",
},
"sora2-landscape-25s": {
Type: "video",
Orientation: "landscape",
Frames: 750,
Model: "sy_8",
Size: "small",
RequirePro: true,
},
"sora2-portrait-25s": {
Type: "video",
Orientation: "portrait",
Frames: 750,
Model: "sy_8",
Size: "small",
RequirePro: true,
},
"sora2pro-landscape-10s": {
Type: "video",
Orientation: "landscape",
Frames: 300,
Model: "sy_ore",
Size: "small",
RequirePro: true,
},
"sora2pro-portrait-10s": {
Type: "video",
Orientation: "portrait",
Frames: 300,
Model: "sy_ore",
Size: "small",
RequirePro: true,
},
"sora2pro-landscape-15s": {
Type: "video",
Orientation: "landscape",
Frames: 450,
Model: "sy_ore",
Size: "small",
RequirePro: true,
},
"sora2pro-portrait-15s": {
Type: "video",
Orientation: "portrait",
Frames: 450,
Model: "sy_ore",
Size: "small",
RequirePro: true,
},
"sora2pro-landscape-25s": {
Type: "video",
Orientation: "landscape",
Frames: 750,
Model: "sy_ore",
Size: "small",
RequirePro: true,
},
"sora2pro-portrait-25s": {
Type: "video",
Orientation: "portrait",
Frames: 750,
Model: "sy_ore",
Size: "small",
RequirePro: true,
},
"sora2pro-hd-landscape-10s": {
Type: "video",
Orientation: "landscape",
Frames: 300,
Model: "sy_ore",
Size: "large",
RequirePro: true,
},
"sora2pro-hd-portrait-10s": {
Type: "video",
Orientation: "portrait",
Frames: 300,
Model: "sy_ore",
Size: "large",
RequirePro: true,
},
"sora2pro-hd-landscape-15s": {
Type: "video",
Orientation: "landscape",
Frames: 450,
Model: "sy_ore",
Size: "large",
RequirePro: true,
},
"sora2pro-hd-portrait-15s": {
Type: "video",
Orientation: "portrait",
Frames: 450,
Model: "sy_ore",
Size: "large",
RequirePro: true,
},
"prompt-enhance-short-10s": {
Type: "prompt_enhance",
},
"prompt-enhance-short-15s": {
Type: "prompt_enhance",
},
"prompt-enhance-short-20s": {
Type: "prompt_enhance",
},
"prompt-enhance-medium-10s": {
Type: "prompt_enhance",
},
"prompt-enhance-medium-15s": {
Type: "prompt_enhance",
},
"prompt-enhance-medium-20s": {
Type: "prompt_enhance",
},
"prompt-enhance-long-10s": {
Type: "prompt_enhance",
},
"prompt-enhance-long-15s": {
Type: "prompt_enhance",
},
"prompt-enhance-long-20s": {
Type: "prompt_enhance",
},
}
var soraModelIDs = []string{
"gpt-image",
"gpt-image-landscape",
"gpt-image-portrait",
"sora2-landscape-10s",
"sora2-portrait-10s",
"sora2-landscape-15s",
"sora2-portrait-15s",
"sora2-landscape-25s",
"sora2-portrait-25s",
"sora2pro-landscape-10s",
"sora2pro-portrait-10s",
"sora2pro-landscape-15s",
"sora2pro-portrait-15s",
"sora2pro-landscape-25s",
"sora2pro-portrait-25s",
"sora2pro-hd-landscape-10s",
"sora2pro-hd-portrait-10s",
"sora2pro-hd-landscape-15s",
"sora2pro-hd-portrait-15s",
"prompt-enhance-short-10s",
"prompt-enhance-short-15s",
"prompt-enhance-short-20s",
"prompt-enhance-medium-10s",
"prompt-enhance-medium-15s",
"prompt-enhance-medium-20s",
"prompt-enhance-long-10s",
"prompt-enhance-long-15s",
"prompt-enhance-long-20s",
}
// GetSoraModelConfig 返回 Sora 模型配置
func GetSoraModelConfig(model string) (SoraModelConfig, bool) {
key := strings.ToLower(strings.TrimSpace(model))
cfg, ok := soraModelConfigs[key]
return cfg, ok
}
// DefaultSoraModels returns the default Sora model list.
func DefaultSoraModels(cfg *config.Config) []openai.Model {
models := make([]openai.Model, 0, len(soraModelIDs))
for _, id := range soraModelIDs {
models = append(models, openai.Model{
ID: id,
Object: "model",
OwnedBy: "openai",
Type: "model",
DisplayName: id,
})
}
if cfg != nil && cfg.Gateway.SoraModelFilters.HidePromptEnhance {
filtered := models[:0]
for _, model := range models {
if strings.HasPrefix(strings.ToLower(model.ID), "prompt-enhance") {
continue
}
filtered = append(filtered, model)
}
models = filtered
}
return models
}

View File

@@ -63,16 +63,6 @@ func (s *TokenRefreshService) SetSoraAccountRepo(repo SoraAccountRepository) {
}
}
// SetSoraSyncService 设置 Sora2API 同步服务
// 需要在 Start() 之前调用
func (s *TokenRefreshService) SetSoraSyncService(svc *Sora2APISyncService) {
for _, refresher := range s.refreshers {
if openaiRefresher, ok := refresher.(*OpenAITokenRefresher); ok {
openaiRefresher.SetSoraSyncService(svc)
}
}
}
// Start 启动后台刷新服务
func (s *TokenRefreshService) Start() {
if !s.cfg.Enabled {

View File

@@ -86,7 +86,6 @@ type OpenAITokenRefresher struct {
openaiOAuthService *OpenAIOAuthService
accountRepo AccountRepository
soraAccountRepo SoraAccountRepository // Sora 扩展表仓储,用于双表同步
soraSyncService *Sora2APISyncService // Sora2API 同步服务
}
// NewOpenAITokenRefresher 创建 OpenAI token刷新器
@@ -104,11 +103,6 @@ func (r *OpenAITokenRefresher) SetSoraAccountRepo(repo SoraAccountRepository) {
r.soraAccountRepo = repo
}
// SetSoraSyncService 设置 Sora2API 同步服务
func (r *OpenAITokenRefresher) SetSoraSyncService(svc *Sora2APISyncService) {
r.soraSyncService = svc
}
// CanRefresh 检查是否能处理此账号
// 只处理 openai 平台的 oauth 类型账号
func (r *OpenAITokenRefresher) CanRefresh(account *Account) bool {
@@ -151,17 +145,6 @@ func (r *OpenAITokenRefresher) Refresh(ctx context.Context, account *Account) (m
go r.syncLinkedSoraAccounts(context.Background(), account.ID, newCredentials)
}
// 如果是 Sora 平台账号,同步到 sora2api不阻塞主流程
if account.Platform == PlatformSora && r.soraSyncService != nil {
syncAccount := *account
syncAccount.Credentials = newCredentials
go func() {
if err := r.soraSyncService.SyncAccount(context.Background(), &syncAccount); err != nil {
log.Printf("[TokenSync] 同步 Sora2API 失败: account_id=%d err=%v", syncAccount.ID, err)
}
}()
}
return newCredentials, nil
}
@@ -218,13 +201,6 @@ func (r *OpenAITokenRefresher) syncLinkedSoraAccounts(ctx context.Context, opena
}
}
// 2.3 同步到 sora2api如果配置
if r.soraSyncService != nil {
if err := r.soraSyncService.SyncAccount(ctx, &soraAccount); err != nil {
log.Printf("[TokenSync] 同步 sora2api 失败: account_id=%d err=%v", soraAccount.ID, err)
}
}
log.Printf("[TokenSync] 成功同步 Sora 账号 token: sora_account_id=%d openai_account_id=%d dual_table=%v",
soraAccount.ID, openaiAccountID, r.soraAccountRepo != nil)
}

View File

@@ -40,7 +40,6 @@ func ProvideEmailQueueService(emailService *EmailService) *EmailQueueService {
func ProvideTokenRefreshService(
accountRepo AccountRepository,
soraAccountRepo SoraAccountRepository, // Sora 扩展表仓储,用于双表同步
soraSyncService *Sora2APISyncService,
oauthService *OAuthService,
openaiOAuthService *OpenAIOAuthService,
geminiOAuthService *GeminiOAuthService,
@@ -51,7 +50,6 @@ func ProvideTokenRefreshService(
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, cfg)
// 注入 Sora 账号扩展表仓储,用于 OpenAI Token 刷新时同步 sora_accounts 表
svc.SetSoraAccountRepo(soraAccountRepo)
svc.SetSoraSyncService(soraSyncService)
svc.Start()
return svc
}
@@ -187,6 +185,18 @@ func ProvideOpsCleanupService(
return svc
}
// ProvideSoraMediaStorage 初始化 Sora 媒体存储
func ProvideSoraMediaStorage(cfg *config.Config) *SoraMediaStorage {
return NewSoraMediaStorage(cfg)
}
// ProvideSoraMediaCleanupService 创建并启动 Sora 媒体清理服务
func ProvideSoraMediaCleanupService(storage *SoraMediaStorage, cfg *config.Config) *SoraMediaCleanupService {
svc := NewSoraMediaCleanupService(storage, cfg)
svc.Start()
return svc
}
// ProvideOpsScheduledReportService creates and starts OpsScheduledReportService.
func ProvideOpsScheduledReportService(
opsService *OpsService,
@@ -226,6 +236,10 @@ var ProviderSet = wire.NewSet(
NewBillingCacheService,
NewAdminService,
NewGatewayService,
ProvideSoraMediaStorage,
ProvideSoraMediaCleanupService,
NewSoraDirectClient,
wire.Bind(new(SoraClient), new(*SoraDirectClient)),
NewSoraGatewayService,
NewOpenAIGatewayService,
NewOAuthService,