feat: 添加5h窗口费用控制和会话数量限制
- 支持Anthropic OAuth/SetupToken账号的5h窗口费用阈值控制 - 支持账号级别的并发会话数量限制 - 使用Redis缓存窗口费用(30秒TTL)减少数据库压力 - 费用计算基于标准费用(不含账号倍率)
This commit is contained in:
@@ -557,3 +557,141 @@ func (a *Account) IsMixedSchedulingEnabled() bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// WindowCostSchedulability 窗口费用调度状态
|
||||
type WindowCostSchedulability int
|
||||
|
||||
const (
|
||||
// WindowCostSchedulable 可正常调度
|
||||
WindowCostSchedulable WindowCostSchedulability = iota
|
||||
// WindowCostStickyOnly 仅允许粘性会话
|
||||
WindowCostStickyOnly
|
||||
// WindowCostNotSchedulable 完全不可调度
|
||||
WindowCostNotSchedulable
|
||||
)
|
||||
|
||||
// IsAnthropicOAuthOrSetupToken 判断是否为 Anthropic OAuth 或 SetupToken 类型账号
|
||||
// 仅这两类账号支持 5h 窗口额度控制和会话数量控制
|
||||
func (a *Account) IsAnthropicOAuthOrSetupToken() bool {
|
||||
return a.Platform == PlatformAnthropic && (a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken)
|
||||
}
|
||||
|
||||
// GetWindowCostLimit 获取 5h 窗口费用阈值(美元)
|
||||
// 返回 0 表示未启用
|
||||
func (a *Account) GetWindowCostLimit() float64 {
|
||||
if a.Extra == nil {
|
||||
return 0
|
||||
}
|
||||
if v, ok := a.Extra["window_cost_limit"]; ok {
|
||||
return parseExtraFloat64(v)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// GetWindowCostStickyReserve 获取粘性会话预留额度(美元)
|
||||
// 默认值为 10
|
||||
func (a *Account) GetWindowCostStickyReserve() float64 {
|
||||
if a.Extra == nil {
|
||||
return 10.0
|
||||
}
|
||||
if v, ok := a.Extra["window_cost_sticky_reserve"]; ok {
|
||||
val := parseExtraFloat64(v)
|
||||
if val > 0 {
|
||||
return val
|
||||
}
|
||||
}
|
||||
return 10.0
|
||||
}
|
||||
|
||||
// GetMaxSessions 获取最大并发会话数
|
||||
// 返回 0 表示未启用
|
||||
func (a *Account) GetMaxSessions() int {
|
||||
if a.Extra == nil {
|
||||
return 0
|
||||
}
|
||||
if v, ok := a.Extra["max_sessions"]; ok {
|
||||
return parseExtraInt(v)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// GetSessionIdleTimeoutMinutes 获取会话空闲超时分钟数
|
||||
// 默认值为 5 分钟
|
||||
func (a *Account) GetSessionIdleTimeoutMinutes() int {
|
||||
if a.Extra == nil {
|
||||
return 5
|
||||
}
|
||||
if v, ok := a.Extra["session_idle_timeout_minutes"]; ok {
|
||||
val := parseExtraInt(v)
|
||||
if val > 0 {
|
||||
return val
|
||||
}
|
||||
}
|
||||
return 5
|
||||
}
|
||||
|
||||
// CheckWindowCostSchedulability 根据当前窗口费用检查调度状态
|
||||
// - 费用 < 阈值: WindowCostSchedulable(可正常调度)
|
||||
// - 费用 >= 阈值 且 < 阈值+预留: WindowCostStickyOnly(仅粘性会话)
|
||||
// - 费用 >= 阈值+预留: WindowCostNotSchedulable(不可调度)
|
||||
func (a *Account) CheckWindowCostSchedulability(currentWindowCost float64) WindowCostSchedulability {
|
||||
limit := a.GetWindowCostLimit()
|
||||
if limit <= 0 {
|
||||
return WindowCostSchedulable
|
||||
}
|
||||
|
||||
if currentWindowCost < limit {
|
||||
return WindowCostSchedulable
|
||||
}
|
||||
|
||||
stickyReserve := a.GetWindowCostStickyReserve()
|
||||
if currentWindowCost < limit+stickyReserve {
|
||||
return WindowCostStickyOnly
|
||||
}
|
||||
|
||||
return WindowCostNotSchedulable
|
||||
}
|
||||
|
||||
// parseExtraFloat64 从 extra 字段解析 float64 值
|
||||
func parseExtraFloat64(value any) float64 {
|
||||
switch v := value.(type) {
|
||||
case float64:
|
||||
return v
|
||||
case float32:
|
||||
return float64(v)
|
||||
case int:
|
||||
return float64(v)
|
||||
case int64:
|
||||
return float64(v)
|
||||
case json.Number:
|
||||
if f, err := v.Float64(); err == nil {
|
||||
return f
|
||||
}
|
||||
case string:
|
||||
if f, err := strconv.ParseFloat(strings.TrimSpace(v), 64); err == nil {
|
||||
return f
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// parseExtraInt 从 extra 字段解析 int 值
|
||||
func parseExtraInt(value any) int {
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
return v
|
||||
case int64:
|
||||
return int(v)
|
||||
case float64:
|
||||
return int(v)
|
||||
case json.Number:
|
||||
if i, err := v.Int64(); err == nil {
|
||||
return int(i)
|
||||
}
|
||||
case string:
|
||||
if i, err := strconv.Atoi(strings.TrimSpace(v)); err == nil {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
@@ -575,3 +575,9 @@ func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// GetAccountWindowStats 获取账号在指定时间窗口内的使用统计
|
||||
// 用于账号列表页面显示当前窗口费用
|
||||
func (s *AccountUsageService) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) {
|
||||
return s.usageLogRepo.GetAccountWindowStats(ctx, accountID, startTime)
|
||||
}
|
||||
|
||||
@@ -1052,7 +1052,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
concurrencyService: nil, // No concurrency service
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
@@ -1105,7 +1105,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
concurrencyService: nil, // legacy path
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-b", nil)
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-b", nil, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
@@ -1137,7 +1137,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
concurrencyService: nil,
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
@@ -1169,7 +1169,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
}
|
||||
|
||||
excludedIDs := map[int64]struct{}{1: {}}
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs)
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
@@ -1203,7 +1203,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil)
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
@@ -1239,7 +1239,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil)
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
@@ -1266,7 +1266,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
concurrencyService: nil,
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
|
||||
require.Error(t, err)
|
||||
require.Nil(t, result)
|
||||
require.Contains(t, err.Error(), "no available accounts")
|
||||
@@ -1298,7 +1298,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
concurrencyService: nil,
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
@@ -1331,7 +1331,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
concurrencyService: nil,
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
|
||||
@@ -176,6 +176,7 @@ type GatewayService struct {
|
||||
deferredService *DeferredService
|
||||
concurrencyService *ConcurrencyService
|
||||
claudeTokenProvider *ClaudeTokenProvider
|
||||
sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken)
|
||||
}
|
||||
|
||||
// NewGatewayService creates a new GatewayService
|
||||
@@ -196,6 +197,7 @@ func NewGatewayService(
|
||||
httpUpstream HTTPUpstream,
|
||||
deferredService *DeferredService,
|
||||
claudeTokenProvider *ClaudeTokenProvider,
|
||||
sessionLimitCache SessionLimitCache,
|
||||
) *GatewayService {
|
||||
return &GatewayService{
|
||||
accountRepo: accountRepo,
|
||||
@@ -214,6 +216,7 @@ func NewGatewayService(
|
||||
httpUpstream: httpUpstream,
|
||||
deferredService: deferredService,
|
||||
claudeTokenProvider: claudeTokenProvider,
|
||||
sessionLimitCache: sessionLimitCache,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -407,8 +410,12 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
|
||||
}
|
||||
|
||||
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
|
||||
func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) {
|
||||
// metadataUserID: 原始 metadata.user_id 字段(用于提取会话 UUID 进行会话数量限制)
|
||||
func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string) (*AccountSelectionResult, error) {
|
||||
cfg := s.schedulingConfig()
|
||||
// 提取会话 UUID(用于会话数量限制)
|
||||
sessionUUID := extractSessionUUID(metadataUserID)
|
||||
|
||||
var stickyAccountID int64
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash); err == nil {
|
||||
@@ -527,7 +534,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
if len(routingAccountIDs) > 0 && s.concurrencyService != nil {
|
||||
// 1. 过滤出路由列表中可调度的账号
|
||||
var routingCandidates []*Account
|
||||
var filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping int
|
||||
var filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping, filteredWindowCost int
|
||||
for _, routingAccountID := range routingAccountIDs {
|
||||
if isExcluded(routingAccountID) {
|
||||
filteredExcluded++
|
||||
@@ -554,13 +561,18 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
filteredModelMapping++
|
||||
continue
|
||||
}
|
||||
// 窗口费用检查(非粘性会话路径)
|
||||
if !s.isAccountSchedulableForWindowCost(ctx, account, false) {
|
||||
filteredWindowCost++
|
||||
continue
|
||||
}
|
||||
routingCandidates = append(routingCandidates, account)
|
||||
}
|
||||
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] routed candidates: group_id=%v model=%s routed=%d candidates=%d filtered(excluded=%d missing=%d unsched=%d platform=%d model_scope=%d model_mapping=%d)",
|
||||
log.Printf("[ModelRoutingDebug] routed candidates: group_id=%v model=%s routed=%d candidates=%d filtered(excluded=%d missing=%d unsched=%d platform=%d model_scope=%d model_mapping=%d window_cost=%d)",
|
||||
derefGroupID(groupID), requestedModel, len(routingAccountIDs), len(routingCandidates),
|
||||
filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping)
|
||||
filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping, filteredWindowCost)
|
||||
}
|
||||
|
||||
if len(routingCandidates) > 0 {
|
||||
@@ -573,18 +585,25 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
if stickyAccount.IsSchedulable() &&
|
||||
s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) &&
|
||||
stickyAccount.IsSchedulableForModel(requestedModel) &&
|
||||
(requestedModel == "" || s.isModelSupportedByAccount(stickyAccount, requestedModel)) {
|
||||
(requestedModel == "" || s.isModelSupportedByAccount(stickyAccount, requestedModel)) &&
|
||||
s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) { // 粘性会话窗口费用检查
|
||||
result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID)
|
||||
// 会话数量限制检查
|
||||
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionUUID) {
|
||||
result.ReleaseFunc() // 释放槽位
|
||||
// 继续到负载感知选择
|
||||
} else {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: stickyAccount,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: stickyAccount,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
}
|
||||
|
||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID)
|
||||
@@ -657,6 +676,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
for _, item := range routingAvailable {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
// 会话数量限制检查
|
||||
if !s.checkAndRegisterSession(ctx, item.account, sessionUUID) {
|
||||
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
|
||||
continue
|
||||
}
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL)
|
||||
}
|
||||
@@ -699,15 +723,21 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
if ok && s.isAccountInGroup(account, groupID) &&
|
||||
s.isAccountAllowedForPlatform(account, platform, useMixed) &&
|
||||
account.IsSchedulableForModel(requestedModel) &&
|
||||
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) &&
|
||||
s.isAccountSchedulableForWindowCost(ctx, account, true) { // 粘性会话窗口费用检查
|
||||
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
// 会话数量限制检查
|
||||
if !s.checkAndRegisterSession(ctx, account, sessionUUID) {
|
||||
result.ReleaseFunc() // 释放槽位,继续到 Layer 2
|
||||
} else {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
|
||||
@@ -748,6 +778,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
// 窗口费用检查(非粘性会话路径)
|
||||
if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
|
||||
continue
|
||||
}
|
||||
candidates = append(candidates, acc)
|
||||
}
|
||||
|
||||
@@ -765,7 +799,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
|
||||
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
|
||||
if err != nil {
|
||||
if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); ok {
|
||||
if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth, sessionUUID); ok {
|
||||
return result, nil
|
||||
}
|
||||
} else {
|
||||
@@ -814,6 +848,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
for _, item := range available {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
// 会话数量限制检查
|
||||
if !s.checkAndRegisterSession(ctx, item.account, sessionUUID) {
|
||||
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
|
||||
continue
|
||||
}
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL)
|
||||
}
|
||||
@@ -843,13 +882,18 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
return nil, errors.New("no available accounts")
|
||||
}
|
||||
|
||||
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) {
|
||||
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool, sessionUUID string) (*AccountSelectionResult, bool) {
|
||||
ordered := append([]*Account(nil), candidates...)
|
||||
sortAccountsByPriorityAndLastUsed(ordered, preferOAuth)
|
||||
|
||||
for _, acc := range ordered {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
// 会话数量限制检查
|
||||
if !s.checkAndRegisterSession(ctx, acc, sessionUUID) {
|
||||
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
|
||||
continue
|
||||
}
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, acc.ID, stickySessionTTL)
|
||||
}
|
||||
@@ -1081,6 +1125,107 @@ func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID in
|
||||
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
|
||||
}
|
||||
|
||||
// isAccountSchedulableForWindowCost 检查账号是否可根据窗口费用进行调度
|
||||
// 仅适用于 Anthropic OAuth/SetupToken 账号
|
||||
// 返回 true 表示可调度,false 表示不可调度
|
||||
func (s *GatewayService) isAccountSchedulableForWindowCost(ctx context.Context, account *Account, isSticky bool) bool {
|
||||
// 只检查 Anthropic OAuth/SetupToken 账号
|
||||
if !account.IsAnthropicOAuthOrSetupToken() {
|
||||
return true
|
||||
}
|
||||
|
||||
limit := account.GetWindowCostLimit()
|
||||
if limit <= 0 {
|
||||
return true // 未启用窗口费用限制
|
||||
}
|
||||
|
||||
// 尝试从缓存获取窗口费用
|
||||
var currentCost float64
|
||||
if s.sessionLimitCache != nil {
|
||||
if cost, hit, err := s.sessionLimitCache.GetWindowCost(ctx, account.ID); err == nil && hit {
|
||||
currentCost = cost
|
||||
goto checkSchedulability
|
||||
}
|
||||
}
|
||||
|
||||
// 缓存未命中,从数据库查询
|
||||
{
|
||||
var startTime time.Time
|
||||
if account.SessionWindowStart != nil {
|
||||
startTime = *account.SessionWindowStart
|
||||
} else {
|
||||
startTime = time.Now().Add(-5 * time.Hour)
|
||||
}
|
||||
|
||||
stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime)
|
||||
if err != nil {
|
||||
// 失败开放:查询失败时允许调度
|
||||
return true
|
||||
}
|
||||
|
||||
// 使用标准费用(不含账号倍率)
|
||||
currentCost = stats.StandardCost
|
||||
|
||||
// 设置缓存(忽略错误)
|
||||
if s.sessionLimitCache != nil {
|
||||
_ = s.sessionLimitCache.SetWindowCost(ctx, account.ID, currentCost)
|
||||
}
|
||||
}
|
||||
|
||||
checkSchedulability:
|
||||
schedulability := account.CheckWindowCostSchedulability(currentCost)
|
||||
|
||||
switch schedulability {
|
||||
case WindowCostSchedulable:
|
||||
return true
|
||||
case WindowCostStickyOnly:
|
||||
return isSticky
|
||||
case WindowCostNotSchedulable:
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// checkAndRegisterSession 检查并注册会话,用于会话数量限制
|
||||
// 仅适用于 Anthropic OAuth/SetupToken 账号
|
||||
// 返回 true 表示允许(在限制内或会话已存在),false 表示拒绝(超出限制且是新会话)
|
||||
func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *Account, sessionUUID string) bool {
|
||||
// 只检查 Anthropic OAuth/SetupToken 账号
|
||||
if !account.IsAnthropicOAuthOrSetupToken() {
|
||||
return true
|
||||
}
|
||||
|
||||
maxSessions := account.GetMaxSessions()
|
||||
if maxSessions <= 0 || sessionUUID == "" {
|
||||
return true // 未启用会话限制或无会话ID
|
||||
}
|
||||
|
||||
if s.sessionLimitCache == nil {
|
||||
return true // 缓存不可用时允许通过
|
||||
}
|
||||
|
||||
idleTimeout := time.Duration(account.GetSessionIdleTimeoutMinutes()) * time.Minute
|
||||
|
||||
allowed, err := s.sessionLimitCache.RegisterSession(ctx, account.ID, sessionUUID, maxSessions, idleTimeout)
|
||||
if err != nil {
|
||||
// 失败开放:缓存错误时允许通过
|
||||
return true
|
||||
}
|
||||
return allowed
|
||||
}
|
||||
|
||||
// extractSessionUUID 从 metadata.user_id 中提取会话 UUID
|
||||
// 格式: user_{64位hex}_account__session_{uuid}
|
||||
func extractSessionUUID(metadataUserID string) string {
|
||||
if metadataUserID == "" {
|
||||
return ""
|
||||
}
|
||||
if match := sessionIDRegex.FindStringSubmatch(metadataUserID); len(match) > 1 {
|
||||
return match[1]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) {
|
||||
if s.schedulerSnapshot != nil {
|
||||
return s.schedulerSnapshot.GetAccount(ctx, accountID)
|
||||
|
||||
@@ -514,7 +514,7 @@ func (s *OpsService) selectAccountForRetry(ctx context.Context, reqType opsRetry
|
||||
if s.gatewayService == nil {
|
||||
return nil, fmt.Errorf("gateway service not available")
|
||||
}
|
||||
return s.gatewayService.SelectAccountWithLoadAwareness(ctx, groupID, "", model, excludedIDs)
|
||||
return s.gatewayService.SelectAccountWithLoadAwareness(ctx, groupID, "", model, excludedIDs, "") // 重试不使用会话限制
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported retry type: %s", reqType)
|
||||
}
|
||||
|
||||
63
backend/internal/service/session_limit_cache.go
Normal file
63
backend/internal/service/session_limit_cache.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SessionLimitCache 管理账号级别的活跃会话跟踪
|
||||
// 用于 Anthropic OAuth/SetupToken 账号的会话数量限制
|
||||
//
|
||||
// Key 格式: session_limit:account:{accountID}
|
||||
// 数据结构: Sorted Set (member=sessionUUID, score=timestamp)
|
||||
//
|
||||
// 会话在空闲超时后自动过期,无需手动清理
|
||||
type SessionLimitCache interface {
|
||||
// RegisterSession 注册会话活动
|
||||
// - 如果会话已存在,刷新其时间戳并返回 true
|
||||
// - 如果会话不存在且活跃会话数 < maxSessions,添加新会话并返回 true
|
||||
// - 如果会话不存在且活跃会话数 >= maxSessions,返回 false(拒绝)
|
||||
//
|
||||
// 参数:
|
||||
// accountID: 账号 ID
|
||||
// sessionUUID: 从 metadata.user_id 中提取的会话 UUID
|
||||
// maxSessions: 最大并发会话数限制
|
||||
// idleTimeout: 会话空闲超时时间
|
||||
//
|
||||
// 返回:
|
||||
// allowed: true 表示允许(在限制内或会话已存在),false 表示拒绝(超出限制且是新会话)
|
||||
// error: 操作错误
|
||||
RegisterSession(ctx context.Context, accountID int64, sessionUUID string, maxSessions int, idleTimeout time.Duration) (allowed bool, err error)
|
||||
|
||||
// RefreshSession 刷新现有会话的时间戳
|
||||
// 用于活跃会话保持活动状态
|
||||
RefreshSession(ctx context.Context, accountID int64, sessionUUID string, idleTimeout time.Duration) error
|
||||
|
||||
// GetActiveSessionCount 获取当前活跃会话数
|
||||
// 返回未过期的会话数量
|
||||
GetActiveSessionCount(ctx context.Context, accountID int64) (int, error)
|
||||
|
||||
// GetActiveSessionCountBatch 批量获取多个账号的活跃会话数
|
||||
// 返回 map[accountID]count,查询失败的账号不在 map 中
|
||||
GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error)
|
||||
|
||||
// IsSessionActive 检查特定会话是否活跃(未过期)
|
||||
IsSessionActive(ctx context.Context, accountID int64, sessionUUID string) (bool, error)
|
||||
|
||||
// ========== 5h窗口费用缓存 ==========
|
||||
// Key 格式: window_cost:account:{accountID}
|
||||
// 用于缓存账号在当前5h窗口内的标准费用,减少数据库聚合查询压力
|
||||
|
||||
// GetWindowCost 获取缓存的窗口费用
|
||||
// 返回 (cost, true, nil) 如果缓存命中
|
||||
// 返回 (0, false, nil) 如果缓存未命中
|
||||
// 返回 (0, false, err) 如果发生错误
|
||||
GetWindowCost(ctx context.Context, accountID int64) (cost float64, hit bool, err error)
|
||||
|
||||
// SetWindowCost 设置窗口费用缓存
|
||||
SetWindowCost(ctx context.Context, accountID int64, cost float64) error
|
||||
|
||||
// GetWindowCostBatch 批量获取窗口费用缓存
|
||||
// 返回 map[accountID]cost,缓存未命中的账号不在 map 中
|
||||
GetWindowCostBatch(ctx context.Context, accountIDs []int64) (map[int64]float64, error)
|
||||
}
|
||||
Reference in New Issue
Block a user