refactor: move channel model restriction from handler to scheduling phase
Move the model pricing restriction check from 8 handler entry points to the account scheduling phase (SelectAccountForModelWithExclusions / SelectAccountWithLoadAwareness), aligning restriction with billing: - requested: check original request model against pricing list - channel_mapped: check channel-mapped model against pricing list - upstream: per-account check using account-mapped model Handler layer now only resolves channel mapping (no restriction). Scheduling layer performs pre-check for requested/channel_mapped, and per-account filtering for upstream billing source.
This commit is contained in:
@@ -80,7 +80,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||||
|
|
||||||
// 解析渠道级模型映射
|
// 解析渠道级模型映射 + 限制检查
|
||||||
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
||||||
|
|
||||||
// Claude Code only restriction
|
// Claude Code only restriction
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
|
|||||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||||
|
|
||||||
// 解析渠道级模型映射
|
// 解析渠道级模型映射 + 限制检查
|
||||||
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
||||||
|
|
||||||
// Claude Code only restriction:
|
// Claude Code only restriction:
|
||||||
|
|||||||
@@ -121,7 +121,7 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) {
|
|||||||
googleError(c, http.StatusBadGateway, err.Error())
|
googleError(c, http.StatusBadGateway, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if shouldFallbackGeminiModel(modelName, res) {
|
if shouldFallbackGeminiModels(res) {
|
||||||
c.JSON(http.StatusOK, gemini.FallbackModel(modelName))
|
c.JSON(http.StatusOK, gemini.FallbackModel(modelName))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -184,7 +184,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
setOpsRequestContext(c, modelName, stream, body)
|
setOpsRequestContext(c, modelName, stream, body)
|
||||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(stream, false)))
|
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(stream, false)))
|
||||||
|
|
||||||
// 解析渠道级模型映射
|
// 解析渠道级模型映射 + 限制检查
|
||||||
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, modelName)
|
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, modelName)
|
||||||
reqModel := modelName // 保存映射前的原始模型名
|
reqModel := modelName // 保存映射前的原始模型名
|
||||||
if channelMapping.Mapped {
|
if channelMapping.Mapped {
|
||||||
@@ -682,16 +682,6 @@ func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func shouldFallbackGeminiModel(modelName string, res *service.UpstreamHTTPResult) bool {
|
|
||||||
if shouldFallbackGeminiModels(res) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if res == nil || res.StatusCode != http.StatusNotFound {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return gemini.HasFallbackModel(modelName)
|
|
||||||
}
|
|
||||||
|
|
||||||
// extractGeminiCLISessionHash 从 Gemini CLI 请求中提取会话标识。
|
// extractGeminiCLISessionHash 从 Gemini CLI 请求中提取会话标识。
|
||||||
// 组合 x-gemini-api-privileged-user-id header 和请求体中的 tmp 目录哈希。
|
// 组合 x-gemini-api-privileged-user-id header 和请求体中的 tmp 目录哈希。
|
||||||
//
|
//
|
||||||
|
|||||||
@@ -79,7 +79,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||||
|
|
||||||
// 解析渠道级模型映射
|
// 解析渠道级模型映射 + 限制检查
|
||||||
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
||||||
|
|
||||||
if h.errorPassthroughService != nil {
|
if h.errorPassthroughService != nil {
|
||||||
|
|||||||
@@ -47,13 +47,6 @@ func resolveOpenAIForwardDefaultMappedModel(apiKey *service.APIKey, fallbackMode
|
|||||||
return strings.TrimSpace(apiKey.Group.DefaultMappedModel)
|
return strings.TrimSpace(apiKey.Group.DefaultMappedModel)
|
||||||
}
|
}
|
||||||
|
|
||||||
func resolveOpenAIMessagesDispatchMappedModel(apiKey *service.APIKey, requestedModel string) string {
|
|
||||||
if apiKey == nil || apiKey.Group == nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return strings.TrimSpace(apiKey.Group.ResolveMessagesDispatchModel(requestedModel))
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
|
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
|
||||||
func NewOpenAIGatewayHandler(
|
func NewOpenAIGatewayHandler(
|
||||||
gatewayService *service.OpenAIGatewayService,
|
gatewayService *service.OpenAIGatewayService,
|
||||||
@@ -557,8 +550,6 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
reqModel := modelResult.String()
|
reqModel := modelResult.String()
|
||||||
routingModel := service.NormalizeOpenAICompatRequestedModel(reqModel)
|
|
||||||
preferredMappedModel := resolveOpenAIMessagesDispatchMappedModel(apiKey, reqModel)
|
|
||||||
reqStream := gjson.GetBytes(body, "stream").Bool()
|
reqStream := gjson.GetBytes(body, "stream").Bool()
|
||||||
|
|
||||||
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
|
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
|
||||||
@@ -617,20 +608,17 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
|||||||
failedAccountIDs := make(map[int64]struct{})
|
failedAccountIDs := make(map[int64]struct{})
|
||||||
sameAccountRetryCount := make(map[int64]int)
|
sameAccountRetryCount := make(map[int64]int)
|
||||||
var lastFailoverErr *service.UpstreamFailoverError
|
var lastFailoverErr *service.UpstreamFailoverError
|
||||||
effectiveMappedModel := preferredMappedModel
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
currentRoutingModel := routingModel
|
// 清除上一次迭代的降级模型标记,避免残留影响本次迭代
|
||||||
if effectiveMappedModel != "" {
|
c.Set("openai_messages_fallback_model", "")
|
||||||
currentRoutingModel = effectiveMappedModel
|
|
||||||
}
|
|
||||||
reqLog.Debug("openai_messages.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
|
reqLog.Debug("openai_messages.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
|
||||||
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
|
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
|
||||||
c.Request.Context(),
|
c.Request.Context(),
|
||||||
apiKey.GroupID,
|
apiKey.GroupID,
|
||||||
"", // no previous_response_id
|
"", // no previous_response_id
|
||||||
sessionHash,
|
sessionHash,
|
||||||
currentRoutingModel,
|
reqModel,
|
||||||
failedAccountIDs,
|
failedAccountIDs,
|
||||||
service.OpenAIUpstreamTransportAny,
|
service.OpenAIUpstreamTransportAny,
|
||||||
)
|
)
|
||||||
@@ -639,7 +627,29 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
|||||||
zap.Error(err),
|
zap.Error(err),
|
||||||
zap.Int("excluded_account_count", len(failedAccountIDs)),
|
zap.Int("excluded_account_count", len(failedAccountIDs)),
|
||||||
)
|
)
|
||||||
|
// 首次调度失败 + 有默认映射模型 → 用默认模型重试
|
||||||
if len(failedAccountIDs) == 0 {
|
if len(failedAccountIDs) == 0 {
|
||||||
|
defaultModel := ""
|
||||||
|
if apiKey.Group != nil {
|
||||||
|
defaultModel = apiKey.Group.DefaultMappedModel
|
||||||
|
}
|
||||||
|
if defaultModel != "" && defaultModel != reqModel {
|
||||||
|
reqLog.Info("openai_messages.fallback_to_default_model",
|
||||||
|
zap.String("default_mapped_model", defaultModel),
|
||||||
|
)
|
||||||
|
selection, scheduleDecision, err = h.gatewayService.SelectAccountWithScheduler(
|
||||||
|
c.Request.Context(),
|
||||||
|
apiKey.GroupID,
|
||||||
|
"",
|
||||||
|
sessionHash,
|
||||||
|
defaultModel,
|
||||||
|
failedAccountIDs,
|
||||||
|
service.OpenAIUpstreamTransportAny,
|
||||||
|
)
|
||||||
|
if err == nil && selection != nil {
|
||||||
|
c.Set("openai_messages_fallback_model", defaultModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.anthropicStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
|
h.anthropicStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
|
||||||
return
|
return
|
||||||
@@ -671,7 +681,9 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
|||||||
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||||
forwardStart := time.Now()
|
forwardStart := time.Now()
|
||||||
|
|
||||||
defaultMappedModel := strings.TrimSpace(effectiveMappedModel)
|
// Forward 层需要始终拿到 group 默认映射模型,这样未命中账号级映射的
|
||||||
|
// Claude 兼容模型才不会在后续 Codex 规范化中意外退化到 gpt-5.1。
|
||||||
|
defaultMappedModel := resolveOpenAIForwardDefaultMappedModel(apiKey, c.GetString("openai_messages_fallback_model"))
|
||||||
// 应用渠道模型映射到请求体
|
// 应用渠道模型映射到请求体
|
||||||
forwardBody := body
|
forwardBody := body
|
||||||
if channelMappingMsg.Mapped {
|
if channelMappingMsg.Mapped {
|
||||||
@@ -1106,7 +1118,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
|||||||
setOpsRequestContext(c, reqModel, true, firstMessage)
|
setOpsRequestContext(c, reqModel, true, firstMessage)
|
||||||
setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2))
|
setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2))
|
||||||
|
|
||||||
// 解析渠道级模型映射
|
// 解析渠道级模型映射 + 限制检查
|
||||||
channelMappingWS, _ := h.gatewayService.ResolveChannelMappingAndRestrict(ctx, apiKey.GroupID, reqModel)
|
channelMappingWS, _ := h.gatewayService.ResolveChannelMappingAndRestrict(ctx, apiKey.GroupID, reqModel)
|
||||||
|
|
||||||
var currentUserRelease func()
|
var currentUserRelease func()
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user