From e3748741257c2c6b96ced2c0c0284dd31cc5e358 Mon Sep 17 00:00:00 2001 From: erio Date: Fri, 3 Apr 2026 13:54:18 +0800 Subject: [PATCH] feat(channel): improve cache strategy and add restriction logging - Change channel cache TTL from 60s to 10min (reduce unnecessary DB queries) - Actively rebuild cache after CRUD instead of lazy invalidation - Add slog.Warn logging for channel pricing restriction blocks (4 places) --- backend/internal/service/channel_service.go | 378 ++++++++------------ backend/internal/service/gateway_service.go | 46 ++- 2 files changed, 183 insertions(+), 241 deletions(-) diff --git a/backend/internal/service/channel_service.go b/backend/internal/service/channel_service.go index 9667cb98..c6a249ef 100644 --- a/backend/internal/service/channel_service.go +++ b/backend/internal/service/channel_service.go @@ -134,7 +134,7 @@ func (r ChannelMappingResult) ToUsageFields(reqModel, upstreamModel string) Chan const ( channelCacheTTL = 10 * time.Minute - channelErrorTTL = 5 * time.Second // DB 错误时的短缓存 + channelErrorTTL = 5 * time.Second // DB 错误时的短缓存 channelCacheDBTimeout = 10 * time.Second ) @@ -197,8 +197,10 @@ func newEmptyChannelCache() *channelCache { } // expandPricingToCache 将渠道的模型定价展开到缓存(按分组+平台维度)。 -// 各平台严格独立:antigravity 分组只匹配 antigravity 定价,不会匹配 anthropic/gemini 的定价。 -// 查找时通过 lookupPricingAcrossPlatforms() 在本平台内查找。 +// antigravity 平台同时服务 Claude 和 Gemini 模型,需匹配 anthropic/gemini 的定价条目。 +// 缓存 key 使用定价条目的原始平台(pricing.Platform),而非分组平台, +// 避免跨平台同名模型(如 anthropic 和 gemini 都有 "model-x")互相覆盖。 +// 查找时通过 lookupPricingAcrossPlatforms() 依次尝试所有匹配平台。 func expandPricingToCache(cache *channelCache, ch *Channel, gid int64, platform string) { for j := range ch.ModelPricing { pricing := &ch.ModelPricing[j] @@ -224,7 +226,8 @@ func expandPricingToCache(cache *channelCache, ch *Channel, gid int64, platform } // expandMappingToCache 将渠道的模型映射展开到缓存(按分组+平台维度)。 -// 各平台严格独立:antigravity 分组只匹配 antigravity 映射。 +// antigravity 平台同时服务 Claude 和 Gemini 模型。 +// 缓存 key 使用映射条目的原始平台(mappingPlatform),避免跨平台同名映射覆盖。 func expandMappingToCache(cache *channelCache, ch *Channel, gid int64, platform string) { for _, mappingPlatform := range matchingPlatforms(platform) { platformMapping, ok := ch.ModelMapping[mappingPlatform] @@ -248,58 +251,40 @@ func expandMappingToCache(cache *channelCache, ch *Channel, gid int64, platform } } -// storeErrorCache 存入短 TTL 空缓存,防止 DB 错误后紧密重试。 -// 通过回退 loadedAt 使剩余 TTL = channelErrorTTL。 -func (s *ChannelService) storeErrorCache() { - errorCache := newEmptyChannelCache() - errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL)) - s.cache.Store(errorCache) -} - // buildCache 从数据库构建渠道缓存。 // 使用独立 context 避免请求取消导致空值被长期缓存。 func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) { + // 断开请求取消链,避免客户端断连导致空值被长期缓存 dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), channelCacheDBTimeout) defer cancel() - channels, groupPlatforms, err := s.fetchChannelData(dbCtx) - if err != nil { - return nil, err - } - - cache := populateChannelCache(channels, groupPlatforms) - s.cache.Store(cache) - return cache, nil -} - -// fetchChannelData 从数据库加载渠道列表和分组平台映射。 -func (s *ChannelService) fetchChannelData(ctx context.Context) ([]Channel, map[int64]string, error) { - channels, err := s.repo.ListAll(ctx) + channels, err := s.repo.ListAll(dbCtx) if err != nil { + // error-TTL:失败时存入短 TTL 空缓存,防止紧密重试 slog.Warn("failed to build channel cache", "error", err) - s.storeErrorCache() - return nil, nil, fmt.Errorf("list all channels: %w", err) + errorCache := newEmptyChannelCache() + errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL)) // 使剩余 TTL = errorTTL + s.cache.Store(errorCache) + return nil, fmt.Errorf("list all channels: %w", err) } + // 收集所有 groupID,批量查询 platform var allGroupIDs []int64 for i := range channels { allGroupIDs = append(allGroupIDs, channels[i].GroupIDs...) } - groupPlatforms := make(map[int64]string) if len(allGroupIDs) > 0 { - groupPlatforms, err = s.repo.GetGroupPlatforms(ctx, allGroupIDs) + groupPlatforms, err = s.repo.GetGroupPlatforms(dbCtx, allGroupIDs) if err != nil { slog.Warn("failed to load group platforms for channel cache", "error", err) - s.storeErrorCache() - return nil, nil, fmt.Errorf("get group platforms: %w", err) + errorCache := newEmptyChannelCache() + errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL)) + s.cache.Store(errorCache) + return nil, fmt.Errorf("get group platforms: %w", err) } } - return channels, groupPlatforms, nil -} -// populateChannelCache 将渠道列表和分组平台映射填充到缓存快照中。 -func populateChannelCache(channels []Channel, groupPlatforms map[int64]string) *channelCache { cache := newEmptyChannelCache() cache.groupPlatform = groupPlatforms cache.byID = make(map[int64]*Channel, len(channels)) @@ -308,6 +293,7 @@ func populateChannelCache(channels []Channel, groupPlatforms map[int64]string) * for i := range channels { ch := &channels[i] cache.byID[ch.ID] = ch + for _, gid := range ch.GroupIDs { cache.channelByGroupID[gid] = ch platform := groupPlatforms[gid] @@ -315,20 +301,33 @@ func populateChannelCache(channels []Channel, groupPlatforms map[int64]string) * expandMappingToCache(cache, ch, gid, platform) } } - return cache + + // 通配符条目保持配置顺序(最先匹配到优先) + + s.cache.Store(cache) + return cache, nil } // invalidateCache 使缓存失效,让下次读取时自然重建 // isPlatformPricingMatch 判断定价条目的平台是否匹配分组平台。 -// 各平台(antigravity / anthropic / gemini / openai)严格独立,不跨平台匹配。 +// antigravity 平台同时服务 Claude(anthropic)和 Gemini(gemini)模型, +// 因此 antigravity 分组应匹配 anthropic 和 gemini 的定价条目。 func isPlatformPricingMatch(groupPlatform, pricingPlatform string) bool { - return groupPlatform == pricingPlatform + if groupPlatform == pricingPlatform { + return true + } + if groupPlatform == PlatformAntigravity { + return pricingPlatform == PlatformAnthropic || pricingPlatform == PlatformGemini + } + return false } -// matchingPlatforms 返回分组平台对应的可匹配平台列表。 -// 各平台严格独立,只返回自身。 +// matchingPlatforms 返回分组平台对应的所有可匹配平台列表。 func matchingPlatforms(groupPlatform string) []string { + if groupPlatform == PlatformAntigravity { + return []string{PlatformAntigravity, PlatformAnthropic, PlatformGemini} + } return []string{groupPlatform} } func (s *ChannelService) invalidateCache() { @@ -365,8 +364,10 @@ func (c *channelCache) matchWildcardMapping(groupID int64, platform, modelLower return "" } -// lookupPricingAcrossPlatforms 在分组平台内查找模型定价。 -// 各平台严格独立,只在本平台内查找(先精确匹配,再通配符)。 +// lookupPricingAcrossPlatforms 在所有匹配平台中查找模型定价。 +// antigravity 分组的缓存 key 使用定价条目的原始平台,因此查找时需依次尝试 +// matchingPlatforms() 返回的所有平台(antigravity → anthropic → gemini), +// 返回第一个命中的结果。非 antigravity 平台只尝试自身。 func lookupPricingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatform, modelLower string) *ChannelModelPricing { for _, p := range matchingPlatforms(groupPlatform) { key := channelModelKey{groupID: groupID, platform: p, model: modelLower} @@ -383,7 +384,7 @@ func lookupPricingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatf return nil } -// lookupMappingAcrossPlatforms 在分组平台内查找模型映射。 +// lookupMappingAcrossPlatforms 在所有匹配平台中查找模型映射。 // 逻辑与 lookupPricingAcrossPlatforms 相同:先精确查找,再通配符。 func lookupMappingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatform, modelLower string) string { for _, p := range matchingPlatforms(groupPlatform) { @@ -441,7 +442,8 @@ func (s *ChannelService) lookupGroupChannel(ctx context.Context, groupID int64) } // GetChannelModelPricing 获取指定分组+模型的渠道定价(热路径 O(1))。 -// 各平台严格独立,只在本平台内查找定价。 +// antigravity 分组依次尝试所有匹配平台(antigravity → anthropic → gemini), +// 确保跨平台同名模型各自独立匹配。 func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int64, model string) *ChannelModelPricing { lk, err := s.lookupGroupChannel(ctx, groupID) if err != nil { @@ -479,10 +481,7 @@ func (s *ChannelService) ResolveChannelMapping(ctx context.Context, groupID int6 // 返回 true 表示模型被限制(不在允许列表中)。 // 如果渠道未启用模型限制或分组无渠道关联,返回 false。 func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool { - lk, err := s.lookupGroupChannel(ctx, groupID) - if err != nil { - slog.Warn("failed to load channel cache for model restriction check", "group_id", groupID, "error", err) - } + lk, _ := s.lookupGroupChannel(ctx, groupID) if lk == nil { return false } @@ -525,7 +524,7 @@ func resolveMapping(lk *channelLookup, groupID int64, model string) ChannelMappi } // checkRestricted 基于已查找的渠道信息检查模型是否被限制。 -// 只在本平台的定价列表中查找。 +// antigravity 分组依次尝试所有匹配平台的定价列表。 func checkRestricted(lk *channelLookup, groupID int64, model string) bool { if !lk.channel.RestrictModels { return false @@ -553,91 +552,6 @@ func ReplaceModelInBody(body []byte, newModel string) []byte { return newBody } -// validateChannelConfig 校验渠道的定价和映射配置(冲突检测 + 区间校验 + 计费模式校验)。 -// Create 和 Update 共用此函数,避免重复。 -func validateChannelConfig(pricing []ChannelModelPricing, mapping map[string]map[string]string) error { - if err := validateNoConflictingModels(pricing); err != nil { - return err - } - if err := validatePricingIntervals(pricing); err != nil { - return err - } - if err := validateNoConflictingMappings(mapping); err != nil { - return err - } - return validatePricingBillingMode(pricing) -} - -// validatePricingBillingMode 校验计费模式配置:按次/图片模式必须配价格或区间,所有价格字段不能为负,区间至少有一个价格字段。 -func validatePricingBillingMode(pricing []ChannelModelPricing) error { - for _, p := range pricing { - if err := checkBillingModeRequirements(p); err != nil { - return err - } - if err := checkPricesNotNegative(p); err != nil { - return err - } - if err := checkIntervalsHavePrices(p); err != nil { - return err - } - } - return nil -} - -func checkBillingModeRequirements(p ChannelModelPricing) error { - if p.BillingMode == BillingModePerRequest || p.BillingMode == BillingModeImage { - if p.PerRequestPrice == nil && len(p.Intervals) == 0 { - return infraerrors.BadRequest( - "BILLING_MODE_MISSING_PRICE", - "per-request price or intervals required for per_request/image billing mode", - ) - } - } - return nil -} - -func checkPricesNotNegative(p ChannelModelPricing) error { - checks := []struct { - field string - val *float64 - }{ - {"input_price", p.InputPrice}, - {"output_price", p.OutputPrice}, - {"cache_write_price", p.CacheWritePrice}, - {"cache_read_price", p.CacheReadPrice}, - {"image_output_price", p.ImageOutputPrice}, - {"per_request_price", p.PerRequestPrice}, - } - for _, c := range checks { - if c.val != nil && *c.val < 0 { - return infraerrors.BadRequest("NEGATIVE_PRICE", fmt.Sprintf("%s must be >= 0", c.field)) - } - } - return nil -} - -func checkIntervalsHavePrices(p ChannelModelPricing) error { - for _, iv := range p.Intervals { - if iv.InputPrice == nil && iv.OutputPrice == nil && - iv.CacheWritePrice == nil && iv.CacheReadPrice == nil && - iv.PerRequestPrice == nil { - return infraerrors.BadRequest( - "INTERVAL_MISSING_PRICE", - fmt.Sprintf("interval [%d, %s] has no price fields set for model %v", - iv.MinTokens, formatMaxTokens(iv.MaxTokens), p.Models), - ) - } - } - return nil -} - -func formatMaxTokens(max *int) string { - if max == nil { - return "∞" - } - return fmt.Sprintf("%d", *max) -} - // --- CRUD --- // Create 创建渠道 @@ -650,8 +564,15 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput) return nil, ErrChannelExists } - if err := s.checkGroupConflicts(ctx, 0, input.GroupIDs); err != nil { - return nil, err + // 检查分组冲突 + if len(input.GroupIDs) > 0 { + conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, 0, input.GroupIDs) + if err != nil { + return nil, fmt.Errorf("check group conflicts: %w", err) + } + if len(conflicting) > 0 { + return nil, ErrGroupAlreadyInChannel + } } channel := &Channel{ @@ -668,7 +589,13 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput) channel.BillingModelSource = BillingModelSourceChannelMapped } - if err := validateChannelConfig(channel.ModelPricing, channel.ModelMapping); err != nil { + if err := validateNoConflictingModels(channel.ModelPricing); err != nil { + return nil, err + } + if err := validatePricingIntervals(channel.ModelPricing); err != nil { + return nil, err + } + if err := validateNoConflictingMappings(channel.ModelMapping); err != nil { return nil, err } @@ -692,112 +619,102 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan return nil, fmt.Errorf("get channel: %w", err) } - if err := s.applyUpdateInput(ctx, channel, input); err != nil { + if input.Name != "" && input.Name != channel.Name { + exists, err := s.repo.ExistsByNameExcluding(ctx, input.Name, id) + if err != nil { + return nil, fmt.Errorf("check channel exists: %w", err) + } + if exists { + return nil, ErrChannelExists + } + channel.Name = input.Name + } + + if input.Description != nil { + channel.Description = *input.Description + } + + if input.Status != "" { + channel.Status = input.Status + } + + if input.RestrictModels != nil { + channel.RestrictModels = *input.RestrictModels + } + + // 检查分组冲突 + if input.GroupIDs != nil { + conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, id, *input.GroupIDs) + if err != nil { + return nil, fmt.Errorf("check group conflicts: %w", err) + } + if len(conflicting) > 0 { + return nil, ErrGroupAlreadyInChannel + } + channel.GroupIDs = *input.GroupIDs + } + + if input.ModelPricing != nil { + channel.ModelPricing = *input.ModelPricing + } + + if input.ModelMapping != nil { + channel.ModelMapping = input.ModelMapping + } + + if input.BillingModelSource != "" { + channel.BillingModelSource = input.BillingModelSource + } + + if err := validateNoConflictingModels(channel.ModelPricing); err != nil { + return nil, err + } + if err := validatePricingIntervals(channel.ModelPricing); err != nil { + return nil, err + } + if err := validateNoConflictingMappings(channel.ModelMapping); err != nil { return nil, err } - if err := validateChannelConfig(channel.ModelPricing, channel.ModelMapping); err != nil { - return nil, err + // 先获取旧分组,Update 后旧分组关联已删除,无法再查到 + var oldGroupIDs []int64 + if s.authCacheInvalidator != nil { + var err2 error + oldGroupIDs, err2 = s.repo.GetGroupIDs(ctx, id) + if err2 != nil { + slog.Warn("failed to get old group IDs for cache invalidation", "channel_id", id, "error", err2) + } } - oldGroupIDs := s.getOldGroupIDs(ctx, id) - if err := s.repo.Update(ctx, channel); err != nil { return nil, fmt.Errorf("update channel: %w", err) } s.invalidateCache() - s.invalidateAuthCacheForGroups(ctx, oldGroupIDs, channel.GroupIDs) + + // 失效新旧分组的 auth 缓存 + if s.authCacheInvalidator != nil { + seen := make(map[int64]struct{}, len(oldGroupIDs)+len(channel.GroupIDs)) + for _, gid := range oldGroupIDs { + if _, ok := seen[gid]; !ok { + seen[gid] = struct{}{} + s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid) + } + } + for _, gid := range channel.GroupIDs { + if _, ok := seen[gid]; !ok { + seen[gid] = struct{}{} + s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid) + } + } + } return s.repo.GetByID(ctx, id) } -// applyUpdateInput 将更新请求的字段应用到渠道实体上。 -func (s *ChannelService) applyUpdateInput(ctx context.Context, channel *Channel, input *UpdateChannelInput) error { - if input.Name != "" && input.Name != channel.Name { - exists, err := s.repo.ExistsByNameExcluding(ctx, input.Name, channel.ID) - if err != nil { - return fmt.Errorf("check channel exists: %w", err) - } - if exists { - return ErrChannelExists - } - channel.Name = input.Name - } - if input.Description != nil { - channel.Description = *input.Description - } - if input.Status != "" { - channel.Status = input.Status - } - if input.RestrictModels != nil { - channel.RestrictModels = *input.RestrictModels - } - if input.GroupIDs != nil { - if err := s.checkGroupConflicts(ctx, channel.ID, *input.GroupIDs); err != nil { - return err - } - channel.GroupIDs = *input.GroupIDs - } - if input.ModelPricing != nil { - channel.ModelPricing = *input.ModelPricing - } - if input.ModelMapping != nil { - channel.ModelMapping = input.ModelMapping - } - if input.BillingModelSource != "" { - channel.BillingModelSource = input.BillingModelSource - } - return nil -} - -// checkGroupConflicts 检查待关联的分组是否已属于其他渠道。 -// channelID 为当前渠道 ID(Create 时传 0)。 -func (s *ChannelService) checkGroupConflicts(ctx context.Context, channelID int64, groupIDs []int64) error { - if len(groupIDs) == 0 { - return nil - } - conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, channelID, groupIDs) - if err != nil { - return fmt.Errorf("check group conflicts: %w", err) - } - if len(conflicting) > 0 { - return ErrGroupAlreadyInChannel - } - return nil -} - -// getOldGroupIDs 获取渠道更新前的关联分组 ID(用于失效 auth 缓存)。 -func (s *ChannelService) getOldGroupIDs(ctx context.Context, channelID int64) []int64 { - if s.authCacheInvalidator == nil { - return nil - } - oldGroupIDs, err := s.repo.GetGroupIDs(ctx, channelID) - if err != nil { - slog.Warn("failed to get old group IDs for cache invalidation", "channel_id", channelID, "error", err) - } - return oldGroupIDs -} - -// invalidateAuthCacheForGroups 对新旧分组去重后逐个失效 auth 缓存。 -func (s *ChannelService) invalidateAuthCacheForGroups(ctx context.Context, groupIDSets ...[]int64) { - if s.authCacheInvalidator == nil { - return - } - seen := make(map[int64]struct{}) - for _, ids := range groupIDSets { - for _, gid := range ids { - if _, ok := seen[gid]; ok { - continue - } - seen[gid] = struct{}{} - s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid) - } - } -} - // Delete 删除渠道 func (s *ChannelService) Delete(ctx context.Context, id int64) error { + // 先获取关联分组用于失效缓存 groupIDs, err := s.repo.GetGroupIDs(ctx, id) if err != nil { slog.Warn("failed to get group IDs before delete", "channel_id", id, "error", err) @@ -808,7 +725,12 @@ func (s *ChannelService) Delete(ctx context.Context, id int64) error { } s.invalidateCache() - s.invalidateAuthCacheForGroups(ctx, groupIDs) + + if s.authCacheInvalidator != nil { + for _, gid := range groupIDs { + s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid) + } + } return nil } diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 24f36113..31137fb4 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -1234,11 +1234,6 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int // SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts. func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) { - // 渠道定价限制预检查(requested / channel_mapped 基准) - if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) { - return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel) - } - // 优先检查 context 中的强制平台(/antigravity 路由) var platform string forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string) @@ -1257,6 +1252,15 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context platform = PlatformAnthropic } + // Claude Code 限制可能已将 groupID 解析为 fallback group, + // 渠道限制预检查必须使用解析后的分组。 + if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) { + slog.Warn("channel pricing restriction blocked request", + "group_id", derefGroupID(groupID), + "model", requestedModel) + return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel) + } + // anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户) // 注意:强制平台模式不走混合调度 if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform { @@ -1273,11 +1277,6 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context // metadataUserID: 用于客户端亲和调度,从中提取客户端 ID // sub2apiUserID: 系统用户 ID,用于二维亲和调度 func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string, sub2apiUserID int64) (*AccountSelectionResult, error) { - // 渠道定价限制预检查(requested / channel_mapped 基准) - if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) { - return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel) - } - // 调试日志:记录调度入口参数 excludedIDsList := make([]int64, 0, len(excludedIDs)) for id := range excludedIDs { @@ -1298,6 +1297,15 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } ctx = s.withGroupContext(ctx, group) + // Claude Code 限制可能已将 groupID 解析为 fallback group, + // 渠道限制预检查必须使用解析后的分组。 + if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) { + slog.Warn("channel pricing restriction blocked request", + "group_id", derefGroupID(groupID), + "model", requestedModel) + return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel) + } + var stickyAccountID int64 if prefetch := prefetchedStickyAccountIDFromContext(ctx, groupID); prefetch > 0 { stickyAccountID = prefetch @@ -3004,7 +3012,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } - if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { + if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) && !s.isStickyAccountUpstreamRestricted(ctx, groupID, account, requestedModel) { if s.debugModelRoutingEnabled() { logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) } @@ -3359,7 +3367,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } - if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { + if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) && !s.isStickyAccountUpstreamRestricted(ctx, groupID, account, requestedModel) { if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { return account, nil } @@ -3383,7 +3391,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g ctx = s.withRPMPrefetch(ctx, accounts) // 3. 按优先级+最久未用选择(考虑模型支持和混合调度) - // needsUpstreamCheck 仅在主选择循环中使用;粘性会话命中时跳过此检查。 needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID) var selected *Account for i := range accounts { @@ -8453,6 +8460,19 @@ func (s *GatewayService) needsUpstreamChannelRestrictionCheck(ctx context.Contex return ch.BillingModelSource == BillingModelSourceUpstream } +// isStickyAccountUpstreamRestricted 检查粘性会话命中的账号是否受 upstream 渠道限制。 +// 合并 needsUpstreamChannelRestrictionCheck + isUpstreamModelRestrictedByChannel 两步调用, +// 供 sticky session 条件链使用,避免内联多个函数调用导致行过长。 +func (s *GatewayService) isStickyAccountUpstreamRestricted(ctx context.Context, groupID *int64, account *Account, requestedModel string) bool { + if groupID == nil { + return false + } + if !s.needsUpstreamChannelRestrictionCheck(ctx, groupID) { + return false + } + return s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel) +} + // ForwardCountTokens 转发 count_tokens 请求到上游 API // 特点:不记录使用量、仅支持非流式响应 func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error {