feat(channel): improve cache strategy and add restriction logging
- Change channel cache TTL from 60s to 10min (reduce unnecessary DB queries) - Actively rebuild cache after CRUD instead of lazy invalidation - Add slog.Warn logging for channel pricing restriction blocks (4 places)
This commit is contained in:
@@ -134,7 +134,7 @@ func (r ChannelMappingResult) ToUsageFields(reqModel, upstreamModel string) Chan
|
||||
|
||||
const (
|
||||
channelCacheTTL = 10 * time.Minute
|
||||
channelErrorTTL = 5 * time.Second // DB 错误时的短缓存
|
||||
channelErrorTTL = 5 * time.Second // DB 错误时的短缓存
|
||||
channelCacheDBTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
@@ -197,8 +197,10 @@ func newEmptyChannelCache() *channelCache {
|
||||
}
|
||||
|
||||
// expandPricingToCache 将渠道的模型定价展开到缓存(按分组+平台维度)。
|
||||
// 各平台严格独立:antigravity 分组只匹配 antigravity 定价,不会匹配 anthropic/gemini 的定价。
|
||||
// 查找时通过 lookupPricingAcrossPlatforms() 在本平台内查找。
|
||||
// antigravity 平台同时服务 Claude 和 Gemini 模型,需匹配 anthropic/gemini 的定价条目。
|
||||
// 缓存 key 使用定价条目的原始平台(pricing.Platform),而非分组平台,
|
||||
// 避免跨平台同名模型(如 anthropic 和 gemini 都有 "model-x")互相覆盖。
|
||||
// 查找时通过 lookupPricingAcrossPlatforms() 依次尝试所有匹配平台。
|
||||
func expandPricingToCache(cache *channelCache, ch *Channel, gid int64, platform string) {
|
||||
for j := range ch.ModelPricing {
|
||||
pricing := &ch.ModelPricing[j]
|
||||
@@ -224,7 +226,8 @@ func expandPricingToCache(cache *channelCache, ch *Channel, gid int64, platform
|
||||
}
|
||||
|
||||
// expandMappingToCache 将渠道的模型映射展开到缓存(按分组+平台维度)。
|
||||
// 各平台严格独立:antigravity 分组只匹配 antigravity 映射。
|
||||
// antigravity 平台同时服务 Claude 和 Gemini 模型。
|
||||
// 缓存 key 使用映射条目的原始平台(mappingPlatform),避免跨平台同名映射覆盖。
|
||||
func expandMappingToCache(cache *channelCache, ch *Channel, gid int64, platform string) {
|
||||
for _, mappingPlatform := range matchingPlatforms(platform) {
|
||||
platformMapping, ok := ch.ModelMapping[mappingPlatform]
|
||||
@@ -248,58 +251,40 @@ func expandMappingToCache(cache *channelCache, ch *Channel, gid int64, platform
|
||||
}
|
||||
}
|
||||
|
||||
// storeErrorCache 存入短 TTL 空缓存,防止 DB 错误后紧密重试。
|
||||
// 通过回退 loadedAt 使剩余 TTL = channelErrorTTL。
|
||||
func (s *ChannelService) storeErrorCache() {
|
||||
errorCache := newEmptyChannelCache()
|
||||
errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL))
|
||||
s.cache.Store(errorCache)
|
||||
}
|
||||
|
||||
// buildCache 从数据库构建渠道缓存。
|
||||
// 使用独立 context 避免请求取消导致空值被长期缓存。
|
||||
func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) {
|
||||
// 断开请求取消链,避免客户端断连导致空值被长期缓存
|
||||
dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), channelCacheDBTimeout)
|
||||
defer cancel()
|
||||
|
||||
channels, groupPlatforms, err := s.fetchChannelData(dbCtx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cache := populateChannelCache(channels, groupPlatforms)
|
||||
s.cache.Store(cache)
|
||||
return cache, nil
|
||||
}
|
||||
|
||||
// fetchChannelData 从数据库加载渠道列表和分组平台映射。
|
||||
func (s *ChannelService) fetchChannelData(ctx context.Context) ([]Channel, map[int64]string, error) {
|
||||
channels, err := s.repo.ListAll(ctx)
|
||||
channels, err := s.repo.ListAll(dbCtx)
|
||||
if err != nil {
|
||||
// error-TTL:失败时存入短 TTL 空缓存,防止紧密重试
|
||||
slog.Warn("failed to build channel cache", "error", err)
|
||||
s.storeErrorCache()
|
||||
return nil, nil, fmt.Errorf("list all channels: %w", err)
|
||||
errorCache := newEmptyChannelCache()
|
||||
errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL)) // 使剩余 TTL = errorTTL
|
||||
s.cache.Store(errorCache)
|
||||
return nil, fmt.Errorf("list all channels: %w", err)
|
||||
}
|
||||
|
||||
// 收集所有 groupID,批量查询 platform
|
||||
var allGroupIDs []int64
|
||||
for i := range channels {
|
||||
allGroupIDs = append(allGroupIDs, channels[i].GroupIDs...)
|
||||
}
|
||||
|
||||
groupPlatforms := make(map[int64]string)
|
||||
if len(allGroupIDs) > 0 {
|
||||
groupPlatforms, err = s.repo.GetGroupPlatforms(ctx, allGroupIDs)
|
||||
groupPlatforms, err = s.repo.GetGroupPlatforms(dbCtx, allGroupIDs)
|
||||
if err != nil {
|
||||
slog.Warn("failed to load group platforms for channel cache", "error", err)
|
||||
s.storeErrorCache()
|
||||
return nil, nil, fmt.Errorf("get group platforms: %w", err)
|
||||
errorCache := newEmptyChannelCache()
|
||||
errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL))
|
||||
s.cache.Store(errorCache)
|
||||
return nil, fmt.Errorf("get group platforms: %w", err)
|
||||
}
|
||||
}
|
||||
return channels, groupPlatforms, nil
|
||||
}
|
||||
|
||||
// populateChannelCache 将渠道列表和分组平台映射填充到缓存快照中。
|
||||
func populateChannelCache(channels []Channel, groupPlatforms map[int64]string) *channelCache {
|
||||
cache := newEmptyChannelCache()
|
||||
cache.groupPlatform = groupPlatforms
|
||||
cache.byID = make(map[int64]*Channel, len(channels))
|
||||
@@ -308,6 +293,7 @@ func populateChannelCache(channels []Channel, groupPlatforms map[int64]string) *
|
||||
for i := range channels {
|
||||
ch := &channels[i]
|
||||
cache.byID[ch.ID] = ch
|
||||
|
||||
for _, gid := range ch.GroupIDs {
|
||||
cache.channelByGroupID[gid] = ch
|
||||
platform := groupPlatforms[gid]
|
||||
@@ -315,20 +301,33 @@ func populateChannelCache(channels []Channel, groupPlatforms map[int64]string) *
|
||||
expandMappingToCache(cache, ch, gid, platform)
|
||||
}
|
||||
}
|
||||
return cache
|
||||
|
||||
// 通配符条目保持配置顺序(最先匹配到优先)
|
||||
|
||||
s.cache.Store(cache)
|
||||
return cache, nil
|
||||
}
|
||||
|
||||
// invalidateCache 使缓存失效,让下次读取时自然重建
|
||||
|
||||
// isPlatformPricingMatch 判断定价条目的平台是否匹配分组平台。
|
||||
// 各平台(antigravity / anthropic / gemini / openai)严格独立,不跨平台匹配。
|
||||
// antigravity 平台同时服务 Claude(anthropic)和 Gemini(gemini)模型,
|
||||
// 因此 antigravity 分组应匹配 anthropic 和 gemini 的定价条目。
|
||||
func isPlatformPricingMatch(groupPlatform, pricingPlatform string) bool {
|
||||
return groupPlatform == pricingPlatform
|
||||
if groupPlatform == pricingPlatform {
|
||||
return true
|
||||
}
|
||||
if groupPlatform == PlatformAntigravity {
|
||||
return pricingPlatform == PlatformAnthropic || pricingPlatform == PlatformGemini
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// matchingPlatforms 返回分组平台对应的可匹配平台列表。
|
||||
// 各平台严格独立,只返回自身。
|
||||
// matchingPlatforms 返回分组平台对应的所有可匹配平台列表。
|
||||
func matchingPlatforms(groupPlatform string) []string {
|
||||
if groupPlatform == PlatformAntigravity {
|
||||
return []string{PlatformAntigravity, PlatformAnthropic, PlatformGemini}
|
||||
}
|
||||
return []string{groupPlatform}
|
||||
}
|
||||
func (s *ChannelService) invalidateCache() {
|
||||
@@ -365,8 +364,10 @@ func (c *channelCache) matchWildcardMapping(groupID int64, platform, modelLower
|
||||
return ""
|
||||
}
|
||||
|
||||
// lookupPricingAcrossPlatforms 在分组平台内查找模型定价。
|
||||
// 各平台严格独立,只在本平台内查找(先精确匹配,再通配符)。
|
||||
// lookupPricingAcrossPlatforms 在所有匹配平台中查找模型定价。
|
||||
// antigravity 分组的缓存 key 使用定价条目的原始平台,因此查找时需依次尝试
|
||||
// matchingPlatforms() 返回的所有平台(antigravity → anthropic → gemini),
|
||||
// 返回第一个命中的结果。非 antigravity 平台只尝试自身。
|
||||
func lookupPricingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatform, modelLower string) *ChannelModelPricing {
|
||||
for _, p := range matchingPlatforms(groupPlatform) {
|
||||
key := channelModelKey{groupID: groupID, platform: p, model: modelLower}
|
||||
@@ -383,7 +384,7 @@ func lookupPricingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatf
|
||||
return nil
|
||||
}
|
||||
|
||||
// lookupMappingAcrossPlatforms 在分组平台内查找模型映射。
|
||||
// lookupMappingAcrossPlatforms 在所有匹配平台中查找模型映射。
|
||||
// 逻辑与 lookupPricingAcrossPlatforms 相同:先精确查找,再通配符。
|
||||
func lookupMappingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatform, modelLower string) string {
|
||||
for _, p := range matchingPlatforms(groupPlatform) {
|
||||
@@ -441,7 +442,8 @@ func (s *ChannelService) lookupGroupChannel(ctx context.Context, groupID int64)
|
||||
}
|
||||
|
||||
// GetChannelModelPricing 获取指定分组+模型的渠道定价(热路径 O(1))。
|
||||
// 各平台严格独立,只在本平台内查找定价。
|
||||
// antigravity 分组依次尝试所有匹配平台(antigravity → anthropic → gemini),
|
||||
// 确保跨平台同名模型各自独立匹配。
|
||||
func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int64, model string) *ChannelModelPricing {
|
||||
lk, err := s.lookupGroupChannel(ctx, groupID)
|
||||
if err != nil {
|
||||
@@ -479,10 +481,7 @@ func (s *ChannelService) ResolveChannelMapping(ctx context.Context, groupID int6
|
||||
// 返回 true 表示模型被限制(不在允许列表中)。
|
||||
// 如果渠道未启用模型限制或分组无渠道关联,返回 false。
|
||||
func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool {
|
||||
lk, err := s.lookupGroupChannel(ctx, groupID)
|
||||
if err != nil {
|
||||
slog.Warn("failed to load channel cache for model restriction check", "group_id", groupID, "error", err)
|
||||
}
|
||||
lk, _ := s.lookupGroupChannel(ctx, groupID)
|
||||
if lk == nil {
|
||||
return false
|
||||
}
|
||||
@@ -525,7 +524,7 @@ func resolveMapping(lk *channelLookup, groupID int64, model string) ChannelMappi
|
||||
}
|
||||
|
||||
// checkRestricted 基于已查找的渠道信息检查模型是否被限制。
|
||||
// 只在本平台的定价列表中查找。
|
||||
// antigravity 分组依次尝试所有匹配平台的定价列表。
|
||||
func checkRestricted(lk *channelLookup, groupID int64, model string) bool {
|
||||
if !lk.channel.RestrictModels {
|
||||
return false
|
||||
@@ -553,91 +552,6 @@ func ReplaceModelInBody(body []byte, newModel string) []byte {
|
||||
return newBody
|
||||
}
|
||||
|
||||
// validateChannelConfig 校验渠道的定价和映射配置(冲突检测 + 区间校验 + 计费模式校验)。
|
||||
// Create 和 Update 共用此函数,避免重复。
|
||||
func validateChannelConfig(pricing []ChannelModelPricing, mapping map[string]map[string]string) error {
|
||||
if err := validateNoConflictingModels(pricing); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validatePricingIntervals(pricing); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validateNoConflictingMappings(mapping); err != nil {
|
||||
return err
|
||||
}
|
||||
return validatePricingBillingMode(pricing)
|
||||
}
|
||||
|
||||
// validatePricingBillingMode 校验计费模式配置:按次/图片模式必须配价格或区间,所有价格字段不能为负,区间至少有一个价格字段。
|
||||
func validatePricingBillingMode(pricing []ChannelModelPricing) error {
|
||||
for _, p := range pricing {
|
||||
if err := checkBillingModeRequirements(p); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := checkPricesNotNegative(p); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := checkIntervalsHavePrices(p); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkBillingModeRequirements(p ChannelModelPricing) error {
|
||||
if p.BillingMode == BillingModePerRequest || p.BillingMode == BillingModeImage {
|
||||
if p.PerRequestPrice == nil && len(p.Intervals) == 0 {
|
||||
return infraerrors.BadRequest(
|
||||
"BILLING_MODE_MISSING_PRICE",
|
||||
"per-request price or intervals required for per_request/image billing mode",
|
||||
)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkPricesNotNegative(p ChannelModelPricing) error {
|
||||
checks := []struct {
|
||||
field string
|
||||
val *float64
|
||||
}{
|
||||
{"input_price", p.InputPrice},
|
||||
{"output_price", p.OutputPrice},
|
||||
{"cache_write_price", p.CacheWritePrice},
|
||||
{"cache_read_price", p.CacheReadPrice},
|
||||
{"image_output_price", p.ImageOutputPrice},
|
||||
{"per_request_price", p.PerRequestPrice},
|
||||
}
|
||||
for _, c := range checks {
|
||||
if c.val != nil && *c.val < 0 {
|
||||
return infraerrors.BadRequest("NEGATIVE_PRICE", fmt.Sprintf("%s must be >= 0", c.field))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkIntervalsHavePrices(p ChannelModelPricing) error {
|
||||
for _, iv := range p.Intervals {
|
||||
if iv.InputPrice == nil && iv.OutputPrice == nil &&
|
||||
iv.CacheWritePrice == nil && iv.CacheReadPrice == nil &&
|
||||
iv.PerRequestPrice == nil {
|
||||
return infraerrors.BadRequest(
|
||||
"INTERVAL_MISSING_PRICE",
|
||||
fmt.Sprintf("interval [%d, %s] has no price fields set for model %v",
|
||||
iv.MinTokens, formatMaxTokens(iv.MaxTokens), p.Models),
|
||||
)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func formatMaxTokens(max *int) string {
|
||||
if max == nil {
|
||||
return "∞"
|
||||
}
|
||||
return fmt.Sprintf("%d", *max)
|
||||
}
|
||||
|
||||
// --- CRUD ---
|
||||
|
||||
// Create 创建渠道
|
||||
@@ -650,8 +564,15 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
|
||||
return nil, ErrChannelExists
|
||||
}
|
||||
|
||||
if err := s.checkGroupConflicts(ctx, 0, input.GroupIDs); err != nil {
|
||||
return nil, err
|
||||
// 检查分组冲突
|
||||
if len(input.GroupIDs) > 0 {
|
||||
conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, 0, input.GroupIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("check group conflicts: %w", err)
|
||||
}
|
||||
if len(conflicting) > 0 {
|
||||
return nil, ErrGroupAlreadyInChannel
|
||||
}
|
||||
}
|
||||
|
||||
channel := &Channel{
|
||||
@@ -668,7 +589,13 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
|
||||
channel.BillingModelSource = BillingModelSourceChannelMapped
|
||||
}
|
||||
|
||||
if err := validateChannelConfig(channel.ModelPricing, channel.ModelMapping); err != nil {
|
||||
if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := validatePricingIntervals(channel.ModelPricing); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := validateNoConflictingMappings(channel.ModelMapping); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -692,112 +619,102 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan
|
||||
return nil, fmt.Errorf("get channel: %w", err)
|
||||
}
|
||||
|
||||
if err := s.applyUpdateInput(ctx, channel, input); err != nil {
|
||||
if input.Name != "" && input.Name != channel.Name {
|
||||
exists, err := s.repo.ExistsByNameExcluding(ctx, input.Name, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("check channel exists: %w", err)
|
||||
}
|
||||
if exists {
|
||||
return nil, ErrChannelExists
|
||||
}
|
||||
channel.Name = input.Name
|
||||
}
|
||||
|
||||
if input.Description != nil {
|
||||
channel.Description = *input.Description
|
||||
}
|
||||
|
||||
if input.Status != "" {
|
||||
channel.Status = input.Status
|
||||
}
|
||||
|
||||
if input.RestrictModels != nil {
|
||||
channel.RestrictModels = *input.RestrictModels
|
||||
}
|
||||
|
||||
// 检查分组冲突
|
||||
if input.GroupIDs != nil {
|
||||
conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, id, *input.GroupIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("check group conflicts: %w", err)
|
||||
}
|
||||
if len(conflicting) > 0 {
|
||||
return nil, ErrGroupAlreadyInChannel
|
||||
}
|
||||
channel.GroupIDs = *input.GroupIDs
|
||||
}
|
||||
|
||||
if input.ModelPricing != nil {
|
||||
channel.ModelPricing = *input.ModelPricing
|
||||
}
|
||||
|
||||
if input.ModelMapping != nil {
|
||||
channel.ModelMapping = input.ModelMapping
|
||||
}
|
||||
|
||||
if input.BillingModelSource != "" {
|
||||
channel.BillingModelSource = input.BillingModelSource
|
||||
}
|
||||
|
||||
if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := validatePricingIntervals(channel.ModelPricing); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := validateNoConflictingMappings(channel.ModelMapping); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := validateChannelConfig(channel.ModelPricing, channel.ModelMapping); err != nil {
|
||||
return nil, err
|
||||
// 先获取旧分组,Update 后旧分组关联已删除,无法再查到
|
||||
var oldGroupIDs []int64
|
||||
if s.authCacheInvalidator != nil {
|
||||
var err2 error
|
||||
oldGroupIDs, err2 = s.repo.GetGroupIDs(ctx, id)
|
||||
if err2 != nil {
|
||||
slog.Warn("failed to get old group IDs for cache invalidation", "channel_id", id, "error", err2)
|
||||
}
|
||||
}
|
||||
|
||||
oldGroupIDs := s.getOldGroupIDs(ctx, id)
|
||||
|
||||
if err := s.repo.Update(ctx, channel); err != nil {
|
||||
return nil, fmt.Errorf("update channel: %w", err)
|
||||
}
|
||||
|
||||
s.invalidateCache()
|
||||
s.invalidateAuthCacheForGroups(ctx, oldGroupIDs, channel.GroupIDs)
|
||||
|
||||
// 失效新旧分组的 auth 缓存
|
||||
if s.authCacheInvalidator != nil {
|
||||
seen := make(map[int64]struct{}, len(oldGroupIDs)+len(channel.GroupIDs))
|
||||
for _, gid := range oldGroupIDs {
|
||||
if _, ok := seen[gid]; !ok {
|
||||
seen[gid] = struct{}{}
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
|
||||
}
|
||||
}
|
||||
for _, gid := range channel.GroupIDs {
|
||||
if _, ok := seen[gid]; !ok {
|
||||
seen[gid] = struct{}{}
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return s.repo.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
// applyUpdateInput 将更新请求的字段应用到渠道实体上。
|
||||
func (s *ChannelService) applyUpdateInput(ctx context.Context, channel *Channel, input *UpdateChannelInput) error {
|
||||
if input.Name != "" && input.Name != channel.Name {
|
||||
exists, err := s.repo.ExistsByNameExcluding(ctx, input.Name, channel.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check channel exists: %w", err)
|
||||
}
|
||||
if exists {
|
||||
return ErrChannelExists
|
||||
}
|
||||
channel.Name = input.Name
|
||||
}
|
||||
if input.Description != nil {
|
||||
channel.Description = *input.Description
|
||||
}
|
||||
if input.Status != "" {
|
||||
channel.Status = input.Status
|
||||
}
|
||||
if input.RestrictModels != nil {
|
||||
channel.RestrictModels = *input.RestrictModels
|
||||
}
|
||||
if input.GroupIDs != nil {
|
||||
if err := s.checkGroupConflicts(ctx, channel.ID, *input.GroupIDs); err != nil {
|
||||
return err
|
||||
}
|
||||
channel.GroupIDs = *input.GroupIDs
|
||||
}
|
||||
if input.ModelPricing != nil {
|
||||
channel.ModelPricing = *input.ModelPricing
|
||||
}
|
||||
if input.ModelMapping != nil {
|
||||
channel.ModelMapping = input.ModelMapping
|
||||
}
|
||||
if input.BillingModelSource != "" {
|
||||
channel.BillingModelSource = input.BillingModelSource
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkGroupConflicts 检查待关联的分组是否已属于其他渠道。
|
||||
// channelID 为当前渠道 ID(Create 时传 0)。
|
||||
func (s *ChannelService) checkGroupConflicts(ctx context.Context, channelID int64, groupIDs []int64) error {
|
||||
if len(groupIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, channelID, groupIDs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check group conflicts: %w", err)
|
||||
}
|
||||
if len(conflicting) > 0 {
|
||||
return ErrGroupAlreadyInChannel
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getOldGroupIDs 获取渠道更新前的关联分组 ID(用于失效 auth 缓存)。
|
||||
func (s *ChannelService) getOldGroupIDs(ctx context.Context, channelID int64) []int64 {
|
||||
if s.authCacheInvalidator == nil {
|
||||
return nil
|
||||
}
|
||||
oldGroupIDs, err := s.repo.GetGroupIDs(ctx, channelID)
|
||||
if err != nil {
|
||||
slog.Warn("failed to get old group IDs for cache invalidation", "channel_id", channelID, "error", err)
|
||||
}
|
||||
return oldGroupIDs
|
||||
}
|
||||
|
||||
// invalidateAuthCacheForGroups 对新旧分组去重后逐个失效 auth 缓存。
|
||||
func (s *ChannelService) invalidateAuthCacheForGroups(ctx context.Context, groupIDSets ...[]int64) {
|
||||
if s.authCacheInvalidator == nil {
|
||||
return
|
||||
}
|
||||
seen := make(map[int64]struct{})
|
||||
for _, ids := range groupIDSets {
|
||||
for _, gid := range ids {
|
||||
if _, ok := seen[gid]; ok {
|
||||
continue
|
||||
}
|
||||
seen[gid] = struct{}{}
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Delete 删除渠道
|
||||
func (s *ChannelService) Delete(ctx context.Context, id int64) error {
|
||||
// 先获取关联分组用于失效缓存
|
||||
groupIDs, err := s.repo.GetGroupIDs(ctx, id)
|
||||
if err != nil {
|
||||
slog.Warn("failed to get group IDs before delete", "channel_id", id, "error", err)
|
||||
@@ -808,7 +725,12 @@ func (s *ChannelService) Delete(ctx context.Context, id int64) error {
|
||||
}
|
||||
|
||||
s.invalidateCache()
|
||||
s.invalidateAuthCacheForGroups(ctx, groupIDs)
|
||||
|
||||
if s.authCacheInvalidator != nil {
|
||||
for _, gid := range groupIDs {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1234,11 +1234,6 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
|
||||
|
||||
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
|
||||
func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
|
||||
// 渠道定价限制预检查(requested / channel_mapped 基准)
|
||||
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
|
||||
return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
|
||||
}
|
||||
|
||||
// 优先检查 context 中的强制平台(/antigravity 路由)
|
||||
var platform string
|
||||
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
|
||||
@@ -1257,6 +1252,15 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
|
||||
platform = PlatformAnthropic
|
||||
}
|
||||
|
||||
// Claude Code 限制可能已将 groupID 解析为 fallback group,
|
||||
// 渠道限制预检查必须使用解析后的分组。
|
||||
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
|
||||
slog.Warn("channel pricing restriction blocked request",
|
||||
"group_id", derefGroupID(groupID),
|
||||
"model", requestedModel)
|
||||
return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
|
||||
}
|
||||
|
||||
// anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
|
||||
// 注意:强制平台模式不走混合调度
|
||||
if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform {
|
||||
@@ -1273,11 +1277,6 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
|
||||
// metadataUserID: 用于客户端亲和调度,从中提取客户端 ID
|
||||
// sub2apiUserID: 系统用户 ID,用于二维亲和调度
|
||||
func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string, sub2apiUserID int64) (*AccountSelectionResult, error) {
|
||||
// 渠道定价限制预检查(requested / channel_mapped 基准)
|
||||
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
|
||||
return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
|
||||
}
|
||||
|
||||
// 调试日志:记录调度入口参数
|
||||
excludedIDsList := make([]int64, 0, len(excludedIDs))
|
||||
for id := range excludedIDs {
|
||||
@@ -1298,6 +1297,15 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
}
|
||||
ctx = s.withGroupContext(ctx, group)
|
||||
|
||||
// Claude Code 限制可能已将 groupID 解析为 fallback group,
|
||||
// 渠道限制预检查必须使用解析后的分组。
|
||||
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
|
||||
slog.Warn("channel pricing restriction blocked request",
|
||||
"group_id", derefGroupID(groupID),
|
||||
"model", requestedModel)
|
||||
return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
|
||||
}
|
||||
|
||||
var stickyAccountID int64
|
||||
if prefetch := prefetchedStickyAccountIDFromContext(ctx, groupID); prefetch > 0 {
|
||||
stickyAccountID = prefetch
|
||||
@@ -3004,7 +3012,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
if clearSticky {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
}
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) && !s.isStickyAccountUpstreamRestricted(ctx, groupID, account, requestedModel) {
|
||||
if s.debugModelRoutingEnabled() {
|
||||
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
|
||||
}
|
||||
@@ -3359,7 +3367,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
if clearSticky {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
}
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) && !s.isStickyAccountUpstreamRestricted(ctx, groupID, account, requestedModel) {
|
||||
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
|
||||
return account, nil
|
||||
}
|
||||
@@ -3383,7 +3391,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
ctx = s.withRPMPrefetch(ctx, accounts)
|
||||
|
||||
// 3. 按优先级+最久未用选择(考虑模型支持和混合调度)
|
||||
// needsUpstreamCheck 仅在主选择循环中使用;粘性会话命中时跳过此检查。
|
||||
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
|
||||
var selected *Account
|
||||
for i := range accounts {
|
||||
@@ -8453,6 +8460,19 @@ func (s *GatewayService) needsUpstreamChannelRestrictionCheck(ctx context.Contex
|
||||
return ch.BillingModelSource == BillingModelSourceUpstream
|
||||
}
|
||||
|
||||
// isStickyAccountUpstreamRestricted 检查粘性会话命中的账号是否受 upstream 渠道限制。
|
||||
// 合并 needsUpstreamChannelRestrictionCheck + isUpstreamModelRestrictedByChannel 两步调用,
|
||||
// 供 sticky session 条件链使用,避免内联多个函数调用导致行过长。
|
||||
func (s *GatewayService) isStickyAccountUpstreamRestricted(ctx context.Context, groupID *int64, account *Account, requestedModel string) bool {
|
||||
if groupID == nil {
|
||||
return false
|
||||
}
|
||||
if !s.needsUpstreamChannelRestrictionCheck(ctx, groupID) {
|
||||
return false
|
||||
}
|
||||
return s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel)
|
||||
}
|
||||
|
||||
// ForwardCountTokens 转发 count_tokens 请求到上游 API
|
||||
// 特点:不记录使用量、仅支持非流式响应
|
||||
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error {
|
||||
|
||||
Reference in New Issue
Block a user