diff --git a/backend/internal/service/channel_service.go b/backend/internal/service/channel_service.go index 38d4f751..d4d4d377 100644 --- a/backend/internal/service/channel_service.go +++ b/backend/internal/service/channel_service.go @@ -183,6 +183,67 @@ func (s *ChannelService) loadCache(ctx context.Context) (*channelCache, error) { return cache, nil } +// newEmptyChannelCache 创建空的渠道缓存(所有 map 已初始化) +func newEmptyChannelCache() *channelCache { + return &channelCache{ + pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing), + wildcardByGroupPlatform: make(map[channelGroupPlatformKey][]*wildcardPricingEntry), + mappingByGroupModel: make(map[channelModelKey]string), + wildcardMappingByGP: make(map[channelGroupPlatformKey][]*wildcardMappingEntry), + channelByGroupID: make(map[int64]*Channel), + groupPlatform: make(map[int64]string), + byID: make(map[int64]*Channel), + } +} + +// expandPricingToCache 将渠道的模型定价展开到缓存(按分组+平台维度)。 +// antigravity 平台同时服务 Claude 和 Gemini 模型,需匹配 anthropic/gemini 的定价条目。 +func expandPricingToCache(cache *channelCache, ch *Channel, gid int64, platform string) { + for j := range ch.ModelPricing { + pricing := &ch.ModelPricing[j] + if !isPlatformPricingMatch(platform, pricing.Platform) { + continue // 跳过非本平台的定价 + } + gpKey := channelGroupPlatformKey{groupID: gid, platform: platform} + for _, model := range pricing.Models { + if strings.HasSuffix(model, "*") { + prefix := strings.ToLower(strings.TrimSuffix(model, "*")) + cache.wildcardByGroupPlatform[gpKey] = append(cache.wildcardByGroupPlatform[gpKey], &wildcardPricingEntry{ + prefix: prefix, + pricing: pricing, + }) + } else { + key := channelModelKey{groupID: gid, platform: platform, model: strings.ToLower(model)} + cache.pricingByGroupModel[key] = pricing + } + } + } +} + +// expandMappingToCache 将渠道的模型映射展开到缓存(按分组+平台维度)。 +// antigravity 平台同时服务 Claude 和 Gemini 模型。 +func expandMappingToCache(cache *channelCache, ch *Channel, gid int64, platform string) { + for _, mappingPlatform := range matchingPlatforms(platform) { + platformMapping, ok := ch.ModelMapping[mappingPlatform] + if !ok { + continue + } + gpKey := channelGroupPlatformKey{groupID: gid, platform: platform} + for src, dst := range platformMapping { + if strings.HasSuffix(src, "*") { + prefix := strings.ToLower(strings.TrimSuffix(src, "*")) + cache.wildcardMappingByGP[gpKey] = append(cache.wildcardMappingByGP[gpKey], &wildcardMappingEntry{ + prefix: prefix, + target: dst, + }) + } else { + key := channelModelKey{groupID: gid, platform: platform, model: strings.ToLower(src)} + cache.mappingByGroupModel[key] = dst + } + } + } +} + // buildCache 从数据库构建渠道缓存。 // 使用独立 context 避免请求取消导致空值被长期缓存。 func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) { @@ -194,16 +255,8 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) if err != nil { // error-TTL:失败时存入短 TTL 空缓存,防止紧密重试 slog.Warn("failed to build channel cache", "error", err) - errorCache := &channelCache{ - pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing), - wildcardByGroupPlatform: make(map[channelGroupPlatformKey][]*wildcardPricingEntry), - mappingByGroupModel: make(map[channelModelKey]string), - wildcardMappingByGP: make(map[channelGroupPlatformKey][]*wildcardMappingEntry), - channelByGroupID: make(map[int64]*Channel), - groupPlatform: make(map[int64]string), - byID: make(map[int64]*Channel), - loadedAt: time.Now().Add(-(channelCacheTTL - channelErrorTTL)), // 使剩余 TTL = errorTTL - } + errorCache := newEmptyChannelCache() + errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL)) // 使剩余 TTL = errorTTL s.cache.Store(errorCache) return nil, fmt.Errorf("list all channels: %w", err) } @@ -222,71 +275,20 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) } } - cache := &channelCache{ - pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing), - wildcardByGroupPlatform: make(map[channelGroupPlatformKey][]*wildcardPricingEntry), - mappingByGroupModel: make(map[channelModelKey]string), - wildcardMappingByGP: make(map[channelGroupPlatformKey][]*wildcardMappingEntry), - channelByGroupID: make(map[int64]*Channel), - groupPlatform: groupPlatforms, - byID: make(map[int64]*Channel, len(channels)), - loadedAt: time.Now(), - } + cache := newEmptyChannelCache() + cache.groupPlatform = groupPlatforms + cache.byID = make(map[int64]*Channel, len(channels)) + cache.loadedAt = time.Now() for i := range channels { ch := &channels[i] cache.byID[ch.ID] = ch - // 展开到分组维度 for _, gid := range ch.GroupIDs { cache.channelByGroupID[gid] = ch - platform := groupPlatforms[gid] // e.g. "anthropic" - - // 只展开该平台的模型定价到 (groupID, platform, model) → *ChannelModelPricing - // antigravity 平台同时服务 Claude 和 Gemini 模型,需匹配 anthropic/gemini 的定价条目 - for j := range ch.ModelPricing { - pricing := &ch.ModelPricing[j] - if !isPlatformPricingMatch(platform, pricing.Platform) { - continue // 跳过非本平台的定价 - } - for _, model := range pricing.Models { - if strings.HasSuffix(model, "*") { - // 通配符模型 → 存入 wildcardByGroupPlatform - prefix := strings.ToLower(strings.TrimSuffix(model, "*")) - gpKey := channelGroupPlatformKey{groupID: gid, platform: platform} - cache.wildcardByGroupPlatform[gpKey] = append(cache.wildcardByGroupPlatform[gpKey], &wildcardPricingEntry{ - prefix: prefix, - pricing: pricing, - }) - } else { - key := channelModelKey{groupID: gid, platform: platform, model: strings.ToLower(model)} - cache.pricingByGroupModel[key] = pricing - } - } - } - - // 只展开该平台的模型映射到 (groupID, platform, model) → target - // antigravity 平台同时服务 Claude 和 Gemini 模型 - for _, mappingPlatform := range matchingPlatforms(platform) { - platformMapping, ok := ch.ModelMapping[mappingPlatform] - if !ok { - continue - } - for src, dst := range platformMapping { - if strings.HasSuffix(src, "*") { - // 通配符映射 → 存入 wildcardMappingByGP - prefix := strings.ToLower(strings.TrimSuffix(src, "*")) - gpKey := channelGroupPlatformKey{groupID: gid, platform: platform} - cache.wildcardMappingByGP[gpKey] = append(cache.wildcardMappingByGP[gpKey], &wildcardMappingEntry{ - prefix: prefix, - target: dst, - }) - } else { - key := channelModelKey{groupID: gid, platform: platform, model: strings.ToLower(src)} - cache.mappingByGroupModel[key] = dst - } - } - } + platform := groupPlatforms[gid] + expandPricingToCache(cache, ch, gid, platform) + expandMappingToCache(cache, ch, gid, platform) } } @@ -362,26 +364,48 @@ func (s *ChannelService) GetChannelForGroup(ctx context.Context, groupID int64) return ch.Clone(), nil } +// channelLookup 热路径公共查找结果 +type channelLookup struct { + cache *channelCache + channel *Channel + platform string +} + +// lookupGroupChannel 加载缓存并查找分组对应的渠道信息(公共热路径前置逻辑)。 +// 返回 nil 且 err==nil 表示分组无活跃渠道;err!=nil 表示缓存加载失败。 +func (s *ChannelService) lookupGroupChannel(ctx context.Context, groupID int64) (*channelLookup, error) { + cache, err := s.loadCache(ctx) + if err != nil { + return nil, err + } + ch, ok := cache.channelByGroupID[groupID] + if !ok || !ch.IsActive() { + return nil, nil + } + return &channelLookup{ + cache: cache, + channel: ch, + platform: cache.groupPlatform[groupID], + }, nil +} + // GetChannelModelPricing 获取指定分组+模型的渠道定价(热路径 O(1)) func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int64, model string) *ChannelModelPricing { - cache, err := s.loadCache(ctx) + lk, err := s.lookupGroupChannel(ctx, groupID) if err != nil { slog.Warn("failed to load channel cache", "group_id", groupID, "error", err) return nil } - - // 检查渠道是否启用 - ch, ok := cache.channelByGroupID[groupID] - if !ok || !ch.IsActive() { + if lk == nil { return nil } - platform := cache.groupPlatform[groupID] - key := channelModelKey{groupID: groupID, platform: platform, model: strings.ToLower(model)} - pricing, ok := cache.pricingByGroupModel[key] + modelLower := strings.ToLower(model) + key := channelModelKey{groupID: groupID, platform: lk.platform, model: modelLower} + pricing, ok := lk.cache.pricingByGroupModel[key] if !ok { // 精确查找失败,尝试通配符匹配 - pricing = cache.matchWildcard(groupID, platform, strings.ToLower(model)) + pricing = lk.cache.matchWildcard(groupID, lk.platform, modelLower) if pricing == nil { return nil } @@ -394,31 +418,57 @@ func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int // ResolveChannelMapping 解析渠道级模型映射(热路径 O(1)) // 返回映射结果,包含映射后的模型名、渠道 ID、计费模型来源。 func (s *ChannelService) ResolveChannelMapping(ctx context.Context, groupID int64, model string) ChannelMappingResult { - cache, err := s.loadCache(ctx) - if err != nil { + lk, _ := s.lookupGroupChannel(ctx, groupID) + if lk == nil { return ChannelMappingResult{MappedModel: model} } + return resolveMapping(lk, groupID, model) +} - ch, ok := cache.channelByGroupID[groupID] - if !ok || !ch.IsActive() { - return ChannelMappingResult{MappedModel: model} +// IsModelRestricted 检查模型是否被渠道限制。 +// 返回 true 表示模型被限制(不在允许列表中)。 +// 如果渠道未启用模型限制或分组无渠道关联,返回 false。 +func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool { + lk, _ := s.lookupGroupChannel(ctx, groupID) + if lk == nil { + return false } + return checkRestricted(lk, groupID, model) +} +// ResolveChannelMappingAndRestrict 解析渠道映射并检查模型限制(组合方法)。 +// 返回映射结果和是否被限制。groupID 为 nil 时跳过。 +func (s *ChannelService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) { + if groupID == nil { + return ChannelMappingResult{MappedModel: model}, false + } + lk, _ := s.lookupGroupChannel(ctx, *groupID) + if lk == nil { + return ChannelMappingResult{MappedModel: model}, false + } + // 先用原始模型检查定价列表限制,再做映射 + restricted := checkRestricted(lk, *groupID, model) + mapping := resolveMapping(lk, *groupID, model) + return mapping, restricted +} + +// resolveMapping 基于已查找的渠道信息解析模型映射 +func resolveMapping(lk *channelLookup, groupID int64, model string) ChannelMappingResult { result := ChannelMappingResult{ MappedModel: model, - ChannelID: ch.ID, - BillingModelSource: ch.BillingModelSource, + ChannelID: lk.channel.ID, + BillingModelSource: lk.channel.BillingModelSource, } if result.BillingModelSource == "" { result.BillingModelSource = BillingModelSourceChannelMapped } - platform := cache.groupPlatform[groupID] - key := channelModelKey{groupID: groupID, platform: platform, model: strings.ToLower(model)} - if mapped, ok := cache.mappingByGroupModel[key]; ok { + modelLower := strings.ToLower(model) + key := channelModelKey{groupID: groupID, platform: lk.platform, model: modelLower} + if mapped, ok := lk.cache.mappingByGroupModel[key]; ok { result.MappedModel = mapped result.Mapped = true - } else if mapped := cache.matchWildcardMapping(groupID, platform, strings.ToLower(model)); mapped != "" { + } else if mapped := lk.cache.matchWildcardMapping(groupID, lk.platform, modelLower); mapped != "" { result.MappedModel = mapped result.Mapped = true } @@ -426,48 +476,24 @@ func (s *ChannelService) ResolveChannelMapping(ctx context.Context, groupID int6 return result } -// IsModelRestricted 检查模型是否被渠道限制。 -// 返回 true 表示模型被限制(不在允许列表中)。 -// 如果渠道未启用模型限制或分组无渠道关联,返回 false。 -func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool { - cache, err := s.loadCache(ctx) - if err != nil { - return false // 缓存加载失败时不限制 - } - - ch, ok := cache.channelByGroupID[groupID] - if !ok || !ch.IsActive() || !ch.RestrictModels { +// checkRestricted 基于已查找的渠道信息检查模型是否被限制 +func checkRestricted(lk *channelLookup, groupID int64, model string) bool { + if !lk.channel.RestrictModels { return false } - // 检查模型是否在定价列表中 - platform := cache.groupPlatform[groupID] - key := channelModelKey{groupID: groupID, platform: platform, model: strings.ToLower(model)} - _, exists := cache.pricingByGroupModel[key] - if exists { + modelLower := strings.ToLower(model) + key := channelModelKey{groupID: groupID, platform: lk.platform, model: modelLower} + if _, exists := lk.cache.pricingByGroupModel[key]; exists { return false } // 精确查找失败,尝试通配符匹配 - if cache.matchWildcard(groupID, platform, strings.ToLower(model)) != nil { + if lk.cache.matchWildcard(groupID, lk.platform, modelLower) != nil { return false } return true } -// ResolveChannelMappingAndRestrict 解析渠道映射并检查模型限制(组合方法)。 -// 返回映射结果和是否被限制。groupID 为 nil 时跳过。 -func (s *ChannelService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) { - var mapping ChannelMappingResult - mapping.MappedModel = model - if groupID == nil { - return mapping, false - } - // 先用原始模型检查定价列表限制,再做映射 - restricted := s.IsModelRestricted(ctx, *groupID, model) - mapping = s.ResolveChannelMapping(ctx, *groupID, model) - return mapping, restricted -} - // ReplaceModelInBody 替换请求体 JSON 中的 model 字段。 func ReplaceModelInBody(body []byte, newModel string) []byte { if len(body) == 0 {