diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go index e89986b5..50deba8b 100644 --- a/backend/internal/service/account_usage_service.go +++ b/backend/internal/service/account_usage_service.go @@ -855,6 +855,13 @@ func (s *AccountUsageService) GetAntigravityCredits(ctx context.Context, account return s.getAntigravityUsage(ctx, account) } +// InvalidateAntigravityCreditsCache 清除指定账号的 Antigravity 用量缓存, +// 使下次调用 GetAntigravityCredits 时强制重新拉取。 +// 用于 credits 降级响应重试场景:避免重试命中同一个降级缓存。 +func (s *AccountUsageService) InvalidateAntigravityCreditsCache(accountID int64) { + s.cache.antigravityCache.Delete(accountID) +} + // recalcAntigravityRemainingSeconds 重新计算 Antigravity UsageInfo 中各窗口的 RemainingSeconds // 用于从缓存取出时更新倒计时,避免返回过时的剩余秒数 func recalcAntigravityRemainingSeconds(info *UsageInfo) { diff --git a/backend/internal/service/antigravity_credits_overages.go b/backend/internal/service/antigravity_credits_overages.go index 99ec7d08..3e19c563 100644 --- a/backend/internal/service/antigravity_credits_overages.go +++ b/backend/internal/service/antigravity_credits_overages.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "io" + "log/slog" "net/http" "strings" "time" @@ -17,33 +18,116 @@ const ( // 与普通模型限流完全同构:通过 SetModelRateLimit / isRateLimitActiveForKey 读写。 creditsExhaustedKey = "AICredits" creditsExhaustedDuration = 5 * time.Hour + + // credits 降级响应重试参数 + creditsRetryMaxAttempts = 3 + creditsRetryBaseInterval = 500 * time.Millisecond ) +// creditsRetryableErrorCodes 是降级响应中可重试的错误码集合。 +// forbidden 是稳定的封号状态,不属于可恢复的瞬态错误,不重试。 +var creditsRetryableErrorCodes = map[string]bool{ + errorCodeUnauthenticated: true, + errorCodeRateLimited: true, + errorCodeNetworkError: true, +} + +// isAntigravityDegradedResponse 检查 UsageInfo 是否为可重试的降级响应。 +// 仅检测 3 个瞬态错误码(unauthenticated/rate_limited/network_error), +// forbidden 是稳定的封号状态,不属于降级。 +func isAntigravityDegradedResponse(info *UsageInfo) bool { + if info == nil || info.ErrorCode == "" { + return false + } + return creditsRetryableErrorCodes[info.ErrorCode] +} + // checkAccountCredits 通过共享的 AccountUsageService 缓存检查账号是否有足够的 AI Credits。 // 缓存 TTL 不足时会自动从 Google loadCodeAssist API 刷新。 -// 返回 true 表示积分可用。 +// 检测到降级响应时会清除缓存并重试,最终 fail-open(返回 true)。 func (s *AntigravityGatewayService) checkAccountCredits( ctx context.Context, account *Account, ) bool { if account == nil || account.ID == 0 { return false } - if s.accountUsageService == nil { return true // 无 usage service 时不阻断 } usageInfo, err := s.accountUsageService.GetAntigravityCredits(ctx, account) if err != nil { - logger.LegacyPrintf("service.antigravity_gateway", - "check_credits: get_credits_failed account=%d err=%v", account.ID, err) - return true // 出错时假设有积分,不阻断 + slog.Error("check_credits: get_credits_failed", + "account_id", account.ID, "error", err) + return true // 出错时 fail-open } - hasCredits := hasEnoughCredits(usageInfo) + // 非降级响应:直接检查积分余额 + if !isAntigravityDegradedResponse(usageInfo) { + return s.logCreditsResult(account, usageInfo) + } + + // 降级响应:清除缓存后重试 + return s.retryCreditsOnDegraded(ctx, account, usageInfo) +} + +// retryCreditsOnDegraded 在检测到降级响应后,清除缓存并重试获取 credits。 +// 使用指数退避(500ms → 1s → 2s),最多重试 creditsRetryMaxAttempts 次。 +// 所有重试失败后 fail-open(返回 true),不做熔断。 +func (s *AntigravityGatewayService) retryCreditsOnDegraded( + ctx context.Context, account *Account, lastInfo *UsageInfo, +) bool { + for attempt := 1; attempt <= creditsRetryMaxAttempts; attempt++ { + delay := creditsRetryBaseInterval << (attempt - 1) // 指数退避:500ms, 1s, 2s + slog.Warn("check_credits: degraded response, retrying", + "account_id", account.ID, + "attempt", attempt, + "max_attempts", creditsRetryMaxAttempts, + "error_code", lastInfo.ErrorCode, + "delay", delay, + ) + + select { + case <-ctx.Done(): + slog.Warn("check_credits: context cancelled during retry, fail-open", + "account_id", account.ID, "attempt", attempt) + return true + case <-time.After(delay): + } + + // 清除缓存,强制下次 GetAntigravityCredits 重新拉取 + s.accountUsageService.InvalidateAntigravityCreditsCache(account.ID) + + info, err := s.accountUsageService.GetAntigravityCredits(ctx, account) + if err != nil { + slog.Error("check_credits: retry get_credits_failed", + "account_id", account.ID, "attempt", attempt, "error", err) + continue + } + + // 重试成功(不再是降级响应):检查积分余额 + if !isAntigravityDegradedResponse(info) { + slog.Info("check_credits: retry succeeded", + "account_id", account.ID, "attempt", attempt) + return s.logCreditsResult(account, info) + } + lastInfo = info + } + + // 所有重试失败:fail-open,不做熔断 + slog.Warn("check_credits: all retries exhausted, fail-open", + "account_id", account.ID, + "last_error_code", lastInfo.ErrorCode, + ) + return true +} + +// logCreditsResult 检查积分并记录不足日志,返回是否有积分。 +func (s *AntigravityGatewayService) logCreditsResult(account *Account, info *UsageInfo) bool { + hasCredits := hasEnoughCredits(info) if !hasCredits { - logger.LegacyPrintf("service.antigravity_gateway", - "check_credits: account=%d has_credits=false", account.ID) + slog.Warn("check_credits: insufficient credits", + "account_id", account.ID) } return hasCredits } diff --git a/backend/internal/service/channel.go b/backend/internal/service/channel.go index 05f96bc9..1697ed6f 100644 --- a/backend/internal/service/channel.go +++ b/backend/internal/service/channel.go @@ -1,6 +1,8 @@ package service import ( + "fmt" + "sort" "strings" "time" ) @@ -177,6 +179,94 @@ func (c *Channel) Clone() *Channel { return &cp } +// ValidateIntervals 校验区间列表的合法性。 +// 规则:MinTokens >= 0;MaxTokens 若非 nil 则 > 0 且 > MinTokens; +// 所有价格字段 >= 0;区间按 MinTokens 排序后无重叠((min, max] 语义); +// 无界区间(MaxTokens=nil)必须是最后一个。间隙允许(回退默认价格)。 +func ValidateIntervals(intervals []PricingInterval) error { + if len(intervals) == 0 { + return nil + } + sorted := make([]PricingInterval, len(intervals)) + copy(sorted, intervals) + sort.Slice(sorted, func(i, j int) bool { + return sorted[i].MinTokens < sorted[j].MinTokens + }) + + for i := range sorted { + if err := validateSingleInterval(&sorted[i], i); err != nil { + return err + } + } + return validateIntervalOverlap(sorted) +} + +// validateSingleInterval 校验单个区间的字段合法性 +func validateSingleInterval(iv *PricingInterval, idx int) error { + if iv.MinTokens < 0 { + return fmt.Errorf("interval #%d: min_tokens (%d) must be >= 0", idx+1, iv.MinTokens) + } + if iv.MaxTokens != nil { + if *iv.MaxTokens <= 0 { + return fmt.Errorf("interval #%d: max_tokens (%d) must be > 0", idx+1, *iv.MaxTokens) + } + if *iv.MaxTokens <= iv.MinTokens { + return fmt.Errorf("interval #%d: max_tokens (%d) must be > min_tokens (%d)", + idx+1, *iv.MaxTokens, iv.MinTokens) + } + } + return validateIntervalPrices(iv, idx) +} + +// validateIntervalPrices 校验区间内所有价格字段 >= 0 +func validateIntervalPrices(iv *PricingInterval, idx int) error { + prices := []struct { + name string + val *float64 + }{ + {"input_price", iv.InputPrice}, + {"output_price", iv.OutputPrice}, + {"cache_write_price", iv.CacheWritePrice}, + {"cache_read_price", iv.CacheReadPrice}, + {"per_request_price", iv.PerRequestPrice}, + } + for _, p := range prices { + if p.val != nil && *p.val < 0 { + return fmt.Errorf("interval #%d: %s must be >= 0", idx+1, p.name) + } + } + return nil +} + +// validateIntervalOverlap 校验排序后的区间列表无重叠,且无界区间在最后 +func validateIntervalOverlap(sorted []PricingInterval) error { + for i, iv := range sorted { + // 无界区间必须是最后一个 + if iv.MaxTokens == nil && i < len(sorted)-1 { + return fmt.Errorf("interval #%d: unbounded interval (max_tokens=null) must be the last one", + i+1) + } + if i == 0 { + continue + } + prev := sorted[i-1] + // 检查重叠:前一个区间的上界 > 当前区间的下界则重叠 + // (min, max] 语义:prev 覆盖 (prev.Min, prev.Max],cur 覆盖 (cur.Min, cur.Max] + if prev.MaxTokens == nil || *prev.MaxTokens > iv.MinTokens { + return fmt.Errorf("interval #%d and #%d overlap: prev max=%s > cur min=%d", + i, i+1, formatMaxTokensLabel(prev.MaxTokens), iv.MinTokens) + } + } + return nil +} + +func formatMaxTokensLabel(max *int) string { + if max == nil { + return "∞" + } + return fmt.Sprintf("%d", *max) +} + // ChannelUsageFields 渠道相关的使用记录字段(嵌入到各平台的 RecordUsageInput 中) type ChannelUsageFields struct { ChannelID int64 // 渠道 ID(0 = 无渠道) diff --git a/backend/internal/service/channel_service.go b/backend/internal/service/channel_service.go index 25a8d39b..cbab9bfe 100644 --- a/backend/internal/service/channel_service.go +++ b/backend/internal/service/channel_service.go @@ -198,13 +198,18 @@ func newEmptyChannelCache() *channelCache { // expandPricingToCache 将渠道的模型定价展开到缓存(按分组+平台维度)。 // antigravity 平台同时服务 Claude 和 Gemini 模型,需匹配 anthropic/gemini 的定价条目。 +// 缓存 key 使用定价条目的原始平台(pricing.Platform),而非分组平台, +// 避免跨平台同名模型(如 anthropic 和 gemini 都有 "model-x")互相覆盖。 +// 查找时通过 lookupPricingAcrossPlatforms() 依次尝试所有匹配平台。 func expandPricingToCache(cache *channelCache, ch *Channel, gid int64, platform string) { for j := range ch.ModelPricing { pricing := &ch.ModelPricing[j] if !isPlatformPricingMatch(platform, pricing.Platform) { continue // 跳过非本平台的定价 } - gpKey := channelGroupPlatformKey{groupID: gid, platform: platform} + // 使用定价条目的原始平台作为缓存 key,防止跨平台同名模型冲突 + pricingPlatform := pricing.Platform + gpKey := channelGroupPlatformKey{groupID: gid, platform: pricingPlatform} for _, model := range pricing.Models { if strings.HasSuffix(model, "*") { prefix := strings.ToLower(strings.TrimSuffix(model, "*")) @@ -213,7 +218,7 @@ func expandPricingToCache(cache *channelCache, ch *Channel, gid int64, platform pricing: pricing, }) } else { - key := channelModelKey{groupID: gid, platform: platform, model: strings.ToLower(model)} + key := channelModelKey{groupID: gid, platform: pricingPlatform, model: strings.ToLower(model)} cache.pricingByGroupModel[key] = pricing } } @@ -222,13 +227,15 @@ func expandPricingToCache(cache *channelCache, ch *Channel, gid int64, platform // expandMappingToCache 将渠道的模型映射展开到缓存(按分组+平台维度)。 // antigravity 平台同时服务 Claude 和 Gemini 模型。 +// 缓存 key 使用映射条目的原始平台(mappingPlatform),避免跨平台同名映射覆盖。 func expandMappingToCache(cache *channelCache, ch *Channel, gid int64, platform string) { for _, mappingPlatform := range matchingPlatforms(platform) { platformMapping, ok := ch.ModelMapping[mappingPlatform] if !ok { continue } - gpKey := channelGroupPlatformKey{groupID: gid, platform: platform} + // 使用映射条目的原始平台作为缓存 key,防止跨平台同名映射冲突 + gpKey := channelGroupPlatformKey{groupID: gid, platform: mappingPlatform} for src, dst := range platformMapping { if strings.HasSuffix(src, "*") { prefix := strings.ToLower(strings.TrimSuffix(src, "*")) @@ -237,7 +244,7 @@ func expandMappingToCache(cache *channelCache, ch *Channel, gid int64, platform target: dst, }) } else { - key := channelModelKey{groupID: gid, platform: platform, model: strings.ToLower(src)} + key := channelModelKey{groupID: gid, platform: mappingPlatform, model: strings.ToLower(src)} cache.mappingByGroupModel[key] = dst } } @@ -349,6 +356,43 @@ func (c *channelCache) matchWildcardMapping(groupID int64, platform, modelLower return "" } +// lookupPricingAcrossPlatforms 在所有匹配平台中查找模型定价。 +// antigravity 分组的缓存 key 使用定价条目的原始平台,因此查找时需依次尝试 +// matchingPlatforms() 返回的所有平台(antigravity → anthropic → gemini), +// 返回第一个命中的结果。非 antigravity 平台只尝试自身。 +func lookupPricingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatform, modelLower string) *ChannelModelPricing { + for _, p := range matchingPlatforms(groupPlatform) { + key := channelModelKey{groupID: groupID, platform: p, model: modelLower} + if pricing, ok := cache.pricingByGroupModel[key]; ok { + return pricing + } + } + // 精确查找全部失败,依次尝试通配符匹配 + for _, p := range matchingPlatforms(groupPlatform) { + if pricing := cache.matchWildcard(groupID, p, modelLower); pricing != nil { + return pricing + } + } + return nil +} + +// lookupMappingAcrossPlatforms 在所有匹配平台中查找模型映射。 +// 逻辑与 lookupPricingAcrossPlatforms 相同:先精确查找,再通配符。 +func lookupMappingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatform, modelLower string) string { + for _, p := range matchingPlatforms(groupPlatform) { + key := channelModelKey{groupID: groupID, platform: p, model: modelLower} + if mapped, ok := cache.mappingByGroupModel[key]; ok { + return mapped + } + } + for _, p := range matchingPlatforms(groupPlatform) { + if mapped := cache.matchWildcardMapping(groupID, p, modelLower); mapped != "" { + return mapped + } + } + return "" +} + // GetChannelForGroup 获取分组关联的渠道(热路径 O(1)) func (s *ChannelService) GetChannelForGroup(ctx context.Context, groupID int64) (*Channel, error) { cache, err := s.loadCache(ctx) @@ -389,7 +433,9 @@ func (s *ChannelService) lookupGroupChannel(ctx context.Context, groupID int64) }, nil } -// GetChannelModelPricing 获取指定分组+模型的渠道定价(热路径 O(1)) +// GetChannelModelPricing 获取指定分组+模型的渠道定价(热路径 O(1))。 +// antigravity 分组依次尝试所有匹配平台(antigravity → anthropic → gemini), +// 确保跨平台同名模型各自独立匹配。 func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int64, model string) *ChannelModelPricing { lk, err := s.lookupGroupChannel(ctx, groupID) if err != nil { @@ -401,14 +447,9 @@ func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int } modelLower := strings.ToLower(model) - key := channelModelKey{groupID: groupID, platform: lk.platform, model: modelLower} - pricing, ok := lk.cache.pricingByGroupModel[key] - if !ok { - // 精确查找失败,尝试通配符匹配 - pricing = lk.cache.matchWildcard(groupID, lk.platform, modelLower) - if pricing == nil { - return nil - } + pricing := lookupPricingAcrossPlatforms(lk.cache, groupID, lk.platform, modelLower) + if pricing == nil { + return nil } cp := pricing.Clone() @@ -453,7 +494,8 @@ func (s *ChannelService) ResolveChannelMappingAndRestrict(ctx context.Context, g return resolveMapping(lk, *groupID, model), false } -// resolveMapping 基于已查找的渠道信息解析模型映射 +// resolveMapping 基于已查找的渠道信息解析模型映射。 +// antigravity 分组依次尝试所有匹配平台,确保跨平台同名映射各自独立。 func resolveMapping(lk *channelLookup, groupID int64, model string) ChannelMappingResult { result := ChannelMappingResult{ MappedModel: model, @@ -465,11 +507,7 @@ func resolveMapping(lk *channelLookup, groupID int64, model string) ChannelMappi } modelLower := strings.ToLower(model) - key := channelModelKey{groupID: groupID, platform: lk.platform, model: modelLower} - if mapped, ok := lk.cache.mappingByGroupModel[key]; ok { - result.MappedModel = mapped - result.Mapped = true - } else if mapped := lk.cache.matchWildcardMapping(groupID, lk.platform, modelLower); mapped != "" { + if mapped := lookupMappingAcrossPlatforms(lk.cache, groupID, lk.platform, modelLower); mapped != "" { result.MappedModel = mapped result.Mapped = true } @@ -477,19 +515,15 @@ func resolveMapping(lk *channelLookup, groupID int64, model string) ChannelMappi return result } -// checkRestricted 基于已查找的渠道信息检查模型是否被限制 +// checkRestricted 基于已查找的渠道信息检查模型是否被限制。 +// antigravity 分组依次尝试所有匹配平台的定价列表。 func checkRestricted(lk *channelLookup, groupID int64, model string) bool { if !lk.channel.RestrictModels { return false } - // 检查模型是否在定价列表中 modelLower := strings.ToLower(model) - key := channelModelKey{groupID: groupID, platform: lk.platform, model: modelLower} - if _, exists := lk.cache.pricingByGroupModel[key]; exists { - return false - } - // 精确查找失败,尝试通配符匹配 - if lk.cache.matchWildcard(groupID, lk.platform, modelLower) != nil { + // 使用与查找定价相同的跨平台逻辑 + if lookupPricingAcrossPlatforms(lk.cache, groupID, lk.platform, modelLower) != nil { return false } return true @@ -550,6 +584,9 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput) if err := validateNoConflictingModels(channel.ModelPricing); err != nil { return nil, err } + if err := validatePricingIntervals(channel.ModelPricing); err != nil { + return nil, err + } if err := validateNoConflictingMappings(channel.ModelMapping); err != nil { return nil, err } @@ -624,6 +661,9 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan if err := validateNoConflictingModels(channel.ModelPricing); err != nil { return nil, err } + if err := validatePricingIntervals(channel.ModelPricing); err != nil { + return nil, err + } if err := validateNoConflictingMappings(channel.ModelMapping); err != nil { return nil, err } @@ -756,6 +796,19 @@ func validateNoConflictingMappings(mapping map[string]map[string]string) error { return nil } +func validatePricingIntervals(pricingList []ChannelModelPricing) error { + for _, pricing := range pricingList { + if err := ValidateIntervals(pricing.Intervals); err != nil { + return infraerrors.BadRequest( + "INVALID_PRICING_INTERVALS", + fmt.Sprintf("invalid pricing intervals for platform '%s' models %v: %v", + pricing.Platform, pricing.Models, err), + ) + } + } + return nil +} + // detectConflicts 在一组 modelEntry 中检测冲突,返回带有 errCode 和 label 的错误 func detectConflicts(entries []modelEntry, platform, errCode, label string) error { for i := 0; i < len(entries); i++ { diff --git a/backend/internal/service/channel_service_test.go b/backend/internal/service/channel_service_test.go index 93cafa67..0232062c 100644 --- a/backend/internal/service/channel_service_test.go +++ b/backend/internal/service/channel_service_test.go @@ -1401,6 +1401,32 @@ func TestCreate_DuplicateModel(t *testing.T) { require.Contains(t, err.Error(), "claude-opus-4") } +func TestCreate_InvalidPricingIntervals(t *testing.T) { + repo := &mockChannelRepository{ + existsByNameFn: func(_ context.Context, _ string) (bool, error) { + return false, nil + }, + } + svc := newTestChannelService(repo) + + _, err := svc.Create(context.Background(), &CreateChannelInput{ + Name: "new-channel", + ModelPricing: []ChannelModelPricing{ + { + Platform: "anthropic", + Models: []string{"claude-opus-4"}, + Intervals: []PricingInterval{ + {MinTokens: 0, MaxTokens: testPtrInt(2000), InputPrice: testPtrFloat64(1e-6)}, + {MinTokens: 1000, MaxTokens: testPtrInt(3000), InputPrice: testPtrFloat64(2e-6)}, + }, + }, + }, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "INVALID_PRICING_INTERVALS") + require.Contains(t, err.Error(), "overlap") +} + func TestCreate_DefaultBillingModelSource(t *testing.T) { var capturedChannel *Channel repo := &mockChannelRepository{ @@ -1592,6 +1618,37 @@ func TestUpdate_DuplicateModel(t *testing.T) { require.Contains(t, err.Error(), "claude-opus-4") } +func TestUpdate_InvalidPricingIntervals(t *testing.T) { + existing := &Channel{ + ID: 1, + Name: "original", + Status: StatusActive, + } + repo := &mockChannelRepository{ + getByIDFn: func(_ context.Context, _ int64) (*Channel, error) { + return existing.Clone(), nil + }, + } + svc := newTestChannelService(repo) + + invalidPricing := []ChannelModelPricing{ + { + Platform: "anthropic", + Models: []string{"claude-opus-4"}, + Intervals: []PricingInterval{ + {MinTokens: 0, MaxTokens: nil, InputPrice: testPtrFloat64(1e-6)}, + {MinTokens: 2000, MaxTokens: testPtrInt(4000), InputPrice: testPtrFloat64(2e-6)}, + }, + }, + } + _, err := svc.Update(context.Background(), 1, &UpdateChannelInput{ + ModelPricing: &invalidPricing, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "INVALID_PRICING_INTERVALS") + require.Contains(t, err.Error(), "unbounded") +} + func TestUpdate_InvalidatesChannelCache(t *testing.T) { existing := &Channel{ ID: 1, @@ -1984,3 +2041,144 @@ func TestResolveChannelMapping_AntigravityCrossPlatform(t *testing.T) { require.Equal(t, "claude-opus-4-6", result.MappedModel) require.Equal(t, int64(1), result.ChannelID) } + +// =========================================================================== +// 11. Antigravity cross-platform same-name model — no overwrite +// =========================================================================== + +func TestGetChannelModelPricing_AntigravitySameModelDifferentPlatforms(t *testing.T) { + // anthropic 和 gemini 都定义了同名模型 "shared-model",价格不同。 + // antigravity 分组应能分别查到各自的定价,而不是后者覆盖前者。 + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{ + {ID: 200, Platform: PlatformAnthropic, Models: []string{"shared-model"}, InputPrice: testPtrFloat64(10e-6)}, + {ID: 201, Platform: PlatformGemini, Models: []string{"shared-model"}, InputPrice: testPtrFloat64(5e-6)}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity}) + svc := newTestChannelService(repo) + + // antigravity 分组查找 "shared-model":应命中第一个匹配(按 matchingPlatforms 顺序 antigravity→anthropic→gemini) + result := svc.GetChannelModelPricing(context.Background(), 10, "shared-model") + require.NotNil(t, result, "antigravity group should find pricing for shared-model") + // 第一个匹配应该是 anthropic(matchingPlatforms 返回 [antigravity, anthropic, gemini]) + require.Equal(t, int64(200), result.ID) + require.InDelta(t, 10e-6, *result.InputPrice, 1e-12) +} + +func TestGetChannelModelPricing_AntigravityOnlyGeminiPricing(t *testing.T) { + // 只有 gemini 平台定义了模型 "gemini-model"。 + // antigravity 分组应能查到 gemini 的定价。 + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{ + {ID: 300, Platform: PlatformGemini, Models: []string{"gemini-model"}, InputPrice: testPtrFloat64(2e-6)}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity}) + svc := newTestChannelService(repo) + + result := svc.GetChannelModelPricing(context.Background(), 10, "gemini-model") + require.NotNil(t, result, "antigravity group should find gemini pricing") + require.Equal(t, int64(300), result.ID) + require.InDelta(t, 2e-6, *result.InputPrice, 1e-12) +} + +func TestGetChannelModelPricing_AntigravityWildcardCrossPlatformNoOverwrite(t *testing.T) { + // anthropic 和 gemini 都有 "shared-*" 通配符定价,价格不同。 + // antigravity 分组查找 "shared-model" 应命中第一个匹配而非被覆盖。 + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{ + {ID: 400, Platform: PlatformAnthropic, Models: []string{"shared-*"}, InputPrice: testPtrFloat64(10e-6)}, + {ID: 401, Platform: PlatformGemini, Models: []string{"shared-*"}, InputPrice: testPtrFloat64(5e-6)}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity}) + svc := newTestChannelService(repo) + + result := svc.GetChannelModelPricing(context.Background(), 10, "shared-model") + require.NotNil(t, result, "antigravity group should find wildcard pricing for shared-model") + // 两个通配符都存在,应命中 anthropic 的(matchingPlatforms 顺序) + require.Equal(t, int64(400), result.ID) + require.InDelta(t, 10e-6, *result.InputPrice, 1e-12) +} + +func TestResolveChannelMapping_AntigravitySameModelDifferentPlatforms(t *testing.T) { + // anthropic 和 gemini 都定义了同名模型映射 "alias" → 不同目标。 + // antigravity 分组应命中 anthropic 的映射(按 matchingPlatforms 顺序)。 + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelMapping: map[string]map[string]string{ + PlatformAnthropic: {"alias": "anthropic-target"}, + PlatformGemini: {"alias": "gemini-target"}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity}) + svc := newTestChannelService(repo) + + result := svc.ResolveChannelMapping(context.Background(), 10, "alias") + require.True(t, result.Mapped) + require.Equal(t, "anthropic-target", result.MappedModel) +} + +func TestCheckRestricted_AntigravitySameModelDifferentPlatforms(t *testing.T) { + // anthropic 和 gemini 都定义了同名模型 "shared-model"。 + // antigravity 分组启用了 RestrictModels,"shared-model" 应不被限制。 + ch := Channel{ + ID: 1, + Status: StatusActive, + RestrictModels: true, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{ + {ID: 500, Platform: PlatformAnthropic, Models: []string{"shared-model"}, InputPrice: testPtrFloat64(10e-6)}, + {ID: 501, Platform: PlatformGemini, Models: []string{"shared-model"}, InputPrice: testPtrFloat64(5e-6)}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity}) + svc := newTestChannelService(repo) + + restricted := svc.IsModelRestricted(context.Background(), 10, "shared-model") + require.False(t, restricted, "shared-model should not be restricted for antigravity") + + // 未定义的模型应被限制 + restricted = svc.IsModelRestricted(context.Background(), 10, "unknown-model") + require.True(t, restricted, "unknown-model should be restricted for antigravity") +} + +func TestGetChannelModelPricing_NonAntigravityUnaffected(t *testing.T) { + // 确保非 antigravity 平台的行为不受影响。 + // anthropic 分组只能看到 anthropic 的定价,看不到 gemini 的。 + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10, 20}, + ModelPricing: []ChannelModelPricing{ + {ID: 600, Platform: PlatformAnthropic, Models: []string{"shared-model"}, InputPrice: testPtrFloat64(10e-6)}, + {ID: 601, Platform: PlatformGemini, Models: []string{"shared-model"}, InputPrice: testPtrFloat64(5e-6)}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: PlatformAnthropic, 20: PlatformGemini}) + svc := newTestChannelService(repo) + + // anthropic 分组应该只看到 anthropic 的定价 + result := svc.GetChannelModelPricing(context.Background(), 10, "shared-model") + require.NotNil(t, result) + require.Equal(t, int64(600), result.ID) + require.InDelta(t, 10e-6, *result.InputPrice, 1e-12) + + // gemini 分组应该只看到 gemini 的定价 + result = svc.GetChannelModelPricing(context.Background(), 20, "shared-model") + require.NotNil(t, result) + require.Equal(t, int64(601), result.ID) + require.InDelta(t, 5e-6, *result.InputPrice, 1e-12) +} diff --git a/backend/internal/service/channel_test.go b/backend/internal/service/channel_test.go index d01c252b..deac64d6 100644 --- a/backend/internal/service/channel_test.go +++ b/backend/internal/service/channel_test.go @@ -307,3 +307,129 @@ func TestChannelClone_EdgeCases(t *testing.T) { require.Equal(t, "gpt-4-turbo", original.ModelMapping["openai"]["gpt-4"]) }) } + +// --- ValidateIntervals --- + +func TestValidateIntervals_Empty(t *testing.T) { + require.NoError(t, ValidateIntervals(nil)) + require.NoError(t, ValidateIntervals([]PricingInterval{})) +} + +func TestValidateIntervals_ValidIntervals(t *testing.T) { + tests := []struct { + name string + intervals []PricingInterval + }{ + { + name: "single bounded interval", + intervals: []PricingInterval{ + {MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6)}, + }, + }, + { + name: "two intervals with gap", + intervals: []PricingInterval{ + {MinTokens: 0, MaxTokens: testPtrInt(100000), InputPrice: testPtrFloat64(1e-6)}, + {MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(2e-6)}, + }, + }, + { + name: "two contiguous intervals", + intervals: []PricingInterval{ + {MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6)}, + {MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(2e-6)}, + }, + }, + { + name: "unsorted input (auto-sorted by validator)", + intervals: []PricingInterval{ + {MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(2e-6)}, + {MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6)}, + }, + }, + { + name: "single unbounded interval", + intervals: []PricingInterval{ + {MinTokens: 0, MaxTokens: nil, InputPrice: testPtrFloat64(1e-6)}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.NoError(t, ValidateIntervals(tt.intervals)) + }) + } +} + +func TestValidateIntervals_NegativeMinTokens(t *testing.T) { + intervals := []PricingInterval{ + {MinTokens: -1, MaxTokens: testPtrInt(100), InputPrice: testPtrFloat64(1e-6)}, + } + err := ValidateIntervals(intervals) + require.Error(t, err) + require.Contains(t, err.Error(), "min_tokens") + require.Contains(t, err.Error(), ">= 0") +} + +func TestValidateIntervals_MaxTokensZero(t *testing.T) { + intervals := []PricingInterval{ + {MinTokens: 0, MaxTokens: testPtrInt(0), InputPrice: testPtrFloat64(1e-6)}, + } + err := ValidateIntervals(intervals) + require.Error(t, err) + require.Contains(t, err.Error(), "max_tokens") + require.Contains(t, err.Error(), "> 0") +} + +func TestValidateIntervals_MaxLessThanMin(t *testing.T) { + intervals := []PricingInterval{ + {MinTokens: 100, MaxTokens: testPtrInt(50), InputPrice: testPtrFloat64(1e-6)}, + } + err := ValidateIntervals(intervals) + require.Error(t, err) + require.Contains(t, err.Error(), "max_tokens") + require.Contains(t, err.Error(), "> min_tokens") +} + +func TestValidateIntervals_MaxEqualsMin(t *testing.T) { + intervals := []PricingInterval{ + {MinTokens: 100, MaxTokens: testPtrInt(100), InputPrice: testPtrFloat64(1e-6)}, + } + err := ValidateIntervals(intervals) + require.Error(t, err) + require.Contains(t, err.Error(), "max_tokens") + require.Contains(t, err.Error(), "> min_tokens") +} + +func TestValidateIntervals_NegativePrice(t *testing.T) { + negPrice := -0.01 + intervals := []PricingInterval{ + {MinTokens: 0, MaxTokens: testPtrInt(100), InputPrice: &negPrice}, + } + err := ValidateIntervals(intervals) + require.Error(t, err) + require.Contains(t, err.Error(), "input_price") + require.Contains(t, err.Error(), ">= 0") +} + +func TestValidateIntervals_OverlappingIntervals(t *testing.T) { + intervals := []PricingInterval{ + {MinTokens: 0, MaxTokens: testPtrInt(200), InputPrice: testPtrFloat64(1e-6)}, + {MinTokens: 100, MaxTokens: testPtrInt(300), InputPrice: testPtrFloat64(2e-6)}, + } + err := ValidateIntervals(intervals) + require.Error(t, err) + require.Contains(t, err.Error(), "overlap") +} + +func TestValidateIntervals_UnboundedNotLast(t *testing.T) { + intervals := []PricingInterval{ + {MinTokens: 0, MaxTokens: nil, InputPrice: testPtrFloat64(1e-6)}, + {MinTokens: 128000, MaxTokens: testPtrInt(256000), InputPrice: testPtrFloat64(2e-6)}, + } + err := ValidateIntervals(intervals) + require.Error(t, err) + require.Contains(t, err.Error(), "unbounded") + require.Contains(t, err.Error(), "last") +} diff --git a/backend/internal/service/gateway_channel_restriction_fallback_test.go b/backend/internal/service/gateway_channel_restriction_fallback_test.go new file mode 100644 index 00000000..d3196419 --- /dev/null +++ b/backend/internal/service/gateway_channel_restriction_fallback_test.go @@ -0,0 +1,130 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/stretchr/testify/require" +) + +func TestSelectAccountForModelWithExclusions_UsesFallbackGroupForChannelRestriction(t *testing.T) { + t.Parallel() + + groupID := int64(10) + fallbackID := int64(11) + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{fallbackID}, + RestrictModels: true, + ModelPricing: []ChannelModelPricing{ + {Platform: PlatformAnthropic, Models: []string{"claude-sonnet-4-6"}}, + }, + } + channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{ + fallbackID: PlatformAnthropic, + })) + accountRepo := &mockAccountRepoForPlatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range accountRepo.accounts { + accountRepo.accountsByID[accountRepo.accounts[i].ID] = &accountRepo.accounts[i] + } + groupRepo := &mockGroupRepoForGateway{ + groups: map[int64]*Group{ + groupID: { + ID: groupID, + Platform: PlatformAnthropic, + Status: StatusActive, + ClaudeCodeOnly: true, + FallbackGroupID: &fallbackID, + Hydrated: true, + }, + fallbackID: { + ID: fallbackID, + Platform: PlatformAnthropic, + Status: StatusActive, + Hydrated: true, + }, + }, + } + + svc := &GatewayService{ + accountRepo: accountRepo, + groupRepo: groupRepo, + channelService: channelSvc, + cfg: testConfig(), + } + + ctx := context.WithValue(context.Background(), ctxkey.Group, groupRepo.groups[groupID]) + account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-sonnet-4-6", nil) + require.NoError(t, err) + require.NotNil(t, account) + require.Equal(t, int64(1), account.ID) +} + +func TestSelectAccountWithLoadAwareness_UsesFallbackGroupForChannelRestriction(t *testing.T) { + t.Parallel() + + groupID := int64(10) + fallbackID := int64(11) + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{fallbackID}, + RestrictModels: true, + ModelPricing: []ChannelModelPricing{ + {Platform: PlatformAnthropic, Models: []string{"claude-sonnet-4-6"}}, + }, + } + channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{ + fallbackID: PlatformAnthropic, + })) + accountRepo := &mockAccountRepoForPlatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range accountRepo.accounts { + accountRepo.accountsByID[accountRepo.accounts[i].ID] = &accountRepo.accounts[i] + } + groupRepo := &mockGroupRepoForGateway{ + groups: map[int64]*Group{ + groupID: { + ID: groupID, + Platform: PlatformAnthropic, + Status: StatusActive, + ClaudeCodeOnly: true, + FallbackGroupID: &fallbackID, + Hydrated: true, + }, + fallbackID: { + ID: fallbackID, + Platform: PlatformAnthropic, + Status: StatusActive, + Hydrated: true, + }, + }, + } + + svc := &GatewayService{ + accountRepo: accountRepo, + groupRepo: groupRepo, + channelService: channelSvc, + cfg: testConfig(), + } + + ctx := context.WithValue(context.Background(), ctxkey.Group, groupRepo.groups[groupID]) + result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "claude-sonnet-4-6", nil, "", 0) + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.Account) + require.Equal(t, int64(1), result.Account.ID) +} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index d272b8de..14937dd4 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -1178,11 +1178,6 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int // SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts. func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) { - // 渠道定价限制预检查(requested / channel_mapped 基准) - if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) { - return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel) - } - // 优先检查 context 中的强制平台(/antigravity 路由) var platform string forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string) @@ -1201,6 +1196,12 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context platform = PlatformAnthropic } + // Claude Code 限制可能已将 groupID 解析为 fallback group, + // 渠道限制预检查必须使用解析后的分组。 + if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) { + return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel) + } + // anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户) // 注意:强制平台模式不走混合调度 if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform { @@ -1217,11 +1218,6 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context // metadataUserID: 用于客户端亲和调度,从中提取客户端 ID // sub2apiUserID: 系统用户 ID,用于二维亲和调度 func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string, sub2apiUserID int64) (*AccountSelectionResult, error) { - // 渠道定价限制预检查(requested / channel_mapped 基准) - if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) { - return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel) - } - // 调试日志:记录调度入口参数 excludedIDsList := make([]int64, 0, len(excludedIDs)) for id := range excludedIDs { @@ -1242,6 +1238,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } ctx = s.withGroupContext(ctx, group) + // Claude Code 限制可能已将 groupID 解析为 fallback group, + // 渠道限制预检查必须使用解析后的分组。 + if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) { + return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel) + } + var stickyAccountID int64 if prefetch := prefetchedStickyAccountIDFromContext(ctx, groupID); prefetch > 0 { stickyAccountID = prefetch diff --git a/backend/internal/service/openai_channel_restriction_test.go b/backend/internal/service/openai_channel_restriction_test.go new file mode 100644 index 00000000..c9dbceab --- /dev/null +++ b/backend/internal/service/openai_channel_restriction_test.go @@ -0,0 +1,140 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestOpenAISelectAccountForModelWithExclusions_ChannelMappedRestrictionRejectsEarly(t *testing.T) { + t.Parallel() + + channelSvc := newTestChannelService(makeStandardRepo(Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: true, + BillingModelSource: BillingModelSourceChannelMapped, + ModelPricing: []ChannelModelPricing{ + {Platform: PlatformOpenAI, Models: []string{"gpt-4o"}}, + }, + ModelMapping: map[string]map[string]string{ + PlatformOpenAI: {"gpt-4.1": "o3-mini"}, + }, + }, map[int64]string{10: PlatformOpenAI})) + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{ + {ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true}, + }}, + channelService: channelSvc, + } + + groupID := int64(10) + _, err := svc.SelectAccountForModelWithExclusions(context.Background(), &groupID, "", "gpt-4.1", nil) + require.ErrorIs(t, err, ErrNoAvailableAccounts) + require.Contains(t, err.Error(), "channel pricing restriction") +} + +func TestOpenAISelectAccountForModelWithExclusions_UpstreamRestrictionSkipsDisallowedAccount(t *testing.T) { + t.Parallel() + + channelSvc := newTestChannelService(makeStandardRepo(Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: true, + BillingModelSource: BillingModelSourceUpstream, + ModelPricing: []ChannelModelPricing{ + {Platform: PlatformOpenAI, Models: []string{"o3-mini"}}, + }, + }, map[int64]string{10: PlatformOpenAI})) + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{ + { + ID: 1, + Platform: PlatformOpenAI, + Status: StatusActive, + Schedulable: true, + Priority: 10, + Credentials: map[string]any{ + "model_mapping": map[string]any{"gpt-4.1": "gpt-4o"}, + }, + }, + { + ID: 2, + Platform: PlatformOpenAI, + Status: StatusActive, + Schedulable: true, + Priority: 20, + Credentials: map[string]any{ + "model_mapping": map[string]any{"gpt-4.1": "o3-mini"}, + }, + }, + }}, + channelService: channelSvc, + } + + groupID := int64(10) + account, err := svc.SelectAccountForModelWithExclusions(context.Background(), &groupID, "", "gpt-4.1", nil) + require.NoError(t, err) + require.NotNil(t, account) + require.Equal(t, int64(2), account.ID) +} + +func TestOpenAISelectAccountForModelWithExclusions_StickyRestrictedUpstreamFallsBack(t *testing.T) { + t.Parallel() + + channelSvc := newTestChannelService(makeStandardRepo(Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: true, + BillingModelSource: BillingModelSourceUpstream, + ModelPricing: []ChannelModelPricing{ + {Platform: PlatformOpenAI, Models: []string{"o3-mini"}}, + }, + }, map[int64]string{10: PlatformOpenAI})) + + cache := &stubGatewayCache{ + sessionBindings: map[string]int64{"openai:sticky-session": 1}, + } + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{ + { + ID: 1, + Platform: PlatformOpenAI, + Status: StatusActive, + Schedulable: true, + Priority: 10, + Credentials: map[string]any{ + "model_mapping": map[string]any{"gpt-4.1": "gpt-4o"}, + }, + }, + { + ID: 2, + Platform: PlatformOpenAI, + Status: StatusActive, + Schedulable: true, + Priority: 20, + Credentials: map[string]any{ + "model_mapping": map[string]any{"gpt-4.1": "o3-mini"}, + }, + }, + }}, + channelService: channelSvc, + cache: cache, + } + + groupID := int64(10) + account, err := svc.SelectAccountForModelWithExclusions(context.Background(), &groupID, "sticky-session", "gpt-4.1", nil) + require.NoError(t, err) + require.NotNil(t, account) + require.Equal(t, int64(2), account.ID) + require.Equal(t, 1, cache.deletedSessions["openai:sticky-session"]) + require.Equal(t, int64(2), cache.sessionBindings["openai:sticky-session"]) +} diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index ef0aaa5b..d1c52ea5 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -10,6 +10,7 @@ import ( "errors" "fmt" "io" + "log/slog" "math/rand" "net/http" "sort" @@ -423,6 +424,44 @@ func (s *OpenAIGatewayService) ResolveChannelMappingAndRestrict(ctx context.Cont return s.channelService.ResolveChannelMappingAndRestrict(ctx, groupID, model) } +func (s *OpenAIGatewayService) checkChannelPricingRestriction(ctx context.Context, groupID *int64, requestedModel string) bool { + if groupID == nil || s.channelService == nil || requestedModel == "" { + return false + } + mapping := s.channelService.ResolveChannelMapping(ctx, *groupID, requestedModel) + billingModel := billingModelForRestriction(mapping.BillingModelSource, requestedModel, mapping.MappedModel) + if billingModel == "" { + return false + } + return s.channelService.IsModelRestricted(ctx, *groupID, billingModel) +} + +func (s *OpenAIGatewayService) isUpstreamModelRestrictedByChannel(ctx context.Context, groupID int64, account *Account, requestedModel string) bool { + if s.channelService == nil { + return false + } + upstreamModel := resolveOpenAIForwardModel(account, requestedModel, "") + if upstreamModel == "" { + return false + } + return s.channelService.IsModelRestricted(ctx, groupID, upstreamModel) +} + +func (s *OpenAIGatewayService) needsUpstreamChannelRestrictionCheck(ctx context.Context, groupID *int64) bool { + if groupID == nil || s.channelService == nil { + return false + } + ch, err := s.channelService.GetChannelForGroup(ctx, *groupID) + if err != nil { + slog.Warn("failed to check openai channel upstream restriction", "group_id", *groupID, "error", err) + return false + } + if ch == nil || !ch.RestrictModels { + return false + } + return ch.BillingModelSource == BillingModelSourceUpstream +} + // ReplaceModelInBody 替换请求体中的 JSON model 字段(通用 gjson/sjson 实现)。 func (s *OpenAIGatewayService) ReplaceModelInBody(body []byte, newModel string) []byte { return ReplaceModelInBody(body, newModel) @@ -1162,6 +1201,10 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C } func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, stickyAccountID int64) (*Account, error) { + if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) { + return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel) + } + // 1. 尝试粘性会话命中 // Try sticky session hit if account := s.tryStickySessionHit(ctx, groupID, sessionHash, requestedModel, excludedIDs, stickyAccountID); account != nil { @@ -1177,7 +1220,7 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C // 3. 按优先级 + LRU 选择最佳账号 // Select by priority + LRU - selected := s.selectBestAccount(ctx, accounts, requestedModel, excludedIDs) + selected := s.selectBestAccount(ctx, groupID, accounts, requestedModel, excludedIDs) if selected == nil { if requestedModel != "" { @@ -1243,6 +1286,11 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) return nil } + if groupID != nil && s.needsUpstreamChannelRestrictionCheck(ctx, groupID) && + s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel) { + _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) + return nil + } // 刷新会话 TTL 并返回账号 // Refresh session TTL and return account @@ -1255,8 +1303,9 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID // // selectBestAccount selects the best account from candidates (priority + LRU). // Returns nil if no available account. -func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account { +func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, groupID *int64, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account { var selected *Account + needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID) for i := range accounts { acc := &accounts[i] @@ -1275,6 +1324,9 @@ func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, accounts [ if fresh == nil { continue } + if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) { + continue + } // 选择优先级最高且最久未使用的账号 // Select highest priority and least recently used @@ -1326,7 +1378,12 @@ func (s *OpenAIGatewayService) isBetterAccount(candidate, current *Account) bool // SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan. func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) { + if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) { + return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel) + } + cfg := s.schedulingConfig() + needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID) var stickyAccountID int64 if sessionHash != "" && s.cache != nil { if accountID, err := s.getStickySessionAccountID(ctx, groupID, sessionHash); err == nil { @@ -1402,6 +1459,8 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel) if account == nil { _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) + } else if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel) { + _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) } else { result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) if err == nil && result.Acquired { @@ -1447,6 +1506,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex if requestedModel != "" && !acc.IsModelSupported(requestedModel) { continue } + if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, acc, requestedModel) { + continue + } candidates = append(candidates, acc) } @@ -1471,6 +1533,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex if fresh == nil { continue } + if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) { + continue + } result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency) if err == nil && result.Acquired { if sessionHash != "" { @@ -1525,6 +1590,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex if fresh == nil { continue } + if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) { + continue + } result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency) if err == nil && result.Acquired { if sessionHash != "" { @@ -1547,6 +1615,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex if fresh == nil { continue } + if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) { + continue + } return &AccountSelectionResult{ Account: fresh, WaitPlan: &AccountWaitPlan{ diff --git a/frontend/src/components/admin/channel/types.ts b/frontend/src/components/admin/channel/types.ts index 431e3eeb..8d998911 100644 --- a/frontend/src/components/admin/channel/types.ts +++ b/frontend/src/components/admin/channel/types.ts @@ -113,6 +113,70 @@ export function findModelConflict(models: string[]): [string, string] | null { return null } +// ── 区间校验 ────────────────────────────────────────────── + +/** 校验区间列表的合法性,返回错误消息;通过则返回 null */ +export function validateIntervals(intervals: IntervalFormEntry[]): string | null { + if (!intervals || intervals.length === 0) return null + + // 按 min_tokens 排序(不修改原数组) + const sorted = [...intervals].sort((a, b) => a.min_tokens - b.min_tokens) + + for (let i = 0; i < sorted.length; i++) { + const err = validateSingleInterval(sorted[i], i) + if (err) return err + } + return checkIntervalOverlap(sorted) +} + +function validateSingleInterval(iv: IntervalFormEntry, idx: number): string | null { + if (iv.min_tokens < 0) { + return `区间 #${idx + 1}: 最小 token 数 (${iv.min_tokens}) 不能为负数` + } + if (iv.max_tokens != null) { + if (iv.max_tokens <= 0) { + return `区间 #${idx + 1}: 最大 token 数 (${iv.max_tokens}) 必须大于 0` + } + if (iv.max_tokens <= iv.min_tokens) { + return `区间 #${idx + 1}: 最大 token 数 (${iv.max_tokens}) 必须大于最小 token 数 (${iv.min_tokens})` + } + } + return validateIntervalPrices(iv, idx) +} + +function validateIntervalPrices(iv: IntervalFormEntry, idx: number): string | null { + const prices: [string, number | string | null][] = [ + ['输入价格', iv.input_price], + ['输出价格', iv.output_price], + ['缓存写入价格', iv.cache_write_price], + ['缓存读取价格', iv.cache_read_price], + ['单次价格', iv.per_request_price], + ] + for (const [name, val] of prices) { + if (val != null && val !== '' && Number(val) < 0) { + return `区间 #${idx + 1}: ${name}不能为负数` + } + } + return null +} + +function checkIntervalOverlap(sorted: IntervalFormEntry[]): string | null { + for (let i = 0; i < sorted.length; i++) { + // 无上限区间必须是最后一个 + if (sorted[i].max_tokens == null && i < sorted.length - 1) { + return `区间 #${i + 1}: 无上限区间(最大 token 数为空)只能是最后一个` + } + if (i === 0) continue + const prev = sorted[i - 1] + // (min, max] 语义:前一个区间上界 > 当前区间下界则重叠 + if (prev.max_tokens == null || prev.max_tokens > sorted[i].min_tokens) { + const prevMax = prev.max_tokens == null ? '∞' : String(prev.max_tokens) + return `区间 #${i} 和 #${i + 1} 重叠:前一个区间上界 (${prevMax}) 大于当前区间下界 (${sorted[i].min_tokens})` + } + } + return null +} + /** 平台对应的模型 tag 样式(背景+文字) */ export function getPlatformTagClass(platform: string): string { switch (platform) { diff --git a/frontend/src/views/admin/ChannelsView.vue b/frontend/src/views/admin/ChannelsView.vue index 4d4150fb..b651be7d 100644 --- a/frontend/src/views/admin/ChannelsView.vue +++ b/frontend/src/views/admin/ChannelsView.vue @@ -418,7 +418,7 @@ import { useAppStore } from '@/stores/app' import { adminAPI } from '@/api/admin' import type { Channel, ChannelModelPricing, CreateChannelRequest, UpdateChannelRequest } from '@/api/admin/channels' import type { PricingFormEntry } from '@/components/admin/channel/types' -import { mTokToPerToken, perTokenToMTok, apiIntervalsToForm, formIntervalsToAPI, findModelConflict } from '@/components/admin/channel/types' +import { mTokToPerToken, perTokenToMTok, apiIntervalsToForm, formIntervalsToAPI, findModelConflict, validateIntervals } from '@/components/admin/channel/types' import type { AdminGroup, GroupPlatform } from '@/types' import type { Column } from '@/components/common/types' import AppLayout from '@/components/layout/AppLayout.vue' @@ -922,6 +922,21 @@ async function handleSubmit() { } } + // 校验区间合法性(范围、重叠等) + for (const section of form.platforms.filter(s => s.enabled)) { + for (const entry of section.model_pricing) { + if (!entry.intervals || entry.intervals.length === 0) continue + const intervalErr = validateIntervals(entry.intervals) + if (intervalErr) { + const platformLabel = t('admin.groups.platforms.' + section.platform, section.platform) + const modelLabel = entry.models.join(', ') || '未命名' + appStore.showError(`${platformLabel} - ${modelLabel}: ${intervalErr}`) + activeTab.value = section.platform + return + } + } + } + const { group_ids, model_pricing, model_mapping } = formToAPI() submitting.value = true