diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go
index 27404b02..31e47332 100644
--- a/backend/cmd/server/wire_gen.go
+++ b/backend/cmd/server/wire_gen.go
@@ -118,7 +118,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
- accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService)
+ sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
+ accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache)
oAuthHandler := admin.NewOAuthHandler(oAuthService)
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
@@ -140,7 +141,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
identityService := service.NewIdentityService(identityCache)
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
- gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider)
+ gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache)
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go
index 5dc6ad19..655169cc 100644
--- a/backend/internal/config/config.go
+++ b/backend/internal/config/config.go
@@ -234,6 +234,10 @@ type GatewayConfig struct {
// ConcurrencySlotTTLMinutes: 并发槽位过期时间(分钟)
// 应大于最长 LLM 请求时间,防止请求完成前槽位过期
ConcurrencySlotTTLMinutes int `mapstructure:"concurrency_slot_ttl_minutes"`
+ // SessionIdleTimeoutMinutes: 会话空闲超时时间(分钟),默认 5 分钟
+ // 用于 Anthropic OAuth/SetupToken 账号的会话数量限制功能
+ // 空闲超过此时间的会话将被自动释放
+ SessionIdleTimeoutMinutes int `mapstructure:"session_idle_timeout_minutes"`
// StreamDataIntervalTimeout: 流数据间隔超时(秒),0表示禁用
StreamDataIntervalTimeout int `mapstructure:"stream_data_interval_timeout"`
diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go
index 92fdf2eb..33c91dae 100644
--- a/backend/internal/handler/admin/account_handler.go
+++ b/backend/internal/handler/admin/account_handler.go
@@ -44,6 +44,7 @@ type AccountHandler struct {
accountTestService *service.AccountTestService
concurrencyService *service.ConcurrencyService
crsSyncService *service.CRSSyncService
+ sessionLimitCache service.SessionLimitCache
}
// NewAccountHandler creates a new admin account handler
@@ -58,6 +59,7 @@ func NewAccountHandler(
accountTestService *service.AccountTestService,
concurrencyService *service.ConcurrencyService,
crsSyncService *service.CRSSyncService,
+ sessionLimitCache service.SessionLimitCache,
) *AccountHandler {
return &AccountHandler{
adminService: adminService,
@@ -70,6 +72,7 @@ func NewAccountHandler(
accountTestService: accountTestService,
concurrencyService: concurrencyService,
crsSyncService: crsSyncService,
+ sessionLimitCache: sessionLimitCache,
}
}
@@ -130,6 +133,9 @@ type BulkUpdateAccountsRequest struct {
type AccountWithConcurrency struct {
*dto.Account
CurrentConcurrency int `json:"current_concurrency"`
+ // 以下字段仅对 Anthropic OAuth/SetupToken 账号有效,且仅在启用相应功能时返回
+ CurrentWindowCost *float64 `json:"current_window_cost,omitempty"` // 当前窗口费用
+ ActiveSessions *int `json:"active_sessions,omitempty"` // 当前活跃会话数
}
// List handles listing all accounts with pagination
@@ -164,13 +170,89 @@ func (h *AccountHandler) List(c *gin.Context) {
concurrencyCounts = make(map[int64]int)
}
+ // 识别需要查询窗口费用和会话数的账号(Anthropic OAuth/SetupToken 且启用了相应功能)
+ windowCostAccountIDs := make([]int64, 0)
+ sessionLimitAccountIDs := make([]int64, 0)
+ for i := range accounts {
+ acc := &accounts[i]
+ if acc.IsAnthropicOAuthOrSetupToken() {
+ if acc.GetWindowCostLimit() > 0 {
+ windowCostAccountIDs = append(windowCostAccountIDs, acc.ID)
+ }
+ if acc.GetMaxSessions() > 0 {
+ sessionLimitAccountIDs = append(sessionLimitAccountIDs, acc.ID)
+ }
+ }
+ }
+
+ // 并行获取窗口费用和活跃会话数
+ var windowCosts map[int64]float64
+ var activeSessions map[int64]int
+
+ // 获取活跃会话数(批量查询)
+ if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil {
+ activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs)
+ if activeSessions == nil {
+ activeSessions = make(map[int64]int)
+ }
+ }
+
+ // 获取窗口费用(并行查询)
+ if len(windowCostAccountIDs) > 0 {
+ windowCosts = make(map[int64]float64)
+ var mu sync.Mutex
+ g, gctx := errgroup.WithContext(c.Request.Context())
+ g.SetLimit(10) // 限制并发数
+
+ for i := range accounts {
+ acc := &accounts[i]
+ if !acc.IsAnthropicOAuthOrSetupToken() || acc.GetWindowCostLimit() <= 0 {
+ continue
+ }
+ accCopy := acc // 闭包捕获
+ g.Go(func() error {
+ var startTime time.Time
+ if accCopy.SessionWindowStart != nil {
+ startTime = *accCopy.SessionWindowStart
+ } else {
+ startTime = time.Now().Add(-5 * time.Hour)
+ }
+ stats, err := h.accountUsageService.GetAccountWindowStats(gctx, accCopy.ID, startTime)
+ if err == nil && stats != nil {
+ mu.Lock()
+ windowCosts[accCopy.ID] = stats.StandardCost // 使用标准费用
+ mu.Unlock()
+ }
+ return nil // 不返回错误,允许部分失败
+ })
+ }
+ _ = g.Wait()
+ }
+
// Build response with concurrency info
result := make([]AccountWithConcurrency, len(accounts))
for i := range accounts {
- result[i] = AccountWithConcurrency{
- Account: dto.AccountFromService(&accounts[i]),
- CurrentConcurrency: concurrencyCounts[accounts[i].ID],
+ acc := &accounts[i]
+ item := AccountWithConcurrency{
+ Account: dto.AccountFromService(acc),
+ CurrentConcurrency: concurrencyCounts[acc.ID],
}
+
+ // 添加窗口费用(仅当启用时)
+ if windowCosts != nil {
+ if cost, ok := windowCosts[acc.ID]; ok {
+ item.CurrentWindowCost = &cost
+ }
+ }
+
+ // 添加活跃会话数(仅当启用时)
+ if activeSessions != nil {
+ if count, ok := activeSessions[acc.ID]; ok {
+ item.ActiveSessions = &count
+ }
+ }
+
+ result[i] = item
}
response.Paginated(c, result, total, page, pageSize)
diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go
index 4d59ddff..f5bdd008 100644
--- a/backend/internal/handler/dto/mappers.go
+++ b/backend/internal/handler/dto/mappers.go
@@ -116,7 +116,7 @@ func AccountFromServiceShallow(a *service.Account) *Account {
if a == nil {
return nil
}
- return &Account{
+ out := &Account{
ID: a.ID,
Name: a.Name,
Notes: a.Notes,
@@ -146,6 +146,24 @@ func AccountFromServiceShallow(a *service.Account) *Account {
SessionWindowStatus: a.SessionWindowStatus,
GroupIDs: a.GroupIDs,
}
+
+ // 提取 5h 窗口费用控制和会话数量控制配置(仅 Anthropic OAuth/SetupToken 账号有效)
+ if a.IsAnthropicOAuthOrSetupToken() {
+ if limit := a.GetWindowCostLimit(); limit > 0 {
+ out.WindowCostLimit = &limit
+ }
+ if reserve := a.GetWindowCostStickyReserve(); reserve > 0 {
+ out.WindowCostStickyReserve = &reserve
+ }
+ if maxSessions := a.GetMaxSessions(); maxSessions > 0 {
+ out.MaxSessions = &maxSessions
+ }
+ if idleTimeout := a.GetSessionIdleTimeoutMinutes(); idleTimeout > 0 {
+ out.SessionIdleTimeoutMin = &idleTimeout
+ }
+ }
+
+ return out
}
func AccountFromService(a *service.Account) *Account {
diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go
index 914f2b23..4519143c 100644
--- a/backend/internal/handler/dto/types.go
+++ b/backend/internal/handler/dto/types.go
@@ -102,6 +102,16 @@ type Account struct {
SessionWindowEnd *time.Time `json:"session_window_end"`
SessionWindowStatus string `json:"session_window_status"`
+ // 5h窗口费用控制(仅 Anthropic OAuth/SetupToken 账号有效)
+ // 从 extra 字段提取,方便前端显示和编辑
+ WindowCostLimit *float64 `json:"window_cost_limit,omitempty"`
+ WindowCostStickyReserve *float64 `json:"window_cost_sticky_reserve,omitempty"`
+
+ // 会话数量控制(仅 Anthropic OAuth/SetupToken 账号有效)
+ // 从 extra 字段提取,方便前端显示和编辑
+ MaxSessions *int `json:"max_sessions,omitempty"`
+ SessionIdleTimeoutMin *int `json:"session_idle_timeout_minutes,omitempty"`
+
Proxy *Proxy `json:"proxy,omitempty"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go
index b60618a8..8c32be21 100644
--- a/backend/internal/handler/gateway_handler.go
+++ b/backend/internal/handler/gateway_handler.go
@@ -185,7 +185,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
lastFailoverStatus := 0
for {
- selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs)
+ selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制
if err != nil {
if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
@@ -320,7 +320,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
for {
// 选择支持该模型的账号
- selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs)
+ selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID)
if err != nil {
if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go
index 2dddb856..ec943e61 100644
--- a/backend/internal/handler/gemini_v1beta_handler.go
+++ b/backend/internal/handler/gemini_v1beta_handler.go
@@ -226,7 +226,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
lastFailoverStatus := 0
for {
- selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs)
+ selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs, "") // Gemini 不使用会话限制
if err != nil {
if len(failedAccountIDs) == 0 {
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
diff --git a/backend/internal/repository/session_limit_cache.go b/backend/internal/repository/session_limit_cache.go
new file mode 100644
index 00000000..16f2a69c
--- /dev/null
+++ b/backend/internal/repository/session_limit_cache.go
@@ -0,0 +1,321 @@
+package repository
+
+import (
+ "context"
+ "fmt"
+ "strconv"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/redis/go-redis/v9"
+)
+
+// 会话限制缓存常量定义
+//
+// 设计说明:
+// 使用 Redis 有序集合(Sorted Set)跟踪每个账号的活跃会话:
+// - Key: session_limit:account:{accountID}
+// - Member: sessionUUID(从 metadata.user_id 中提取)
+// - Score: Unix 时间戳(会话最后活跃时间)
+//
+// 通过 ZREMRANGEBYSCORE 自动清理过期会话,无需手动管理 TTL
+const (
+ // 会话限制键前缀
+ // 格式: session_limit:account:{accountID}
+ sessionLimitKeyPrefix = "session_limit:account:"
+
+ // 窗口费用缓存键前缀
+ // 格式: window_cost:account:{accountID}
+ windowCostKeyPrefix = "window_cost:account:"
+
+ // 窗口费用缓存 TTL(30秒)
+ windowCostCacheTTL = 30 * time.Second
+)
+
+var (
+ // registerSessionScript 注册会话活动
+ // 使用 Redis TIME 命令获取服务器时间,避免多实例时钟不同步
+ // KEYS[1] = session_limit:account:{accountID}
+ // ARGV[1] = maxSessions
+ // ARGV[2] = idleTimeout(秒)
+ // ARGV[3] = sessionUUID
+ // 返回: 1 = 允许, 0 = 拒绝
+ registerSessionScript = redis.NewScript(`
+ local key = KEYS[1]
+ local maxSessions = tonumber(ARGV[1])
+ local idleTimeout = tonumber(ARGV[2])
+ local sessionUUID = ARGV[3]
+
+ -- 使用 Redis 服务器时间,确保多实例时钟一致
+ local timeResult = redis.call('TIME')
+ local now = tonumber(timeResult[1])
+ local expireBefore = now - idleTimeout
+
+ -- 清理过期会话
+ redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
+
+ -- 检查会话是否已存在(支持刷新时间戳)
+ local exists = redis.call('ZSCORE', key, sessionUUID)
+ if exists ~= false then
+ -- 会话已存在,刷新时间戳
+ redis.call('ZADD', key, now, sessionUUID)
+ redis.call('EXPIRE', key, idleTimeout + 60)
+ return 1
+ end
+
+ -- 检查是否达到会话数量上限
+ local count = redis.call('ZCARD', key)
+ if count < maxSessions then
+ -- 未达上限,添加新会话
+ redis.call('ZADD', key, now, sessionUUID)
+ redis.call('EXPIRE', key, idleTimeout + 60)
+ return 1
+ end
+
+ -- 达到上限,拒绝新会话
+ return 0
+ `)
+
+ // refreshSessionScript 刷新会话时间戳
+ // KEYS[1] = session_limit:account:{accountID}
+ // ARGV[1] = idleTimeout(秒)
+ // ARGV[2] = sessionUUID
+ refreshSessionScript = redis.NewScript(`
+ local key = KEYS[1]
+ local idleTimeout = tonumber(ARGV[1])
+ local sessionUUID = ARGV[2]
+
+ local timeResult = redis.call('TIME')
+ local now = tonumber(timeResult[1])
+
+ -- 检查会话是否存在
+ local exists = redis.call('ZSCORE', key, sessionUUID)
+ if exists ~= false then
+ redis.call('ZADD', key, now, sessionUUID)
+ redis.call('EXPIRE', key, idleTimeout + 60)
+ end
+ return 1
+ `)
+
+ // getActiveSessionCountScript 获取活跃会话数
+ // KEYS[1] = session_limit:account:{accountID}
+ // ARGV[1] = idleTimeout(秒)
+ getActiveSessionCountScript = redis.NewScript(`
+ local key = KEYS[1]
+ local idleTimeout = tonumber(ARGV[1])
+
+ local timeResult = redis.call('TIME')
+ local now = tonumber(timeResult[1])
+ local expireBefore = now - idleTimeout
+
+ -- 清理过期会话
+ redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
+
+ return redis.call('ZCARD', key)
+ `)
+
+ // isSessionActiveScript 检查会话是否活跃
+ // KEYS[1] = session_limit:account:{accountID}
+ // ARGV[1] = idleTimeout(秒)
+ // ARGV[2] = sessionUUID
+ isSessionActiveScript = redis.NewScript(`
+ local key = KEYS[1]
+ local idleTimeout = tonumber(ARGV[1])
+ local sessionUUID = ARGV[2]
+
+ local timeResult = redis.call('TIME')
+ local now = tonumber(timeResult[1])
+ local expireBefore = now - idleTimeout
+
+ -- 获取会话的时间戳
+ local score = redis.call('ZSCORE', key, sessionUUID)
+ if score == false then
+ return 0
+ end
+
+ -- 检查是否过期
+ if tonumber(score) <= expireBefore then
+ return 0
+ end
+
+ return 1
+ `)
+)
+
+type sessionLimitCache struct {
+ rdb *redis.Client
+ defaultIdleTimeout time.Duration // 默认空闲超时(用于 GetActiveSessionCount)
+}
+
+// NewSessionLimitCache 创建会话限制缓存
+// defaultIdleTimeoutMinutes: 默认空闲超时时间(分钟),用于无参数查询
+func NewSessionLimitCache(rdb *redis.Client, defaultIdleTimeoutMinutes int) service.SessionLimitCache {
+ if defaultIdleTimeoutMinutes <= 0 {
+ defaultIdleTimeoutMinutes = 5 // 默认 5 分钟
+ }
+ return &sessionLimitCache{
+ rdb: rdb,
+ defaultIdleTimeout: time.Duration(defaultIdleTimeoutMinutes) * time.Minute,
+ }
+}
+
+// sessionLimitKey 生成会话限制的 Redis 键
+func sessionLimitKey(accountID int64) string {
+ return fmt.Sprintf("%s%d", sessionLimitKeyPrefix, accountID)
+}
+
+// windowCostKey 生成窗口费用缓存的 Redis 键
+func windowCostKey(accountID int64) string {
+ return fmt.Sprintf("%s%d", windowCostKeyPrefix, accountID)
+}
+
+// RegisterSession 注册会话活动
+func (c *sessionLimitCache) RegisterSession(ctx context.Context, accountID int64, sessionUUID string, maxSessions int, idleTimeout time.Duration) (bool, error) {
+ if sessionUUID == "" || maxSessions <= 0 {
+ return true, nil // 无效参数,默认允许
+ }
+
+ key := sessionLimitKey(accountID)
+ idleTimeoutSeconds := int(idleTimeout.Seconds())
+ if idleTimeoutSeconds <= 0 {
+ idleTimeoutSeconds = int(c.defaultIdleTimeout.Seconds())
+ }
+
+ result, err := registerSessionScript.Run(ctx, c.rdb, []string{key}, maxSessions, idleTimeoutSeconds, sessionUUID).Int()
+ if err != nil {
+ return true, err // 失败开放:缓存错误时允许请求通过
+ }
+ return result == 1, nil
+}
+
+// RefreshSession 刷新会话时间戳
+func (c *sessionLimitCache) RefreshSession(ctx context.Context, accountID int64, sessionUUID string, idleTimeout time.Duration) error {
+ if sessionUUID == "" {
+ return nil
+ }
+
+ key := sessionLimitKey(accountID)
+ idleTimeoutSeconds := int(idleTimeout.Seconds())
+ if idleTimeoutSeconds <= 0 {
+ idleTimeoutSeconds = int(c.defaultIdleTimeout.Seconds())
+ }
+
+ _, err := refreshSessionScript.Run(ctx, c.rdb, []string{key}, idleTimeoutSeconds, sessionUUID).Result()
+ return err
+}
+
+// GetActiveSessionCount 获取活跃会话数
+func (c *sessionLimitCache) GetActiveSessionCount(ctx context.Context, accountID int64) (int, error) {
+ key := sessionLimitKey(accountID)
+ idleTimeoutSeconds := int(c.defaultIdleTimeout.Seconds())
+
+ result, err := getActiveSessionCountScript.Run(ctx, c.rdb, []string{key}, idleTimeoutSeconds).Int()
+ if err != nil {
+ return 0, err
+ }
+ return result, nil
+}
+
+// GetActiveSessionCountBatch 批量获取多个账号的活跃会话数
+func (c *sessionLimitCache) GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
+ if len(accountIDs) == 0 {
+ return make(map[int64]int), nil
+ }
+
+ results := make(map[int64]int, len(accountIDs))
+
+ // 使用 pipeline 批量执行
+ pipe := c.rdb.Pipeline()
+ idleTimeoutSeconds := int(c.defaultIdleTimeout.Seconds())
+
+ cmds := make(map[int64]*redis.Cmd, len(accountIDs))
+ for _, accountID := range accountIDs {
+ key := sessionLimitKey(accountID)
+ cmds[accountID] = getActiveSessionCountScript.Run(ctx, pipe, []string{key}, idleTimeoutSeconds)
+ }
+
+ // 执行 pipeline,即使部分失败也尝试获取成功的结果
+ _, _ = pipe.Exec(ctx)
+
+ for accountID, cmd := range cmds {
+ if result, err := cmd.Int(); err == nil {
+ results[accountID] = result
+ }
+ }
+
+ return results, nil
+}
+
+// IsSessionActive 检查会话是否活跃
+func (c *sessionLimitCache) IsSessionActive(ctx context.Context, accountID int64, sessionUUID string) (bool, error) {
+ if sessionUUID == "" {
+ return false, nil
+ }
+
+ key := sessionLimitKey(accountID)
+ idleTimeoutSeconds := int(c.defaultIdleTimeout.Seconds())
+
+ result, err := isSessionActiveScript.Run(ctx, c.rdb, []string{key}, idleTimeoutSeconds, sessionUUID).Int()
+ if err != nil {
+ return false, err
+ }
+ return result == 1, nil
+}
+
+// ========== 5h窗口费用缓存实现 ==========
+
+// GetWindowCost 获取缓存的窗口费用
+func (c *sessionLimitCache) GetWindowCost(ctx context.Context, accountID int64) (float64, bool, error) {
+ key := windowCostKey(accountID)
+ val, err := c.rdb.Get(ctx, key).Float64()
+ if err == redis.Nil {
+ return 0, false, nil // 缓存未命中
+ }
+ if err != nil {
+ return 0, false, err
+ }
+ return val, true, nil
+}
+
+// SetWindowCost 设置窗口费用缓存
+func (c *sessionLimitCache) SetWindowCost(ctx context.Context, accountID int64, cost float64) error {
+ key := windowCostKey(accountID)
+ return c.rdb.Set(ctx, key, cost, windowCostCacheTTL).Err()
+}
+
+// GetWindowCostBatch 批量获取窗口费用缓存
+func (c *sessionLimitCache) GetWindowCostBatch(ctx context.Context, accountIDs []int64) (map[int64]float64, error) {
+ if len(accountIDs) == 0 {
+ return make(map[int64]float64), nil
+ }
+
+ // 构建批量查询的 keys
+ keys := make([]string, len(accountIDs))
+ for i, accountID := range accountIDs {
+ keys[i] = windowCostKey(accountID)
+ }
+
+ // 使用 MGET 批量获取
+ vals, err := c.rdb.MGet(ctx, keys...).Result()
+ if err != nil {
+ return nil, err
+ }
+
+ results := make(map[int64]float64, len(accountIDs))
+ for i, val := range vals {
+ if val == nil {
+ continue // 缓存未命中
+ }
+ // 尝试解析为 float64
+ switch v := val.(type) {
+ case string:
+ if cost, err := strconv.ParseFloat(v, 64); err == nil {
+ results[accountIDs[i]] = cost
+ }
+ case float64:
+ results[accountIDs[i]] = v
+ }
+ }
+
+ return results, nil
+}
diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go
index 91ef9413..77ed37e1 100644
--- a/backend/internal/repository/wire.go
+++ b/backend/internal/repository/wire.go
@@ -37,6 +37,16 @@ func ProvidePricingRemoteClient(cfg *config.Config) service.PricingRemoteClient
return NewPricingRemoteClient(cfg.Update.ProxyURL)
}
+// ProvideSessionLimitCache 创建会话限制缓存
+// 用于 Anthropic OAuth/SetupToken 账号的并发会话数量控制
+func ProvideSessionLimitCache(rdb *redis.Client, cfg *config.Config) service.SessionLimitCache {
+ defaultIdleTimeoutMinutes := 5 // 默认 5 分钟空闲超时
+ if cfg != nil && cfg.Gateway.SessionIdleTimeoutMinutes > 0 {
+ defaultIdleTimeoutMinutes = cfg.Gateway.SessionIdleTimeoutMinutes
+ }
+ return NewSessionLimitCache(rdb, defaultIdleTimeoutMinutes)
+}
+
// ProviderSet is the Wire provider set for all repositories
var ProviderSet = wire.NewSet(
NewUserRepository,
@@ -61,6 +71,7 @@ var ProviderSet = wire.NewSet(
NewTempUnschedCache,
NewTimeoutCounterCache,
ProvideConcurrencyCache,
+ ProvideSessionLimitCache,
NewDashboardCache,
NewEmailCache,
NewIdentityCache,
diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go
index 0d7a9cf9..36ba0bcc 100644
--- a/backend/internal/service/account.go
+++ b/backend/internal/service/account.go
@@ -557,3 +557,141 @@ func (a *Account) IsMixedSchedulingEnabled() bool {
}
return false
}
+
+// WindowCostSchedulability 窗口费用调度状态
+type WindowCostSchedulability int
+
+const (
+ // WindowCostSchedulable 可正常调度
+ WindowCostSchedulable WindowCostSchedulability = iota
+ // WindowCostStickyOnly 仅允许粘性会话
+ WindowCostStickyOnly
+ // WindowCostNotSchedulable 完全不可调度
+ WindowCostNotSchedulable
+)
+
+// IsAnthropicOAuthOrSetupToken 判断是否为 Anthropic OAuth 或 SetupToken 类型账号
+// 仅这两类账号支持 5h 窗口额度控制和会话数量控制
+func (a *Account) IsAnthropicOAuthOrSetupToken() bool {
+ return a.Platform == PlatformAnthropic && (a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken)
+}
+
+// GetWindowCostLimit 获取 5h 窗口费用阈值(美元)
+// 返回 0 表示未启用
+func (a *Account) GetWindowCostLimit() float64 {
+ if a.Extra == nil {
+ return 0
+ }
+ if v, ok := a.Extra["window_cost_limit"]; ok {
+ return parseExtraFloat64(v)
+ }
+ return 0
+}
+
+// GetWindowCostStickyReserve 获取粘性会话预留额度(美元)
+// 默认值为 10
+func (a *Account) GetWindowCostStickyReserve() float64 {
+ if a.Extra == nil {
+ return 10.0
+ }
+ if v, ok := a.Extra["window_cost_sticky_reserve"]; ok {
+ val := parseExtraFloat64(v)
+ if val > 0 {
+ return val
+ }
+ }
+ return 10.0
+}
+
+// GetMaxSessions 获取最大并发会话数
+// 返回 0 表示未启用
+func (a *Account) GetMaxSessions() int {
+ if a.Extra == nil {
+ return 0
+ }
+ if v, ok := a.Extra["max_sessions"]; ok {
+ return parseExtraInt(v)
+ }
+ return 0
+}
+
+// GetSessionIdleTimeoutMinutes 获取会话空闲超时分钟数
+// 默认值为 5 分钟
+func (a *Account) GetSessionIdleTimeoutMinutes() int {
+ if a.Extra == nil {
+ return 5
+ }
+ if v, ok := a.Extra["session_idle_timeout_minutes"]; ok {
+ val := parseExtraInt(v)
+ if val > 0 {
+ return val
+ }
+ }
+ return 5
+}
+
+// CheckWindowCostSchedulability 根据当前窗口费用检查调度状态
+// - 费用 < 阈值: WindowCostSchedulable(可正常调度)
+// - 费用 >= 阈值 且 < 阈值+预留: WindowCostStickyOnly(仅粘性会话)
+// - 费用 >= 阈值+预留: WindowCostNotSchedulable(不可调度)
+func (a *Account) CheckWindowCostSchedulability(currentWindowCost float64) WindowCostSchedulability {
+ limit := a.GetWindowCostLimit()
+ if limit <= 0 {
+ return WindowCostSchedulable
+ }
+
+ if currentWindowCost < limit {
+ return WindowCostSchedulable
+ }
+
+ stickyReserve := a.GetWindowCostStickyReserve()
+ if currentWindowCost < limit+stickyReserve {
+ return WindowCostStickyOnly
+ }
+
+ return WindowCostNotSchedulable
+}
+
+// parseExtraFloat64 从 extra 字段解析 float64 值
+func parseExtraFloat64(value any) float64 {
+ switch v := value.(type) {
+ case float64:
+ return v
+ case float32:
+ return float64(v)
+ case int:
+ return float64(v)
+ case int64:
+ return float64(v)
+ case json.Number:
+ if f, err := v.Float64(); err == nil {
+ return f
+ }
+ case string:
+ if f, err := strconv.ParseFloat(strings.TrimSpace(v), 64); err == nil {
+ return f
+ }
+ }
+ return 0
+}
+
+// parseExtraInt 从 extra 字段解析 int 值
+func parseExtraInt(value any) int {
+ switch v := value.(type) {
+ case int:
+ return v
+ case int64:
+ return int(v)
+ case float64:
+ return int(v)
+ case json.Number:
+ if i, err := v.Int64(); err == nil {
+ return int(i)
+ }
+ case string:
+ if i, err := strconv.Atoi(strings.TrimSpace(v)); err == nil {
+ return i
+ }
+ }
+ return 0
+}
diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go
index d9ed5609..6f012385 100644
--- a/backend/internal/service/account_usage_service.go
+++ b/backend/internal/service/account_usage_service.go
@@ -575,3 +575,9 @@ func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64
},
}
}
+
+// GetAccountWindowStats 获取账号在指定时间窗口内的使用统计
+// 用于账号列表页面显示当前窗口费用
+func (s *AccountUsageService) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) {
+ return s.usageLogRepo.GetAccountWindowStats(ctx, accountID, startTime)
+}
diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go
index 76d73286..f543ef1a 100644
--- a/backend/internal/service/gateway_multiplatform_test.go
+++ b/backend/internal/service/gateway_multiplatform_test.go
@@ -1052,7 +1052,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: nil, // No concurrency service
}
- result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
+ result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
@@ -1105,7 +1105,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: nil, // legacy path
}
- result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-b", nil)
+ result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-b", nil, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
@@ -1137,7 +1137,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: nil,
}
- result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
+ result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
@@ -1169,7 +1169,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
}
excludedIDs := map[int64]struct{}{1: {}}
- result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs)
+ result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
@@ -1203,7 +1203,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: NewConcurrencyService(concurrencyCache),
}
- result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil)
+ result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
@@ -1239,7 +1239,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: NewConcurrencyService(concurrencyCache),
}
- result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil)
+ result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
@@ -1266,7 +1266,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: nil,
}
- result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
+ result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
require.Error(t, err)
require.Nil(t, result)
require.Contains(t, err.Error(), "no available accounts")
@@ -1298,7 +1298,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: nil,
}
- result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
+ result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
@@ -1331,7 +1331,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: nil,
}
- result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
+ result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go
index 1e3221d3..5068767c 100644
--- a/backend/internal/service/gateway_service.go
+++ b/backend/internal/service/gateway_service.go
@@ -176,6 +176,7 @@ type GatewayService struct {
deferredService *DeferredService
concurrencyService *ConcurrencyService
claudeTokenProvider *ClaudeTokenProvider
+ sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken)
}
// NewGatewayService creates a new GatewayService
@@ -196,6 +197,7 @@ func NewGatewayService(
httpUpstream HTTPUpstream,
deferredService *DeferredService,
claudeTokenProvider *ClaudeTokenProvider,
+ sessionLimitCache SessionLimitCache,
) *GatewayService {
return &GatewayService{
accountRepo: accountRepo,
@@ -214,6 +216,7 @@ func NewGatewayService(
httpUpstream: httpUpstream,
deferredService: deferredService,
claudeTokenProvider: claudeTokenProvider,
+ sessionLimitCache: sessionLimitCache,
}
}
@@ -407,8 +410,12 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
}
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
-func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) {
+// metadataUserID: 原始 metadata.user_id 字段(用于提取会话 UUID 进行会话数量限制)
+func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string) (*AccountSelectionResult, error) {
cfg := s.schedulingConfig()
+ // 提取会话 UUID(用于会话数量限制)
+ sessionUUID := extractSessionUUID(metadataUserID)
+
var stickyAccountID int64
if sessionHash != "" && s.cache != nil {
if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash); err == nil {
@@ -527,7 +534,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if len(routingAccountIDs) > 0 && s.concurrencyService != nil {
// 1. 过滤出路由列表中可调度的账号
var routingCandidates []*Account
- var filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping int
+ var filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping, filteredWindowCost int
for _, routingAccountID := range routingAccountIDs {
if isExcluded(routingAccountID) {
filteredExcluded++
@@ -554,13 +561,18 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
filteredModelMapping++
continue
}
+ // 窗口费用检查(非粘性会话路径)
+ if !s.isAccountSchedulableForWindowCost(ctx, account, false) {
+ filteredWindowCost++
+ continue
+ }
routingCandidates = append(routingCandidates, account)
}
if s.debugModelRoutingEnabled() {
- log.Printf("[ModelRoutingDebug] routed candidates: group_id=%v model=%s routed=%d candidates=%d filtered(excluded=%d missing=%d unsched=%d platform=%d model_scope=%d model_mapping=%d)",
+ log.Printf("[ModelRoutingDebug] routed candidates: group_id=%v model=%s routed=%d candidates=%d filtered(excluded=%d missing=%d unsched=%d platform=%d model_scope=%d model_mapping=%d window_cost=%d)",
derefGroupID(groupID), requestedModel, len(routingAccountIDs), len(routingCandidates),
- filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping)
+ filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping, filteredWindowCost)
}
if len(routingCandidates) > 0 {
@@ -573,18 +585,25 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if stickyAccount.IsSchedulable() &&
s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) &&
stickyAccount.IsSchedulableForModel(requestedModel) &&
- (requestedModel == "" || s.isModelSupportedByAccount(stickyAccount, requestedModel)) {
+ (requestedModel == "" || s.isModelSupportedByAccount(stickyAccount, requestedModel)) &&
+ s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) { // 粘性会话窗口费用检查
result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency)
if err == nil && result.Acquired {
- _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
- if s.debugModelRoutingEnabled() {
- log.Printf("[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID)
+ // 会话数量限制检查
+ if !s.checkAndRegisterSession(ctx, stickyAccount, sessionUUID) {
+ result.ReleaseFunc() // 释放槽位
+ // 继续到负载感知选择
+ } else {
+ _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
+ if s.debugModelRoutingEnabled() {
+ log.Printf("[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID)
+ }
+ return &AccountSelectionResult{
+ Account: stickyAccount,
+ Acquired: true,
+ ReleaseFunc: result.ReleaseFunc,
+ }, nil
}
- return &AccountSelectionResult{
- Account: stickyAccount,
- Acquired: true,
- ReleaseFunc: result.ReleaseFunc,
- }, nil
}
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID)
@@ -657,6 +676,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
for _, item := range routingAvailable {
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
if err == nil && result.Acquired {
+ // 会话数量限制检查
+ if !s.checkAndRegisterSession(ctx, item.account, sessionUUID) {
+ result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
+ continue
+ }
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL)
}
@@ -699,15 +723,21 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if ok && s.isAccountInGroup(account, groupID) &&
s.isAccountAllowedForPlatform(account, platform, useMixed) &&
account.IsSchedulableForModel(requestedModel) &&
- (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
+ (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) &&
+ s.isAccountSchedulableForWindowCost(ctx, account, true) { // 粘性会话窗口费用检查
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
- _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
- return &AccountSelectionResult{
- Account: account,
- Acquired: true,
- ReleaseFunc: result.ReleaseFunc,
- }, nil
+ // 会话数量限制检查
+ if !s.checkAndRegisterSession(ctx, account, sessionUUID) {
+ result.ReleaseFunc() // 释放槽位,继续到 Layer 2
+ } else {
+ _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
+ return &AccountSelectionResult{
+ Account: account,
+ Acquired: true,
+ ReleaseFunc: result.ReleaseFunc,
+ }, nil
+ }
}
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
@@ -748,6 +778,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue
}
+ // 窗口费用检查(非粘性会话路径)
+ if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
+ continue
+ }
candidates = append(candidates, acc)
}
@@ -765,7 +799,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
if err != nil {
- if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); ok {
+ if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth, sessionUUID); ok {
return result, nil
}
} else {
@@ -814,6 +848,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
for _, item := range available {
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
if err == nil && result.Acquired {
+ // 会话数量限制检查
+ if !s.checkAndRegisterSession(ctx, item.account, sessionUUID) {
+ result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
+ continue
+ }
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL)
}
@@ -843,13 +882,18 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
return nil, errors.New("no available accounts")
}
-func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) {
+func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool, sessionUUID string) (*AccountSelectionResult, bool) {
ordered := append([]*Account(nil), candidates...)
sortAccountsByPriorityAndLastUsed(ordered, preferOAuth)
for _, acc := range ordered {
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
if err == nil && result.Acquired {
+ // 会话数量限制检查
+ if !s.checkAndRegisterSession(ctx, acc, sessionUUID) {
+ result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
+ continue
+ }
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, acc.ID, stickySessionTTL)
}
@@ -1081,6 +1125,107 @@ func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID in
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
}
+// isAccountSchedulableForWindowCost 检查账号是否可根据窗口费用进行调度
+// 仅适用于 Anthropic OAuth/SetupToken 账号
+// 返回 true 表示可调度,false 表示不可调度
+func (s *GatewayService) isAccountSchedulableForWindowCost(ctx context.Context, account *Account, isSticky bool) bool {
+ // 只检查 Anthropic OAuth/SetupToken 账号
+ if !account.IsAnthropicOAuthOrSetupToken() {
+ return true
+ }
+
+ limit := account.GetWindowCostLimit()
+ if limit <= 0 {
+ return true // 未启用窗口费用限制
+ }
+
+ // 尝试从缓存获取窗口费用
+ var currentCost float64
+ if s.sessionLimitCache != nil {
+ if cost, hit, err := s.sessionLimitCache.GetWindowCost(ctx, account.ID); err == nil && hit {
+ currentCost = cost
+ goto checkSchedulability
+ }
+ }
+
+ // 缓存未命中,从数据库查询
+ {
+ var startTime time.Time
+ if account.SessionWindowStart != nil {
+ startTime = *account.SessionWindowStart
+ } else {
+ startTime = time.Now().Add(-5 * time.Hour)
+ }
+
+ stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime)
+ if err != nil {
+ // 失败开放:查询失败时允许调度
+ return true
+ }
+
+ // 使用标准费用(不含账号倍率)
+ currentCost = stats.StandardCost
+
+ // 设置缓存(忽略错误)
+ if s.sessionLimitCache != nil {
+ _ = s.sessionLimitCache.SetWindowCost(ctx, account.ID, currentCost)
+ }
+ }
+
+checkSchedulability:
+ schedulability := account.CheckWindowCostSchedulability(currentCost)
+
+ switch schedulability {
+ case WindowCostSchedulable:
+ return true
+ case WindowCostStickyOnly:
+ return isSticky
+ case WindowCostNotSchedulable:
+ return false
+ }
+ return true
+}
+
+// checkAndRegisterSession 检查并注册会话,用于会话数量限制
+// 仅适用于 Anthropic OAuth/SetupToken 账号
+// 返回 true 表示允许(在限制内或会话已存在),false 表示拒绝(超出限制且是新会话)
+func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *Account, sessionUUID string) bool {
+ // 只检查 Anthropic OAuth/SetupToken 账号
+ if !account.IsAnthropicOAuthOrSetupToken() {
+ return true
+ }
+
+ maxSessions := account.GetMaxSessions()
+ if maxSessions <= 0 || sessionUUID == "" {
+ return true // 未启用会话限制或无会话ID
+ }
+
+ if s.sessionLimitCache == nil {
+ return true // 缓存不可用时允许通过
+ }
+
+ idleTimeout := time.Duration(account.GetSessionIdleTimeoutMinutes()) * time.Minute
+
+ allowed, err := s.sessionLimitCache.RegisterSession(ctx, account.ID, sessionUUID, maxSessions, idleTimeout)
+ if err != nil {
+ // 失败开放:缓存错误时允许通过
+ return true
+ }
+ return allowed
+}
+
+// extractSessionUUID 从 metadata.user_id 中提取会话 UUID
+// 格式: user_{64位hex}_account__session_{uuid}
+func extractSessionUUID(metadataUserID string) string {
+ if metadataUserID == "" {
+ return ""
+ }
+ if match := sessionIDRegex.FindStringSubmatch(metadataUserID); len(match) > 1 {
+ return match[1]
+ }
+ return ""
+}
+
func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) {
if s.schedulerSnapshot != nil {
return s.schedulerSnapshot.GetAccount(ctx, accountID)
diff --git a/backend/internal/service/ops_retry.go b/backend/internal/service/ops_retry.go
index 25c10af6..8d98e43f 100644
--- a/backend/internal/service/ops_retry.go
+++ b/backend/internal/service/ops_retry.go
@@ -514,7 +514,7 @@ func (s *OpsService) selectAccountForRetry(ctx context.Context, reqType opsRetry
if s.gatewayService == nil {
return nil, fmt.Errorf("gateway service not available")
}
- return s.gatewayService.SelectAccountWithLoadAwareness(ctx, groupID, "", model, excludedIDs)
+ return s.gatewayService.SelectAccountWithLoadAwareness(ctx, groupID, "", model, excludedIDs, "") // 重试不使用会话限制
default:
return nil, fmt.Errorf("unsupported retry type: %s", reqType)
}
diff --git a/backend/internal/service/session_limit_cache.go b/backend/internal/service/session_limit_cache.go
new file mode 100644
index 00000000..f6f0c26a
--- /dev/null
+++ b/backend/internal/service/session_limit_cache.go
@@ -0,0 +1,63 @@
+package service
+
+import (
+ "context"
+ "time"
+)
+
+// SessionLimitCache 管理账号级别的活跃会话跟踪
+// 用于 Anthropic OAuth/SetupToken 账号的会话数量限制
+//
+// Key 格式: session_limit:account:{accountID}
+// 数据结构: Sorted Set (member=sessionUUID, score=timestamp)
+//
+// 会话在空闲超时后自动过期,无需手动清理
+type SessionLimitCache interface {
+ // RegisterSession 注册会话活动
+ // - 如果会话已存在,刷新其时间戳并返回 true
+ // - 如果会话不存在且活跃会话数 < maxSessions,添加新会话并返回 true
+ // - 如果会话不存在且活跃会话数 >= maxSessions,返回 false(拒绝)
+ //
+ // 参数:
+ // accountID: 账号 ID
+ // sessionUUID: 从 metadata.user_id 中提取的会话 UUID
+ // maxSessions: 最大并发会话数限制
+ // idleTimeout: 会话空闲超时时间
+ //
+ // 返回:
+ // allowed: true 表示允许(在限制内或会话已存在),false 表示拒绝(超出限制且是新会话)
+ // error: 操作错误
+ RegisterSession(ctx context.Context, accountID int64, sessionUUID string, maxSessions int, idleTimeout time.Duration) (allowed bool, err error)
+
+ // RefreshSession 刷新现有会话的时间戳
+ // 用于活跃会话保持活动状态
+ RefreshSession(ctx context.Context, accountID int64, sessionUUID string, idleTimeout time.Duration) error
+
+ // GetActiveSessionCount 获取当前活跃会话数
+ // 返回未过期的会话数量
+ GetActiveSessionCount(ctx context.Context, accountID int64) (int, error)
+
+ // GetActiveSessionCountBatch 批量获取多个账号的活跃会话数
+ // 返回 map[accountID]count,查询失败的账号不在 map 中
+ GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error)
+
+ // IsSessionActive 检查特定会话是否活跃(未过期)
+ IsSessionActive(ctx context.Context, accountID int64, sessionUUID string) (bool, error)
+
+ // ========== 5h窗口费用缓存 ==========
+ // Key 格式: window_cost:account:{accountID}
+ // 用于缓存账号在当前5h窗口内的标准费用,减少数据库聚合查询压力
+
+ // GetWindowCost 获取缓存的窗口费用
+ // 返回 (cost, true, nil) 如果缓存命中
+ // 返回 (0, false, nil) 如果缓存未命中
+ // 返回 (0, false, err) 如果发生错误
+ GetWindowCost(ctx context.Context, accountID int64) (cost float64, hit bool, err error)
+
+ // SetWindowCost 设置窗口费用缓存
+ SetWindowCost(ctx context.Context, accountID int64, cost float64) error
+
+ // GetWindowCostBatch 批量获取窗口费用缓存
+ // 返回 map[accountID]cost,缓存未命中的账号不在 map 中
+ GetWindowCostBatch(ctx context.Context, accountIDs []int64) (map[int64]float64, error)
+}
diff --git a/frontend/src/components/account/AccountCapacityCell.vue b/frontend/src/components/account/AccountCapacityCell.vue
new file mode 100644
index 00000000..ae338aca
--- /dev/null
+++ b/frontend/src/components/account/AccountCapacityCell.vue
@@ -0,0 +1,199 @@
+
+
+ {{ t('admin.accounts.quotaControl.hint') }} +
++ {{ t('admin.accounts.quotaControl.windowCost.hint') }} +
+{{ t('admin.accounts.quotaControl.windowCost.limitHint') }}
+{{ t('admin.accounts.quotaControl.windowCost.stickyReserveHint') }}
++ {{ t('admin.accounts.quotaControl.sessionLimit.hint') }} +
+{{ t('admin.accounts.quotaControl.sessionLimit.maxSessionsHint') }}
+{{ t('admin.accounts.quotaControl.sessionLimit.idleTimeoutHint') }}
+