fix: address review findings for channel restriction refactoring

- Fix 7 stale comments still mentioning "限制检查" in handlers/services
- Make billingModelForRestriction explicitly list channel_mapped case
- Add slog.Warn for error swallowing in ResolveChannelMapping and
  needsUpstreamChannelRestrictionCheck
- Document sticky session upstream check exemption
This commit is contained in:
erio
2026-04-02 13:36:58 +08:00
parent 2dce4306b4
commit 160903fce7
6 changed files with 18 additions and 9 deletions

View File

@@ -80,7 +80,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// 解析渠道级模型映射 + 限制检查
// 解析渠道级模型映射
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
// Claude Code only restriction

View File

@@ -80,7 +80,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// 解析渠道级模型映射 + 限制检查
// 解析渠道级模型映射
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
// Claude Code only restriction:

View File

@@ -184,7 +184,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
setOpsRequestContext(c, modelName, stream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(stream, false)))
// 解析渠道级模型映射 + 限制检查
// 解析渠道级模型映射
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, modelName)
reqModel := modelName // 保存映射前的原始模型名
if channelMapping.Mapped {

View File

@@ -79,7 +79,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// 解析渠道级模型映射 + 限制检查
// 解析渠道级模型映射
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
if h.errorPassthroughService != nil {

View File

@@ -1118,7 +1118,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
setOpsRequestContext(c, reqModel, true, firstMessage)
setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2))
// 解析渠道级模型映射 + 限制检查
// 解析渠道级模型映射
channelMappingWS, _ := h.gatewayService.ResolveChannelMappingAndRestrict(ctx, apiKey.GroupID, reqModel)
var currentUserRelease func()

View File

@@ -3143,6 +3143,8 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
ctx = s.withRPMPrefetch(ctx, accounts)
// 3. 按优先级+最久未用选择(考虑模型支持)
// needsUpstreamCheck 仅在主选择循环中使用;粘性会话命中时跳过此检查,
// 因为粘性会话优先保持连接一致性,且 upstream 计费基准极少使用。
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
var selected *Account
for i := range accounts {
@@ -3381,6 +3383,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
ctx = s.withRPMPrefetch(ctx, accounts)
// 3. 按优先级+最久未用选择(考虑模型支持和混合调度)
// needsUpstreamCheck 仅在主选择循环中使用;粘性会话命中时跳过此检查。
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
var selected *Account
for i := range accounts {
@@ -8374,8 +8377,8 @@ func (s *GatewayService) IsModelRestricted(ctx context.Context, groupID int64, m
return s.channelService.IsModelRestricted(ctx, groupID, model)
}
// ResolveChannelMappingAndRestrict 解析渠道映射并检查模型限制
// 返回映射结果和是否被限制
// ResolveChannelMappingAndRestrict 解析渠道映射。
// 模型限制检查已移至调度阶段checkChannelPricingRestrictionrestricted 始终返回 false
func (s *GatewayService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) {
if s.channelService == nil {
return ChannelMappingResult{MappedModel: model}, false
@@ -8406,7 +8409,9 @@ func billingModelForRestriction(source, requestedModel, channelMappedModel strin
return requestedModel
case BillingModelSourceUpstream:
return ""
default: // channel_mapped
case BillingModelSourceChannelMapped:
return channelMappedModel
default:
return channelMappedModel
}
}
@@ -8438,7 +8443,11 @@ func (s *GatewayService) needsUpstreamChannelRestrictionCheck(ctx context.Contex
return false
}
ch, err := s.channelService.GetChannelForGroup(ctx, *groupID)
if err != nil || ch == nil || !ch.RestrictModels {
if err != nil {
slog.Warn("failed to check channel upstream restriction", "group_id", *groupID, "error", err)
return false
}
if ch == nil || !ch.RestrictModels {
return false
}
return ch.BillingModelSource == BillingModelSourceUpstream