refactor: move channel model restriction from handler to scheduling phase
Move the model pricing restriction check from 8 handler entry points to the account scheduling phase (SelectAccountForModelWithExclusions / SelectAccountWithLoadAwareness), aligning restriction with billing: - requested: check original request model against pricing list - channel_mapped: check channel-mapped model against pricing list - upstream: per-account check using account-mapped model Handler layer now only resolves channel mapping (no restriction). Scheduling layer performs pre-check for requested/channel_mapped, and per-account filtering for upstream billing source.
This commit is contained in:
@@ -158,12 +158,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
reqStream := parsedReq.Stream
|
||||
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
|
||||
|
||||
// 解析渠道级模型映射 + 限制检查
|
||||
channelMapping, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
||||
if restricted {
|
||||
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "The requested model is not available for this API key")
|
||||
return
|
||||
}
|
||||
// 解析渠道级模型映射
|
||||
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
||||
|
||||
// 设置 max_tokens=1 + haiku 探测请求标识到 context 中
|
||||
// 必须在 SetClaudeCodeClientContext 之前设置,因为 ClaudeCodeValidator 需要读取此标识进行绕过判断
|
||||
|
||||
@@ -81,11 +81,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||
|
||||
// 解析渠道级模型映射 + 限制检查
|
||||
channelMapping, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
||||
if restricted {
|
||||
h.chatCompletionsErrorResponse(c, http.StatusServiceUnavailable, "api_error", "The requested model is not available for this API key")
|
||||
return
|
||||
}
|
||||
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
||||
|
||||
// Claude Code only restriction
|
||||
if apiKey.Group != nil && apiKey.Group.ClaudeCodeOnly {
|
||||
|
||||
@@ -81,11 +81,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||
|
||||
// 解析渠道级模型映射 + 限制检查
|
||||
channelMapping, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
||||
if restricted {
|
||||
h.responsesErrorResponse(c, http.StatusServiceUnavailable, "api_error", "The requested model is not available for this API key")
|
||||
return
|
||||
}
|
||||
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
||||
|
||||
// Claude Code only restriction:
|
||||
// /v1/responses is never a Claude Code endpoint.
|
||||
|
||||
@@ -185,11 +185,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(stream, false)))
|
||||
|
||||
// 解析渠道级模型映射 + 限制检查
|
||||
channelMapping, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, modelName)
|
||||
if restricted {
|
||||
googleError(c, http.StatusServiceUnavailable, "The requested model is not available for this API key")
|
||||
return
|
||||
}
|
||||
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, modelName)
|
||||
reqModel := modelName // 保存映射前的原始模型名
|
||||
if channelMapping.Mapped {
|
||||
modelName = channelMapping.MappedModel
|
||||
|
||||
@@ -80,11 +80,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||
|
||||
// 解析渠道级模型映射 + 限制检查
|
||||
channelMapping, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
||||
if restricted {
|
||||
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "The requested model is not available for this API key")
|
||||
return
|
||||
}
|
||||
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
||||
|
||||
if h.errorPassthroughService != nil {
|
||||
service.BindErrorPassthroughService(c, h.errorPassthroughService)
|
||||
|
||||
@@ -185,12 +185,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||
|
||||
// 解析渠道级模型映射 + 限制检查
|
||||
channelMapping, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
||||
if restricted {
|
||||
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "The requested model is not available for this API key")
|
||||
return
|
||||
}
|
||||
// 解析渠道级模型映射
|
||||
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
||||
|
||||
// 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。
|
||||
if !h.validateFunctionCallOutputRequest(c, body, reqLog) {
|
||||
@@ -562,12 +558,8 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||
|
||||
// 解析渠道级模型映射 + 限制检查
|
||||
channelMappingMsg, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
||||
if restricted {
|
||||
h.anthropicErrorResponse(c, http.StatusServiceUnavailable, "api_error", "The requested model is not available for this API key")
|
||||
return
|
||||
}
|
||||
// 解析渠道级模型映射
|
||||
channelMappingMsg, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
||||
|
||||
// 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。
|
||||
if h.errorPassthroughService != nil {
|
||||
@@ -1128,11 +1120,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2))
|
||||
|
||||
// 解析渠道级模型映射 + 限制检查
|
||||
channelMappingWS, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(ctx, apiKey.GroupID, reqModel)
|
||||
if restricted {
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "model not allowed")
|
||||
return
|
||||
}
|
||||
channelMappingWS, _ := h.gatewayService.ResolveChannelMappingAndRestrict(ctx, apiKey.GroupID, reqModel)
|
||||
|
||||
var currentUserRelease func()
|
||||
var currentAccountRelease func()
|
||||
|
||||
@@ -436,8 +436,9 @@ func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, m
|
||||
return checkRestricted(lk, groupID, model)
|
||||
}
|
||||
|
||||
// ResolveChannelMappingAndRestrict 解析渠道映射并检查模型限制(组合方法)。
|
||||
// 返回映射结果和是否被限制。groupID 为 nil 时跳过。
|
||||
// ResolveChannelMappingAndRestrict 解析渠道映射。
|
||||
// 返回映射结果。模型限制检查已移至调度阶段(GatewayService.checkChannelPricingRestriction),
|
||||
// restricted 始终返回 false,保留签名兼容性。
|
||||
func (s *ChannelService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) {
|
||||
if groupID == nil {
|
||||
return ChannelMappingResult{MappedModel: model}, false
|
||||
@@ -446,10 +447,7 @@ func (s *ChannelService) ResolveChannelMappingAndRestrict(ctx context.Context, g
|
||||
if lk == nil {
|
||||
return ChannelMappingResult{MappedModel: model}, false
|
||||
}
|
||||
// 先用原始模型检查定价列表限制,再做映射
|
||||
restricted := checkRestricted(lk, *groupID, model)
|
||||
mapping := resolveMapping(lk, *groupID, model)
|
||||
return mapping, restricted
|
||||
return resolveMapping(lk, *groupID, model), false
|
||||
}
|
||||
|
||||
// resolveMapping 基于已查找的渠道信息解析模型映射
|
||||
|
||||
@@ -1068,6 +1068,8 @@ func TestIsModelRestricted_CaseInsensitive(t *testing.T) {
|
||||
}
|
||||
|
||||
// --- 4.5 ResolveChannelMappingAndRestrict ---
|
||||
// 注意:模型限制检查已移至调度阶段(GatewayService.checkChannelPricingRestriction),
|
||||
// ResolveChannelMappingAndRestrict 仅做映射,restricted 始终为 false。
|
||||
|
||||
func TestResolveChannelMappingAndRestrict_NilGroupID(t *testing.T) {
|
||||
repo := &mockChannelRepository{
|
||||
@@ -1083,7 +1085,7 @@ func TestResolveChannelMappingAndRestrict_NilGroupID(t *testing.T) {
|
||||
require.Equal(t, "claude-opus-4", mapping.MappedModel)
|
||||
}
|
||||
|
||||
func TestResolveChannelMappingAndRestrict_ModelInPricing_WithMapping(t *testing.T) {
|
||||
func TestResolveChannelMappingAndRestrict_WithMapping(t *testing.T) {
|
||||
ch := Channel{
|
||||
ID: 1,
|
||||
Status: StatusActive,
|
||||
@@ -1103,41 +1105,12 @@ func TestResolveChannelMappingAndRestrict_ModelInPricing_WithMapping(t *testing.
|
||||
|
||||
gid := int64(10)
|
||||
mapping, restricted := svc.ResolveChannelMappingAndRestrict(context.Background(), &gid, "claude-sonnet-4")
|
||||
require.False(t, restricted) // model IS in pricing
|
||||
require.False(t, restricted) // restricted 始终为 false,限制检查在调度阶段
|
||||
require.True(t, mapping.Mapped)
|
||||
require.Equal(t, "claude-sonnet-4-20250514", mapping.MappedModel)
|
||||
}
|
||||
|
||||
func TestResolveChannelMappingAndRestrict_ModelNotInPricing_WithMapping(t *testing.T) {
|
||||
// CRITICAL: this test verifies that restriction checks the ORIGINAL model
|
||||
// against pricing BEFORE applying mapping. The model "unknown-model" is NOT
|
||||
// in pricing, so even though the wildcard mapping "*" matches it, it should
|
||||
// still be restricted.
|
||||
ch := Channel{
|
||||
ID: 1,
|
||||
Status: StatusActive,
|
||||
GroupIDs: []int64{10},
|
||||
RestrictModels: true,
|
||||
ModelPricing: []ChannelModelPricing{
|
||||
{Platform: "anthropic", Models: []string{"claude-sonnet-4"}},
|
||||
},
|
||||
ModelMapping: map[string]map[string]string{
|
||||
"anthropic": {
|
||||
"*": "catch-all-target",
|
||||
},
|
||||
},
|
||||
}
|
||||
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
|
||||
svc := newTestChannelService(repo)
|
||||
|
||||
gid := int64(10)
|
||||
mapping, restricted := svc.ResolveChannelMappingAndRestrict(context.Background(), &gid, "unknown-model")
|
||||
require.True(t, restricted) // model NOT in pricing, even though mapping exists
|
||||
require.True(t, mapping.Mapped)
|
||||
require.Equal(t, "catch-all-target", mapping.MappedModel)
|
||||
}
|
||||
|
||||
func TestResolveChannelMappingAndRestrict_ModelNotInPricing_NoMapping(t *testing.T) {
|
||||
func TestResolveChannelMappingAndRestrict_NoMapping(t *testing.T) {
|
||||
ch := Channel{
|
||||
ID: 1,
|
||||
Status: StatusActive,
|
||||
@@ -1152,7 +1125,7 @@ func TestResolveChannelMappingAndRestrict_ModelNotInPricing_NoMapping(t *testing
|
||||
|
||||
gid := int64(10)
|
||||
mapping, restricted := svc.ResolveChannelMappingAndRestrict(context.Background(), &gid, "unknown-model")
|
||||
require.True(t, restricted) // model NOT in pricing
|
||||
require.False(t, restricted) // restricted 始终为 false,限制检查在调度阶段
|
||||
require.False(t, mapping.Mapped)
|
||||
require.Equal(t, "unknown-model", mapping.MappedModel)
|
||||
}
|
||||
|
||||
@@ -1178,6 +1178,11 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
|
||||
|
||||
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
|
||||
func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
|
||||
// 渠道定价限制预检查(requested / channel_mapped 基准)
|
||||
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
|
||||
return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
|
||||
}
|
||||
|
||||
// 优先检查 context 中的强制平台(/antigravity 路由)
|
||||
var platform string
|
||||
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
|
||||
@@ -1208,8 +1213,15 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
|
||||
}
|
||||
|
||||
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
|
||||
// metadataUserID: 已废弃参数,会话限制现在统一使用 sessionHash
|
||||
func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string) (*AccountSelectionResult, error) {
|
||||
// 调度流程文档见 docs/ACCOUNT_SCHEDULING_FLOW.md 。
|
||||
// metadataUserID: 用于客户端亲和调度,从中提取客户端 ID
|
||||
// sub2apiUserID: 系统用户 ID,用于二维亲和调度
|
||||
func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string, sub2apiUserID int64) (*AccountSelectionResult, error) {
|
||||
// 渠道定价限制预检查(requested / channel_mapped 基准)
|
||||
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
|
||||
return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
|
||||
}
|
||||
|
||||
// 调试日志:记录调度入口参数
|
||||
excludedIDsList := make([]int64, 0, len(excludedIDs))
|
||||
for id := range excludedIDs {
|
||||
@@ -2955,6 +2967,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
ctx = s.withRPMPrefetch(ctx, accounts)
|
||||
|
||||
// 3. 按优先级+最久未用选择(考虑模型支持)
|
||||
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
|
||||
var selected *Account
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
@@ -2975,6 +2988,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
@@ -3207,6 +3223,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
ctx = s.withRPMPrefetch(ctx, accounts)
|
||||
|
||||
// 3. 按优先级+最久未用选择(考虑模型支持和混合调度)
|
||||
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
|
||||
var selected *Account
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
@@ -3231,6 +3248,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
@@ -8212,6 +8232,67 @@ func (s *GatewayService) ResolveChannelMappingAndRestrict(ctx context.Context, g
|
||||
return s.channelService.ResolveChannelMappingAndRestrict(ctx, groupID, model)
|
||||
}
|
||||
|
||||
// checkChannelPricingRestriction 根据渠道计费基准检查模型是否受定价列表限制。
|
||||
// 供调度阶段预检查(requested / channel_mapped)。
|
||||
// upstream 需逐账号检查,此处返回 false。
|
||||
func (s *GatewayService) checkChannelPricingRestriction(ctx context.Context, groupID *int64, requestedModel string) bool {
|
||||
if groupID == nil || s.channelService == nil || requestedModel == "" {
|
||||
return false
|
||||
}
|
||||
mapping := s.channelService.ResolveChannelMapping(ctx, *groupID, requestedModel)
|
||||
billingModel := billingModelForRestriction(mapping.BillingModelSource, requestedModel, mapping.MappedModel)
|
||||
if billingModel == "" {
|
||||
return false
|
||||
}
|
||||
return s.channelService.IsModelRestricted(ctx, *groupID, billingModel)
|
||||
}
|
||||
|
||||
// billingModelForRestriction 根据计费基准确定限制检查使用的模型。
|
||||
// upstream 返回空(需逐账号检查)。
|
||||
func billingModelForRestriction(source, requestedModel, channelMappedModel string) string {
|
||||
switch source {
|
||||
case BillingModelSourceRequested:
|
||||
return requestedModel
|
||||
case BillingModelSourceUpstream:
|
||||
return ""
|
||||
default: // channel_mapped
|
||||
return channelMappedModel
|
||||
}
|
||||
}
|
||||
|
||||
// isUpstreamModelRestrictedByChannel 检查账号映射后的上游模型是否受渠道定价限制。
|
||||
// 仅在 BillingModelSource="upstream" 且 RestrictModels=true 时由调度循环调用。
|
||||
func (s *GatewayService) isUpstreamModelRestrictedByChannel(ctx context.Context, groupID int64, account *Account, requestedModel string) bool {
|
||||
if s.channelService == nil {
|
||||
return false
|
||||
}
|
||||
upstreamModel := resolveAccountUpstreamModel(account, requestedModel)
|
||||
if upstreamModel == "" {
|
||||
return false
|
||||
}
|
||||
return s.channelService.IsModelRestricted(ctx, groupID, upstreamModel)
|
||||
}
|
||||
|
||||
// resolveAccountUpstreamModel 确定账号将请求模型映射为什么上游模型。
|
||||
func resolveAccountUpstreamModel(account *Account, requestedModel string) string {
|
||||
if account.Platform == PlatformAntigravity {
|
||||
return mapAntigravityModel(account, requestedModel)
|
||||
}
|
||||
return account.GetMappedModel(requestedModel)
|
||||
}
|
||||
|
||||
// needsUpstreamChannelRestrictionCheck 判断是否需要在调度循环中逐账号检查上游模型的渠道限制。
|
||||
func (s *GatewayService) needsUpstreamChannelRestrictionCheck(ctx context.Context, groupID *int64) bool {
|
||||
if groupID == nil || s.channelService == nil {
|
||||
return false
|
||||
}
|
||||
ch, err := s.channelService.GetChannelForGroup(ctx, *groupID)
|
||||
if err != nil || ch == nil || !ch.RestrictModels {
|
||||
return false
|
||||
}
|
||||
return ch.BillingModelSource == BillingModelSourceUpstream
|
||||
}
|
||||
|
||||
// ForwardCountTokens 转发 count_tokens 请求到上游 API
|
||||
// 特点:不记录使用量、仅支持非流式响应
|
||||
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error {
|
||||
|
||||
Reference in New Issue
Block a user