refactor: move channel model restriction from handler to scheduling phase

Move the model pricing restriction check from 8 handler entry points
to the account scheduling phase (SelectAccountForModelWithExclusions /
SelectAccountWithLoadAwareness), aligning restriction with billing:

- requested: check original request model against pricing list
- channel_mapped: check channel-mapped model against pricing list
- upstream: per-account check using account-mapped model

Handler layer now only resolves channel mapping (no restriction).
Scheduling layer performs pre-check for requested/channel_mapped,
and per-account filtering for upstream billing source.
This commit is contained in:
erio
2026-04-02 13:24:30 +08:00
parent b4a42a640d
commit ce41afb756
9 changed files with 104 additions and 84 deletions

View File

@@ -1178,6 +1178,11 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
// 渠道定价限制预检查requested / channel_mapped 基准)
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
}
// 优先检查 context 中的强制平台(/antigravity 路由)
var platform string
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
@@ -1208,8 +1213,15 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
}
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
// metadataUserID: 已废弃参数,会话限制现在统一使用 sessionHash
func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string) (*AccountSelectionResult, error) {
// 调度流程文档见 docs/ACCOUNT_SCHEDULING_FLOW.md 。
// metadataUserID: 用于客户端亲和调度,从中提取客户端 ID
// sub2apiUserID: 系统用户 ID用于二维亲和调度
func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string, sub2apiUserID int64) (*AccountSelectionResult, error) {
// 渠道定价限制预检查requested / channel_mapped 基准)
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
}
// 调试日志:记录调度入口参数
excludedIDsList := make([]int64, 0, len(excludedIDs))
for id := range excludedIDs {
@@ -2955,6 +2967,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
ctx = s.withRPMPrefetch(ctx, accounts)
// 3. 按优先级+最久未用选择(考虑模型支持)
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
var selected *Account
for i := range accounts {
acc := &accounts[i]
@@ -2975,6 +2988,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
continue
}
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, acc, requestedModel) {
continue
}
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
continue
}
@@ -3207,6 +3223,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
ctx = s.withRPMPrefetch(ctx, accounts)
// 3. 按优先级+最久未用选择(考虑模型支持和混合调度)
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
var selected *Account
for i := range accounts {
acc := &accounts[i]
@@ -3231,6 +3248,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
continue
}
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, acc, requestedModel) {
continue
}
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
continue
}
@@ -8212,6 +8232,67 @@ func (s *GatewayService) ResolveChannelMappingAndRestrict(ctx context.Context, g
return s.channelService.ResolveChannelMappingAndRestrict(ctx, groupID, model)
}
// checkChannelPricingRestriction 根据渠道计费基准检查模型是否受定价列表限制。
// 供调度阶段预检查requested / channel_mapped
// upstream 需逐账号检查,此处返回 false。
func (s *GatewayService) checkChannelPricingRestriction(ctx context.Context, groupID *int64, requestedModel string) bool {
if groupID == nil || s.channelService == nil || requestedModel == "" {
return false
}
mapping := s.channelService.ResolveChannelMapping(ctx, *groupID, requestedModel)
billingModel := billingModelForRestriction(mapping.BillingModelSource, requestedModel, mapping.MappedModel)
if billingModel == "" {
return false
}
return s.channelService.IsModelRestricted(ctx, *groupID, billingModel)
}
// billingModelForRestriction 根据计费基准确定限制检查使用的模型。
// upstream 返回空(需逐账号检查)。
func billingModelForRestriction(source, requestedModel, channelMappedModel string) string {
switch source {
case BillingModelSourceRequested:
return requestedModel
case BillingModelSourceUpstream:
return ""
default: // channel_mapped
return channelMappedModel
}
}
// isUpstreamModelRestrictedByChannel 检查账号映射后的上游模型是否受渠道定价限制。
// 仅在 BillingModelSource="upstream" 且 RestrictModels=true 时由调度循环调用。
func (s *GatewayService) isUpstreamModelRestrictedByChannel(ctx context.Context, groupID int64, account *Account, requestedModel string) bool {
if s.channelService == nil {
return false
}
upstreamModel := resolveAccountUpstreamModel(account, requestedModel)
if upstreamModel == "" {
return false
}
return s.channelService.IsModelRestricted(ctx, groupID, upstreamModel)
}
// resolveAccountUpstreamModel 确定账号将请求模型映射为什么上游模型。
func resolveAccountUpstreamModel(account *Account, requestedModel string) string {
if account.Platform == PlatformAntigravity {
return mapAntigravityModel(account, requestedModel)
}
return account.GetMappedModel(requestedModel)
}
// needsUpstreamChannelRestrictionCheck 判断是否需要在调度循环中逐账号检查上游模型的渠道限制。
func (s *GatewayService) needsUpstreamChannelRestrictionCheck(ctx context.Context, groupID *int64) bool {
if groupID == nil || s.channelService == nil {
return false
}
ch, err := s.channelService.GetChannelForGroup(ctx, *groupID)
if err != nil || ch == nil || !ch.RestrictModels {
return false
}
return ch.BillingModelSource == BillingModelSourceUpstream
}
// ForwardCountTokens 转发 count_tokens 请求到上游 API
// 特点:不记录使用量、仅支持非流式响应
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error {