diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 651936c1..c6daa2b1 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -158,12 +158,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) { reqStream := parsedReq.Stream reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) - // 解析渠道级模型映射 + 限制检查 - channelMapping, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) - if restricted { - h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "The requested model is not available for this API key") - return - } + // 解析渠道级模型映射 + channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) // 设置 max_tokens=1 + haiku 探测请求标识到 context 中 // 必须在 SetClaudeCodeClientContext 之前设置,因为 ClaudeCodeValidator 需要读取此标识进行绕过判断 diff --git a/backend/internal/handler/gateway_handler_chat_completions.go b/backend/internal/handler/gateway_handler_chat_completions.go index b70582f6..dda4bed7 100644 --- a/backend/internal/handler/gateway_handler_chat_completions.go +++ b/backend/internal/handler/gateway_handler_chat_completions.go @@ -81,11 +81,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) { setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) // 解析渠道级模型映射 + 限制检查 - channelMapping, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) - if restricted { - h.chatCompletionsErrorResponse(c, http.StatusServiceUnavailable, "api_error", "The requested model is not available for this API key") - return - } + channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) // Claude Code only restriction if apiKey.Group != nil && apiKey.Group.ClaudeCodeOnly { diff --git a/backend/internal/handler/gateway_handler_responses.go b/backend/internal/handler/gateway_handler_responses.go index d4ee905a..8f264551 100644 --- a/backend/internal/handler/gateway_handler_responses.go +++ b/backend/internal/handler/gateway_handler_responses.go @@ -81,11 +81,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) { setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) // 解析渠道级模型映射 + 限制检查 - channelMapping, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) - if restricted { - h.responsesErrorResponse(c, http.StatusServiceUnavailable, "api_error", "The requested model is not available for this API key") - return - } + channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) // Claude Code only restriction: // /v1/responses is never a Claude Code endpoint. diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 55556764..2c9a38f8 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -185,11 +185,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(stream, false))) // 解析渠道级模型映射 + 限制检查 - channelMapping, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, modelName) - if restricted { - googleError(c, http.StatusServiceUnavailable, "The requested model is not available for this API key") - return - } + channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, modelName) reqModel := modelName // 保存映射前的原始模型名 if channelMapping.Mapped { modelName = channelMapping.MappedModel diff --git a/backend/internal/handler/openai_chat_completions.go b/backend/internal/handler/openai_chat_completions.go index 20695e0e..ada401c9 100644 --- a/backend/internal/handler/openai_chat_completions.go +++ b/backend/internal/handler/openai_chat_completions.go @@ -80,11 +80,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) // 解析渠道级模型映射 + 限制检查 - channelMapping, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) - if restricted { - h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "The requested model is not available for this API key") - return - } + channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) if h.errorPassthroughService != nil { service.BindErrorPassthroughService(c, h.errorPassthroughService) diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index d1fc9b51..0063a1c2 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -185,12 +185,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { setOpsRequestContext(c, reqModel, reqStream, body) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) - // 解析渠道级模型映射 + 限制检查 - channelMapping, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) - if restricted { - h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "The requested model is not available for this API key") - return - } + // 解析渠道级模型映射 + channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) // 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。 if !h.validateFunctionCallOutputRequest(c, body, reqLog) { @@ -562,12 +558,8 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { setOpsRequestContext(c, reqModel, reqStream, body) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) - // 解析渠道级模型映射 + 限制检查 - channelMappingMsg, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) - if restricted { - h.anthropicErrorResponse(c, http.StatusServiceUnavailable, "api_error", "The requested model is not available for this API key") - return - } + // 解析渠道级模型映射 + channelMappingMsg, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) // 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。 if h.errorPassthroughService != nil { @@ -1128,11 +1120,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2)) // 解析渠道级模型映射 + 限制检查 - channelMappingWS, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(ctx, apiKey.GroupID, reqModel) - if restricted { - closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "model not allowed") - return - } + channelMappingWS, _ := h.gatewayService.ResolveChannelMappingAndRestrict(ctx, apiKey.GroupID, reqModel) var currentUserRelease func() var currentAccountRelease func() diff --git a/backend/internal/service/channel_service.go b/backend/internal/service/channel_service.go index d4d4d377..e22ebc81 100644 --- a/backend/internal/service/channel_service.go +++ b/backend/internal/service/channel_service.go @@ -436,8 +436,9 @@ func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, m return checkRestricted(lk, groupID, model) } -// ResolveChannelMappingAndRestrict 解析渠道映射并检查模型限制(组合方法)。 -// 返回映射结果和是否被限制。groupID 为 nil 时跳过。 +// ResolveChannelMappingAndRestrict 解析渠道映射。 +// 返回映射结果。模型限制检查已移至调度阶段(GatewayService.checkChannelPricingRestriction), +// restricted 始终返回 false,保留签名兼容性。 func (s *ChannelService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) { if groupID == nil { return ChannelMappingResult{MappedModel: model}, false @@ -446,10 +447,7 @@ func (s *ChannelService) ResolveChannelMappingAndRestrict(ctx context.Context, g if lk == nil { return ChannelMappingResult{MappedModel: model}, false } - // 先用原始模型检查定价列表限制,再做映射 - restricted := checkRestricted(lk, *groupID, model) - mapping := resolveMapping(lk, *groupID, model) - return mapping, restricted + return resolveMapping(lk, *groupID, model), false } // resolveMapping 基于已查找的渠道信息解析模型映射 diff --git a/backend/internal/service/channel_service_test.go b/backend/internal/service/channel_service_test.go index 0cb16b2b..93cafa67 100644 --- a/backend/internal/service/channel_service_test.go +++ b/backend/internal/service/channel_service_test.go @@ -1068,6 +1068,8 @@ func TestIsModelRestricted_CaseInsensitive(t *testing.T) { } // --- 4.5 ResolveChannelMappingAndRestrict --- +// 注意:模型限制检查已移至调度阶段(GatewayService.checkChannelPricingRestriction), +// ResolveChannelMappingAndRestrict 仅做映射,restricted 始终为 false。 func TestResolveChannelMappingAndRestrict_NilGroupID(t *testing.T) { repo := &mockChannelRepository{ @@ -1083,7 +1085,7 @@ func TestResolveChannelMappingAndRestrict_NilGroupID(t *testing.T) { require.Equal(t, "claude-opus-4", mapping.MappedModel) } -func TestResolveChannelMappingAndRestrict_ModelInPricing_WithMapping(t *testing.T) { +func TestResolveChannelMappingAndRestrict_WithMapping(t *testing.T) { ch := Channel{ ID: 1, Status: StatusActive, @@ -1103,41 +1105,12 @@ func TestResolveChannelMappingAndRestrict_ModelInPricing_WithMapping(t *testing. gid := int64(10) mapping, restricted := svc.ResolveChannelMappingAndRestrict(context.Background(), &gid, "claude-sonnet-4") - require.False(t, restricted) // model IS in pricing + require.False(t, restricted) // restricted 始终为 false,限制检查在调度阶段 require.True(t, mapping.Mapped) require.Equal(t, "claude-sonnet-4-20250514", mapping.MappedModel) } -func TestResolveChannelMappingAndRestrict_ModelNotInPricing_WithMapping(t *testing.T) { - // CRITICAL: this test verifies that restriction checks the ORIGINAL model - // against pricing BEFORE applying mapping. The model "unknown-model" is NOT - // in pricing, so even though the wildcard mapping "*" matches it, it should - // still be restricted. - ch := Channel{ - ID: 1, - Status: StatusActive, - GroupIDs: []int64{10}, - RestrictModels: true, - ModelPricing: []ChannelModelPricing{ - {Platform: "anthropic", Models: []string{"claude-sonnet-4"}}, - }, - ModelMapping: map[string]map[string]string{ - "anthropic": { - "*": "catch-all-target", - }, - }, - } - repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) - svc := newTestChannelService(repo) - - gid := int64(10) - mapping, restricted := svc.ResolveChannelMappingAndRestrict(context.Background(), &gid, "unknown-model") - require.True(t, restricted) // model NOT in pricing, even though mapping exists - require.True(t, mapping.Mapped) - require.Equal(t, "catch-all-target", mapping.MappedModel) -} - -func TestResolveChannelMappingAndRestrict_ModelNotInPricing_NoMapping(t *testing.T) { +func TestResolveChannelMappingAndRestrict_NoMapping(t *testing.T) { ch := Channel{ ID: 1, Status: StatusActive, @@ -1152,7 +1125,7 @@ func TestResolveChannelMappingAndRestrict_ModelNotInPricing_NoMapping(t *testing gid := int64(10) mapping, restricted := svc.ResolveChannelMappingAndRestrict(context.Background(), &gid, "unknown-model") - require.True(t, restricted) // model NOT in pricing + require.False(t, restricted) // restricted 始终为 false,限制检查在调度阶段 require.False(t, mapping.Mapped) require.Equal(t, "unknown-model", mapping.MappedModel) } diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index e42f3702..8879f3d2 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -1178,6 +1178,11 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int // SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts. func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) { + // 渠道定价限制预检查(requested / channel_mapped 基准) + if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) { + return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel) + } + // 优先检查 context 中的强制平台(/antigravity 路由) var platform string forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string) @@ -1208,8 +1213,15 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context } // SelectAccountWithLoadAwareness selects account with load-awareness and wait plan. -// metadataUserID: 已废弃参数,会话限制现在统一使用 sessionHash -func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string) (*AccountSelectionResult, error) { +// 调度流程文档见 docs/ACCOUNT_SCHEDULING_FLOW.md 。 +// metadataUserID: 用于客户端亲和调度,从中提取客户端 ID +// sub2apiUserID: 系统用户 ID,用于二维亲和调度 +func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string, sub2apiUserID int64) (*AccountSelectionResult, error) { + // 渠道定价限制预检查(requested / channel_mapped 基准) + if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) { + return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel) + } + // 调试日志:记录调度入口参数 excludedIDsList := make([]int64, 0, len(excludedIDs)) for id := range excludedIDs { @@ -2955,6 +2967,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, ctx = s.withRPMPrefetch(ctx, accounts) // 3. 按优先级+最久未用选择(考虑模型支持) + needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID) var selected *Account for i := range accounts { acc := &accounts[i] @@ -2975,6 +2988,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { continue } + if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, acc, requestedModel) { + continue + } if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { continue } @@ -3207,6 +3223,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g ctx = s.withRPMPrefetch(ctx, accounts) // 3. 按优先级+最久未用选择(考虑模型支持和混合调度) + needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID) var selected *Account for i := range accounts { acc := &accounts[i] @@ -3231,6 +3248,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { continue } + if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, acc, requestedModel) { + continue + } if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { continue } @@ -8212,6 +8232,67 @@ func (s *GatewayService) ResolveChannelMappingAndRestrict(ctx context.Context, g return s.channelService.ResolveChannelMappingAndRestrict(ctx, groupID, model) } +// checkChannelPricingRestriction 根据渠道计费基准检查模型是否受定价列表限制。 +// 供调度阶段预检查(requested / channel_mapped)。 +// upstream 需逐账号检查,此处返回 false。 +func (s *GatewayService) checkChannelPricingRestriction(ctx context.Context, groupID *int64, requestedModel string) bool { + if groupID == nil || s.channelService == nil || requestedModel == "" { + return false + } + mapping := s.channelService.ResolveChannelMapping(ctx, *groupID, requestedModel) + billingModel := billingModelForRestriction(mapping.BillingModelSource, requestedModel, mapping.MappedModel) + if billingModel == "" { + return false + } + return s.channelService.IsModelRestricted(ctx, *groupID, billingModel) +} + +// billingModelForRestriction 根据计费基准确定限制检查使用的模型。 +// upstream 返回空(需逐账号检查)。 +func billingModelForRestriction(source, requestedModel, channelMappedModel string) string { + switch source { + case BillingModelSourceRequested: + return requestedModel + case BillingModelSourceUpstream: + return "" + default: // channel_mapped + return channelMappedModel + } +} + +// isUpstreamModelRestrictedByChannel 检查账号映射后的上游模型是否受渠道定价限制。 +// 仅在 BillingModelSource="upstream" 且 RestrictModels=true 时由调度循环调用。 +func (s *GatewayService) isUpstreamModelRestrictedByChannel(ctx context.Context, groupID int64, account *Account, requestedModel string) bool { + if s.channelService == nil { + return false + } + upstreamModel := resolveAccountUpstreamModel(account, requestedModel) + if upstreamModel == "" { + return false + } + return s.channelService.IsModelRestricted(ctx, groupID, upstreamModel) +} + +// resolveAccountUpstreamModel 确定账号将请求模型映射为什么上游模型。 +func resolveAccountUpstreamModel(account *Account, requestedModel string) string { + if account.Platform == PlatformAntigravity { + return mapAntigravityModel(account, requestedModel) + } + return account.GetMappedModel(requestedModel) +} + +// needsUpstreamChannelRestrictionCheck 判断是否需要在调度循环中逐账号检查上游模型的渠道限制。 +func (s *GatewayService) needsUpstreamChannelRestrictionCheck(ctx context.Context, groupID *int64) bool { + if groupID == nil || s.channelService == nil { + return false + } + ch, err := s.channelService.GetChannelForGroup(ctx, *groupID) + if err != nil || ch == nil || !ch.RestrictModels { + return false + } + return ch.BillingModelSource == BillingModelSourceUpstream +} + // ForwardCountTokens 转发 count_tokens 请求到上游 API // 特点:不记录使用量、仅支持非流式响应 func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error {