refactor: move channel model restriction from handler to scheduling phase

Move the model pricing restriction check from 8 handler entry points
to the account scheduling phase (SelectAccountForModelWithExclusions /
SelectAccountWithLoadAwareness), aligning restriction with billing:

- requested: check original request model against pricing list
- channel_mapped: check channel-mapped model against pricing list
- upstream: per-account check using account-mapped model

Handler layer now only resolves channel mapping (no restriction).
Scheduling layer performs pre-check for requested/channel_mapped,
and per-account filtering for upstream billing source.
This commit is contained in:
erio
2026-04-02 13:24:30 +08:00
parent b4a42a640d
commit ce41afb756
9 changed files with 104 additions and 84 deletions

View File

@@ -185,12 +185,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// 解析渠道级模型映射 + 限制检查
channelMapping, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
if restricted {
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "The requested model is not available for this API key")
return
}
// 解析渠道级模型映射
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
// 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。
if !h.validateFunctionCallOutputRequest(c, body, reqLog) {
@@ -562,12 +558,8 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// 解析渠道级模型映射 + 限制检查
channelMappingMsg, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
if restricted {
h.anthropicErrorResponse(c, http.StatusServiceUnavailable, "api_error", "The requested model is not available for this API key")
return
}
// 解析渠道级模型映射
channelMappingMsg, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
// 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。
if h.errorPassthroughService != nil {
@@ -1128,11 +1120,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2))
// 解析渠道级模型映射 + 限制检查
channelMappingWS, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(ctx, apiKey.GroupID, reqModel)
if restricted {
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "model not allowed")
return
}
channelMappingWS, _ := h.gatewayService.ResolveChannelMappingAndRestrict(ctx, apiKey.GroupID, reqModel)
var currentUserRelease func()
var currentAccountRelease func()