fix: resolve 5 audit findings in channel/credits/scheduling
P0-1: Credits degraded response retry + fail-open - Add isAntigravityDegradedResponse() to detect transient API failures - Retry up to 3 times with exponential backoff (500ms/1s/2s) - Invalidate singleflight cache between retries - Fail-open after exhausting retries instead of 5h circuit break P1-1: Fix channel restriction pre-check timing conflict - Swap checkClaudeCodeRestriction before checkChannelPricingRestriction - Ensures channel restriction is checked against final fallback groupID P1-2: Add interval pricing validation (frontend + backend) - Backend: ValidateIntervals() with boundary, price, overlap checks - Frontend: validateIntervals() with Chinese error messages - Rules: MinTokens>=0, MaxTokens>MinTokens, prices>=0, no overlap P2: Fix cross-platform same-model pricing/mapping override - Store cache keys using original platform instead of group platform - Lookup across matching platforms (antigravity→anthropic→gemini) - Prevents anthropic/gemini same-name models from overwriting each other
This commit is contained in:
@@ -855,6 +855,13 @@ func (s *AccountUsageService) GetAntigravityCredits(ctx context.Context, account
|
|||||||
return s.getAntigravityUsage(ctx, account)
|
return s.getAntigravityUsage(ctx, account)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InvalidateAntigravityCreditsCache 清除指定账号的 Antigravity 用量缓存,
|
||||||
|
// 使下次调用 GetAntigravityCredits 时强制重新拉取。
|
||||||
|
// 用于 credits 降级响应重试场景:避免重试命中同一个降级缓存。
|
||||||
|
func (s *AccountUsageService) InvalidateAntigravityCreditsCache(accountID int64) {
|
||||||
|
s.cache.antigravityCache.Delete(accountID)
|
||||||
|
}
|
||||||
|
|
||||||
// recalcAntigravityRemainingSeconds 重新计算 Antigravity UsageInfo 中各窗口的 RemainingSeconds
|
// recalcAntigravityRemainingSeconds 重新计算 Antigravity UsageInfo 中各窗口的 RemainingSeconds
|
||||||
// 用于从缓存取出时更新倒计时,避免返回过时的剩余秒数
|
// 用于从缓存取出时更新倒计时,避免返回过时的剩余秒数
|
||||||
func recalcAntigravityRemainingSeconds(info *UsageInfo) {
|
func recalcAntigravityRemainingSeconds(info *UsageInfo) {
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"io"
|
"io"
|
||||||
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -17,33 +18,116 @@ const (
|
|||||||
// 与普通模型限流完全同构:通过 SetModelRateLimit / isRateLimitActiveForKey 读写。
|
// 与普通模型限流完全同构:通过 SetModelRateLimit / isRateLimitActiveForKey 读写。
|
||||||
creditsExhaustedKey = "AICredits"
|
creditsExhaustedKey = "AICredits"
|
||||||
creditsExhaustedDuration = 5 * time.Hour
|
creditsExhaustedDuration = 5 * time.Hour
|
||||||
|
|
||||||
|
// credits 降级响应重试参数
|
||||||
|
creditsRetryMaxAttempts = 3
|
||||||
|
creditsRetryBaseInterval = 500 * time.Millisecond
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// creditsRetryableErrorCodes 是降级响应中可重试的错误码集合。
|
||||||
|
// forbidden 是稳定的封号状态,不属于可恢复的瞬态错误,不重试。
|
||||||
|
var creditsRetryableErrorCodes = map[string]bool{
|
||||||
|
errorCodeUnauthenticated: true,
|
||||||
|
errorCodeRateLimited: true,
|
||||||
|
errorCodeNetworkError: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// isAntigravityDegradedResponse 检查 UsageInfo 是否为可重试的降级响应。
|
||||||
|
// 仅检测 3 个瞬态错误码(unauthenticated/rate_limited/network_error),
|
||||||
|
// forbidden 是稳定的封号状态,不属于降级。
|
||||||
|
func isAntigravityDegradedResponse(info *UsageInfo) bool {
|
||||||
|
if info == nil || info.ErrorCode == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return creditsRetryableErrorCodes[info.ErrorCode]
|
||||||
|
}
|
||||||
|
|
||||||
// checkAccountCredits 通过共享的 AccountUsageService 缓存检查账号是否有足够的 AI Credits。
|
// checkAccountCredits 通过共享的 AccountUsageService 缓存检查账号是否有足够的 AI Credits。
|
||||||
// 缓存 TTL 不足时会自动从 Google loadCodeAssist API 刷新。
|
// 缓存 TTL 不足时会自动从 Google loadCodeAssist API 刷新。
|
||||||
// 返回 true 表示积分可用。
|
// 检测到降级响应时会清除缓存并重试,最终 fail-open(返回 true)。
|
||||||
func (s *AntigravityGatewayService) checkAccountCredits(
|
func (s *AntigravityGatewayService) checkAccountCredits(
|
||||||
ctx context.Context, account *Account,
|
ctx context.Context, account *Account,
|
||||||
) bool {
|
) bool {
|
||||||
if account == nil || account.ID == 0 {
|
if account == nil || account.ID == 0 {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.accountUsageService == nil {
|
if s.accountUsageService == nil {
|
||||||
return true // 无 usage service 时不阻断
|
return true // 无 usage service 时不阻断
|
||||||
}
|
}
|
||||||
|
|
||||||
usageInfo, err := s.accountUsageService.GetAntigravityCredits(ctx, account)
|
usageInfo, err := s.accountUsageService.GetAntigravityCredits(ctx, account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.LegacyPrintf("service.antigravity_gateway",
|
slog.Error("check_credits: get_credits_failed",
|
||||||
"check_credits: get_credits_failed account=%d err=%v", account.ID, err)
|
"account_id", account.ID, "error", err)
|
||||||
return true // 出错时假设有积分,不阻断
|
return true // 出错时 fail-open
|
||||||
}
|
}
|
||||||
|
|
||||||
hasCredits := hasEnoughCredits(usageInfo)
|
// 非降级响应:直接检查积分余额
|
||||||
|
if !isAntigravityDegradedResponse(usageInfo) {
|
||||||
|
return s.logCreditsResult(account, usageInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 降级响应:清除缓存后重试
|
||||||
|
return s.retryCreditsOnDegraded(ctx, account, usageInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
// retryCreditsOnDegraded 在检测到降级响应后,清除缓存并重试获取 credits。
|
||||||
|
// 使用指数退避(500ms → 1s → 2s),最多重试 creditsRetryMaxAttempts 次。
|
||||||
|
// 所有重试失败后 fail-open(返回 true),不做熔断。
|
||||||
|
func (s *AntigravityGatewayService) retryCreditsOnDegraded(
|
||||||
|
ctx context.Context, account *Account, lastInfo *UsageInfo,
|
||||||
|
) bool {
|
||||||
|
for attempt := 1; attempt <= creditsRetryMaxAttempts; attempt++ {
|
||||||
|
delay := creditsRetryBaseInterval << (attempt - 1) // 指数退避:500ms, 1s, 2s
|
||||||
|
slog.Warn("check_credits: degraded response, retrying",
|
||||||
|
"account_id", account.ID,
|
||||||
|
"attempt", attempt,
|
||||||
|
"max_attempts", creditsRetryMaxAttempts,
|
||||||
|
"error_code", lastInfo.ErrorCode,
|
||||||
|
"delay", delay,
|
||||||
|
)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
slog.Warn("check_credits: context cancelled during retry, fail-open",
|
||||||
|
"account_id", account.ID, "attempt", attempt)
|
||||||
|
return true
|
||||||
|
case <-time.After(delay):
|
||||||
|
}
|
||||||
|
|
||||||
|
// 清除缓存,强制下次 GetAntigravityCredits 重新拉取
|
||||||
|
s.accountUsageService.InvalidateAntigravityCreditsCache(account.ID)
|
||||||
|
|
||||||
|
info, err := s.accountUsageService.GetAntigravityCredits(ctx, account)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("check_credits: retry get_credits_failed",
|
||||||
|
"account_id", account.ID, "attempt", attempt, "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 重试成功(不再是降级响应):检查积分余额
|
||||||
|
if !isAntigravityDegradedResponse(info) {
|
||||||
|
slog.Info("check_credits: retry succeeded",
|
||||||
|
"account_id", account.ID, "attempt", attempt)
|
||||||
|
return s.logCreditsResult(account, info)
|
||||||
|
}
|
||||||
|
lastInfo = info
|
||||||
|
}
|
||||||
|
|
||||||
|
// 所有重试失败:fail-open,不做熔断
|
||||||
|
slog.Warn("check_credits: all retries exhausted, fail-open",
|
||||||
|
"account_id", account.ID,
|
||||||
|
"last_error_code", lastInfo.ErrorCode,
|
||||||
|
)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// logCreditsResult 检查积分并记录不足日志,返回是否有积分。
|
||||||
|
func (s *AntigravityGatewayService) logCreditsResult(account *Account, info *UsageInfo) bool {
|
||||||
|
hasCredits := hasEnoughCredits(info)
|
||||||
if !hasCredits {
|
if !hasCredits {
|
||||||
logger.LegacyPrintf("service.antigravity_gateway",
|
slog.Warn("check_credits: insufficient credits",
|
||||||
"check_credits: account=%d has_credits=false", account.ID)
|
"account_id", account.ID)
|
||||||
}
|
}
|
||||||
return hasCredits
|
return hasCredits
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@@ -177,6 +179,94 @@ func (c *Channel) Clone() *Channel {
|
|||||||
return &cp
|
return &cp
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ValidateIntervals 校验区间列表的合法性。
|
||||||
|
// 规则:MinTokens >= 0;MaxTokens 若非 nil 则 > 0 且 > MinTokens;
|
||||||
|
// 所有价格字段 >= 0;区间按 MinTokens 排序后无重叠((min, max] 语义);
|
||||||
|
// 无界区间(MaxTokens=nil)必须是最后一个。间隙允许(回退默认价格)。
|
||||||
|
func ValidateIntervals(intervals []PricingInterval) error {
|
||||||
|
if len(intervals) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
sorted := make([]PricingInterval, len(intervals))
|
||||||
|
copy(sorted, intervals)
|
||||||
|
sort.Slice(sorted, func(i, j int) bool {
|
||||||
|
return sorted[i].MinTokens < sorted[j].MinTokens
|
||||||
|
})
|
||||||
|
|
||||||
|
for i := range sorted {
|
||||||
|
if err := validateSingleInterval(&sorted[i], i); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return validateIntervalOverlap(sorted)
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateSingleInterval 校验单个区间的字段合法性
|
||||||
|
func validateSingleInterval(iv *PricingInterval, idx int) error {
|
||||||
|
if iv.MinTokens < 0 {
|
||||||
|
return fmt.Errorf("interval #%d: min_tokens (%d) must be >= 0", idx+1, iv.MinTokens)
|
||||||
|
}
|
||||||
|
if iv.MaxTokens != nil {
|
||||||
|
if *iv.MaxTokens <= 0 {
|
||||||
|
return fmt.Errorf("interval #%d: max_tokens (%d) must be > 0", idx+1, *iv.MaxTokens)
|
||||||
|
}
|
||||||
|
if *iv.MaxTokens <= iv.MinTokens {
|
||||||
|
return fmt.Errorf("interval #%d: max_tokens (%d) must be > min_tokens (%d)",
|
||||||
|
idx+1, *iv.MaxTokens, iv.MinTokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return validateIntervalPrices(iv, idx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateIntervalPrices 校验区间内所有价格字段 >= 0
|
||||||
|
func validateIntervalPrices(iv *PricingInterval, idx int) error {
|
||||||
|
prices := []struct {
|
||||||
|
name string
|
||||||
|
val *float64
|
||||||
|
}{
|
||||||
|
{"input_price", iv.InputPrice},
|
||||||
|
{"output_price", iv.OutputPrice},
|
||||||
|
{"cache_write_price", iv.CacheWritePrice},
|
||||||
|
{"cache_read_price", iv.CacheReadPrice},
|
||||||
|
{"per_request_price", iv.PerRequestPrice},
|
||||||
|
}
|
||||||
|
for _, p := range prices {
|
||||||
|
if p.val != nil && *p.val < 0 {
|
||||||
|
return fmt.Errorf("interval #%d: %s must be >= 0", idx+1, p.name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateIntervalOverlap 校验排序后的区间列表无重叠,且无界区间在最后
|
||||||
|
func validateIntervalOverlap(sorted []PricingInterval) error {
|
||||||
|
for i, iv := range sorted {
|
||||||
|
// 无界区间必须是最后一个
|
||||||
|
if iv.MaxTokens == nil && i < len(sorted)-1 {
|
||||||
|
return fmt.Errorf("interval #%d: unbounded interval (max_tokens=null) must be the last one",
|
||||||
|
i+1)
|
||||||
|
}
|
||||||
|
if i == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
prev := sorted[i-1]
|
||||||
|
// 检查重叠:前一个区间的上界 > 当前区间的下界则重叠
|
||||||
|
// (min, max] 语义:prev 覆盖 (prev.Min, prev.Max],cur 覆盖 (cur.Min, cur.Max]
|
||||||
|
if prev.MaxTokens == nil || *prev.MaxTokens > iv.MinTokens {
|
||||||
|
return fmt.Errorf("interval #%d and #%d overlap: prev max=%s > cur min=%d",
|
||||||
|
i, i+1, formatMaxTokensLabel(prev.MaxTokens), iv.MinTokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatMaxTokensLabel(max *int) string {
|
||||||
|
if max == nil {
|
||||||
|
return "∞"
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%d", *max)
|
||||||
|
}
|
||||||
|
|
||||||
// ChannelUsageFields 渠道相关的使用记录字段(嵌入到各平台的 RecordUsageInput 中)
|
// ChannelUsageFields 渠道相关的使用记录字段(嵌入到各平台的 RecordUsageInput 中)
|
||||||
type ChannelUsageFields struct {
|
type ChannelUsageFields struct {
|
||||||
ChannelID int64 // 渠道 ID(0 = 无渠道)
|
ChannelID int64 // 渠道 ID(0 = 无渠道)
|
||||||
|
|||||||
@@ -198,13 +198,18 @@ func newEmptyChannelCache() *channelCache {
|
|||||||
|
|
||||||
// expandPricingToCache 将渠道的模型定价展开到缓存(按分组+平台维度)。
|
// expandPricingToCache 将渠道的模型定价展开到缓存(按分组+平台维度)。
|
||||||
// antigravity 平台同时服务 Claude 和 Gemini 模型,需匹配 anthropic/gemini 的定价条目。
|
// antigravity 平台同时服务 Claude 和 Gemini 模型,需匹配 anthropic/gemini 的定价条目。
|
||||||
|
// 缓存 key 使用定价条目的原始平台(pricing.Platform),而非分组平台,
|
||||||
|
// 避免跨平台同名模型(如 anthropic 和 gemini 都有 "model-x")互相覆盖。
|
||||||
|
// 查找时通过 lookupPricingAcrossPlatforms() 依次尝试所有匹配平台。
|
||||||
func expandPricingToCache(cache *channelCache, ch *Channel, gid int64, platform string) {
|
func expandPricingToCache(cache *channelCache, ch *Channel, gid int64, platform string) {
|
||||||
for j := range ch.ModelPricing {
|
for j := range ch.ModelPricing {
|
||||||
pricing := &ch.ModelPricing[j]
|
pricing := &ch.ModelPricing[j]
|
||||||
if !isPlatformPricingMatch(platform, pricing.Platform) {
|
if !isPlatformPricingMatch(platform, pricing.Platform) {
|
||||||
continue // 跳过非本平台的定价
|
continue // 跳过非本平台的定价
|
||||||
}
|
}
|
||||||
gpKey := channelGroupPlatformKey{groupID: gid, platform: platform}
|
// 使用定价条目的原始平台作为缓存 key,防止跨平台同名模型冲突
|
||||||
|
pricingPlatform := pricing.Platform
|
||||||
|
gpKey := channelGroupPlatformKey{groupID: gid, platform: pricingPlatform}
|
||||||
for _, model := range pricing.Models {
|
for _, model := range pricing.Models {
|
||||||
if strings.HasSuffix(model, "*") {
|
if strings.HasSuffix(model, "*") {
|
||||||
prefix := strings.ToLower(strings.TrimSuffix(model, "*"))
|
prefix := strings.ToLower(strings.TrimSuffix(model, "*"))
|
||||||
@@ -213,7 +218,7 @@ func expandPricingToCache(cache *channelCache, ch *Channel, gid int64, platform
|
|||||||
pricing: pricing,
|
pricing: pricing,
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
key := channelModelKey{groupID: gid, platform: platform, model: strings.ToLower(model)}
|
key := channelModelKey{groupID: gid, platform: pricingPlatform, model: strings.ToLower(model)}
|
||||||
cache.pricingByGroupModel[key] = pricing
|
cache.pricingByGroupModel[key] = pricing
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -222,13 +227,15 @@ func expandPricingToCache(cache *channelCache, ch *Channel, gid int64, platform
|
|||||||
|
|
||||||
// expandMappingToCache 将渠道的模型映射展开到缓存(按分组+平台维度)。
|
// expandMappingToCache 将渠道的模型映射展开到缓存(按分组+平台维度)。
|
||||||
// antigravity 平台同时服务 Claude 和 Gemini 模型。
|
// antigravity 平台同时服务 Claude 和 Gemini 模型。
|
||||||
|
// 缓存 key 使用映射条目的原始平台(mappingPlatform),避免跨平台同名映射覆盖。
|
||||||
func expandMappingToCache(cache *channelCache, ch *Channel, gid int64, platform string) {
|
func expandMappingToCache(cache *channelCache, ch *Channel, gid int64, platform string) {
|
||||||
for _, mappingPlatform := range matchingPlatforms(platform) {
|
for _, mappingPlatform := range matchingPlatforms(platform) {
|
||||||
platformMapping, ok := ch.ModelMapping[mappingPlatform]
|
platformMapping, ok := ch.ModelMapping[mappingPlatform]
|
||||||
if !ok {
|
if !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
gpKey := channelGroupPlatformKey{groupID: gid, platform: platform}
|
// 使用映射条目的原始平台作为缓存 key,防止跨平台同名映射冲突
|
||||||
|
gpKey := channelGroupPlatformKey{groupID: gid, platform: mappingPlatform}
|
||||||
for src, dst := range platformMapping {
|
for src, dst := range platformMapping {
|
||||||
if strings.HasSuffix(src, "*") {
|
if strings.HasSuffix(src, "*") {
|
||||||
prefix := strings.ToLower(strings.TrimSuffix(src, "*"))
|
prefix := strings.ToLower(strings.TrimSuffix(src, "*"))
|
||||||
@@ -237,7 +244,7 @@ func expandMappingToCache(cache *channelCache, ch *Channel, gid int64, platform
|
|||||||
target: dst,
|
target: dst,
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
key := channelModelKey{groupID: gid, platform: platform, model: strings.ToLower(src)}
|
key := channelModelKey{groupID: gid, platform: mappingPlatform, model: strings.ToLower(src)}
|
||||||
cache.mappingByGroupModel[key] = dst
|
cache.mappingByGroupModel[key] = dst
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -349,6 +356,43 @@ func (c *channelCache) matchWildcardMapping(groupID int64, platform, modelLower
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// lookupPricingAcrossPlatforms 在所有匹配平台中查找模型定价。
|
||||||
|
// antigravity 分组的缓存 key 使用定价条目的原始平台,因此查找时需依次尝试
|
||||||
|
// matchingPlatforms() 返回的所有平台(antigravity → anthropic → gemini),
|
||||||
|
// 返回第一个命中的结果。非 antigravity 平台只尝试自身。
|
||||||
|
func lookupPricingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatform, modelLower string) *ChannelModelPricing {
|
||||||
|
for _, p := range matchingPlatforms(groupPlatform) {
|
||||||
|
key := channelModelKey{groupID: groupID, platform: p, model: modelLower}
|
||||||
|
if pricing, ok := cache.pricingByGroupModel[key]; ok {
|
||||||
|
return pricing
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 精确查找全部失败,依次尝试通配符匹配
|
||||||
|
for _, p := range matchingPlatforms(groupPlatform) {
|
||||||
|
if pricing := cache.matchWildcard(groupID, p, modelLower); pricing != nil {
|
||||||
|
return pricing
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// lookupMappingAcrossPlatforms 在所有匹配平台中查找模型映射。
|
||||||
|
// 逻辑与 lookupPricingAcrossPlatforms 相同:先精确查找,再通配符。
|
||||||
|
func lookupMappingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatform, modelLower string) string {
|
||||||
|
for _, p := range matchingPlatforms(groupPlatform) {
|
||||||
|
key := channelModelKey{groupID: groupID, platform: p, model: modelLower}
|
||||||
|
if mapped, ok := cache.mappingByGroupModel[key]; ok {
|
||||||
|
return mapped
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, p := range matchingPlatforms(groupPlatform) {
|
||||||
|
if mapped := cache.matchWildcardMapping(groupID, p, modelLower); mapped != "" {
|
||||||
|
return mapped
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
// GetChannelForGroup 获取分组关联的渠道(热路径 O(1))
|
// GetChannelForGroup 获取分组关联的渠道(热路径 O(1))
|
||||||
func (s *ChannelService) GetChannelForGroup(ctx context.Context, groupID int64) (*Channel, error) {
|
func (s *ChannelService) GetChannelForGroup(ctx context.Context, groupID int64) (*Channel, error) {
|
||||||
cache, err := s.loadCache(ctx)
|
cache, err := s.loadCache(ctx)
|
||||||
@@ -389,7 +433,9 @@ func (s *ChannelService) lookupGroupChannel(ctx context.Context, groupID int64)
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetChannelModelPricing 获取指定分组+模型的渠道定价(热路径 O(1))
|
// GetChannelModelPricing 获取指定分组+模型的渠道定价(热路径 O(1))。
|
||||||
|
// antigravity 分组依次尝试所有匹配平台(antigravity → anthropic → gemini),
|
||||||
|
// 确保跨平台同名模型各自独立匹配。
|
||||||
func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int64, model string) *ChannelModelPricing {
|
func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int64, model string) *ChannelModelPricing {
|
||||||
lk, err := s.lookupGroupChannel(ctx, groupID)
|
lk, err := s.lookupGroupChannel(ctx, groupID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -401,14 +447,9 @@ func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int
|
|||||||
}
|
}
|
||||||
|
|
||||||
modelLower := strings.ToLower(model)
|
modelLower := strings.ToLower(model)
|
||||||
key := channelModelKey{groupID: groupID, platform: lk.platform, model: modelLower}
|
pricing := lookupPricingAcrossPlatforms(lk.cache, groupID, lk.platform, modelLower)
|
||||||
pricing, ok := lk.cache.pricingByGroupModel[key]
|
if pricing == nil {
|
||||||
if !ok {
|
return nil
|
||||||
// 精确查找失败,尝试通配符匹配
|
|
||||||
pricing = lk.cache.matchWildcard(groupID, lk.platform, modelLower)
|
|
||||||
if pricing == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
cp := pricing.Clone()
|
cp := pricing.Clone()
|
||||||
@@ -453,7 +494,8 @@ func (s *ChannelService) ResolveChannelMappingAndRestrict(ctx context.Context, g
|
|||||||
return resolveMapping(lk, *groupID, model), false
|
return resolveMapping(lk, *groupID, model), false
|
||||||
}
|
}
|
||||||
|
|
||||||
// resolveMapping 基于已查找的渠道信息解析模型映射
|
// resolveMapping 基于已查找的渠道信息解析模型映射。
|
||||||
|
// antigravity 分组依次尝试所有匹配平台,确保跨平台同名映射各自独立。
|
||||||
func resolveMapping(lk *channelLookup, groupID int64, model string) ChannelMappingResult {
|
func resolveMapping(lk *channelLookup, groupID int64, model string) ChannelMappingResult {
|
||||||
result := ChannelMappingResult{
|
result := ChannelMappingResult{
|
||||||
MappedModel: model,
|
MappedModel: model,
|
||||||
@@ -465,11 +507,7 @@ func resolveMapping(lk *channelLookup, groupID int64, model string) ChannelMappi
|
|||||||
}
|
}
|
||||||
|
|
||||||
modelLower := strings.ToLower(model)
|
modelLower := strings.ToLower(model)
|
||||||
key := channelModelKey{groupID: groupID, platform: lk.platform, model: modelLower}
|
if mapped := lookupMappingAcrossPlatforms(lk.cache, groupID, lk.platform, modelLower); mapped != "" {
|
||||||
if mapped, ok := lk.cache.mappingByGroupModel[key]; ok {
|
|
||||||
result.MappedModel = mapped
|
|
||||||
result.Mapped = true
|
|
||||||
} else if mapped := lk.cache.matchWildcardMapping(groupID, lk.platform, modelLower); mapped != "" {
|
|
||||||
result.MappedModel = mapped
|
result.MappedModel = mapped
|
||||||
result.Mapped = true
|
result.Mapped = true
|
||||||
}
|
}
|
||||||
@@ -477,19 +515,15 @@ func resolveMapping(lk *channelLookup, groupID int64, model string) ChannelMappi
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkRestricted 基于已查找的渠道信息检查模型是否被限制
|
// checkRestricted 基于已查找的渠道信息检查模型是否被限制。
|
||||||
|
// antigravity 分组依次尝试所有匹配平台的定价列表。
|
||||||
func checkRestricted(lk *channelLookup, groupID int64, model string) bool {
|
func checkRestricted(lk *channelLookup, groupID int64, model string) bool {
|
||||||
if !lk.channel.RestrictModels {
|
if !lk.channel.RestrictModels {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
// 检查模型是否在定价列表中
|
|
||||||
modelLower := strings.ToLower(model)
|
modelLower := strings.ToLower(model)
|
||||||
key := channelModelKey{groupID: groupID, platform: lk.platform, model: modelLower}
|
// 使用与查找定价相同的跨平台逻辑
|
||||||
if _, exists := lk.cache.pricingByGroupModel[key]; exists {
|
if lookupPricingAcrossPlatforms(lk.cache, groupID, lk.platform, modelLower) != nil {
|
||||||
return false
|
|
||||||
}
|
|
||||||
// 精确查找失败,尝试通配符匹配
|
|
||||||
if lk.cache.matchWildcard(groupID, lk.platform, modelLower) != nil {
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
@@ -550,6 +584,9 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
|
|||||||
if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
|
if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if err := validatePricingIntervals(channel.ModelPricing); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
if err := validateNoConflictingMappings(channel.ModelMapping); err != nil {
|
if err := validateNoConflictingMappings(channel.ModelMapping); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -624,6 +661,9 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan
|
|||||||
if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
|
if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if err := validatePricingIntervals(channel.ModelPricing); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
if err := validateNoConflictingMappings(channel.ModelMapping); err != nil {
|
if err := validateNoConflictingMappings(channel.ModelMapping); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -756,6 +796,19 @@ func validateNoConflictingMappings(mapping map[string]map[string]string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func validatePricingIntervals(pricingList []ChannelModelPricing) error {
|
||||||
|
for _, pricing := range pricingList {
|
||||||
|
if err := ValidateIntervals(pricing.Intervals); err != nil {
|
||||||
|
return infraerrors.BadRequest(
|
||||||
|
"INVALID_PRICING_INTERVALS",
|
||||||
|
fmt.Sprintf("invalid pricing intervals for platform '%s' models %v: %v",
|
||||||
|
pricing.Platform, pricing.Models, err),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// detectConflicts 在一组 modelEntry 中检测冲突,返回带有 errCode 和 label 的错误
|
// detectConflicts 在一组 modelEntry 中检测冲突,返回带有 errCode 和 label 的错误
|
||||||
func detectConflicts(entries []modelEntry, platform, errCode, label string) error {
|
func detectConflicts(entries []modelEntry, platform, errCode, label string) error {
|
||||||
for i := 0; i < len(entries); i++ {
|
for i := 0; i < len(entries); i++ {
|
||||||
|
|||||||
@@ -1401,6 +1401,32 @@ func TestCreate_DuplicateModel(t *testing.T) {
|
|||||||
require.Contains(t, err.Error(), "claude-opus-4")
|
require.Contains(t, err.Error(), "claude-opus-4")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCreate_InvalidPricingIntervals(t *testing.T) {
|
||||||
|
repo := &mockChannelRepository{
|
||||||
|
existsByNameFn: func(_ context.Context, _ string) (bool, error) {
|
||||||
|
return false, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := newTestChannelService(repo)
|
||||||
|
|
||||||
|
_, err := svc.Create(context.Background(), &CreateChannelInput{
|
||||||
|
Name: "new-channel",
|
||||||
|
ModelPricing: []ChannelModelPricing{
|
||||||
|
{
|
||||||
|
Platform: "anthropic",
|
||||||
|
Models: []string{"claude-opus-4"},
|
||||||
|
Intervals: []PricingInterval{
|
||||||
|
{MinTokens: 0, MaxTokens: testPtrInt(2000), InputPrice: testPtrFloat64(1e-6)},
|
||||||
|
{MinTokens: 1000, MaxTokens: testPtrInt(3000), InputPrice: testPtrFloat64(2e-6)},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "INVALID_PRICING_INTERVALS")
|
||||||
|
require.Contains(t, err.Error(), "overlap")
|
||||||
|
}
|
||||||
|
|
||||||
func TestCreate_DefaultBillingModelSource(t *testing.T) {
|
func TestCreate_DefaultBillingModelSource(t *testing.T) {
|
||||||
var capturedChannel *Channel
|
var capturedChannel *Channel
|
||||||
repo := &mockChannelRepository{
|
repo := &mockChannelRepository{
|
||||||
@@ -1592,6 +1618,37 @@ func TestUpdate_DuplicateModel(t *testing.T) {
|
|||||||
require.Contains(t, err.Error(), "claude-opus-4")
|
require.Contains(t, err.Error(), "claude-opus-4")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUpdate_InvalidPricingIntervals(t *testing.T) {
|
||||||
|
existing := &Channel{
|
||||||
|
ID: 1,
|
||||||
|
Name: "original",
|
||||||
|
Status: StatusActive,
|
||||||
|
}
|
||||||
|
repo := &mockChannelRepository{
|
||||||
|
getByIDFn: func(_ context.Context, _ int64) (*Channel, error) {
|
||||||
|
return existing.Clone(), nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := newTestChannelService(repo)
|
||||||
|
|
||||||
|
invalidPricing := []ChannelModelPricing{
|
||||||
|
{
|
||||||
|
Platform: "anthropic",
|
||||||
|
Models: []string{"claude-opus-4"},
|
||||||
|
Intervals: []PricingInterval{
|
||||||
|
{MinTokens: 0, MaxTokens: nil, InputPrice: testPtrFloat64(1e-6)},
|
||||||
|
{MinTokens: 2000, MaxTokens: testPtrInt(4000), InputPrice: testPtrFloat64(2e-6)},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
_, err := svc.Update(context.Background(), 1, &UpdateChannelInput{
|
||||||
|
ModelPricing: &invalidPricing,
|
||||||
|
})
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "INVALID_PRICING_INTERVALS")
|
||||||
|
require.Contains(t, err.Error(), "unbounded")
|
||||||
|
}
|
||||||
|
|
||||||
func TestUpdate_InvalidatesChannelCache(t *testing.T) {
|
func TestUpdate_InvalidatesChannelCache(t *testing.T) {
|
||||||
existing := &Channel{
|
existing := &Channel{
|
||||||
ID: 1,
|
ID: 1,
|
||||||
@@ -1984,3 +2041,144 @@ func TestResolveChannelMapping_AntigravityCrossPlatform(t *testing.T) {
|
|||||||
require.Equal(t, "claude-opus-4-6", result.MappedModel)
|
require.Equal(t, "claude-opus-4-6", result.MappedModel)
|
||||||
require.Equal(t, int64(1), result.ChannelID)
|
require.Equal(t, int64(1), result.ChannelID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ===========================================================================
|
||||||
|
// 11. Antigravity cross-platform same-name model — no overwrite
|
||||||
|
// ===========================================================================
|
||||||
|
|
||||||
|
func TestGetChannelModelPricing_AntigravitySameModelDifferentPlatforms(t *testing.T) {
|
||||||
|
// anthropic 和 gemini 都定义了同名模型 "shared-model",价格不同。
|
||||||
|
// antigravity 分组应能分别查到各自的定价,而不是后者覆盖前者。
|
||||||
|
ch := Channel{
|
||||||
|
ID: 1,
|
||||||
|
Status: StatusActive,
|
||||||
|
GroupIDs: []int64{10},
|
||||||
|
ModelPricing: []ChannelModelPricing{
|
||||||
|
{ID: 200, Platform: PlatformAnthropic, Models: []string{"shared-model"}, InputPrice: testPtrFloat64(10e-6)},
|
||||||
|
{ID: 201, Platform: PlatformGemini, Models: []string{"shared-model"}, InputPrice: testPtrFloat64(5e-6)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity})
|
||||||
|
svc := newTestChannelService(repo)
|
||||||
|
|
||||||
|
// antigravity 分组查找 "shared-model":应命中第一个匹配(按 matchingPlatforms 顺序 antigravity→anthropic→gemini)
|
||||||
|
result := svc.GetChannelModelPricing(context.Background(), 10, "shared-model")
|
||||||
|
require.NotNil(t, result, "antigravity group should find pricing for shared-model")
|
||||||
|
// 第一个匹配应该是 anthropic(matchingPlatforms 返回 [antigravity, anthropic, gemini])
|
||||||
|
require.Equal(t, int64(200), result.ID)
|
||||||
|
require.InDelta(t, 10e-6, *result.InputPrice, 1e-12)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetChannelModelPricing_AntigravityOnlyGeminiPricing(t *testing.T) {
|
||||||
|
// 只有 gemini 平台定义了模型 "gemini-model"。
|
||||||
|
// antigravity 分组应能查到 gemini 的定价。
|
||||||
|
ch := Channel{
|
||||||
|
ID: 1,
|
||||||
|
Status: StatusActive,
|
||||||
|
GroupIDs: []int64{10},
|
||||||
|
ModelPricing: []ChannelModelPricing{
|
||||||
|
{ID: 300, Platform: PlatformGemini, Models: []string{"gemini-model"}, InputPrice: testPtrFloat64(2e-6)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity})
|
||||||
|
svc := newTestChannelService(repo)
|
||||||
|
|
||||||
|
result := svc.GetChannelModelPricing(context.Background(), 10, "gemini-model")
|
||||||
|
require.NotNil(t, result, "antigravity group should find gemini pricing")
|
||||||
|
require.Equal(t, int64(300), result.ID)
|
||||||
|
require.InDelta(t, 2e-6, *result.InputPrice, 1e-12)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetChannelModelPricing_AntigravityWildcardCrossPlatformNoOverwrite(t *testing.T) {
|
||||||
|
// anthropic 和 gemini 都有 "shared-*" 通配符定价,价格不同。
|
||||||
|
// antigravity 分组查找 "shared-model" 应命中第一个匹配而非被覆盖。
|
||||||
|
ch := Channel{
|
||||||
|
ID: 1,
|
||||||
|
Status: StatusActive,
|
||||||
|
GroupIDs: []int64{10},
|
||||||
|
ModelPricing: []ChannelModelPricing{
|
||||||
|
{ID: 400, Platform: PlatformAnthropic, Models: []string{"shared-*"}, InputPrice: testPtrFloat64(10e-6)},
|
||||||
|
{ID: 401, Platform: PlatformGemini, Models: []string{"shared-*"}, InputPrice: testPtrFloat64(5e-6)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity})
|
||||||
|
svc := newTestChannelService(repo)
|
||||||
|
|
||||||
|
result := svc.GetChannelModelPricing(context.Background(), 10, "shared-model")
|
||||||
|
require.NotNil(t, result, "antigravity group should find wildcard pricing for shared-model")
|
||||||
|
// 两个通配符都存在,应命中 anthropic 的(matchingPlatforms 顺序)
|
||||||
|
require.Equal(t, int64(400), result.ID)
|
||||||
|
require.InDelta(t, 10e-6, *result.InputPrice, 1e-12)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveChannelMapping_AntigravitySameModelDifferentPlatforms(t *testing.T) {
|
||||||
|
// anthropic 和 gemini 都定义了同名模型映射 "alias" → 不同目标。
|
||||||
|
// antigravity 分组应命中 anthropic 的映射(按 matchingPlatforms 顺序)。
|
||||||
|
ch := Channel{
|
||||||
|
ID: 1,
|
||||||
|
Status: StatusActive,
|
||||||
|
GroupIDs: []int64{10},
|
||||||
|
ModelMapping: map[string]map[string]string{
|
||||||
|
PlatformAnthropic: {"alias": "anthropic-target"},
|
||||||
|
PlatformGemini: {"alias": "gemini-target"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity})
|
||||||
|
svc := newTestChannelService(repo)
|
||||||
|
|
||||||
|
result := svc.ResolveChannelMapping(context.Background(), 10, "alias")
|
||||||
|
require.True(t, result.Mapped)
|
||||||
|
require.Equal(t, "anthropic-target", result.MappedModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckRestricted_AntigravitySameModelDifferentPlatforms(t *testing.T) {
|
||||||
|
// anthropic 和 gemini 都定义了同名模型 "shared-model"。
|
||||||
|
// antigravity 分组启用了 RestrictModels,"shared-model" 应不被限制。
|
||||||
|
ch := Channel{
|
||||||
|
ID: 1,
|
||||||
|
Status: StatusActive,
|
||||||
|
RestrictModels: true,
|
||||||
|
GroupIDs: []int64{10},
|
||||||
|
ModelPricing: []ChannelModelPricing{
|
||||||
|
{ID: 500, Platform: PlatformAnthropic, Models: []string{"shared-model"}, InputPrice: testPtrFloat64(10e-6)},
|
||||||
|
{ID: 501, Platform: PlatformGemini, Models: []string{"shared-model"}, InputPrice: testPtrFloat64(5e-6)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity})
|
||||||
|
svc := newTestChannelService(repo)
|
||||||
|
|
||||||
|
restricted := svc.IsModelRestricted(context.Background(), 10, "shared-model")
|
||||||
|
require.False(t, restricted, "shared-model should not be restricted for antigravity")
|
||||||
|
|
||||||
|
// 未定义的模型应被限制
|
||||||
|
restricted = svc.IsModelRestricted(context.Background(), 10, "unknown-model")
|
||||||
|
require.True(t, restricted, "unknown-model should be restricted for antigravity")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetChannelModelPricing_NonAntigravityUnaffected(t *testing.T) {
|
||||||
|
// 确保非 antigravity 平台的行为不受影响。
|
||||||
|
// anthropic 分组只能看到 anthropic 的定价,看不到 gemini 的。
|
||||||
|
ch := Channel{
|
||||||
|
ID: 1,
|
||||||
|
Status: StatusActive,
|
||||||
|
GroupIDs: []int64{10, 20},
|
||||||
|
ModelPricing: []ChannelModelPricing{
|
||||||
|
{ID: 600, Platform: PlatformAnthropic, Models: []string{"shared-model"}, InputPrice: testPtrFloat64(10e-6)},
|
||||||
|
{ID: 601, Platform: PlatformGemini, Models: []string{"shared-model"}, InputPrice: testPtrFloat64(5e-6)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
repo := makeStandardRepo(ch, map[int64]string{10: PlatformAnthropic, 20: PlatformGemini})
|
||||||
|
svc := newTestChannelService(repo)
|
||||||
|
|
||||||
|
// anthropic 分组应该只看到 anthropic 的定价
|
||||||
|
result := svc.GetChannelModelPricing(context.Background(), 10, "shared-model")
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Equal(t, int64(600), result.ID)
|
||||||
|
require.InDelta(t, 10e-6, *result.InputPrice, 1e-12)
|
||||||
|
|
||||||
|
// gemini 分组应该只看到 gemini 的定价
|
||||||
|
result = svc.GetChannelModelPricing(context.Background(), 20, "shared-model")
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Equal(t, int64(601), result.ID)
|
||||||
|
require.InDelta(t, 5e-6, *result.InputPrice, 1e-12)
|
||||||
|
}
|
||||||
|
|||||||
@@ -307,3 +307,129 @@ func TestChannelClone_EdgeCases(t *testing.T) {
|
|||||||
require.Equal(t, "gpt-4-turbo", original.ModelMapping["openai"]["gpt-4"])
|
require.Equal(t, "gpt-4-turbo", original.ModelMapping["openai"]["gpt-4"])
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// --- ValidateIntervals ---
|
||||||
|
|
||||||
|
func TestValidateIntervals_Empty(t *testing.T) {
|
||||||
|
require.NoError(t, ValidateIntervals(nil))
|
||||||
|
require.NoError(t, ValidateIntervals([]PricingInterval{}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateIntervals_ValidIntervals(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
intervals []PricingInterval
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "single bounded interval",
|
||||||
|
intervals: []PricingInterval{
|
||||||
|
{MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6)},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "two intervals with gap",
|
||||||
|
intervals: []PricingInterval{
|
||||||
|
{MinTokens: 0, MaxTokens: testPtrInt(100000), InputPrice: testPtrFloat64(1e-6)},
|
||||||
|
{MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(2e-6)},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "two contiguous intervals",
|
||||||
|
intervals: []PricingInterval{
|
||||||
|
{MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6)},
|
||||||
|
{MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(2e-6)},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unsorted input (auto-sorted by validator)",
|
||||||
|
intervals: []PricingInterval{
|
||||||
|
{MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(2e-6)},
|
||||||
|
{MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6)},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single unbounded interval",
|
||||||
|
intervals: []PricingInterval{
|
||||||
|
{MinTokens: 0, MaxTokens: nil, InputPrice: testPtrFloat64(1e-6)},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
require.NoError(t, ValidateIntervals(tt.intervals))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateIntervals_NegativeMinTokens(t *testing.T) {
|
||||||
|
intervals := []PricingInterval{
|
||||||
|
{MinTokens: -1, MaxTokens: testPtrInt(100), InputPrice: testPtrFloat64(1e-6)},
|
||||||
|
}
|
||||||
|
err := ValidateIntervals(intervals)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "min_tokens")
|
||||||
|
require.Contains(t, err.Error(), ">= 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateIntervals_MaxTokensZero(t *testing.T) {
|
||||||
|
intervals := []PricingInterval{
|
||||||
|
{MinTokens: 0, MaxTokens: testPtrInt(0), InputPrice: testPtrFloat64(1e-6)},
|
||||||
|
}
|
||||||
|
err := ValidateIntervals(intervals)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "max_tokens")
|
||||||
|
require.Contains(t, err.Error(), "> 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateIntervals_MaxLessThanMin(t *testing.T) {
|
||||||
|
intervals := []PricingInterval{
|
||||||
|
{MinTokens: 100, MaxTokens: testPtrInt(50), InputPrice: testPtrFloat64(1e-6)},
|
||||||
|
}
|
||||||
|
err := ValidateIntervals(intervals)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "max_tokens")
|
||||||
|
require.Contains(t, err.Error(), "> min_tokens")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateIntervals_MaxEqualsMin(t *testing.T) {
|
||||||
|
intervals := []PricingInterval{
|
||||||
|
{MinTokens: 100, MaxTokens: testPtrInt(100), InputPrice: testPtrFloat64(1e-6)},
|
||||||
|
}
|
||||||
|
err := ValidateIntervals(intervals)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "max_tokens")
|
||||||
|
require.Contains(t, err.Error(), "> min_tokens")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateIntervals_NegativePrice(t *testing.T) {
|
||||||
|
negPrice := -0.01
|
||||||
|
intervals := []PricingInterval{
|
||||||
|
{MinTokens: 0, MaxTokens: testPtrInt(100), InputPrice: &negPrice},
|
||||||
|
}
|
||||||
|
err := ValidateIntervals(intervals)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "input_price")
|
||||||
|
require.Contains(t, err.Error(), ">= 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateIntervals_OverlappingIntervals(t *testing.T) {
|
||||||
|
intervals := []PricingInterval{
|
||||||
|
{MinTokens: 0, MaxTokens: testPtrInt(200), InputPrice: testPtrFloat64(1e-6)},
|
||||||
|
{MinTokens: 100, MaxTokens: testPtrInt(300), InputPrice: testPtrFloat64(2e-6)},
|
||||||
|
}
|
||||||
|
err := ValidateIntervals(intervals)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "overlap")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateIntervals_UnboundedNotLast(t *testing.T) {
|
||||||
|
intervals := []PricingInterval{
|
||||||
|
{MinTokens: 0, MaxTokens: nil, InputPrice: testPtrFloat64(1e-6)},
|
||||||
|
{MinTokens: 128000, MaxTokens: testPtrInt(256000), InputPrice: testPtrFloat64(2e-6)},
|
||||||
|
}
|
||||||
|
err := ValidateIntervals(intervals)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "unbounded")
|
||||||
|
require.Contains(t, err.Error(), "last")
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,130 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSelectAccountForModelWithExclusions_UsesFallbackGroupForChannelRestriction(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
groupID := int64(10)
|
||||||
|
fallbackID := int64(11)
|
||||||
|
ch := Channel{
|
||||||
|
ID: 1,
|
||||||
|
Status: StatusActive,
|
||||||
|
GroupIDs: []int64{fallbackID},
|
||||||
|
RestrictModels: true,
|
||||||
|
ModelPricing: []ChannelModelPricing{
|
||||||
|
{Platform: PlatformAnthropic, Models: []string{"claude-sonnet-4-6"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{
|
||||||
|
fallbackID: PlatformAnthropic,
|
||||||
|
}))
|
||||||
|
accountRepo := &mockAccountRepoForPlatform{
|
||||||
|
accounts: []Account{
|
||||||
|
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||||
|
},
|
||||||
|
accountsByID: map[int64]*Account{},
|
||||||
|
}
|
||||||
|
for i := range accountRepo.accounts {
|
||||||
|
accountRepo.accountsByID[accountRepo.accounts[i].ID] = &accountRepo.accounts[i]
|
||||||
|
}
|
||||||
|
groupRepo := &mockGroupRepoForGateway{
|
||||||
|
groups: map[int64]*Group{
|
||||||
|
groupID: {
|
||||||
|
ID: groupID,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Status: StatusActive,
|
||||||
|
ClaudeCodeOnly: true,
|
||||||
|
FallbackGroupID: &fallbackID,
|
||||||
|
Hydrated: true,
|
||||||
|
},
|
||||||
|
fallbackID: {
|
||||||
|
ID: fallbackID,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Status: StatusActive,
|
||||||
|
Hydrated: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &GatewayService{
|
||||||
|
accountRepo: accountRepo,
|
||||||
|
groupRepo: groupRepo,
|
||||||
|
channelService: channelSvc,
|
||||||
|
cfg: testConfig(),
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.WithValue(context.Background(), ctxkey.Group, groupRepo.groups[groupID])
|
||||||
|
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-sonnet-4-6", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, account)
|
||||||
|
require.Equal(t, int64(1), account.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSelectAccountWithLoadAwareness_UsesFallbackGroupForChannelRestriction(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
groupID := int64(10)
|
||||||
|
fallbackID := int64(11)
|
||||||
|
ch := Channel{
|
||||||
|
ID: 1,
|
||||||
|
Status: StatusActive,
|
||||||
|
GroupIDs: []int64{fallbackID},
|
||||||
|
RestrictModels: true,
|
||||||
|
ModelPricing: []ChannelModelPricing{
|
||||||
|
{Platform: PlatformAnthropic, Models: []string{"claude-sonnet-4-6"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{
|
||||||
|
fallbackID: PlatformAnthropic,
|
||||||
|
}))
|
||||||
|
accountRepo := &mockAccountRepoForPlatform{
|
||||||
|
accounts: []Account{
|
||||||
|
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||||
|
},
|
||||||
|
accountsByID: map[int64]*Account{},
|
||||||
|
}
|
||||||
|
for i := range accountRepo.accounts {
|
||||||
|
accountRepo.accountsByID[accountRepo.accounts[i].ID] = &accountRepo.accounts[i]
|
||||||
|
}
|
||||||
|
groupRepo := &mockGroupRepoForGateway{
|
||||||
|
groups: map[int64]*Group{
|
||||||
|
groupID: {
|
||||||
|
ID: groupID,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Status: StatusActive,
|
||||||
|
ClaudeCodeOnly: true,
|
||||||
|
FallbackGroupID: &fallbackID,
|
||||||
|
Hydrated: true,
|
||||||
|
},
|
||||||
|
fallbackID: {
|
||||||
|
ID: fallbackID,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Status: StatusActive,
|
||||||
|
Hydrated: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &GatewayService{
|
||||||
|
accountRepo: accountRepo,
|
||||||
|
groupRepo: groupRepo,
|
||||||
|
channelService: channelSvc,
|
||||||
|
cfg: testConfig(),
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.WithValue(context.Background(), ctxkey.Group, groupRepo.groups[groupID])
|
||||||
|
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "claude-sonnet-4-6", nil, "", 0)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.NotNil(t, result.Account)
|
||||||
|
require.Equal(t, int64(1), result.Account.ID)
|
||||||
|
}
|
||||||
@@ -1178,11 +1178,6 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
|
|||||||
|
|
||||||
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
|
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
|
||||||
func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
|
func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
|
||||||
// 渠道定价限制预检查(requested / channel_mapped 基准)
|
|
||||||
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
|
|
||||||
return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 优先检查 context 中的强制平台(/antigravity 路由)
|
// 优先检查 context 中的强制平台(/antigravity 路由)
|
||||||
var platform string
|
var platform string
|
||||||
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
|
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
|
||||||
@@ -1201,6 +1196,12 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
|
|||||||
platform = PlatformAnthropic
|
platform = PlatformAnthropic
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Claude Code 限制可能已将 groupID 解析为 fallback group,
|
||||||
|
// 渠道限制预检查必须使用解析后的分组。
|
||||||
|
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
|
||||||
|
return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
|
||||||
|
}
|
||||||
|
|
||||||
// anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
|
// anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
|
||||||
// 注意:强制平台模式不走混合调度
|
// 注意:强制平台模式不走混合调度
|
||||||
if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform {
|
if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform {
|
||||||
@@ -1217,11 +1218,6 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
|
|||||||
// metadataUserID: 用于客户端亲和调度,从中提取客户端 ID
|
// metadataUserID: 用于客户端亲和调度,从中提取客户端 ID
|
||||||
// sub2apiUserID: 系统用户 ID,用于二维亲和调度
|
// sub2apiUserID: 系统用户 ID,用于二维亲和调度
|
||||||
func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string, sub2apiUserID int64) (*AccountSelectionResult, error) {
|
func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string, sub2apiUserID int64) (*AccountSelectionResult, error) {
|
||||||
// 渠道定价限制预检查(requested / channel_mapped 基准)
|
|
||||||
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
|
|
||||||
return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 调试日志:记录调度入口参数
|
// 调试日志:记录调度入口参数
|
||||||
excludedIDsList := make([]int64, 0, len(excludedIDs))
|
excludedIDsList := make([]int64, 0, len(excludedIDs))
|
||||||
for id := range excludedIDs {
|
for id := range excludedIDs {
|
||||||
@@ -1242,6 +1238,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
}
|
}
|
||||||
ctx = s.withGroupContext(ctx, group)
|
ctx = s.withGroupContext(ctx, group)
|
||||||
|
|
||||||
|
// Claude Code 限制可能已将 groupID 解析为 fallback group,
|
||||||
|
// 渠道限制预检查必须使用解析后的分组。
|
||||||
|
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
|
||||||
|
return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
|
||||||
|
}
|
||||||
|
|
||||||
var stickyAccountID int64
|
var stickyAccountID int64
|
||||||
if prefetch := prefetchedStickyAccountIDFromContext(ctx, groupID); prefetch > 0 {
|
if prefetch := prefetchedStickyAccountIDFromContext(ctx, groupID); prefetch > 0 {
|
||||||
stickyAccountID = prefetch
|
stickyAccountID = prefetch
|
||||||
|
|||||||
140
backend/internal/service/openai_channel_restriction_test.go
Normal file
140
backend/internal/service/openai_channel_restriction_test.go
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestOpenAISelectAccountForModelWithExclusions_ChannelMappedRestrictionRejectsEarly(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
channelSvc := newTestChannelService(makeStandardRepo(Channel{
|
||||||
|
ID: 1,
|
||||||
|
Status: StatusActive,
|
||||||
|
GroupIDs: []int64{10},
|
||||||
|
RestrictModels: true,
|
||||||
|
BillingModelSource: BillingModelSourceChannelMapped,
|
||||||
|
ModelPricing: []ChannelModelPricing{
|
||||||
|
{Platform: PlatformOpenAI, Models: []string{"gpt-4o"}},
|
||||||
|
},
|
||||||
|
ModelMapping: map[string]map[string]string{
|
||||||
|
PlatformOpenAI: {"gpt-4.1": "o3-mini"},
|
||||||
|
},
|
||||||
|
}, map[int64]string{10: PlatformOpenAI}))
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
accountRepo: stubOpenAIAccountRepo{accounts: []Account{
|
||||||
|
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true},
|
||||||
|
}},
|
||||||
|
channelService: channelSvc,
|
||||||
|
}
|
||||||
|
|
||||||
|
groupID := int64(10)
|
||||||
|
_, err := svc.SelectAccountForModelWithExclusions(context.Background(), &groupID, "", "gpt-4.1", nil)
|
||||||
|
require.ErrorIs(t, err, ErrNoAvailableAccounts)
|
||||||
|
require.Contains(t, err.Error(), "channel pricing restriction")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAISelectAccountForModelWithExclusions_UpstreamRestrictionSkipsDisallowedAccount(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
channelSvc := newTestChannelService(makeStandardRepo(Channel{
|
||||||
|
ID: 1,
|
||||||
|
Status: StatusActive,
|
||||||
|
GroupIDs: []int64{10},
|
||||||
|
RestrictModels: true,
|
||||||
|
BillingModelSource: BillingModelSourceUpstream,
|
||||||
|
ModelPricing: []ChannelModelPricing{
|
||||||
|
{Platform: PlatformOpenAI, Models: []string{"o3-mini"}},
|
||||||
|
},
|
||||||
|
}, map[int64]string{10: PlatformOpenAI}))
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
accountRepo: stubOpenAIAccountRepo{accounts: []Account{
|
||||||
|
{
|
||||||
|
ID: 1,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Priority: 10,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"model_mapping": map[string]any{"gpt-4.1": "gpt-4o"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: 2,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Priority: 20,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"model_mapping": map[string]any{"gpt-4.1": "o3-mini"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
channelService: channelSvc,
|
||||||
|
}
|
||||||
|
|
||||||
|
groupID := int64(10)
|
||||||
|
account, err := svc.SelectAccountForModelWithExclusions(context.Background(), &groupID, "", "gpt-4.1", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, account)
|
||||||
|
require.Equal(t, int64(2), account.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAISelectAccountForModelWithExclusions_StickyRestrictedUpstreamFallsBack(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
channelSvc := newTestChannelService(makeStandardRepo(Channel{
|
||||||
|
ID: 1,
|
||||||
|
Status: StatusActive,
|
||||||
|
GroupIDs: []int64{10},
|
||||||
|
RestrictModels: true,
|
||||||
|
BillingModelSource: BillingModelSourceUpstream,
|
||||||
|
ModelPricing: []ChannelModelPricing{
|
||||||
|
{Platform: PlatformOpenAI, Models: []string{"o3-mini"}},
|
||||||
|
},
|
||||||
|
}, map[int64]string{10: PlatformOpenAI}))
|
||||||
|
|
||||||
|
cache := &stubGatewayCache{
|
||||||
|
sessionBindings: map[string]int64{"openai:sticky-session": 1},
|
||||||
|
}
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
accountRepo: stubOpenAIAccountRepo{accounts: []Account{
|
||||||
|
{
|
||||||
|
ID: 1,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Priority: 10,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"model_mapping": map[string]any{"gpt-4.1": "gpt-4o"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: 2,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Priority: 20,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"model_mapping": map[string]any{"gpt-4.1": "o3-mini"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
channelService: channelSvc,
|
||||||
|
cache: cache,
|
||||||
|
}
|
||||||
|
|
||||||
|
groupID := int64(10)
|
||||||
|
account, err := svc.SelectAccountForModelWithExclusions(context.Background(), &groupID, "sticky-session", "gpt-4.1", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, account)
|
||||||
|
require.Equal(t, int64(2), account.ID)
|
||||||
|
require.Equal(t, 1, cache.deletedSessions["openai:sticky-session"])
|
||||||
|
require.Equal(t, int64(2), cache.sessionBindings["openai:sticky-session"])
|
||||||
|
}
|
||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log/slog"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sort"
|
"sort"
|
||||||
@@ -423,6 +424,44 @@ func (s *OpenAIGatewayService) ResolveChannelMappingAndRestrict(ctx context.Cont
|
|||||||
return s.channelService.ResolveChannelMappingAndRestrict(ctx, groupID, model)
|
return s.channelService.ResolveChannelMappingAndRestrict(ctx, groupID, model)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) checkChannelPricingRestriction(ctx context.Context, groupID *int64, requestedModel string) bool {
|
||||||
|
if groupID == nil || s.channelService == nil || requestedModel == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
mapping := s.channelService.ResolveChannelMapping(ctx, *groupID, requestedModel)
|
||||||
|
billingModel := billingModelForRestriction(mapping.BillingModelSource, requestedModel, mapping.MappedModel)
|
||||||
|
if billingModel == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return s.channelService.IsModelRestricted(ctx, *groupID, billingModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) isUpstreamModelRestrictedByChannel(ctx context.Context, groupID int64, account *Account, requestedModel string) bool {
|
||||||
|
if s.channelService == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
upstreamModel := resolveOpenAIForwardModel(account, requestedModel, "")
|
||||||
|
if upstreamModel == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return s.channelService.IsModelRestricted(ctx, groupID, upstreamModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) needsUpstreamChannelRestrictionCheck(ctx context.Context, groupID *int64) bool {
|
||||||
|
if groupID == nil || s.channelService == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
ch, err := s.channelService.GetChannelForGroup(ctx, *groupID)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("failed to check openai channel upstream restriction", "group_id", *groupID, "error", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if ch == nil || !ch.RestrictModels {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return ch.BillingModelSource == BillingModelSourceUpstream
|
||||||
|
}
|
||||||
|
|
||||||
// ReplaceModelInBody 替换请求体中的 JSON model 字段(通用 gjson/sjson 实现)。
|
// ReplaceModelInBody 替换请求体中的 JSON model 字段(通用 gjson/sjson 实现)。
|
||||||
func (s *OpenAIGatewayService) ReplaceModelInBody(body []byte, newModel string) []byte {
|
func (s *OpenAIGatewayService) ReplaceModelInBody(body []byte, newModel string) []byte {
|
||||||
return ReplaceModelInBody(body, newModel)
|
return ReplaceModelInBody(body, newModel)
|
||||||
@@ -1162,6 +1201,10 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, stickyAccountID int64) (*Account, error) {
|
func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, stickyAccountID int64) (*Account, error) {
|
||||||
|
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
|
||||||
|
return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
|
||||||
|
}
|
||||||
|
|
||||||
// 1. 尝试粘性会话命中
|
// 1. 尝试粘性会话命中
|
||||||
// Try sticky session hit
|
// Try sticky session hit
|
||||||
if account := s.tryStickySessionHit(ctx, groupID, sessionHash, requestedModel, excludedIDs, stickyAccountID); account != nil {
|
if account := s.tryStickySessionHit(ctx, groupID, sessionHash, requestedModel, excludedIDs, stickyAccountID); account != nil {
|
||||||
@@ -1177,7 +1220,7 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C
|
|||||||
|
|
||||||
// 3. 按优先级 + LRU 选择最佳账号
|
// 3. 按优先级 + LRU 选择最佳账号
|
||||||
// Select by priority + LRU
|
// Select by priority + LRU
|
||||||
selected := s.selectBestAccount(ctx, accounts, requestedModel, excludedIDs)
|
selected := s.selectBestAccount(ctx, groupID, accounts, requestedModel, excludedIDs)
|
||||||
|
|
||||||
if selected == nil {
|
if selected == nil {
|
||||||
if requestedModel != "" {
|
if requestedModel != "" {
|
||||||
@@ -1243,6 +1286,11 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
|
|||||||
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
|
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
if groupID != nil && s.needsUpstreamChannelRestrictionCheck(ctx, groupID) &&
|
||||||
|
s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel) {
|
||||||
|
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// 刷新会话 TTL 并返回账号
|
// 刷新会话 TTL 并返回账号
|
||||||
// Refresh session TTL and return account
|
// Refresh session TTL and return account
|
||||||
@@ -1255,8 +1303,9 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
|
|||||||
//
|
//
|
||||||
// selectBestAccount selects the best account from candidates (priority + LRU).
|
// selectBestAccount selects the best account from candidates (priority + LRU).
|
||||||
// Returns nil if no available account.
|
// Returns nil if no available account.
|
||||||
func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account {
|
func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, groupID *int64, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account {
|
||||||
var selected *Account
|
var selected *Account
|
||||||
|
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
|
||||||
|
|
||||||
for i := range accounts {
|
for i := range accounts {
|
||||||
acc := &accounts[i]
|
acc := &accounts[i]
|
||||||
@@ -1275,6 +1324,9 @@ func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, accounts [
|
|||||||
if fresh == nil {
|
if fresh == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
// 选择优先级最高且最久未使用的账号
|
// 选择优先级最高且最久未使用的账号
|
||||||
// Select highest priority and least recently used
|
// Select highest priority and least recently used
|
||||||
@@ -1326,7 +1378,12 @@ func (s *OpenAIGatewayService) isBetterAccount(candidate, current *Account) bool
|
|||||||
|
|
||||||
// SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan.
|
// SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan.
|
||||||
func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) {
|
func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) {
|
||||||
|
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
|
||||||
|
return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
|
||||||
|
}
|
||||||
|
|
||||||
cfg := s.schedulingConfig()
|
cfg := s.schedulingConfig()
|
||||||
|
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
|
||||||
var stickyAccountID int64
|
var stickyAccountID int64
|
||||||
if sessionHash != "" && s.cache != nil {
|
if sessionHash != "" && s.cache != nil {
|
||||||
if accountID, err := s.getStickySessionAccountID(ctx, groupID, sessionHash); err == nil {
|
if accountID, err := s.getStickySessionAccountID(ctx, groupID, sessionHash); err == nil {
|
||||||
@@ -1402,6 +1459,8 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
|||||||
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel)
|
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel)
|
||||||
if account == nil {
|
if account == nil {
|
||||||
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
|
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
|
||||||
|
} else if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel) {
|
||||||
|
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
|
||||||
} else {
|
} else {
|
||||||
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||||
if err == nil && result.Acquired {
|
if err == nil && result.Acquired {
|
||||||
@@ -1447,6 +1506,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
|||||||
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, acc, requestedModel) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
candidates = append(candidates, acc)
|
candidates = append(candidates, acc)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1471,6 +1533,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
|||||||
if fresh == nil {
|
if fresh == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
|
result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
|
||||||
if err == nil && result.Acquired {
|
if err == nil && result.Acquired {
|
||||||
if sessionHash != "" {
|
if sessionHash != "" {
|
||||||
@@ -1525,6 +1590,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
|||||||
if fresh == nil {
|
if fresh == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
|
result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
|
||||||
if err == nil && result.Acquired {
|
if err == nil && result.Acquired {
|
||||||
if sessionHash != "" {
|
if sessionHash != "" {
|
||||||
@@ -1547,6 +1615,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
|||||||
if fresh == nil {
|
if fresh == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
return &AccountSelectionResult{
|
return &AccountSelectionResult{
|
||||||
Account: fresh,
|
Account: fresh,
|
||||||
WaitPlan: &AccountWaitPlan{
|
WaitPlan: &AccountWaitPlan{
|
||||||
|
|||||||
@@ -113,6 +113,70 @@ export function findModelConflict(models: string[]): [string, string] | null {
|
|||||||
return null
|
return null
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ── 区间校验 ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
/** 校验区间列表的合法性,返回错误消息;通过则返回 null */
|
||||||
|
export function validateIntervals(intervals: IntervalFormEntry[]): string | null {
|
||||||
|
if (!intervals || intervals.length === 0) return null
|
||||||
|
|
||||||
|
// 按 min_tokens 排序(不修改原数组)
|
||||||
|
const sorted = [...intervals].sort((a, b) => a.min_tokens - b.min_tokens)
|
||||||
|
|
||||||
|
for (let i = 0; i < sorted.length; i++) {
|
||||||
|
const err = validateSingleInterval(sorted[i], i)
|
||||||
|
if (err) return err
|
||||||
|
}
|
||||||
|
return checkIntervalOverlap(sorted)
|
||||||
|
}
|
||||||
|
|
||||||
|
function validateSingleInterval(iv: IntervalFormEntry, idx: number): string | null {
|
||||||
|
if (iv.min_tokens < 0) {
|
||||||
|
return `区间 #${idx + 1}: 最小 token 数 (${iv.min_tokens}) 不能为负数`
|
||||||
|
}
|
||||||
|
if (iv.max_tokens != null) {
|
||||||
|
if (iv.max_tokens <= 0) {
|
||||||
|
return `区间 #${idx + 1}: 最大 token 数 (${iv.max_tokens}) 必须大于 0`
|
||||||
|
}
|
||||||
|
if (iv.max_tokens <= iv.min_tokens) {
|
||||||
|
return `区间 #${idx + 1}: 最大 token 数 (${iv.max_tokens}) 必须大于最小 token 数 (${iv.min_tokens})`
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return validateIntervalPrices(iv, idx)
|
||||||
|
}
|
||||||
|
|
||||||
|
function validateIntervalPrices(iv: IntervalFormEntry, idx: number): string | null {
|
||||||
|
const prices: [string, number | string | null][] = [
|
||||||
|
['输入价格', iv.input_price],
|
||||||
|
['输出价格', iv.output_price],
|
||||||
|
['缓存写入价格', iv.cache_write_price],
|
||||||
|
['缓存读取价格', iv.cache_read_price],
|
||||||
|
['单次价格', iv.per_request_price],
|
||||||
|
]
|
||||||
|
for (const [name, val] of prices) {
|
||||||
|
if (val != null && val !== '' && Number(val) < 0) {
|
||||||
|
return `区间 #${idx + 1}: ${name}不能为负数`
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
|
function checkIntervalOverlap(sorted: IntervalFormEntry[]): string | null {
|
||||||
|
for (let i = 0; i < sorted.length; i++) {
|
||||||
|
// 无上限区间必须是最后一个
|
||||||
|
if (sorted[i].max_tokens == null && i < sorted.length - 1) {
|
||||||
|
return `区间 #${i + 1}: 无上限区间(最大 token 数为空)只能是最后一个`
|
||||||
|
}
|
||||||
|
if (i === 0) continue
|
||||||
|
const prev = sorted[i - 1]
|
||||||
|
// (min, max] 语义:前一个区间上界 > 当前区间下界则重叠
|
||||||
|
if (prev.max_tokens == null || prev.max_tokens > sorted[i].min_tokens) {
|
||||||
|
const prevMax = prev.max_tokens == null ? '∞' : String(prev.max_tokens)
|
||||||
|
return `区间 #${i} 和 #${i + 1} 重叠:前一个区间上界 (${prevMax}) 大于当前区间下界 (${sorted[i].min_tokens})`
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
/** 平台对应的模型 tag 样式(背景+文字) */
|
/** 平台对应的模型 tag 样式(背景+文字) */
|
||||||
export function getPlatformTagClass(platform: string): string {
|
export function getPlatformTagClass(platform: string): string {
|
||||||
switch (platform) {
|
switch (platform) {
|
||||||
|
|||||||
@@ -418,7 +418,7 @@ import { useAppStore } from '@/stores/app'
|
|||||||
import { adminAPI } from '@/api/admin'
|
import { adminAPI } from '@/api/admin'
|
||||||
import type { Channel, ChannelModelPricing, CreateChannelRequest, UpdateChannelRequest } from '@/api/admin/channels'
|
import type { Channel, ChannelModelPricing, CreateChannelRequest, UpdateChannelRequest } from '@/api/admin/channels'
|
||||||
import type { PricingFormEntry } from '@/components/admin/channel/types'
|
import type { PricingFormEntry } from '@/components/admin/channel/types'
|
||||||
import { mTokToPerToken, perTokenToMTok, apiIntervalsToForm, formIntervalsToAPI, findModelConflict } from '@/components/admin/channel/types'
|
import { mTokToPerToken, perTokenToMTok, apiIntervalsToForm, formIntervalsToAPI, findModelConflict, validateIntervals } from '@/components/admin/channel/types'
|
||||||
import type { AdminGroup, GroupPlatform } from '@/types'
|
import type { AdminGroup, GroupPlatform } from '@/types'
|
||||||
import type { Column } from '@/components/common/types'
|
import type { Column } from '@/components/common/types'
|
||||||
import AppLayout from '@/components/layout/AppLayout.vue'
|
import AppLayout from '@/components/layout/AppLayout.vue'
|
||||||
@@ -922,6 +922,21 @@ async function handleSubmit() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 校验区间合法性(范围、重叠等)
|
||||||
|
for (const section of form.platforms.filter(s => s.enabled)) {
|
||||||
|
for (const entry of section.model_pricing) {
|
||||||
|
if (!entry.intervals || entry.intervals.length === 0) continue
|
||||||
|
const intervalErr = validateIntervals(entry.intervals)
|
||||||
|
if (intervalErr) {
|
||||||
|
const platformLabel = t('admin.groups.platforms.' + section.platform, section.platform)
|
||||||
|
const modelLabel = entry.models.join(', ') || '未命名'
|
||||||
|
appStore.showError(`${platformLabel} - ${modelLabel}: ${intervalErr}`)
|
||||||
|
activeTab.value = section.platform
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const { group_ids, model_pricing, model_mapping } = formToAPI()
|
const { group_ids, model_pricing, model_mapping } = formToAPI()
|
||||||
|
|
||||||
submitting.value = true
|
submitting.value = true
|
||||||
|
|||||||
Reference in New Issue
Block a user