feat(antigravity): comprehensive enhancements - model mapping, rate limiting, scheduling & ops

Key changes:
- Upgrade model mapping: Opus 4.5 → Opus 4.6-thinking with precise matching
- Unified rate limiting: scope-level → model-level with Redis snapshot sync
- Load-balanced scheduling by call count with smart retry mechanism
- Force cache billing support
- Model identity injection in prompts with leak prevention
- Thinking mode auto-handling (max_tokens/budget_tokens fix)
- Frontend: whitelist mode toggle, model mapping validation, status indicators
- Gemini session fallback with Redis Trie O(L) matching
- Ops: enhanced concurrency monitoring, account availability, retry logic
- Migration scripts: 049-051 for model mapping unification
This commit is contained in:
erio
2026-02-07 12:31:10 +08:00
parent e617b45ba3
commit 5e98445b22
73 changed files with 8553 additions and 1926 deletions

View File

@@ -22,6 +22,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
@@ -49,6 +50,29 @@ const (
claudeMimicDebugInfoKey = "claude_mimic_debug_info"
)
// ForceCacheBillingContextKey 强制缓存计费上下文键
// 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费
type forceCacheBillingKeyType struct{}
// accountWithLoad 账号与负载信息的组合,用于负载感知调度
type accountWithLoad struct {
account *Account
loadInfo *AccountLoadInfo
}
var ForceCacheBillingContextKey = forceCacheBillingKeyType{}
// IsForceCacheBilling 检查是否启用强制缓存计费
func IsForceCacheBilling(ctx context.Context) bool {
v, _ := ctx.Value(ForceCacheBillingContextKey).(bool)
return v
}
// WithForceCacheBilling 返回带有强制缓存计费标记的上下文
func WithForceCacheBilling(ctx context.Context) context.Context {
return context.WithValue(ctx, ForceCacheBillingContextKey, true)
}
func (s *GatewayService) debugModelRoutingEnabled() bool {
v := strings.ToLower(strings.TrimSpace(os.Getenv("SUB2API_DEBUG_MODEL_ROUTING")))
return v == "1" || v == "true" || v == "yes" || v == "on"
@@ -250,6 +274,13 @@ var allowedHeaders = map[string]bool{
// GatewayCache 定义网关服务的缓存操作接口。
// 提供粘性会话Sticky Session的存储、查询、刷新和删除功能。
//
// ModelLoadInfo 模型负载信息(用于 Antigravity 调度)
// Model load info for Antigravity scheduling
type ModelLoadInfo struct {
CallCount int64 // 当前分钟调用次数 / Call count in current minute
LastUsedAt time.Time // 最后调度时间(零值表示未调度过)/ Last scheduling time (zero means never scheduled)
}
// GatewayCache defines cache operations for gateway service.
// Provides sticky session storage, retrieval, refresh and deletion capabilities.
type GatewayCache interface {
@@ -265,6 +296,24 @@ type GatewayCache interface {
// DeleteSessionAccountID 删除粘性会话绑定,用于账号不可用时主动清理
// Delete sticky session binding, used to proactively clean up when account becomes unavailable
DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error
// IncrModelCallCount 增加模型调用次数并更新最后调度时间Antigravity 专用)
// Increment model call count and update last scheduling time (Antigravity only)
// 返回更新后的调用次数
IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error)
// GetModelLoadBatch 批量获取账号的模型负载信息Antigravity 专用)
// Batch get model load info for accounts (Antigravity only)
GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error)
// FindGeminiSession 查找 Gemini 会话MGET 倒序匹配)
// Find Gemini session using MGET reverse order matching
// 返回最长匹配的会话信息uuid, accountID
FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool)
// SaveGeminiSession 保存 Gemini 会话
// Save Gemini session binding
SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error
}
// derefGroupID safely dereferences *int64 to int64, returning 0 if nil
@@ -275,16 +324,23 @@ func derefGroupID(groupID *int64) int64 {
return *groupID
}
// stickySessionRateLimitThreshold 定义清除粘性会话的限流时间阈值。
// 当账号限流剩余时间超过此阈值时,清除粘性会话以便切换到其他账号。
// 低于此阈值时保持粘性会话,等待短暂限流结束。
const stickySessionRateLimitThreshold = 10 * time.Second
// shouldClearStickySession 检查账号是否处于不可调度状态,需要清理粘性会话绑定。
// 当账号状态为错误、禁用、不可调度,或处于临时不可调度期间时,返回 true。
// 当账号状态为错误、禁用、不可调度处于临时不可调度期间
// 或模型限流剩余时间超过 stickySessionRateLimitThreshold 时,返回 true。
// 这确保后续请求不会继续使用不可用的账号。
//
// shouldClearStickySession checks if an account is in an unschedulable state
// and the sticky session binding should be cleared.
// Returns true when account status is error/disabled, schedulable is false,
// or within temporary unschedulable period.
// within temporary unschedulable period, or model rate limit remaining time
// exceeds stickySessionRateLimitThreshold.
// This ensures subsequent requests won't continue using unavailable accounts.
func shouldClearStickySession(account *Account) bool {
func shouldClearStickySession(account *Account, requestedModel string) bool {
if account == nil {
return false
}
@@ -294,6 +350,10 @@ func shouldClearStickySession(account *Account) bool {
if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) {
return true
}
// 检查模型限流和 scope 限流,只在超过阈值时清除粘性会话
if remaining := account.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel); remaining > stickySessionRateLimitThreshold {
return true
}
return false
}
@@ -336,8 +396,9 @@ type ForwardResult struct {
// UpstreamFailoverError indicates an upstream error that should trigger account failover.
type UpstreamFailoverError struct {
StatusCode int
ResponseBody []byte // 上游响应体,用于错误透传规则匹配
StatusCode int
ResponseBody []byte // 上游响应体,用于错误透传规则匹配
ForceCacheBilling bool // Antigravity 粘性会话切换时设为 true
}
func (e *UpstreamFailoverError) Error() string {
@@ -470,6 +531,23 @@ func (s *GatewayService) GetCachedSessionAccountID(ctx context.Context, groupID
return accountID, nil
}
// FindGeminiSession 查找 Gemini 会话(基于内容摘要链的 Fallback 匹配)
// 返回最长匹配的会话信息uuid, accountID
func (s *GatewayService) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
if digestChain == "" || s.cache == nil {
return "", 0, false
}
return s.cache.FindGeminiSession(ctx, groupID, prefixHash, digestChain)
}
// SaveGeminiSession 保存 Gemini 会话
func (s *GatewayService) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
if digestChain == "" || s.cache == nil {
return nil
}
return s.cache.SaveGeminiSession(ctx, groupID, prefixHash, digestChain, uuid, accountID)
}
func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string {
if parsed == nil {
return ""
@@ -968,6 +1046,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
// 1. 过滤出路由列表中可调度的账号
var routingCandidates []*Account
var filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping, filteredWindowCost int
var modelScopeSkippedIDs []int64 // 记录因模型限流被跳过的账号 ID
for _, routingAccountID := range routingAccountIDs {
if isExcluded(routingAccountID) {
filteredExcluded++
@@ -986,12 +1065,13 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
filteredPlatform++
continue
}
if !account.IsSchedulableForModel(requestedModel) {
filteredModelScope++
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, account, requestedModel) {
filteredModelMapping++
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(account, requestedModel) {
filteredModelMapping++
if !account.IsSchedulableForModelWithContext(ctx, requestedModel) {
filteredModelScope++
modelScopeSkippedIDs = append(modelScopeSkippedIDs, account.ID)
continue
}
// 窗口费用检查(非粘性会话路径)
@@ -1006,6 +1086,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
log.Printf("[ModelRoutingDebug] routed candidates: group_id=%v model=%s routed=%d candidates=%d filtered(excluded=%d missing=%d unsched=%d platform=%d model_scope=%d model_mapping=%d window_cost=%d)",
derefGroupID(groupID), requestedModel, len(routingAccountIDs), len(routingCandidates),
filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping, filteredWindowCost)
if len(modelScopeSkippedIDs) > 0 {
log.Printf("[ModelRoutingDebug] model_rate_limited accounts skipped: group_id=%v model=%s account_ids=%v",
derefGroupID(groupID), requestedModel, modelScopeSkippedIDs)
}
}
if len(routingCandidates) > 0 {
@@ -1017,8 +1101,8 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if stickyAccount, ok := accountByID[stickyAccountID]; ok {
if stickyAccount.IsSchedulable() &&
s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) &&
stickyAccount.IsSchedulableForModel(requestedModel) &&
(requestedModel == "" || s.isModelSupportedByAccount(stickyAccount, requestedModel)) &&
(requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, stickyAccount, requestedModel)) &&
stickyAccount.IsSchedulableForModelWithContext(ctx, requestedModel) &&
s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) { // 粘性会话窗口费用检查
result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency)
if err == nil && result.Acquired {
@@ -1075,10 +1159,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
routingLoadMap, _ := s.concurrencyService.GetAccountsLoadBatch(ctx, routingLoads)
// 3. 按负载感知排序
type accountWithLoad struct {
account *Account
loadInfo *AccountLoadInfo
}
var routingAvailable []accountWithLoad
for _, acc := range routingCandidates {
loadInfo := routingLoadMap[acc.ID]
@@ -1169,14 +1249,14 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if ok {
// 检查账户是否需要清理粘性会话绑定
// Check if the account needs sticky session cleanup
clearSticky := shouldClearStickySession(account)
clearSticky := shouldClearStickySession(account, requestedModel)
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) &&
s.isAccountAllowedForPlatform(account, platform, useMixed) &&
account.IsSchedulableForModel(requestedModel) &&
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) &&
(requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) &&
account.IsSchedulableForModelWithContext(ctx, requestedModel) &&
s.isAccountSchedulableForWindowCost(ctx, account, true) { // 粘性会话窗口费用检查
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
@@ -1234,10 +1314,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if !s.isAccountAllowedForPlatform(acc, platform, useMixed) {
continue
}
if !acc.IsSchedulableForModel(requestedModel) {
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) {
continue
}
// 窗口费用检查(非粘性会话路径)
@@ -1265,10 +1345,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
return result, nil
}
} else {
type accountWithLoad struct {
account *Account
loadInfo *AccountLoadInfo
}
// Antigravity 平台:获取模型负载信息
var modelLoadMap map[int64]*ModelLoadInfo
isAntigravity := platform == PlatformAntigravity
var available []accountWithLoad
for _, acc := range candidates {
loadInfo := loadMap[acc.ID]
@@ -1283,47 +1363,108 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
}
if len(available) > 0 {
sort.SliceStable(available, func(i, j int) bool {
a, b := available[i], available[j]
if a.account.Priority != b.account.Priority {
return a.account.Priority < b.account.Priority
}
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
}
switch {
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
return true
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
return false
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
if preferOAuth && a.account.Type != b.account.Type {
return a.account.Type == AccountTypeOAuth
}
return false
default:
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
}
})
// Antigravity 平台:按账号实际映射后的模型名获取模型负载(与 Forward 的统计保持一致)
if isAntigravity && requestedModel != "" && s.cache != nil && len(available) > 0 {
modelLoadMap = make(map[int64]*ModelLoadInfo, len(available))
modelToAccountIDs := make(map[string][]int64)
for _, item := range available {
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
mappedModel := mapAntigravityModel(item.account, requestedModel)
if mappedModel == "" {
continue
}
modelToAccountIDs[mappedModel] = append(modelToAccountIDs[mappedModel], item.account.ID)
}
for model, ids := range modelToAccountIDs {
batch, err := s.cache.GetModelLoadBatch(ctx, ids, model)
if err != nil {
continue
}
for id, info := range batch {
modelLoadMap[id] = info
}
}
if len(modelLoadMap) == 0 {
modelLoadMap = nil
}
}
// Antigravity 平台:优先级硬过滤 →(同优先级内)按调用次数选择(最少优先,新账号用平均值)
// 其他平台:分层过滤选择:优先级 → 负载率 → LRU
if isAntigravity {
for len(available) > 0 {
// 1. 取优先级最小的集合(硬过滤)
candidates := filterByMinPriority(available)
// 2. 同优先级内按调用次数选择(调用次数最少优先,新账号使用平均值)
selected := selectByCallCount(candidates, modelLoadMap, preferOAuth)
if selected == nil {
break
}
result, err := s.tryAcquireAccountSlot(ctx, selected.account.ID, selected.account.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
if !s.checkAndRegisterSession(ctx, item.account, sessionHash) {
if !s.checkAndRegisterSession(ctx, selected.account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
continue
} else {
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL)
}
return &AccountSelectionResult{
Account: selected.account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL)
}
return &AccountSelectionResult{
Account: item.account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
// 移除已尝试的账号,重新选择
selectedID := selected.account.ID
newAvailable := make([]accountWithLoad, 0, len(available)-1)
for _, acc := range available {
if acc.account.ID != selectedID {
newAvailable = append(newAvailable, acc)
}
}
available = newAvailable
}
} else {
for len(available) > 0 {
// 1. 取优先级最小的集合
candidates := filterByMinPriority(available)
// 2. 取负载率最低的集合
candidates = filterByMinLoadRate(candidates)
// 3. LRU 选择最久未用的账号
selected := selectByLRU(candidates, preferOAuth)
if selected == nil {
break
}
result, err := s.tryAcquireAccountSlot(ctx, selected.account.ID, selected.account.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
if !s.checkAndRegisterSession(ctx, selected.account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
} else {
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL)
}
return &AccountSelectionResult{
Account: selected.account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
}
// 移除已尝试的账号,重新进行分层过滤
selectedID := selected.account.ID
newAvailable := make([]accountWithLoad, 0, len(available)-1)
for _, acc := range available {
if acc.account.ID != selectedID {
newAvailable = append(newAvailable, acc)
}
}
available = newAvailable
}
}
}
@@ -1740,6 +1881,106 @@ func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID in
return s.accountRepo.GetByID(ctx, accountID)
}
// filterByMinPriority 过滤出优先级最小的账号集合
func filterByMinPriority(accounts []accountWithLoad) []accountWithLoad {
if len(accounts) == 0 {
return accounts
}
minPriority := accounts[0].account.Priority
for _, acc := range accounts[1:] {
if acc.account.Priority < minPriority {
minPriority = acc.account.Priority
}
}
result := make([]accountWithLoad, 0, len(accounts))
for _, acc := range accounts {
if acc.account.Priority == minPriority {
result = append(result, acc)
}
}
return result
}
// filterByMinLoadRate 过滤出负载率最低的账号集合
func filterByMinLoadRate(accounts []accountWithLoad) []accountWithLoad {
if len(accounts) == 0 {
return accounts
}
minLoadRate := accounts[0].loadInfo.LoadRate
for _, acc := range accounts[1:] {
if acc.loadInfo.LoadRate < minLoadRate {
minLoadRate = acc.loadInfo.LoadRate
}
}
result := make([]accountWithLoad, 0, len(accounts))
for _, acc := range accounts {
if acc.loadInfo.LoadRate == minLoadRate {
result = append(result, acc)
}
}
return result
}
// selectByLRU 从集合中选择最久未用的账号
// 如果有多个账号具有相同的最小 LastUsedAt则随机选择一个
func selectByLRU(accounts []accountWithLoad, preferOAuth bool) *accountWithLoad {
if len(accounts) == 0 {
return nil
}
if len(accounts) == 1 {
return &accounts[0]
}
// 1. 找到最小的 LastUsedAtnil 被视为最小)
var minTime *time.Time
hasNil := false
for _, acc := range accounts {
if acc.account.LastUsedAt == nil {
hasNil = true
break
}
if minTime == nil || acc.account.LastUsedAt.Before(*minTime) {
minTime = acc.account.LastUsedAt
}
}
// 2. 收集所有具有最小 LastUsedAt 的账号索引
var candidateIdxs []int
for i, acc := range accounts {
if hasNil {
if acc.account.LastUsedAt == nil {
candidateIdxs = append(candidateIdxs, i)
}
} else {
if acc.account.LastUsedAt != nil && acc.account.LastUsedAt.Equal(*minTime) {
candidateIdxs = append(candidateIdxs, i)
}
}
}
// 3. 如果只有一个候选,直接返回
if len(candidateIdxs) == 1 {
return &accounts[candidateIdxs[0]]
}
// 4. 如果有多个候选且 preferOAuth优先选择 OAuth 类型
if preferOAuth {
var oauthIdxs []int
for _, idx := range candidateIdxs {
if accounts[idx].account.Type == AccountTypeOAuth {
oauthIdxs = append(oauthIdxs, idx)
}
}
if len(oauthIdxs) > 0 {
candidateIdxs = oauthIdxs
}
}
// 5. 随机选择一个
selectedIdx := candidateIdxs[mathrand.Intn(len(candidateIdxs))]
return &accounts[selectedIdx]
}
func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
sort.SliceStable(accounts, func(i, j int) bool {
a, b := accounts[i], accounts[j]
@@ -1762,6 +2003,87 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
})
}
// selectByCallCount 从候选账号中选择调用次数最少的账号Antigravity 专用)
// 新账号CallCount=0使用平均调用次数作为虚拟值避免冷启动被猛调
// 如果有多个账号具有相同的最小调用次数,则随机选择一个
func selectByCallCount(accounts []accountWithLoad, modelLoadMap map[int64]*ModelLoadInfo, preferOAuth bool) *accountWithLoad {
if len(accounts) == 0 {
return nil
}
if len(accounts) == 1 {
return &accounts[0]
}
// 如果没有负载信息,回退到 LRU
if modelLoadMap == nil {
return selectByLRU(accounts, preferOAuth)
}
// 1. 计算平均调用次数(用于新账号冷启动)
var totalCallCount int64
var countWithCalls int
for _, acc := range accounts {
if info := modelLoadMap[acc.account.ID]; info != nil && info.CallCount > 0 {
totalCallCount += info.CallCount
countWithCalls++
}
}
var avgCallCount int64
if countWithCalls > 0 {
avgCallCount = totalCallCount / int64(countWithCalls)
}
// 2. 获取每个账号的有效调用次数
getEffectiveCallCount := func(acc accountWithLoad) int64 {
if acc.account == nil {
return 0
}
info := modelLoadMap[acc.account.ID]
if info == nil || info.CallCount == 0 {
return avgCallCount // 新账号使用平均值
}
return info.CallCount
}
// 3. 找到最小调用次数
minCount := getEffectiveCallCount(accounts[0])
for _, acc := range accounts[1:] {
if c := getEffectiveCallCount(acc); c < minCount {
minCount = c
}
}
// 4. 收集所有具有最小调用次数的账号
var candidateIdxs []int
for i, acc := range accounts {
if getEffectiveCallCount(acc) == minCount {
candidateIdxs = append(candidateIdxs, i)
}
}
// 5. 如果只有一个候选,直接返回
if len(candidateIdxs) == 1 {
return &accounts[candidateIdxs[0]]
}
// 6. preferOAuth 处理
if preferOAuth {
var oauthIdxs []int
for _, idx := range candidateIdxs {
if accounts[idx].account.Type == AccountTypeOAuth {
oauthIdxs = append(oauthIdxs, idx)
}
}
if len(oauthIdxs) > 0 {
candidateIdxs = oauthIdxs
}
}
// 7. 随机选择
return &accounts[candidateIdxs[mathrand.Intn(len(candidateIdxs))]]
}
// sortCandidatesForFallback 根据配置选择排序策略
// mode: "last_used"(按最后使用时间) 或 "random"(随机)
func (s *GatewayService) sortCandidatesForFallback(accounts []*Account, preferOAuth bool, mode string) {
@@ -1843,11 +2165,11 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
account, err := s.getSchedulableAccount(ctx, accountID)
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
if err == nil {
clearSticky := shouldClearStickySession(account)
clearSticky := shouldClearStickySession(account, requestedModel)
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
@@ -1894,10 +2216,10 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if !acc.IsSchedulable() {
continue
}
if !acc.IsSchedulableForModel(requestedModel) {
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) {
continue
}
if selected == nil {
@@ -1946,11 +2268,11 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
account, err := s.getSchedulableAccount(ctx, accountID)
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
if err == nil {
clearSticky := shouldClearStickySession(account)
clearSticky := shouldClearStickySession(account, requestedModel)
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
@@ -1986,10 +2308,10 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if !acc.IsSchedulable() {
continue
}
if !acc.IsSchedulableForModel(requestedModel) {
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) {
continue
}
if selected == nil {
@@ -2056,11 +2378,11 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
account, err := s.getSchedulableAccount(ctx, accountID)
// 检查账号分组归属和有效性原生平台直接匹配antigravity 需要启用混合调度
if err == nil {
clearSticky := shouldClearStickySession(account)
clearSticky := shouldClearStickySession(account, requestedModel)
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
@@ -2109,10 +2431,10 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue
}
if !acc.IsSchedulableForModel(requestedModel) {
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) {
continue
}
if selected == nil {
@@ -2161,11 +2483,11 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
account, err := s.getSchedulableAccount(ctx, accountID)
// 检查账号分组归属和有效性原生平台直接匹配antigravity 需要启用混合调度
if err == nil {
clearSticky := shouldClearStickySession(account)
clearSticky := shouldClearStickySession(account, requestedModel)
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
@@ -2203,10 +2525,10 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue
}
if !acc.IsSchedulableForModel(requestedModel) {
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) {
continue
}
if selected == nil {
@@ -2250,11 +2572,44 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
return selected, nil
}
// isModelSupportedByAccount 根据账户平台检查模型支持
func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedModel string) bool {
// isModelSupportedByAccountWithContext 根据账户平台检查模型支持(带 context
// 对于 Antigravity 平台,会先获取映射后的最终模型名(包括 thinking 后缀)再检查支持
func (s *GatewayService) isModelSupportedByAccountWithContext(ctx context.Context, account *Account, requestedModel string) bool {
if account.Platform == PlatformAntigravity {
// Antigravity 平台使用专门的模型支持检查
return IsAntigravityModelSupported(requestedModel)
if strings.TrimSpace(requestedModel) == "" {
return true
}
if !IsAntigravityModelSupported(requestedModel) {
return false
}
// 先用默认映射获取基础模型名,再应用 thinking 后缀
defaultMapped, exists := domain.DefaultAntigravityModelMapping[requestedModel]
if !exists || defaultMapped == "" {
return false
}
finalModel := defaultMapped
if enabled, ok := ctx.Value(ctxkey.ThinkingEnabled).(bool); ok {
finalModel = applyThinkingModelSuffix(finalModel, enabled)
}
// 使用最终模型名检查 model_mapping 支持
return account.IsModelSupported(finalModel)
}
return s.isModelSupportedByAccount(account, requestedModel)
}
// isModelSupportedByAccount 根据账户平台检查模型支持(无 context用于非 Antigravity 平台)
func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedModel string) bool {
if account.Platform == PlatformAntigravity {
// Antigravity 应使用 isModelSupportedByAccountWithContext
// 这里作为兼容保留,使用原始模型名检查
if strings.TrimSpace(requestedModel) == "" {
return true
}
if !IsAntigravityModelSupported(requestedModel) {
return false
}
return account.IsModelSupported(requestedModel)
}
// OAuth/SetupToken 账号使用 Anthropic 标准映射短ID → 长ID
if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
@@ -2269,10 +2624,11 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
}
// IsAntigravityModelSupported 检查 Antigravity 平台是否支持指定模型
// 所有 claude- 和 gemini- 前缀的模型都能通过映射或透传支持
// 只有在默认映射DefaultAntigravityModelMapping中配置的模型才被支持
func IsAntigravityModelSupported(requestedModel string) bool {
return strings.HasPrefix(requestedModel, "claude-") ||
strings.HasPrefix(requestedModel, "gemini-")
// 检查是否在默认映射的 key 中
_, exists := domain.DefaultAntigravityModelMapping[requestedModel]
return exists
}
// GetAccessToken 获取账号凭证
@@ -3563,34 +3919,6 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
)
}
// 非 failover 错误也支持错误透传规则匹配。
if status, errType, errMsg, matched := applyErrorPassthroughRule(
c,
account.Platform,
resp.StatusCode,
body,
http.StatusBadGateway,
"upstream_error",
"Upstream request failed",
); matched {
c.JSON(status, gin.H{
"type": "error",
"error": gin.H{
"type": errType,
"message": errMsg,
},
})
summary := upstreamMsg
if summary == "" {
summary = errMsg
}
if summary == "" {
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched)", resp.StatusCode)
}
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", resp.StatusCode, summary)
}
// 根据状态码返回适当的自定义错误响应(不透传上游详细信息)
var errType, errMsg string
var statusCode int
@@ -3722,33 +4050,6 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht
)
}
if status, errType, errMsg, matched := applyErrorPassthroughRule(
c,
account.Platform,
resp.StatusCode,
respBody,
http.StatusBadGateway,
"upstream_error",
"Upstream request failed after retries",
); matched {
c.JSON(status, gin.H{
"type": "error",
"error": gin.H{
"type": errType,
"message": errMsg,
},
})
summary := upstreamMsg
if summary == "" {
summary = errMsg
}
if summary == "" {
return nil, fmt.Errorf("upstream error: %d (retries exhausted, passthrough rule matched)", resp.StatusCode)
}
return nil, fmt.Errorf("upstream error: %d (retries exhausted, passthrough rule matched) message=%s", resp.StatusCode, summary)
}
// 返回统一的重试耗尽错误响应
c.JSON(http.StatusBadGateway, gin.H{
"type": "error",
@@ -4162,14 +4463,15 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
// RecordUsageInput 记录使用量的输入参数
type RecordUsageInput struct {
Result *ForwardResult
APIKey *APIKey
User *User
Account *Account
Subscription *UserSubscription // 可选:订阅信息
UserAgent string // 请求的 User-Agent
IPAddress string // 请求的客户端 IP 地址
APIKeyService APIKeyQuotaUpdater // 可选用于更新API Key配额
Result *ForwardResult
APIKey *APIKey
User *User
Account *Account
Subscription *UserSubscription // 可选:订阅信息
UserAgent string // 请求的 User-Agent
IPAddress string // 请求的客户端 IP 地址
ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
APIKeyService APIKeyQuotaUpdater // 可选用于更新API Key配额
}
// APIKeyQuotaUpdater defines the interface for updating API Key quota
@@ -4185,6 +4487,15 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
account := input.Account
subscription := input.Subscription
// 强制缓存计费:将 input_tokens 转为 cache_read_input_tokens
// 用于粘性会话切换时的特殊计费处理
if input.ForceCacheBilling && result.Usage.InputTokens > 0 {
log.Printf("force_cache_billing: %d input_tokens → cache_read_input_tokens (account=%d)",
result.Usage.InputTokens, account.ID)
result.Usage.CacheReadInputTokens += result.Usage.InputTokens
result.Usage.InputTokens = 0
}
// 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
multiplier := s.cfg.Default.RateMultiplier
if apiKey.GroupID != nil && apiKey.Group != nil {
@@ -4345,6 +4656,7 @@ type RecordUsageLongContextInput struct {
IPAddress string // 请求的客户端 IP 地址
LongContextThreshold int // 长上下文阈值(如 200000
LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0
ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
APIKeyService *APIKeyService // API Key 配额服务(可选)
}
@@ -4356,6 +4668,15 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
account := input.Account
subscription := input.Subscription
// 强制缓存计费:将 input_tokens 转为 cache_read_input_tokens
// 用于粘性会话切换时的特殊计费处理
if input.ForceCacheBilling && result.Usage.InputTokens > 0 {
log.Printf("force_cache_billing: %d input_tokens → cache_read_input_tokens (account=%d)",
result.Usage.InputTokens, account.ID)
result.Usage.CacheReadInputTokens += result.Usage.InputTokens
result.Usage.InputTokens = 0
}
// 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
multiplier := s.cfg.Default.RateMultiplier
if apiKey.GroupID != nil && apiKey.Group != nil {