refactor: split buildCache into sub-functions, reduce nesting 5→2
- Extract newEmptyChannelCache() factory to deduplicate map init - Extract expandPricingToCache() for model pricing expansion - Extract expandMappingToCache() for model mapping expansion - buildCache reduced from 110 to 50 lines
This commit is contained in:
@@ -183,6 +183,67 @@ func (s *ChannelService) loadCache(ctx context.Context) (*channelCache, error) {
|
|||||||
return cache, nil
|
return cache, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// newEmptyChannelCache 创建空的渠道缓存(所有 map 已初始化)
|
||||||
|
func newEmptyChannelCache() *channelCache {
|
||||||
|
return &channelCache{
|
||||||
|
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
|
||||||
|
wildcardByGroupPlatform: make(map[channelGroupPlatformKey][]*wildcardPricingEntry),
|
||||||
|
mappingByGroupModel: make(map[channelModelKey]string),
|
||||||
|
wildcardMappingByGP: make(map[channelGroupPlatformKey][]*wildcardMappingEntry),
|
||||||
|
channelByGroupID: make(map[int64]*Channel),
|
||||||
|
groupPlatform: make(map[int64]string),
|
||||||
|
byID: make(map[int64]*Channel),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// expandPricingToCache 将渠道的模型定价展开到缓存(按分组+平台维度)。
|
||||||
|
// antigravity 平台同时服务 Claude 和 Gemini 模型,需匹配 anthropic/gemini 的定价条目。
|
||||||
|
func expandPricingToCache(cache *channelCache, ch *Channel, gid int64, platform string) {
|
||||||
|
for j := range ch.ModelPricing {
|
||||||
|
pricing := &ch.ModelPricing[j]
|
||||||
|
if !isPlatformPricingMatch(platform, pricing.Platform) {
|
||||||
|
continue // 跳过非本平台的定价
|
||||||
|
}
|
||||||
|
gpKey := channelGroupPlatformKey{groupID: gid, platform: platform}
|
||||||
|
for _, model := range pricing.Models {
|
||||||
|
if strings.HasSuffix(model, "*") {
|
||||||
|
prefix := strings.ToLower(strings.TrimSuffix(model, "*"))
|
||||||
|
cache.wildcardByGroupPlatform[gpKey] = append(cache.wildcardByGroupPlatform[gpKey], &wildcardPricingEntry{
|
||||||
|
prefix: prefix,
|
||||||
|
pricing: pricing,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
key := channelModelKey{groupID: gid, platform: platform, model: strings.ToLower(model)}
|
||||||
|
cache.pricingByGroupModel[key] = pricing
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// expandMappingToCache 将渠道的模型映射展开到缓存(按分组+平台维度)。
|
||||||
|
// antigravity 平台同时服务 Claude 和 Gemini 模型。
|
||||||
|
func expandMappingToCache(cache *channelCache, ch *Channel, gid int64, platform string) {
|
||||||
|
for _, mappingPlatform := range matchingPlatforms(platform) {
|
||||||
|
platformMapping, ok := ch.ModelMapping[mappingPlatform]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
gpKey := channelGroupPlatformKey{groupID: gid, platform: platform}
|
||||||
|
for src, dst := range platformMapping {
|
||||||
|
if strings.HasSuffix(src, "*") {
|
||||||
|
prefix := strings.ToLower(strings.TrimSuffix(src, "*"))
|
||||||
|
cache.wildcardMappingByGP[gpKey] = append(cache.wildcardMappingByGP[gpKey], &wildcardMappingEntry{
|
||||||
|
prefix: prefix,
|
||||||
|
target: dst,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
key := channelModelKey{groupID: gid, platform: platform, model: strings.ToLower(src)}
|
||||||
|
cache.mappingByGroupModel[key] = dst
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// buildCache 从数据库构建渠道缓存。
|
// buildCache 从数据库构建渠道缓存。
|
||||||
// 使用独立 context 避免请求取消导致空值被长期缓存。
|
// 使用独立 context 避免请求取消导致空值被长期缓存。
|
||||||
func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) {
|
func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) {
|
||||||
@@ -194,16 +255,8 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
// error-TTL:失败时存入短 TTL 空缓存,防止紧密重试
|
// error-TTL:失败时存入短 TTL 空缓存,防止紧密重试
|
||||||
slog.Warn("failed to build channel cache", "error", err)
|
slog.Warn("failed to build channel cache", "error", err)
|
||||||
errorCache := &channelCache{
|
errorCache := newEmptyChannelCache()
|
||||||
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
|
errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL)) // 使剩余 TTL = errorTTL
|
||||||
wildcardByGroupPlatform: make(map[channelGroupPlatformKey][]*wildcardPricingEntry),
|
|
||||||
mappingByGroupModel: make(map[channelModelKey]string),
|
|
||||||
wildcardMappingByGP: make(map[channelGroupPlatformKey][]*wildcardMappingEntry),
|
|
||||||
channelByGroupID: make(map[int64]*Channel),
|
|
||||||
groupPlatform: make(map[int64]string),
|
|
||||||
byID: make(map[int64]*Channel),
|
|
||||||
loadedAt: time.Now().Add(-(channelCacheTTL - channelErrorTTL)), // 使剩余 TTL = errorTTL
|
|
||||||
}
|
|
||||||
s.cache.Store(errorCache)
|
s.cache.Store(errorCache)
|
||||||
return nil, fmt.Errorf("list all channels: %w", err)
|
return nil, fmt.Errorf("list all channels: %w", err)
|
||||||
}
|
}
|
||||||
@@ -222,71 +275,20 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cache := &channelCache{
|
cache := newEmptyChannelCache()
|
||||||
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
|
cache.groupPlatform = groupPlatforms
|
||||||
wildcardByGroupPlatform: make(map[channelGroupPlatformKey][]*wildcardPricingEntry),
|
cache.byID = make(map[int64]*Channel, len(channels))
|
||||||
mappingByGroupModel: make(map[channelModelKey]string),
|
cache.loadedAt = time.Now()
|
||||||
wildcardMappingByGP: make(map[channelGroupPlatformKey][]*wildcardMappingEntry),
|
|
||||||
channelByGroupID: make(map[int64]*Channel),
|
|
||||||
groupPlatform: groupPlatforms,
|
|
||||||
byID: make(map[int64]*Channel, len(channels)),
|
|
||||||
loadedAt: time.Now(),
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range channels {
|
for i := range channels {
|
||||||
ch := &channels[i]
|
ch := &channels[i]
|
||||||
cache.byID[ch.ID] = ch
|
cache.byID[ch.ID] = ch
|
||||||
|
|
||||||
// 展开到分组维度
|
|
||||||
for _, gid := range ch.GroupIDs {
|
for _, gid := range ch.GroupIDs {
|
||||||
cache.channelByGroupID[gid] = ch
|
cache.channelByGroupID[gid] = ch
|
||||||
platform := groupPlatforms[gid] // e.g. "anthropic"
|
platform := groupPlatforms[gid]
|
||||||
|
expandPricingToCache(cache, ch, gid, platform)
|
||||||
// 只展开该平台的模型定价到 (groupID, platform, model) → *ChannelModelPricing
|
expandMappingToCache(cache, ch, gid, platform)
|
||||||
// antigravity 平台同时服务 Claude 和 Gemini 模型,需匹配 anthropic/gemini 的定价条目
|
|
||||||
for j := range ch.ModelPricing {
|
|
||||||
pricing := &ch.ModelPricing[j]
|
|
||||||
if !isPlatformPricingMatch(platform, pricing.Platform) {
|
|
||||||
continue // 跳过非本平台的定价
|
|
||||||
}
|
|
||||||
for _, model := range pricing.Models {
|
|
||||||
if strings.HasSuffix(model, "*") {
|
|
||||||
// 通配符模型 → 存入 wildcardByGroupPlatform
|
|
||||||
prefix := strings.ToLower(strings.TrimSuffix(model, "*"))
|
|
||||||
gpKey := channelGroupPlatformKey{groupID: gid, platform: platform}
|
|
||||||
cache.wildcardByGroupPlatform[gpKey] = append(cache.wildcardByGroupPlatform[gpKey], &wildcardPricingEntry{
|
|
||||||
prefix: prefix,
|
|
||||||
pricing: pricing,
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
key := channelModelKey{groupID: gid, platform: platform, model: strings.ToLower(model)}
|
|
||||||
cache.pricingByGroupModel[key] = pricing
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 只展开该平台的模型映射到 (groupID, platform, model) → target
|
|
||||||
// antigravity 平台同时服务 Claude 和 Gemini 模型
|
|
||||||
for _, mappingPlatform := range matchingPlatforms(platform) {
|
|
||||||
platformMapping, ok := ch.ModelMapping[mappingPlatform]
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
for src, dst := range platformMapping {
|
|
||||||
if strings.HasSuffix(src, "*") {
|
|
||||||
// 通配符映射 → 存入 wildcardMappingByGP
|
|
||||||
prefix := strings.ToLower(strings.TrimSuffix(src, "*"))
|
|
||||||
gpKey := channelGroupPlatformKey{groupID: gid, platform: platform}
|
|
||||||
cache.wildcardMappingByGP[gpKey] = append(cache.wildcardMappingByGP[gpKey], &wildcardMappingEntry{
|
|
||||||
prefix: prefix,
|
|
||||||
target: dst,
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
key := channelModelKey{groupID: gid, platform: platform, model: strings.ToLower(src)}
|
|
||||||
cache.mappingByGroupModel[key] = dst
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -362,26 +364,48 @@ func (s *ChannelService) GetChannelForGroup(ctx context.Context, groupID int64)
|
|||||||
return ch.Clone(), nil
|
return ch.Clone(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// channelLookup 热路径公共查找结果
|
||||||
|
type channelLookup struct {
|
||||||
|
cache *channelCache
|
||||||
|
channel *Channel
|
||||||
|
platform string
|
||||||
|
}
|
||||||
|
|
||||||
|
// lookupGroupChannel 加载缓存并查找分组对应的渠道信息(公共热路径前置逻辑)。
|
||||||
|
// 返回 nil 且 err==nil 表示分组无活跃渠道;err!=nil 表示缓存加载失败。
|
||||||
|
func (s *ChannelService) lookupGroupChannel(ctx context.Context, groupID int64) (*channelLookup, error) {
|
||||||
|
cache, err := s.loadCache(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ch, ok := cache.channelByGroupID[groupID]
|
||||||
|
if !ok || !ch.IsActive() {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return &channelLookup{
|
||||||
|
cache: cache,
|
||||||
|
channel: ch,
|
||||||
|
platform: cache.groupPlatform[groupID],
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetChannelModelPricing 获取指定分组+模型的渠道定价(热路径 O(1))
|
// GetChannelModelPricing 获取指定分组+模型的渠道定价(热路径 O(1))
|
||||||
func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int64, model string) *ChannelModelPricing {
|
func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int64, model string) *ChannelModelPricing {
|
||||||
cache, err := s.loadCache(ctx)
|
lk, err := s.lookupGroupChannel(ctx, groupID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn("failed to load channel cache", "group_id", groupID, "error", err)
|
slog.Warn("failed to load channel cache", "group_id", groupID, "error", err)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
if lk == nil {
|
||||||
// 检查渠道是否启用
|
|
||||||
ch, ok := cache.channelByGroupID[groupID]
|
|
||||||
if !ok || !ch.IsActive() {
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
platform := cache.groupPlatform[groupID]
|
modelLower := strings.ToLower(model)
|
||||||
key := channelModelKey{groupID: groupID, platform: platform, model: strings.ToLower(model)}
|
key := channelModelKey{groupID: groupID, platform: lk.platform, model: modelLower}
|
||||||
pricing, ok := cache.pricingByGroupModel[key]
|
pricing, ok := lk.cache.pricingByGroupModel[key]
|
||||||
if !ok {
|
if !ok {
|
||||||
// 精确查找失败,尝试通配符匹配
|
// 精确查找失败,尝试通配符匹配
|
||||||
pricing = cache.matchWildcard(groupID, platform, strings.ToLower(model))
|
pricing = lk.cache.matchWildcard(groupID, lk.platform, modelLower)
|
||||||
if pricing == nil {
|
if pricing == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -394,31 +418,57 @@ func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int
|
|||||||
// ResolveChannelMapping 解析渠道级模型映射(热路径 O(1))
|
// ResolveChannelMapping 解析渠道级模型映射(热路径 O(1))
|
||||||
// 返回映射结果,包含映射后的模型名、渠道 ID、计费模型来源。
|
// 返回映射结果,包含映射后的模型名、渠道 ID、计费模型来源。
|
||||||
func (s *ChannelService) ResolveChannelMapping(ctx context.Context, groupID int64, model string) ChannelMappingResult {
|
func (s *ChannelService) ResolveChannelMapping(ctx context.Context, groupID int64, model string) ChannelMappingResult {
|
||||||
cache, err := s.loadCache(ctx)
|
lk, _ := s.lookupGroupChannel(ctx, groupID)
|
||||||
if err != nil {
|
if lk == nil {
|
||||||
return ChannelMappingResult{MappedModel: model}
|
return ChannelMappingResult{MappedModel: model}
|
||||||
}
|
}
|
||||||
|
return resolveMapping(lk, groupID, model)
|
||||||
|
}
|
||||||
|
|
||||||
ch, ok := cache.channelByGroupID[groupID]
|
// IsModelRestricted 检查模型是否被渠道限制。
|
||||||
if !ok || !ch.IsActive() {
|
// 返回 true 表示模型被限制(不在允许列表中)。
|
||||||
return ChannelMappingResult{MappedModel: model}
|
// 如果渠道未启用模型限制或分组无渠道关联,返回 false。
|
||||||
|
func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool {
|
||||||
|
lk, _ := s.lookupGroupChannel(ctx, groupID)
|
||||||
|
if lk == nil {
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
return checkRestricted(lk, groupID, model)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResolveChannelMappingAndRestrict 解析渠道映射并检查模型限制(组合方法)。
|
||||||
|
// 返回映射结果和是否被限制。groupID 为 nil 时跳过。
|
||||||
|
func (s *ChannelService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) {
|
||||||
|
if groupID == nil {
|
||||||
|
return ChannelMappingResult{MappedModel: model}, false
|
||||||
|
}
|
||||||
|
lk, _ := s.lookupGroupChannel(ctx, *groupID)
|
||||||
|
if lk == nil {
|
||||||
|
return ChannelMappingResult{MappedModel: model}, false
|
||||||
|
}
|
||||||
|
// 先用原始模型检查定价列表限制,再做映射
|
||||||
|
restricted := checkRestricted(lk, *groupID, model)
|
||||||
|
mapping := resolveMapping(lk, *groupID, model)
|
||||||
|
return mapping, restricted
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveMapping 基于已查找的渠道信息解析模型映射
|
||||||
|
func resolveMapping(lk *channelLookup, groupID int64, model string) ChannelMappingResult {
|
||||||
result := ChannelMappingResult{
|
result := ChannelMappingResult{
|
||||||
MappedModel: model,
|
MappedModel: model,
|
||||||
ChannelID: ch.ID,
|
ChannelID: lk.channel.ID,
|
||||||
BillingModelSource: ch.BillingModelSource,
|
BillingModelSource: lk.channel.BillingModelSource,
|
||||||
}
|
}
|
||||||
if result.BillingModelSource == "" {
|
if result.BillingModelSource == "" {
|
||||||
result.BillingModelSource = BillingModelSourceChannelMapped
|
result.BillingModelSource = BillingModelSourceChannelMapped
|
||||||
}
|
}
|
||||||
|
|
||||||
platform := cache.groupPlatform[groupID]
|
modelLower := strings.ToLower(model)
|
||||||
key := channelModelKey{groupID: groupID, platform: platform, model: strings.ToLower(model)}
|
key := channelModelKey{groupID: groupID, platform: lk.platform, model: modelLower}
|
||||||
if mapped, ok := cache.mappingByGroupModel[key]; ok {
|
if mapped, ok := lk.cache.mappingByGroupModel[key]; ok {
|
||||||
result.MappedModel = mapped
|
result.MappedModel = mapped
|
||||||
result.Mapped = true
|
result.Mapped = true
|
||||||
} else if mapped := cache.matchWildcardMapping(groupID, platform, strings.ToLower(model)); mapped != "" {
|
} else if mapped := lk.cache.matchWildcardMapping(groupID, lk.platform, modelLower); mapped != "" {
|
||||||
result.MappedModel = mapped
|
result.MappedModel = mapped
|
||||||
result.Mapped = true
|
result.Mapped = true
|
||||||
}
|
}
|
||||||
@@ -426,48 +476,24 @@ func (s *ChannelService) ResolveChannelMapping(ctx context.Context, groupID int6
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsModelRestricted 检查模型是否被渠道限制。
|
// checkRestricted 基于已查找的渠道信息检查模型是否被限制
|
||||||
// 返回 true 表示模型被限制(不在允许列表中)。
|
func checkRestricted(lk *channelLookup, groupID int64, model string) bool {
|
||||||
// 如果渠道未启用模型限制或分组无渠道关联,返回 false。
|
if !lk.channel.RestrictModels {
|
||||||
func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool {
|
|
||||||
cache, err := s.loadCache(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return false // 缓存加载失败时不限制
|
|
||||||
}
|
|
||||||
|
|
||||||
ch, ok := cache.channelByGroupID[groupID]
|
|
||||||
if !ok || !ch.IsActive() || !ch.RestrictModels {
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查模型是否在定价列表中
|
// 检查模型是否在定价列表中
|
||||||
platform := cache.groupPlatform[groupID]
|
modelLower := strings.ToLower(model)
|
||||||
key := channelModelKey{groupID: groupID, platform: platform, model: strings.ToLower(model)}
|
key := channelModelKey{groupID: groupID, platform: lk.platform, model: modelLower}
|
||||||
_, exists := cache.pricingByGroupModel[key]
|
if _, exists := lk.cache.pricingByGroupModel[key]; exists {
|
||||||
if exists {
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
// 精确查找失败,尝试通配符匹配
|
// 精确查找失败,尝试通配符匹配
|
||||||
if cache.matchWildcard(groupID, platform, strings.ToLower(model)) != nil {
|
if lk.cache.matchWildcard(groupID, lk.platform, modelLower) != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// ResolveChannelMappingAndRestrict 解析渠道映射并检查模型限制(组合方法)。
|
|
||||||
// 返回映射结果和是否被限制。groupID 为 nil 时跳过。
|
|
||||||
func (s *ChannelService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) {
|
|
||||||
var mapping ChannelMappingResult
|
|
||||||
mapping.MappedModel = model
|
|
||||||
if groupID == nil {
|
|
||||||
return mapping, false
|
|
||||||
}
|
|
||||||
// 先用原始模型检查定价列表限制,再做映射
|
|
||||||
restricted := s.IsModelRestricted(ctx, *groupID, model)
|
|
||||||
mapping = s.ResolveChannelMapping(ctx, *groupID, model)
|
|
||||||
return mapping, restricted
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReplaceModelInBody 替换请求体 JSON 中的 model 字段。
|
// ReplaceModelInBody 替换请求体 JSON 中的 model 字段。
|
||||||
func ReplaceModelInBody(body []byte, newModel string) []byte {
|
func ReplaceModelInBody(body []byte, newModel string) []byte {
|
||||||
if len(body) == 0 {
|
if len(body) == 0 {
|
||||||
|
|||||||
Reference in New Issue
Block a user