refactor: move channel model restriction from handler to scheduling phase
Move the model pricing restriction check from 8 handler entry points to the account scheduling phase (SelectAccountForModelWithExclusions / SelectAccountWithLoadAwareness), aligning restriction with billing: - requested: check original request model against pricing list - channel_mapped: check channel-mapped model against pricing list - upstream: per-account check using account-mapped model Handler layer now only resolves channel mapping (no restriction). Scheduling layer performs pre-check for requested/channel_mapped, and per-account filtering for upstream billing source.
This commit is contained in:
@@ -1178,6 +1178,11 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
|
||||
|
||||
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
|
||||
func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
|
||||
// 渠道定价限制预检查(requested / channel_mapped 基准)
|
||||
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
|
||||
return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
|
||||
}
|
||||
|
||||
// 优先检查 context 中的强制平台(/antigravity 路由)
|
||||
var platform string
|
||||
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
|
||||
@@ -1208,8 +1213,15 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
|
||||
}
|
||||
|
||||
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
|
||||
// metadataUserID: 已废弃参数,会话限制现在统一使用 sessionHash
|
||||
func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string) (*AccountSelectionResult, error) {
|
||||
// 调度流程文档见 docs/ACCOUNT_SCHEDULING_FLOW.md 。
|
||||
// metadataUserID: 用于客户端亲和调度,从中提取客户端 ID
|
||||
// sub2apiUserID: 系统用户 ID,用于二维亲和调度
|
||||
func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string, sub2apiUserID int64) (*AccountSelectionResult, error) {
|
||||
// 渠道定价限制预检查(requested / channel_mapped 基准)
|
||||
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
|
||||
return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
|
||||
}
|
||||
|
||||
// 调试日志:记录调度入口参数
|
||||
excludedIDsList := make([]int64, 0, len(excludedIDs))
|
||||
for id := range excludedIDs {
|
||||
@@ -2955,6 +2967,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
ctx = s.withRPMPrefetch(ctx, accounts)
|
||||
|
||||
// 3. 按优先级+最久未用选择(考虑模型支持)
|
||||
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
|
||||
var selected *Account
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
@@ -2975,6 +2988,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
@@ -3207,6 +3223,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
ctx = s.withRPMPrefetch(ctx, accounts)
|
||||
|
||||
// 3. 按优先级+最久未用选择(考虑模型支持和混合调度)
|
||||
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
|
||||
var selected *Account
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
@@ -3231,6 +3248,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
@@ -8212,6 +8232,67 @@ func (s *GatewayService) ResolveChannelMappingAndRestrict(ctx context.Context, g
|
||||
return s.channelService.ResolveChannelMappingAndRestrict(ctx, groupID, model)
|
||||
}
|
||||
|
||||
// checkChannelPricingRestriction 根据渠道计费基准检查模型是否受定价列表限制。
|
||||
// 供调度阶段预检查(requested / channel_mapped)。
|
||||
// upstream 需逐账号检查,此处返回 false。
|
||||
func (s *GatewayService) checkChannelPricingRestriction(ctx context.Context, groupID *int64, requestedModel string) bool {
|
||||
if groupID == nil || s.channelService == nil || requestedModel == "" {
|
||||
return false
|
||||
}
|
||||
mapping := s.channelService.ResolveChannelMapping(ctx, *groupID, requestedModel)
|
||||
billingModel := billingModelForRestriction(mapping.BillingModelSource, requestedModel, mapping.MappedModel)
|
||||
if billingModel == "" {
|
||||
return false
|
||||
}
|
||||
return s.channelService.IsModelRestricted(ctx, *groupID, billingModel)
|
||||
}
|
||||
|
||||
// billingModelForRestriction 根据计费基准确定限制检查使用的模型。
|
||||
// upstream 返回空(需逐账号检查)。
|
||||
func billingModelForRestriction(source, requestedModel, channelMappedModel string) string {
|
||||
switch source {
|
||||
case BillingModelSourceRequested:
|
||||
return requestedModel
|
||||
case BillingModelSourceUpstream:
|
||||
return ""
|
||||
default: // channel_mapped
|
||||
return channelMappedModel
|
||||
}
|
||||
}
|
||||
|
||||
// isUpstreamModelRestrictedByChannel 检查账号映射后的上游模型是否受渠道定价限制。
|
||||
// 仅在 BillingModelSource="upstream" 且 RestrictModels=true 时由调度循环调用。
|
||||
func (s *GatewayService) isUpstreamModelRestrictedByChannel(ctx context.Context, groupID int64, account *Account, requestedModel string) bool {
|
||||
if s.channelService == nil {
|
||||
return false
|
||||
}
|
||||
upstreamModel := resolveAccountUpstreamModel(account, requestedModel)
|
||||
if upstreamModel == "" {
|
||||
return false
|
||||
}
|
||||
return s.channelService.IsModelRestricted(ctx, groupID, upstreamModel)
|
||||
}
|
||||
|
||||
// resolveAccountUpstreamModel 确定账号将请求模型映射为什么上游模型。
|
||||
func resolveAccountUpstreamModel(account *Account, requestedModel string) string {
|
||||
if account.Platform == PlatformAntigravity {
|
||||
return mapAntigravityModel(account, requestedModel)
|
||||
}
|
||||
return account.GetMappedModel(requestedModel)
|
||||
}
|
||||
|
||||
// needsUpstreamChannelRestrictionCheck 判断是否需要在调度循环中逐账号检查上游模型的渠道限制。
|
||||
func (s *GatewayService) needsUpstreamChannelRestrictionCheck(ctx context.Context, groupID *int64) bool {
|
||||
if groupID == nil || s.channelService == nil {
|
||||
return false
|
||||
}
|
||||
ch, err := s.channelService.GetChannelForGroup(ctx, *groupID)
|
||||
if err != nil || ch == nil || !ch.RestrictModels {
|
||||
return false
|
||||
}
|
||||
return ch.BillingModelSource == BillingModelSourceUpstream
|
||||
}
|
||||
|
||||
// ForwardCountTokens 转发 count_tokens 请求到上游 API
|
||||
// 特点:不记录使用量、仅支持非流式响应
|
||||
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error {
|
||||
|
||||
Reference in New Issue
Block a user