diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index dff922d1..89a791fd 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -158,18 +158,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) { reqStream := parsedReq.Stream reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) - // 解析渠道级模型映射 - var channelMapping service.ChannelMappingResult - if apiKey.GroupID != nil { - channelMapping = h.gatewayService.ResolveChannelMapping(c.Request.Context(), *apiKey.GroupID, reqModel) - } - - // 渠道模型限制检查:先映射再判断,映射后的模型在定价列表中即放行 - if apiKey.GroupID != nil { - if h.gatewayService.IsModelRestricted(c.Request.Context(), *apiKey.GroupID, channelMapping.MappedModel) { - h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts") - return - } + // 解析渠道级模型映射 + 限制检查 + channelMapping, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) + if restricted { + h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts") + return } // 设置 max_tokens=1 + haiku 探测请求标识到 context 中 @@ -495,18 +488,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ChannelID: channelMapping.ChannelID, OriginalModel: reqModel, BillingModelSource: channelMapping.BillingModelSource, - ModelMappingChain: func() string { - if !channelMapping.Mapped { - if result.UpstreamModel != "" && result.UpstreamModel != result.Model { - return reqModel + "→" + result.UpstreamModel - } - return "" - } - if result.UpstreamModel != "" && result.UpstreamModel != channelMapping.MappedModel { - return reqModel + "→" + channelMapping.MappedModel + "→" + result.UpstreamModel - } - return reqModel + "→" + channelMapping.MappedModel - }(), + ModelMappingChain: channelMapping.BuildModelMappingChain(reqModel, result.UpstreamModel), }); err != nil { logger.L().With( zap.String("component", "handler.gateway.messages"), @@ -849,18 +831,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ChannelID: channelMapping.ChannelID, OriginalModel: reqModel, BillingModelSource: channelMapping.BillingModelSource, - ModelMappingChain: func() string { - if !channelMapping.Mapped { - if result.UpstreamModel != "" && result.UpstreamModel != result.Model { - return reqModel + "→" + result.UpstreamModel - } - return "" - } - if result.UpstreamModel != "" && result.UpstreamModel != channelMapping.MappedModel { - return reqModel + "→" + channelMapping.MappedModel + "→" + result.UpstreamModel - } - return reqModel + "→" + channelMapping.MappedModel - }(), + ModelMappingChain: channelMapping.BuildModelMappingChain(reqModel, result.UpstreamModel), }); err != nil { logger.L().With( zap.String("component", "handler.gateway.messages"), diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index c8b90e14..17f2fe82 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -185,17 +185,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { setOpsRequestContext(c, reqModel, reqStream, body) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) - // 解析渠道级模型映射 - var channelMapping service.ChannelMappingResult - if apiKey.GroupID != nil { - channelMapping = h.gatewayService.ResolveChannelMapping(c.Request.Context(), *apiKey.GroupID, reqModel) - } - - // 渠道模型限制检查:先映射再判断 - if apiKey.GroupID != nil { - if h.gatewayService.IsModelRestricted(c.Request.Context(), *apiKey.GroupID, channelMapping.MappedModel) { - return - } + // 解析渠道级模型映射 + 限制检查 + channelMapping, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) + if restricted { + h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts") + return } // 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。 @@ -297,7 +291,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { // Forward request service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds()) forwardStart := time.Now() - result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body) + // 应用渠道模型映射到请求体 + forwardBody := body + if channelMapping.Mapped { + forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel) + } + result, err := h.gatewayService.Forward(c.Request.Context(), c, account, forwardBody) forwardDurationMs := time.Since(forwardStart).Milliseconds() if accountReleaseFunc != nil { accountReleaseFunc() @@ -395,18 +394,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ChannelID: channelMapping.ChannelID, OriginalModel: reqModel, BillingModelSource: channelMapping.BillingModelSource, - ModelMappingChain: func() string { - if !channelMapping.Mapped { - if result.UpstreamModel != "" && result.UpstreamModel != result.Model { - return reqModel + "→" + result.UpstreamModel - } - return "" - } - if result.UpstreamModel != "" && result.UpstreamModel != channelMapping.MappedModel { - return reqModel + "→" + channelMapping.MappedModel + "→" + result.UpstreamModel - } - return reqModel + "→" + channelMapping.MappedModel - }(), + ModelMappingChain: channelMapping.BuildModelMappingChain(reqModel, result.UpstreamModel), }); err != nil { logger.L().With( zap.String("component", "handler.openai_gateway.responses"), @@ -577,17 +565,11 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { setOpsRequestContext(c, reqModel, reqStream, body) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) - // 解析渠道级模型映射 - var channelMappingMsg service.ChannelMappingResult - if apiKey.GroupID != nil { - channelMappingMsg = h.gatewayService.ResolveChannelMapping(c.Request.Context(), *apiKey.GroupID, reqModel) - } - - // 渠道模型限制检查:先映射再判断 - if apiKey.GroupID != nil { - if h.gatewayService.IsModelRestricted(c.Request.Context(), *apiKey.GroupID, channelMappingMsg.MappedModel) { - return - } + // 解析渠道级模型映射 + 限制检查 + channelMappingMsg, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) + if restricted { + h.anthropicErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts") + return } // 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。 @@ -714,7 +696,12 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { // Forward 层需要始终拿到 group 默认映射模型,这样未命中账号级映射的 // Claude 兼容模型才不会在后续 Codex 规范化中意外退化到 gpt-5.1。 defaultMappedModel := resolveOpenAIForwardDefaultMappedModel(apiKey, c.GetString("openai_messages_fallback_model")) - result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel) + // 应用渠道模型映射到请求体 + forwardBody := body + if channelMappingMsg.Mapped { + forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMappingMsg.MappedModel) + } + result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, forwardBody, promptCacheKey, defaultMappedModel) forwardDurationMs := time.Since(forwardStart).Milliseconds() if accountReleaseFunc != nil { @@ -803,18 +790,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { ChannelID: channelMappingMsg.ChannelID, OriginalModel: reqModel, BillingModelSource: channelMappingMsg.BillingModelSource, - ModelMappingChain: func() string { - if !channelMappingMsg.Mapped { - if result.UpstreamModel != "" && result.UpstreamModel != result.Model { - return reqModel + "→" + result.UpstreamModel - } - return "" - } - if result.UpstreamModel != "" && result.UpstreamModel != channelMappingMsg.MappedModel { - return reqModel + "→" + channelMappingMsg.MappedModel + "→" + result.UpstreamModel - } - return reqModel + "→" + channelMappingMsg.MappedModel - }(), + ModelMappingChain: channelMappingMsg.BuildModelMappingChain(reqModel, result.UpstreamModel), }); err != nil { logger.L().With( zap.String("component", "handler.openai_gateway.messages"), @@ -1157,18 +1133,11 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { setOpsRequestContext(c, reqModel, true, firstMessage) setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2)) - // 解析渠道级模型映射 - var channelMappingWS service.ChannelMappingResult - if apiKey.GroupID != nil { - channelMappingWS = h.gatewayService.ResolveChannelMapping(ctx, *apiKey.GroupID, reqModel) - } - - // 渠道模型限制检查:先映射再判断 - if apiKey.GroupID != nil { - if h.gatewayService.IsModelRestricted(ctx, *apiKey.GroupID, channelMappingWS.MappedModel) { - closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "model not allowed") - return - } + // 解析渠道级模型映射 + 限制检查 + channelMappingWS, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(ctx, apiKey.GroupID, reqModel) + if restricted { + closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "model not allowed") + return } var currentUserRelease func() @@ -1332,18 +1301,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { ChannelID: channelMappingWS.ChannelID, OriginalModel: reqModel, BillingModelSource: channelMappingWS.BillingModelSource, - ModelMappingChain: func() string { - if !channelMappingWS.Mapped { - if result.UpstreamModel != "" && result.UpstreamModel != result.Model { - return reqModel + "→" + result.UpstreamModel - } - return "" - } - if result.UpstreamModel != "" && result.UpstreamModel != channelMappingWS.MappedModel { - return reqModel + "→" + channelMappingWS.MappedModel + "→" + result.UpstreamModel - } - return reqModel + "→" + channelMappingWS.MappedModel - }(), + ModelMappingChain: channelMappingWS.BuildModelMappingChain(reqModel, result.UpstreamModel), }); err != nil { reqLog.Error("openai.websocket_record_usage_failed", zap.Int64("account_id", account.ID), @@ -1355,7 +1313,13 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { }, } - if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, firstMessage, hooks); err != nil { + // 应用渠道模型映射到 WebSocket 首条消息 + wsFirstMessage := firstMessage + if channelMappingWS.Mapped { + wsFirstMessage = h.gatewayService.ReplaceModelInBody(firstMessage, channelMappingWS.MappedModel) + } + + if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, wsFirstMessage, hooks); err != nil { h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) closeStatus, closeReason := summarizeWSCloseErrorForLog(err) reqLog.Warn("openai.websocket_proxy_failed", diff --git a/backend/internal/service/channel_service.go b/backend/internal/service/channel_service.go index 29f4b615..fb75bafc 100644 --- a/backend/internal/service/channel_service.go +++ b/backend/internal/service/channel_service.go @@ -92,6 +92,23 @@ type ChannelMappingResult struct { BillingModelSource string // 计费模型来源("requested" / "upstream") } +// BuildModelMappingChain 根据映射结果和上游实际模型构建映射链描述。 +// reqModel: 客户端请求的原始模型名。 +// upstreamModel: 上游实际使用的模型名(ForwardResult.UpstreamModel)。 +// 返回空字符串表示无映射。 +func (r ChannelMappingResult) BuildModelMappingChain(reqModel, upstreamModel string) string { + if !r.Mapped { + if upstreamModel != "" && upstreamModel != reqModel { + return reqModel + "→" + upstreamModel + } + return "" + } + if upstreamModel != "" && upstreamModel != r.MappedModel { + return reqModel + "→" + r.MappedModel + "→" + upstreamModel + } + return reqModel + "→" + r.MappedModel +} + const ( channelCacheTTL = 60 * time.Second channelErrorTTL = 5 * time.Second // DB 错误时的短缓存 diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index c7d8403e..140e7202 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -8195,6 +8195,19 @@ func (s *GatewayService) IsModelRestricted(ctx context.Context, groupID int64, m return s.channelService.IsModelRestricted(ctx, groupID, model) } +// ResolveChannelMappingAndRestrict 解析渠道映射并检查模型限制。 +// 返回映射结果和是否被限制。 +func (s *GatewayService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) { + var mapping ChannelMappingResult + mapping.MappedModel = model + if groupID == nil { + return mapping, false + } + mapping = s.ResolveChannelMapping(ctx, *groupID, model) + restricted := s.IsModelRestricted(ctx, *groupID, mapping.MappedModel) + return mapping, restricted +} + // ForwardCountTokens 转发 count_tokens 请求到上游 API // 特点:不记录使用量、仅支持非流式响应 func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error { diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 3818af02..66f492a5 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -413,6 +413,34 @@ func (s *OpenAIGatewayService) IsModelRestricted(ctx context.Context, groupID in return s.channelService.IsModelRestricted(ctx, groupID, model) } +// ResolveChannelMappingAndRestrict 解析渠道映射并检查模型限制。 +// 返回映射结果和是否被限制。 +func (s *OpenAIGatewayService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) { + var mapping ChannelMappingResult + mapping.MappedModel = model + if groupID == nil { + return mapping, false + } + mapping = s.ResolveChannelMapping(ctx, *groupID, model) + restricted := s.IsModelRestricted(ctx, *groupID, mapping.MappedModel) + return mapping, restricted +} + +// ReplaceModelInBody 替换请求体中的 JSON model 字段(通用 gjson/sjson 实现)。 +func (s *OpenAIGatewayService) ReplaceModelInBody(body []byte, newModel string) []byte { + if len(body) == 0 { + return body + } + if current := gjson.GetBytes(body, "model"); current.Exists() && current.String() == newModel { + return body + } + newBody, err := sjson.SetBytes(body, "model", newModel) + if err != nil { + return body + } + return newBody +} + func (s *OpenAIGatewayService) getCodexSnapshotThrottle() *accountWriteThrottle { if s != nil && s.codexSnapshotThrottle != nil { return s.codexSnapshotThrottle diff --git a/frontend/src/views/admin/ChannelsView.vue b/frontend/src/views/admin/ChannelsView.vue index 6c57b269..cdfcf291 100644 --- a/frontend/src/views/admin/ChannelsView.vue +++ b/frontend/src/views/admin/ChannelsView.vue @@ -683,7 +683,6 @@ function formToAPI(): { group_ids: number[], model_pricing: ChannelModelPricing[ // Model pricing with platform tag for (const entry of section.model_pricing) { - console.log('[formToAPI] entry:', JSON.stringify({ models: entry.models, billing_mode: entry.billing_mode, per_request_price: entry.per_request_price })) if (entry.models.length === 0) continue model_pricing.push({ platform: section.platform, @@ -700,7 +699,6 @@ function formToAPI(): { group_ids: number[], model_pricing: ChannelModelPricing[ } } - console.log('[formToAPI] result:', JSON.stringify({ group_ids, model_pricing_count: model_pricing.length, model_mapping_keys: Object.keys(model_mapping), platforms_count: form.platforms.length, pricing_entries: form.platforms.map(s => s.model_pricing.length) })) return { group_ids, model_pricing, model_mapping } } @@ -883,7 +881,6 @@ async function handleSubmit() { } const { group_ids, model_pricing, model_mapping } = formToAPI() - console.log('[handleSubmit] model_pricing to send:', JSON.stringify(model_pricing)) submitting.value = true try {