Merge branch 'main' of https://github.com/mt21625457/aicodex2api
This commit is contained in:
@@ -5,7 +5,6 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -17,6 +16,7 @@ import (
|
||||
"os"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -26,6 +26,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
"github.com/cespare/xxhash/v2"
|
||||
"github.com/google/uuid"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
@@ -245,9 +246,6 @@ var (
|
||||
// ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问
|
||||
var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients")
|
||||
|
||||
// ErrModelScopeNotSupported 表示请求的模型系列不在分组支持的范围内
|
||||
var ErrModelScopeNotSupported = errors.New("model scope not supported by this group")
|
||||
|
||||
// allowedHeaders 白名单headers(参考CRS项目)
|
||||
var allowedHeaders = map[string]bool{
|
||||
"accept": true,
|
||||
@@ -273,13 +271,6 @@ var allowedHeaders = map[string]bool{
|
||||
// GatewayCache 定义网关服务的缓存操作接口。
|
||||
// 提供粘性会话(Sticky Session)的存储、查询、刷新和删除功能。
|
||||
//
|
||||
// ModelLoadInfo 模型负载信息(用于 Antigravity 调度)
|
||||
// Model load info for Antigravity scheduling
|
||||
type ModelLoadInfo struct {
|
||||
CallCount int64 // 当前分钟调用次数 / Call count in current minute
|
||||
LastUsedAt time.Time // 最后调度时间(零值表示未调度过)/ Last scheduling time (zero means never scheduled)
|
||||
}
|
||||
|
||||
// GatewayCache defines cache operations for gateway service.
|
||||
// Provides sticky session storage, retrieval, refresh and deletion capabilities.
|
||||
type GatewayCache interface {
|
||||
@@ -295,24 +286,6 @@ type GatewayCache interface {
|
||||
// DeleteSessionAccountID 删除粘性会话绑定,用于账号不可用时主动清理
|
||||
// Delete sticky session binding, used to proactively clean up when account becomes unavailable
|
||||
DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error
|
||||
|
||||
// IncrModelCallCount 增加模型调用次数并更新最后调度时间(Antigravity 专用)
|
||||
// Increment model call count and update last scheduling time (Antigravity only)
|
||||
// 返回更新后的调用次数
|
||||
IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error)
|
||||
|
||||
// GetModelLoadBatch 批量获取账号的模型负载信息(Antigravity 专用)
|
||||
// Batch get model load info for accounts (Antigravity only)
|
||||
GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error)
|
||||
|
||||
// FindGeminiSession 查找 Gemini 会话(MGET 倒序匹配)
|
||||
// Find Gemini session using MGET reverse order matching
|
||||
// 返回最长匹配的会话信息(uuid, accountID)
|
||||
FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool)
|
||||
|
||||
// SaveGeminiSession 保存 Gemini 会话
|
||||
// Save Gemini session binding
|
||||
SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error
|
||||
}
|
||||
|
||||
// derefGroupID safely dereferences *int64 to int64, returning 0 if nil
|
||||
@@ -323,21 +296,15 @@ func derefGroupID(groupID *int64) int64 {
|
||||
return *groupID
|
||||
}
|
||||
|
||||
// stickySessionRateLimitThreshold 定义清除粘性会话的限流时间阈值。
|
||||
// 当账号限流剩余时间超过此阈值时,清除粘性会话以便切换到其他账号。
|
||||
// 低于此阈值时保持粘性会话,等待短暂限流结束。
|
||||
const stickySessionRateLimitThreshold = 10 * time.Second
|
||||
|
||||
// shouldClearStickySession 检查账号是否处于不可调度状态,需要清理粘性会话绑定。
|
||||
// 当账号状态为错误、禁用、不可调度、处于临时不可调度期间,
|
||||
// 或模型限流剩余时间超过 stickySessionRateLimitThreshold 时,返回 true。
|
||||
// 或请求的模型处于限流状态时,返回 true。
|
||||
// 这确保后续请求不会继续使用不可用的账号。
|
||||
//
|
||||
// shouldClearStickySession checks if an account is in an unschedulable state
|
||||
// and the sticky session binding should be cleared.
|
||||
// Returns true when account status is error/disabled, schedulable is false,
|
||||
// within temporary unschedulable period, or model rate limit remaining time
|
||||
// exceeds stickySessionRateLimitThreshold.
|
||||
// within temporary unschedulable period, or the requested model is rate-limited.
|
||||
// This ensures subsequent requests won't continue using unavailable accounts.
|
||||
func shouldClearStickySession(account *Account, requestedModel string) bool {
|
||||
if account == nil {
|
||||
@@ -349,8 +316,8 @@ func shouldClearStickySession(account *Account, requestedModel string) bool {
|
||||
if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) {
|
||||
return true
|
||||
}
|
||||
// 检查模型限流和 scope 限流,只在超过阈值时清除粘性会话
|
||||
if remaining := account.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel); remaining > stickySessionRateLimitThreshold {
|
||||
// 检查模型限流和 scope 限流,有限流即清除粘性会话
|
||||
if remaining := account.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel); remaining > 0 {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
@@ -413,6 +380,7 @@ type GatewayService struct {
|
||||
userSubRepo UserSubscriptionRepository
|
||||
userGroupRateRepo UserGroupRateRepository
|
||||
cache GatewayCache
|
||||
digestStore *DigestSessionStore
|
||||
cfg *config.Config
|
||||
schedulerSnapshot *SchedulerSnapshotService
|
||||
billingService *BillingService
|
||||
@@ -446,6 +414,7 @@ func NewGatewayService(
|
||||
deferredService *DeferredService,
|
||||
claudeTokenProvider *ClaudeTokenProvider,
|
||||
sessionLimitCache SessionLimitCache,
|
||||
digestStore *DigestSessionStore,
|
||||
) *GatewayService {
|
||||
return &GatewayService{
|
||||
accountRepo: accountRepo,
|
||||
@@ -455,6 +424,7 @@ func NewGatewayService(
|
||||
userSubRepo: userSubRepo,
|
||||
userGroupRateRepo: userGroupRateRepo,
|
||||
cache: cache,
|
||||
digestStore: digestStore,
|
||||
cfg: cfg,
|
||||
schedulerSnapshot: schedulerSnapshot,
|
||||
concurrencyService: concurrencyService,
|
||||
@@ -488,23 +458,45 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
|
||||
return s.hashContent(cacheableContent)
|
||||
}
|
||||
|
||||
// 3. Fallback: 使用 system 内容
|
||||
// 3. 最后 fallback: 使用 session上下文 + system + 所有消息的完整摘要串
|
||||
var combined strings.Builder
|
||||
// 混入请求上下文区分因子,避免不同用户相同消息产生相同 hash
|
||||
if parsed.SessionContext != nil {
|
||||
_, _ = combined.WriteString(parsed.SessionContext.ClientIP)
|
||||
_, _ = combined.WriteString(":")
|
||||
_, _ = combined.WriteString(parsed.SessionContext.UserAgent)
|
||||
_, _ = combined.WriteString(":")
|
||||
_, _ = combined.WriteString(strconv.FormatInt(parsed.SessionContext.APIKeyID, 10))
|
||||
_, _ = combined.WriteString("|")
|
||||
}
|
||||
if parsed.System != nil {
|
||||
systemText := s.extractTextFromSystem(parsed.System)
|
||||
if systemText != "" {
|
||||
return s.hashContent(systemText)
|
||||
_, _ = combined.WriteString(systemText)
|
||||
}
|
||||
}
|
||||
|
||||
// 4. 最后 fallback: 使用第一条消息
|
||||
if len(parsed.Messages) > 0 {
|
||||
if firstMsg, ok := parsed.Messages[0].(map[string]any); ok {
|
||||
msgText := s.extractTextFromContent(firstMsg["content"])
|
||||
if msgText != "" {
|
||||
return s.hashContent(msgText)
|
||||
for _, msg := range parsed.Messages {
|
||||
if m, ok := msg.(map[string]any); ok {
|
||||
if content, exists := m["content"]; exists {
|
||||
// Anthropic: messages[].content
|
||||
if msgText := s.extractTextFromContent(content); msgText != "" {
|
||||
_, _ = combined.WriteString(msgText)
|
||||
}
|
||||
} else if parts, ok := m["parts"].([]any); ok {
|
||||
// Gemini: contents[].parts[].text
|
||||
for _, part := range parts {
|
||||
if partMap, ok := part.(map[string]any); ok {
|
||||
if text, ok := partMap["text"].(string); ok {
|
||||
_, _ = combined.WriteString(text)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if combined.Len() > 0 {
|
||||
return s.hashContent(combined.String())
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
@@ -532,19 +524,37 @@ func (s *GatewayService) GetCachedSessionAccountID(ctx context.Context, groupID
|
||||
|
||||
// FindGeminiSession 查找 Gemini 会话(基于内容摘要链的 Fallback 匹配)
|
||||
// 返回最长匹配的会话信息(uuid, accountID)
|
||||
func (s *GatewayService) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
|
||||
if digestChain == "" || s.cache == nil {
|
||||
return "", 0, false
|
||||
func (s *GatewayService) FindGeminiSession(_ context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, matchedChain string, found bool) {
|
||||
if digestChain == "" || s.digestStore == nil {
|
||||
return "", 0, "", false
|
||||
}
|
||||
return s.cache.FindGeminiSession(ctx, groupID, prefixHash, digestChain)
|
||||
return s.digestStore.Find(groupID, prefixHash, digestChain)
|
||||
}
|
||||
|
||||
// SaveGeminiSession 保存 Gemini 会话
|
||||
func (s *GatewayService) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
|
||||
if digestChain == "" || s.cache == nil {
|
||||
// SaveGeminiSession 保存 Gemini 会话。oldDigestChain 为 Find 返回的 matchedChain,用于删旧 key。
|
||||
func (s *GatewayService) SaveGeminiSession(_ context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64, oldDigestChain string) error {
|
||||
if digestChain == "" || s.digestStore == nil {
|
||||
return nil
|
||||
}
|
||||
return s.cache.SaveGeminiSession(ctx, groupID, prefixHash, digestChain, uuid, accountID)
|
||||
s.digestStore.Save(groupID, prefixHash, digestChain, uuid, accountID, oldDigestChain)
|
||||
return nil
|
||||
}
|
||||
|
||||
// FindAnthropicSession 查找 Anthropic 会话(基于内容摘要链的 Fallback 匹配)
|
||||
func (s *GatewayService) FindAnthropicSession(_ context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, matchedChain string, found bool) {
|
||||
if digestChain == "" || s.digestStore == nil {
|
||||
return "", 0, "", false
|
||||
}
|
||||
return s.digestStore.Find(groupID, prefixHash, digestChain)
|
||||
}
|
||||
|
||||
// SaveAnthropicSession 保存 Anthropic 会话
|
||||
func (s *GatewayService) SaveAnthropicSession(_ context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64, oldDigestChain string) error {
|
||||
if digestChain == "" || s.digestStore == nil {
|
||||
return nil
|
||||
}
|
||||
s.digestStore.Save(groupID, prefixHash, digestChain, uuid, accountID, oldDigestChain)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string {
|
||||
@@ -629,8 +639,8 @@ func (s *GatewayService) extractTextFromContent(content any) string {
|
||||
}
|
||||
|
||||
func (s *GatewayService) hashContent(content string) string {
|
||||
hash := sha256.Sum256([]byte(content))
|
||||
return hex.EncodeToString(hash[:16]) // 32字符
|
||||
h := xxhash.Sum64String(content)
|
||||
return strconv.FormatUint(h, 36)
|
||||
}
|
||||
|
||||
// replaceModelInBody 替换请求体中的model字段
|
||||
@@ -989,13 +999,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
log.Printf("[ModelRoutingDebug] load-aware enabled: group_id=%v model=%s session=%s platform=%s", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), platform)
|
||||
}
|
||||
|
||||
// Antigravity 模型系列检查(在账号选择前检查,确保所有代码路径都经过此检查)
|
||||
if platform == PlatformAntigravity && groupID != nil && requestedModel != "" {
|
||||
if err := s.checkAntigravityModelScope(ctx, *groupID, requestedModel); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1110,7 +1113,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
result.ReleaseFunc() // 释放槽位
|
||||
// 继续到负载感知选择
|
||||
} else {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID)
|
||||
}
|
||||
@@ -1190,6 +1192,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
|
||||
}
|
||||
})
|
||||
shuffleWithinSortGroups(routingAvailable)
|
||||
|
||||
// 4. 尝试获取槽位
|
||||
for _, item := range routingAvailable {
|
||||
@@ -1264,7 +1267,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
|
||||
result.ReleaseFunc() // 释放槽位,继续到 Layer 2
|
||||
} else {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
Acquired: true,
|
||||
@@ -1344,10 +1346,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
return result, nil
|
||||
}
|
||||
} else {
|
||||
// Antigravity 平台:获取模型负载信息
|
||||
var modelLoadMap map[int64]*ModelLoadInfo
|
||||
isAntigravity := platform == PlatformAntigravity
|
||||
|
||||
var available []accountWithLoad
|
||||
for _, acc := range candidates {
|
||||
loadInfo := loadMap[acc.ID]
|
||||
@@ -1362,109 +1360,44 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
}
|
||||
}
|
||||
|
||||
// Antigravity 平台:按账号实际映射后的模型名获取模型负载(与 Forward 的统计保持一致)
|
||||
if isAntigravity && requestedModel != "" && s.cache != nil && len(available) > 0 {
|
||||
modelLoadMap = make(map[int64]*ModelLoadInfo, len(available))
|
||||
modelToAccountIDs := make(map[string][]int64)
|
||||
for _, item := range available {
|
||||
mappedModel := mapAntigravityModel(item.account, requestedModel)
|
||||
if mappedModel == "" {
|
||||
continue
|
||||
}
|
||||
modelToAccountIDs[mappedModel] = append(modelToAccountIDs[mappedModel], item.account.ID)
|
||||
// 分层过滤选择:优先级 → 负载率 → LRU
|
||||
for len(available) > 0 {
|
||||
// 1. 取优先级最小的集合
|
||||
candidates := filterByMinPriority(available)
|
||||
// 2. 取负载率最低的集合
|
||||
candidates = filterByMinLoadRate(candidates)
|
||||
// 3. LRU 选择最久未用的账号
|
||||
selected := selectByLRU(candidates, preferOAuth)
|
||||
if selected == nil {
|
||||
break
|
||||
}
|
||||
for model, ids := range modelToAccountIDs {
|
||||
batch, err := s.cache.GetModelLoadBatch(ctx, ids, model)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
for id, info := range batch {
|
||||
modelLoadMap[id] = info
|
||||
}
|
||||
}
|
||||
if len(modelLoadMap) == 0 {
|
||||
modelLoadMap = nil
|
||||
}
|
||||
}
|
||||
|
||||
// Antigravity 平台:优先级硬过滤 →(同优先级内)按调用次数选择(最少优先,新账号用平均值)
|
||||
// 其他平台:分层过滤选择:优先级 → 负载率 → LRU
|
||||
if isAntigravity {
|
||||
for len(available) > 0 {
|
||||
// 1. 取优先级最小的集合(硬过滤)
|
||||
candidates := filterByMinPriority(available)
|
||||
// 2. 同优先级内按调用次数选择(调用次数最少优先,新账号使用平均值)
|
||||
selected := selectByCallCount(candidates, modelLoadMap, preferOAuth)
|
||||
if selected == nil {
|
||||
break
|
||||
}
|
||||
|
||||
result, err := s.tryAcquireAccountSlot(ctx, selected.account.ID, selected.account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
// 会话数量限制检查
|
||||
if !s.checkAndRegisterSession(ctx, selected.account, sessionHash) {
|
||||
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
|
||||
} else {
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: selected.account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
result, err := s.tryAcquireAccountSlot(ctx, selected.account.ID, selected.account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
// 会话数量限制检查
|
||||
if !s.checkAndRegisterSession(ctx, selected.account, sessionHash) {
|
||||
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
|
||||
} else {
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: selected.account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 移除已尝试的账号,重新选择
|
||||
selectedID := selected.account.ID
|
||||
newAvailable := make([]accountWithLoad, 0, len(available)-1)
|
||||
for _, acc := range available {
|
||||
if acc.account.ID != selectedID {
|
||||
newAvailable = append(newAvailable, acc)
|
||||
}
|
||||
}
|
||||
available = newAvailable
|
||||
}
|
||||
} else {
|
||||
for len(available) > 0 {
|
||||
// 1. 取优先级最小的集合
|
||||
candidates := filterByMinPriority(available)
|
||||
// 2. 取负载率最低的集合
|
||||
candidates = filterByMinLoadRate(candidates)
|
||||
// 3. LRU 选择最久未用的账号
|
||||
selected := selectByLRU(candidates, preferOAuth)
|
||||
if selected == nil {
|
||||
break
|
||||
}
|
||||
|
||||
result, err := s.tryAcquireAccountSlot(ctx, selected.account.ID, selected.account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
// 会话数量限制检查
|
||||
if !s.checkAndRegisterSession(ctx, selected.account, sessionHash) {
|
||||
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
|
||||
} else {
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: selected.account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
}
|
||||
// 移除已尝试的账号,重新进行分层过滤
|
||||
selectedID := selected.account.ID
|
||||
newAvailable := make([]accountWithLoad, 0, len(available)-1)
|
||||
for _, acc := range available {
|
||||
if acc.account.ID != selectedID {
|
||||
newAvailable = append(newAvailable, acc)
|
||||
}
|
||||
|
||||
// 移除已尝试的账号,重新进行分层过滤
|
||||
selectedID := selected.account.ID
|
||||
newAvailable := make([]accountWithLoad, 0, len(available)-1)
|
||||
for _, acc := range available {
|
||||
if acc.account.ID != selectedID {
|
||||
newAvailable = append(newAvailable, acc)
|
||||
}
|
||||
}
|
||||
available = newAvailable
|
||||
}
|
||||
available = newAvailable
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2000,87 +1933,79 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
|
||||
return a.LastUsedAt.Before(*b.LastUsedAt)
|
||||
}
|
||||
})
|
||||
shuffleWithinPriorityAndLastUsed(accounts)
|
||||
}
|
||||
|
||||
// selectByCallCount 从候选账号中选择调用次数最少的账号(Antigravity 专用)
|
||||
// 新账号(CallCount=0)使用平均调用次数作为虚拟值,避免冷启动被猛调
|
||||
// 如果有多个账号具有相同的最小调用次数,则随机选择一个
|
||||
func selectByCallCount(accounts []accountWithLoad, modelLoadMap map[int64]*ModelLoadInfo, preferOAuth bool) *accountWithLoad {
|
||||
if len(accounts) == 0 {
|
||||
return nil
|
||||
// shuffleWithinSortGroups 对排序后的 accountWithLoad 切片,按 (Priority, LoadRate, LastUsedAt) 分组后组内随机打乱。
|
||||
// 防止并发请求读取同一快照时,确定性排序导致所有请求命中相同账号。
|
||||
func shuffleWithinSortGroups(accounts []accountWithLoad) {
|
||||
if len(accounts) <= 1 {
|
||||
return
|
||||
}
|
||||
if len(accounts) == 1 {
|
||||
return &accounts[0]
|
||||
}
|
||||
|
||||
// 如果没有负载信息,回退到 LRU
|
||||
if modelLoadMap == nil {
|
||||
return selectByLRU(accounts, preferOAuth)
|
||||
}
|
||||
|
||||
// 1. 计算平均调用次数(用于新账号冷启动)
|
||||
var totalCallCount int64
|
||||
var countWithCalls int
|
||||
for _, acc := range accounts {
|
||||
if info := modelLoadMap[acc.account.ID]; info != nil && info.CallCount > 0 {
|
||||
totalCallCount += info.CallCount
|
||||
countWithCalls++
|
||||
i := 0
|
||||
for i < len(accounts) {
|
||||
j := i + 1
|
||||
for j < len(accounts) && sameAccountWithLoadGroup(accounts[i], accounts[j]) {
|
||||
j++
|
||||
}
|
||||
}
|
||||
|
||||
var avgCallCount int64
|
||||
if countWithCalls > 0 {
|
||||
avgCallCount = totalCallCount / int64(countWithCalls)
|
||||
}
|
||||
|
||||
// 2. 获取每个账号的有效调用次数
|
||||
getEffectiveCallCount := func(acc accountWithLoad) int64 {
|
||||
if acc.account == nil {
|
||||
return 0
|
||||
if j-i > 1 {
|
||||
mathrand.Shuffle(j-i, func(a, b int) {
|
||||
accounts[i+a], accounts[i+b] = accounts[i+b], accounts[i+a]
|
||||
})
|
||||
}
|
||||
info := modelLoadMap[acc.account.ID]
|
||||
if info == nil || info.CallCount == 0 {
|
||||
return avgCallCount // 新账号使用平均值
|
||||
}
|
||||
return info.CallCount
|
||||
i = j
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 找到最小调用次数
|
||||
minCount := getEffectiveCallCount(accounts[0])
|
||||
for _, acc := range accounts[1:] {
|
||||
if c := getEffectiveCallCount(acc); c < minCount {
|
||||
minCount = c
|
||||
}
|
||||
// sameAccountWithLoadGroup 判断两个 accountWithLoad 是否属于同一排序组
|
||||
func sameAccountWithLoadGroup(a, b accountWithLoad) bool {
|
||||
if a.account.Priority != b.account.Priority {
|
||||
return false
|
||||
}
|
||||
|
||||
// 4. 收集所有具有最小调用次数的账号
|
||||
var candidateIdxs []int
|
||||
for i, acc := range accounts {
|
||||
if getEffectiveCallCount(acc) == minCount {
|
||||
candidateIdxs = append(candidateIdxs, i)
|
||||
}
|
||||
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
|
||||
return false
|
||||
}
|
||||
return sameLastUsedAt(a.account.LastUsedAt, b.account.LastUsedAt)
|
||||
}
|
||||
|
||||
// 5. 如果只有一个候选,直接返回
|
||||
if len(candidateIdxs) == 1 {
|
||||
return &accounts[candidateIdxs[0]]
|
||||
// shuffleWithinPriorityAndLastUsed 对排序后的 []*Account 切片,按 (Priority, LastUsedAt) 分组后组内随机打乱。
|
||||
func shuffleWithinPriorityAndLastUsed(accounts []*Account) {
|
||||
if len(accounts) <= 1 {
|
||||
return
|
||||
}
|
||||
|
||||
// 6. preferOAuth 处理
|
||||
if preferOAuth {
|
||||
var oauthIdxs []int
|
||||
for _, idx := range candidateIdxs {
|
||||
if accounts[idx].account.Type == AccountTypeOAuth {
|
||||
oauthIdxs = append(oauthIdxs, idx)
|
||||
}
|
||||
i := 0
|
||||
for i < len(accounts) {
|
||||
j := i + 1
|
||||
for j < len(accounts) && sameAccountGroup(accounts[i], accounts[j]) {
|
||||
j++
|
||||
}
|
||||
if len(oauthIdxs) > 0 {
|
||||
candidateIdxs = oauthIdxs
|
||||
if j-i > 1 {
|
||||
mathrand.Shuffle(j-i, func(a, b int) {
|
||||
accounts[i+a], accounts[i+b] = accounts[i+b], accounts[i+a]
|
||||
})
|
||||
}
|
||||
i = j
|
||||
}
|
||||
}
|
||||
|
||||
// 7. 随机选择
|
||||
return &accounts[candidateIdxs[mathrand.Intn(len(candidateIdxs))]]
|
||||
// sameAccountGroup 判断两个 Account 是否属于同一排序组(Priority + LastUsedAt)
|
||||
func sameAccountGroup(a, b *Account) bool {
|
||||
if a.Priority != b.Priority {
|
||||
return false
|
||||
}
|
||||
return sameLastUsedAt(a.LastUsedAt, b.LastUsedAt)
|
||||
}
|
||||
|
||||
// sameLastUsedAt 判断两个 LastUsedAt 是否相同(精度到秒)
|
||||
func sameLastUsedAt(a, b *time.Time) bool {
|
||||
switch {
|
||||
case a == nil && b == nil:
|
||||
return true
|
||||
case a == nil || b == nil:
|
||||
return false
|
||||
default:
|
||||
return a.Unix() == b.Unix()
|
||||
}
|
||||
}
|
||||
|
||||
// sortCandidatesForFallback 根据配置选择排序策略
|
||||
@@ -2135,13 +2060,6 @@ func shuffleWithinPriority(accounts []*Account) {
|
||||
|
||||
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
|
||||
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
|
||||
// 对 Antigravity 平台,检查请求的模型系列是否在分组支持范围内
|
||||
if platform == PlatformAntigravity && groupID != nil && requestedModel != "" {
|
||||
if err := s.checkAntigravityModelScope(ctx, *groupID, requestedModel); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
preferOAuth := platform == PlatformGemini
|
||||
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform)
|
||||
|
||||
@@ -2169,9 +2087,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
}
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) {
|
||||
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||
}
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
|
||||
}
|
||||
@@ -2272,9 +2187,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
}
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) {
|
||||
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||
}
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
@@ -2383,9 +2295,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
}
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) {
|
||||
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
|
||||
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||
}
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
|
||||
}
|
||||
@@ -2488,9 +2397,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
}
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) {
|
||||
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
|
||||
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||
}
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
@@ -5159,27 +5065,6 @@ func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) {
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
// checkAntigravityModelScope 检查 Antigravity 平台的模型系列是否在分组支持范围内
|
||||
func (s *GatewayService) checkAntigravityModelScope(ctx context.Context, groupID int64, requestedModel string) error {
|
||||
scope, ok := ResolveAntigravityQuotaScope(requestedModel)
|
||||
if !ok {
|
||||
return nil // 无法解析 scope,跳过检查
|
||||
}
|
||||
|
||||
group, err := s.resolveGroupByID(ctx, groupID)
|
||||
if err != nil {
|
||||
return nil // 查询失败时放行
|
||||
}
|
||||
if group == nil {
|
||||
return nil // 分组不存在时放行
|
||||
}
|
||||
|
||||
if !IsScopeSupported(group.SupportedModelScopes, scope) {
|
||||
return ErrModelScopeNotSupported
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAvailableModels returns the list of models available for a group
|
||||
// It aggregates model_mapping keys from all schedulable accounts in the group
|
||||
func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string {
|
||||
|
||||
Reference in New Issue
Block a user