diff --git a/backend/internal/handler/admin/channel_handler.go b/backend/internal/handler/admin/channel_handler.go index fb6f7d02..492b6b8f 100644 --- a/backend/internal/handler/admin/channel_handler.go +++ b/backend/internal/handler/admin/channel_handler.go @@ -24,29 +24,29 @@ func NewChannelHandler(channelService *service.ChannelService) *ChannelHandler { // --- Request / Response types --- type createChannelRequest struct { - Name string `json:"name" binding:"required,max=100"` - Description string `json:"description"` - GroupIDs []int64 `json:"group_ids"` - ModelPricing []channelModelPricingRequest `json:"model_pricing"` + Name string `json:"name" binding:"required,max=100"` + Description string `json:"description"` + GroupIDs []int64 `json:"group_ids"` + ModelPricing []channelModelPricingRequest `json:"model_pricing"` } type updateChannelRequest struct { - Name string `json:"name" binding:"omitempty,max=100"` - Description *string `json:"description"` - Status string `json:"status" binding:"omitempty,oneof=active disabled"` - GroupIDs *[]int64 `json:"group_ids"` - ModelPricing *[]channelModelPricingRequest `json:"model_pricing"` + Name string `json:"name" binding:"omitempty,max=100"` + Description *string `json:"description"` + Status string `json:"status" binding:"omitempty,oneof=active disabled"` + GroupIDs *[]int64 `json:"group_ids"` + ModelPricing *[]channelModelPricingRequest `json:"model_pricing"` } type channelModelPricingRequest struct { - Models []string `json:"models" binding:"required,min=1,max=100"` - BillingMode string `json:"billing_mode" binding:"omitempty,oneof=token per_request image"` - InputPrice *float64 `json:"input_price" binding:"omitempty,min=0"` - OutputPrice *float64 `json:"output_price" binding:"omitempty,min=0"` - CacheWritePrice *float64 `json:"cache_write_price" binding:"omitempty,min=0"` - CacheReadPrice *float64 `json:"cache_read_price" binding:"omitempty,min=0"` - ImageOutputPrice *float64 `json:"image_output_price" binding:"omitempty,min=0"` - Intervals []pricingIntervalRequest `json:"intervals"` + Models []string `json:"models" binding:"required,min=1,max=100"` + BillingMode string `json:"billing_mode" binding:"omitempty,oneof=token per_request image"` + InputPrice *float64 `json:"input_price" binding:"omitempty,min=0"` + OutputPrice *float64 `json:"output_price" binding:"omitempty,min=0"` + CacheWritePrice *float64 `json:"cache_write_price" binding:"omitempty,min=0"` + CacheReadPrice *float64 `json:"cache_read_price" binding:"omitempty,min=0"` + ImageOutputPrice *float64 `json:"image_output_price" binding:"omitempty,min=0"` + Intervals []pricingIntervalRequest `json:"intervals"` } type pricingIntervalRequest struct { @@ -62,26 +62,26 @@ type pricingIntervalRequest struct { } type channelResponse struct { - ID int64 `json:"id"` - Name string `json:"name"` - Description string `json:"description"` - Status string `json:"status"` - GroupIDs []int64 `json:"group_ids"` - ModelPricing []channelModelPricingResponse `json:"model_pricing"` - CreatedAt string `json:"created_at"` - UpdatedAt string `json:"updated_at"` + ID int64 `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + Status string `json:"status"` + GroupIDs []int64 `json:"group_ids"` + ModelPricing []channelModelPricingResponse `json:"model_pricing"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` } type channelModelPricingResponse struct { - ID int64 `json:"id"` - Models []string `json:"models"` - BillingMode string `json:"billing_mode"` - InputPrice *float64 `json:"input_price"` - OutputPrice *float64 `json:"output_price"` - CacheWritePrice *float64 `json:"cache_write_price"` - CacheReadPrice *float64 `json:"cache_read_price"` - ImageOutputPrice *float64 `json:"image_output_price"` - Intervals []pricingIntervalResponse `json:"intervals"` + ID int64 `json:"id"` + Models []string `json:"models"` + BillingMode string `json:"billing_mode"` + InputPrice *float64 `json:"input_price"` + OutputPrice *float64 `json:"output_price"` + CacheWritePrice *float64 `json:"cache_write_price"` + CacheReadPrice *float64 `json:"cache_read_price"` + ImageOutputPrice *float64 `json:"image_output_price"` + Intervals []pricingIntervalResponse `json:"intervals"` } type pricingIntervalResponse struct { @@ -106,7 +106,7 @@ func channelToResponse(ch *service.Channel) *channelResponse { Name: ch.Name, Description: ch.Description, Status: ch.Status, - GroupIDs: ch.GroupIDs, + GroupIDs: ch.GroupIDs, CreatedAt: ch.CreatedAt.Format("2006-01-02T15:04:05Z"), UpdatedAt: ch.UpdatedAt.Format("2006-01-02T15:04:05Z"), } diff --git a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go index 69c8d1d5..7dc062df 100644 --- a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go +++ b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go @@ -161,6 +161,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi nil, // digestStore nil, // settingService nil, // tlsFPProfileService + nil, // channelService ) // RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。 diff --git a/backend/internal/handler/sora_client_handler_test.go b/backend/internal/handler/sora_client_handler_test.go index fe035b6f..78e2d24b 100644 --- a/backend/internal/handler/sora_client_handler_test.go +++ b/backend/internal/handler/sora_client_handler_test.go @@ -2224,7 +2224,7 @@ func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Ac func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService { return service.NewGatewayService( accountRepo, nil, nil, nil, nil, nil, nil, nil, nil, - nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, + nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, ) } diff --git a/backend/internal/handler/sora_gateway_handler_test.go b/backend/internal/handler/sora_gateway_handler_test.go index c790a36c..18e6e929 100644 --- a/backend/internal/handler/sora_gateway_handler_test.go +++ b/backend/internal/handler/sora_gateway_handler_test.go @@ -465,6 +465,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) { nil, // digestStore nil, // settingService nil, // tlsFPProfileService + nil, // channelService ) soraClient := &stubSoraClient{imageURLs: []string{"https://example.com/a.png"}} diff --git a/backend/internal/repository/channel_repo.go b/backend/internal/repository/channel_repo.go index aa8696ab..9259edd6 100644 --- a/backend/internal/repository/channel_repo.go +++ b/backend/internal/repository/channel_repo.go @@ -186,7 +186,7 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati if err != nil { return nil, nil, fmt.Errorf("query channels: %w", err) } - defer rows.Close() + defer func() { _ = rows.Close() }() var channels []service.Channel var channelIDs []int64 @@ -240,7 +240,7 @@ func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, err if err != nil { return nil, fmt.Errorf("query all channels: %w", err) } - defer rows.Close() + defer func() { _ = rows.Close() }() var channels []service.Channel var channelIDs []int64 @@ -292,7 +292,7 @@ func (r *channelRepository) batchLoadGroupIDs(ctx context.Context, channelIDs [] if err != nil { return nil, fmt.Errorf("batch load group ids: %w", err) } - defer rows.Close() + defer func() { _ = rows.Close() }() groupMap := make(map[int64][]int64, len(channelIDs)) for rows.Next() { @@ -333,7 +333,7 @@ func (r *channelRepository) GetGroupIDs(ctx context.Context, channelID int64) ([ if err != nil { return nil, fmt.Errorf("get group ids: %w", err) } - defer rows.Close() + defer func() { _ = rows.Close() }() var ids []int64 for rows.Next() { @@ -375,7 +375,7 @@ func (r *channelRepository) GetGroupsInOtherChannels(ctx context.Context, channe if err != nil { return nil, fmt.Errorf("get groups in other channels: %w", err) } - defer rows.Close() + defer func() { _ = rows.Close() }() var conflicting []int64 for rows.Next() { diff --git a/backend/internal/repository/channel_repo_pricing.go b/backend/internal/repository/channel_repo_pricing.go index 2e7ec6a3..87c856f8 100644 --- a/backend/internal/repository/channel_repo_pricing.go +++ b/backend/internal/repository/channel_repo_pricing.go @@ -21,7 +21,7 @@ func (r *channelRepository) ListModelPricing(ctx context.Context, channelID int6 if err != nil { return nil, fmt.Errorf("list model pricing: %w", err) } - defer rows.Close() + defer func() { _ = rows.Close() }() result, pricingIDs, err := scanModelPricingRows(rows) if err != nil { @@ -97,7 +97,7 @@ func (r *channelRepository) batchLoadModelPricing(ctx context.Context, channelID if err != nil { return nil, fmt.Errorf("batch load model pricing: %w", err) } - defer rows.Close() + defer func() { _ = rows.Close() }() allPricing, allPricingIDs, err := scanModelPricingRows(rows) if err != nil { @@ -139,7 +139,7 @@ func (r *channelRepository) batchLoadIntervals(ctx context.Context, pricingIDs [ if err != nil { return nil, fmt.Errorf("batch load intervals: %w", err) } - defer rows.Close() + defer func() { _ = rows.Close() }() intervalMap := make(map[int64][]service.PricingInterval, len(pricingIDs)) for rows.Next() { diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go index 58c86f36..7deb1cf9 100644 --- a/backend/internal/service/billing_service.go +++ b/backend/internal/service/billing_service.go @@ -413,12 +413,12 @@ func (s *BillingService) CalculateCostWithChannel(model string, tokens UsageToke type CostInput struct { Ctx context.Context Model string - GroupID *int64 // 用于渠道定价查找 + GroupID *int64 // 用于渠道定价查找 Tokens UsageTokens - RequestCount int // 按次计费时使用 - SizeTier string // 按次/图片模式的层级标签("1K","2K","4K","HD" 等) + RequestCount int // 按次计费时使用 + SizeTier string // 按次/图片模式的层级标签("1K","2K","4K","HD" 等) RateMultiplier float64 - ServiceTier string // "priority","flex","" 等 + ServiceTier string // "priority","flex","" 等 Resolver *ModelPricingResolver // 定价解析器 } diff --git a/backend/internal/service/channel.go b/backend/internal/service/channel.go index e3556edd..f408f246 100644 --- a/backend/internal/service/channel.go +++ b/backend/internal/service/channel.go @@ -10,8 +10,8 @@ type BillingMode string const ( BillingModeToken BillingMode = "token" // 按 token 区间计费 - BillingModePerRequest BillingMode = "per_request" // 按次计费(支持上下文窗口分层) - BillingModeImage BillingMode = "image" // 图片计费(当前按次,预留 token 计费) + BillingModePerRequest BillingMode = "per_request" // 按次计费(支持上下文窗口分层) + BillingModeImage BillingMode = "image" // 图片计费(当前按次,预留 token 计费) ) // IsValid 检查 BillingMode 是否为合法值 @@ -42,13 +42,13 @@ type Channel struct { type ChannelModelPricing struct { ID int64 ChannelID int64 - Models []string // 绑定的模型列表 - BillingMode BillingMode // 计费模式 - InputPrice *float64 // 每 token 输入价格(USD)— 向后兼容 flat 定价 - OutputPrice *float64 // 每 token 输出价格(USD) - CacheWritePrice *float64 // 缓存写入价格 - CacheReadPrice *float64 // 缓存读取价格 - ImageOutputPrice *float64 // 图片输出价格(向后兼容) + Models []string // 绑定的模型列表 + BillingMode BillingMode // 计费模式 + InputPrice *float64 // 每 token 输入价格(USD)— 向后兼容 flat 定价 + OutputPrice *float64 // 每 token 输出价格(USD) + CacheWritePrice *float64 // 缓存写入价格 + CacheReadPrice *float64 // 缓存读取价格 + ImageOutputPrice *float64 // 图片输出价格(向后兼容) Intervals []PricingInterval // 区间定价列表 CreatedAt time.Time UpdatedAt time.Time