refactor: split buildCache into sub-functions, reduce nesting 5→2
- Extract newEmptyChannelCache() factory to deduplicate map init - Extract expandPricingToCache() for model pricing expansion - Extract expandMappingToCache() for model mapping expansion - buildCache reduced from 110 to 50 lines
This commit is contained in:
@@ -183,6 +183,67 @@ func (s *ChannelService) loadCache(ctx context.Context) (*channelCache, error) {
|
||||
return cache, nil
|
||||
}
|
||||
|
||||
// newEmptyChannelCache 创建空的渠道缓存(所有 map 已初始化)
|
||||
func newEmptyChannelCache() *channelCache {
|
||||
return &channelCache{
|
||||
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
|
||||
wildcardByGroupPlatform: make(map[channelGroupPlatformKey][]*wildcardPricingEntry),
|
||||
mappingByGroupModel: make(map[channelModelKey]string),
|
||||
wildcardMappingByGP: make(map[channelGroupPlatformKey][]*wildcardMappingEntry),
|
||||
channelByGroupID: make(map[int64]*Channel),
|
||||
groupPlatform: make(map[int64]string),
|
||||
byID: make(map[int64]*Channel),
|
||||
}
|
||||
}
|
||||
|
||||
// expandPricingToCache 将渠道的模型定价展开到缓存(按分组+平台维度)。
|
||||
// antigravity 平台同时服务 Claude 和 Gemini 模型,需匹配 anthropic/gemini 的定价条目。
|
||||
func expandPricingToCache(cache *channelCache, ch *Channel, gid int64, platform string) {
|
||||
for j := range ch.ModelPricing {
|
||||
pricing := &ch.ModelPricing[j]
|
||||
if !isPlatformPricingMatch(platform, pricing.Platform) {
|
||||
continue // 跳过非本平台的定价
|
||||
}
|
||||
gpKey := channelGroupPlatformKey{groupID: gid, platform: platform}
|
||||
for _, model := range pricing.Models {
|
||||
if strings.HasSuffix(model, "*") {
|
||||
prefix := strings.ToLower(strings.TrimSuffix(model, "*"))
|
||||
cache.wildcardByGroupPlatform[gpKey] = append(cache.wildcardByGroupPlatform[gpKey], &wildcardPricingEntry{
|
||||
prefix: prefix,
|
||||
pricing: pricing,
|
||||
})
|
||||
} else {
|
||||
key := channelModelKey{groupID: gid, platform: platform, model: strings.ToLower(model)}
|
||||
cache.pricingByGroupModel[key] = pricing
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// expandMappingToCache 将渠道的模型映射展开到缓存(按分组+平台维度)。
|
||||
// antigravity 平台同时服务 Claude 和 Gemini 模型。
|
||||
func expandMappingToCache(cache *channelCache, ch *Channel, gid int64, platform string) {
|
||||
for _, mappingPlatform := range matchingPlatforms(platform) {
|
||||
platformMapping, ok := ch.ModelMapping[mappingPlatform]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
gpKey := channelGroupPlatformKey{groupID: gid, platform: platform}
|
||||
for src, dst := range platformMapping {
|
||||
if strings.HasSuffix(src, "*") {
|
||||
prefix := strings.ToLower(strings.TrimSuffix(src, "*"))
|
||||
cache.wildcardMappingByGP[gpKey] = append(cache.wildcardMappingByGP[gpKey], &wildcardMappingEntry{
|
||||
prefix: prefix,
|
||||
target: dst,
|
||||
})
|
||||
} else {
|
||||
key := channelModelKey{groupID: gid, platform: platform, model: strings.ToLower(src)}
|
||||
cache.mappingByGroupModel[key] = dst
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// buildCache 从数据库构建渠道缓存。
|
||||
// 使用独立 context 避免请求取消导致空值被长期缓存。
|
||||
func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) {
|
||||
@@ -194,16 +255,8 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
|
||||
if err != nil {
|
||||
// error-TTL:失败时存入短 TTL 空缓存,防止紧密重试
|
||||
slog.Warn("failed to build channel cache", "error", err)
|
||||
errorCache := &channelCache{
|
||||
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
|
||||
wildcardByGroupPlatform: make(map[channelGroupPlatformKey][]*wildcardPricingEntry),
|
||||
mappingByGroupModel: make(map[channelModelKey]string),
|
||||
wildcardMappingByGP: make(map[channelGroupPlatformKey][]*wildcardMappingEntry),
|
||||
channelByGroupID: make(map[int64]*Channel),
|
||||
groupPlatform: make(map[int64]string),
|
||||
byID: make(map[int64]*Channel),
|
||||
loadedAt: time.Now().Add(-(channelCacheTTL - channelErrorTTL)), // 使剩余 TTL = errorTTL
|
||||
}
|
||||
errorCache := newEmptyChannelCache()
|
||||
errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL)) // 使剩余 TTL = errorTTL
|
||||
s.cache.Store(errorCache)
|
||||
return nil, fmt.Errorf("list all channels: %w", err)
|
||||
}
|
||||
@@ -222,71 +275,20 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
|
||||
}
|
||||
}
|
||||
|
||||
cache := &channelCache{
|
||||
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
|
||||
wildcardByGroupPlatform: make(map[channelGroupPlatformKey][]*wildcardPricingEntry),
|
||||
mappingByGroupModel: make(map[channelModelKey]string),
|
||||
wildcardMappingByGP: make(map[channelGroupPlatformKey][]*wildcardMappingEntry),
|
||||
channelByGroupID: make(map[int64]*Channel),
|
||||
groupPlatform: groupPlatforms,
|
||||
byID: make(map[int64]*Channel, len(channels)),
|
||||
loadedAt: time.Now(),
|
||||
}
|
||||
cache := newEmptyChannelCache()
|
||||
cache.groupPlatform = groupPlatforms
|
||||
cache.byID = make(map[int64]*Channel, len(channels))
|
||||
cache.loadedAt = time.Now()
|
||||
|
||||
for i := range channels {
|
||||
ch := &channels[i]
|
||||
cache.byID[ch.ID] = ch
|
||||
|
||||
// 展开到分组维度
|
||||
for _, gid := range ch.GroupIDs {
|
||||
cache.channelByGroupID[gid] = ch
|
||||
platform := groupPlatforms[gid] // e.g. "anthropic"
|
||||
|
||||
// 只展开该平台的模型定价到 (groupID, platform, model) → *ChannelModelPricing
|
||||
// antigravity 平台同时服务 Claude 和 Gemini 模型,需匹配 anthropic/gemini 的定价条目
|
||||
for j := range ch.ModelPricing {
|
||||
pricing := &ch.ModelPricing[j]
|
||||
if !isPlatformPricingMatch(platform, pricing.Platform) {
|
||||
continue // 跳过非本平台的定价
|
||||
}
|
||||
for _, model := range pricing.Models {
|
||||
if strings.HasSuffix(model, "*") {
|
||||
// 通配符模型 → 存入 wildcardByGroupPlatform
|
||||
prefix := strings.ToLower(strings.TrimSuffix(model, "*"))
|
||||
gpKey := channelGroupPlatformKey{groupID: gid, platform: platform}
|
||||
cache.wildcardByGroupPlatform[gpKey] = append(cache.wildcardByGroupPlatform[gpKey], &wildcardPricingEntry{
|
||||
prefix: prefix,
|
||||
pricing: pricing,
|
||||
})
|
||||
} else {
|
||||
key := channelModelKey{groupID: gid, platform: platform, model: strings.ToLower(model)}
|
||||
cache.pricingByGroupModel[key] = pricing
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 只展开该平台的模型映射到 (groupID, platform, model) → target
|
||||
// antigravity 平台同时服务 Claude 和 Gemini 模型
|
||||
for _, mappingPlatform := range matchingPlatforms(platform) {
|
||||
platformMapping, ok := ch.ModelMapping[mappingPlatform]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for src, dst := range platformMapping {
|
||||
if strings.HasSuffix(src, "*") {
|
||||
// 通配符映射 → 存入 wildcardMappingByGP
|
||||
prefix := strings.ToLower(strings.TrimSuffix(src, "*"))
|
||||
gpKey := channelGroupPlatformKey{groupID: gid, platform: platform}
|
||||
cache.wildcardMappingByGP[gpKey] = append(cache.wildcardMappingByGP[gpKey], &wildcardMappingEntry{
|
||||
prefix: prefix,
|
||||
target: dst,
|
||||
})
|
||||
} else {
|
||||
key := channelModelKey{groupID: gid, platform: platform, model: strings.ToLower(src)}
|
||||
cache.mappingByGroupModel[key] = dst
|
||||
}
|
||||
}
|
||||
}
|
||||
platform := groupPlatforms[gid]
|
||||
expandPricingToCache(cache, ch, gid, platform)
|
||||
expandMappingToCache(cache, ch, gid, platform)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -362,26 +364,48 @@ func (s *ChannelService) GetChannelForGroup(ctx context.Context, groupID int64)
|
||||
return ch.Clone(), nil
|
||||
}
|
||||
|
||||
// channelLookup 热路径公共查找结果
|
||||
type channelLookup struct {
|
||||
cache *channelCache
|
||||
channel *Channel
|
||||
platform string
|
||||
}
|
||||
|
||||
// lookupGroupChannel 加载缓存并查找分组对应的渠道信息(公共热路径前置逻辑)。
|
||||
// 返回 nil 且 err==nil 表示分组无活跃渠道;err!=nil 表示缓存加载失败。
|
||||
func (s *ChannelService) lookupGroupChannel(ctx context.Context, groupID int64) (*channelLookup, error) {
|
||||
cache, err := s.loadCache(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ch, ok := cache.channelByGroupID[groupID]
|
||||
if !ok || !ch.IsActive() {
|
||||
return nil, nil
|
||||
}
|
||||
return &channelLookup{
|
||||
cache: cache,
|
||||
channel: ch,
|
||||
platform: cache.groupPlatform[groupID],
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetChannelModelPricing 获取指定分组+模型的渠道定价(热路径 O(1))
|
||||
func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int64, model string) *ChannelModelPricing {
|
||||
cache, err := s.loadCache(ctx)
|
||||
lk, err := s.lookupGroupChannel(ctx, groupID)
|
||||
if err != nil {
|
||||
slog.Warn("failed to load channel cache", "group_id", groupID, "error", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查渠道是否启用
|
||||
ch, ok := cache.channelByGroupID[groupID]
|
||||
if !ok || !ch.IsActive() {
|
||||
if lk == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
platform := cache.groupPlatform[groupID]
|
||||
key := channelModelKey{groupID: groupID, platform: platform, model: strings.ToLower(model)}
|
||||
pricing, ok := cache.pricingByGroupModel[key]
|
||||
modelLower := strings.ToLower(model)
|
||||
key := channelModelKey{groupID: groupID, platform: lk.platform, model: modelLower}
|
||||
pricing, ok := lk.cache.pricingByGroupModel[key]
|
||||
if !ok {
|
||||
// 精确查找失败,尝试通配符匹配
|
||||
pricing = cache.matchWildcard(groupID, platform, strings.ToLower(model))
|
||||
pricing = lk.cache.matchWildcard(groupID, lk.platform, modelLower)
|
||||
if pricing == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -394,31 +418,57 @@ func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int
|
||||
// ResolveChannelMapping 解析渠道级模型映射(热路径 O(1))
|
||||
// 返回映射结果,包含映射后的模型名、渠道 ID、计费模型来源。
|
||||
func (s *ChannelService) ResolveChannelMapping(ctx context.Context, groupID int64, model string) ChannelMappingResult {
|
||||
cache, err := s.loadCache(ctx)
|
||||
if err != nil {
|
||||
lk, _ := s.lookupGroupChannel(ctx, groupID)
|
||||
if lk == nil {
|
||||
return ChannelMappingResult{MappedModel: model}
|
||||
}
|
||||
return resolveMapping(lk, groupID, model)
|
||||
}
|
||||
|
||||
ch, ok := cache.channelByGroupID[groupID]
|
||||
if !ok || !ch.IsActive() {
|
||||
return ChannelMappingResult{MappedModel: model}
|
||||
// IsModelRestricted 检查模型是否被渠道限制。
|
||||
// 返回 true 表示模型被限制(不在允许列表中)。
|
||||
// 如果渠道未启用模型限制或分组无渠道关联,返回 false。
|
||||
func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool {
|
||||
lk, _ := s.lookupGroupChannel(ctx, groupID)
|
||||
if lk == nil {
|
||||
return false
|
||||
}
|
||||
return checkRestricted(lk, groupID, model)
|
||||
}
|
||||
|
||||
// ResolveChannelMappingAndRestrict 解析渠道映射并检查模型限制(组合方法)。
|
||||
// 返回映射结果和是否被限制。groupID 为 nil 时跳过。
|
||||
func (s *ChannelService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) {
|
||||
if groupID == nil {
|
||||
return ChannelMappingResult{MappedModel: model}, false
|
||||
}
|
||||
lk, _ := s.lookupGroupChannel(ctx, *groupID)
|
||||
if lk == nil {
|
||||
return ChannelMappingResult{MappedModel: model}, false
|
||||
}
|
||||
// 先用原始模型检查定价列表限制,再做映射
|
||||
restricted := checkRestricted(lk, *groupID, model)
|
||||
mapping := resolveMapping(lk, *groupID, model)
|
||||
return mapping, restricted
|
||||
}
|
||||
|
||||
// resolveMapping 基于已查找的渠道信息解析模型映射
|
||||
func resolveMapping(lk *channelLookup, groupID int64, model string) ChannelMappingResult {
|
||||
result := ChannelMappingResult{
|
||||
MappedModel: model,
|
||||
ChannelID: ch.ID,
|
||||
BillingModelSource: ch.BillingModelSource,
|
||||
ChannelID: lk.channel.ID,
|
||||
BillingModelSource: lk.channel.BillingModelSource,
|
||||
}
|
||||
if result.BillingModelSource == "" {
|
||||
result.BillingModelSource = BillingModelSourceChannelMapped
|
||||
}
|
||||
|
||||
platform := cache.groupPlatform[groupID]
|
||||
key := channelModelKey{groupID: groupID, platform: platform, model: strings.ToLower(model)}
|
||||
if mapped, ok := cache.mappingByGroupModel[key]; ok {
|
||||
modelLower := strings.ToLower(model)
|
||||
key := channelModelKey{groupID: groupID, platform: lk.platform, model: modelLower}
|
||||
if mapped, ok := lk.cache.mappingByGroupModel[key]; ok {
|
||||
result.MappedModel = mapped
|
||||
result.Mapped = true
|
||||
} else if mapped := cache.matchWildcardMapping(groupID, platform, strings.ToLower(model)); mapped != "" {
|
||||
} else if mapped := lk.cache.matchWildcardMapping(groupID, lk.platform, modelLower); mapped != "" {
|
||||
result.MappedModel = mapped
|
||||
result.Mapped = true
|
||||
}
|
||||
@@ -426,48 +476,24 @@ func (s *ChannelService) ResolveChannelMapping(ctx context.Context, groupID int6
|
||||
return result
|
||||
}
|
||||
|
||||
// IsModelRestricted 检查模型是否被渠道限制。
|
||||
// 返回 true 表示模型被限制(不在允许列表中)。
|
||||
// 如果渠道未启用模型限制或分组无渠道关联,返回 false。
|
||||
func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool {
|
||||
cache, err := s.loadCache(ctx)
|
||||
if err != nil {
|
||||
return false // 缓存加载失败时不限制
|
||||
}
|
||||
|
||||
ch, ok := cache.channelByGroupID[groupID]
|
||||
if !ok || !ch.IsActive() || !ch.RestrictModels {
|
||||
// checkRestricted 基于已查找的渠道信息检查模型是否被限制
|
||||
func checkRestricted(lk *channelLookup, groupID int64, model string) bool {
|
||||
if !lk.channel.RestrictModels {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查模型是否在定价列表中
|
||||
platform := cache.groupPlatform[groupID]
|
||||
key := channelModelKey{groupID: groupID, platform: platform, model: strings.ToLower(model)}
|
||||
_, exists := cache.pricingByGroupModel[key]
|
||||
if exists {
|
||||
modelLower := strings.ToLower(model)
|
||||
key := channelModelKey{groupID: groupID, platform: lk.platform, model: modelLower}
|
||||
if _, exists := lk.cache.pricingByGroupModel[key]; exists {
|
||||
return false
|
||||
}
|
||||
// 精确查找失败,尝试通配符匹配
|
||||
if cache.matchWildcard(groupID, platform, strings.ToLower(model)) != nil {
|
||||
if lk.cache.matchWildcard(groupID, lk.platform, modelLower) != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// ResolveChannelMappingAndRestrict 解析渠道映射并检查模型限制(组合方法)。
|
||||
// 返回映射结果和是否被限制。groupID 为 nil 时跳过。
|
||||
func (s *ChannelService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) {
|
||||
var mapping ChannelMappingResult
|
||||
mapping.MappedModel = model
|
||||
if groupID == nil {
|
||||
return mapping, false
|
||||
}
|
||||
// 先用原始模型检查定价列表限制,再做映射
|
||||
restricted := s.IsModelRestricted(ctx, *groupID, model)
|
||||
mapping = s.ResolveChannelMapping(ctx, *groupID, model)
|
||||
return mapping, restricted
|
||||
}
|
||||
|
||||
// ReplaceModelInBody 替换请求体 JSON 中的 model 字段。
|
||||
func ReplaceModelInBody(body []byte, newModel string) []byte {
|
||||
if len(body) == 0 {
|
||||
|
||||
Reference in New Issue
Block a user