From eb385457b2e878a63457d0739e54bb8cfecc70d3 Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 31 Mar 2026 15:26:20 +0800 Subject: [PATCH] =?UTF-8?q?fix(channel):=20=E5=85=A8=E5=B9=B3=E5=8F=B0?= =?UTF-8?q?=E6=B8=A0=E9=81=93=E6=98=A0=E5=B0=84=E8=A6=86=E7=9B=96=20+=20?= =?UTF-8?q?=E5=85=AC=E5=85=B1=E5=87=BD=E6=95=B0=E6=8A=BD=E5=8F=96=20+=20?= =?UTF-8?q?=E6=AD=BB=E4=BB=A3=E7=A0=81=E6=B8=85=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 4个缺失handler入口添加渠道映射+限制检查(ChatCompletions/Responses/Gemini) - 模型限制错误信息优化,区分"模型不可用"和"无账号" - OpenAI RecordUsage RequestedModel 改用 OriginalModel - ResolveChannelMappingAndRestrict/ReplaceModelInBody 抽取到 ChannelService 消除跨service重复 - validateNoDuplicateModels 按 platform:model 去重 - 删除 Channel.ResolveMappedModel 死代码和 CalculateCostWithChannel Deprecated方法 - 移除冗余nil检查,抽取 validatePricingBillingMode 公共校验 --- .../internal/handler/admin/channel_handler.go | 33 ++++++++------- backend/internal/handler/gateway_handler.go | 2 +- .../gateway_handler_chat_completions.go | 17 +++++++- .../handler/gateway_handler_responses.go | 17 +++++++- .../internal/handler/gemini_v1beta_handler.go | 15 +++++++ .../handler/openai_chat_completions.go | 37 ++++++++++++----- .../handler/openai_gateway_handler.go | 4 +- backend/internal/service/billing_service.go | 6 --- backend/internal/service/channel.go | 32 --------------- backend/internal/service/channel_service.go | 40 ++++++++++++++++--- backend/internal/service/gateway_service.go | 31 ++++---------- .../service/openai_gateway_service.go | 31 ++++++-------- 12 files changed, 149 insertions(+), 116 deletions(-) diff --git a/backend/internal/handler/admin/channel_handler.go b/backend/internal/handler/admin/channel_handler.go index e5a3eac5..6c460d8e 100644 --- a/backend/internal/handler/admin/channel_handler.go +++ b/backend/internal/handler/admin/channel_handler.go @@ -1,6 +1,7 @@ package admin import ( + "errors" "strconv" "strings" @@ -224,6 +225,18 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe return result } +// validatePricingBillingMode 校验按次/图片计费模式必须配置 PerRequestPrice 或 Intervals +func validatePricingBillingMode(pricing []service.ChannelModelPricing) error { + for _, p := range pricing { + if p.BillingMode == service.BillingModePerRequest || p.BillingMode == service.BillingModeImage { + if p.PerRequestPrice == nil && len(p.Intervals) == 0 { + return errors.New("Per-request price or intervals required for per_request/image billing mode") + } + } + } + return nil +} + // --- Handlers --- // List handles listing channels with pagination @@ -277,13 +290,9 @@ func (h *ChannelHandler) Create(c *gin.Context) { } pricing := pricingRequestToService(req.ModelPricing) - for _, p := range pricing { - if p.BillingMode == service.BillingModePerRequest || p.BillingMode == service.BillingModeImage { - if p.PerRequestPrice == nil && len(p.Intervals) == 0 { - response.BadRequest(c, "Per-request price or intervals required for per_request/image billing mode") - return - } - } + if err := validatePricingBillingMode(pricing); err != nil { + response.BadRequest(c, err.Error()) + return } channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{ @@ -329,13 +338,9 @@ func (h *ChannelHandler) Update(c *gin.Context) { } if req.ModelPricing != nil { pricing := pricingRequestToService(*req.ModelPricing) - for _, p := range pricing { - if p.BillingMode == service.BillingModePerRequest || p.BillingMode == service.BillingModeImage { - if p.PerRequestPrice == nil && len(p.Intervals) == 0 { - response.BadRequest(c, "Per-request price or intervals required for per_request/image billing mode") - return - } - } + if err := validatePricingBillingMode(pricing); err != nil { + response.BadRequest(c, err.Error()) + return } input.ModelPricing = &pricing } diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 89a791fd..8f66ad03 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -161,7 +161,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 解析渠道级模型映射 + 限制检查 channelMapping, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) if restricted { - h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts") + h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "The requested model is not available for this API key") return } diff --git a/backend/internal/handler/gateway_handler_chat_completions.go b/backend/internal/handler/gateway_handler_chat_completions.go index da376036..f0f16131 100644 --- a/backend/internal/handler/gateway_handler_chat_completions.go +++ b/backend/internal/handler/gateway_handler_chat_completions.go @@ -80,6 +80,13 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) { setOpsRequestContext(c, reqModel, reqStream, body) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) + // 解析渠道级模型映射 + 限制检查 + channelMapping, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) + if restricted { + h.chatCompletionsErrorResponse(c, http.StatusServiceUnavailable, "api_error", "The requested model is not available for this API key") + return + } + // Claude Code only restriction if apiKey.Group != nil && apiKey.Group.ClaudeCodeOnly { h.chatCompletionsErrorResponse(c, http.StatusForbidden, "permission_error", @@ -203,7 +210,11 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) { // 5. Forward request writerSizeBeforeForward := c.Writer.Size() - result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, parsedReq) + forwardBody := body + if channelMapping.Mapped { + forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel) + } + result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, parsedReq) if accountReleaseFunc != nil { accountReleaseFunc() @@ -255,6 +266,10 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) { IPAddress: clientIP, RequestPayloadHash: requestPayloadHash, APIKeyService: h.apiKeyService, + ChannelID: channelMapping.ChannelID, + OriginalModel: reqModel, + BillingModelSource: channelMapping.BillingModelSource, + ModelMappingChain: channelMapping.BuildModelMappingChain(reqModel, result.UpstreamModel), }); err != nil { reqLog.Error("gateway.cc.record_usage_failed", zap.Int64("account_id", account.ID), diff --git a/backend/internal/handler/gateway_handler_responses.go b/backend/internal/handler/gateway_handler_responses.go index d146d724..1e9cdc02 100644 --- a/backend/internal/handler/gateway_handler_responses.go +++ b/backend/internal/handler/gateway_handler_responses.go @@ -80,6 +80,13 @@ func (h *GatewayHandler) Responses(c *gin.Context) { setOpsRequestContext(c, reqModel, reqStream, body) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) + // 解析渠道级模型映射 + 限制检查 + channelMapping, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) + if restricted { + h.responsesErrorResponse(c, http.StatusServiceUnavailable, "api_error", "The requested model is not available for this API key") + return + } + // Claude Code only restriction: // /v1/responses is never a Claude Code endpoint. // When claude_code_only is enabled, this endpoint is rejected. @@ -208,7 +215,11 @@ func (h *GatewayHandler) Responses(c *gin.Context) { // 5. Forward request writerSizeBeforeForward := c.Writer.Size() - result, err := h.gatewayService.ForwardAsResponses(c.Request.Context(), c, account, body, parsedReq) + forwardBody := body + if channelMapping.Mapped { + forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel) + } + result, err := h.gatewayService.ForwardAsResponses(c.Request.Context(), c, account, forwardBody, parsedReq) if accountReleaseFunc != nil { accountReleaseFunc() @@ -261,6 +272,10 @@ func (h *GatewayHandler) Responses(c *gin.Context) { IPAddress: clientIP, RequestPayloadHash: requestPayloadHash, APIKeyService: h.apiKeyService, + ChannelID: channelMapping.ChannelID, + OriginalModel: reqModel, + BillingModelSource: channelMapping.BillingModelSource, + ModelMappingChain: channelMapping.BuildModelMappingChain(reqModel, result.UpstreamModel), }); err != nil { reqLog.Error("gateway.responses.record_usage_failed", zap.Int64("account_id", account.ID), diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 524c6b6d..7c1386b8 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -184,6 +184,17 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { setOpsRequestContext(c, modelName, stream, body) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(stream, false))) + // 解析渠道级模型映射 + 限制检查 + channelMapping, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, modelName) + if restricted { + googleError(c, http.StatusServiceUnavailable, "The requested model is not available for this API key") + return + } + reqModel := modelName // 保存映射前的原始模型名 + if channelMapping.Mapped { + modelName = channelMapping.MappedModel + } + // Get subscription (may be nil) subscription, _ := middleware.GetSubscriptionFromContext(c) @@ -523,6 +534,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { LongContextMultiplier: 2.0, // 超出部分双倍计费 ForceCacheBilling: fs.ForceCacheBilling, APIKeyService: h.apiKeyService, + ChannelID: channelMapping.ChannelID, + OriginalModel: reqModel, + BillingModelSource: channelMapping.BillingModelSource, + ModelMappingChain: channelMapping.BuildModelMappingChain(reqModel, result.UpstreamModel), }); err != nil { logger.L().With( zap.String("component", "handler.gemini_v1beta.models"), diff --git a/backend/internal/handler/openai_chat_completions.go b/backend/internal/handler/openai_chat_completions.go index 0c94aa21..a117c3be 100644 --- a/backend/internal/handler/openai_chat_completions.go +++ b/backend/internal/handler/openai_chat_completions.go @@ -79,6 +79,13 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { setOpsRequestContext(c, reqModel, reqStream, body) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) + // 解析渠道级模型映射 + 限制检查 + channelMapping, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) + if restricted { + h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "The requested model is not available for this API key") + return + } + if h.errorPassthroughService != nil { service.BindErrorPassthroughService(c, h.errorPassthroughService) } @@ -183,7 +190,11 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { forwardStart := time.Now() defaultMappedModel := resolveOpenAIForwardDefaultMappedModel(apiKey, c.GetString("openai_chat_completions_fallback_model")) - result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel) + forwardBody := body + if channelMapping.Mapped { + forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel) + } + result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, promptCacheKey, defaultMappedModel) forwardDurationMs := time.Since(forwardStart).Milliseconds() if accountReleaseFunc != nil { @@ -257,16 +268,20 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ - Result: result, - APIKey: apiKey, - User: apiKey.User, - Account: account, - Subscription: subscription, - InboundEndpoint: GetInboundEndpoint(c), - UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform), - UserAgent: userAgent, - IPAddress: clientIP, - APIKeyService: h.apiKeyService, + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + InboundEndpoint: GetInboundEndpoint(c), + UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform), + UserAgent: userAgent, + IPAddress: clientIP, + APIKeyService: h.apiKeyService, + ChannelID: channelMapping.ChannelID, + OriginalModel: reqModel, + BillingModelSource: channelMapping.BillingModelSource, + ModelMappingChain: channelMapping.BuildModelMappingChain(reqModel, result.UpstreamModel), }); err != nil { logger.L().With( zap.String("component", "handler.openai_gateway.chat_completions"), diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 17f2fe82..70198a53 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -188,7 +188,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { // 解析渠道级模型映射 + 限制检查 channelMapping, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) if restricted { - h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts") + h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "The requested model is not available for this API key") return } @@ -568,7 +568,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { // 解析渠道级模型映射 + 限制检查 channelMappingMsg, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) if restricted { - h.anthropicErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts") + h.anthropicErrorResponse(c, http.StatusServiceUnavailable, "api_error", "The requested model is not available for this API key") return } diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go index 458788fd..93cefd9a 100644 --- a/backend/internal/service/billing_service.go +++ b/backend/internal/service/billing_service.go @@ -402,12 +402,6 @@ func (s *BillingService) GetModelPricingWithChannel(model string, channelPricing return pricing, nil } -// CalculateCostWithChannel 使用渠道定价计算费用 -// Deprecated: 使用 CalculateCostUnified 代替 -func (s *BillingService) CalculateCostWithChannel(model string, tokens UsageTokens, rateMultiplier float64, channelPricing *ChannelModelPricing) (*CostBreakdown, error) { - return s.calculateCostInternal(model, tokens, rateMultiplier, "", channelPricing) -} - // --- 统一计费入口 --- // CostInput 统一计费输入 diff --git a/backend/internal/service/channel.go b/backend/internal/service/channel.go index bc13d642..40a137c1 100644 --- a/backend/internal/service/channel.go +++ b/backend/internal/service/channel.go @@ -82,38 +82,6 @@ type PricingInterval struct { UpdatedAt time.Time } -// ResolveMappedModel 解析渠道级模型映射,返回映射后的模型名。 -// platform 指定查找哪个平台的映射规则。 -// 支持通配符(如 "claude-*" → "claude-sonnet-4")。 -// 如果没有匹配的映射规则,返回原始模型名。 -func (c *Channel) ResolveMappedModel(platform, requestedModel string) string { - if len(c.ModelMapping) == 0 { - return requestedModel - } - platformMapping, ok := c.ModelMapping[platform] - if !ok || len(platformMapping) == 0 { - return requestedModel - } - lower := strings.ToLower(requestedModel) - // 精确匹配优先 - for src, dst := range platformMapping { - if strings.ToLower(src) == lower { - return dst - } - } - // 通配符匹配 - for src, dst := range platformMapping { - srcLower := strings.ToLower(src) - if strings.HasSuffix(srcLower, "*") { - prefix := strings.TrimSuffix(srcLower, "*") - if strings.HasPrefix(lower, prefix) { - return dst - } - } - } - return requestedModel -} - // IsActive 判断渠道是否启用 func (c *Channel) IsActive() bool { return c.Status == StatusActive diff --git a/backend/internal/service/channel_service.go b/backend/internal/service/channel_service.go index fb75bafc..6dbc2624 100644 --- a/backend/internal/service/channel_service.go +++ b/backend/internal/service/channel_service.go @@ -11,6 +11,8 @@ import ( infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" "golang.org/x/sync/singleflight" ) @@ -379,6 +381,34 @@ func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, m return true } +// ResolveChannelMappingAndRestrict 解析渠道映射并检查模型限制(组合方法)。 +// 返回映射结果和是否被限制。groupID 为 nil 时跳过。 +func (s *ChannelService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) { + var mapping ChannelMappingResult + mapping.MappedModel = model + if groupID == nil { + return mapping, false + } + mapping = s.ResolveChannelMapping(ctx, *groupID, model) + restricted := s.IsModelRestricted(ctx, *groupID, mapping.MappedModel) + return mapping, restricted +} + +// ReplaceModelInBody 替换请求体 JSON 中的 model 字段。 +func ReplaceModelInBody(body []byte, newModel string) []byte { + if len(body) == 0 { + return body + } + if current := gjson.GetBytes(body, "model"); current.Exists() && current.String() == newModel { + return body + } + newBody, err := sjson.SetBytes(body, "model", newModel) + if err != nil { + return body + } + return newBody +} + // --- CRUD --- // Create 创建渠道 @@ -539,16 +569,16 @@ func (s *ChannelService) List(ctx context.Context, params pagination.PaginationP return s.repo.List(ctx, params, status, search) } -// validateNoDuplicateModels 检查定价列表中是否有重复模型 +// validateNoDuplicateModels 检查定价列表中是否有重复模型(同一平台下不允许重复) func validateNoDuplicateModels(pricingList []ChannelModelPricing) error { seen := make(map[string]bool) for _, p := range pricingList { for _, model := range p.Models { - lower := strings.ToLower(model) - if seen[lower] { - return infraerrors.BadRequest("DUPLICATE_MODEL", fmt.Sprintf("model '%s' appears in multiple pricing entries", model)) + key := p.Platform + ":" + strings.ToLower(model) + if seen[key] { + return infraerrors.BadRequest("DUPLICATE_MODEL", fmt.Sprintf("model '%s' appears in multiple pricing entries for platform '%s'", model, p.Platform)) } - seen[lower] = true + seen[key] = true } } return nil diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 140e7202..5a866d63 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -872,17 +872,7 @@ type anthropicMetadataPayload struct { // replaceModelInBody 替换请求体中的model字段 // 优先使用定点修改,尽量保持客户端原始字段顺序。 func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte { - if len(body) == 0 { - return body - } - if current := gjson.GetBytes(body, "model"); current.Exists() && current.String() == newModel { - return body - } - newBody, err := sjson.SetBytes(body, "model", newModel) - if err != nil { - return body - } - return newBody + return ReplaceModelInBody(body, newModel) } type claudeOAuthNormalizeOptions struct { @@ -7794,11 +7784,8 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu } var err error if s.resolver != nil && apiKey.Group != nil { - var groupID *int64 - if apiKey.Group != nil { - gid := apiKey.Group.ID - groupID = &gid - } + gid := apiKey.Group.ID + groupID := &gid cost, err = s.billingService.CalculateCostUnified(CostInput{ Ctx: ctx, Model: billingModel, @@ -8184,7 +8171,7 @@ func (s *GatewayService) ResolveChannelMapping(ctx context.Context, groupID int6 // ReplaceModelInBody 替换请求体中的模型名(导出供 handler 使用) func (s *GatewayService) ReplaceModelInBody(body []byte, newModel string) []byte { - return s.replaceModelInBody(body, newModel) + return ReplaceModelInBody(body, newModel) } // IsModelRestricted 检查模型是否被渠道限制 @@ -8198,14 +8185,10 @@ func (s *GatewayService) IsModelRestricted(ctx context.Context, groupID int64, m // ResolveChannelMappingAndRestrict 解析渠道映射并检查模型限制。 // 返回映射结果和是否被限制。 func (s *GatewayService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) { - var mapping ChannelMappingResult - mapping.MappedModel = model - if groupID == nil { - return mapping, false + if s.channelService == nil { + return ChannelMappingResult{MappedModel: model}, false } - mapping = s.ResolveChannelMapping(ctx, *groupID, model) - restricted := s.IsModelRestricted(ctx, *groupID, mapping.MappedModel) - return mapping, restricted + return s.channelService.ResolveChannelMappingAndRestrict(ctx, groupID, model) } // ForwardCountTokens 转发 count_tokens 请求到上游 API diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 66f492a5..f68562f8 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -416,29 +416,15 @@ func (s *OpenAIGatewayService) IsModelRestricted(ctx context.Context, groupID in // ResolveChannelMappingAndRestrict 解析渠道映射并检查模型限制。 // 返回映射结果和是否被限制。 func (s *OpenAIGatewayService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) { - var mapping ChannelMappingResult - mapping.MappedModel = model - if groupID == nil { - return mapping, false + if s.channelService == nil { + return ChannelMappingResult{MappedModel: model}, false } - mapping = s.ResolveChannelMapping(ctx, *groupID, model) - restricted := s.IsModelRestricted(ctx, *groupID, mapping.MappedModel) - return mapping, restricted + return s.channelService.ResolveChannelMappingAndRestrict(ctx, groupID, model) } // ReplaceModelInBody 替换请求体中的 JSON model 字段(通用 gjson/sjson 实现)。 func (s *OpenAIGatewayService) ReplaceModelInBody(body []byte, newModel string) []byte { - if len(body) == 0 { - return body - } - if current := gjson.GetBytes(body, "model"); current.Exists() && current.String() == newModel { - return body - } - newBody, err := sjson.SetBytes(body, "model", newModel) - if err != nil { - return body - } - return newBody + return ReplaceModelInBody(body, newModel) } func (s *OpenAIGatewayService) getCodexSnapshotThrottle() *accountWriteThrottle { @@ -4249,13 +4235,20 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec durationMs := int(result.Duration.Milliseconds()) accountRateMultiplier := account.BillingRateMultiplier() requestID := resolveUsageBillingRequestID(ctx, result.RequestID) + + // 确定 RequestedModel(渠道映射前的原始模型) + requestedModel := result.Model + if input.OriginalModel != "" { + requestedModel = input.OriginalModel + } + usageLog := &UsageLog{ UserID: user.ID, APIKeyID: apiKey.ID, AccountID: account.ID, RequestID: requestID, Model: result.Model, - RequestedModel: result.Model, + RequestedModel: requestedModel, UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model), ServiceTier: result.ServiceTier, ReasoningEffort: result.ReasoningEffort,