fix: resolve 5 audit findings in channel/credits/scheduling

P0-1: Credits degraded response retry + fail-open
- Add isAntigravityDegradedResponse() to detect transient API failures
- Retry up to 3 times with exponential backoff (500ms/1s/2s)
- Invalidate singleflight cache between retries
- Fail-open after exhausting retries instead of 5h circuit break

P1-1: Fix channel restriction pre-check timing conflict
- Swap checkClaudeCodeRestriction before checkChannelPricingRestriction
- Ensures channel restriction is checked against final fallback groupID

P1-2: Add interval pricing validation (frontend + backend)
- Backend: ValidateIntervals() with boundary, price, overlap checks
- Frontend: validateIntervals() with Chinese error messages
- Rules: MinTokens>=0, MaxTokens>MinTokens, prices>=0, no overlap

P2: Fix cross-platform same-model pricing/mapping override
- Store cache keys using original platform instead of group platform
- Lookup across matching platforms (antigravity→anthropic→gemini)
- Prevents anthropic/gemini same-name models from overwriting each other
This commit is contained in:
erio
2026-04-02 20:28:04 +08:00
parent 6d3ea64a35
commit 71f61bbc47
12 changed files with 1028 additions and 48 deletions

View File

@@ -198,13 +198,18 @@ func newEmptyChannelCache() *channelCache {
// expandPricingToCache 将渠道的模型定价展开到缓存(按分组+平台维度)。
// antigravity 平台同时服务 Claude 和 Gemini 模型,需匹配 anthropic/gemini 的定价条目。
// 缓存 key 使用定价条目的原始平台pricing.Platform而非分组平台
// 避免跨平台同名模型(如 anthropic 和 gemini 都有 "model-x")互相覆盖。
// 查找时通过 lookupPricingAcrossPlatforms() 依次尝试所有匹配平台。
func expandPricingToCache(cache *channelCache, ch *Channel, gid int64, platform string) {
for j := range ch.ModelPricing {
pricing := &ch.ModelPricing[j]
if !isPlatformPricingMatch(platform, pricing.Platform) {
continue // 跳过非本平台的定价
}
gpKey := channelGroupPlatformKey{groupID: gid, platform: platform}
// 使用定价条目的原始平台作为缓存 key防止跨平台同名模型冲突
pricingPlatform := pricing.Platform
gpKey := channelGroupPlatformKey{groupID: gid, platform: pricingPlatform}
for _, model := range pricing.Models {
if strings.HasSuffix(model, "*") {
prefix := strings.ToLower(strings.TrimSuffix(model, "*"))
@@ -213,7 +218,7 @@ func expandPricingToCache(cache *channelCache, ch *Channel, gid int64, platform
pricing: pricing,
})
} else {
key := channelModelKey{groupID: gid, platform: platform, model: strings.ToLower(model)}
key := channelModelKey{groupID: gid, platform: pricingPlatform, model: strings.ToLower(model)}
cache.pricingByGroupModel[key] = pricing
}
}
@@ -222,13 +227,15 @@ func expandPricingToCache(cache *channelCache, ch *Channel, gid int64, platform
// expandMappingToCache 将渠道的模型映射展开到缓存(按分组+平台维度)。
// antigravity 平台同时服务 Claude 和 Gemini 模型。
// 缓存 key 使用映射条目的原始平台mappingPlatform避免跨平台同名映射覆盖。
func expandMappingToCache(cache *channelCache, ch *Channel, gid int64, platform string) {
for _, mappingPlatform := range matchingPlatforms(platform) {
platformMapping, ok := ch.ModelMapping[mappingPlatform]
if !ok {
continue
}
gpKey := channelGroupPlatformKey{groupID: gid, platform: platform}
// 使用映射条目的原始平台作为缓存 key防止跨平台同名映射冲突
gpKey := channelGroupPlatformKey{groupID: gid, platform: mappingPlatform}
for src, dst := range platformMapping {
if strings.HasSuffix(src, "*") {
prefix := strings.ToLower(strings.TrimSuffix(src, "*"))
@@ -237,7 +244,7 @@ func expandMappingToCache(cache *channelCache, ch *Channel, gid int64, platform
target: dst,
})
} else {
key := channelModelKey{groupID: gid, platform: platform, model: strings.ToLower(src)}
key := channelModelKey{groupID: gid, platform: mappingPlatform, model: strings.ToLower(src)}
cache.mappingByGroupModel[key] = dst
}
}
@@ -349,6 +356,43 @@ func (c *channelCache) matchWildcardMapping(groupID int64, platform, modelLower
return ""
}
// lookupPricingAcrossPlatforms 在所有匹配平台中查找模型定价。
// antigravity 分组的缓存 key 使用定价条目的原始平台,因此查找时需依次尝试
// matchingPlatforms() 返回的所有平台antigravity → anthropic → gemini
// 返回第一个命中的结果。非 antigravity 平台只尝试自身。
func lookupPricingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatform, modelLower string) *ChannelModelPricing {
for _, p := range matchingPlatforms(groupPlatform) {
key := channelModelKey{groupID: groupID, platform: p, model: modelLower}
if pricing, ok := cache.pricingByGroupModel[key]; ok {
return pricing
}
}
// 精确查找全部失败,依次尝试通配符匹配
for _, p := range matchingPlatforms(groupPlatform) {
if pricing := cache.matchWildcard(groupID, p, modelLower); pricing != nil {
return pricing
}
}
return nil
}
// lookupMappingAcrossPlatforms 在所有匹配平台中查找模型映射。
// 逻辑与 lookupPricingAcrossPlatforms 相同:先精确查找,再通配符。
func lookupMappingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatform, modelLower string) string {
for _, p := range matchingPlatforms(groupPlatform) {
key := channelModelKey{groupID: groupID, platform: p, model: modelLower}
if mapped, ok := cache.mappingByGroupModel[key]; ok {
return mapped
}
}
for _, p := range matchingPlatforms(groupPlatform) {
if mapped := cache.matchWildcardMapping(groupID, p, modelLower); mapped != "" {
return mapped
}
}
return ""
}
// GetChannelForGroup 获取分组关联的渠道(热路径 O(1)
func (s *ChannelService) GetChannelForGroup(ctx context.Context, groupID int64) (*Channel, error) {
cache, err := s.loadCache(ctx)
@@ -389,7 +433,9 @@ func (s *ChannelService) lookupGroupChannel(ctx context.Context, groupID int64)
}, nil
}
// GetChannelModelPricing 获取指定分组+模型的渠道定价(热路径 O(1)
// GetChannelModelPricing 获取指定分组+模型的渠道定价(热路径 O(1)
// antigravity 分组依次尝试所有匹配平台antigravity → anthropic → gemini
// 确保跨平台同名模型各自独立匹配。
func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int64, model string) *ChannelModelPricing {
lk, err := s.lookupGroupChannel(ctx, groupID)
if err != nil {
@@ -401,14 +447,9 @@ func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int
}
modelLower := strings.ToLower(model)
key := channelModelKey{groupID: groupID, platform: lk.platform, model: modelLower}
pricing, ok := lk.cache.pricingByGroupModel[key]
if !ok {
// 精确查找失败,尝试通配符匹配
pricing = lk.cache.matchWildcard(groupID, lk.platform, modelLower)
if pricing == nil {
return nil
}
pricing := lookupPricingAcrossPlatforms(lk.cache, groupID, lk.platform, modelLower)
if pricing == nil {
return nil
}
cp := pricing.Clone()
@@ -453,7 +494,8 @@ func (s *ChannelService) ResolveChannelMappingAndRestrict(ctx context.Context, g
return resolveMapping(lk, *groupID, model), false
}
// resolveMapping 基于已查找的渠道信息解析模型映射
// resolveMapping 基于已查找的渠道信息解析模型映射
// antigravity 分组依次尝试所有匹配平台,确保跨平台同名映射各自独立。
func resolveMapping(lk *channelLookup, groupID int64, model string) ChannelMappingResult {
result := ChannelMappingResult{
MappedModel: model,
@@ -465,11 +507,7 @@ func resolveMapping(lk *channelLookup, groupID int64, model string) ChannelMappi
}
modelLower := strings.ToLower(model)
key := channelModelKey{groupID: groupID, platform: lk.platform, model: modelLower}
if mapped, ok := lk.cache.mappingByGroupModel[key]; ok {
result.MappedModel = mapped
result.Mapped = true
} else if mapped := lk.cache.matchWildcardMapping(groupID, lk.platform, modelLower); mapped != "" {
if mapped := lookupMappingAcrossPlatforms(lk.cache, groupID, lk.platform, modelLower); mapped != "" {
result.MappedModel = mapped
result.Mapped = true
}
@@ -477,19 +515,15 @@ func resolveMapping(lk *channelLookup, groupID int64, model string) ChannelMappi
return result
}
// checkRestricted 基于已查找的渠道信息检查模型是否被限制
// checkRestricted 基于已查找的渠道信息检查模型是否被限制
// antigravity 分组依次尝试所有匹配平台的定价列表。
func checkRestricted(lk *channelLookup, groupID int64, model string) bool {
if !lk.channel.RestrictModels {
return false
}
// 检查模型是否在定价列表中
modelLower := strings.ToLower(model)
key := channelModelKey{groupID: groupID, platform: lk.platform, model: modelLower}
if _, exists := lk.cache.pricingByGroupModel[key]; exists {
return false
}
// 精确查找失败,尝试通配符匹配
if lk.cache.matchWildcard(groupID, lk.platform, modelLower) != nil {
// 使用与查找定价相同的跨平台逻辑
if lookupPricingAcrossPlatforms(lk.cache, groupID, lk.platform, modelLower) != nil {
return false
}
return true
@@ -550,6 +584,9 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
return nil, err
}
if err := validatePricingIntervals(channel.ModelPricing); err != nil {
return nil, err
}
if err := validateNoConflictingMappings(channel.ModelMapping); err != nil {
return nil, err
}
@@ -624,6 +661,9 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan
if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
return nil, err
}
if err := validatePricingIntervals(channel.ModelPricing); err != nil {
return nil, err
}
if err := validateNoConflictingMappings(channel.ModelMapping); err != nil {
return nil, err
}
@@ -756,6 +796,19 @@ func validateNoConflictingMappings(mapping map[string]map[string]string) error {
return nil
}
func validatePricingIntervals(pricingList []ChannelModelPricing) error {
for _, pricing := range pricingList {
if err := ValidateIntervals(pricing.Intervals); err != nil {
return infraerrors.BadRequest(
"INVALID_PRICING_INTERVALS",
fmt.Sprintf("invalid pricing intervals for platform '%s' models %v: %v",
pricing.Platform, pricing.Models, err),
)
}
}
return nil
}
// detectConflicts 在一组 modelEntry 中检测冲突,返回带有 errCode 和 label 的错误
func detectConflicts(entries []modelEntry, platform, errCode, label string) error {
for i := 0; i < len(entries); i++ {