fix: address review findings for channel restriction refactoring
- Fix 7 stale comments still mentioning "限制检查" in handlers/services - Make billingModelForRestriction explicitly list channel_mapped case - Add slog.Warn for error swallowing in ResolveChannelMapping and needsUpstreamChannelRestrictionCheck - Document sticky session upstream check exemption
This commit is contained in:
@@ -80,7 +80,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||||
|
|
||||||
// 解析渠道级模型映射 + 限制检查
|
// 解析渠道级模型映射
|
||||||
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
||||||
|
|
||||||
// Claude Code only restriction
|
// Claude Code only restriction
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
|
|||||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||||
|
|
||||||
// 解析渠道级模型映射 + 限制检查
|
// 解析渠道级模型映射
|
||||||
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
||||||
|
|
||||||
// Claude Code only restriction:
|
// Claude Code only restriction:
|
||||||
|
|||||||
@@ -184,7 +184,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
setOpsRequestContext(c, modelName, stream, body)
|
setOpsRequestContext(c, modelName, stream, body)
|
||||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(stream, false)))
|
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(stream, false)))
|
||||||
|
|
||||||
// 解析渠道级模型映射 + 限制检查
|
// 解析渠道级模型映射
|
||||||
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, modelName)
|
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, modelName)
|
||||||
reqModel := modelName // 保存映射前的原始模型名
|
reqModel := modelName // 保存映射前的原始模型名
|
||||||
if channelMapping.Mapped {
|
if channelMapping.Mapped {
|
||||||
|
|||||||
@@ -79,7 +79,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||||
|
|
||||||
// 解析渠道级模型映射 + 限制检查
|
// 解析渠道级模型映射
|
||||||
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
||||||
|
|
||||||
if h.errorPassthroughService != nil {
|
if h.errorPassthroughService != nil {
|
||||||
|
|||||||
@@ -1119,7 +1119,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
|||||||
setOpsRequestContext(c, reqModel, true, firstMessage)
|
setOpsRequestContext(c, reqModel, true, firstMessage)
|
||||||
setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2))
|
setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2))
|
||||||
|
|
||||||
// 解析渠道级模型映射 + 限制检查
|
// 解析渠道级模型映射
|
||||||
channelMappingWS, _ := h.gatewayService.ResolveChannelMappingAndRestrict(ctx, apiKey.GroupID, reqModel)
|
channelMappingWS, _ := h.gatewayService.ResolveChannelMappingAndRestrict(ctx, apiKey.GroupID, reqModel)
|
||||||
|
|
||||||
var currentUserRelease func()
|
var currentUserRelease func()
|
||||||
|
|||||||
@@ -418,7 +418,10 @@ func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int
|
|||||||
// ResolveChannelMapping 解析渠道级模型映射(热路径 O(1))
|
// ResolveChannelMapping 解析渠道级模型映射(热路径 O(1))
|
||||||
// 返回映射结果,包含映射后的模型名、渠道 ID、计费模型来源。
|
// 返回映射结果,包含映射后的模型名、渠道 ID、计费模型来源。
|
||||||
func (s *ChannelService) ResolveChannelMapping(ctx context.Context, groupID int64, model string) ChannelMappingResult {
|
func (s *ChannelService) ResolveChannelMapping(ctx context.Context, groupID int64, model string) ChannelMappingResult {
|
||||||
lk, _ := s.lookupGroupChannel(ctx, groupID)
|
lk, err := s.lookupGroupChannel(ctx, groupID)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("failed to load channel cache for mapping", "group_id", groupID, "error", err)
|
||||||
|
}
|
||||||
if lk == nil {
|
if lk == nil {
|
||||||
return ChannelMappingResult{MappedModel: model}
|
return ChannelMappingResult{MappedModel: model}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2967,6 +2967,8 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
|||||||
ctx = s.withRPMPrefetch(ctx, accounts)
|
ctx = s.withRPMPrefetch(ctx, accounts)
|
||||||
|
|
||||||
// 3. 按优先级+最久未用选择(考虑模型支持)
|
// 3. 按优先级+最久未用选择(考虑模型支持)
|
||||||
|
// needsUpstreamCheck 仅在主选择循环中使用;粘性会话命中时跳过此检查,
|
||||||
|
// 因为粘性会话优先保持连接一致性,且 upstream 计费基准极少使用。
|
||||||
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
|
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
|
||||||
var selected *Account
|
var selected *Account
|
||||||
for i := range accounts {
|
for i := range accounts {
|
||||||
@@ -3223,6 +3225,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
|||||||
ctx = s.withRPMPrefetch(ctx, accounts)
|
ctx = s.withRPMPrefetch(ctx, accounts)
|
||||||
|
|
||||||
// 3. 按优先级+最久未用选择(考虑模型支持和混合调度)
|
// 3. 按优先级+最久未用选择(考虑模型支持和混合调度)
|
||||||
|
// needsUpstreamCheck 仅在主选择循环中使用;粘性会话命中时跳过此检查。
|
||||||
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
|
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
|
||||||
var selected *Account
|
var selected *Account
|
||||||
for i := range accounts {
|
for i := range accounts {
|
||||||
@@ -8223,8 +8226,8 @@ func (s *GatewayService) IsModelRestricted(ctx context.Context, groupID int64, m
|
|||||||
return s.channelService.IsModelRestricted(ctx, groupID, model)
|
return s.channelService.IsModelRestricted(ctx, groupID, model)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ResolveChannelMappingAndRestrict 解析渠道映射并检查模型限制。
|
// ResolveChannelMappingAndRestrict 解析渠道映射。
|
||||||
// 返回映射结果和是否被限制。
|
// 模型限制检查已移至调度阶段(checkChannelPricingRestriction),restricted 始终返回 false。
|
||||||
func (s *GatewayService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) {
|
func (s *GatewayService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) {
|
||||||
if s.channelService == nil {
|
if s.channelService == nil {
|
||||||
return ChannelMappingResult{MappedModel: model}, false
|
return ChannelMappingResult{MappedModel: model}, false
|
||||||
@@ -8255,7 +8258,9 @@ func billingModelForRestriction(source, requestedModel, channelMappedModel strin
|
|||||||
return requestedModel
|
return requestedModel
|
||||||
case BillingModelSourceUpstream:
|
case BillingModelSourceUpstream:
|
||||||
return ""
|
return ""
|
||||||
default: // channel_mapped
|
case BillingModelSourceChannelMapped:
|
||||||
|
return channelMappedModel
|
||||||
|
default:
|
||||||
return channelMappedModel
|
return channelMappedModel
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -8287,7 +8292,11 @@ func (s *GatewayService) needsUpstreamChannelRestrictionCheck(ctx context.Contex
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
ch, err := s.channelService.GetChannelForGroup(ctx, *groupID)
|
ch, err := s.channelService.GetChannelForGroup(ctx, *groupID)
|
||||||
if err != nil || ch == nil || !ch.RestrictModels {
|
if err != nil {
|
||||||
|
slog.Warn("failed to check channel upstream restriction", "group_id", *groupID, "error", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if ch == nil || !ch.RestrictModels {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return ch.BillingModelSource == BillingModelSourceUpstream
|
return ch.BillingModelSource == BillingModelSourceUpstream
|
||||||
|
|||||||
@@ -414,8 +414,8 @@ func (s *OpenAIGatewayService) IsModelRestricted(ctx context.Context, groupID in
|
|||||||
return s.channelService.IsModelRestricted(ctx, groupID, model)
|
return s.channelService.IsModelRestricted(ctx, groupID, model)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ResolveChannelMappingAndRestrict 解析渠道映射并检查模型限制。
|
// ResolveChannelMappingAndRestrict 解析渠道映射。
|
||||||
// 返回映射结果和是否被限制。
|
// 模型限制检查已移至调度阶段,restricted 始终返回 false。
|
||||||
func (s *OpenAIGatewayService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) {
|
func (s *OpenAIGatewayService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) {
|
||||||
if s.channelService == nil {
|
if s.channelService == nil {
|
||||||
return ChannelMappingResult{MappedModel: model}, false
|
return ChannelMappingResult{MappedModel: model}, false
|
||||||
|
|||||||
Reference in New Issue
Block a user