diff --git a/backend/ent/schema/usage_log.go b/backend/ent/schema/usage_log.go index 32c39e25..667a6ed5 100644 --- a/backend/ent/schema/usage_log.go +++ b/backend/ent/schema/usage_log.go @@ -53,6 +53,9 @@ func (UsageLog) Fields() []ent.Field { MaxLen(100). Optional(). Nillable(), + field.Int64("channel_id").Optional().Nillable().Comment("渠道 ID"), + field.String("model_mapping_chain").MaxLen(500).Optional().Nillable().Comment("模型映射链"), + field.String("billing_tier").MaxLen(50).Optional().Nillable().Comment("计费层级标签"), field.Int64("group_id"). Optional(). Nillable(), diff --git a/backend/internal/handler/admin/channel_handler.go b/backend/internal/handler/admin/channel_handler.go index f47d6791..80c707e6 100644 --- a/backend/internal/handler/admin/channel_handler.go +++ b/backend/internal/handler/admin/channel_handler.go @@ -24,31 +24,36 @@ func NewChannelHandler(channelService *service.ChannelService) *ChannelHandler { // --- Request / Response types --- type createChannelRequest struct { - Name string `json:"name" binding:"required,max=100"` - Description string `json:"description"` - GroupIDs []int64 `json:"group_ids"` - ModelPricing []channelModelPricingRequest `json:"model_pricing"` - ModelMapping map[string]string `json:"model_mapping"` + Name string `json:"name" binding:"required,max=100"` + Description string `json:"description"` + GroupIDs []int64 `json:"group_ids"` + ModelPricing []channelModelPricingRequest `json:"model_pricing"` + ModelMapping map[string]string `json:"model_mapping"` + BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream"` + RestrictModels bool `json:"restrict_models"` } type updateChannelRequest struct { - Name string `json:"name" binding:"omitempty,max=100"` - Description *string `json:"description"` - Status string `json:"status" binding:"omitempty,oneof=active disabled"` - GroupIDs *[]int64 `json:"group_ids"` - ModelPricing *[]channelModelPricingRequest `json:"model_pricing"` - ModelMapping map[string]string `json:"model_mapping"` + Name string `json:"name" binding:"omitempty,max=100"` + Description *string `json:"description"` + Status string `json:"status" binding:"omitempty,oneof=active disabled"` + GroupIDs *[]int64 `json:"group_ids"` + ModelPricing *[]channelModelPricingRequest `json:"model_pricing"` + ModelMapping map[string]string `json:"model_mapping"` + BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream"` + RestrictModels *bool `json:"restrict_models"` } type channelModelPricingRequest struct { - Models []string `json:"models" binding:"required,min=1,max=100"` - BillingMode string `json:"billing_mode" binding:"omitempty,oneof=token per_request image"` - InputPrice *float64 `json:"input_price" binding:"omitempty,min=0"` - OutputPrice *float64 `json:"output_price" binding:"omitempty,min=0"` - CacheWritePrice *float64 `json:"cache_write_price" binding:"omitempty,min=0"` - CacheReadPrice *float64 `json:"cache_read_price" binding:"omitempty,min=0"` - ImageOutputPrice *float64 `json:"image_output_price" binding:"omitempty,min=0"` - Intervals []pricingIntervalRequest `json:"intervals"` + Models []string `json:"models" binding:"required,min=1,max=100"` + BillingMode string `json:"billing_mode" binding:"omitempty,oneof=token per_request image"` + InputPrice *float64 `json:"input_price" binding:"omitempty,min=0"` + OutputPrice *float64 `json:"output_price" binding:"omitempty,min=0"` + CacheWritePrice *float64 `json:"cache_write_price" binding:"omitempty,min=0"` + CacheReadPrice *float64 `json:"cache_read_price" binding:"omitempty,min=0"` + ImageOutputPrice *float64 `json:"image_output_price" binding:"omitempty,min=0"` + PerRequestPrice *float64 `json:"per_request_price" binding:"omitempty,min=0"` + Intervals []pricingIntervalRequest `json:"intervals"` } type pricingIntervalRequest struct { @@ -64,27 +69,30 @@ type pricingIntervalRequest struct { } type channelResponse struct { - ID int64 `json:"id"` - Name string `json:"name"` - Description string `json:"description"` - Status string `json:"status"` - GroupIDs []int64 `json:"group_ids"` - ModelPricing []channelModelPricingResponse `json:"model_pricing"` - ModelMapping map[string]string `json:"model_mapping"` - CreatedAt string `json:"created_at"` - UpdatedAt string `json:"updated_at"` + ID int64 `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + Status string `json:"status"` + BillingModelSource string `json:"billing_model_source"` + RestrictModels bool `json:"restrict_models"` + GroupIDs []int64 `json:"group_ids"` + ModelPricing []channelModelPricingResponse `json:"model_pricing"` + ModelMapping map[string]string `json:"model_mapping"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` } type channelModelPricingResponse struct { - ID int64 `json:"id"` - Models []string `json:"models"` - BillingMode string `json:"billing_mode"` - InputPrice *float64 `json:"input_price"` - OutputPrice *float64 `json:"output_price"` - CacheWritePrice *float64 `json:"cache_write_price"` - CacheReadPrice *float64 `json:"cache_read_price"` - ImageOutputPrice *float64 `json:"image_output_price"` - Intervals []pricingIntervalResponse `json:"intervals"` + ID int64 `json:"id"` + Models []string `json:"models"` + BillingMode string `json:"billing_mode"` + InputPrice *float64 `json:"input_price"` + OutputPrice *float64 `json:"output_price"` + CacheWritePrice *float64 `json:"cache_write_price"` + CacheReadPrice *float64 `json:"cache_read_price"` + ImageOutputPrice *float64 `json:"image_output_price"` + PerRequestPrice *float64 `json:"per_request_price"` + Intervals []pricingIntervalResponse `json:"intervals"` } type pricingIntervalResponse struct { @@ -109,11 +117,16 @@ func channelToResponse(ch *service.Channel) *channelResponse { Name: ch.Name, Description: ch.Description, Status: ch.Status, + RestrictModels: ch.RestrictModels, GroupIDs: ch.GroupIDs, ModelMapping: ch.ModelMapping, CreatedAt: ch.CreatedAt.Format("2006-01-02T15:04:05Z"), UpdatedAt: ch.UpdatedAt.Format("2006-01-02T15:04:05Z"), } + resp.BillingModelSource = ch.BillingModelSource + if resp.BillingModelSource == "" { + resp.BillingModelSource = "requested" + } if resp.GroupIDs == nil { resp.GroupIDs = []int64{} } @@ -155,6 +168,7 @@ func channelToResponse(ch *service.Channel) *channelResponse { CacheWritePrice: p.CacheWritePrice, CacheReadPrice: p.CacheReadPrice, ImageOutputPrice: p.ImageOutputPrice, + PerRequestPrice: p.PerRequestPrice, Intervals: intervals, }) } @@ -190,6 +204,7 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe CacheWritePrice: r.CacheWritePrice, CacheReadPrice: r.CacheReadPrice, ImageOutputPrice: r.ImageOutputPrice, + PerRequestPrice: r.PerRequestPrice, Intervals: intervals, }) } @@ -249,11 +264,13 @@ func (h *ChannelHandler) Create(c *gin.Context) { } channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{ - Name: req.Name, - Description: req.Description, - GroupIDs: req.GroupIDs, - ModelPricing: pricingRequestToService(req.ModelPricing), - ModelMapping: req.ModelMapping, + Name: req.Name, + Description: req.Description, + GroupIDs: req.GroupIDs, + ModelPricing: pricingRequestToService(req.ModelPricing), + ModelMapping: req.ModelMapping, + BillingModelSource: req.BillingModelSource, + RestrictModels: req.RestrictModels, }) if err != nil { response.ErrorFrom(c, err) @@ -279,11 +296,13 @@ func (h *ChannelHandler) Update(c *gin.Context) { } input := &service.UpdateChannelInput{ - Name: req.Name, - Description: req.Description, - Status: req.Status, - GroupIDs: req.GroupIDs, - ModelMapping: req.ModelMapping, + Name: req.Name, + Description: req.Description, + Status: req.Status, + GroupIDs: req.GroupIDs, + ModelMapping: req.ModelMapping, + BillingModelSource: req.BillingModelSource, + RestrictModels: req.RestrictModels, } if req.ModelPricing != nil { pricing := pricingRequestToService(*req.ModelPricing) diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index a8da92c0..cef3a5f8 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -604,6 +604,9 @@ func UsageLogFromServiceAdmin(l *service.UsageLog) *AdminUsageLog { return &AdminUsageLog{ UsageLog: usageLogFromServiceUser(l), UpstreamModel: l.UpstreamModel, + ChannelID: l.ChannelID, + ModelMappingChain: l.ModelMappingChain, + BillingTier: l.BillingTier, AccountRateMultiplier: l.AccountRateMultiplier, IPAddress: l.IPAddress, Account: AccountSummaryFromService(l.Account), diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index 46984044..392abe2d 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -406,6 +406,13 @@ type AdminUsageLog struct { // Omitted when no mapping was applied (requested model was used as-is). UpstreamModel *string `json:"upstream_model,omitempty"` + // ChannelID 渠道 ID + ChannelID *int64 `json:"channel_id,omitempty"` + // ModelMappingChain 模型映射链,如 "a→b→c" + ModelMappingChain *string `json:"model_mapping_chain,omitempty"` + // BillingTier 计费层级标签(per_request/image 模式) + BillingTier *string `json:"billing_tier,omitempty"` + // AccountRateMultiplier 账号计费倍率快照(nil 表示按 1.0 处理) AccountRateMultiplier *float64 `json:"account_rate_multiplier"` diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index a0d8b2e9..2ad3bb76 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -158,6 +158,24 @@ func (h *GatewayHandler) Messages(c *gin.Context) { reqStream := parsedReq.Stream reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) + // 解析渠道级模型映射 + var channelMapping service.ChannelMappingResult + if apiKey.GroupID != nil { + channelMapping = h.gatewayService.ResolveChannelMapping(c.Request.Context(), *apiKey.GroupID, reqModel) + } + + // 渠道模型限制检查 + if apiKey.GroupID != nil { + checkModel := reqModel + if channelMapping.Mapped { + checkModel = channelMapping.MappedModel + } + if h.gatewayService.IsModelRestricted(c.Request.Context(), *apiKey.GroupID, checkModel) { + h.errorResponse(c, http.StatusForbidden, "invalid_request_error", "Model not available in current channel: "+reqModel) + return + } + } + // 设置 max_tokens=1 + haiku 探测请求标识到 context 中 // 必须在 SetClaudeCodeClientContext 之前设置,因为 ClaudeCodeValidator 需要读取此标识进行绕过判断 if isMaxTokensOneHaikuRequest(reqModel, parsedReq.MaxTokens, reqStream) { @@ -478,6 +496,21 @@ func (h *GatewayHandler) Messages(c *gin.Context) { RequestPayloadHash: requestPayloadHash, ForceCacheBilling: fs.ForceCacheBilling, APIKeyService: h.apiKeyService, + ChannelID: channelMapping.ChannelID, + OriginalModel: reqModel, + BillingModelSource: channelMapping.BillingModelSource, + ModelMappingChain: func() string { + if !channelMapping.Mapped { + if result.UpstreamModel != "" && result.UpstreamModel != result.Model { + return reqModel + "→" + result.UpstreamModel + } + return "" + } + if result.UpstreamModel != "" && result.UpstreamModel != channelMapping.MappedModel { + return reqModel + "→" + channelMapping.MappedModel + "→" + result.UpstreamModel + } + return reqModel + "→" + channelMapping.MappedModel + }(), }); err != nil { logger.L().With( zap.String("component", "handler.gateway.messages"), @@ -660,6 +693,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) { parsedReq.OnUpstreamAccepted = queueRelease // ===== 用户消息串行队列 END ===== + // 应用渠道模型映射到请求 + if channelMapping.Mapped { + parsedReq.Model = channelMapping.MappedModel + parsedReq.Body = h.gatewayService.ReplaceModelInBody(parsedReq.Body, channelMapping.MappedModel) + body = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel) + } + // 转发请求 - 根据账号平台分流 var result *service.ForwardResult requestCtx := c.Request.Context() @@ -810,6 +850,21 @@ func (h *GatewayHandler) Messages(c *gin.Context) { RequestPayloadHash: requestPayloadHash, ForceCacheBilling: fs.ForceCacheBilling, APIKeyService: h.apiKeyService, + ChannelID: channelMapping.ChannelID, + OriginalModel: reqModel, + BillingModelSource: channelMapping.BillingModelSource, + ModelMappingChain: func() string { + if !channelMapping.Mapped { + if result.UpstreamModel != "" && result.UpstreamModel != result.Model { + return reqModel + "→" + result.UpstreamModel + } + return "" + } + if result.UpstreamModel != "" && result.UpstreamModel != channelMapping.MappedModel { + return reqModel + "→" + channelMapping.MappedModel + "→" + result.UpstreamModel + } + return reqModel + "→" + channelMapping.MappedModel + }(), }); err != nil { logger.L().With( zap.String("component", "handler.gateway.messages"), diff --git a/backend/internal/repository/channel_repo.go b/backend/internal/repository/channel_repo.go index eaf25668..6d4008c8 100644 --- a/backend/internal/repository/channel_repo.go +++ b/backend/internal/repository/channel_repo.go @@ -42,9 +42,9 @@ func (r *channelRepository) Create(ctx context.Context, channel *service.Channel return err } err = tx.QueryRowContext(ctx, - `INSERT INTO channels (name, description, status, model_mapping) VALUES ($1, $2, $3, $4) + `INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models) VALUES ($1, $2, $3, $4, $5, $6) RETURNING id, created_at, updated_at`, - channel.Name, channel.Description, channel.Status, modelMappingJSON, + channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, ).Scan(&channel.ID, &channel.CreatedAt, &channel.UpdatedAt) if err != nil { if isUniqueViolation(err) { @@ -75,9 +75,9 @@ func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Cha ch := &service.Channel{} var modelMappingJSON []byte err := r.db.QueryRowContext(ctx, - `SELECT id, name, description, status, model_mapping, created_at, updated_at + `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, created_at, updated_at FROM channels WHERE id = $1`, id, - ).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.CreatedAt, &ch.UpdatedAt) + ).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.CreatedAt, &ch.UpdatedAt) if err == sql.ErrNoRows { return nil, service.ErrChannelNotFound } @@ -108,9 +108,9 @@ func (r *channelRepository) Update(ctx context.Context, channel *service.Channel return err } result, err := tx.ExecContext(ctx, - `UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, updated_at = NOW() - WHERE id = $5`, - channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.ID, + `UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, updated_at = NOW() + WHERE id = $7`, + channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.ID, ) if err != nil { if isUniqueViolation(err) { @@ -187,7 +187,7 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati // 查询 channel 列表 dataQuery := fmt.Sprintf( - `SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.created_at, c.updated_at + `SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.created_at, c.updated_at FROM channels c WHERE %s ORDER BY c.id DESC LIMIT $%d OFFSET $%d`, whereClause, argIdx, argIdx+1, ) @@ -204,7 +204,7 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati for rows.Next() { var ch service.Channel var modelMappingJSON []byte - if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.CreatedAt, &ch.UpdatedAt); err != nil { + if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.CreatedAt, &ch.UpdatedAt); err != nil { return nil, nil, fmt.Errorf("scan channel: %w", err) } ch.ModelMapping = unmarshalModelMapping(modelMappingJSON) @@ -248,7 +248,7 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, error) { rows, err := r.db.QueryContext(ctx, - `SELECT id, name, description, status, model_mapping, created_at, updated_at FROM channels ORDER BY id`, + `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, created_at, updated_at FROM channels ORDER BY id`, ) if err != nil { return nil, fmt.Errorf("query all channels: %w", err) @@ -260,7 +260,7 @@ func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, err for rows.Next() { var ch service.Channel var modelMappingJSON []byte - if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.CreatedAt, &ch.UpdatedAt); err != nil { + if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.CreatedAt, &ch.UpdatedAt); err != nil { return nil, fmt.Errorf("scan channel: %w", err) } ch.ModelMapping = unmarshalModelMapping(modelMappingJSON) diff --git a/backend/internal/repository/channel_repo_pricing.go b/backend/internal/repository/channel_repo_pricing.go index 87c856f8..ad903ead 100644 --- a/backend/internal/repository/channel_repo_pricing.go +++ b/backend/internal/repository/channel_repo_pricing.go @@ -15,7 +15,7 @@ import ( func (r *channelRepository) ListModelPricing(ctx context.Context, channelID int64) ([]service.ChannelModelPricing, error) { rows, err := r.db.QueryContext(ctx, - `SELECT id, channel_id, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, created_at, updated_at + `SELECT id, channel_id, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price, created_at, updated_at FROM channel_model_pricing WHERE channel_id = $1 ORDER BY id`, channelID, ) if err != nil { @@ -56,10 +56,10 @@ func (r *channelRepository) UpdateModelPricing(ctx context.Context, pricing *ser } result, err := r.db.ExecContext(ctx, `UPDATE channel_model_pricing - SET models = $1, billing_mode = $2, input_price = $3, output_price = $4, cache_write_price = $5, cache_read_price = $6, image_output_price = $7, updated_at = NOW() - WHERE id = $8`, + SET models = $1, billing_mode = $2, input_price = $3, output_price = $4, cache_write_price = $5, cache_read_price = $6, image_output_price = $7, per_request_price = $8, updated_at = NOW() + WHERE id = $9`, modelsJSON, billingMode, pricing.InputPrice, pricing.OutputPrice, pricing.CacheWritePrice, pricing.CacheReadPrice, - pricing.ImageOutputPrice, pricing.ID, + pricing.ImageOutputPrice, pricing.PerRequestPrice, pricing.ID, ) if err != nil { return fmt.Errorf("update model pricing: %w", err) @@ -90,7 +90,7 @@ func (r *channelRepository) ReplaceModelPricing(ctx context.Context, channelID i // batchLoadModelPricing 批量加载多个渠道的模型定价(含区间) func (r *channelRepository) batchLoadModelPricing(ctx context.Context, channelIDs []int64) (map[int64][]service.ChannelModelPricing, error) { rows, err := r.db.QueryContext(ctx, - `SELECT id, channel_id, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, created_at, updated_at + `SELECT id, channel_id, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price, created_at, updated_at FROM channel_model_pricing WHERE channel_id = ANY($1) ORDER BY channel_id, id`, pq.Array(channelIDs), ) @@ -171,7 +171,7 @@ func scanModelPricingRows(rows *sql.Rows) ([]service.ChannelModelPricing, []int6 if err := rows.Scan( &p.ID, &p.ChannelID, &modelsJSON, &p.BillingMode, &p.InputPrice, &p.OutputPrice, &p.CacheWritePrice, &p.CacheReadPrice, - &p.ImageOutputPrice, &p.CreatedAt, &p.UpdatedAt, + &p.ImageOutputPrice, &p.PerRequestPrice, &p.CreatedAt, &p.UpdatedAt, ); err != nil { return nil, nil, fmt.Errorf("scan model pricing: %w", err) } @@ -224,11 +224,11 @@ func createModelPricingExec(ctx context.Context, exec dbExec, pricing *service.C billingMode = service.BillingModeToken } err = exec.QueryRowContext(ctx, - `INSERT INTO channel_model_pricing (channel_id, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8) RETURNING id, created_at, updated_at`, + `INSERT INTO channel_model_pricing (channel_id, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING id, created_at, updated_at`, pricing.ChannelID, modelsJSON, billingMode, pricing.InputPrice, pricing.OutputPrice, pricing.CacheWritePrice, pricing.CacheReadPrice, - pricing.ImageOutputPrice, + pricing.ImageOutputPrice, pricing.PerRequestPrice, ).Scan(&pricing.ID, &pricing.CreatedAt, &pricing.UpdatedAt) if err != nil { return fmt.Errorf("insert model pricing: %w", err) diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index e4da825b..c40beffd 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -28,7 +28,7 @@ import ( gocache "github.com/patrickmn/go-cache" ) -const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at" +const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, created_at" // usageLogInsertArgTypes must stay in the same order as: // 1. prepareUsageLogInsert().args @@ -77,6 +77,9 @@ var usageLogInsertArgTypes = [...]string{ "text", // inbound_endpoint "text", // upstream_endpoint "boolean", // cache_ttl_overridden + "bigint", // channel_id + "text", // model_mapping_chain + "text", // billing_tier "timestamptz", // created_at } @@ -350,6 +353,9 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, + channel_id, + model_mapping_chain, + billing_tier, created_at ) VALUES ( $1, $2, $3, $4, $5, $6, $7, @@ -357,7 +363,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, - $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40 + $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43 ) ON CONFLICT (request_id, api_key_id) DO NOTHING RETURNING id, created_at @@ -782,10 +788,13 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage inbound_endpoint, upstream_endpoint, cache_ttl_overridden, + channel_id, + model_mapping_chain, + billing_tier, created_at ) AS (VALUES `) - args := make([]any, 0, len(keys)*39) + args := make([]any, 0, len(keys)*44) argPos := 1 for idx, key := range keys { if idx > 0 { @@ -853,6 +862,9 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage inbound_endpoint, upstream_endpoint, cache_ttl_overridden, + channel_id, + model_mapping_chain, + billing_tier, created_at ) SELECT @@ -895,6 +907,9 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage inbound_endpoint, upstream_endpoint, cache_ttl_overridden, + channel_id, + model_mapping_chain, + billing_tier, created_at FROM input ON CONFLICT (request_id, api_key_id) DO NOTHING @@ -977,10 +992,13 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( inbound_endpoint, upstream_endpoint, cache_ttl_overridden, + channel_id, + model_mapping_chain, + billing_tier, created_at ) AS (VALUES `) - args := make([]any, 0, len(preparedList)*40) + args := make([]any, 0, len(preparedList)*43) argPos := 1 for idx, prepared := range preparedList { if idx > 0 { @@ -1045,6 +1063,9 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( inbound_endpoint, upstream_endpoint, cache_ttl_overridden, + channel_id, + model_mapping_chain, + billing_tier, created_at ) SELECT @@ -1087,6 +1108,9 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( inbound_endpoint, upstream_endpoint, cache_ttl_overridden, + channel_id, + model_mapping_chain, + billing_tier, created_at FROM input ON CONFLICT (request_id, api_key_id) DO NOTHING @@ -1137,6 +1161,9 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared inbound_endpoint, upstream_endpoint, cache_ttl_overridden, + channel_id, + model_mapping_chain, + billing_tier, created_at ) VALUES ( $1, $2, $3, $4, $5, $6, $7, @@ -1144,7 +1171,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, - $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40 + $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43 ) ON CONFLICT (request_id, api_key_id) DO NOTHING `, prepared.args...) @@ -1176,6 +1203,9 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared { reasoningEffort := nullString(log.ReasoningEffort) inboundEndpoint := nullString(log.InboundEndpoint) upstreamEndpoint := nullString(log.UpstreamEndpoint) + channelID := nullInt64(log.ChannelID) + modelMappingChain := nullString(log.ModelMappingChain) + billingTier := nullString(log.BillingTier) requestedModel := strings.TrimSpace(log.RequestedModel) if requestedModel == "" { requestedModel = strings.TrimSpace(log.Model) @@ -1232,6 +1262,9 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared { inboundEndpoint, upstreamEndpoint, log.CacheTTLOverridden, + channelID, + modelMappingChain, + billingTier, createdAt, }, } @@ -3959,6 +3992,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e inboundEndpoint sql.NullString upstreamEndpoint sql.NullString cacheTTLOverridden bool + channelID sql.NullInt64 + modelMappingChain sql.NullString + billingTier sql.NullString createdAt time.Time ) @@ -4003,6 +4039,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e &inboundEndpoint, &upstreamEndpoint, &cacheTTLOverridden, + &channelID, + &modelMappingChain, + &billingTier, &createdAt, ); err != nil { return nil, err @@ -4087,6 +4126,16 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e if upstreamModel.Valid { log.UpstreamModel = &upstreamModel.String } + if channelID.Valid { + value := channelID.Int64 + log.ChannelID = &value + } + if modelMappingChain.Valid { + log.ModelMappingChain = &modelMappingChain.String + } + if billingTier.Valid { + log.BillingTier = &billingTier.String + } return log, nil } diff --git a/backend/internal/service/channel.go b/backend/internal/service/channel.go index 7b43b18b..ccb5473f 100644 --- a/backend/internal/service/channel.go +++ b/backend/internal/service/channel.go @@ -23,14 +23,21 @@ func (m BillingMode) IsValid() bool { return false } +const ( + BillingModelSourceRequested = "requested" + BillingModelSourceUpstream = "upstream" +) + // Channel 渠道实体 type Channel struct { - ID int64 - Name string - Description string - Status string - CreatedAt time.Time - UpdatedAt time.Time + ID int64 + Name string + Description string + Status string + BillingModelSource string // "requested" or "upstream" + RestrictModels bool // 是否限制模型(仅允许定价列表中的模型) + CreatedAt time.Time + UpdatedAt time.Time // 关联的分组 ID 列表 GroupIDs []int64 @@ -44,13 +51,14 @@ type Channel struct { type ChannelModelPricing struct { ID int64 ChannelID int64 - Models []string // 绑定的模型列表 - BillingMode BillingMode // 计费模式 - InputPrice *float64 // 每 token 输入价格(USD)— 向后兼容 flat 定价 - OutputPrice *float64 // 每 token 输出价格(USD) - CacheWritePrice *float64 // 缓存写入价格 - CacheReadPrice *float64 // 缓存读取价格 - ImageOutputPrice *float64 // 图片输出价格(向后兼容) + Models []string // 绑定的模型列表 + BillingMode BillingMode // 计费模式 + InputPrice *float64 // 每 token 输入价格(USD)— 向后兼容 flat 定价 + OutputPrice *float64 // 每 token 输出价格(USD) + CacheWritePrice *float64 // 缓存写入价格 + CacheReadPrice *float64 // 缓存读取价格 + ImageOutputPrice *float64 // 图片输出价格(向后兼容) + PerRequestPrice *float64 // 默认按次计费价格(USD) Intervals []PricingInterval // 区间定价列表 CreatedAt time.Time UpdatedAt time.Time @@ -106,12 +114,10 @@ func (c *Channel) IsActive() bool { } // GetModelPricing 根据模型名查找渠道定价,未找到返回 nil。 -// 优先精确匹配,然后通配符匹配(如 claude-opus-*)。大小写不敏感。 -// 返回值拷贝,不污染缓存。 +// 精确匹配,大小写不敏感。返回值拷贝,不污染缓存。 func (c *Channel) GetModelPricing(model string) *ChannelModelPricing { modelLower := strings.ToLower(model) - // 第一轮:精确匹配 for i := range c.ModelPricing { for _, m := range c.ModelPricing[i].Models { if strings.ToLower(m) == modelLower { @@ -121,20 +127,6 @@ func (c *Channel) GetModelPricing(model string) *ChannelModelPricing { } } - // 第二轮:通配符匹配(仅支持末尾 *) - for i := range c.ModelPricing { - for _, m := range c.ModelPricing[i].Models { - mLower := strings.ToLower(m) - if strings.HasSuffix(mLower, "*") { - prefix := strings.TrimSuffix(mLower, "*") - if strings.HasPrefix(modelLower, prefix) { - cp := c.ModelPricing[i].Clone() - return &cp - } - } - } - } - return nil } diff --git a/backend/internal/service/channel_service.go b/backend/internal/service/channel_service.go index adf1a64f..91b69ad2 100644 --- a/backend/internal/service/channel_service.go +++ b/backend/internal/service/channel_service.go @@ -47,13 +47,30 @@ type ChannelRepository interface { ReplaceModelPricing(ctx context.Context, channelID int64, pricingList []ChannelModelPricing) error } -// channelCache 渠道缓存快照 +// channelModelKey 渠道缓存复合键 +type channelModelKey struct { + groupID int64 + model string // lowercase +} + +// channelCache 渠道缓存快照(扁平化哈希结构,热路径 O(1) 查找) type channelCache struct { - // byID: channelID -> *Channel(含 ModelPricing) - byID map[int64]*Channel - // byGroupID: groupID -> channelID - byGroupID map[int64]int64 - loadedAt time.Time + // 热路径查找 + pricingByGroupModel map[channelModelKey]*ChannelModelPricing // (groupID, model) → 定价 + mappingByGroupModel map[channelModelKey]string // (groupID, model) → 映射目标 + channelByGroupID map[int64]*Channel // groupID → 渠道 + + // 冷路径(CRUD 操作) + byID map[int64]*Channel + loadedAt time.Time +} + +// ChannelMappingResult 渠道映射查找结果 +type ChannelMappingResult struct { + MappedModel string // 映射后的模型名(无映射时等于原始模型名) + ChannelID int64 // 渠道 ID(0 = 无渠道关联) + Mapped bool // 是否发生了映射 + BillingModelSource string // 计费模型来源("requested" / "upstream") } const ( @@ -115,25 +132,46 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) // error-TTL:失败时存入短 TTL 空缓存,防止紧密重试 slog.Warn("failed to build channel cache", "error", err) errorCache := &channelCache{ - byID: make(map[int64]*Channel), - byGroupID: make(map[int64]int64), - loadedAt: time.Now().Add(channelCacheTTL - channelErrorTTL), // 使剩余 TTL = errorTTL + pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing), + mappingByGroupModel: make(map[channelModelKey]string), + channelByGroupID: make(map[int64]*Channel), + byID: make(map[int64]*Channel), + loadedAt: time.Now().Add(channelCacheTTL - channelErrorTTL), // 使剩余 TTL = errorTTL } s.cache.Store(errorCache) return nil, fmt.Errorf("list all channels: %w", err) } cache := &channelCache{ - byID: make(map[int64]*Channel, len(channels)), - byGroupID: make(map[int64]int64), - loadedAt: time.Now(), + pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing), + mappingByGroupModel: make(map[channelModelKey]string), + channelByGroupID: make(map[int64]*Channel), + byID: make(map[int64]*Channel, len(channels)), + loadedAt: time.Now(), } for i := range channels { ch := &channels[i] cache.byID[ch.ID] = ch + + // 展开到分组维度 for _, gid := range ch.GroupIDs { - cache.byGroupID[gid] = ch.ID + cache.channelByGroupID[gid] = ch + + // 展开模型定价到 (groupID, model) → *ChannelModelPricing + for j := range ch.ModelPricing { + pricing := &ch.ModelPricing[j] + for _, model := range pricing.Models { + key := channelModelKey{groupID: gid, model: strings.ToLower(model)} + cache.pricingByGroupModel[key] = pricing + } + } + + // 展开模型映射到 (groupID, model) → target + for src, dst := range ch.ModelMapping { + key := channelModelKey{groupID: gid, model: strings.ToLower(src)} + cache.mappingByGroupModel[key] = dst + } } } @@ -147,42 +185,94 @@ func (s *ChannelService) invalidateCache() { s.cacheSF.Forget("channel_cache") } -// GetChannelForGroup 获取分组关联的渠道(热路径,从缓存读取) -// 返回深拷贝,不污染缓存。 +// GetChannelForGroup 获取分组关联的渠道(热路径 O(1)) func (s *ChannelService) GetChannelForGroup(ctx context.Context, groupID int64) (*Channel, error) { cache, err := s.loadCache(ctx) if err != nil { return nil, err } - channelID, ok := cache.byGroupID[groupID] - if !ok { - return nil, nil - } - - ch, ok := cache.byID[channelID] - if !ok { - return nil, nil - } - - if !ch.IsActive() { + ch, ok := cache.channelByGroupID[groupID] + if !ok || !ch.IsActive() { return nil, nil } return ch.Clone(), nil } -// GetChannelModelPricing 获取指定分组+模型的渠道定价(热路径) +// GetChannelModelPricing 获取指定分组+模型的渠道定价(热路径 O(1)) func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int64, model string) *ChannelModelPricing { - ch, err := s.GetChannelForGroup(ctx, groupID) + cache, err := s.loadCache(ctx) if err != nil { - slog.Warn("failed to get channel for group", "group_id", groupID, "error", err) + slog.Warn("failed to load channel cache", "group_id", groupID, "error", err) return nil } - if ch == nil { + + // 检查渠道是否启用 + ch, ok := cache.channelByGroupID[groupID] + if !ok || !ch.IsActive() { return nil } - return ch.GetModelPricing(model) + + key := channelModelKey{groupID: groupID, model: strings.ToLower(model)} + pricing, ok := cache.pricingByGroupModel[key] + if !ok { + return nil + } + + cp := pricing.Clone() + return &cp +} + +// ResolveChannelMapping 解析渠道级模型映射(热路径 O(1)) +// 返回映射结果,包含映射后的模型名、渠道 ID、计费模型来源。 +func (s *ChannelService) ResolveChannelMapping(ctx context.Context, groupID int64, model string) ChannelMappingResult { + cache, err := s.loadCache(ctx) + if err != nil { + return ChannelMappingResult{MappedModel: model} + } + + ch, ok := cache.channelByGroupID[groupID] + if !ok || !ch.IsActive() { + return ChannelMappingResult{MappedModel: model} + } + + result := ChannelMappingResult{ + MappedModel: model, + ChannelID: ch.ID, + BillingModelSource: ch.BillingModelSource, + } + if result.BillingModelSource == "" { + result.BillingModelSource = BillingModelSourceRequested + } + + key := channelModelKey{groupID: groupID, model: strings.ToLower(model)} + if mapped, ok := cache.mappingByGroupModel[key]; ok { + result.MappedModel = mapped + result.Mapped = true + } + + return result +} + +// IsModelRestricted 检查模型是否被渠道限制。 +// 返回 true 表示模型被限制(不在允许列表中)。 +// 如果渠道未启用模型限制或分组无渠道关联,返回 false。 +func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool { + cache, err := s.loadCache(ctx) + if err != nil { + return false // 缓存加载失败时不限制 + } + + ch, ok := cache.channelByGroupID[groupID] + if !ok || !ch.IsActive() || !ch.RestrictModels { + return false + } + + // 检查模型是否在定价列表中 + key := channelModelKey{groupID: groupID, model: strings.ToLower(model)} + _, exists := cache.pricingByGroupModel[key] + return !exists } // --- CRUD --- @@ -209,12 +299,17 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput) } channel := &Channel{ - Name: input.Name, - Description: input.Description, - Status: StatusActive, - GroupIDs: input.GroupIDs, - ModelPricing: input.ModelPricing, - ModelMapping: input.ModelMapping, + Name: input.Name, + Description: input.Description, + Status: StatusActive, + BillingModelSource: input.BillingModelSource, + RestrictModels: input.RestrictModels, + GroupIDs: input.GroupIDs, + ModelPricing: input.ModelPricing, + ModelMapping: input.ModelMapping, + } + if channel.BillingModelSource == "" { + channel.BillingModelSource = BillingModelSourceRequested } if err := validateNoDuplicateModels(channel.ModelPricing); err != nil { @@ -260,6 +355,10 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan channel.Status = input.Status } + if input.RestrictModels != nil { + channel.RestrictModels = *input.RestrictModels + } + // 检查分组冲突 if input.GroupIDs != nil { conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, id, *input.GroupIDs) @@ -280,6 +379,10 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan channel.ModelMapping = input.ModelMapping } + if input.BillingModelSource != "" { + channel.BillingModelSource = input.BillingModelSource + } + if err := validateNoDuplicateModels(channel.ModelPricing); err != nil { return nil, err } @@ -351,19 +454,23 @@ func validateNoDuplicateModels(pricingList []ChannelModelPricing) error { // CreateChannelInput 创建渠道输入 type CreateChannelInput struct { - Name string - Description string - GroupIDs []int64 - ModelPricing []ChannelModelPricing - ModelMapping map[string]string + Name string + Description string + GroupIDs []int64 + ModelPricing []ChannelModelPricing + ModelMapping map[string]string + BillingModelSource string + RestrictModels bool } // UpdateChannelInput 更新渠道输入 type UpdateChannelInput struct { - Name string - Description *string - Status string - GroupIDs *[]int64 - ModelPricing *[]ChannelModelPricing - ModelMapping map[string]string + Name string + Description *string + Status string + GroupIDs *[]int64 + ModelPricing *[]ChannelModelPricing + ModelMapping map[string]string + BillingModelSource string + RestrictModels *bool } diff --git a/backend/internal/service/channel_test.go b/backend/internal/service/channel_test.go index 0c055ce4..46cf193a 100644 --- a/backend/internal/service/channel_test.go +++ b/backend/internal/service/channel_test.go @@ -15,7 +15,6 @@ func TestGetModelPricing(t *testing.T) { ch := &Channel{ ModelPricing: []ChannelModelPricing{ {ID: 1, Models: []string{"claude-sonnet-4"}, BillingMode: BillingModeToken, InputPrice: channelTestPtrFloat64(3e-6)}, - {ID: 2, Models: []string{"claude-*"}, BillingMode: BillingModeToken, InputPrice: channelTestPtrFloat64(5e-6)}, {ID: 3, Models: []string{"gpt-5.1"}, BillingMode: BillingModePerRequest}, }, } @@ -28,9 +27,8 @@ func TestGetModelPricing(t *testing.T) { }{ {"exact match", "claude-sonnet-4", 1, false}, {"case insensitive", "Claude-Sonnet-4", 1, false}, - {"wildcard match", "claude-opus-4-20250514", 2, false}, - {"exact takes priority over wildcard", "claude-sonnet-4", 1, false}, {"not found", "gemini-3.1-pro", 0, true}, + {"wildcard pattern not matched", "claude-opus-4-20250514", 0, true}, {"per_request model", "gpt-5.1", 3, false}, } diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 69d218f8..89bcb75c 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -7413,6 +7413,12 @@ type RecordUsageInput struct { RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险 ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换) APIKeyService APIKeyQuotaUpdater // 可选:用于更新API Key配额 + + // 渠道映射信息(由 handler 在 Forward 前解析) + ChannelID int64 // 渠道 ID(0 = 无渠道) + OriginalModel string // 用户原始请求模型(渠道映射前) + BillingModelSource string // 计费模型来源:"requested" / "upstream" + ModelMappingChain string // 映射链描述,如 "a→b→c" } // APIKeyQuotaUpdater defines the interface for updating API Key quota and rate limit usage @@ -7732,7 +7738,17 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu } var cost *CostBreakdown + // 确定计费模型 billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel) + if input.BillingModelSource == BillingModelSourceRequested && input.OriginalModel != "" { + billingModel = input.OriginalModel + } + + // 确定 RequestedModel(渠道映射前的原始模型) + requestedModel := result.Model + if input.OriginalModel != "" { + requestedModel = input.OriginalModel + } // 根据请求类型选择计费方式 if result.MediaType == "image" || result.MediaType == "video" { @@ -7815,7 +7831,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu AccountID: account.ID, RequestID: requestID, Model: result.Model, - RequestedModel: result.Model, + RequestedModel: requestedModel, UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model), ReasoningEffort: result.ReasoningEffort, InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint), @@ -7842,6 +7858,8 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ImageSize: imageSize, MediaType: mediaType, CacheTTLOverridden: cacheTTLOverridden, + ChannelID: optionalInt64Ptr(input.ChannelID), + ModelMappingChain: optionalTrimmedStringPtr(input.ModelMappingChain), CreatedAt: time.Now(), } @@ -7909,6 +7927,12 @@ type RecordUsageLongContextInput struct { LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0) ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换) APIKeyService APIKeyQuotaUpdater // API Key 配额服务(可选) + + // 渠道映射信息(由 handler 在 Forward 前解析) + ChannelID int64 // 渠道 ID(0 = 无渠道) + OriginalModel string // 用户原始请求模型(渠道映射前) + BillingModelSource string // 计费模型来源:"requested" / "upstream" + ModelMappingChain string // 映射链描述,如 "a→b→c" } // RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini) @@ -7946,7 +7970,17 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * } var cost *CostBreakdown + // 确定计费模型 billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel) + if input.BillingModelSource == BillingModelSourceRequested && input.OriginalModel != "" { + billingModel = input.OriginalModel + } + + // 确定 RequestedModel(渠道映射前的原始模型) + requestedModel := result.Model + if input.OriginalModel != "" { + requestedModel = input.OriginalModel + } // 根据请求类型选择计费方式 if result.ImageCount > 0 { @@ -8008,7 +8042,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * AccountID: account.ID, RequestID: requestID, Model: result.Model, - RequestedModel: result.Model, + RequestedModel: requestedModel, UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model), ReasoningEffort: result.ReasoningEffort, InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint), @@ -8034,6 +8068,8 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * ImageCount: result.ImageCount, ImageSize: imageSize, CacheTTLOverridden: cacheTTLOverridden, + ChannelID: optionalInt64Ptr(input.ChannelID), + ModelMappingChain: optionalTrimmedStringPtr(input.ModelMappingChain), CreatedAt: time.Now(), } @@ -8085,6 +8121,27 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * return nil } +// ResolveChannelMapping 委托渠道服务解析模型映射 +func (s *GatewayService) ResolveChannelMapping(ctx context.Context, groupID int64, model string) ChannelMappingResult { + if s.channelService == nil { + return ChannelMappingResult{MappedModel: model} + } + return s.channelService.ResolveChannelMapping(ctx, groupID, model) +} + +// ReplaceModelInBody 替换请求体中的模型名(导出供 handler 使用) +func (s *GatewayService) ReplaceModelInBody(body []byte, newModel string) []byte { + return s.replaceModelInBody(body, newModel) +} + +// IsModelRestricted 检查模型是否被渠道限制 +func (s *GatewayService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool { + if s.channelService == nil { + return false + } + return s.channelService.IsModelRestricted(ctx, groupID, model) +} + // ForwardCountTokens 转发 count_tokens 请求到上游 API // 特点:不记录使用量、仅支持非流式响应 func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error { diff --git a/backend/internal/service/usage_log.go b/backend/internal/service/usage_log.go index 576841fa..da5773f5 100644 --- a/backend/internal/service/usage_log.go +++ b/backend/internal/service/usage_log.go @@ -104,6 +104,12 @@ type UsageLog struct { // UpstreamModel is the actual model sent to the upstream provider after mapping. // Nil means no mapping was applied (requested model was used as-is). UpstreamModel *string + // ChannelID 渠道 ID + ChannelID *int64 + // ModelMappingChain 模型映射链,如 "a→b→c" + ModelMappingChain *string + // BillingTier 计费层级标签(per_request/image 模式) + BillingTier *string // ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex". ServiceTier *string // ReasoningEffort is the request's reasoning effort level. diff --git a/backend/internal/service/usage_log_helpers.go b/backend/internal/service/usage_log_helpers.go index a7bcae99..7cc8a713 100644 --- a/backend/internal/service/usage_log_helpers.go +++ b/backend/internal/service/usage_log_helpers.go @@ -26,3 +26,10 @@ func forwardResultBillingModel(requestedModel, upstreamModel string) string { } return strings.TrimSpace(upstreamModel) } + +func optionalInt64Ptr(v int64) *int64 { + if v == 0 { + return nil + } + return &v +} diff --git a/backend/migrations/084_channel_billing_model_source.sql b/backend/migrations/084_channel_billing_model_source.sql new file mode 100644 index 00000000..bd615bac --- /dev/null +++ b/backend/migrations/084_channel_billing_model_source.sql @@ -0,0 +1,7 @@ +-- Add billing_model_source to channels (controls whether billing uses requested or upstream model) +ALTER TABLE channels ADD COLUMN IF NOT EXISTS billing_model_source VARCHAR(20) DEFAULT 'requested'; + +-- Add channel tracking fields to usage_logs +ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS channel_id BIGINT; +ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS model_mapping_chain VARCHAR(500); +ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS billing_tier VARCHAR(50); diff --git a/backend/migrations/085_channel_restrict_and_per_request_price.sql b/backend/migrations/085_channel_restrict_and_per_request_price.sql new file mode 100644 index 00000000..2f494c63 --- /dev/null +++ b/backend/migrations/085_channel_restrict_and_per_request_price.sql @@ -0,0 +1,5 @@ +-- Add model restriction switch to channels +ALTER TABLE channels ADD COLUMN IF NOT EXISTS restrict_models BOOLEAN DEFAULT false; + +-- Add default per_request_price to channel_model_pricing (fallback when no tier matches) +ALTER TABLE channel_model_pricing ADD COLUMN IF NOT EXISTS per_request_price NUMERIC(20,10); diff --git a/frontend/src/api/admin/channels.ts b/frontend/src/api/admin/channels.ts index 0b86fcaa..23244a4f 100644 --- a/frontend/src/api/admin/channels.ts +++ b/frontend/src/api/admin/channels.ts @@ -29,6 +29,7 @@ export interface ChannelModelPricing { cache_write_price: number | null cache_read_price: number | null image_output_price: number | null + per_request_price: number | null intervals: PricingInterval[] } @@ -37,8 +38,11 @@ export interface Channel { name: string description: string status: string + billing_model_source: string // "requested" | "upstream" + restrict_models: boolean group_ids: number[] model_pricing: ChannelModelPricing[] + model_mapping: Record created_at: string updated_at: string } @@ -48,6 +52,9 @@ export interface CreateChannelRequest { description?: string group_ids?: number[] model_pricing?: ChannelModelPricing[] + model_mapping?: Record + billing_model_source?: string + restrict_models?: boolean } export interface UpdateChannelRequest { @@ -56,6 +63,9 @@ export interface UpdateChannelRequest { status?: string group_ids?: number[] model_pricing?: ChannelModelPricing[] + model_mapping?: Record + billing_model_source?: string + restrict_models?: boolean } interface PaginatedResponse { diff --git a/frontend/src/components/admin/channel/ModelTagInput.vue b/frontend/src/components/admin/channel/ModelTagInput.vue index a9ccf630..d9002748 100644 --- a/frontend/src/components/admin/channel/ModelTagInput.vue +++ b/frontend/src/components/admin/channel/ModelTagInput.vue @@ -29,7 +29,7 @@ />

- {{ t('admin.channels.form.modelInputHint', 'Press Enter to add. Supports wildcard *.') }} + {{ t('admin.channels.form.modelInputHint', 'Press Enter to add, supports paste for batch import.') }}

diff --git a/frontend/src/components/admin/channel/PricingEntryCard.vue b/frontend/src/components/admin/channel/PricingEntryCard.vue index b0238a19..6e97676a 100644 --- a/frontend/src/components/admin/channel/PricingEntryCard.vue +++ b/frontend/src/components/admin/channel/PricingEntryCard.vue @@ -70,7 +70,7 @@
+ + +
+ +
+ +
- +
+ + +
+ +
+ +
-
-
-
- - -
-
-
diff --git a/frontend/src/components/admin/channel/types.ts b/frontend/src/components/admin/channel/types.ts index 550a3cbd..3ddc94e7 100644 --- a/frontend/src/components/admin/channel/types.ts +++ b/frontend/src/components/admin/channel/types.ts @@ -20,6 +20,7 @@ export interface PricingFormEntry { cache_write_price: number | string | null cache_read_price: number | string | null image_output_price: number | string | null + per_request_price: number | string | null intervals: IntervalFormEntry[] } diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 679f8290..18cd812d 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -1770,8 +1770,8 @@ export default { inOtherChannel: 'In "{name}"', modelPricing: 'Model Pricing', models: 'Models', - modelsPlaceholder: 'Type model name and press Enter. Supports wildcard *', - modelInputHint: 'Press Enter to add. Supports paste and wildcard *.', + modelsPlaceholder: 'Type full model name and press Enter', + modelInputHint: 'Press Enter to add, supports paste for batch import.', billingMode: 'Billing Mode', defaultPrices: 'Default prices (fallback when no interval matches)', inputPrice: 'Input', @@ -1790,7 +1790,23 @@ export default { noPricingRules: 'No pricing rules yet. Click "Add" to create one.', perRequestPrice: 'Price per Request', tierLabel: 'Tier', - resolution: 'Resolution' + resolution: 'Resolution', + modelMapping: 'Model Mapping', + modelMappingHint: 'Map request model names to actual model names. Runs before account-level mapping.', + noMappingRules: 'No mapping rules. Click "Add" to create one.', + mappingSource: 'Source model', + mappingTarget: 'Target model', + billingModelSource: 'Billing Model', + billingModelSourceRequested: 'Bill by requested model', + billingModelSourceUpstream: 'Bill by final upstream model', + billingModelSourceHint: 'Controls which model name is used for pricing lookup', + selectedCount: '{count} selected', + searchGroups: 'Search groups...', + noGroupsMatch: 'No groups match your search', + restrictModels: 'Restrict Models', + restrictModelsHint: 'When enabled, only models in the pricing list are allowed. Others will be rejected.', + defaultPerRequestPrice: 'Default per-request price (fallback when no tier matches)', + defaultImagePrice: 'Default image price (fallback when no tier matches)' } }, diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index dfb859a5..f410fa4b 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -1850,8 +1850,8 @@ export default { inOtherChannel: '已属于「{name}」', modelPricing: '模型定价', models: '模型列表', - modelsPlaceholder: '输入模型名后按回车添加,支持通配符 *', - modelInputHint: '按回车添加,支持粘贴批量导入,支持通配符 *', + modelsPlaceholder: '输入完整模型名后按回车添加', + modelInputHint: '按回车添加,支持粘贴批量导入', billingMode: '计费模式', defaultPrices: '默认价格(未命中区间时使用)', inputPrice: '输入', @@ -1870,7 +1870,23 @@ export default { noPricingRules: '暂无定价规则,点击"添加"创建', perRequestPrice: '单次价格', tierLabel: '层级', - resolution: '分辨率' + resolution: '分辨率', + modelMapping: '模型映射', + modelMappingHint: '将请求中的模型名映射为实际模型名。在账号级别映射之前执行。', + noMappingRules: '暂无映射规则,点击"添加"创建', + mappingSource: '源模型', + mappingTarget: '目标模型', + billingModelSource: '计费模型', + billingModelSourceRequested: '以请求模型计费', + billingModelSourceUpstream: '以最终模型计费', + billingModelSourceHint: '控制使用哪个模型名称进行定价查找', + selectedCount: '已选 {count} 个', + searchGroups: '搜索分组...', + noGroupsMatch: '没有匹配的分组', + restrictModels: '限制模型', + restrictModelsHint: '开启后,仅允许模型定价列表中的模型。不在列表中的模型请求将被拒绝。', + defaultPerRequestPrice: '默认单次价格(未命中层级时使用)', + defaultImagePrice: '默认图片价格(未命中层级时使用)' } }, diff --git a/frontend/src/views/admin/ChannelsView.vue b/frontend/src/views/admin/ChannelsView.vue index 9af183fa..579d7578 100644 --- a/frontend/src/views/admin/ChannelsView.vue +++ b/frontend/src/views/admin/ChannelsView.vue @@ -146,7 +146,7 @@
- +
+ +
+ +

+ {{ t('admin.channels.form.restrictModelsHint', 'When enabled, only models in the pricing list are allowed. Others will be rejected.') }} +

+
+
- +
{{ t('common.loading', 'Loading...') }} @@ -198,38 +218,61 @@
{{ t('admin.channels.form.noGroupsAvailable', 'No groups available') }}
-
+
{{ t('admin.channels.form.noGroupsMatch', 'No groups match your search') }}
-
- +
+ + +
+ + + + + +
+
+