fix(sora): 修复流式重写与计费问题

This commit is contained in:
yangjianbo
2026-01-31 21:46:28 +08:00
parent 618a614cbf
commit 78d0ca3775
19 changed files with 325 additions and 210 deletions

View File

@@ -928,6 +928,7 @@ func setDefaults() {
viper.SetDefault("sora2api.admin_token_ttl_seconds", 900) viper.SetDefault("sora2api.admin_token_ttl_seconds", 900)
viper.SetDefault("sora2api.admin_timeout_seconds", 10) viper.SetDefault("sora2api.admin_timeout_seconds", 10)
viper.SetDefault("sora2api.token_import_mode", "at") viper.SetDefault("sora2api.token_import_mode", "at")
} }
func (c *Config) Validate() error { func (c *Config) Validate() error {
@@ -1263,20 +1264,6 @@ func (c *Config) Validate() error {
if err := ValidateAbsoluteHTTPURL(c.Sora2API.BaseURL); err != nil { if err := ValidateAbsoluteHTTPURL(c.Sora2API.BaseURL); err != nil {
return fmt.Errorf("sora2api.base_url invalid: %w", err) return fmt.Errorf("sora2api.base_url invalid: %w", err)
} }
warnIfInsecureURL("sora2api.base_url", c.Sora2API.BaseURL)
}
if mode := strings.TrimSpace(strings.ToLower(c.Sora2API.TokenImportMode)); mode != "" {
switch mode {
case "at", "offline":
default:
return fmt.Errorf("sora2api.token_import_mode must be one of: at/offline")
}
}
if c.Sora2API.AdminTokenTTLSeconds < 0 {
return fmt.Errorf("sora2api.admin_token_ttl_seconds must be non-negative")
}
if c.Sora2API.AdminTimeoutSeconds < 0 {
return fmt.Errorf("sora2api.admin_timeout_seconds must be non-negative")
} }
if c.Ops.MetricsCollectorCache.TTL < 0 { if c.Ops.MetricsCollectorCache.TTL < 0 {
return fmt.Errorf("ops.metrics_collector_cache.ttl must be non-negative") return fmt.Errorf("ops.metrics_collector_cache.ttl must be non-negative")

View File

@@ -35,15 +35,15 @@ type CreateGroupRequest struct {
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"` WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"` MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
// 图片生成计费配置antigravity 和 gemini 平台使用,负数表示清除配置) // 图片生成计费配置antigravity 和 gemini 平台使用,负数表示清除配置)
ImagePrice1K *float64 `json:"image_price_1k"` ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"` ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"` ImagePrice4K *float64 `json:"image_price_4k"`
SoraImagePrice360 *float64 `json:"sora_image_price_360"` SoraImagePrice360 *float64 `json:"sora_image_price_360"`
SoraImagePrice540 *float64 `json:"sora_image_price_540"` SoraImagePrice540 *float64 `json:"sora_image_price_540"`
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"` SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"`
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"` SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"`
ClaudeCodeOnly bool `json:"claude_code_only"` ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"` FallbackGroupID *int64 `json:"fallback_group_id"`
// 模型路由配置(仅 anthropic 平台使用) // 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64 `json:"model_routing"` ModelRouting map[string][]int64 `json:"model_routing"`
ModelRoutingEnabled bool `json:"model_routing_enabled"` ModelRoutingEnabled bool `json:"model_routing_enabled"`
@@ -62,15 +62,15 @@ type UpdateGroupRequest struct {
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"` WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"` MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
// 图片生成计费配置antigravity 和 gemini 平台使用,负数表示清除配置) // 图片生成计费配置antigravity 和 gemini 平台使用,负数表示清除配置)
ImagePrice1K *float64 `json:"image_price_1k"` ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"` ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"` ImagePrice4K *float64 `json:"image_price_4k"`
SoraImagePrice360 *float64 `json:"sora_image_price_360"` SoraImagePrice360 *float64 `json:"sora_image_price_360"`
SoraImagePrice540 *float64 `json:"sora_image_price_540"` SoraImagePrice540 *float64 `json:"sora_image_price_540"`
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"` SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"`
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"` SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"`
ClaudeCodeOnly *bool `json:"claude_code_only"` ClaudeCodeOnly *bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"` FallbackGroupID *int64 `json:"fallback_group_id"`
// 模型路由配置(仅 anthropic 平台使用) // 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64 `json:"model_routing"` ModelRouting map[string][]int64 `json:"model_routing"`
ModelRoutingEnabled *bool `json:"model_routing_enabled"` ModelRoutingEnabled *bool `json:"model_routing_enabled"`
@@ -163,26 +163,26 @@ func (h *GroupHandler) Create(c *gin.Context) {
} }
group, err := h.adminService.CreateGroup(c.Request.Context(), &service.CreateGroupInput{ group, err := h.adminService.CreateGroup(c.Request.Context(), &service.CreateGroupInput{
Name: req.Name, Name: req.Name,
Description: req.Description, Description: req.Description,
Platform: req.Platform, Platform: req.Platform,
RateMultiplier: req.RateMultiplier, RateMultiplier: req.RateMultiplier,
IsExclusive: req.IsExclusive, IsExclusive: req.IsExclusive,
SubscriptionType: req.SubscriptionType, SubscriptionType: req.SubscriptionType,
DailyLimitUSD: req.DailyLimitUSD, DailyLimitUSD: req.DailyLimitUSD,
WeeklyLimitUSD: req.WeeklyLimitUSD, WeeklyLimitUSD: req.WeeklyLimitUSD,
MonthlyLimitUSD: req.MonthlyLimitUSD, MonthlyLimitUSD: req.MonthlyLimitUSD,
ImagePrice1K: req.ImagePrice1K, ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K, ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K, ImagePrice4K: req.ImagePrice4K,
SoraImagePrice360: req.SoraImagePrice360, SoraImagePrice360: req.SoraImagePrice360,
SoraImagePrice540: req.SoraImagePrice540, SoraImagePrice540: req.SoraImagePrice540,
SoraVideoPricePerRequest: req.SoraVideoPricePerRequest, SoraVideoPricePerRequest: req.SoraVideoPricePerRequest,
SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD, SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD,
ClaudeCodeOnly: req.ClaudeCodeOnly, ClaudeCodeOnly: req.ClaudeCodeOnly,
FallbackGroupID: req.FallbackGroupID, FallbackGroupID: req.FallbackGroupID,
ModelRouting: req.ModelRouting, ModelRouting: req.ModelRouting,
ModelRoutingEnabled: req.ModelRoutingEnabled, ModelRoutingEnabled: req.ModelRoutingEnabled,
}) })
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
@@ -208,27 +208,27 @@ func (h *GroupHandler) Update(c *gin.Context) {
} }
group, err := h.adminService.UpdateGroup(c.Request.Context(), groupID, &service.UpdateGroupInput{ group, err := h.adminService.UpdateGroup(c.Request.Context(), groupID, &service.UpdateGroupInput{
Name: req.Name, Name: req.Name,
Description: req.Description, Description: req.Description,
Platform: req.Platform, Platform: req.Platform,
RateMultiplier: req.RateMultiplier, RateMultiplier: req.RateMultiplier,
IsExclusive: req.IsExclusive, IsExclusive: req.IsExclusive,
Status: req.Status, Status: req.Status,
SubscriptionType: req.SubscriptionType, SubscriptionType: req.SubscriptionType,
DailyLimitUSD: req.DailyLimitUSD, DailyLimitUSD: req.DailyLimitUSD,
WeeklyLimitUSD: req.WeeklyLimitUSD, WeeklyLimitUSD: req.WeeklyLimitUSD,
MonthlyLimitUSD: req.MonthlyLimitUSD, MonthlyLimitUSD: req.MonthlyLimitUSD,
ImagePrice1K: req.ImagePrice1K, ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K, ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K, ImagePrice4K: req.ImagePrice4K,
SoraImagePrice360: req.SoraImagePrice360, SoraImagePrice360: req.SoraImagePrice360,
SoraImagePrice540: req.SoraImagePrice540, SoraImagePrice540: req.SoraImagePrice540,
SoraVideoPricePerRequest: req.SoraVideoPricePerRequest, SoraVideoPricePerRequest: req.SoraVideoPricePerRequest,
SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD, SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD,
ClaudeCodeOnly: req.ClaudeCodeOnly, ClaudeCodeOnly: req.ClaudeCodeOnly,
FallbackGroupID: req.FallbackGroupID, FallbackGroupID: req.FallbackGroupID,
ModelRouting: req.ModelRouting, ModelRouting: req.ModelRouting,
ModelRoutingEnabled: req.ModelRoutingEnabled, ModelRoutingEnabled: req.ModelRoutingEnabled,
}) })
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)

View File

@@ -122,28 +122,28 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
func groupFromServiceBase(g *service.Group) Group { func groupFromServiceBase(g *service.Group) Group {
return Group{ return Group{
ID: g.ID, ID: g.ID,
Name: g.Name, Name: g.Name,
Description: g.Description, Description: g.Description,
Platform: g.Platform, Platform: g.Platform,
RateMultiplier: g.RateMultiplier, RateMultiplier: g.RateMultiplier,
IsExclusive: g.IsExclusive, IsExclusive: g.IsExclusive,
Status: g.Status, Status: g.Status,
SubscriptionType: g.SubscriptionType, SubscriptionType: g.SubscriptionType,
DailyLimitUSD: g.DailyLimitUSD, DailyLimitUSD: g.DailyLimitUSD,
WeeklyLimitUSD: g.WeeklyLimitUSD, WeeklyLimitUSD: g.WeeklyLimitUSD,
MonthlyLimitUSD: g.MonthlyLimitUSD, MonthlyLimitUSD: g.MonthlyLimitUSD,
ImagePrice1K: g.ImagePrice1K, ImagePrice1K: g.ImagePrice1K,
ImagePrice2K: g.ImagePrice2K, ImagePrice2K: g.ImagePrice2K,
ImagePrice4K: g.ImagePrice4K, ImagePrice4K: g.ImagePrice4K,
SoraImagePrice360: g.SoraImagePrice360, SoraImagePrice360: g.SoraImagePrice360,
SoraImagePrice540: g.SoraImagePrice540, SoraImagePrice540: g.SoraImagePrice540,
SoraVideoPricePerRequest: g.SoraVideoPricePerRequest, SoraVideoPricePerRequest: g.SoraVideoPricePerRequest,
SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHD, SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHD,
ClaudeCodeOnly: g.ClaudeCodeOnly, ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID, FallbackGroupID: g.FallbackGroupID,
CreatedAt: g.CreatedAt, CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt, UpdatedAt: g.UpdatedAt,
} }
} }

View File

@@ -62,9 +62,9 @@ type Group struct {
ImagePrice4K *float64 `json:"image_price_4k"` ImagePrice4K *float64 `json:"image_price_4k"`
// Sora 按次计费配置 // Sora 按次计费配置
SoraImagePrice360 *float64 `json:"sora_image_price_360"` SoraImagePrice360 *float64 `json:"sora_image_price_360"`
SoraImagePrice540 *float64 `json:"sora_image_price_540"` SoraImagePrice540 *float64 `json:"sora_image_price_540"`
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"` SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"`
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"` SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"`
// Claude Code 客户端限制 // Claude Code 客户端限制

View File

@@ -33,6 +33,7 @@ type SoraGatewayHandler struct {
streamMode string streamMode string
sora2apiBaseURL string sora2apiBaseURL string
soraMediaSigningKey string soraMediaSigningKey string
mediaClient *http.Client
} }
// NewSoraGatewayHandler creates a new SoraGatewayHandler // NewSoraGatewayHandler creates a new SoraGatewayHandler
@@ -61,6 +62,10 @@ func NewSoraGatewayHandler(
if cfg != nil { if cfg != nil {
baseURL = strings.TrimRight(strings.TrimSpace(cfg.Sora2API.BaseURL), "/") baseURL = strings.TrimRight(strings.TrimSpace(cfg.Sora2API.BaseURL), "/")
} }
mediaTimeout := 180 * time.Second
if cfg != nil && cfg.Gateway.SoraRequestTimeoutSeconds > 0 {
mediaTimeout = time.Duration(cfg.Gateway.SoraRequestTimeoutSeconds) * time.Second
}
return &SoraGatewayHandler{ return &SoraGatewayHandler{
gatewayService: gatewayService, gatewayService: gatewayService,
soraGatewayService: soraGatewayService, soraGatewayService: soraGatewayService,
@@ -70,6 +75,7 @@ func NewSoraGatewayHandler(
streamMode: strings.ToLower(streamMode), streamMode: strings.ToLower(streamMode),
sora2apiBaseURL: baseURL, sora2apiBaseURL: baseURL,
soraMediaSigningKey: signKey, soraMediaSigningKey: signKey,
mediaClient: &http.Client{Timeout: mediaTimeout},
} }
} }
@@ -457,7 +463,11 @@ func (h *SoraGatewayHandler) proxySoraMedia(c *gin.Context, requireSignature boo
} }
} }
resp, err := http.DefaultClient.Do(req) client := h.mediaClient
if client == nil {
client = http.DefaultClient
}
resp, err := client.Do(req)
if err != nil { if err != nil {
c.Status(http.StatusBadGateway) c.Status(http.StatusBadGateway)
return return

View File

@@ -1565,7 +1565,7 @@ func itoa(v int) string {
// Uses PostgreSQL JSONB @> operator for efficient queries (requires GIN index). // Uses PostgreSQL JSONB @> operator for efficient queries (requires GIN index).
// //
// Use case: Finding Sora accounts linked via linked_openai_account_id. // Use case: Finding Sora accounts linked via linked_openai_account_id.
func (r *accountRepository) FindByExtraField(ctx context.Context, key string, value interface{}) ([]service.Account, error) { func (r *accountRepository) FindByExtraField(ctx context.Context, key string, value any) ([]service.Account, error) {
accounts, err := r.client.Account.Query(). accounts, err := r.client.Account.Query().
Where( Where(
dbaccount.PlatformEQ("sora"), // 限定平台为 sora dbaccount.PlatformEQ("sora"), // 限定平台为 sora

View File

@@ -410,32 +410,32 @@ func groupEntityToService(g *dbent.Group) *service.Group {
return nil return nil
} }
return &service.Group{ return &service.Group{
ID: g.ID, ID: g.ID,
Name: g.Name, Name: g.Name,
Description: derefString(g.Description), Description: derefString(g.Description),
Platform: g.Platform, Platform: g.Platform,
RateMultiplier: g.RateMultiplier, RateMultiplier: g.RateMultiplier,
IsExclusive: g.IsExclusive, IsExclusive: g.IsExclusive,
Status: g.Status, Status: g.Status,
Hydrated: true, Hydrated: true,
SubscriptionType: g.SubscriptionType, SubscriptionType: g.SubscriptionType,
DailyLimitUSD: g.DailyLimitUsd, DailyLimitUSD: g.DailyLimitUsd,
WeeklyLimitUSD: g.WeeklyLimitUsd, WeeklyLimitUSD: g.WeeklyLimitUsd,
MonthlyLimitUSD: g.MonthlyLimitUsd, MonthlyLimitUSD: g.MonthlyLimitUsd,
ImagePrice1K: g.ImagePrice1k, ImagePrice1K: g.ImagePrice1k,
ImagePrice2K: g.ImagePrice2k, ImagePrice2K: g.ImagePrice2k,
ImagePrice4K: g.ImagePrice4k, ImagePrice4K: g.ImagePrice4k,
SoraImagePrice360: g.SoraImagePrice360, SoraImagePrice360: g.SoraImagePrice360,
SoraImagePrice540: g.SoraImagePrice540, SoraImagePrice540: g.SoraImagePrice540,
SoraVideoPricePerRequest: g.SoraVideoPricePerRequest, SoraVideoPricePerRequest: g.SoraVideoPricePerRequest,
SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHd, SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHd,
DefaultValidityDays: g.DefaultValidityDays, DefaultValidityDays: g.DefaultValidityDays,
ClaudeCodeOnly: g.ClaudeCodeOnly, ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID, FallbackGroupID: g.FallbackGroupID,
ModelRouting: g.ModelRouting, ModelRouting: g.ModelRouting,
ModelRoutingEnabled: g.ModelRoutingEnabled, ModelRoutingEnabled: g.ModelRoutingEnabled,
CreatedAt: g.CreatedAt, CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt, UpdatedAt: g.UpdatedAt,
} }
} }

View File

@@ -76,7 +76,7 @@ func (r *soraAccountRepository) GetByAccountID(ctx context.Context, accountID in
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close() defer func() { _ = rows.Close() }()
if !rows.Next() { if !rows.Next() {
return nil, nil // 记录不存在 return nil, nil // 记录不存在

View File

@@ -27,7 +27,7 @@ type AccountRepository interface {
GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error)
// FindByExtraField 根据 extra 字段中的键值对查找账号(限定 platform='sora' // FindByExtraField 根据 extra 字段中的键值对查找账号(限定 platform='sora'
// 用于查找通过 linked_openai_account_id 关联的 Sora 账号 // 用于查找通过 linked_openai_account_id 关联的 Sora 账号
FindByExtraField(ctx context.Context, key string, value interface{}) ([]Account, error) FindByExtraField(ctx context.Context, key string, value any) ([]Account, error)
Update(ctx context.Context, account *Account) error Update(ctx context.Context, account *Account) error
Delete(ctx context.Context, id int64) error Delete(ctx context.Context, id int64) error

View File

@@ -23,24 +23,24 @@ type APIKeyAuthUserSnapshot struct {
// APIKeyAuthGroupSnapshot 分组快照 // APIKeyAuthGroupSnapshot 分组快照
type APIKeyAuthGroupSnapshot struct { type APIKeyAuthGroupSnapshot struct {
ID int64 `json:"id"` ID int64 `json:"id"`
Name string `json:"name"` Name string `json:"name"`
Platform string `json:"platform"` Platform string `json:"platform"`
Status string `json:"status"` Status string `json:"status"`
SubscriptionType string `json:"subscription_type"` SubscriptionType string `json:"subscription_type"`
RateMultiplier float64 `json:"rate_multiplier"` RateMultiplier float64 `json:"rate_multiplier"`
DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"` DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"`
WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"` WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"`
MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"` MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"`
ImagePrice1K *float64 `json:"image_price_1k,omitempty"` ImagePrice1K *float64 `json:"image_price_1k,omitempty"`
ImagePrice2K *float64 `json:"image_price_2k,omitempty"` ImagePrice2K *float64 `json:"image_price_2k,omitempty"`
ImagePrice4K *float64 `json:"image_price_4k,omitempty"` ImagePrice4K *float64 `json:"image_price_4k,omitempty"`
SoraImagePrice360 *float64 `json:"sora_image_price_360,omitempty"` SoraImagePrice360 *float64 `json:"sora_image_price_360,omitempty"`
SoraImagePrice540 *float64 `json:"sora_image_price_540,omitempty"` SoraImagePrice540 *float64 `json:"sora_image_price_540,omitempty"`
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request,omitempty"` SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request,omitempty"`
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd,omitempty"` SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd,omitempty"`
ClaudeCodeOnly bool `json:"claude_code_only"` ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"` FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
// Model routing is used by gateway account selection, so it must be part of auth cache snapshot. // Model routing is used by gateway account selection, so it must be part of auth cache snapshot.
// Only anthropic groups use these fields; others may leave them empty. // Only anthropic groups use these fields; others may leave them empty.

View File

@@ -223,26 +223,26 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
} }
if apiKey.Group != nil { if apiKey.Group != nil {
snapshot.Group = &APIKeyAuthGroupSnapshot{ snapshot.Group = &APIKeyAuthGroupSnapshot{
ID: apiKey.Group.ID, ID: apiKey.Group.ID,
Name: apiKey.Group.Name, Name: apiKey.Group.Name,
Platform: apiKey.Group.Platform, Platform: apiKey.Group.Platform,
Status: apiKey.Group.Status, Status: apiKey.Group.Status,
SubscriptionType: apiKey.Group.SubscriptionType, SubscriptionType: apiKey.Group.SubscriptionType,
RateMultiplier: apiKey.Group.RateMultiplier, RateMultiplier: apiKey.Group.RateMultiplier,
DailyLimitUSD: apiKey.Group.DailyLimitUSD, DailyLimitUSD: apiKey.Group.DailyLimitUSD,
WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD, WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD,
MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD, MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD,
ImagePrice1K: apiKey.Group.ImagePrice1K, ImagePrice1K: apiKey.Group.ImagePrice1K,
ImagePrice2K: apiKey.Group.ImagePrice2K, ImagePrice2K: apiKey.Group.ImagePrice2K,
ImagePrice4K: apiKey.Group.ImagePrice4K, ImagePrice4K: apiKey.Group.ImagePrice4K,
SoraImagePrice360: apiKey.Group.SoraImagePrice360, SoraImagePrice360: apiKey.Group.SoraImagePrice360,
SoraImagePrice540: apiKey.Group.SoraImagePrice540, SoraImagePrice540: apiKey.Group.SoraImagePrice540,
SoraVideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest, SoraVideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest,
SoraVideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD, SoraVideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD,
ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly, ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly,
FallbackGroupID: apiKey.Group.FallbackGroupID, FallbackGroupID: apiKey.Group.FallbackGroupID,
ModelRouting: apiKey.Group.ModelRouting, ModelRouting: apiKey.Group.ModelRouting,
ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled, ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled,
} }
} }
return snapshot return snapshot
@@ -270,27 +270,27 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
} }
if snapshot.Group != nil { if snapshot.Group != nil {
apiKey.Group = &Group{ apiKey.Group = &Group{
ID: snapshot.Group.ID, ID: snapshot.Group.ID,
Name: snapshot.Group.Name, Name: snapshot.Group.Name,
Platform: snapshot.Group.Platform, Platform: snapshot.Group.Platform,
Status: snapshot.Group.Status, Status: snapshot.Group.Status,
Hydrated: true, Hydrated: true,
SubscriptionType: snapshot.Group.SubscriptionType, SubscriptionType: snapshot.Group.SubscriptionType,
RateMultiplier: snapshot.Group.RateMultiplier, RateMultiplier: snapshot.Group.RateMultiplier,
DailyLimitUSD: snapshot.Group.DailyLimitUSD, DailyLimitUSD: snapshot.Group.DailyLimitUSD,
WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD, WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD,
MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD, MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD,
ImagePrice1K: snapshot.Group.ImagePrice1K, ImagePrice1K: snapshot.Group.ImagePrice1K,
ImagePrice2K: snapshot.Group.ImagePrice2K, ImagePrice2K: snapshot.Group.ImagePrice2K,
ImagePrice4K: snapshot.Group.ImagePrice4K, ImagePrice4K: snapshot.Group.ImagePrice4K,
SoraImagePrice360: snapshot.Group.SoraImagePrice360, SoraImagePrice360: snapshot.Group.SoraImagePrice360,
SoraImagePrice540: snapshot.Group.SoraImagePrice540, SoraImagePrice540: snapshot.Group.SoraImagePrice540,
SoraVideoPricePerRequest: snapshot.Group.SoraVideoPricePerRequest, SoraVideoPricePerRequest: snapshot.Group.SoraVideoPricePerRequest,
SoraVideoPricePerRequestHD: snapshot.Group.SoraVideoPricePerRequestHD, SoraVideoPricePerRequestHD: snapshot.Group.SoraVideoPricePerRequestHD,
ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly, ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly,
FallbackGroupID: snapshot.Group.FallbackGroupID, FallbackGroupID: snapshot.Group.FallbackGroupID,
ModelRouting: snapshot.Group.ModelRouting, ModelRouting: snapshot.Group.ModelRouting,
ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled, ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled,
} }
} }
return apiKey return apiKey

View File

@@ -3465,7 +3465,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
var cost *CostBreakdown var cost *CostBreakdown
// 根据请求类型选择计费方式 // 根据请求类型选择计费方式
if result.MediaType == "image" || result.MediaType == "video" || result.MediaType == "prompt" { if result.MediaType == "image" || result.MediaType == "video" {
var soraConfig *SoraPriceConfig var soraConfig *SoraPriceConfig
if apiKey.Group != nil { if apiKey.Group != nil {
soraConfig = &SoraPriceConfig{ soraConfig = &SoraPriceConfig{
@@ -3480,6 +3480,8 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
} else { } else {
cost = s.billingService.CalculateSoraVideoCost(result.Model, soraConfig, multiplier) cost = s.billingService.CalculateSoraVideoCost(result.Model, soraConfig, multiplier)
} }
} else if result.MediaType == "prompt" {
cost = &CostBreakdown{}
} else if result.ImageCount > 0 { } else if result.ImageCount > 0 {
// 图片生成计费 // 图片生成计费
var groupConfig *ImagePriceConfig var groupConfig *ImagePriceConfig

View File

@@ -27,9 +27,9 @@ type Group struct {
ImagePrice4K *float64 ImagePrice4K *float64
// Sora 按次计费配置(阶段 1 // Sora 按次计费配置(阶段 1
SoraImagePrice360 *float64 SoraImagePrice360 *float64
SoraImagePrice540 *float64 SoraImagePrice540 *float64
SoraVideoPricePerRequest *float64 SoraVideoPricePerRequest *float64
SoraVideoPricePerRequestHD *float64 SoraVideoPricePerRequestHD *float64
// Claude Code 客户端限制 // Claude Code 客户端限制

View File

@@ -62,7 +62,6 @@ type Sora2APIService struct {
adminUsername string adminUsername string
adminPassword string adminPassword string
adminTokenTTL time.Duration adminTokenTTL time.Duration
adminTimeout time.Duration
tokenImportMode string tokenImportMode string
client *http.Client client *http.Client
@@ -72,9 +71,8 @@ type Sora2APIService struct {
adminTokenAt time.Time adminTokenAt time.Time
adminMu sync.Mutex adminMu sync.Mutex
modelCache []Sora2APIModel modelCache []Sora2APIModel
modelCacheAt time.Time modelMu sync.RWMutex
modelMu sync.RWMutex
} }
func NewSora2APIService(cfg *config.Config) *Sora2APIService { func NewSora2APIService(cfg *config.Config) *Sora2APIService {
@@ -96,7 +94,6 @@ func NewSora2APIService(cfg *config.Config) *Sora2APIService {
adminUsername: strings.TrimSpace(cfg.Sora2API.AdminUsername), adminUsername: strings.TrimSpace(cfg.Sora2API.AdminUsername),
adminPassword: strings.TrimSpace(cfg.Sora2API.AdminPassword), adminPassword: strings.TrimSpace(cfg.Sora2API.AdminPassword),
adminTokenTTL: adminTTL, adminTokenTTL: adminTTL,
adminTimeout: adminTimeout,
tokenImportMode: strings.ToLower(strings.TrimSpace(cfg.Sora2API.TokenImportMode)), tokenImportMode: strings.ToLower(strings.TrimSpace(cfg.Sora2API.TokenImportMode)),
client: &http.Client{}, client: &http.Client{},
adminClient: &http.Client{Timeout: adminTimeout}, adminClient: &http.Client{Timeout: adminTimeout},
@@ -176,7 +173,6 @@ func (s *Sora2APIService) ListModels(ctx context.Context) ([]Sora2APIModel, erro
s.modelMu.Lock() s.modelMu.Lock()
s.modelCache = models s.modelCache = models
s.modelCacheAt = time.Now()
s.modelMu.Unlock() s.modelMu.Unlock()
return models, nil return models, nil

View File

@@ -23,6 +23,8 @@ var soraSSEDataRe = regexp.MustCompile(`^data:\s*`)
var soraImageMarkdownRe = regexp.MustCompile(`!\[[^\]]*\]\(([^)]+)\)`) var soraImageMarkdownRe = regexp.MustCompile(`!\[[^\]]*\]\(([^)]+)\)`)
var soraVideoHTMLRe = regexp.MustCompile(`(?i)<video[^>]+src=['"]([^'"]+)['"]`) var soraVideoHTMLRe = regexp.MustCompile(`(?i)<video[^>]+src=['"]([^'"]+)['"]`)
const soraRewriteBufferLimit = 2048
var soraImageSizeMap = map[string]string{ var soraImageSizeMap = map[string]string{
"gpt-image": "360", "gpt-image": "360",
"gpt-image-landscape": "540", "gpt-image-landscape": "540",
@@ -30,7 +32,6 @@ var soraImageSizeMap = map[string]string{
} }
type soraStreamingResult struct { type soraStreamingResult struct {
content string
mediaType string mediaType string
mediaURLs []string mediaURLs []string
imageCount int imageCount int
@@ -307,6 +308,7 @@ func (s *SoraGatewayService) handleStreamingResponse(ctx context.Context, resp *
contentBuilder := strings.Builder{} contentBuilder := strings.Builder{}
var firstTokenMs *int var firstTokenMs *int
var upstreamError error var upstreamError error
rewriteBuffer := ""
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize maxLineSize := defaultMaxLineSize
@@ -333,12 +335,29 @@ func (s *SoraGatewayService) handleStreamingResponse(ctx context.Context, resp *
if soraSSEDataRe.MatchString(line) { if soraSSEDataRe.MatchString(line) {
data := soraSSEDataRe.ReplaceAllString(line, "") data := soraSSEDataRe.ReplaceAllString(line, "")
if data == "[DONE]" { if data == "[DONE]" {
if rewriteBuffer != "" {
flushLine, flushContent, err := s.flushSoraRewriteBuffer(rewriteBuffer, originalModel)
if err != nil {
return nil, err
}
if flushLine != "" {
if flushContent != "" {
if _, err := contentBuilder.WriteString(flushContent); err != nil {
return nil, err
}
}
if err := sendLine(flushLine); err != nil {
return nil, err
}
}
rewriteBuffer = ""
}
if err := sendLine("data: [DONE]"); err != nil { if err := sendLine("data: [DONE]"); err != nil {
return nil, err return nil, err
} }
break break
} }
updatedLine, contentDelta, errEvent := s.processSoraSSEData(data, originalModel) updatedLine, contentDelta, errEvent := s.processSoraSSEData(data, originalModel, &rewriteBuffer)
if errEvent != nil && upstreamError == nil { if errEvent != nil && upstreamError == nil {
upstreamError = errEvent upstreamError = errEvent
} }
@@ -347,7 +366,9 @@ func (s *SoraGatewayService) handleStreamingResponse(ctx context.Context, resp *
ms := int(time.Since(startTime).Milliseconds()) ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms firstTokenMs = &ms
} }
contentBuilder.WriteString(contentDelta) if _, err := contentBuilder.WriteString(contentDelta); err != nil {
return nil, err
}
} }
if err := sendLine(updatedLine); err != nil { if err := sendLine(updatedLine); err != nil {
return nil, err return nil, err
@@ -417,7 +438,6 @@ func (s *SoraGatewayService) handleStreamingResponse(ctx context.Context, resp *
} }
return &soraStreamingResult{ return &soraStreamingResult{
content: content,
mediaType: mediaType, mediaType: mediaType,
mediaURLs: mediaURLs, mediaURLs: mediaURLs,
imageCount: imageCount, imageCount: imageCount,
@@ -426,7 +446,7 @@ func (s *SoraGatewayService) handleStreamingResponse(ctx context.Context, resp *
}, nil }, nil
} }
func (s *SoraGatewayService) processSoraSSEData(data string, originalModel string) (string, string, error) { func (s *SoraGatewayService) processSoraSSEData(data string, originalModel string, rewriteBuffer *string) (string, string, error) {
if strings.TrimSpace(data) == "" { if strings.TrimSpace(data) == "" {
return "data: ", "", nil return "data: ", "", nil
} }
@@ -448,7 +468,12 @@ func (s *SoraGatewayService) processSoraSSEData(data string, originalModel strin
contentDelta, updated := extractSoraContent(payload) contentDelta, updated := extractSoraContent(payload)
if updated { if updated {
rewritten := s.rewriteSoraContent(contentDelta) var rewritten string
if rewriteBuffer != nil {
rewritten = s.rewriteSoraContentWithBuffer(contentDelta, rewriteBuffer)
} else {
rewritten = s.rewriteSoraContent(contentDelta)
}
if rewritten != contentDelta { if rewritten != contentDelta {
applySoraContent(payload, rewritten) applySoraContent(payload, rewritten)
contentDelta = rewritten contentDelta = rewritten
@@ -504,6 +529,78 @@ func applySoraContent(payload map[string]any, content string) {
} }
} }
func (s *SoraGatewayService) rewriteSoraContentWithBuffer(contentDelta string, buffer *string) string {
if buffer == nil {
return s.rewriteSoraContent(contentDelta)
}
if contentDelta == "" && *buffer == "" {
return ""
}
combined := *buffer + contentDelta
rewritten := s.rewriteSoraContent(combined)
bufferStart := s.findSoraRewriteBufferStart(rewritten)
if bufferStart < 0 {
*buffer = ""
return rewritten
}
if len(rewritten)-bufferStart > soraRewriteBufferLimit {
bufferStart = len(rewritten) - soraRewriteBufferLimit
}
output := rewritten[:bufferStart]
*buffer = rewritten[bufferStart:]
return output
}
func (s *SoraGatewayService) findSoraRewriteBufferStart(content string) int {
minIndex := -1
start := 0
for {
idx := strings.Index(content[start:], "![")
if idx < 0 {
break
}
idx += start
if !hasSoraImageMatchAt(content, idx) {
if minIndex == -1 || idx < minIndex {
minIndex = idx
}
}
start = idx + 2
}
lower := strings.ToLower(content)
start = 0
for {
idx := strings.Index(lower[start:], "<video")
if idx < 0 {
break
}
idx += start
if !hasSoraVideoMatchAt(content, idx) {
if minIndex == -1 || idx < minIndex {
minIndex = idx
}
}
start = idx + len("<video")
}
return minIndex
}
func hasSoraImageMatchAt(content string, idx int) bool {
if idx < 0 || idx >= len(content) {
return false
}
loc := soraImageMarkdownRe.FindStringIndex(content[idx:])
return loc != nil && loc[0] == 0
}
func hasSoraVideoMatchAt(content string, idx int) bool {
if idx < 0 || idx >= len(content) {
return false
}
loc := soraVideoHTMLRe.FindStringIndex(content[idx:])
return loc != nil && loc[0] == 0
}
func (s *SoraGatewayService) rewriteSoraContent(content string) string { func (s *SoraGatewayService) rewriteSoraContent(content string) string {
if content == "" { if content == "" {
return content return content
@@ -533,6 +630,31 @@ func (s *SoraGatewayService) rewriteSoraContent(content string) string {
return content return content
} }
func (s *SoraGatewayService) flushSoraRewriteBuffer(buffer string, originalModel string) (string, string, error) {
if buffer == "" {
return "", "", nil
}
rewritten := s.rewriteSoraContent(buffer)
payload := map[string]any{
"choices": []any{
map[string]any{
"delta": map[string]any{
"content": rewritten,
},
"index": 0,
},
},
}
if originalModel != "" {
payload["model"] = originalModel
}
updatedData, err := json.Marshal(payload)
if err != nil {
return "", "", err
}
return "data: " + string(updatedData), rewritten, nil
}
func (s *SoraGatewayService) rewriteSoraURL(raw string) string { func (s *SoraGatewayService) rewriteSoraURL(raw string) string {
raw = strings.TrimSpace(raw) raw = strings.TrimSpace(raw)
if raw == "" { if raw == "" {

View File

@@ -15,9 +15,15 @@ func SignSoraMediaURL(path string, query string, expires int64, key string) stri
return "" return ""
} }
mac := hmac.New(sha256.New, []byte(key)) mac := hmac.New(sha256.New, []byte(key))
mac.Write([]byte(buildSoraMediaSignPayload(path, query))) if _, err := mac.Write([]byte(buildSoraMediaSignPayload(path, query))); err != nil {
mac.Write([]byte("|")) return ""
mac.Write([]byte(strconv.FormatInt(expires, 10))) }
if _, err := mac.Write([]byte("|")); err != nil {
return ""
}
if _, err := mac.Write([]byte(strconv.FormatInt(expires, 10))); err != nil {
return ""
}
return hex.EncodeToString(mac.Sum(nil)) return hex.EncodeToString(mac.Sum(nil))
} }

View File

@@ -15,11 +15,9 @@ import (
// 定期检查并刷新即将过期的token // 定期检查并刷新即将过期的token
type TokenRefreshService struct { type TokenRefreshService struct {
accountRepo AccountRepository accountRepo AccountRepository
soraAccountRepo SoraAccountRepository // Sora 扩展表仓储,用于双表同步
refreshers []TokenRefresher refreshers []TokenRefresher
cfg *config.TokenRefreshConfig cfg *config.TokenRefreshConfig
cacheInvalidator TokenCacheInvalidator cacheInvalidator TokenCacheInvalidator
soraSyncService *Sora2APISyncService
stopCh chan struct{} stopCh chan struct{}
wg sync.WaitGroup wg sync.WaitGroup
@@ -57,7 +55,6 @@ func NewTokenRefreshService(
// 用于在 OpenAI Token 刷新时同步更新 sora_accounts 表 // 用于在 OpenAI Token 刷新时同步更新 sora_accounts 表
// 需要在 Start() 之前调用 // 需要在 Start() 之前调用
func (s *TokenRefreshService) SetSoraAccountRepo(repo SoraAccountRepository) { func (s *TokenRefreshService) SetSoraAccountRepo(repo SoraAccountRepository) {
s.soraAccountRepo = repo
// 将 soraAccountRepo 注入到 OpenAITokenRefresher // 将 soraAccountRepo 注入到 OpenAITokenRefresher
for _, refresher := range s.refreshers { for _, refresher := range s.refreshers {
if openaiRefresher, ok := refresher.(*OpenAITokenRefresher); ok { if openaiRefresher, ok := refresher.(*OpenAITokenRefresher); ok {
@@ -69,7 +66,6 @@ func (s *TokenRefreshService) SetSoraAccountRepo(repo SoraAccountRepository) {
// SetSoraSyncService 设置 Sora2API 同步服务 // SetSoraSyncService 设置 Sora2API 同步服务
// 需要在 Start() 之前调用 // 需要在 Start() 之前调用
func (s *TokenRefreshService) SetSoraSyncService(svc *Sora2APISyncService) { func (s *TokenRefreshService) SetSoraSyncService(svc *Sora2APISyncService) {
s.soraSyncService = svc
for _, refresher := range s.refreshers { for _, refresher := range s.refreshers {
if openaiRefresher, ok := refresher.(*OpenAITokenRefresher); ok { if openaiRefresher, ok := refresher.(*OpenAITokenRefresher); ok {
openaiRefresher.SetSoraSyncService(svc) openaiRefresher.SetSoraSyncService(svc)

View File

@@ -83,10 +83,10 @@ func (r *ClaudeTokenRefresher) Refresh(ctx context.Context, account *Account) (m
// OpenAITokenRefresher 处理 OpenAI OAuth token刷新 // OpenAITokenRefresher 处理 OpenAI OAuth token刷新
type OpenAITokenRefresher struct { type OpenAITokenRefresher struct {
openaiOAuthService *OpenAIOAuthService openaiOAuthService *OpenAIOAuthService
accountRepo AccountRepository accountRepo AccountRepository
soraAccountRepo SoraAccountRepository // Sora 扩展表仓储,用于双表同步 soraAccountRepo SoraAccountRepository // Sora 扩展表仓储,用于双表同步
soraSyncService *Sora2APISyncService // Sora2API 同步服务 soraSyncService *Sora2APISyncService // Sora2API 同步服务
} }
// NewOpenAITokenRefresher 创建 OpenAI token刷新器 // NewOpenAITokenRefresher 创建 OpenAI token刷新器

View File

@@ -51,9 +51,7 @@ func ProvideTokenRefreshService(
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, cfg) svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, cfg)
// 注入 Sora 账号扩展表仓储,用于 OpenAI Token 刷新时同步 sora_accounts 表 // 注入 Sora 账号扩展表仓储,用于 OpenAI Token 刷新时同步 sora_accounts 表
svc.SetSoraAccountRepo(soraAccountRepo) svc.SetSoraAccountRepo(soraAccountRepo)
if soraSyncService != nil { svc.SetSoraSyncService(soraSyncService)
svc.SetSoraSyncService(soraSyncService)
}
svc.Start() svc.Start()
return svc return svc
} }
@@ -242,8 +240,6 @@ var ProviderSet = wire.NewSet(
NewAntigravityTokenProvider, NewAntigravityTokenProvider,
NewOpenAITokenProvider, NewOpenAITokenProvider,
NewClaudeTokenProvider, NewClaudeTokenProvider,
NewSora2APIService,
NewSora2APISyncService,
NewAntigravityGatewayService, NewAntigravityGatewayService,
ProvideRateLimitService, ProvideRateLimitService,
NewAccountUsageService, NewAccountUsageService,