From 78d0ca3775da8fb2162dca8d7358508c2e9f3fdb Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Sat, 31 Jan 2026 21:46:28 +0800 Subject: [PATCH] =?UTF-8?q?fix(sora):=20=E4=BF=AE=E5=A4=8D=E6=B5=81?= =?UTF-8?q?=E5=BC=8F=E9=87=8D=E5=86=99=E4=B8=8E=E8=AE=A1=E8=B4=B9=E9=97=AE?= =?UTF-8?q?=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/internal/config/config.go | 15 +- .../internal/handler/admin/group_handler.go | 110 +++++++------- backend/internal/handler/dto/mappers.go | 42 +++--- backend/internal/handler/dto/types.go | 6 +- .../internal/handler/sora_gateway_handler.go | 12 +- backend/internal/repository/account_repo.go | 2 +- backend/internal/repository/api_key_repo.go | 50 +++---- .../internal/repository/sora_account_repo.go | 2 +- backend/internal/service/account_service.go | 2 +- .../internal/service/api_key_auth_cache.go | 34 ++--- .../service/api_key_auth_cache_impl.go | 78 +++++----- backend/internal/service/gateway_service.go | 4 +- backend/internal/service/group.go | 6 +- backend/internal/service/sora2api_service.go | 8 +- .../internal/service/sora_gateway_service.go | 134 +++++++++++++++++- backend/internal/service/sora_media_sign.go | 12 +- .../internal/service/token_refresh_service.go | 4 - backend/internal/service/token_refresher.go | 8 +- backend/internal/service/wire.go | 6 +- 19 files changed, 325 insertions(+), 210 deletions(-) diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 5dd2b415..f3dec213 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -928,6 +928,7 @@ func setDefaults() { viper.SetDefault("sora2api.admin_token_ttl_seconds", 900) viper.SetDefault("sora2api.admin_timeout_seconds", 10) viper.SetDefault("sora2api.token_import_mode", "at") + } func (c *Config) Validate() error { @@ -1263,20 +1264,6 @@ func (c *Config) Validate() error { if err := ValidateAbsoluteHTTPURL(c.Sora2API.BaseURL); err != nil { return fmt.Errorf("sora2api.base_url invalid: %w", err) } - warnIfInsecureURL("sora2api.base_url", c.Sora2API.BaseURL) - } - if mode := strings.TrimSpace(strings.ToLower(c.Sora2API.TokenImportMode)); mode != "" { - switch mode { - case "at", "offline": - default: - return fmt.Errorf("sora2api.token_import_mode must be one of: at/offline") - } - } - if c.Sora2API.AdminTokenTTLSeconds < 0 { - return fmt.Errorf("sora2api.admin_token_ttl_seconds must be non-negative") - } - if c.Sora2API.AdminTimeoutSeconds < 0 { - return fmt.Errorf("sora2api.admin_timeout_seconds must be non-negative") } if c.Ops.MetricsCollectorCache.TTL < 0 { return fmt.Errorf("ops.metrics_collector_cache.ttl must be non-negative") diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go index 1af570d9..328c8fce 100644 --- a/backend/internal/handler/admin/group_handler.go +++ b/backend/internal/handler/admin/group_handler.go @@ -35,15 +35,15 @@ type CreateGroupRequest struct { WeeklyLimitUSD *float64 `json:"weekly_limit_usd"` MonthlyLimitUSD *float64 `json:"monthly_limit_usd"` // 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置) - ImagePrice1K *float64 `json:"image_price_1k"` - ImagePrice2K *float64 `json:"image_price_2k"` - ImagePrice4K *float64 `json:"image_price_4k"` - SoraImagePrice360 *float64 `json:"sora_image_price_360"` - SoraImagePrice540 *float64 `json:"sora_image_price_540"` - SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"` + ImagePrice1K *float64 `json:"image_price_1k"` + ImagePrice2K *float64 `json:"image_price_2k"` + ImagePrice4K *float64 `json:"image_price_4k"` + SoraImagePrice360 *float64 `json:"sora_image_price_360"` + SoraImagePrice540 *float64 `json:"sora_image_price_540"` + SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"` SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"` - ClaudeCodeOnly bool `json:"claude_code_only"` - FallbackGroupID *int64 `json:"fallback_group_id"` + ClaudeCodeOnly bool `json:"claude_code_only"` + FallbackGroupID *int64 `json:"fallback_group_id"` // 模型路由配置(仅 anthropic 平台使用) ModelRouting map[string][]int64 `json:"model_routing"` ModelRoutingEnabled bool `json:"model_routing_enabled"` @@ -62,15 +62,15 @@ type UpdateGroupRequest struct { WeeklyLimitUSD *float64 `json:"weekly_limit_usd"` MonthlyLimitUSD *float64 `json:"monthly_limit_usd"` // 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置) - ImagePrice1K *float64 `json:"image_price_1k"` - ImagePrice2K *float64 `json:"image_price_2k"` - ImagePrice4K *float64 `json:"image_price_4k"` - SoraImagePrice360 *float64 `json:"sora_image_price_360"` - SoraImagePrice540 *float64 `json:"sora_image_price_540"` - SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"` + ImagePrice1K *float64 `json:"image_price_1k"` + ImagePrice2K *float64 `json:"image_price_2k"` + ImagePrice4K *float64 `json:"image_price_4k"` + SoraImagePrice360 *float64 `json:"sora_image_price_360"` + SoraImagePrice540 *float64 `json:"sora_image_price_540"` + SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"` SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"` - ClaudeCodeOnly *bool `json:"claude_code_only"` - FallbackGroupID *int64 `json:"fallback_group_id"` + ClaudeCodeOnly *bool `json:"claude_code_only"` + FallbackGroupID *int64 `json:"fallback_group_id"` // 模型路由配置(仅 anthropic 平台使用) ModelRouting map[string][]int64 `json:"model_routing"` ModelRoutingEnabled *bool `json:"model_routing_enabled"` @@ -163,26 +163,26 @@ func (h *GroupHandler) Create(c *gin.Context) { } group, err := h.adminService.CreateGroup(c.Request.Context(), &service.CreateGroupInput{ - Name: req.Name, - Description: req.Description, - Platform: req.Platform, - RateMultiplier: req.RateMultiplier, - IsExclusive: req.IsExclusive, - SubscriptionType: req.SubscriptionType, - DailyLimitUSD: req.DailyLimitUSD, - WeeklyLimitUSD: req.WeeklyLimitUSD, - MonthlyLimitUSD: req.MonthlyLimitUSD, - ImagePrice1K: req.ImagePrice1K, - ImagePrice2K: req.ImagePrice2K, - ImagePrice4K: req.ImagePrice4K, - SoraImagePrice360: req.SoraImagePrice360, - SoraImagePrice540: req.SoraImagePrice540, - SoraVideoPricePerRequest: req.SoraVideoPricePerRequest, + Name: req.Name, + Description: req.Description, + Platform: req.Platform, + RateMultiplier: req.RateMultiplier, + IsExclusive: req.IsExclusive, + SubscriptionType: req.SubscriptionType, + DailyLimitUSD: req.DailyLimitUSD, + WeeklyLimitUSD: req.WeeklyLimitUSD, + MonthlyLimitUSD: req.MonthlyLimitUSD, + ImagePrice1K: req.ImagePrice1K, + ImagePrice2K: req.ImagePrice2K, + ImagePrice4K: req.ImagePrice4K, + SoraImagePrice360: req.SoraImagePrice360, + SoraImagePrice540: req.SoraImagePrice540, + SoraVideoPricePerRequest: req.SoraVideoPricePerRequest, SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD, - ClaudeCodeOnly: req.ClaudeCodeOnly, - FallbackGroupID: req.FallbackGroupID, - ModelRouting: req.ModelRouting, - ModelRoutingEnabled: req.ModelRoutingEnabled, + ClaudeCodeOnly: req.ClaudeCodeOnly, + FallbackGroupID: req.FallbackGroupID, + ModelRouting: req.ModelRouting, + ModelRoutingEnabled: req.ModelRoutingEnabled, }) if err != nil { response.ErrorFrom(c, err) @@ -208,27 +208,27 @@ func (h *GroupHandler) Update(c *gin.Context) { } group, err := h.adminService.UpdateGroup(c.Request.Context(), groupID, &service.UpdateGroupInput{ - Name: req.Name, - Description: req.Description, - Platform: req.Platform, - RateMultiplier: req.RateMultiplier, - IsExclusive: req.IsExclusive, - Status: req.Status, - SubscriptionType: req.SubscriptionType, - DailyLimitUSD: req.DailyLimitUSD, - WeeklyLimitUSD: req.WeeklyLimitUSD, - MonthlyLimitUSD: req.MonthlyLimitUSD, - ImagePrice1K: req.ImagePrice1K, - ImagePrice2K: req.ImagePrice2K, - ImagePrice4K: req.ImagePrice4K, - SoraImagePrice360: req.SoraImagePrice360, - SoraImagePrice540: req.SoraImagePrice540, - SoraVideoPricePerRequest: req.SoraVideoPricePerRequest, + Name: req.Name, + Description: req.Description, + Platform: req.Platform, + RateMultiplier: req.RateMultiplier, + IsExclusive: req.IsExclusive, + Status: req.Status, + SubscriptionType: req.SubscriptionType, + DailyLimitUSD: req.DailyLimitUSD, + WeeklyLimitUSD: req.WeeklyLimitUSD, + MonthlyLimitUSD: req.MonthlyLimitUSD, + ImagePrice1K: req.ImagePrice1K, + ImagePrice2K: req.ImagePrice2K, + ImagePrice4K: req.ImagePrice4K, + SoraImagePrice360: req.SoraImagePrice360, + SoraImagePrice540: req.SoraImagePrice540, + SoraVideoPricePerRequest: req.SoraVideoPricePerRequest, SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD, - ClaudeCodeOnly: req.ClaudeCodeOnly, - FallbackGroupID: req.FallbackGroupID, - ModelRouting: req.ModelRouting, - ModelRoutingEnabled: req.ModelRoutingEnabled, + ClaudeCodeOnly: req.ClaudeCodeOnly, + FallbackGroupID: req.FallbackGroupID, + ModelRouting: req.ModelRouting, + ModelRoutingEnabled: req.ModelRoutingEnabled, }) if err != nil { response.ErrorFrom(c, err) diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index b44c3225..58a4ad86 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -122,28 +122,28 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup { func groupFromServiceBase(g *service.Group) Group { return Group{ - ID: g.ID, - Name: g.Name, - Description: g.Description, - Platform: g.Platform, - RateMultiplier: g.RateMultiplier, - IsExclusive: g.IsExclusive, - Status: g.Status, - SubscriptionType: g.SubscriptionType, - DailyLimitUSD: g.DailyLimitUSD, - WeeklyLimitUSD: g.WeeklyLimitUSD, - MonthlyLimitUSD: g.MonthlyLimitUSD, - ImagePrice1K: g.ImagePrice1K, - ImagePrice2K: g.ImagePrice2K, - ImagePrice4K: g.ImagePrice4K, - SoraImagePrice360: g.SoraImagePrice360, - SoraImagePrice540: g.SoraImagePrice540, - SoraVideoPricePerRequest: g.SoraVideoPricePerRequest, + ID: g.ID, + Name: g.Name, + Description: g.Description, + Platform: g.Platform, + RateMultiplier: g.RateMultiplier, + IsExclusive: g.IsExclusive, + Status: g.Status, + SubscriptionType: g.SubscriptionType, + DailyLimitUSD: g.DailyLimitUSD, + WeeklyLimitUSD: g.WeeklyLimitUSD, + MonthlyLimitUSD: g.MonthlyLimitUSD, + ImagePrice1K: g.ImagePrice1K, + ImagePrice2K: g.ImagePrice2K, + ImagePrice4K: g.ImagePrice4K, + SoraImagePrice360: g.SoraImagePrice360, + SoraImagePrice540: g.SoraImagePrice540, + SoraVideoPricePerRequest: g.SoraVideoPricePerRequest, SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHD, - ClaudeCodeOnly: g.ClaudeCodeOnly, - FallbackGroupID: g.FallbackGroupID, - CreatedAt: g.CreatedAt, - UpdatedAt: g.UpdatedAt, + ClaudeCodeOnly: g.ClaudeCodeOnly, + FallbackGroupID: g.FallbackGroupID, + CreatedAt: g.CreatedAt, + UpdatedAt: g.UpdatedAt, } } diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index 3ae899ee..505f9dd4 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -62,9 +62,9 @@ type Group struct { ImagePrice4K *float64 `json:"image_price_4k"` // Sora 按次计费配置 - SoraImagePrice360 *float64 `json:"sora_image_price_360"` - SoraImagePrice540 *float64 `json:"sora_image_price_540"` - SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"` + SoraImagePrice360 *float64 `json:"sora_image_price_360"` + SoraImagePrice540 *float64 `json:"sora_image_price_540"` + SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"` SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"` // Claude Code 客户端限制 diff --git a/backend/internal/handler/sora_gateway_handler.go b/backend/internal/handler/sora_gateway_handler.go index 94f712df..05833144 100644 --- a/backend/internal/handler/sora_gateway_handler.go +++ b/backend/internal/handler/sora_gateway_handler.go @@ -33,6 +33,7 @@ type SoraGatewayHandler struct { streamMode string sora2apiBaseURL string soraMediaSigningKey string + mediaClient *http.Client } // NewSoraGatewayHandler creates a new SoraGatewayHandler @@ -61,6 +62,10 @@ func NewSoraGatewayHandler( if cfg != nil { baseURL = strings.TrimRight(strings.TrimSpace(cfg.Sora2API.BaseURL), "/") } + mediaTimeout := 180 * time.Second + if cfg != nil && cfg.Gateway.SoraRequestTimeoutSeconds > 0 { + mediaTimeout = time.Duration(cfg.Gateway.SoraRequestTimeoutSeconds) * time.Second + } return &SoraGatewayHandler{ gatewayService: gatewayService, soraGatewayService: soraGatewayService, @@ -70,6 +75,7 @@ func NewSoraGatewayHandler( streamMode: strings.ToLower(streamMode), sora2apiBaseURL: baseURL, soraMediaSigningKey: signKey, + mediaClient: &http.Client{Timeout: mediaTimeout}, } } @@ -457,7 +463,11 @@ func (h *SoraGatewayHandler) proxySoraMedia(c *gin.Context, requireSignature boo } } - resp, err := http.DefaultClient.Do(req) + client := h.mediaClient + if client == nil { + client = http.DefaultClient + } + resp, err := client.Do(req) if err != nil { c.Status(http.StatusBadGateway) return diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index 5edc4f6d..170e5de9 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -1565,7 +1565,7 @@ func itoa(v int) string { // Uses PostgreSQL JSONB @> operator for efficient queries (requires GIN index). // // Use case: Finding Sora accounts linked via linked_openai_account_id. -func (r *accountRepository) FindByExtraField(ctx context.Context, key string, value interface{}) ([]service.Account, error) { +func (r *accountRepository) FindByExtraField(ctx context.Context, key string, value any) ([]service.Account, error) { accounts, err := r.client.Account.Query(). Where( dbaccount.PlatformEQ("sora"), // 限定平台为 sora diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index 9308326b..a020ee2b 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -410,32 +410,32 @@ func groupEntityToService(g *dbent.Group) *service.Group { return nil } return &service.Group{ - ID: g.ID, - Name: g.Name, - Description: derefString(g.Description), - Platform: g.Platform, - RateMultiplier: g.RateMultiplier, - IsExclusive: g.IsExclusive, - Status: g.Status, - Hydrated: true, - SubscriptionType: g.SubscriptionType, - DailyLimitUSD: g.DailyLimitUsd, - WeeklyLimitUSD: g.WeeklyLimitUsd, - MonthlyLimitUSD: g.MonthlyLimitUsd, - ImagePrice1K: g.ImagePrice1k, - ImagePrice2K: g.ImagePrice2k, - ImagePrice4K: g.ImagePrice4k, - SoraImagePrice360: g.SoraImagePrice360, - SoraImagePrice540: g.SoraImagePrice540, - SoraVideoPricePerRequest: g.SoraVideoPricePerRequest, + ID: g.ID, + Name: g.Name, + Description: derefString(g.Description), + Platform: g.Platform, + RateMultiplier: g.RateMultiplier, + IsExclusive: g.IsExclusive, + Status: g.Status, + Hydrated: true, + SubscriptionType: g.SubscriptionType, + DailyLimitUSD: g.DailyLimitUsd, + WeeklyLimitUSD: g.WeeklyLimitUsd, + MonthlyLimitUSD: g.MonthlyLimitUsd, + ImagePrice1K: g.ImagePrice1k, + ImagePrice2K: g.ImagePrice2k, + ImagePrice4K: g.ImagePrice4k, + SoraImagePrice360: g.SoraImagePrice360, + SoraImagePrice540: g.SoraImagePrice540, + SoraVideoPricePerRequest: g.SoraVideoPricePerRequest, SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHd, - DefaultValidityDays: g.DefaultValidityDays, - ClaudeCodeOnly: g.ClaudeCodeOnly, - FallbackGroupID: g.FallbackGroupID, - ModelRouting: g.ModelRouting, - ModelRoutingEnabled: g.ModelRoutingEnabled, - CreatedAt: g.CreatedAt, - UpdatedAt: g.UpdatedAt, + DefaultValidityDays: g.DefaultValidityDays, + ClaudeCodeOnly: g.ClaudeCodeOnly, + FallbackGroupID: g.FallbackGroupID, + ModelRouting: g.ModelRouting, + ModelRoutingEnabled: g.ModelRoutingEnabled, + CreatedAt: g.CreatedAt, + UpdatedAt: g.UpdatedAt, } } diff --git a/backend/internal/repository/sora_account_repo.go b/backend/internal/repository/sora_account_repo.go index e0ec6073..ad2ae638 100644 --- a/backend/internal/repository/sora_account_repo.go +++ b/backend/internal/repository/sora_account_repo.go @@ -76,7 +76,7 @@ func (r *soraAccountRepository) GetByAccountID(ctx context.Context, accountID in if err != nil { return nil, err } - defer rows.Close() + defer func() { _ = rows.Close() }() if !rows.Next() { return nil, nil // 记录不存在 diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go index 4befc996..a261fb21 100644 --- a/backend/internal/service/account_service.go +++ b/backend/internal/service/account_service.go @@ -27,7 +27,7 @@ type AccountRepository interface { GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) // FindByExtraField 根据 extra 字段中的键值对查找账号(限定 platform='sora') // 用于查找通过 linked_openai_account_id 关联的 Sora 账号 - FindByExtraField(ctx context.Context, key string, value interface{}) ([]Account, error) + FindByExtraField(ctx context.Context, key string, value any) ([]Account, error) Update(ctx context.Context, account *Account) error Delete(ctx context.Context, id int64) error diff --git a/backend/internal/service/api_key_auth_cache.go b/backend/internal/service/api_key_auth_cache.go index 6247da00..9d8f87f2 100644 --- a/backend/internal/service/api_key_auth_cache.go +++ b/backend/internal/service/api_key_auth_cache.go @@ -23,24 +23,24 @@ type APIKeyAuthUserSnapshot struct { // APIKeyAuthGroupSnapshot 分组快照 type APIKeyAuthGroupSnapshot struct { - ID int64 `json:"id"` - Name string `json:"name"` - Platform string `json:"platform"` - Status string `json:"status"` - SubscriptionType string `json:"subscription_type"` - RateMultiplier float64 `json:"rate_multiplier"` - DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"` - WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"` - MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"` - ImagePrice1K *float64 `json:"image_price_1k,omitempty"` - ImagePrice2K *float64 `json:"image_price_2k,omitempty"` - ImagePrice4K *float64 `json:"image_price_4k,omitempty"` - SoraImagePrice360 *float64 `json:"sora_image_price_360,omitempty"` - SoraImagePrice540 *float64 `json:"sora_image_price_540,omitempty"` - SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request,omitempty"` + ID int64 `json:"id"` + Name string `json:"name"` + Platform string `json:"platform"` + Status string `json:"status"` + SubscriptionType string `json:"subscription_type"` + RateMultiplier float64 `json:"rate_multiplier"` + DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"` + WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"` + MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"` + ImagePrice1K *float64 `json:"image_price_1k,omitempty"` + ImagePrice2K *float64 `json:"image_price_2k,omitempty"` + ImagePrice4K *float64 `json:"image_price_4k,omitempty"` + SoraImagePrice360 *float64 `json:"sora_image_price_360,omitempty"` + SoraImagePrice540 *float64 `json:"sora_image_price_540,omitempty"` + SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request,omitempty"` SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd,omitempty"` - ClaudeCodeOnly bool `json:"claude_code_only"` - FallbackGroupID *int64 `json:"fallback_group_id,omitempty"` + ClaudeCodeOnly bool `json:"claude_code_only"` + FallbackGroupID *int64 `json:"fallback_group_id,omitempty"` // Model routing is used by gateway account selection, so it must be part of auth cache snapshot. // Only anthropic groups use these fields; others may leave them empty. diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go index 5569a503..19ba4e79 100644 --- a/backend/internal/service/api_key_auth_cache_impl.go +++ b/backend/internal/service/api_key_auth_cache_impl.go @@ -223,26 +223,26 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { } if apiKey.Group != nil { snapshot.Group = &APIKeyAuthGroupSnapshot{ - ID: apiKey.Group.ID, - Name: apiKey.Group.Name, - Platform: apiKey.Group.Platform, - Status: apiKey.Group.Status, - SubscriptionType: apiKey.Group.SubscriptionType, - RateMultiplier: apiKey.Group.RateMultiplier, - DailyLimitUSD: apiKey.Group.DailyLimitUSD, - WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD, - MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD, - ImagePrice1K: apiKey.Group.ImagePrice1K, - ImagePrice2K: apiKey.Group.ImagePrice2K, - ImagePrice4K: apiKey.Group.ImagePrice4K, - SoraImagePrice360: apiKey.Group.SoraImagePrice360, - SoraImagePrice540: apiKey.Group.SoraImagePrice540, - SoraVideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest, + ID: apiKey.Group.ID, + Name: apiKey.Group.Name, + Platform: apiKey.Group.Platform, + Status: apiKey.Group.Status, + SubscriptionType: apiKey.Group.SubscriptionType, + RateMultiplier: apiKey.Group.RateMultiplier, + DailyLimitUSD: apiKey.Group.DailyLimitUSD, + WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD, + MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD, + ImagePrice1K: apiKey.Group.ImagePrice1K, + ImagePrice2K: apiKey.Group.ImagePrice2K, + ImagePrice4K: apiKey.Group.ImagePrice4K, + SoraImagePrice360: apiKey.Group.SoraImagePrice360, + SoraImagePrice540: apiKey.Group.SoraImagePrice540, + SoraVideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest, SoraVideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD, - ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly, - FallbackGroupID: apiKey.Group.FallbackGroupID, - ModelRouting: apiKey.Group.ModelRouting, - ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled, + ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly, + FallbackGroupID: apiKey.Group.FallbackGroupID, + ModelRouting: apiKey.Group.ModelRouting, + ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled, } } return snapshot @@ -270,27 +270,27 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho } if snapshot.Group != nil { apiKey.Group = &Group{ - ID: snapshot.Group.ID, - Name: snapshot.Group.Name, - Platform: snapshot.Group.Platform, - Status: snapshot.Group.Status, - Hydrated: true, - SubscriptionType: snapshot.Group.SubscriptionType, - RateMultiplier: snapshot.Group.RateMultiplier, - DailyLimitUSD: snapshot.Group.DailyLimitUSD, - WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD, - MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD, - ImagePrice1K: snapshot.Group.ImagePrice1K, - ImagePrice2K: snapshot.Group.ImagePrice2K, - ImagePrice4K: snapshot.Group.ImagePrice4K, - SoraImagePrice360: snapshot.Group.SoraImagePrice360, - SoraImagePrice540: snapshot.Group.SoraImagePrice540, - SoraVideoPricePerRequest: snapshot.Group.SoraVideoPricePerRequest, + ID: snapshot.Group.ID, + Name: snapshot.Group.Name, + Platform: snapshot.Group.Platform, + Status: snapshot.Group.Status, + Hydrated: true, + SubscriptionType: snapshot.Group.SubscriptionType, + RateMultiplier: snapshot.Group.RateMultiplier, + DailyLimitUSD: snapshot.Group.DailyLimitUSD, + WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD, + MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD, + ImagePrice1K: snapshot.Group.ImagePrice1K, + ImagePrice2K: snapshot.Group.ImagePrice2K, + ImagePrice4K: snapshot.Group.ImagePrice4K, + SoraImagePrice360: snapshot.Group.SoraImagePrice360, + SoraImagePrice540: snapshot.Group.SoraImagePrice540, + SoraVideoPricePerRequest: snapshot.Group.SoraVideoPricePerRequest, SoraVideoPricePerRequestHD: snapshot.Group.SoraVideoPricePerRequestHD, - ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly, - FallbackGroupID: snapshot.Group.FallbackGroupID, - ModelRouting: snapshot.Group.ModelRouting, - ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled, + ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly, + FallbackGroupID: snapshot.Group.FallbackGroupID, + ModelRouting: snapshot.Group.ModelRouting, + ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled, } } return apiKey diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index f0933ae3..6925801d 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -3465,7 +3465,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu var cost *CostBreakdown // 根据请求类型选择计费方式 - if result.MediaType == "image" || result.MediaType == "video" || result.MediaType == "prompt" { + if result.MediaType == "image" || result.MediaType == "video" { var soraConfig *SoraPriceConfig if apiKey.Group != nil { soraConfig = &SoraPriceConfig{ @@ -3480,6 +3480,8 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu } else { cost = s.billingService.CalculateSoraVideoCost(result.Model, soraConfig, multiplier) } + } else if result.MediaType == "prompt" { + cost = &CostBreakdown{} } else if result.ImageCount > 0 { // 图片生成计费 var groupConfig *ImagePriceConfig diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go index bc97e062..e8bf03d4 100644 --- a/backend/internal/service/group.go +++ b/backend/internal/service/group.go @@ -27,9 +27,9 @@ type Group struct { ImagePrice4K *float64 // Sora 按次计费配置(阶段 1) - SoraImagePrice360 *float64 - SoraImagePrice540 *float64 - SoraVideoPricePerRequest *float64 + SoraImagePrice360 *float64 + SoraImagePrice540 *float64 + SoraVideoPricePerRequest *float64 SoraVideoPricePerRequestHD *float64 // Claude Code 客户端限制 diff --git a/backend/internal/service/sora2api_service.go b/backend/internal/service/sora2api_service.go index d4bf9ba4..c047cd40 100644 --- a/backend/internal/service/sora2api_service.go +++ b/backend/internal/service/sora2api_service.go @@ -62,7 +62,6 @@ type Sora2APIService struct { adminUsername string adminPassword string adminTokenTTL time.Duration - adminTimeout time.Duration tokenImportMode string client *http.Client @@ -72,9 +71,8 @@ type Sora2APIService struct { adminTokenAt time.Time adminMu sync.Mutex - modelCache []Sora2APIModel - modelCacheAt time.Time - modelMu sync.RWMutex + modelCache []Sora2APIModel + modelMu sync.RWMutex } func NewSora2APIService(cfg *config.Config) *Sora2APIService { @@ -96,7 +94,6 @@ func NewSora2APIService(cfg *config.Config) *Sora2APIService { adminUsername: strings.TrimSpace(cfg.Sora2API.AdminUsername), adminPassword: strings.TrimSpace(cfg.Sora2API.AdminPassword), adminTokenTTL: adminTTL, - adminTimeout: adminTimeout, tokenImportMode: strings.ToLower(strings.TrimSpace(cfg.Sora2API.TokenImportMode)), client: &http.Client{}, adminClient: &http.Client{Timeout: adminTimeout}, @@ -176,7 +173,6 @@ func (s *Sora2APIService) ListModels(ctx context.Context) ([]Sora2APIModel, erro s.modelMu.Lock() s.modelCache = models - s.modelCacheAt = time.Now() s.modelMu.Unlock() return models, nil diff --git a/backend/internal/service/sora_gateway_service.go b/backend/internal/service/sora_gateway_service.go index 82f4eaaa..2909a76f 100644 --- a/backend/internal/service/sora_gateway_service.go +++ b/backend/internal/service/sora_gateway_service.go @@ -23,6 +23,8 @@ var soraSSEDataRe = regexp.MustCompile(`^data:\s*`) var soraImageMarkdownRe = regexp.MustCompile(`!\[[^\]]*\]\(([^)]+)\)`) var soraVideoHTMLRe = regexp.MustCompile(`(?i)]+src=['"]([^'"]+)['"]`) +const soraRewriteBufferLimit = 2048 + var soraImageSizeMap = map[string]string{ "gpt-image": "360", "gpt-image-landscape": "540", @@ -30,7 +32,6 @@ var soraImageSizeMap = map[string]string{ } type soraStreamingResult struct { - content string mediaType string mediaURLs []string imageCount int @@ -307,6 +308,7 @@ func (s *SoraGatewayService) handleStreamingResponse(ctx context.Context, resp * contentBuilder := strings.Builder{} var firstTokenMs *int var upstreamError error + rewriteBuffer := "" scanner := bufio.NewScanner(resp.Body) maxLineSize := defaultMaxLineSize @@ -333,12 +335,29 @@ func (s *SoraGatewayService) handleStreamingResponse(ctx context.Context, resp * if soraSSEDataRe.MatchString(line) { data := soraSSEDataRe.ReplaceAllString(line, "") if data == "[DONE]" { + if rewriteBuffer != "" { + flushLine, flushContent, err := s.flushSoraRewriteBuffer(rewriteBuffer, originalModel) + if err != nil { + return nil, err + } + if flushLine != "" { + if flushContent != "" { + if _, err := contentBuilder.WriteString(flushContent); err != nil { + return nil, err + } + } + if err := sendLine(flushLine); err != nil { + return nil, err + } + } + rewriteBuffer = "" + } if err := sendLine("data: [DONE]"); err != nil { return nil, err } break } - updatedLine, contentDelta, errEvent := s.processSoraSSEData(data, originalModel) + updatedLine, contentDelta, errEvent := s.processSoraSSEData(data, originalModel, &rewriteBuffer) if errEvent != nil && upstreamError == nil { upstreamError = errEvent } @@ -347,7 +366,9 @@ func (s *SoraGatewayService) handleStreamingResponse(ctx context.Context, resp * ms := int(time.Since(startTime).Milliseconds()) firstTokenMs = &ms } - contentBuilder.WriteString(contentDelta) + if _, err := contentBuilder.WriteString(contentDelta); err != nil { + return nil, err + } } if err := sendLine(updatedLine); err != nil { return nil, err @@ -417,7 +438,6 @@ func (s *SoraGatewayService) handleStreamingResponse(ctx context.Context, resp * } return &soraStreamingResult{ - content: content, mediaType: mediaType, mediaURLs: mediaURLs, imageCount: imageCount, @@ -426,7 +446,7 @@ func (s *SoraGatewayService) handleStreamingResponse(ctx context.Context, resp * }, nil } -func (s *SoraGatewayService) processSoraSSEData(data string, originalModel string) (string, string, error) { +func (s *SoraGatewayService) processSoraSSEData(data string, originalModel string, rewriteBuffer *string) (string, string, error) { if strings.TrimSpace(data) == "" { return "data: ", "", nil } @@ -448,7 +468,12 @@ func (s *SoraGatewayService) processSoraSSEData(data string, originalModel strin contentDelta, updated := extractSoraContent(payload) if updated { - rewritten := s.rewriteSoraContent(contentDelta) + var rewritten string + if rewriteBuffer != nil { + rewritten = s.rewriteSoraContentWithBuffer(contentDelta, rewriteBuffer) + } else { + rewritten = s.rewriteSoraContent(contentDelta) + } if rewritten != contentDelta { applySoraContent(payload, rewritten) contentDelta = rewritten @@ -504,6 +529,78 @@ func applySoraContent(payload map[string]any, content string) { } } +func (s *SoraGatewayService) rewriteSoraContentWithBuffer(contentDelta string, buffer *string) string { + if buffer == nil { + return s.rewriteSoraContent(contentDelta) + } + if contentDelta == "" && *buffer == "" { + return "" + } + combined := *buffer + contentDelta + rewritten := s.rewriteSoraContent(combined) + bufferStart := s.findSoraRewriteBufferStart(rewritten) + if bufferStart < 0 { + *buffer = "" + return rewritten + } + if len(rewritten)-bufferStart > soraRewriteBufferLimit { + bufferStart = len(rewritten) - soraRewriteBufferLimit + } + output := rewritten[:bufferStart] + *buffer = rewritten[bufferStart:] + return output +} + +func (s *SoraGatewayService) findSoraRewriteBufferStart(content string) int { + minIndex := -1 + start := 0 + for { + idx := strings.Index(content[start:], "![") + if idx < 0 { + break + } + idx += start + if !hasSoraImageMatchAt(content, idx) { + if minIndex == -1 || idx < minIndex { + minIndex = idx + } + } + start = idx + 2 + } + lower := strings.ToLower(content) + start = 0 + for { + idx := strings.Index(lower[start:], "= len(content) { + return false + } + loc := soraImageMarkdownRe.FindStringIndex(content[idx:]) + return loc != nil && loc[0] == 0 +} + +func hasSoraVideoMatchAt(content string, idx int) bool { + if idx < 0 || idx >= len(content) { + return false + } + loc := soraVideoHTMLRe.FindStringIndex(content[idx:]) + return loc != nil && loc[0] == 0 +} + func (s *SoraGatewayService) rewriteSoraContent(content string) string { if content == "" { return content @@ -533,6 +630,31 @@ func (s *SoraGatewayService) rewriteSoraContent(content string) string { return content } +func (s *SoraGatewayService) flushSoraRewriteBuffer(buffer string, originalModel string) (string, string, error) { + if buffer == "" { + return "", "", nil + } + rewritten := s.rewriteSoraContent(buffer) + payload := map[string]any{ + "choices": []any{ + map[string]any{ + "delta": map[string]any{ + "content": rewritten, + }, + "index": 0, + }, + }, + } + if originalModel != "" { + payload["model"] = originalModel + } + updatedData, err := json.Marshal(payload) + if err != nil { + return "", "", err + } + return "data: " + string(updatedData), rewritten, nil +} + func (s *SoraGatewayService) rewriteSoraURL(raw string) string { raw = strings.TrimSpace(raw) if raw == "" { diff --git a/backend/internal/service/sora_media_sign.go b/backend/internal/service/sora_media_sign.go index 5d4a8d88..26bf8923 100644 --- a/backend/internal/service/sora_media_sign.go +++ b/backend/internal/service/sora_media_sign.go @@ -15,9 +15,15 @@ func SignSoraMediaURL(path string, query string, expires int64, key string) stri return "" } mac := hmac.New(sha256.New, []byte(key)) - mac.Write([]byte(buildSoraMediaSignPayload(path, query))) - mac.Write([]byte("|")) - mac.Write([]byte(strconv.FormatInt(expires, 10))) + if _, err := mac.Write([]byte(buildSoraMediaSignPayload(path, query))); err != nil { + return "" + } + if _, err := mac.Write([]byte("|")); err != nil { + return "" + } + if _, err := mac.Write([]byte(strconv.FormatInt(expires, 10))); err != nil { + return "" + } return hex.EncodeToString(mac.Sum(nil)) } diff --git a/backend/internal/service/token_refresh_service.go b/backend/internal/service/token_refresh_service.go index 167d2b54..435056ab 100644 --- a/backend/internal/service/token_refresh_service.go +++ b/backend/internal/service/token_refresh_service.go @@ -15,11 +15,9 @@ import ( // 定期检查并刷新即将过期的token type TokenRefreshService struct { accountRepo AccountRepository - soraAccountRepo SoraAccountRepository // Sora 扩展表仓储,用于双表同步 refreshers []TokenRefresher cfg *config.TokenRefreshConfig cacheInvalidator TokenCacheInvalidator - soraSyncService *Sora2APISyncService stopCh chan struct{} wg sync.WaitGroup @@ -57,7 +55,6 @@ func NewTokenRefreshService( // 用于在 OpenAI Token 刷新时同步更新 sora_accounts 表 // 需要在 Start() 之前调用 func (s *TokenRefreshService) SetSoraAccountRepo(repo SoraAccountRepository) { - s.soraAccountRepo = repo // 将 soraAccountRepo 注入到 OpenAITokenRefresher for _, refresher := range s.refreshers { if openaiRefresher, ok := refresher.(*OpenAITokenRefresher); ok { @@ -69,7 +66,6 @@ func (s *TokenRefreshService) SetSoraAccountRepo(repo SoraAccountRepository) { // SetSoraSyncService 设置 Sora2API 同步服务 // 需要在 Start() 之前调用 func (s *TokenRefreshService) SetSoraSyncService(svc *Sora2APISyncService) { - s.soraSyncService = svc for _, refresher := range s.refreshers { if openaiRefresher, ok := refresher.(*OpenAITokenRefresher); ok { openaiRefresher.SetSoraSyncService(svc) diff --git a/backend/internal/service/token_refresher.go b/backend/internal/service/token_refresher.go index 9699092d..7e084bd5 100644 --- a/backend/internal/service/token_refresher.go +++ b/backend/internal/service/token_refresher.go @@ -83,10 +83,10 @@ func (r *ClaudeTokenRefresher) Refresh(ctx context.Context, account *Account) (m // OpenAITokenRefresher 处理 OpenAI OAuth token刷新 type OpenAITokenRefresher struct { - openaiOAuthService *OpenAIOAuthService - accountRepo AccountRepository - soraAccountRepo SoraAccountRepository // Sora 扩展表仓储,用于双表同步 - soraSyncService *Sora2APISyncService // Sora2API 同步服务 + openaiOAuthService *OpenAIOAuthService + accountRepo AccountRepository + soraAccountRepo SoraAccountRepository // Sora 扩展表仓储,用于双表同步 + soraSyncService *Sora2APISyncService // Sora2API 同步服务 } // NewOpenAITokenRefresher 创建 OpenAI token刷新器 diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 689fa5d7..fb0946d2 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -51,9 +51,7 @@ func ProvideTokenRefreshService( svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, cfg) // 注入 Sora 账号扩展表仓储,用于 OpenAI Token 刷新时同步 sora_accounts 表 svc.SetSoraAccountRepo(soraAccountRepo) - if soraSyncService != nil { - svc.SetSoraSyncService(soraSyncService) - } + svc.SetSoraSyncService(soraSyncService) svc.Start() return svc } @@ -242,8 +240,6 @@ var ProviderSet = wire.NewSet( NewAntigravityTokenProvider, NewOpenAITokenProvider, NewClaudeTokenProvider, - NewSora2APIService, - NewSora2APISyncService, NewAntigravityGatewayService, ProvideRateLimitService, NewAccountUsageService,