fix(sora): 修复流式重写与计费问题
This commit is contained in:
@@ -928,6 +928,7 @@ func setDefaults() {
|
|||||||
viper.SetDefault("sora2api.admin_token_ttl_seconds", 900)
|
viper.SetDefault("sora2api.admin_token_ttl_seconds", 900)
|
||||||
viper.SetDefault("sora2api.admin_timeout_seconds", 10)
|
viper.SetDefault("sora2api.admin_timeout_seconds", 10)
|
||||||
viper.SetDefault("sora2api.token_import_mode", "at")
|
viper.SetDefault("sora2api.token_import_mode", "at")
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Config) Validate() error {
|
func (c *Config) Validate() error {
|
||||||
@@ -1263,20 +1264,6 @@ func (c *Config) Validate() error {
|
|||||||
if err := ValidateAbsoluteHTTPURL(c.Sora2API.BaseURL); err != nil {
|
if err := ValidateAbsoluteHTTPURL(c.Sora2API.BaseURL); err != nil {
|
||||||
return fmt.Errorf("sora2api.base_url invalid: %w", err)
|
return fmt.Errorf("sora2api.base_url invalid: %w", err)
|
||||||
}
|
}
|
||||||
warnIfInsecureURL("sora2api.base_url", c.Sora2API.BaseURL)
|
|
||||||
}
|
|
||||||
if mode := strings.TrimSpace(strings.ToLower(c.Sora2API.TokenImportMode)); mode != "" {
|
|
||||||
switch mode {
|
|
||||||
case "at", "offline":
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("sora2api.token_import_mode must be one of: at/offline")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if c.Sora2API.AdminTokenTTLSeconds < 0 {
|
|
||||||
return fmt.Errorf("sora2api.admin_token_ttl_seconds must be non-negative")
|
|
||||||
}
|
|
||||||
if c.Sora2API.AdminTimeoutSeconds < 0 {
|
|
||||||
return fmt.Errorf("sora2api.admin_timeout_seconds must be non-negative")
|
|
||||||
}
|
}
|
||||||
if c.Ops.MetricsCollectorCache.TTL < 0 {
|
if c.Ops.MetricsCollectorCache.TTL < 0 {
|
||||||
return fmt.Errorf("ops.metrics_collector_cache.ttl must be non-negative")
|
return fmt.Errorf("ops.metrics_collector_cache.ttl must be non-negative")
|
||||||
|
|||||||
@@ -35,15 +35,15 @@ type CreateGroupRequest struct {
|
|||||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
||||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
||||||
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
||||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||||
SoraImagePrice360 *float64 `json:"sora_image_price_360"`
|
SoraImagePrice360 *float64 `json:"sora_image_price_360"`
|
||||||
SoraImagePrice540 *float64 `json:"sora_image_price_540"`
|
SoraImagePrice540 *float64 `json:"sora_image_price_540"`
|
||||||
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"`
|
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"`
|
||||||
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"`
|
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"`
|
||||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||||
// 模型路由配置(仅 anthropic 平台使用)
|
// 模型路由配置(仅 anthropic 平台使用)
|
||||||
ModelRouting map[string][]int64 `json:"model_routing"`
|
ModelRouting map[string][]int64 `json:"model_routing"`
|
||||||
ModelRoutingEnabled bool `json:"model_routing_enabled"`
|
ModelRoutingEnabled bool `json:"model_routing_enabled"`
|
||||||
@@ -62,15 +62,15 @@ type UpdateGroupRequest struct {
|
|||||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
||||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
||||||
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
||||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||||
SoraImagePrice360 *float64 `json:"sora_image_price_360"`
|
SoraImagePrice360 *float64 `json:"sora_image_price_360"`
|
||||||
SoraImagePrice540 *float64 `json:"sora_image_price_540"`
|
SoraImagePrice540 *float64 `json:"sora_image_price_540"`
|
||||||
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"`
|
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"`
|
||||||
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"`
|
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"`
|
||||||
ClaudeCodeOnly *bool `json:"claude_code_only"`
|
ClaudeCodeOnly *bool `json:"claude_code_only"`
|
||||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||||
// 模型路由配置(仅 anthropic 平台使用)
|
// 模型路由配置(仅 anthropic 平台使用)
|
||||||
ModelRouting map[string][]int64 `json:"model_routing"`
|
ModelRouting map[string][]int64 `json:"model_routing"`
|
||||||
ModelRoutingEnabled *bool `json:"model_routing_enabled"`
|
ModelRoutingEnabled *bool `json:"model_routing_enabled"`
|
||||||
@@ -163,26 +163,26 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
group, err := h.adminService.CreateGroup(c.Request.Context(), &service.CreateGroupInput{
|
group, err := h.adminService.CreateGroup(c.Request.Context(), &service.CreateGroupInput{
|
||||||
Name: req.Name,
|
Name: req.Name,
|
||||||
Description: req.Description,
|
Description: req.Description,
|
||||||
Platform: req.Platform,
|
Platform: req.Platform,
|
||||||
RateMultiplier: req.RateMultiplier,
|
RateMultiplier: req.RateMultiplier,
|
||||||
IsExclusive: req.IsExclusive,
|
IsExclusive: req.IsExclusive,
|
||||||
SubscriptionType: req.SubscriptionType,
|
SubscriptionType: req.SubscriptionType,
|
||||||
DailyLimitUSD: req.DailyLimitUSD,
|
DailyLimitUSD: req.DailyLimitUSD,
|
||||||
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
||||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
||||||
ImagePrice1K: req.ImagePrice1K,
|
ImagePrice1K: req.ImagePrice1K,
|
||||||
ImagePrice2K: req.ImagePrice2K,
|
ImagePrice2K: req.ImagePrice2K,
|
||||||
ImagePrice4K: req.ImagePrice4K,
|
ImagePrice4K: req.ImagePrice4K,
|
||||||
SoraImagePrice360: req.SoraImagePrice360,
|
SoraImagePrice360: req.SoraImagePrice360,
|
||||||
SoraImagePrice540: req.SoraImagePrice540,
|
SoraImagePrice540: req.SoraImagePrice540,
|
||||||
SoraVideoPricePerRequest: req.SoraVideoPricePerRequest,
|
SoraVideoPricePerRequest: req.SoraVideoPricePerRequest,
|
||||||
SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD,
|
SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD,
|
||||||
ClaudeCodeOnly: req.ClaudeCodeOnly,
|
ClaudeCodeOnly: req.ClaudeCodeOnly,
|
||||||
FallbackGroupID: req.FallbackGroupID,
|
FallbackGroupID: req.FallbackGroupID,
|
||||||
ModelRouting: req.ModelRouting,
|
ModelRouting: req.ModelRouting,
|
||||||
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
@@ -208,27 +208,27 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
group, err := h.adminService.UpdateGroup(c.Request.Context(), groupID, &service.UpdateGroupInput{
|
group, err := h.adminService.UpdateGroup(c.Request.Context(), groupID, &service.UpdateGroupInput{
|
||||||
Name: req.Name,
|
Name: req.Name,
|
||||||
Description: req.Description,
|
Description: req.Description,
|
||||||
Platform: req.Platform,
|
Platform: req.Platform,
|
||||||
RateMultiplier: req.RateMultiplier,
|
RateMultiplier: req.RateMultiplier,
|
||||||
IsExclusive: req.IsExclusive,
|
IsExclusive: req.IsExclusive,
|
||||||
Status: req.Status,
|
Status: req.Status,
|
||||||
SubscriptionType: req.SubscriptionType,
|
SubscriptionType: req.SubscriptionType,
|
||||||
DailyLimitUSD: req.DailyLimitUSD,
|
DailyLimitUSD: req.DailyLimitUSD,
|
||||||
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
||||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
||||||
ImagePrice1K: req.ImagePrice1K,
|
ImagePrice1K: req.ImagePrice1K,
|
||||||
ImagePrice2K: req.ImagePrice2K,
|
ImagePrice2K: req.ImagePrice2K,
|
||||||
ImagePrice4K: req.ImagePrice4K,
|
ImagePrice4K: req.ImagePrice4K,
|
||||||
SoraImagePrice360: req.SoraImagePrice360,
|
SoraImagePrice360: req.SoraImagePrice360,
|
||||||
SoraImagePrice540: req.SoraImagePrice540,
|
SoraImagePrice540: req.SoraImagePrice540,
|
||||||
SoraVideoPricePerRequest: req.SoraVideoPricePerRequest,
|
SoraVideoPricePerRequest: req.SoraVideoPricePerRequest,
|
||||||
SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD,
|
SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD,
|
||||||
ClaudeCodeOnly: req.ClaudeCodeOnly,
|
ClaudeCodeOnly: req.ClaudeCodeOnly,
|
||||||
FallbackGroupID: req.FallbackGroupID,
|
FallbackGroupID: req.FallbackGroupID,
|
||||||
ModelRouting: req.ModelRouting,
|
ModelRouting: req.ModelRouting,
|
||||||
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
|
|||||||
@@ -122,28 +122,28 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
|
|||||||
|
|
||||||
func groupFromServiceBase(g *service.Group) Group {
|
func groupFromServiceBase(g *service.Group) Group {
|
||||||
return Group{
|
return Group{
|
||||||
ID: g.ID,
|
ID: g.ID,
|
||||||
Name: g.Name,
|
Name: g.Name,
|
||||||
Description: g.Description,
|
Description: g.Description,
|
||||||
Platform: g.Platform,
|
Platform: g.Platform,
|
||||||
RateMultiplier: g.RateMultiplier,
|
RateMultiplier: g.RateMultiplier,
|
||||||
IsExclusive: g.IsExclusive,
|
IsExclusive: g.IsExclusive,
|
||||||
Status: g.Status,
|
Status: g.Status,
|
||||||
SubscriptionType: g.SubscriptionType,
|
SubscriptionType: g.SubscriptionType,
|
||||||
DailyLimitUSD: g.DailyLimitUSD,
|
DailyLimitUSD: g.DailyLimitUSD,
|
||||||
WeeklyLimitUSD: g.WeeklyLimitUSD,
|
WeeklyLimitUSD: g.WeeklyLimitUSD,
|
||||||
MonthlyLimitUSD: g.MonthlyLimitUSD,
|
MonthlyLimitUSD: g.MonthlyLimitUSD,
|
||||||
ImagePrice1K: g.ImagePrice1K,
|
ImagePrice1K: g.ImagePrice1K,
|
||||||
ImagePrice2K: g.ImagePrice2K,
|
ImagePrice2K: g.ImagePrice2K,
|
||||||
ImagePrice4K: g.ImagePrice4K,
|
ImagePrice4K: g.ImagePrice4K,
|
||||||
SoraImagePrice360: g.SoraImagePrice360,
|
SoraImagePrice360: g.SoraImagePrice360,
|
||||||
SoraImagePrice540: g.SoraImagePrice540,
|
SoraImagePrice540: g.SoraImagePrice540,
|
||||||
SoraVideoPricePerRequest: g.SoraVideoPricePerRequest,
|
SoraVideoPricePerRequest: g.SoraVideoPricePerRequest,
|
||||||
SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHD,
|
SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHD,
|
||||||
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
||||||
FallbackGroupID: g.FallbackGroupID,
|
FallbackGroupID: g.FallbackGroupID,
|
||||||
CreatedAt: g.CreatedAt,
|
CreatedAt: g.CreatedAt,
|
||||||
UpdatedAt: g.UpdatedAt,
|
UpdatedAt: g.UpdatedAt,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -62,9 +62,9 @@ type Group struct {
|
|||||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||||
|
|
||||||
// Sora 按次计费配置
|
// Sora 按次计费配置
|
||||||
SoraImagePrice360 *float64 `json:"sora_image_price_360"`
|
SoraImagePrice360 *float64 `json:"sora_image_price_360"`
|
||||||
SoraImagePrice540 *float64 `json:"sora_image_price_540"`
|
SoraImagePrice540 *float64 `json:"sora_image_price_540"`
|
||||||
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"`
|
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"`
|
||||||
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"`
|
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"`
|
||||||
|
|
||||||
// Claude Code 客户端限制
|
// Claude Code 客户端限制
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ type SoraGatewayHandler struct {
|
|||||||
streamMode string
|
streamMode string
|
||||||
sora2apiBaseURL string
|
sora2apiBaseURL string
|
||||||
soraMediaSigningKey string
|
soraMediaSigningKey string
|
||||||
|
mediaClient *http.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSoraGatewayHandler creates a new SoraGatewayHandler
|
// NewSoraGatewayHandler creates a new SoraGatewayHandler
|
||||||
@@ -61,6 +62,10 @@ func NewSoraGatewayHandler(
|
|||||||
if cfg != nil {
|
if cfg != nil {
|
||||||
baseURL = strings.TrimRight(strings.TrimSpace(cfg.Sora2API.BaseURL), "/")
|
baseURL = strings.TrimRight(strings.TrimSpace(cfg.Sora2API.BaseURL), "/")
|
||||||
}
|
}
|
||||||
|
mediaTimeout := 180 * time.Second
|
||||||
|
if cfg != nil && cfg.Gateway.SoraRequestTimeoutSeconds > 0 {
|
||||||
|
mediaTimeout = time.Duration(cfg.Gateway.SoraRequestTimeoutSeconds) * time.Second
|
||||||
|
}
|
||||||
return &SoraGatewayHandler{
|
return &SoraGatewayHandler{
|
||||||
gatewayService: gatewayService,
|
gatewayService: gatewayService,
|
||||||
soraGatewayService: soraGatewayService,
|
soraGatewayService: soraGatewayService,
|
||||||
@@ -70,6 +75,7 @@ func NewSoraGatewayHandler(
|
|||||||
streamMode: strings.ToLower(streamMode),
|
streamMode: strings.ToLower(streamMode),
|
||||||
sora2apiBaseURL: baseURL,
|
sora2apiBaseURL: baseURL,
|
||||||
soraMediaSigningKey: signKey,
|
soraMediaSigningKey: signKey,
|
||||||
|
mediaClient: &http.Client{Timeout: mediaTimeout},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -457,7 +463,11 @@ func (h *SoraGatewayHandler) proxySoraMedia(c *gin.Context, requireSignature boo
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := http.DefaultClient.Do(req)
|
client := h.mediaClient
|
||||||
|
if client == nil {
|
||||||
|
client = http.DefaultClient
|
||||||
|
}
|
||||||
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.Status(http.StatusBadGateway)
|
c.Status(http.StatusBadGateway)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -1565,7 +1565,7 @@ func itoa(v int) string {
|
|||||||
// Uses PostgreSQL JSONB @> operator for efficient queries (requires GIN index).
|
// Uses PostgreSQL JSONB @> operator for efficient queries (requires GIN index).
|
||||||
//
|
//
|
||||||
// Use case: Finding Sora accounts linked via linked_openai_account_id.
|
// Use case: Finding Sora accounts linked via linked_openai_account_id.
|
||||||
func (r *accountRepository) FindByExtraField(ctx context.Context, key string, value interface{}) ([]service.Account, error) {
|
func (r *accountRepository) FindByExtraField(ctx context.Context, key string, value any) ([]service.Account, error) {
|
||||||
accounts, err := r.client.Account.Query().
|
accounts, err := r.client.Account.Query().
|
||||||
Where(
|
Where(
|
||||||
dbaccount.PlatformEQ("sora"), // 限定平台为 sora
|
dbaccount.PlatformEQ("sora"), // 限定平台为 sora
|
||||||
|
|||||||
@@ -410,32 +410,32 @@ func groupEntityToService(g *dbent.Group) *service.Group {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return &service.Group{
|
return &service.Group{
|
||||||
ID: g.ID,
|
ID: g.ID,
|
||||||
Name: g.Name,
|
Name: g.Name,
|
||||||
Description: derefString(g.Description),
|
Description: derefString(g.Description),
|
||||||
Platform: g.Platform,
|
Platform: g.Platform,
|
||||||
RateMultiplier: g.RateMultiplier,
|
RateMultiplier: g.RateMultiplier,
|
||||||
IsExclusive: g.IsExclusive,
|
IsExclusive: g.IsExclusive,
|
||||||
Status: g.Status,
|
Status: g.Status,
|
||||||
Hydrated: true,
|
Hydrated: true,
|
||||||
SubscriptionType: g.SubscriptionType,
|
SubscriptionType: g.SubscriptionType,
|
||||||
DailyLimitUSD: g.DailyLimitUsd,
|
DailyLimitUSD: g.DailyLimitUsd,
|
||||||
WeeklyLimitUSD: g.WeeklyLimitUsd,
|
WeeklyLimitUSD: g.WeeklyLimitUsd,
|
||||||
MonthlyLimitUSD: g.MonthlyLimitUsd,
|
MonthlyLimitUSD: g.MonthlyLimitUsd,
|
||||||
ImagePrice1K: g.ImagePrice1k,
|
ImagePrice1K: g.ImagePrice1k,
|
||||||
ImagePrice2K: g.ImagePrice2k,
|
ImagePrice2K: g.ImagePrice2k,
|
||||||
ImagePrice4K: g.ImagePrice4k,
|
ImagePrice4K: g.ImagePrice4k,
|
||||||
SoraImagePrice360: g.SoraImagePrice360,
|
SoraImagePrice360: g.SoraImagePrice360,
|
||||||
SoraImagePrice540: g.SoraImagePrice540,
|
SoraImagePrice540: g.SoraImagePrice540,
|
||||||
SoraVideoPricePerRequest: g.SoraVideoPricePerRequest,
|
SoraVideoPricePerRequest: g.SoraVideoPricePerRequest,
|
||||||
SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHd,
|
SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHd,
|
||||||
DefaultValidityDays: g.DefaultValidityDays,
|
DefaultValidityDays: g.DefaultValidityDays,
|
||||||
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
||||||
FallbackGroupID: g.FallbackGroupID,
|
FallbackGroupID: g.FallbackGroupID,
|
||||||
ModelRouting: g.ModelRouting,
|
ModelRouting: g.ModelRouting,
|
||||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||||
CreatedAt: g.CreatedAt,
|
CreatedAt: g.CreatedAt,
|
||||||
UpdatedAt: g.UpdatedAt,
|
UpdatedAt: g.UpdatedAt,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -76,7 +76,7 @@ func (r *soraAccountRepository) GetByAccountID(ctx context.Context, accountID in
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
if !rows.Next() {
|
if !rows.Next() {
|
||||||
return nil, nil // 记录不存在
|
return nil, nil // 记录不存在
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ type AccountRepository interface {
|
|||||||
GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error)
|
GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error)
|
||||||
// FindByExtraField 根据 extra 字段中的键值对查找账号(限定 platform='sora')
|
// FindByExtraField 根据 extra 字段中的键值对查找账号(限定 platform='sora')
|
||||||
// 用于查找通过 linked_openai_account_id 关联的 Sora 账号
|
// 用于查找通过 linked_openai_account_id 关联的 Sora 账号
|
||||||
FindByExtraField(ctx context.Context, key string, value interface{}) ([]Account, error)
|
FindByExtraField(ctx context.Context, key string, value any) ([]Account, error)
|
||||||
Update(ctx context.Context, account *Account) error
|
Update(ctx context.Context, account *Account) error
|
||||||
Delete(ctx context.Context, id int64) error
|
Delete(ctx context.Context, id int64) error
|
||||||
|
|
||||||
|
|||||||
@@ -23,24 +23,24 @@ type APIKeyAuthUserSnapshot struct {
|
|||||||
|
|
||||||
// APIKeyAuthGroupSnapshot 分组快照
|
// APIKeyAuthGroupSnapshot 分组快照
|
||||||
type APIKeyAuthGroupSnapshot struct {
|
type APIKeyAuthGroupSnapshot struct {
|
||||||
ID int64 `json:"id"`
|
ID int64 `json:"id"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Platform string `json:"platform"`
|
Platform string `json:"platform"`
|
||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
SubscriptionType string `json:"subscription_type"`
|
SubscriptionType string `json:"subscription_type"`
|
||||||
RateMultiplier float64 `json:"rate_multiplier"`
|
RateMultiplier float64 `json:"rate_multiplier"`
|
||||||
DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"`
|
DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"`
|
||||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"`
|
WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"`
|
||||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"`
|
MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"`
|
||||||
ImagePrice1K *float64 `json:"image_price_1k,omitempty"`
|
ImagePrice1K *float64 `json:"image_price_1k,omitempty"`
|
||||||
ImagePrice2K *float64 `json:"image_price_2k,omitempty"`
|
ImagePrice2K *float64 `json:"image_price_2k,omitempty"`
|
||||||
ImagePrice4K *float64 `json:"image_price_4k,omitempty"`
|
ImagePrice4K *float64 `json:"image_price_4k,omitempty"`
|
||||||
SoraImagePrice360 *float64 `json:"sora_image_price_360,omitempty"`
|
SoraImagePrice360 *float64 `json:"sora_image_price_360,omitempty"`
|
||||||
SoraImagePrice540 *float64 `json:"sora_image_price_540,omitempty"`
|
SoraImagePrice540 *float64 `json:"sora_image_price_540,omitempty"`
|
||||||
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request,omitempty"`
|
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request,omitempty"`
|
||||||
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd,omitempty"`
|
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd,omitempty"`
|
||||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||||
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
|
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
|
||||||
|
|
||||||
// Model routing is used by gateway account selection, so it must be part of auth cache snapshot.
|
// Model routing is used by gateway account selection, so it must be part of auth cache snapshot.
|
||||||
// Only anthropic groups use these fields; others may leave them empty.
|
// Only anthropic groups use these fields; others may leave them empty.
|
||||||
|
|||||||
@@ -223,26 +223,26 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
|
|||||||
}
|
}
|
||||||
if apiKey.Group != nil {
|
if apiKey.Group != nil {
|
||||||
snapshot.Group = &APIKeyAuthGroupSnapshot{
|
snapshot.Group = &APIKeyAuthGroupSnapshot{
|
||||||
ID: apiKey.Group.ID,
|
ID: apiKey.Group.ID,
|
||||||
Name: apiKey.Group.Name,
|
Name: apiKey.Group.Name,
|
||||||
Platform: apiKey.Group.Platform,
|
Platform: apiKey.Group.Platform,
|
||||||
Status: apiKey.Group.Status,
|
Status: apiKey.Group.Status,
|
||||||
SubscriptionType: apiKey.Group.SubscriptionType,
|
SubscriptionType: apiKey.Group.SubscriptionType,
|
||||||
RateMultiplier: apiKey.Group.RateMultiplier,
|
RateMultiplier: apiKey.Group.RateMultiplier,
|
||||||
DailyLimitUSD: apiKey.Group.DailyLimitUSD,
|
DailyLimitUSD: apiKey.Group.DailyLimitUSD,
|
||||||
WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD,
|
WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD,
|
||||||
MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD,
|
MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD,
|
||||||
ImagePrice1K: apiKey.Group.ImagePrice1K,
|
ImagePrice1K: apiKey.Group.ImagePrice1K,
|
||||||
ImagePrice2K: apiKey.Group.ImagePrice2K,
|
ImagePrice2K: apiKey.Group.ImagePrice2K,
|
||||||
ImagePrice4K: apiKey.Group.ImagePrice4K,
|
ImagePrice4K: apiKey.Group.ImagePrice4K,
|
||||||
SoraImagePrice360: apiKey.Group.SoraImagePrice360,
|
SoraImagePrice360: apiKey.Group.SoraImagePrice360,
|
||||||
SoraImagePrice540: apiKey.Group.SoraImagePrice540,
|
SoraImagePrice540: apiKey.Group.SoraImagePrice540,
|
||||||
SoraVideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest,
|
SoraVideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest,
|
||||||
SoraVideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD,
|
SoraVideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD,
|
||||||
ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly,
|
ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly,
|
||||||
FallbackGroupID: apiKey.Group.FallbackGroupID,
|
FallbackGroupID: apiKey.Group.FallbackGroupID,
|
||||||
ModelRouting: apiKey.Group.ModelRouting,
|
ModelRouting: apiKey.Group.ModelRouting,
|
||||||
ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled,
|
ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return snapshot
|
return snapshot
|
||||||
@@ -270,27 +270,27 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
|
|||||||
}
|
}
|
||||||
if snapshot.Group != nil {
|
if snapshot.Group != nil {
|
||||||
apiKey.Group = &Group{
|
apiKey.Group = &Group{
|
||||||
ID: snapshot.Group.ID,
|
ID: snapshot.Group.ID,
|
||||||
Name: snapshot.Group.Name,
|
Name: snapshot.Group.Name,
|
||||||
Platform: snapshot.Group.Platform,
|
Platform: snapshot.Group.Platform,
|
||||||
Status: snapshot.Group.Status,
|
Status: snapshot.Group.Status,
|
||||||
Hydrated: true,
|
Hydrated: true,
|
||||||
SubscriptionType: snapshot.Group.SubscriptionType,
|
SubscriptionType: snapshot.Group.SubscriptionType,
|
||||||
RateMultiplier: snapshot.Group.RateMultiplier,
|
RateMultiplier: snapshot.Group.RateMultiplier,
|
||||||
DailyLimitUSD: snapshot.Group.DailyLimitUSD,
|
DailyLimitUSD: snapshot.Group.DailyLimitUSD,
|
||||||
WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD,
|
WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD,
|
||||||
MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD,
|
MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD,
|
||||||
ImagePrice1K: snapshot.Group.ImagePrice1K,
|
ImagePrice1K: snapshot.Group.ImagePrice1K,
|
||||||
ImagePrice2K: snapshot.Group.ImagePrice2K,
|
ImagePrice2K: snapshot.Group.ImagePrice2K,
|
||||||
ImagePrice4K: snapshot.Group.ImagePrice4K,
|
ImagePrice4K: snapshot.Group.ImagePrice4K,
|
||||||
SoraImagePrice360: snapshot.Group.SoraImagePrice360,
|
SoraImagePrice360: snapshot.Group.SoraImagePrice360,
|
||||||
SoraImagePrice540: snapshot.Group.SoraImagePrice540,
|
SoraImagePrice540: snapshot.Group.SoraImagePrice540,
|
||||||
SoraVideoPricePerRequest: snapshot.Group.SoraVideoPricePerRequest,
|
SoraVideoPricePerRequest: snapshot.Group.SoraVideoPricePerRequest,
|
||||||
SoraVideoPricePerRequestHD: snapshot.Group.SoraVideoPricePerRequestHD,
|
SoraVideoPricePerRequestHD: snapshot.Group.SoraVideoPricePerRequestHD,
|
||||||
ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly,
|
ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly,
|
||||||
FallbackGroupID: snapshot.Group.FallbackGroupID,
|
FallbackGroupID: snapshot.Group.FallbackGroupID,
|
||||||
ModelRouting: snapshot.Group.ModelRouting,
|
ModelRouting: snapshot.Group.ModelRouting,
|
||||||
ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled,
|
ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return apiKey
|
return apiKey
|
||||||
|
|||||||
@@ -3465,7 +3465,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
|||||||
var cost *CostBreakdown
|
var cost *CostBreakdown
|
||||||
|
|
||||||
// 根据请求类型选择计费方式
|
// 根据请求类型选择计费方式
|
||||||
if result.MediaType == "image" || result.MediaType == "video" || result.MediaType == "prompt" {
|
if result.MediaType == "image" || result.MediaType == "video" {
|
||||||
var soraConfig *SoraPriceConfig
|
var soraConfig *SoraPriceConfig
|
||||||
if apiKey.Group != nil {
|
if apiKey.Group != nil {
|
||||||
soraConfig = &SoraPriceConfig{
|
soraConfig = &SoraPriceConfig{
|
||||||
@@ -3480,6 +3480,8 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
|||||||
} else {
|
} else {
|
||||||
cost = s.billingService.CalculateSoraVideoCost(result.Model, soraConfig, multiplier)
|
cost = s.billingService.CalculateSoraVideoCost(result.Model, soraConfig, multiplier)
|
||||||
}
|
}
|
||||||
|
} else if result.MediaType == "prompt" {
|
||||||
|
cost = &CostBreakdown{}
|
||||||
} else if result.ImageCount > 0 {
|
} else if result.ImageCount > 0 {
|
||||||
// 图片生成计费
|
// 图片生成计费
|
||||||
var groupConfig *ImagePriceConfig
|
var groupConfig *ImagePriceConfig
|
||||||
|
|||||||
@@ -27,9 +27,9 @@ type Group struct {
|
|||||||
ImagePrice4K *float64
|
ImagePrice4K *float64
|
||||||
|
|
||||||
// Sora 按次计费配置(阶段 1)
|
// Sora 按次计费配置(阶段 1)
|
||||||
SoraImagePrice360 *float64
|
SoraImagePrice360 *float64
|
||||||
SoraImagePrice540 *float64
|
SoraImagePrice540 *float64
|
||||||
SoraVideoPricePerRequest *float64
|
SoraVideoPricePerRequest *float64
|
||||||
SoraVideoPricePerRequestHD *float64
|
SoraVideoPricePerRequestHD *float64
|
||||||
|
|
||||||
// Claude Code 客户端限制
|
// Claude Code 客户端限制
|
||||||
|
|||||||
@@ -62,7 +62,6 @@ type Sora2APIService struct {
|
|||||||
adminUsername string
|
adminUsername string
|
||||||
adminPassword string
|
adminPassword string
|
||||||
adminTokenTTL time.Duration
|
adminTokenTTL time.Duration
|
||||||
adminTimeout time.Duration
|
|
||||||
tokenImportMode string
|
tokenImportMode string
|
||||||
|
|
||||||
client *http.Client
|
client *http.Client
|
||||||
@@ -72,9 +71,8 @@ type Sora2APIService struct {
|
|||||||
adminTokenAt time.Time
|
adminTokenAt time.Time
|
||||||
adminMu sync.Mutex
|
adminMu sync.Mutex
|
||||||
|
|
||||||
modelCache []Sora2APIModel
|
modelCache []Sora2APIModel
|
||||||
modelCacheAt time.Time
|
modelMu sync.RWMutex
|
||||||
modelMu sync.RWMutex
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSora2APIService(cfg *config.Config) *Sora2APIService {
|
func NewSora2APIService(cfg *config.Config) *Sora2APIService {
|
||||||
@@ -96,7 +94,6 @@ func NewSora2APIService(cfg *config.Config) *Sora2APIService {
|
|||||||
adminUsername: strings.TrimSpace(cfg.Sora2API.AdminUsername),
|
adminUsername: strings.TrimSpace(cfg.Sora2API.AdminUsername),
|
||||||
adminPassword: strings.TrimSpace(cfg.Sora2API.AdminPassword),
|
adminPassword: strings.TrimSpace(cfg.Sora2API.AdminPassword),
|
||||||
adminTokenTTL: adminTTL,
|
adminTokenTTL: adminTTL,
|
||||||
adminTimeout: adminTimeout,
|
|
||||||
tokenImportMode: strings.ToLower(strings.TrimSpace(cfg.Sora2API.TokenImportMode)),
|
tokenImportMode: strings.ToLower(strings.TrimSpace(cfg.Sora2API.TokenImportMode)),
|
||||||
client: &http.Client{},
|
client: &http.Client{},
|
||||||
adminClient: &http.Client{Timeout: adminTimeout},
|
adminClient: &http.Client{Timeout: adminTimeout},
|
||||||
@@ -176,7 +173,6 @@ func (s *Sora2APIService) ListModels(ctx context.Context) ([]Sora2APIModel, erro
|
|||||||
|
|
||||||
s.modelMu.Lock()
|
s.modelMu.Lock()
|
||||||
s.modelCache = models
|
s.modelCache = models
|
||||||
s.modelCacheAt = time.Now()
|
|
||||||
s.modelMu.Unlock()
|
s.modelMu.Unlock()
|
||||||
|
|
||||||
return models, nil
|
return models, nil
|
||||||
|
|||||||
@@ -23,6 +23,8 @@ var soraSSEDataRe = regexp.MustCompile(`^data:\s*`)
|
|||||||
var soraImageMarkdownRe = regexp.MustCompile(`!\[[^\]]*\]\(([^)]+)\)`)
|
var soraImageMarkdownRe = regexp.MustCompile(`!\[[^\]]*\]\(([^)]+)\)`)
|
||||||
var soraVideoHTMLRe = regexp.MustCompile(`(?i)<video[^>]+src=['"]([^'"]+)['"]`)
|
var soraVideoHTMLRe = regexp.MustCompile(`(?i)<video[^>]+src=['"]([^'"]+)['"]`)
|
||||||
|
|
||||||
|
const soraRewriteBufferLimit = 2048
|
||||||
|
|
||||||
var soraImageSizeMap = map[string]string{
|
var soraImageSizeMap = map[string]string{
|
||||||
"gpt-image": "360",
|
"gpt-image": "360",
|
||||||
"gpt-image-landscape": "540",
|
"gpt-image-landscape": "540",
|
||||||
@@ -30,7 +32,6 @@ var soraImageSizeMap = map[string]string{
|
|||||||
}
|
}
|
||||||
|
|
||||||
type soraStreamingResult struct {
|
type soraStreamingResult struct {
|
||||||
content string
|
|
||||||
mediaType string
|
mediaType string
|
||||||
mediaURLs []string
|
mediaURLs []string
|
||||||
imageCount int
|
imageCount int
|
||||||
@@ -307,6 +308,7 @@ func (s *SoraGatewayService) handleStreamingResponse(ctx context.Context, resp *
|
|||||||
contentBuilder := strings.Builder{}
|
contentBuilder := strings.Builder{}
|
||||||
var firstTokenMs *int
|
var firstTokenMs *int
|
||||||
var upstreamError error
|
var upstreamError error
|
||||||
|
rewriteBuffer := ""
|
||||||
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
maxLineSize := defaultMaxLineSize
|
maxLineSize := defaultMaxLineSize
|
||||||
@@ -333,12 +335,29 @@ func (s *SoraGatewayService) handleStreamingResponse(ctx context.Context, resp *
|
|||||||
if soraSSEDataRe.MatchString(line) {
|
if soraSSEDataRe.MatchString(line) {
|
||||||
data := soraSSEDataRe.ReplaceAllString(line, "")
|
data := soraSSEDataRe.ReplaceAllString(line, "")
|
||||||
if data == "[DONE]" {
|
if data == "[DONE]" {
|
||||||
|
if rewriteBuffer != "" {
|
||||||
|
flushLine, flushContent, err := s.flushSoraRewriteBuffer(rewriteBuffer, originalModel)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if flushLine != "" {
|
||||||
|
if flushContent != "" {
|
||||||
|
if _, err := contentBuilder.WriteString(flushContent); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := sendLine(flushLine); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rewriteBuffer = ""
|
||||||
|
}
|
||||||
if err := sendLine("data: [DONE]"); err != nil {
|
if err := sendLine("data: [DONE]"); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
updatedLine, contentDelta, errEvent := s.processSoraSSEData(data, originalModel)
|
updatedLine, contentDelta, errEvent := s.processSoraSSEData(data, originalModel, &rewriteBuffer)
|
||||||
if errEvent != nil && upstreamError == nil {
|
if errEvent != nil && upstreamError == nil {
|
||||||
upstreamError = errEvent
|
upstreamError = errEvent
|
||||||
}
|
}
|
||||||
@@ -347,7 +366,9 @@ func (s *SoraGatewayService) handleStreamingResponse(ctx context.Context, resp *
|
|||||||
ms := int(time.Since(startTime).Milliseconds())
|
ms := int(time.Since(startTime).Milliseconds())
|
||||||
firstTokenMs = &ms
|
firstTokenMs = &ms
|
||||||
}
|
}
|
||||||
contentBuilder.WriteString(contentDelta)
|
if _, err := contentBuilder.WriteString(contentDelta); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if err := sendLine(updatedLine); err != nil {
|
if err := sendLine(updatedLine); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -417,7 +438,6 @@ func (s *SoraGatewayService) handleStreamingResponse(ctx context.Context, resp *
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &soraStreamingResult{
|
return &soraStreamingResult{
|
||||||
content: content,
|
|
||||||
mediaType: mediaType,
|
mediaType: mediaType,
|
||||||
mediaURLs: mediaURLs,
|
mediaURLs: mediaURLs,
|
||||||
imageCount: imageCount,
|
imageCount: imageCount,
|
||||||
@@ -426,7 +446,7 @@ func (s *SoraGatewayService) handleStreamingResponse(ctx context.Context, resp *
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SoraGatewayService) processSoraSSEData(data string, originalModel string) (string, string, error) {
|
func (s *SoraGatewayService) processSoraSSEData(data string, originalModel string, rewriteBuffer *string) (string, string, error) {
|
||||||
if strings.TrimSpace(data) == "" {
|
if strings.TrimSpace(data) == "" {
|
||||||
return "data: ", "", nil
|
return "data: ", "", nil
|
||||||
}
|
}
|
||||||
@@ -448,7 +468,12 @@ func (s *SoraGatewayService) processSoraSSEData(data string, originalModel strin
|
|||||||
|
|
||||||
contentDelta, updated := extractSoraContent(payload)
|
contentDelta, updated := extractSoraContent(payload)
|
||||||
if updated {
|
if updated {
|
||||||
rewritten := s.rewriteSoraContent(contentDelta)
|
var rewritten string
|
||||||
|
if rewriteBuffer != nil {
|
||||||
|
rewritten = s.rewriteSoraContentWithBuffer(contentDelta, rewriteBuffer)
|
||||||
|
} else {
|
||||||
|
rewritten = s.rewriteSoraContent(contentDelta)
|
||||||
|
}
|
||||||
if rewritten != contentDelta {
|
if rewritten != contentDelta {
|
||||||
applySoraContent(payload, rewritten)
|
applySoraContent(payload, rewritten)
|
||||||
contentDelta = rewritten
|
contentDelta = rewritten
|
||||||
@@ -504,6 +529,78 @@ func applySoraContent(payload map[string]any, content string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SoraGatewayService) rewriteSoraContentWithBuffer(contentDelta string, buffer *string) string {
|
||||||
|
if buffer == nil {
|
||||||
|
return s.rewriteSoraContent(contentDelta)
|
||||||
|
}
|
||||||
|
if contentDelta == "" && *buffer == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
combined := *buffer + contentDelta
|
||||||
|
rewritten := s.rewriteSoraContent(combined)
|
||||||
|
bufferStart := s.findSoraRewriteBufferStart(rewritten)
|
||||||
|
if bufferStart < 0 {
|
||||||
|
*buffer = ""
|
||||||
|
return rewritten
|
||||||
|
}
|
||||||
|
if len(rewritten)-bufferStart > soraRewriteBufferLimit {
|
||||||
|
bufferStart = len(rewritten) - soraRewriteBufferLimit
|
||||||
|
}
|
||||||
|
output := rewritten[:bufferStart]
|
||||||
|
*buffer = rewritten[bufferStart:]
|
||||||
|
return output
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SoraGatewayService) findSoraRewriteBufferStart(content string) int {
|
||||||
|
minIndex := -1
|
||||||
|
start := 0
|
||||||
|
for {
|
||||||
|
idx := strings.Index(content[start:], "![")
|
||||||
|
if idx < 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
idx += start
|
||||||
|
if !hasSoraImageMatchAt(content, idx) {
|
||||||
|
if minIndex == -1 || idx < minIndex {
|
||||||
|
minIndex = idx
|
||||||
|
}
|
||||||
|
}
|
||||||
|
start = idx + 2
|
||||||
|
}
|
||||||
|
lower := strings.ToLower(content)
|
||||||
|
start = 0
|
||||||
|
for {
|
||||||
|
idx := strings.Index(lower[start:], "<video")
|
||||||
|
if idx < 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
idx += start
|
||||||
|
if !hasSoraVideoMatchAt(content, idx) {
|
||||||
|
if minIndex == -1 || idx < minIndex {
|
||||||
|
minIndex = idx
|
||||||
|
}
|
||||||
|
}
|
||||||
|
start = idx + len("<video")
|
||||||
|
}
|
||||||
|
return minIndex
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasSoraImageMatchAt(content string, idx int) bool {
|
||||||
|
if idx < 0 || idx >= len(content) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
loc := soraImageMarkdownRe.FindStringIndex(content[idx:])
|
||||||
|
return loc != nil && loc[0] == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasSoraVideoMatchAt(content string, idx int) bool {
|
||||||
|
if idx < 0 || idx >= len(content) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
loc := soraVideoHTMLRe.FindStringIndex(content[idx:])
|
||||||
|
return loc != nil && loc[0] == 0
|
||||||
|
}
|
||||||
|
|
||||||
func (s *SoraGatewayService) rewriteSoraContent(content string) string {
|
func (s *SoraGatewayService) rewriteSoraContent(content string) string {
|
||||||
if content == "" {
|
if content == "" {
|
||||||
return content
|
return content
|
||||||
@@ -533,6 +630,31 @@ func (s *SoraGatewayService) rewriteSoraContent(content string) string {
|
|||||||
return content
|
return content
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SoraGatewayService) flushSoraRewriteBuffer(buffer string, originalModel string) (string, string, error) {
|
||||||
|
if buffer == "" {
|
||||||
|
return "", "", nil
|
||||||
|
}
|
||||||
|
rewritten := s.rewriteSoraContent(buffer)
|
||||||
|
payload := map[string]any{
|
||||||
|
"choices": []any{
|
||||||
|
map[string]any{
|
||||||
|
"delta": map[string]any{
|
||||||
|
"content": rewritten,
|
||||||
|
},
|
||||||
|
"index": 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if originalModel != "" {
|
||||||
|
payload["model"] = originalModel
|
||||||
|
}
|
||||||
|
updatedData, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
return "data: " + string(updatedData), rewritten, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *SoraGatewayService) rewriteSoraURL(raw string) string {
|
func (s *SoraGatewayService) rewriteSoraURL(raw string) string {
|
||||||
raw = strings.TrimSpace(raw)
|
raw = strings.TrimSpace(raw)
|
||||||
if raw == "" {
|
if raw == "" {
|
||||||
|
|||||||
@@ -15,9 +15,15 @@ func SignSoraMediaURL(path string, query string, expires int64, key string) stri
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
mac := hmac.New(sha256.New, []byte(key))
|
mac := hmac.New(sha256.New, []byte(key))
|
||||||
mac.Write([]byte(buildSoraMediaSignPayload(path, query)))
|
if _, err := mac.Write([]byte(buildSoraMediaSignPayload(path, query))); err != nil {
|
||||||
mac.Write([]byte("|"))
|
return ""
|
||||||
mac.Write([]byte(strconv.FormatInt(expires, 10)))
|
}
|
||||||
|
if _, err := mac.Write([]byte("|")); err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if _, err := mac.Write([]byte(strconv.FormatInt(expires, 10))); err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
return hex.EncodeToString(mac.Sum(nil))
|
return hex.EncodeToString(mac.Sum(nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -15,11 +15,9 @@ import (
|
|||||||
// 定期检查并刷新即将过期的token
|
// 定期检查并刷新即将过期的token
|
||||||
type TokenRefreshService struct {
|
type TokenRefreshService struct {
|
||||||
accountRepo AccountRepository
|
accountRepo AccountRepository
|
||||||
soraAccountRepo SoraAccountRepository // Sora 扩展表仓储,用于双表同步
|
|
||||||
refreshers []TokenRefresher
|
refreshers []TokenRefresher
|
||||||
cfg *config.TokenRefreshConfig
|
cfg *config.TokenRefreshConfig
|
||||||
cacheInvalidator TokenCacheInvalidator
|
cacheInvalidator TokenCacheInvalidator
|
||||||
soraSyncService *Sora2APISyncService
|
|
||||||
|
|
||||||
stopCh chan struct{}
|
stopCh chan struct{}
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
@@ -57,7 +55,6 @@ func NewTokenRefreshService(
|
|||||||
// 用于在 OpenAI Token 刷新时同步更新 sora_accounts 表
|
// 用于在 OpenAI Token 刷新时同步更新 sora_accounts 表
|
||||||
// 需要在 Start() 之前调用
|
// 需要在 Start() 之前调用
|
||||||
func (s *TokenRefreshService) SetSoraAccountRepo(repo SoraAccountRepository) {
|
func (s *TokenRefreshService) SetSoraAccountRepo(repo SoraAccountRepository) {
|
||||||
s.soraAccountRepo = repo
|
|
||||||
// 将 soraAccountRepo 注入到 OpenAITokenRefresher
|
// 将 soraAccountRepo 注入到 OpenAITokenRefresher
|
||||||
for _, refresher := range s.refreshers {
|
for _, refresher := range s.refreshers {
|
||||||
if openaiRefresher, ok := refresher.(*OpenAITokenRefresher); ok {
|
if openaiRefresher, ok := refresher.(*OpenAITokenRefresher); ok {
|
||||||
@@ -69,7 +66,6 @@ func (s *TokenRefreshService) SetSoraAccountRepo(repo SoraAccountRepository) {
|
|||||||
// SetSoraSyncService 设置 Sora2API 同步服务
|
// SetSoraSyncService 设置 Sora2API 同步服务
|
||||||
// 需要在 Start() 之前调用
|
// 需要在 Start() 之前调用
|
||||||
func (s *TokenRefreshService) SetSoraSyncService(svc *Sora2APISyncService) {
|
func (s *TokenRefreshService) SetSoraSyncService(svc *Sora2APISyncService) {
|
||||||
s.soraSyncService = svc
|
|
||||||
for _, refresher := range s.refreshers {
|
for _, refresher := range s.refreshers {
|
||||||
if openaiRefresher, ok := refresher.(*OpenAITokenRefresher); ok {
|
if openaiRefresher, ok := refresher.(*OpenAITokenRefresher); ok {
|
||||||
openaiRefresher.SetSoraSyncService(svc)
|
openaiRefresher.SetSoraSyncService(svc)
|
||||||
|
|||||||
@@ -83,10 +83,10 @@ func (r *ClaudeTokenRefresher) Refresh(ctx context.Context, account *Account) (m
|
|||||||
|
|
||||||
// OpenAITokenRefresher 处理 OpenAI OAuth token刷新
|
// OpenAITokenRefresher 处理 OpenAI OAuth token刷新
|
||||||
type OpenAITokenRefresher struct {
|
type OpenAITokenRefresher struct {
|
||||||
openaiOAuthService *OpenAIOAuthService
|
openaiOAuthService *OpenAIOAuthService
|
||||||
accountRepo AccountRepository
|
accountRepo AccountRepository
|
||||||
soraAccountRepo SoraAccountRepository // Sora 扩展表仓储,用于双表同步
|
soraAccountRepo SoraAccountRepository // Sora 扩展表仓储,用于双表同步
|
||||||
soraSyncService *Sora2APISyncService // Sora2API 同步服务
|
soraSyncService *Sora2APISyncService // Sora2API 同步服务
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOpenAITokenRefresher 创建 OpenAI token刷新器
|
// NewOpenAITokenRefresher 创建 OpenAI token刷新器
|
||||||
|
|||||||
@@ -51,9 +51,7 @@ func ProvideTokenRefreshService(
|
|||||||
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, cfg)
|
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, cfg)
|
||||||
// 注入 Sora 账号扩展表仓储,用于 OpenAI Token 刷新时同步 sora_accounts 表
|
// 注入 Sora 账号扩展表仓储,用于 OpenAI Token 刷新时同步 sora_accounts 表
|
||||||
svc.SetSoraAccountRepo(soraAccountRepo)
|
svc.SetSoraAccountRepo(soraAccountRepo)
|
||||||
if soraSyncService != nil {
|
svc.SetSoraSyncService(soraSyncService)
|
||||||
svc.SetSoraSyncService(soraSyncService)
|
|
||||||
}
|
|
||||||
svc.Start()
|
svc.Start()
|
||||||
return svc
|
return svc
|
||||||
}
|
}
|
||||||
@@ -242,8 +240,6 @@ var ProviderSet = wire.NewSet(
|
|||||||
NewAntigravityTokenProvider,
|
NewAntigravityTokenProvider,
|
||||||
NewOpenAITokenProvider,
|
NewOpenAITokenProvider,
|
||||||
NewClaudeTokenProvider,
|
NewClaudeTokenProvider,
|
||||||
NewSora2APIService,
|
|
||||||
NewSora2APISyncService,
|
|
||||||
NewAntigravityGatewayService,
|
NewAntigravityGatewayService,
|
||||||
ProvideRateLimitService,
|
ProvideRateLimitService,
|
||||||
NewAccountUsageService,
|
NewAccountUsageService,
|
||||||
|
|||||||
Reference in New Issue
Block a user