feat(Sora): 完成Sora网关接入与媒体能力
新增 Sora 网关路由、账号调度与同步服务\n补充媒体代理与签名 URL、模型列表动态拉取\n完善计费配置、前端支持与相关测试
This commit is contained in:
@@ -102,11 +102,16 @@ type CreateGroupInput struct {
|
||||
WeeklyLimitUSD *float64 // 周限额 (USD)
|
||||
MonthlyLimitUSD *float64 // 月限额 (USD)
|
||||
// 图片生成计费配置(仅 antigravity 平台使用)
|
||||
ImagePrice1K *float64
|
||||
ImagePrice2K *float64
|
||||
ImagePrice4K *float64
|
||||
ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
|
||||
FallbackGroupID *int64 // 降级分组 ID
|
||||
ImagePrice1K *float64
|
||||
ImagePrice2K *float64
|
||||
ImagePrice4K *float64
|
||||
// Sora 按次计费配置
|
||||
SoraImagePrice360 *float64
|
||||
SoraImagePrice540 *float64
|
||||
SoraVideoPricePerRequest *float64
|
||||
SoraVideoPricePerRequestHD *float64
|
||||
ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
|
||||
FallbackGroupID *int64 // 降级分组 ID
|
||||
// 模型路由配置(仅 anthropic 平台使用)
|
||||
ModelRouting map[string][]int64
|
||||
ModelRoutingEnabled bool // 是否启用模型路由
|
||||
@@ -124,11 +129,16 @@ type UpdateGroupInput struct {
|
||||
WeeklyLimitUSD *float64 // 周限额 (USD)
|
||||
MonthlyLimitUSD *float64 // 月限额 (USD)
|
||||
// 图片生成计费配置(仅 antigravity 平台使用)
|
||||
ImagePrice1K *float64
|
||||
ImagePrice2K *float64
|
||||
ImagePrice4K *float64
|
||||
ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
|
||||
FallbackGroupID *int64 // 降级分组 ID
|
||||
ImagePrice1K *float64
|
||||
ImagePrice2K *float64
|
||||
ImagePrice4K *float64
|
||||
// Sora 按次计费配置
|
||||
SoraImagePrice360 *float64
|
||||
SoraImagePrice540 *float64
|
||||
SoraVideoPricePerRequest *float64
|
||||
SoraVideoPricePerRequestHD *float64
|
||||
ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
|
||||
FallbackGroupID *int64 // 降级分组 ID
|
||||
// 模型路由配置(仅 anthropic 平台使用)
|
||||
ModelRouting map[string][]int64
|
||||
ModelRoutingEnabled *bool // 是否启用模型路由
|
||||
@@ -273,6 +283,7 @@ type adminServiceImpl struct {
|
||||
groupRepo GroupRepository
|
||||
accountRepo AccountRepository
|
||||
soraAccountRepo SoraAccountRepository // Sora 账号扩展表仓储
|
||||
soraSyncService *Sora2APISyncService // Sora2API 同步服务
|
||||
proxyRepo ProxyRepository
|
||||
apiKeyRepo APIKeyRepository
|
||||
redeemCodeRepo RedeemCodeRepository
|
||||
@@ -288,6 +299,7 @@ func NewAdminService(
|
||||
groupRepo GroupRepository,
|
||||
accountRepo AccountRepository,
|
||||
soraAccountRepo SoraAccountRepository,
|
||||
soraSyncService *Sora2APISyncService,
|
||||
proxyRepo ProxyRepository,
|
||||
apiKeyRepo APIKeyRepository,
|
||||
redeemCodeRepo RedeemCodeRepository,
|
||||
@@ -301,6 +313,7 @@ func NewAdminService(
|
||||
groupRepo: groupRepo,
|
||||
accountRepo: accountRepo,
|
||||
soraAccountRepo: soraAccountRepo,
|
||||
soraSyncService: soraSyncService,
|
||||
proxyRepo: proxyRepo,
|
||||
apiKeyRepo: apiKeyRepo,
|
||||
redeemCodeRepo: redeemCodeRepo,
|
||||
@@ -567,6 +580,10 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
imagePrice1K := normalizePrice(input.ImagePrice1K)
|
||||
imagePrice2K := normalizePrice(input.ImagePrice2K)
|
||||
imagePrice4K := normalizePrice(input.ImagePrice4K)
|
||||
soraImagePrice360 := normalizePrice(input.SoraImagePrice360)
|
||||
soraImagePrice540 := normalizePrice(input.SoraImagePrice540)
|
||||
soraVideoPrice := normalizePrice(input.SoraVideoPricePerRequest)
|
||||
soraVideoPriceHD := normalizePrice(input.SoraVideoPricePerRequestHD)
|
||||
|
||||
// 校验降级分组
|
||||
if input.FallbackGroupID != nil {
|
||||
@@ -576,22 +593,26 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
}
|
||||
|
||||
group := &Group{
|
||||
Name: input.Name,
|
||||
Description: input.Description,
|
||||
Platform: platform,
|
||||
RateMultiplier: input.RateMultiplier,
|
||||
IsExclusive: input.IsExclusive,
|
||||
Status: StatusActive,
|
||||
SubscriptionType: subscriptionType,
|
||||
DailyLimitUSD: dailyLimit,
|
||||
WeeklyLimitUSD: weeklyLimit,
|
||||
MonthlyLimitUSD: monthlyLimit,
|
||||
ImagePrice1K: imagePrice1K,
|
||||
ImagePrice2K: imagePrice2K,
|
||||
ImagePrice4K: imagePrice4K,
|
||||
ClaudeCodeOnly: input.ClaudeCodeOnly,
|
||||
FallbackGroupID: input.FallbackGroupID,
|
||||
ModelRouting: input.ModelRouting,
|
||||
Name: input.Name,
|
||||
Description: input.Description,
|
||||
Platform: platform,
|
||||
RateMultiplier: input.RateMultiplier,
|
||||
IsExclusive: input.IsExclusive,
|
||||
Status: StatusActive,
|
||||
SubscriptionType: subscriptionType,
|
||||
DailyLimitUSD: dailyLimit,
|
||||
WeeklyLimitUSD: weeklyLimit,
|
||||
MonthlyLimitUSD: monthlyLimit,
|
||||
ImagePrice1K: imagePrice1K,
|
||||
ImagePrice2K: imagePrice2K,
|
||||
ImagePrice4K: imagePrice4K,
|
||||
SoraImagePrice360: soraImagePrice360,
|
||||
SoraImagePrice540: soraImagePrice540,
|
||||
SoraVideoPricePerRequest: soraVideoPrice,
|
||||
SoraVideoPricePerRequestHD: soraVideoPriceHD,
|
||||
ClaudeCodeOnly: input.ClaudeCodeOnly,
|
||||
FallbackGroupID: input.FallbackGroupID,
|
||||
ModelRouting: input.ModelRouting,
|
||||
}
|
||||
if err := s.groupRepo.Create(ctx, group); err != nil {
|
||||
return nil, err
|
||||
@@ -702,6 +723,18 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
||||
if input.ImagePrice4K != nil {
|
||||
group.ImagePrice4K = normalizePrice(input.ImagePrice4K)
|
||||
}
|
||||
if input.SoraImagePrice360 != nil {
|
||||
group.SoraImagePrice360 = normalizePrice(input.SoraImagePrice360)
|
||||
}
|
||||
if input.SoraImagePrice540 != nil {
|
||||
group.SoraImagePrice540 = normalizePrice(input.SoraImagePrice540)
|
||||
}
|
||||
if input.SoraVideoPricePerRequest != nil {
|
||||
group.SoraVideoPricePerRequest = normalizePrice(input.SoraVideoPricePerRequest)
|
||||
}
|
||||
if input.SoraVideoPricePerRequestHD != nil {
|
||||
group.SoraVideoPricePerRequestHD = normalizePrice(input.SoraVideoPricePerRequestHD)
|
||||
}
|
||||
|
||||
// Claude Code 客户端限制
|
||||
if input.ClaudeCodeOnly != nil {
|
||||
@@ -884,6 +917,9 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
|
||||
}
|
||||
}
|
||||
|
||||
// 同步到 sora2api(异步,不阻塞创建)
|
||||
s.syncSoraAccountAsync(account)
|
||||
|
||||
return account, nil
|
||||
}
|
||||
|
||||
@@ -974,7 +1010,12 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
||||
}
|
||||
|
||||
// 重新查询以确保返回完整数据(包括正确的 Proxy 关联对象)
|
||||
return s.accountRepo.GetByID(ctx, id)
|
||||
updated, err := s.accountRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.syncSoraAccountAsync(updated)
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
// BulkUpdateAccounts updates multiple accounts in one request.
|
||||
@@ -990,16 +1031,23 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Preload account platforms for mixed channel risk checks if group bindings are requested.
|
||||
needMixedChannelCheck := input.GroupIDs != nil && !input.SkipMixedChannelCheck
|
||||
needSoraSync := s != nil && s.soraSyncService != nil
|
||||
|
||||
// 预加载账号平台信息(混合渠道检查或 Sora 同步需要)。
|
||||
platformByID := map[int64]string{}
|
||||
if input.GroupIDs != nil && !input.SkipMixedChannelCheck {
|
||||
if needMixedChannelCheck || needSoraSync {
|
||||
accounts, err := s.accountRepo.GetByIDs(ctx, input.AccountIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, account := range accounts {
|
||||
if account != nil {
|
||||
platformByID[account.ID] = account.Platform
|
||||
if needMixedChannelCheck {
|
||||
return nil, err
|
||||
}
|
||||
log.Printf("[AdminService] 预加载账号平台信息失败,将逐个降级同步: err=%v", err)
|
||||
} else {
|
||||
for _, account := range accounts {
|
||||
if account != nil {
|
||||
platformByID[account.ID] = account.Platform
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1086,13 +1134,46 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
||||
result.Success++
|
||||
result.SuccessIDs = append(result.SuccessIDs, accountID)
|
||||
result.Results = append(result.Results, entry)
|
||||
|
||||
// 批量更新后同步 sora2api
|
||||
if needSoraSync {
|
||||
platform := platformByID[accountID]
|
||||
if platform == "" {
|
||||
updated, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
if err != nil {
|
||||
log.Printf("[AdminService] 批量更新后获取账号失败,无法同步 sora2api: account_id=%d err=%v", accountID, err)
|
||||
continue
|
||||
}
|
||||
if updated.Platform == PlatformSora {
|
||||
s.syncSoraAccountAsync(updated)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if platform == PlatformSora {
|
||||
updated, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
if err != nil {
|
||||
log.Printf("[AdminService] 批量更新后获取账号失败,无法同步 sora2api: account_id=%d err=%v", accountID, err)
|
||||
continue
|
||||
}
|
||||
s.syncSoraAccountAsync(updated)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error {
|
||||
return s.accountRepo.Delete(ctx, id)
|
||||
account, err := s.accountRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.accountRepo.Delete(ctx, id); err != nil {
|
||||
return err
|
||||
}
|
||||
s.deleteSoraAccountAsync(account)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error) {
|
||||
@@ -1125,7 +1206,46 @@ func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64,
|
||||
if err := s.accountRepo.SetSchedulable(ctx, id, schedulable); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.accountRepo.GetByID(ctx, id)
|
||||
updated, err := s.accountRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.syncSoraAccountAsync(updated)
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) syncSoraAccountAsync(account *Account) {
|
||||
if s == nil || s.soraSyncService == nil || account == nil {
|
||||
return
|
||||
}
|
||||
if account.Platform != PlatformSora {
|
||||
return
|
||||
}
|
||||
syncAccount := *account
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
if err := s.soraSyncService.SyncAccount(ctx, &syncAccount); err != nil {
|
||||
log.Printf("[AdminService] 同步 sora2api 失败: account_id=%d err=%v", syncAccount.ID, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) deleteSoraAccountAsync(account *Account) {
|
||||
if s == nil || s.soraSyncService == nil || account == nil {
|
||||
return
|
||||
}
|
||||
if account.Platform != PlatformSora {
|
||||
return
|
||||
}
|
||||
syncAccount := *account
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
if err := s.soraSyncService.DeleteAccount(ctx, &syncAccount); err != nil {
|
||||
log.Printf("[AdminService] 删除 sora2api token 失败: account_id=%d err=%v", syncAccount.ID, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Proxy management implementations
|
||||
|
||||
@@ -15,6 +15,13 @@ type accountRepoStubForBulkUpdate struct {
|
||||
bulkUpdateErr error
|
||||
bulkUpdateIDs []int64
|
||||
bindGroupErrByID map[int64]error
|
||||
getByIDsAccounts []*Account
|
||||
getByIDsErr error
|
||||
getByIDsCalled bool
|
||||
getByIDsIDs []int64
|
||||
getByIDAccounts map[int64]*Account
|
||||
getByIDErrByID map[int64]error
|
||||
getByIDCalled []int64
|
||||
}
|
||||
|
||||
func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, _ AccountBulkUpdate) (int64, error) {
|
||||
@@ -32,6 +39,26 @@ func (s *accountRepoStubForBulkUpdate) BindGroups(_ context.Context, accountID i
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *accountRepoStubForBulkUpdate) GetByIDs(_ context.Context, ids []int64) ([]*Account, error) {
|
||||
s.getByIDsCalled = true
|
||||
s.getByIDsIDs = append([]int64{}, ids...)
|
||||
if s.getByIDsErr != nil {
|
||||
return nil, s.getByIDsErr
|
||||
}
|
||||
return s.getByIDsAccounts, nil
|
||||
}
|
||||
|
||||
func (s *accountRepoStubForBulkUpdate) GetByID(_ context.Context, id int64) (*Account, error) {
|
||||
s.getByIDCalled = append(s.getByIDCalled, id)
|
||||
if err, ok := s.getByIDErrByID[id]; ok {
|
||||
return nil, err
|
||||
}
|
||||
if account, ok := s.getByIDAccounts[id]; ok {
|
||||
return account, nil
|
||||
}
|
||||
return nil, errors.New("account not found")
|
||||
}
|
||||
|
||||
// TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。
|
||||
func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) {
|
||||
repo := &accountRepoStubForBulkUpdate{}
|
||||
@@ -78,3 +105,31 @@ func TestAdminService_BulkUpdateAccounts_PartialFailureIDs(t *testing.T) {
|
||||
require.ElementsMatch(t, []int64{2}, result.FailedIDs)
|
||||
require.Len(t, result.Results, 3)
|
||||
}
|
||||
|
||||
// TestAdminService_BulkUpdateAccounts_SoraSyncWithoutGroupIDs 验证无分组更新时仍会触发 Sora 同步。
|
||||
func TestAdminService_BulkUpdateAccounts_SoraSyncWithoutGroupIDs(t *testing.T) {
|
||||
repo := &accountRepoStubForBulkUpdate{
|
||||
getByIDsAccounts: []*Account{
|
||||
{ID: 1, Platform: PlatformSora},
|
||||
},
|
||||
getByIDAccounts: map[int64]*Account{
|
||||
1: {ID: 1, Platform: PlatformSora},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{
|
||||
accountRepo: repo,
|
||||
soraSyncService: &Sora2APISyncService{},
|
||||
}
|
||||
|
||||
schedulable := true
|
||||
input := &BulkUpdateAccountsInput{
|
||||
AccountIDs: []int64{1},
|
||||
Schedulable: &schedulable,
|
||||
}
|
||||
|
||||
result, err := svc.BulkUpdateAccounts(context.Background(), input)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, result.Success)
|
||||
require.True(t, repo.getByIDsCalled)
|
||||
require.ElementsMatch(t, []int64{1}, repo.getByIDCalled)
|
||||
}
|
||||
|
||||
@@ -35,6 +35,10 @@ type APIKeyAuthGroupSnapshot struct {
|
||||
ImagePrice1K *float64 `json:"image_price_1k,omitempty"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k,omitempty"`
|
||||
ImagePrice4K *float64 `json:"image_price_4k,omitempty"`
|
||||
SoraImagePrice360 *float64 `json:"sora_image_price_360,omitempty"`
|
||||
SoraImagePrice540 *float64 `json:"sora_image_price_540,omitempty"`
|
||||
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request,omitempty"`
|
||||
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd,omitempty"`
|
||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
|
||||
|
||||
|
||||
@@ -235,6 +235,10 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
|
||||
ImagePrice1K: apiKey.Group.ImagePrice1K,
|
||||
ImagePrice2K: apiKey.Group.ImagePrice2K,
|
||||
ImagePrice4K: apiKey.Group.ImagePrice4K,
|
||||
SoraImagePrice360: apiKey.Group.SoraImagePrice360,
|
||||
SoraImagePrice540: apiKey.Group.SoraImagePrice540,
|
||||
SoraVideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest,
|
||||
SoraVideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD,
|
||||
ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly,
|
||||
FallbackGroupID: apiKey.Group.FallbackGroupID,
|
||||
ModelRouting: apiKey.Group.ModelRouting,
|
||||
@@ -279,6 +283,10 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
|
||||
ImagePrice1K: snapshot.Group.ImagePrice1K,
|
||||
ImagePrice2K: snapshot.Group.ImagePrice2K,
|
||||
ImagePrice4K: snapshot.Group.ImagePrice4K,
|
||||
SoraImagePrice360: snapshot.Group.SoraImagePrice360,
|
||||
SoraImagePrice540: snapshot.Group.SoraImagePrice540,
|
||||
SoraVideoPricePerRequest: snapshot.Group.SoraVideoPricePerRequest,
|
||||
SoraVideoPricePerRequestHD: snapshot.Group.SoraVideoPricePerRequestHD,
|
||||
ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly,
|
||||
FallbackGroupID: snapshot.Group.FallbackGroupID,
|
||||
ModelRouting: snapshot.Group.ModelRouting,
|
||||
|
||||
@@ -303,6 +303,14 @@ type ImagePriceConfig struct {
|
||||
Price4K *float64 // 4K 尺寸价格(nil 表示使用默认值)
|
||||
}
|
||||
|
||||
// SoraPriceConfig Sora 按次计费配置
|
||||
type SoraPriceConfig struct {
|
||||
ImagePrice360 *float64
|
||||
ImagePrice540 *float64
|
||||
VideoPricePerRequest *float64
|
||||
VideoPricePerRequestHD *float64
|
||||
}
|
||||
|
||||
// CalculateImageCost 计算图片生成费用
|
||||
// model: 请求的模型名称(用于获取 LiteLLM 默认价格)
|
||||
// imageSize: 图片尺寸 "1K", "2K", "4K"
|
||||
@@ -332,6 +340,65 @@ func (s *BillingService) CalculateImageCost(model string, imageSize string, imag
|
||||
}
|
||||
}
|
||||
|
||||
// CalculateSoraImageCost 计算 Sora 图片按次费用
|
||||
func (s *BillingService) CalculateSoraImageCost(imageSize string, imageCount int, groupConfig *SoraPriceConfig, rateMultiplier float64) *CostBreakdown {
|
||||
if imageCount <= 0 {
|
||||
return &CostBreakdown{}
|
||||
}
|
||||
|
||||
unitPrice := 0.0
|
||||
if groupConfig != nil {
|
||||
switch imageSize {
|
||||
case "540":
|
||||
if groupConfig.ImagePrice540 != nil {
|
||||
unitPrice = *groupConfig.ImagePrice540
|
||||
}
|
||||
default:
|
||||
if groupConfig.ImagePrice360 != nil {
|
||||
unitPrice = *groupConfig.ImagePrice360
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
totalCost := unitPrice * float64(imageCount)
|
||||
if rateMultiplier <= 0 {
|
||||
rateMultiplier = 1.0
|
||||
}
|
||||
actualCost := totalCost * rateMultiplier
|
||||
|
||||
return &CostBreakdown{
|
||||
TotalCost: totalCost,
|
||||
ActualCost: actualCost,
|
||||
}
|
||||
}
|
||||
|
||||
// CalculateSoraVideoCost 计算 Sora 视频按次费用
|
||||
func (s *BillingService) CalculateSoraVideoCost(model string, groupConfig *SoraPriceConfig, rateMultiplier float64) *CostBreakdown {
|
||||
unitPrice := 0.0
|
||||
if groupConfig != nil {
|
||||
modelLower := strings.ToLower(model)
|
||||
if strings.Contains(modelLower, "sora2pro-hd") {
|
||||
if groupConfig.VideoPricePerRequestHD != nil {
|
||||
unitPrice = *groupConfig.VideoPricePerRequestHD
|
||||
}
|
||||
}
|
||||
if unitPrice <= 0 && groupConfig.VideoPricePerRequest != nil {
|
||||
unitPrice = *groupConfig.VideoPricePerRequest
|
||||
}
|
||||
}
|
||||
|
||||
totalCost := unitPrice
|
||||
if rateMultiplier <= 0 {
|
||||
rateMultiplier = 1.0
|
||||
}
|
||||
actualCost := totalCost * rateMultiplier
|
||||
|
||||
return &CostBreakdown{
|
||||
TotalCost: totalCost,
|
||||
ActualCost: actualCost,
|
||||
}
|
||||
}
|
||||
|
||||
// getImageUnitPrice 获取图片单价
|
||||
func (s *BillingService) getImageUnitPrice(model string, imageSize string, groupConfig *ImagePriceConfig) float64 {
|
||||
// 优先使用分组配置的价格
|
||||
|
||||
@@ -184,6 +184,10 @@ type ForwardResult struct {
|
||||
// 图片生成计费字段(仅 gemini-3-pro-image 使用)
|
||||
ImageCount int // 生成的图片数量
|
||||
ImageSize string // 图片尺寸 "1K", "2K", "4K"
|
||||
|
||||
// Sora 媒体字段
|
||||
MediaType string // image / video / prompt
|
||||
MediaURL string // 生成后的媒体地址(可选)
|
||||
}
|
||||
|
||||
// UpstreamFailoverError indicates an upstream error that should trigger account failover.
|
||||
@@ -3461,7 +3465,22 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
var cost *CostBreakdown
|
||||
|
||||
// 根据请求类型选择计费方式
|
||||
if result.ImageCount > 0 {
|
||||
if result.MediaType == "image" || result.MediaType == "video" || result.MediaType == "prompt" {
|
||||
var soraConfig *SoraPriceConfig
|
||||
if apiKey.Group != nil {
|
||||
soraConfig = &SoraPriceConfig{
|
||||
ImagePrice360: apiKey.Group.SoraImagePrice360,
|
||||
ImagePrice540: apiKey.Group.SoraImagePrice540,
|
||||
VideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest,
|
||||
VideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD,
|
||||
}
|
||||
}
|
||||
if result.MediaType == "image" {
|
||||
cost = s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier)
|
||||
} else {
|
||||
cost = s.billingService.CalculateSoraVideoCost(result.Model, soraConfig, multiplier)
|
||||
}
|
||||
} else if result.ImageCount > 0 {
|
||||
// 图片生成计费
|
||||
var groupConfig *ImagePriceConfig
|
||||
if apiKey.Group != nil {
|
||||
@@ -3501,6 +3520,10 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
if result.ImageSize != "" {
|
||||
imageSize = &result.ImageSize
|
||||
}
|
||||
var mediaType *string
|
||||
if strings.TrimSpace(result.MediaType) != "" {
|
||||
mediaType = &result.MediaType
|
||||
}
|
||||
accountRateMultiplier := account.BillingRateMultiplier()
|
||||
usageLog := &UsageLog{
|
||||
UserID: user.ID,
|
||||
@@ -3526,6 +3549,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
FirstTokenMs: result.FirstTokenMs,
|
||||
ImageCount: result.ImageCount,
|
||||
ImageSize: imageSize,
|
||||
MediaType: mediaType,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
|
||||
@@ -26,6 +26,12 @@ type Group struct {
|
||||
ImagePrice2K *float64
|
||||
ImagePrice4K *float64
|
||||
|
||||
// Sora 按次计费配置(阶段 1)
|
||||
SoraImagePrice360 *float64
|
||||
SoraImagePrice540 *float64
|
||||
SoraVideoPricePerRequest *float64
|
||||
SoraVideoPricePerRequestHD *float64
|
||||
|
||||
// Claude Code 客户端限制
|
||||
ClaudeCodeOnly bool
|
||||
FallbackGroupID *int64
|
||||
@@ -83,6 +89,18 @@ func (g *Group) GetImagePrice(imageSize string) *float64 {
|
||||
}
|
||||
}
|
||||
|
||||
// GetSoraImagePrice 根据 Sora 图片尺寸返回价格(360/540)
|
||||
func (g *Group) GetSoraImagePrice(imageSize string) *float64 {
|
||||
switch imageSize {
|
||||
case "360":
|
||||
return g.SoraImagePrice360
|
||||
case "540":
|
||||
return g.SoraImagePrice540
|
||||
default:
|
||||
return g.SoraImagePrice360
|
||||
}
|
||||
}
|
||||
|
||||
// IsGroupContextValid reports whether a group from context has the fields required for routing decisions.
|
||||
func IsGroupContextValid(group *Group) bool {
|
||||
if group == nil {
|
||||
|
||||
@@ -41,8 +41,8 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
if account == nil {
|
||||
return "", errors.New("account is nil")
|
||||
}
|
||||
if account.Platform != PlatformOpenAI || account.Type != AccountTypeOAuth {
|
||||
return "", errors.New("not an openai oauth account")
|
||||
if (account.Platform != PlatformOpenAI && account.Platform != PlatformSora) || account.Type != AccountTypeOAuth {
|
||||
return "", errors.New("not an openai/sora oauth account")
|
||||
}
|
||||
|
||||
cacheKey := OpenAITokenCacheKey(account)
|
||||
@@ -157,7 +157,7 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
}
|
||||
}
|
||||
|
||||
accessToken := account.GetOpenAIAccessToken()
|
||||
accessToken := account.GetCredential("access_token")
|
||||
if strings.TrimSpace(accessToken) == "" {
|
||||
return "", errors.New("access_token not found in credentials")
|
||||
}
|
||||
|
||||
@@ -375,7 +375,7 @@ func TestOpenAITokenProvider_WrongPlatform(t *testing.T) {
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not an openai oauth account")
|
||||
require.Contains(t, err.Error(), "not an openai/sora oauth account")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
|
||||
@@ -389,7 +389,7 @@ func TestOpenAITokenProvider_WrongAccountType(t *testing.T) {
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not an openai oauth account")
|
||||
require.Contains(t, err.Error(), "not an openai/sora oauth account")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
|
||||
|
||||
355
backend/internal/service/sora2api_service.go
Normal file
355
backend/internal/service/sora2api_service.go
Normal file
@@ -0,0 +1,355 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
// Sora2APIModel represents a model entry returned by sora2api.
|
||||
type Sora2APIModel struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
OwnedBy string `json:"owned_by,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
}
|
||||
|
||||
// Sora2APIModelList represents /v1/models response.
|
||||
type Sora2APIModelList struct {
|
||||
Object string `json:"object"`
|
||||
Data []Sora2APIModel `json:"data"`
|
||||
}
|
||||
|
||||
// Sora2APIImportTokenItem mirrors sora2api ImportTokenItem.
|
||||
type Sora2APIImportTokenItem struct {
|
||||
Email string `json:"email"`
|
||||
AccessToken string `json:"access_token,omitempty"`
|
||||
SessionToken string `json:"session_token,omitempty"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
ClientID string `json:"client_id,omitempty"`
|
||||
ProxyURL string `json:"proxy_url,omitempty"`
|
||||
Remark string `json:"remark,omitempty"`
|
||||
IsActive bool `json:"is_active"`
|
||||
ImageEnabled bool `json:"image_enabled"`
|
||||
VideoEnabled bool `json:"video_enabled"`
|
||||
ImageConcurrency int `json:"image_concurrency"`
|
||||
VideoConcurrency int `json:"video_concurrency"`
|
||||
}
|
||||
|
||||
// Sora2APIToken represents minimal fields for admin list.
|
||||
type Sora2APIToken struct {
|
||||
ID int64 `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
Remark string `json:"remark"`
|
||||
}
|
||||
|
||||
// Sora2APIService provides access to sora2api endpoints.
|
||||
type Sora2APIService struct {
|
||||
cfg *config.Config
|
||||
|
||||
baseURL string
|
||||
apiKey string
|
||||
adminUsername string
|
||||
adminPassword string
|
||||
adminTokenTTL time.Duration
|
||||
adminTimeout time.Duration
|
||||
tokenImportMode string
|
||||
|
||||
client *http.Client
|
||||
adminClient *http.Client
|
||||
|
||||
adminToken string
|
||||
adminTokenAt time.Time
|
||||
adminMu sync.Mutex
|
||||
|
||||
modelCache []Sora2APIModel
|
||||
modelCacheAt time.Time
|
||||
modelMu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewSora2APIService(cfg *config.Config) *Sora2APIService {
|
||||
if cfg == nil {
|
||||
return &Sora2APIService{}
|
||||
}
|
||||
adminTTL := time.Duration(cfg.Sora2API.AdminTokenTTLSeconds) * time.Second
|
||||
if adminTTL <= 0 {
|
||||
adminTTL = 15 * time.Minute
|
||||
}
|
||||
adminTimeout := time.Duration(cfg.Sora2API.AdminTimeoutSeconds) * time.Second
|
||||
if adminTimeout <= 0 {
|
||||
adminTimeout = 10 * time.Second
|
||||
}
|
||||
return &Sora2APIService{
|
||||
cfg: cfg,
|
||||
baseURL: strings.TrimRight(strings.TrimSpace(cfg.Sora2API.BaseURL), "/"),
|
||||
apiKey: strings.TrimSpace(cfg.Sora2API.APIKey),
|
||||
adminUsername: strings.TrimSpace(cfg.Sora2API.AdminUsername),
|
||||
adminPassword: strings.TrimSpace(cfg.Sora2API.AdminPassword),
|
||||
adminTokenTTL: adminTTL,
|
||||
adminTimeout: adminTimeout,
|
||||
tokenImportMode: strings.ToLower(strings.TrimSpace(cfg.Sora2API.TokenImportMode)),
|
||||
client: &http.Client{},
|
||||
adminClient: &http.Client{Timeout: adminTimeout},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Sora2APIService) Enabled() bool {
|
||||
return s != nil && s.baseURL != "" && s.apiKey != ""
|
||||
}
|
||||
|
||||
func (s *Sora2APIService) AdminEnabled() bool {
|
||||
return s != nil && s.baseURL != "" && s.adminUsername != "" && s.adminPassword != ""
|
||||
}
|
||||
|
||||
func (s *Sora2APIService) buildURL(path string) string {
|
||||
if s.baseURL == "" {
|
||||
return path
|
||||
}
|
||||
if strings.HasPrefix(path, "/") {
|
||||
return s.baseURL + path
|
||||
}
|
||||
return s.baseURL + "/" + path
|
||||
}
|
||||
|
||||
// BuildURL 返回完整的 sora2api URL(用于代理媒体)
|
||||
func (s *Sora2APIService) BuildURL(path string) string {
|
||||
return s.buildURL(path)
|
||||
}
|
||||
|
||||
func (s *Sora2APIService) NewAPIRequest(ctx context.Context, method string, path string, body []byte) (*http.Request, error) {
|
||||
if !s.Enabled() {
|
||||
return nil, errors.New("sora2api not configured")
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, method, s.buildURL(path), bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+s.apiKey)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (s *Sora2APIService) ListModels(ctx context.Context) ([]Sora2APIModel, error) {
|
||||
if !s.Enabled() {
|
||||
return nil, errors.New("sora2api not configured")
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, s.buildURL("/v1/models"), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+s.apiKey)
|
||||
resp, err := s.client.Do(req)
|
||||
if err != nil {
|
||||
return s.cachedModelsOnError(err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return s.cachedModelsOnError(fmt.Errorf("sora2api models status: %d", resp.StatusCode))
|
||||
}
|
||||
|
||||
var payload Sora2APIModelList
|
||||
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
|
||||
return s.cachedModelsOnError(err)
|
||||
}
|
||||
models := payload.Data
|
||||
if s.cfg != nil && s.cfg.Gateway.SoraModelFilters.HidePromptEnhance {
|
||||
filtered := make([]Sora2APIModel, 0, len(models))
|
||||
for _, m := range models {
|
||||
if strings.HasPrefix(strings.ToLower(m.ID), "prompt-enhance") {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, m)
|
||||
}
|
||||
models = filtered
|
||||
}
|
||||
|
||||
s.modelMu.Lock()
|
||||
s.modelCache = models
|
||||
s.modelCacheAt = time.Now()
|
||||
s.modelMu.Unlock()
|
||||
|
||||
return models, nil
|
||||
}
|
||||
|
||||
func (s *Sora2APIService) cachedModelsOnError(err error) ([]Sora2APIModel, error) {
|
||||
s.modelMu.RLock()
|
||||
cached := append([]Sora2APIModel(nil), s.modelCache...)
|
||||
s.modelMu.RUnlock()
|
||||
if len(cached) > 0 {
|
||||
log.Printf("[Sora2API] 模型列表拉取失败,回退缓存: %v", err)
|
||||
return cached, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func (s *Sora2APIService) ImportTokens(ctx context.Context, items []Sora2APIImportTokenItem) error {
|
||||
if !s.AdminEnabled() {
|
||||
return errors.New("sora2api admin not configured")
|
||||
}
|
||||
mode := s.tokenImportMode
|
||||
if mode == "" {
|
||||
mode = "at"
|
||||
}
|
||||
payload := map[string]any{
|
||||
"tokens": items,
|
||||
"mode": mode,
|
||||
}
|
||||
_, err := s.doAdminRequest(ctx, http.MethodPost, "/api/tokens/import", payload, nil)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Sora2APIService) ListTokens(ctx context.Context) ([]Sora2APIToken, error) {
|
||||
if !s.AdminEnabled() {
|
||||
return nil, errors.New("sora2api admin not configured")
|
||||
}
|
||||
var tokens []Sora2APIToken
|
||||
_, err := s.doAdminRequest(ctx, http.MethodGet, "/api/tokens", nil, &tokens)
|
||||
return tokens, err
|
||||
}
|
||||
|
||||
func (s *Sora2APIService) DisableToken(ctx context.Context, tokenID int64) error {
|
||||
if !s.AdminEnabled() {
|
||||
return errors.New("sora2api admin not configured")
|
||||
}
|
||||
path := fmt.Sprintf("/api/tokens/%d/disable", tokenID)
|
||||
_, err := s.doAdminRequest(ctx, http.MethodPost, path, nil, nil)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Sora2APIService) DeleteToken(ctx context.Context, tokenID int64) error {
|
||||
if !s.AdminEnabled() {
|
||||
return errors.New("sora2api admin not configured")
|
||||
}
|
||||
path := fmt.Sprintf("/api/tokens/%d", tokenID)
|
||||
_, err := s.doAdminRequest(ctx, http.MethodDelete, path, nil, nil)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Sora2APIService) doAdminRequest(ctx context.Context, method string, path string, body any, out any) (*http.Response, error) {
|
||||
if !s.AdminEnabled() {
|
||||
return nil, errors.New("sora2api admin not configured")
|
||||
}
|
||||
token, err := s.getAdminToken(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := s.doAdminRequestWithToken(ctx, method, path, token, body, out)
|
||||
if err == nil && resp != nil && resp.StatusCode != http.StatusUnauthorized {
|
||||
return resp, nil
|
||||
}
|
||||
if resp != nil && resp.StatusCode == http.StatusUnauthorized {
|
||||
s.invalidateAdminToken()
|
||||
token, err = s.getAdminToken(ctx)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
return s.doAdminRequestWithToken(ctx, method, path, token, body, out)
|
||||
}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (s *Sora2APIService) doAdminRequestWithToken(ctx context.Context, method string, path string, token string, body any, out any) (*http.Response, error) {
|
||||
var reader *bytes.Reader
|
||||
if body != nil {
|
||||
buf, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
reader = bytes.NewReader(buf)
|
||||
} else {
|
||||
reader = bytes.NewReader(nil)
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, method, s.buildURL(path), reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
if body != nil {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
resp, err := s.adminClient.Do(req)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return resp, fmt.Errorf("sora2api admin status: %d", resp.StatusCode)
|
||||
}
|
||||
if out != nil {
|
||||
if err := json.NewDecoder(resp.Body).Decode(out); err != nil {
|
||||
return resp, err
|
||||
}
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (s *Sora2APIService) getAdminToken(ctx context.Context) (string, error) {
|
||||
s.adminMu.Lock()
|
||||
defer s.adminMu.Unlock()
|
||||
|
||||
if s.adminToken != "" && time.Since(s.adminTokenAt) < s.adminTokenTTL {
|
||||
return s.adminToken, nil
|
||||
}
|
||||
|
||||
if !s.AdminEnabled() {
|
||||
return "", errors.New("sora2api admin not configured")
|
||||
}
|
||||
|
||||
payload := map[string]string{
|
||||
"username": s.adminUsername,
|
||||
"password": s.adminPassword,
|
||||
}
|
||||
buf, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, s.buildURL("/api/login"), bytes.NewReader(buf))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := s.adminClient.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("sora2api login failed: %d", resp.StatusCode)
|
||||
}
|
||||
var result struct {
|
||||
Success bool `json:"success"`
|
||||
Token string `json:"token"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !result.Success || result.Token == "" {
|
||||
if result.Message == "" {
|
||||
result.Message = "sora2api login failed"
|
||||
}
|
||||
return "", errors.New(result.Message)
|
||||
}
|
||||
s.adminToken = result.Token
|
||||
s.adminTokenAt = time.Now()
|
||||
return result.Token, nil
|
||||
}
|
||||
|
||||
func (s *Sora2APIService) invalidateAdminToken() {
|
||||
s.adminMu.Lock()
|
||||
defer s.adminMu.Unlock()
|
||||
s.adminToken = ""
|
||||
s.adminTokenAt = time.Time{}
|
||||
}
|
||||
255
backend/internal/service/sora2api_sync_service.go
Normal file
255
backend/internal/service/sora2api_sync_service.go
Normal file
@@ -0,0 +1,255 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
// Sora2APISyncService 用于同步 Sora 账号到 sora2api token 池
|
||||
type Sora2APISyncService struct {
|
||||
sora2api *Sora2APIService
|
||||
accountRepo AccountRepository
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewSora2APISyncService(sora2api *Sora2APIService, accountRepo AccountRepository) *Sora2APISyncService {
|
||||
return &Sora2APISyncService{
|
||||
sora2api: sora2api,
|
||||
accountRepo: accountRepo,
|
||||
httpClient: &http.Client{Timeout: 10 * time.Second},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Sora2APISyncService) Enabled() bool {
|
||||
return s != nil && s.sora2api != nil && s.sora2api.AdminEnabled()
|
||||
}
|
||||
|
||||
// SyncAccount 将 Sora 账号同步到 sora2api(导入或更新)
|
||||
func (s *Sora2APISyncService) SyncAccount(ctx context.Context, account *Account) error {
|
||||
if !s.Enabled() {
|
||||
return nil
|
||||
}
|
||||
if account == nil || account.Platform != PlatformSora {
|
||||
return nil
|
||||
}
|
||||
|
||||
accessToken := strings.TrimSpace(account.GetCredential("access_token"))
|
||||
if accessToken == "" {
|
||||
return errors.New("sora 账号缺少 access_token")
|
||||
}
|
||||
|
||||
email, updated := s.resolveAccountEmail(ctx, account)
|
||||
if email == "" {
|
||||
return errors.New("无法解析 Sora 账号邮箱")
|
||||
}
|
||||
if updated && s.accountRepo != nil {
|
||||
if err := s.accountRepo.Update(ctx, account); err != nil {
|
||||
log.Printf("[SoraSync] 更新账号邮箱失败: account_id=%d err=%v", account.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
item := Sora2APIImportTokenItem{
|
||||
Email: email,
|
||||
AccessToken: accessToken,
|
||||
SessionToken: strings.TrimSpace(account.GetCredential("session_token")),
|
||||
RefreshToken: strings.TrimSpace(account.GetCredential("refresh_token")),
|
||||
ClientID: strings.TrimSpace(account.GetCredential("client_id")),
|
||||
Remark: account.Name,
|
||||
IsActive: account.IsActive() && account.Schedulable,
|
||||
ImageEnabled: true,
|
||||
VideoEnabled: true,
|
||||
ImageConcurrency: normalizeSoraConcurrency(account.Concurrency),
|
||||
VideoConcurrency: normalizeSoraConcurrency(account.Concurrency),
|
||||
}
|
||||
|
||||
if err := s.sora2api.ImportTokens(ctx, []Sora2APIImportTokenItem{item}); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DisableAccount 禁用 sora2api 中的 token
|
||||
func (s *Sora2APISyncService) DisableAccount(ctx context.Context, account *Account) error {
|
||||
if !s.Enabled() {
|
||||
return nil
|
||||
}
|
||||
if account == nil || account.Platform != PlatformSora {
|
||||
return nil
|
||||
}
|
||||
tokenID, err := s.resolveTokenID(ctx, account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.sora2api.DisableToken(ctx, tokenID)
|
||||
}
|
||||
|
||||
// DeleteAccount 删除 sora2api 中的 token
|
||||
func (s *Sora2APISyncService) DeleteAccount(ctx context.Context, account *Account) error {
|
||||
if !s.Enabled() {
|
||||
return nil
|
||||
}
|
||||
if account == nil || account.Platform != PlatformSora {
|
||||
return nil
|
||||
}
|
||||
tokenID, err := s.resolveTokenID(ctx, account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.sora2api.DeleteToken(ctx, tokenID)
|
||||
}
|
||||
|
||||
func normalizeSoraConcurrency(value int) int {
|
||||
if value <= 0 {
|
||||
return -1
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func (s *Sora2APISyncService) resolveAccountEmail(ctx context.Context, account *Account) (string, bool) {
|
||||
if account == nil {
|
||||
return "", false
|
||||
}
|
||||
if email := strings.TrimSpace(account.GetCredential("email")); email != "" {
|
||||
return email, false
|
||||
}
|
||||
if email := strings.TrimSpace(account.GetExtraString("email")); email != "" {
|
||||
if account.Credentials == nil {
|
||||
account.Credentials = map[string]any{}
|
||||
}
|
||||
account.Credentials["email"] = email
|
||||
return email, true
|
||||
}
|
||||
if email := strings.TrimSpace(account.GetExtraString("sora_email")); email != "" {
|
||||
if account.Credentials == nil {
|
||||
account.Credentials = map[string]any{}
|
||||
}
|
||||
account.Credentials["email"] = email
|
||||
return email, true
|
||||
}
|
||||
|
||||
accessToken := strings.TrimSpace(account.GetCredential("access_token"))
|
||||
if accessToken != "" {
|
||||
if email := extractEmailFromAccessToken(accessToken); email != "" {
|
||||
if account.Credentials == nil {
|
||||
account.Credentials = map[string]any{}
|
||||
}
|
||||
account.Credentials["email"] = email
|
||||
return email, true
|
||||
}
|
||||
if email := s.fetchEmailFromSora(ctx, accessToken); email != "" {
|
||||
if account.Credentials == nil {
|
||||
account.Credentials = map[string]any{}
|
||||
}
|
||||
account.Credentials["email"] = email
|
||||
return email, true
|
||||
}
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
func (s *Sora2APISyncService) resolveTokenID(ctx context.Context, account *Account) (int64, error) {
|
||||
if account == nil {
|
||||
return 0, errors.New("account is nil")
|
||||
}
|
||||
|
||||
if account.Extra != nil {
|
||||
if v, ok := account.Extra["sora2api_token_id"]; ok {
|
||||
if id, ok := v.(float64); ok && id > 0 {
|
||||
return int64(id), nil
|
||||
}
|
||||
if id, ok := v.(int64); ok && id > 0 {
|
||||
return id, nil
|
||||
}
|
||||
if id, ok := v.(int); ok && id > 0 {
|
||||
return int64(id), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
email := strings.TrimSpace(account.GetCredential("email"))
|
||||
if email == "" {
|
||||
email, _ = s.resolveAccountEmail(ctx, account)
|
||||
}
|
||||
if email == "" {
|
||||
return 0, errors.New("sora2api token email missing")
|
||||
}
|
||||
|
||||
tokenID, err := s.findTokenIDByEmail(ctx, email)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return tokenID, nil
|
||||
}
|
||||
|
||||
func (s *Sora2APISyncService) findTokenIDByEmail(ctx context.Context, email string) (int64, error) {
|
||||
if !s.Enabled() {
|
||||
return 0, errors.New("sora2api admin not configured")
|
||||
}
|
||||
tokens, err := s.sora2api.ListTokens(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
for _, token := range tokens {
|
||||
if strings.EqualFold(strings.TrimSpace(token.Email), strings.TrimSpace(email)) {
|
||||
return token.ID, nil
|
||||
}
|
||||
}
|
||||
return 0, fmt.Errorf("sora2api token not found for email: %s", email)
|
||||
}
|
||||
|
||||
func extractEmailFromAccessToken(accessToken string) string {
|
||||
parser := jwt.NewParser(jwt.WithoutClaimsValidation())
|
||||
claims := jwt.MapClaims{}
|
||||
_, _, err := parser.ParseUnverified(accessToken, claims)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
if email, ok := claims["email"].(string); ok && strings.TrimSpace(email) != "" {
|
||||
return email
|
||||
}
|
||||
if profile, ok := claims["https://api.openai.com/profile"].(map[string]any); ok {
|
||||
if email, ok := profile["email"].(string); ok && strings.TrimSpace(email) != "" {
|
||||
return email
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s *Sora2APISyncService) fetchEmailFromSora(ctx context.Context, accessToken string) string {
|
||||
if s.httpClient == nil {
|
||||
return ""
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, soraMeAPIURL, nil)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := s.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return ""
|
||||
}
|
||||
var payload map[string]any
|
||||
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
|
||||
return ""
|
||||
}
|
||||
if email, ok := payload["email"].(string); ok && strings.TrimSpace(email) != "" {
|
||||
return email
|
||||
}
|
||||
return ""
|
||||
}
|
||||
660
backend/internal/service/sora_gateway_service.go
Normal file
660
backend/internal/service/sora_gateway_service.go
Normal file
@@ -0,0 +1,660 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var soraSSEDataRe = regexp.MustCompile(`^data:\s*`)
|
||||
var soraImageMarkdownRe = regexp.MustCompile(`!\[[^\]]*\]\(([^)]+)\)`)
|
||||
var soraVideoHTMLRe = regexp.MustCompile(`(?i)<video[^>]+src=['"]([^'"]+)['"]`)
|
||||
|
||||
var soraImageSizeMap = map[string]string{
|
||||
"gpt-image": "360",
|
||||
"gpt-image-landscape": "540",
|
||||
"gpt-image-portrait": "540",
|
||||
}
|
||||
|
||||
type soraStreamingResult struct {
|
||||
content string
|
||||
mediaType string
|
||||
mediaURLs []string
|
||||
imageCount int
|
||||
imageSize string
|
||||
firstTokenMs *int
|
||||
}
|
||||
|
||||
// SoraGatewayService handles forwarding requests to sora2api.
|
||||
type SoraGatewayService struct {
|
||||
sora2api *Sora2APIService
|
||||
httpUpstream HTTPUpstream
|
||||
rateLimitService *RateLimitService
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
func NewSoraGatewayService(
|
||||
sora2api *Sora2APIService,
|
||||
httpUpstream HTTPUpstream,
|
||||
rateLimitService *RateLimitService,
|
||||
cfg *config.Config,
|
||||
) *SoraGatewayService {
|
||||
return &SoraGatewayService{
|
||||
sora2api: sora2api,
|
||||
httpUpstream: httpUpstream,
|
||||
rateLimitService: rateLimitService,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte, clientStream bool) (*ForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
if s.sora2api == nil || !s.sora2api.Enabled() {
|
||||
if c != nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "api_error",
|
||||
"message": "sora2api 未配置",
|
||||
},
|
||||
})
|
||||
}
|
||||
return nil, errors.New("sora2api not configured")
|
||||
}
|
||||
|
||||
var reqBody map[string]any
|
||||
if err := json.Unmarshal(body, &reqBody); err != nil {
|
||||
return nil, fmt.Errorf("parse request: %w", err)
|
||||
}
|
||||
reqModel, _ := reqBody["model"].(string)
|
||||
reqStream, _ := reqBody["stream"].(bool)
|
||||
|
||||
mappedModel := account.GetMappedModel(reqModel)
|
||||
if mappedModel != reqModel && mappedModel != "" {
|
||||
reqBody["model"] = mappedModel
|
||||
if updated, err := json.Marshal(reqBody); err == nil {
|
||||
body = updated
|
||||
}
|
||||
}
|
||||
|
||||
reqCtx, cancel := s.withSoraTimeout(ctx, reqStream)
|
||||
if cancel != nil {
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
upstreamReq, err := s.sora2api.NewAPIRequest(reqCtx, http.MethodPost, "/v1/chat/completions", body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if c != nil {
|
||||
if ua := strings.TrimSpace(c.GetHeader("User-Agent")); ua != "" {
|
||||
upstreamReq.Header.Set("User-Agent", ua)
|
||||
}
|
||||
}
|
||||
if reqStream {
|
||||
upstreamReq.Header.Set("Accept", "text/event-stream")
|
||||
}
|
||||
|
||||
if c != nil {
|
||||
c.Set(OpsUpstreamRequestBodyKey, string(body))
|
||||
}
|
||||
|
||||
proxyURL := ""
|
||||
if account != nil && account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
var resp *http.Response
|
||||
if s.httpUpstream != nil {
|
||||
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
} else {
|
||||
resp, err = http.DefaultClient.Do(upstreamReq)
|
||||
}
|
||||
if err != nil {
|
||||
s.setUpstreamRequestError(c, account, err)
|
||||
return nil, fmt.Errorf("upstream request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "failover",
|
||||
Message: upstreamMsg,
|
||||
})
|
||||
s.handleFailoverSideEffects(ctx, resp, account)
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
return s.handleErrorResponse(ctx, resp, c, account, reqModel)
|
||||
}
|
||||
|
||||
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, reqModel, clientStream)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := &ForwardResult{
|
||||
RequestID: resp.Header.Get("x-request-id"),
|
||||
Model: reqModel,
|
||||
Stream: clientStream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: streamResult.firstTokenMs,
|
||||
Usage: ClaudeUsage{},
|
||||
MediaType: streamResult.mediaType,
|
||||
MediaURL: firstMediaURL(streamResult.mediaURLs),
|
||||
ImageCount: streamResult.imageCount,
|
||||
ImageSize: streamResult.imageSize,
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) withSoraTimeout(ctx context.Context, stream bool) (context.Context, context.CancelFunc) {
|
||||
if s == nil || s.cfg == nil {
|
||||
return ctx, nil
|
||||
}
|
||||
timeoutSeconds := s.cfg.Gateway.SoraRequestTimeoutSeconds
|
||||
if stream {
|
||||
timeoutSeconds = s.cfg.Gateway.SoraStreamTimeoutSeconds
|
||||
}
|
||||
if timeoutSeconds <= 0 {
|
||||
return ctx, nil
|
||||
}
|
||||
return context.WithTimeout(ctx, time.Duration(timeoutSeconds)*time.Second)
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) setUpstreamRequestError(c *gin.Context, account *Account, err error) {
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
setOpsUpstreamError(c, 0, safeErr, "")
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
})
|
||||
if c != nil {
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": "Upstream request failed",
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) shouldFailoverUpstreamError(statusCode int) bool {
|
||||
switch statusCode {
|
||||
case 401, 402, 403, 429, 529:
|
||||
return true
|
||||
default:
|
||||
return statusCode >= 500
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) {
|
||||
if s.rateLimitService == nil || account == nil || resp == nil {
|
||||
return
|
||||
}
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, reqModel string) (*ForwardResult, error) {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
if msg := soraProErrorMessage(reqModel, upstreamMsg); msg != "" {
|
||||
upstreamMsg = msg
|
||||
}
|
||||
|
||||
upstreamDetail := ""
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 2048
|
||||
}
|
||||
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||
}
|
||||
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "http_error",
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
|
||||
if c != nil {
|
||||
responsePayload := s.buildErrorPayload(respBody, upstreamMsg)
|
||||
c.JSON(resp.StatusCode, responsePayload)
|
||||
}
|
||||
if upstreamMsg == "" {
|
||||
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
|
||||
}
|
||||
return nil, fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg)
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) buildErrorPayload(respBody []byte, overrideMessage string) map[string]any {
|
||||
if len(respBody) > 0 {
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(respBody, &payload); err == nil {
|
||||
if errObj, ok := payload["error"].(map[string]any); ok {
|
||||
if overrideMessage != "" {
|
||||
errObj["message"] = overrideMessage
|
||||
}
|
||||
payload["error"] = errObj
|
||||
return payload
|
||||
}
|
||||
}
|
||||
}
|
||||
return map[string]any{
|
||||
"error": map[string]any{
|
||||
"type": "upstream_error",
|
||||
"message": overrideMessage,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel string, clientStream bool) (*soraStreamingResult, error) {
|
||||
if resp == nil {
|
||||
return nil, errors.New("empty response")
|
||||
}
|
||||
|
||||
if clientStream {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
if v := resp.Header.Get("x-request-id"); v != "" {
|
||||
c.Header("x-request-id", v)
|
||||
}
|
||||
}
|
||||
|
||||
w := c.Writer
|
||||
flusher, _ := w.(http.Flusher)
|
||||
|
||||
contentBuilder := strings.Builder{}
|
||||
var firstTokenMs *int
|
||||
var upstreamError error
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
maxLineSize := defaultMaxLineSize
|
||||
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
||||
maxLineSize = s.cfg.Gateway.MaxLineSize
|
||||
}
|
||||
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
|
||||
|
||||
sendLine := func(line string) error {
|
||||
if !clientStream {
|
||||
return nil
|
||||
}
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
return err
|
||||
}
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if soraSSEDataRe.MatchString(line) {
|
||||
data := soraSSEDataRe.ReplaceAllString(line, "")
|
||||
if data == "[DONE]" {
|
||||
if err := sendLine("data: [DONE]"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
break
|
||||
}
|
||||
updatedLine, contentDelta, errEvent := s.processSoraSSEData(data, originalModel)
|
||||
if errEvent != nil && upstreamError == nil {
|
||||
upstreamError = errEvent
|
||||
}
|
||||
if contentDelta != "" {
|
||||
if firstTokenMs == nil {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
contentBuilder.WriteString(contentDelta)
|
||||
}
|
||||
if err := sendLine(updatedLine); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
if err := sendLine(line); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
if errors.Is(err, bufio.ErrTooLong) {
|
||||
if clientStream {
|
||||
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"response_too_large\"}\n\n")
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if ctx.Err() == context.DeadlineExceeded && s.rateLimitService != nil && account != nil {
|
||||
s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel)
|
||||
}
|
||||
if clientStream {
|
||||
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"stream_read_error\"}\n\n")
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
content := contentBuilder.String()
|
||||
mediaType, mediaURLs := s.extractSoraMedia(content)
|
||||
if mediaType == "" && isSoraPromptEnhanceModel(originalModel) {
|
||||
mediaType = "prompt"
|
||||
}
|
||||
imageSize := ""
|
||||
imageCount := 0
|
||||
if mediaType == "image" {
|
||||
imageSize = soraImageSizeFromModel(originalModel)
|
||||
imageCount = len(mediaURLs)
|
||||
}
|
||||
|
||||
if upstreamError != nil && !clientStream {
|
||||
if c != nil {
|
||||
c.JSON(http.StatusBadGateway, map[string]any{
|
||||
"error": map[string]any{
|
||||
"type": "upstream_error",
|
||||
"message": upstreamError.Error(),
|
||||
},
|
||||
})
|
||||
}
|
||||
return nil, upstreamError
|
||||
}
|
||||
|
||||
if !clientStream {
|
||||
response := buildSoraNonStreamResponse(content, originalModel)
|
||||
if len(mediaURLs) > 0 {
|
||||
response["media_url"] = mediaURLs[0]
|
||||
if len(mediaURLs) > 1 {
|
||||
response["media_urls"] = mediaURLs
|
||||
}
|
||||
}
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
return &soraStreamingResult{
|
||||
content: content,
|
||||
mediaType: mediaType,
|
||||
mediaURLs: mediaURLs,
|
||||
imageCount: imageCount,
|
||||
imageSize: imageSize,
|
||||
firstTokenMs: firstTokenMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) processSoraSSEData(data string, originalModel string) (string, string, error) {
|
||||
if strings.TrimSpace(data) == "" {
|
||||
return "data: ", "", nil
|
||||
}
|
||||
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(data), &payload); err != nil {
|
||||
return "data: " + data, "", nil
|
||||
}
|
||||
|
||||
if errObj, ok := payload["error"].(map[string]any); ok {
|
||||
if msg, ok := errObj["message"].(string); ok && strings.TrimSpace(msg) != "" {
|
||||
return "data: " + data, "", errors.New(msg)
|
||||
}
|
||||
}
|
||||
|
||||
if model, ok := payload["model"].(string); ok && model != "" && originalModel != "" {
|
||||
payload["model"] = originalModel
|
||||
}
|
||||
|
||||
contentDelta, updated := extractSoraContent(payload)
|
||||
if updated {
|
||||
rewritten := s.rewriteSoraContent(contentDelta)
|
||||
if rewritten != contentDelta {
|
||||
applySoraContent(payload, rewritten)
|
||||
contentDelta = rewritten
|
||||
}
|
||||
}
|
||||
|
||||
updatedData, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "data: " + data, contentDelta, nil
|
||||
}
|
||||
return "data: " + string(updatedData), contentDelta, nil
|
||||
}
|
||||
|
||||
func extractSoraContent(payload map[string]any) (string, bool) {
|
||||
choices, ok := payload["choices"].([]any)
|
||||
if !ok || len(choices) == 0 {
|
||||
return "", false
|
||||
}
|
||||
choice, ok := choices[0].(map[string]any)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
if delta, ok := choice["delta"].(map[string]any); ok {
|
||||
if content, ok := delta["content"].(string); ok {
|
||||
return content, true
|
||||
}
|
||||
}
|
||||
if message, ok := choice["message"].(map[string]any); ok {
|
||||
if content, ok := message["content"].(string); ok {
|
||||
return content, true
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func applySoraContent(payload map[string]any, content string) {
|
||||
choices, ok := payload["choices"].([]any)
|
||||
if !ok || len(choices) == 0 {
|
||||
return
|
||||
}
|
||||
choice, ok := choices[0].(map[string]any)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if delta, ok := choice["delta"].(map[string]any); ok {
|
||||
delta["content"] = content
|
||||
choice["delta"] = delta
|
||||
return
|
||||
}
|
||||
if message, ok := choice["message"].(map[string]any); ok {
|
||||
message["content"] = content
|
||||
choice["message"] = message
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) rewriteSoraContent(content string) string {
|
||||
if content == "" {
|
||||
return content
|
||||
}
|
||||
content = soraImageMarkdownRe.ReplaceAllStringFunc(content, func(match string) string {
|
||||
sub := soraImageMarkdownRe.FindStringSubmatch(match)
|
||||
if len(sub) < 2 {
|
||||
return match
|
||||
}
|
||||
rewritten := s.rewriteSoraURL(sub[1])
|
||||
if rewritten == sub[1] {
|
||||
return match
|
||||
}
|
||||
return strings.Replace(match, sub[1], rewritten, 1)
|
||||
})
|
||||
content = soraVideoHTMLRe.ReplaceAllStringFunc(content, func(match string) string {
|
||||
sub := soraVideoHTMLRe.FindStringSubmatch(match)
|
||||
if len(sub) < 2 {
|
||||
return match
|
||||
}
|
||||
rewritten := s.rewriteSoraURL(sub[1])
|
||||
if rewritten == sub[1] {
|
||||
return match
|
||||
}
|
||||
return strings.Replace(match, sub[1], rewritten, 1)
|
||||
})
|
||||
return content
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) rewriteSoraURL(raw string) string {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return raw
|
||||
}
|
||||
parsed, err := url.Parse(raw)
|
||||
if err != nil {
|
||||
return raw
|
||||
}
|
||||
path := parsed.Path
|
||||
if !strings.HasPrefix(path, "/tmp/") && !strings.HasPrefix(path, "/static/") {
|
||||
return raw
|
||||
}
|
||||
return s.buildSoraMediaURL(path, parsed.RawQuery)
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) extractSoraMedia(content string) (string, []string) {
|
||||
if content == "" {
|
||||
return "", nil
|
||||
}
|
||||
if match := soraVideoHTMLRe.FindStringSubmatch(content); len(match) > 1 {
|
||||
return "video", []string{match[1]}
|
||||
}
|
||||
imageMatches := soraImageMarkdownRe.FindAllStringSubmatch(content, -1)
|
||||
if len(imageMatches) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
urls := make([]string, 0, len(imageMatches))
|
||||
for _, match := range imageMatches {
|
||||
if len(match) > 1 {
|
||||
urls = append(urls, match[1])
|
||||
}
|
||||
}
|
||||
return "image", urls
|
||||
}
|
||||
|
||||
func buildSoraNonStreamResponse(content, model string) map[string]any {
|
||||
return map[string]any{
|
||||
"id": fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano()),
|
||||
"object": "chat.completion",
|
||||
"created": time.Now().Unix(),
|
||||
"model": model,
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
"message": map[string]any{
|
||||
"role": "assistant",
|
||||
"content": content,
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func soraImageSizeFromModel(model string) string {
|
||||
modelLower := strings.ToLower(model)
|
||||
if size, ok := soraImageSizeMap[modelLower]; ok {
|
||||
return size
|
||||
}
|
||||
if strings.Contains(modelLower, "landscape") || strings.Contains(modelLower, "portrait") {
|
||||
return "540"
|
||||
}
|
||||
return "360"
|
||||
}
|
||||
|
||||
func isSoraPromptEnhanceModel(model string) bool {
|
||||
return strings.HasPrefix(strings.ToLower(strings.TrimSpace(model)), "prompt-enhance")
|
||||
}
|
||||
|
||||
func soraProErrorMessage(model, upstreamMsg string) string {
|
||||
modelLower := strings.ToLower(model)
|
||||
if strings.Contains(modelLower, "sora2pro-hd") {
|
||||
return "当前账号无法使用 Sora Pro-HD 模型,请更换模型或账号"
|
||||
}
|
||||
if strings.Contains(modelLower, "sora2pro") {
|
||||
return "当前账号无法使用 Sora Pro 模型,请更换模型或账号"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func firstMediaURL(urls []string) string {
|
||||
if len(urls) == 0 {
|
||||
return ""
|
||||
}
|
||||
return urls[0]
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) buildSoraMediaURL(path string, rawQuery string) string {
|
||||
if path == "" {
|
||||
return path
|
||||
}
|
||||
prefix := "/sora/media"
|
||||
values := url.Values{}
|
||||
if rawQuery != "" {
|
||||
if parsed, err := url.ParseQuery(rawQuery); err == nil {
|
||||
values = parsed
|
||||
}
|
||||
}
|
||||
|
||||
signKey := ""
|
||||
ttlSeconds := 0
|
||||
if s != nil && s.cfg != nil {
|
||||
signKey = strings.TrimSpace(s.cfg.Gateway.SoraMediaSigningKey)
|
||||
ttlSeconds = s.cfg.Gateway.SoraMediaSignedURLTTLSeconds
|
||||
}
|
||||
values.Del("sig")
|
||||
values.Del("expires")
|
||||
signingQuery := values.Encode()
|
||||
if signKey != "" && ttlSeconds > 0 {
|
||||
expires := time.Now().Add(time.Duration(ttlSeconds) * time.Second).Unix()
|
||||
signature := SignSoraMediaURL(path, signingQuery, expires, signKey)
|
||||
if signature != "" {
|
||||
values.Set("expires", strconv.FormatInt(expires, 10))
|
||||
values.Set("sig", signature)
|
||||
prefix = "/sora/media-signed"
|
||||
}
|
||||
}
|
||||
|
||||
encoded := values.Encode()
|
||||
if encoded == "" {
|
||||
return prefix + path
|
||||
}
|
||||
return prefix + path + "?" + encoded
|
||||
}
|
||||
42
backend/internal/service/sora_media_sign.go
Normal file
42
backend/internal/service/sora_media_sign.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// SignSoraMediaURL 生成 Sora 媒体临时签名
|
||||
func SignSoraMediaURL(path string, query string, expires int64, key string) string {
|
||||
key = strings.TrimSpace(key)
|
||||
if key == "" {
|
||||
return ""
|
||||
}
|
||||
mac := hmac.New(sha256.New, []byte(key))
|
||||
mac.Write([]byte(buildSoraMediaSignPayload(path, query)))
|
||||
mac.Write([]byte("|"))
|
||||
mac.Write([]byte(strconv.FormatInt(expires, 10)))
|
||||
return hex.EncodeToString(mac.Sum(nil))
|
||||
}
|
||||
|
||||
// VerifySoraMediaURL 校验 Sora 媒体签名
|
||||
func VerifySoraMediaURL(path string, query string, expires int64, signature string, key string) bool {
|
||||
signature = strings.TrimSpace(signature)
|
||||
if signature == "" {
|
||||
return false
|
||||
}
|
||||
expected := SignSoraMediaURL(path, query, expires, key)
|
||||
if expected == "" {
|
||||
return false
|
||||
}
|
||||
return hmac.Equal([]byte(signature), []byte(expected))
|
||||
}
|
||||
|
||||
func buildSoraMediaSignPayload(path string, query string) string {
|
||||
if strings.TrimSpace(query) == "" {
|
||||
return path
|
||||
}
|
||||
return path + "?" + query
|
||||
}
|
||||
34
backend/internal/service/sora_media_sign_test.go
Normal file
34
backend/internal/service/sora_media_sign_test.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package service
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestSoraMediaSignVerify(t *testing.T) {
|
||||
key := "test-key"
|
||||
path := "/tmp/abc.png"
|
||||
query := "a=1&b=2"
|
||||
expires := int64(1700000000)
|
||||
|
||||
signature := SignSoraMediaURL(path, query, expires, key)
|
||||
if signature == "" {
|
||||
t.Fatal("签名为空")
|
||||
}
|
||||
if !VerifySoraMediaURL(path, query, expires, signature, key) {
|
||||
t.Fatal("签名校验失败")
|
||||
}
|
||||
if VerifySoraMediaURL(path, "a=1", expires, signature, key) {
|
||||
t.Fatal("签名参数不同仍然通过")
|
||||
}
|
||||
if VerifySoraMediaURL(path, query, expires+1, signature, key) {
|
||||
t.Fatal("签名过期校验未失败")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSoraMediaSignWithEmptyKey(t *testing.T) {
|
||||
signature := SignSoraMediaURL("/tmp/a.png", "a=1", 1, "")
|
||||
if signature != "" {
|
||||
t.Fatalf("空密钥不应生成签名")
|
||||
}
|
||||
if VerifySoraMediaURL("/tmp/a.png", "a=1", 1, "sig", "") {
|
||||
t.Fatalf("空密钥不应通过校验")
|
||||
}
|
||||
}
|
||||
@@ -42,7 +42,7 @@ func (c *CompositeTokenCacheInvalidator) InvalidateToken(ctx context.Context, ac
|
||||
// Antigravity 同样可能有两种缓存键
|
||||
keysToDelete = append(keysToDelete, AntigravityTokenCacheKey(account))
|
||||
keysToDelete = append(keysToDelete, "ag:"+accountIDKey)
|
||||
case PlatformOpenAI:
|
||||
case PlatformOpenAI, PlatformSora:
|
||||
keysToDelete = append(keysToDelete, OpenAITokenCacheKey(account))
|
||||
case PlatformAnthropic:
|
||||
keysToDelete = append(keysToDelete, ClaudeTokenCacheKey(account))
|
||||
|
||||
@@ -19,6 +19,7 @@ type TokenRefreshService struct {
|
||||
refreshers []TokenRefresher
|
||||
cfg *config.TokenRefreshConfig
|
||||
cacheInvalidator TokenCacheInvalidator
|
||||
soraSyncService *Sora2APISyncService
|
||||
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
@@ -65,6 +66,17 @@ func (s *TokenRefreshService) SetSoraAccountRepo(repo SoraAccountRepository) {
|
||||
}
|
||||
}
|
||||
|
||||
// SetSoraSyncService 设置 Sora2API 同步服务
|
||||
// 需要在 Start() 之前调用
|
||||
func (s *TokenRefreshService) SetSoraSyncService(svc *Sora2APISyncService) {
|
||||
s.soraSyncService = svc
|
||||
for _, refresher := range s.refreshers {
|
||||
if openaiRefresher, ok := refresher.(*OpenAITokenRefresher); ok {
|
||||
openaiRefresher.SetSoraSyncService(svc)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动后台刷新服务
|
||||
func (s *TokenRefreshService) Start() {
|
||||
if !s.cfg.Enabled {
|
||||
|
||||
@@ -86,6 +86,7 @@ type OpenAITokenRefresher struct {
|
||||
openaiOAuthService *OpenAIOAuthService
|
||||
accountRepo AccountRepository
|
||||
soraAccountRepo SoraAccountRepository // Sora 扩展表仓储,用于双表同步
|
||||
soraSyncService *Sora2APISyncService // Sora2API 同步服务
|
||||
}
|
||||
|
||||
// NewOpenAITokenRefresher 创建 OpenAI token刷新器
|
||||
@@ -103,17 +104,22 @@ func (r *OpenAITokenRefresher) SetSoraAccountRepo(repo SoraAccountRepository) {
|
||||
r.soraAccountRepo = repo
|
||||
}
|
||||
|
||||
// SetSoraSyncService 设置 Sora2API 同步服务
|
||||
func (r *OpenAITokenRefresher) SetSoraSyncService(svc *Sora2APISyncService) {
|
||||
r.soraSyncService = svc
|
||||
}
|
||||
|
||||
// CanRefresh 检查是否能处理此账号
|
||||
// 只处理 openai 平台的 oauth 类型账号
|
||||
func (r *OpenAITokenRefresher) CanRefresh(account *Account) bool {
|
||||
return account.Platform == PlatformOpenAI &&
|
||||
return (account.Platform == PlatformOpenAI || account.Platform == PlatformSora) &&
|
||||
account.Type == AccountTypeOAuth
|
||||
}
|
||||
|
||||
// NeedsRefresh 检查token是否需要刷新
|
||||
// 基于 expires_at 字段判断是否在刷新窗口内
|
||||
func (r *OpenAITokenRefresher) NeedsRefresh(account *Account, refreshWindow time.Duration) bool {
|
||||
expiresAt := account.GetOpenAITokenExpiresAt()
|
||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||
if expiresAt == nil {
|
||||
return false
|
||||
}
|
||||
@@ -145,6 +151,17 @@ func (r *OpenAITokenRefresher) Refresh(ctx context.Context, account *Account) (m
|
||||
go r.syncLinkedSoraAccounts(context.Background(), account.ID, newCredentials)
|
||||
}
|
||||
|
||||
// 如果是 Sora 平台账号,同步到 sora2api(不阻塞主流程)
|
||||
if account.Platform == PlatformSora && r.soraSyncService != nil {
|
||||
syncAccount := *account
|
||||
syncAccount.Credentials = newCredentials
|
||||
go func() {
|
||||
if err := r.soraSyncService.SyncAccount(context.Background(), &syncAccount); err != nil {
|
||||
log.Printf("[TokenSync] 同步 Sora2API 失败: account_id=%d err=%v", syncAccount.ID, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return newCredentials, nil
|
||||
}
|
||||
|
||||
@@ -201,6 +218,13 @@ func (r *OpenAITokenRefresher) syncLinkedSoraAccounts(ctx context.Context, opena
|
||||
}
|
||||
}
|
||||
|
||||
// 2.3 同步到 sora2api(如果配置)
|
||||
if r.soraSyncService != nil {
|
||||
if err := r.soraSyncService.SyncAccount(ctx, &soraAccount); err != nil {
|
||||
log.Printf("[TokenSync] 同步 sora2api 失败: account_id=%d err=%v", soraAccount.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("[TokenSync] 成功同步 Sora 账号 token: sora_account_id=%d openai_account_id=%d dual_table=%v",
|
||||
soraAccount.ID, openaiAccountID, r.soraAccountRepo != nil)
|
||||
}
|
||||
|
||||
@@ -46,6 +46,7 @@ type UsageLog struct {
|
||||
// 图片生成字段
|
||||
ImageCount int
|
||||
ImageSize *string
|
||||
MediaType *string
|
||||
|
||||
CreatedAt time.Time
|
||||
|
||||
|
||||
@@ -40,6 +40,7 @@ func ProvideEmailQueueService(emailService *EmailService) *EmailQueueService {
|
||||
func ProvideTokenRefreshService(
|
||||
accountRepo AccountRepository,
|
||||
soraAccountRepo SoraAccountRepository, // Sora 扩展表仓储,用于双表同步
|
||||
soraSyncService *Sora2APISyncService,
|
||||
oauthService *OAuthService,
|
||||
openaiOAuthService *OpenAIOAuthService,
|
||||
geminiOAuthService *GeminiOAuthService,
|
||||
@@ -50,6 +51,9 @@ func ProvideTokenRefreshService(
|
||||
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, cfg)
|
||||
// 注入 Sora 账号扩展表仓储,用于 OpenAI Token 刷新时同步 sora_accounts 表
|
||||
svc.SetSoraAccountRepo(soraAccountRepo)
|
||||
if soraSyncService != nil {
|
||||
svc.SetSoraSyncService(soraSyncService)
|
||||
}
|
||||
svc.Start()
|
||||
return svc
|
||||
}
|
||||
@@ -224,6 +228,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewBillingCacheService,
|
||||
NewAdminService,
|
||||
NewGatewayService,
|
||||
NewSoraGatewayService,
|
||||
NewOpenAIGatewayService,
|
||||
NewOAuthService,
|
||||
NewOpenAIOAuthService,
|
||||
@@ -237,6 +242,8 @@ var ProviderSet = wire.NewSet(
|
||||
NewAntigravityTokenProvider,
|
||||
NewOpenAITokenProvider,
|
||||
NewClaudeTokenProvider,
|
||||
NewSora2APIService,
|
||||
NewSora2APISyncService,
|
||||
NewAntigravityGatewayService,
|
||||
ProvideRateLimitService,
|
||||
NewAccountUsageService,
|
||||
|
||||
Reference in New Issue
Block a user