diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index e5bfa515..9240e51e 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -118,7 +118,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig) concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig) crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig) - accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService) + sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig) + accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache) oAuthHandler := admin.NewOAuthHandler(oAuthService) openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService) geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService) @@ -140,7 +141,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { identityService := service.NewIdentityService(identityCache) deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService) - gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider) + gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache) openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService) openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index d616e44b..42084b37 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -235,6 +235,10 @@ type GatewayConfig struct { // ConcurrencySlotTTLMinutes: 并发槽位过期时间(分钟) // 应大于最长 LLM 请求时间,防止请求完成前槽位过期 ConcurrencySlotTTLMinutes int `mapstructure:"concurrency_slot_ttl_minutes"` + // SessionIdleTimeoutMinutes: 会话空闲超时时间(分钟),默认 5 分钟 + // 用于 Anthropic OAuth/SetupToken 账号的会话数量限制功能 + // 空闲超过此时间的会话将被自动释放 + SessionIdleTimeoutMinutes int `mapstructure:"session_idle_timeout_minutes"` // StreamDataIntervalTimeout: 流数据间隔超时(秒),0表示禁用 StreamDataIntervalTimeout int `mapstructure:"stream_data_interval_timeout"` diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index 92fdf2eb..33c91dae 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -44,6 +44,7 @@ type AccountHandler struct { accountTestService *service.AccountTestService concurrencyService *service.ConcurrencyService crsSyncService *service.CRSSyncService + sessionLimitCache service.SessionLimitCache } // NewAccountHandler creates a new admin account handler @@ -58,6 +59,7 @@ func NewAccountHandler( accountTestService *service.AccountTestService, concurrencyService *service.ConcurrencyService, crsSyncService *service.CRSSyncService, + sessionLimitCache service.SessionLimitCache, ) *AccountHandler { return &AccountHandler{ adminService: adminService, @@ -70,6 +72,7 @@ func NewAccountHandler( accountTestService: accountTestService, concurrencyService: concurrencyService, crsSyncService: crsSyncService, + sessionLimitCache: sessionLimitCache, } } @@ -130,6 +133,9 @@ type BulkUpdateAccountsRequest struct { type AccountWithConcurrency struct { *dto.Account CurrentConcurrency int `json:"current_concurrency"` + // 以下字段仅对 Anthropic OAuth/SetupToken 账号有效,且仅在启用相应功能时返回 + CurrentWindowCost *float64 `json:"current_window_cost,omitempty"` // 当前窗口费用 + ActiveSessions *int `json:"active_sessions,omitempty"` // 当前活跃会话数 } // List handles listing all accounts with pagination @@ -164,13 +170,89 @@ func (h *AccountHandler) List(c *gin.Context) { concurrencyCounts = make(map[int64]int) } + // 识别需要查询窗口费用和会话数的账号(Anthropic OAuth/SetupToken 且启用了相应功能) + windowCostAccountIDs := make([]int64, 0) + sessionLimitAccountIDs := make([]int64, 0) + for i := range accounts { + acc := &accounts[i] + if acc.IsAnthropicOAuthOrSetupToken() { + if acc.GetWindowCostLimit() > 0 { + windowCostAccountIDs = append(windowCostAccountIDs, acc.ID) + } + if acc.GetMaxSessions() > 0 { + sessionLimitAccountIDs = append(sessionLimitAccountIDs, acc.ID) + } + } + } + + // 并行获取窗口费用和活跃会话数 + var windowCosts map[int64]float64 + var activeSessions map[int64]int + + // 获取活跃会话数(批量查询) + if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil { + activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs) + if activeSessions == nil { + activeSessions = make(map[int64]int) + } + } + + // 获取窗口费用(并行查询) + if len(windowCostAccountIDs) > 0 { + windowCosts = make(map[int64]float64) + var mu sync.Mutex + g, gctx := errgroup.WithContext(c.Request.Context()) + g.SetLimit(10) // 限制并发数 + + for i := range accounts { + acc := &accounts[i] + if !acc.IsAnthropicOAuthOrSetupToken() || acc.GetWindowCostLimit() <= 0 { + continue + } + accCopy := acc // 闭包捕获 + g.Go(func() error { + var startTime time.Time + if accCopy.SessionWindowStart != nil { + startTime = *accCopy.SessionWindowStart + } else { + startTime = time.Now().Add(-5 * time.Hour) + } + stats, err := h.accountUsageService.GetAccountWindowStats(gctx, accCopy.ID, startTime) + if err == nil && stats != nil { + mu.Lock() + windowCosts[accCopy.ID] = stats.StandardCost // 使用标准费用 + mu.Unlock() + } + return nil // 不返回错误,允许部分失败 + }) + } + _ = g.Wait() + } + // Build response with concurrency info result := make([]AccountWithConcurrency, len(accounts)) for i := range accounts { - result[i] = AccountWithConcurrency{ - Account: dto.AccountFromService(&accounts[i]), - CurrentConcurrency: concurrencyCounts[accounts[i].ID], + acc := &accounts[i] + item := AccountWithConcurrency{ + Account: dto.AccountFromService(acc), + CurrentConcurrency: concurrencyCounts[acc.ID], } + + // 添加窗口费用(仅当启用时) + if windowCosts != nil { + if cost, ok := windowCosts[acc.ID]; ok { + item.CurrentWindowCost = &cost + } + } + + // 添加活跃会话数(仅当启用时) + if activeSessions != nil { + if count, ok := activeSessions[acc.ID]; ok { + item.ActiveSessions = &count + } + } + + result[i] = item } response.Paginated(c, result, total, page, pageSize) diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index f43fac27..d8f10e6c 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -116,7 +116,7 @@ func AccountFromServiceShallow(a *service.Account) *Account { if a == nil { return nil } - return &Account{ + out := &Account{ ID: a.ID, Name: a.Name, Notes: a.Notes, @@ -146,6 +146,24 @@ func AccountFromServiceShallow(a *service.Account) *Account { SessionWindowStatus: a.SessionWindowStatus, GroupIDs: a.GroupIDs, } + + // 提取 5h 窗口费用控制和会话数量控制配置(仅 Anthropic OAuth/SetupToken 账号有效) + if a.IsAnthropicOAuthOrSetupToken() { + if limit := a.GetWindowCostLimit(); limit > 0 { + out.WindowCostLimit = &limit + } + if reserve := a.GetWindowCostStickyReserve(); reserve > 0 { + out.WindowCostStickyReserve = &reserve + } + if maxSessions := a.GetMaxSessions(); maxSessions > 0 { + out.MaxSessions = &maxSessions + } + if idleTimeout := a.GetSessionIdleTimeoutMinutes(); idleTimeout > 0 { + out.SessionIdleTimeoutMin = &idleTimeout + } + } + + return out } func AccountFromService(a *service.Account) *Account { diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index 5fa5a3fd..ae9da254 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -102,6 +102,16 @@ type Account struct { SessionWindowEnd *time.Time `json:"session_window_end"` SessionWindowStatus string `json:"session_window_status"` + // 5h窗口费用控制(仅 Anthropic OAuth/SetupToken 账号有效) + // 从 extra 字段提取,方便前端显示和编辑 + WindowCostLimit *float64 `json:"window_cost_limit,omitempty"` + WindowCostStickyReserve *float64 `json:"window_cost_sticky_reserve,omitempty"` + + // 会话数量控制(仅 Anthropic OAuth/SetupToken 账号有效) + // 从 extra 字段提取,方便前端显示和编辑 + MaxSessions *int `json:"max_sessions,omitempty"` + SessionIdleTimeoutMin *int `json:"session_idle_timeout_minutes,omitempty"` + Proxy *Proxy `json:"proxy,omitempty"` AccountGroups []AccountGroup `json:"account_groups,omitempty"` diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index b60618a8..8c32be21 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -185,7 +185,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { lastFailoverStatus := 0 for { - selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs) + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制 if err != nil { if len(failedAccountIDs) == 0 { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) @@ -320,7 +320,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { for { // 选择支持该模型的账号 - selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs) + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID) if err != nil { if len(failedAccountIDs) == 0 { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 2dddb856..ec943e61 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -226,7 +226,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { lastFailoverStatus := 0 for { - selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs) + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs, "") // Gemini 不使用会话限制 if err != nil { if len(failedAccountIDs) == 0 { googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error()) diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index c4cfabc3..68e67656 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -186,8 +186,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { return } - // Generate session hash (from header for OpenAI) - sessionHash := h.gatewayService.GenerateSessionHash(c) + // Generate session hash (header first; fallback to prompt_cache_key) + sessionHash := h.gatewayService.GenerateSessionHash(c, reqBody) const maxAccountSwitches = 3 switchCount := 0 diff --git a/backend/internal/pkg/gemini/models.go b/backend/internal/pkg/gemini/models.go index e251c8d8..424e8ddb 100644 --- a/backend/internal/pkg/gemini/models.go +++ b/backend/internal/pkg/gemini/models.go @@ -16,14 +16,11 @@ type ModelsListResponse struct { func DefaultModels() []Model { methods := []string{"generateContent", "streamGenerateContent"} return []Model{ - {Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods}, - {Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods}, - {Name: "models/gemini-2.5-pro", SupportedGenerationMethods: methods}, - {Name: "models/gemini-2.5-flash", SupportedGenerationMethods: methods}, {Name: "models/gemini-2.0-flash", SupportedGenerationMethods: methods}, - {Name: "models/gemini-1.5-pro", SupportedGenerationMethods: methods}, - {Name: "models/gemini-1.5-flash", SupportedGenerationMethods: methods}, - {Name: "models/gemini-1.5-flash-8b", SupportedGenerationMethods: methods}, + {Name: "models/gemini-2.5-flash", SupportedGenerationMethods: methods}, + {Name: "models/gemini-2.5-pro", SupportedGenerationMethods: methods}, + {Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods}, + {Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods}, } } diff --git a/backend/internal/pkg/geminicli/models.go b/backend/internal/pkg/geminicli/models.go index 922988c7..08e69886 100644 --- a/backend/internal/pkg/geminicli/models.go +++ b/backend/internal/pkg/geminicli/models.go @@ -12,10 +12,10 @@ type Model struct { // DefaultModels is the curated Gemini model list used by the admin UI "test account" flow. var DefaultModels = []Model{ {ID: "gemini-2.0-flash", Type: "model", DisplayName: "Gemini 2.0 Flash", CreatedAt: ""}, - {ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""}, {ID: "gemini-2.5-flash", Type: "model", DisplayName: "Gemini 2.5 Flash", CreatedAt: ""}, - {ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""}, + {ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""}, {ID: "gemini-3-flash-preview", Type: "model", DisplayName: "Gemini 3 Flash Preview", CreatedAt: ""}, + {ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""}, } // DefaultTestModel is the default model to preselect in test flows. diff --git a/backend/internal/repository/ent.go b/backend/internal/repository/ent.go index 8005f114..d7d574e8 100644 --- a/backend/internal/repository/ent.go +++ b/backend/internal/repository/ent.go @@ -65,5 +65,18 @@ func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) { // 创建 Ent 客户端,绑定到已配置的数据库驱动。 client := ent.NewClient(ent.Driver(drv)) + + // SIMPLE 模式:启动时补齐各平台默认分组。 + // - anthropic/openai/gemini: 确保存在 -default + // - antigravity: 仅要求存在 >=2 个未软删除分组(用于 claude/gemini 混合调度场景) + if cfg.RunMode == config.RunModeSimple { + seedCtx, seedCancel := context.WithTimeout(context.Background(), 30*time.Second) + defer seedCancel() + if err := ensureSimpleModeDefaultGroups(seedCtx, client); err != nil { + _ = client.Close() + return nil, nil, err + } + } + return client, drv.DB(), nil } diff --git a/backend/internal/repository/ops_repo.go b/backend/internal/repository/ops_repo.go index 613c5bd5..b04154b7 100644 --- a/backend/internal/repository/ops_repo.go +++ b/backend/internal/repository/ops_repo.go @@ -992,7 +992,8 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) { } // View filter: errors vs excluded vs all. - // Excluded = upstream 429/529 and business-limited (quota/concurrency/billing) errors. + // Excluded = business-limited errors (quota/concurrency/billing). + // Upstream 429/529 are included in errors view to match SLA calculation. view := "" if filter != nil { view = strings.ToLower(strings.TrimSpace(filter.View)) @@ -1000,15 +1001,13 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) { switch view { case "", "errors": clauses = append(clauses, "COALESCE(is_business_limited,false) = false") - clauses = append(clauses, "COALESCE(upstream_status_code, status_code, 0) NOT IN (429, 529)") case "excluded": - clauses = append(clauses, "(COALESCE(is_business_limited,false) = true OR COALESCE(upstream_status_code, status_code, 0) IN (429, 529))") + clauses = append(clauses, "COALESCE(is_business_limited,false) = true") case "all": // no-op default: // treat unknown as default 'errors' clauses = append(clauses, "COALESCE(is_business_limited,false) = false") - clauses = append(clauses, "COALESCE(upstream_status_code, status_code, 0) NOT IN (429, 529)") } if len(filter.StatusCodes) > 0 { args = append(args, pq.Array(filter.StatusCodes)) diff --git a/backend/internal/repository/session_limit_cache.go b/backend/internal/repository/session_limit_cache.go new file mode 100644 index 00000000..16f2a69c --- /dev/null +++ b/backend/internal/repository/session_limit_cache.go @@ -0,0 +1,321 @@ +package repository + +import ( + "context" + "fmt" + "strconv" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/redis/go-redis/v9" +) + +// 会话限制缓存常量定义 +// +// 设计说明: +// 使用 Redis 有序集合(Sorted Set)跟踪每个账号的活跃会话: +// - Key: session_limit:account:{accountID} +// - Member: sessionUUID(从 metadata.user_id 中提取) +// - Score: Unix 时间戳(会话最后活跃时间) +// +// 通过 ZREMRANGEBYSCORE 自动清理过期会话,无需手动管理 TTL +const ( + // 会话限制键前缀 + // 格式: session_limit:account:{accountID} + sessionLimitKeyPrefix = "session_limit:account:" + + // 窗口费用缓存键前缀 + // 格式: window_cost:account:{accountID} + windowCostKeyPrefix = "window_cost:account:" + + // 窗口费用缓存 TTL(30秒) + windowCostCacheTTL = 30 * time.Second +) + +var ( + // registerSessionScript 注册会话活动 + // 使用 Redis TIME 命令获取服务器时间,避免多实例时钟不同步 + // KEYS[1] = session_limit:account:{accountID} + // ARGV[1] = maxSessions + // ARGV[2] = idleTimeout(秒) + // ARGV[3] = sessionUUID + // 返回: 1 = 允许, 0 = 拒绝 + registerSessionScript = redis.NewScript(` + local key = KEYS[1] + local maxSessions = tonumber(ARGV[1]) + local idleTimeout = tonumber(ARGV[2]) + local sessionUUID = ARGV[3] + + -- 使用 Redis 服务器时间,确保多实例时钟一致 + local timeResult = redis.call('TIME') + local now = tonumber(timeResult[1]) + local expireBefore = now - idleTimeout + + -- 清理过期会话 + redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore) + + -- 检查会话是否已存在(支持刷新时间戳) + local exists = redis.call('ZSCORE', key, sessionUUID) + if exists ~= false then + -- 会话已存在,刷新时间戳 + redis.call('ZADD', key, now, sessionUUID) + redis.call('EXPIRE', key, idleTimeout + 60) + return 1 + end + + -- 检查是否达到会话数量上限 + local count = redis.call('ZCARD', key) + if count < maxSessions then + -- 未达上限,添加新会话 + redis.call('ZADD', key, now, sessionUUID) + redis.call('EXPIRE', key, idleTimeout + 60) + return 1 + end + + -- 达到上限,拒绝新会话 + return 0 + `) + + // refreshSessionScript 刷新会话时间戳 + // KEYS[1] = session_limit:account:{accountID} + // ARGV[1] = idleTimeout(秒) + // ARGV[2] = sessionUUID + refreshSessionScript = redis.NewScript(` + local key = KEYS[1] + local idleTimeout = tonumber(ARGV[1]) + local sessionUUID = ARGV[2] + + local timeResult = redis.call('TIME') + local now = tonumber(timeResult[1]) + + -- 检查会话是否存在 + local exists = redis.call('ZSCORE', key, sessionUUID) + if exists ~= false then + redis.call('ZADD', key, now, sessionUUID) + redis.call('EXPIRE', key, idleTimeout + 60) + end + return 1 + `) + + // getActiveSessionCountScript 获取活跃会话数 + // KEYS[1] = session_limit:account:{accountID} + // ARGV[1] = idleTimeout(秒) + getActiveSessionCountScript = redis.NewScript(` + local key = KEYS[1] + local idleTimeout = tonumber(ARGV[1]) + + local timeResult = redis.call('TIME') + local now = tonumber(timeResult[1]) + local expireBefore = now - idleTimeout + + -- 清理过期会话 + redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore) + + return redis.call('ZCARD', key) + `) + + // isSessionActiveScript 检查会话是否活跃 + // KEYS[1] = session_limit:account:{accountID} + // ARGV[1] = idleTimeout(秒) + // ARGV[2] = sessionUUID + isSessionActiveScript = redis.NewScript(` + local key = KEYS[1] + local idleTimeout = tonumber(ARGV[1]) + local sessionUUID = ARGV[2] + + local timeResult = redis.call('TIME') + local now = tonumber(timeResult[1]) + local expireBefore = now - idleTimeout + + -- 获取会话的时间戳 + local score = redis.call('ZSCORE', key, sessionUUID) + if score == false then + return 0 + end + + -- 检查是否过期 + if tonumber(score) <= expireBefore then + return 0 + end + + return 1 + `) +) + +type sessionLimitCache struct { + rdb *redis.Client + defaultIdleTimeout time.Duration // 默认空闲超时(用于 GetActiveSessionCount) +} + +// NewSessionLimitCache 创建会话限制缓存 +// defaultIdleTimeoutMinutes: 默认空闲超时时间(分钟),用于无参数查询 +func NewSessionLimitCache(rdb *redis.Client, defaultIdleTimeoutMinutes int) service.SessionLimitCache { + if defaultIdleTimeoutMinutes <= 0 { + defaultIdleTimeoutMinutes = 5 // 默认 5 分钟 + } + return &sessionLimitCache{ + rdb: rdb, + defaultIdleTimeout: time.Duration(defaultIdleTimeoutMinutes) * time.Minute, + } +} + +// sessionLimitKey 生成会话限制的 Redis 键 +func sessionLimitKey(accountID int64) string { + return fmt.Sprintf("%s%d", sessionLimitKeyPrefix, accountID) +} + +// windowCostKey 生成窗口费用缓存的 Redis 键 +func windowCostKey(accountID int64) string { + return fmt.Sprintf("%s%d", windowCostKeyPrefix, accountID) +} + +// RegisterSession 注册会话活动 +func (c *sessionLimitCache) RegisterSession(ctx context.Context, accountID int64, sessionUUID string, maxSessions int, idleTimeout time.Duration) (bool, error) { + if sessionUUID == "" || maxSessions <= 0 { + return true, nil // 无效参数,默认允许 + } + + key := sessionLimitKey(accountID) + idleTimeoutSeconds := int(idleTimeout.Seconds()) + if idleTimeoutSeconds <= 0 { + idleTimeoutSeconds = int(c.defaultIdleTimeout.Seconds()) + } + + result, err := registerSessionScript.Run(ctx, c.rdb, []string{key}, maxSessions, idleTimeoutSeconds, sessionUUID).Int() + if err != nil { + return true, err // 失败开放:缓存错误时允许请求通过 + } + return result == 1, nil +} + +// RefreshSession 刷新会话时间戳 +func (c *sessionLimitCache) RefreshSession(ctx context.Context, accountID int64, sessionUUID string, idleTimeout time.Duration) error { + if sessionUUID == "" { + return nil + } + + key := sessionLimitKey(accountID) + idleTimeoutSeconds := int(idleTimeout.Seconds()) + if idleTimeoutSeconds <= 0 { + idleTimeoutSeconds = int(c.defaultIdleTimeout.Seconds()) + } + + _, err := refreshSessionScript.Run(ctx, c.rdb, []string{key}, idleTimeoutSeconds, sessionUUID).Result() + return err +} + +// GetActiveSessionCount 获取活跃会话数 +func (c *sessionLimitCache) GetActiveSessionCount(ctx context.Context, accountID int64) (int, error) { + key := sessionLimitKey(accountID) + idleTimeoutSeconds := int(c.defaultIdleTimeout.Seconds()) + + result, err := getActiveSessionCountScript.Run(ctx, c.rdb, []string{key}, idleTimeoutSeconds).Int() + if err != nil { + return 0, err + } + return result, nil +} + +// GetActiveSessionCountBatch 批量获取多个账号的活跃会话数 +func (c *sessionLimitCache) GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) { + if len(accountIDs) == 0 { + return make(map[int64]int), nil + } + + results := make(map[int64]int, len(accountIDs)) + + // 使用 pipeline 批量执行 + pipe := c.rdb.Pipeline() + idleTimeoutSeconds := int(c.defaultIdleTimeout.Seconds()) + + cmds := make(map[int64]*redis.Cmd, len(accountIDs)) + for _, accountID := range accountIDs { + key := sessionLimitKey(accountID) + cmds[accountID] = getActiveSessionCountScript.Run(ctx, pipe, []string{key}, idleTimeoutSeconds) + } + + // 执行 pipeline,即使部分失败也尝试获取成功的结果 + _, _ = pipe.Exec(ctx) + + for accountID, cmd := range cmds { + if result, err := cmd.Int(); err == nil { + results[accountID] = result + } + } + + return results, nil +} + +// IsSessionActive 检查会话是否活跃 +func (c *sessionLimitCache) IsSessionActive(ctx context.Context, accountID int64, sessionUUID string) (bool, error) { + if sessionUUID == "" { + return false, nil + } + + key := sessionLimitKey(accountID) + idleTimeoutSeconds := int(c.defaultIdleTimeout.Seconds()) + + result, err := isSessionActiveScript.Run(ctx, c.rdb, []string{key}, idleTimeoutSeconds, sessionUUID).Int() + if err != nil { + return false, err + } + return result == 1, nil +} + +// ========== 5h窗口费用缓存实现 ========== + +// GetWindowCost 获取缓存的窗口费用 +func (c *sessionLimitCache) GetWindowCost(ctx context.Context, accountID int64) (float64, bool, error) { + key := windowCostKey(accountID) + val, err := c.rdb.Get(ctx, key).Float64() + if err == redis.Nil { + return 0, false, nil // 缓存未命中 + } + if err != nil { + return 0, false, err + } + return val, true, nil +} + +// SetWindowCost 设置窗口费用缓存 +func (c *sessionLimitCache) SetWindowCost(ctx context.Context, accountID int64, cost float64) error { + key := windowCostKey(accountID) + return c.rdb.Set(ctx, key, cost, windowCostCacheTTL).Err() +} + +// GetWindowCostBatch 批量获取窗口费用缓存 +func (c *sessionLimitCache) GetWindowCostBatch(ctx context.Context, accountIDs []int64) (map[int64]float64, error) { + if len(accountIDs) == 0 { + return make(map[int64]float64), nil + } + + // 构建批量查询的 keys + keys := make([]string, len(accountIDs)) + for i, accountID := range accountIDs { + keys[i] = windowCostKey(accountID) + } + + // 使用 MGET 批量获取 + vals, err := c.rdb.MGet(ctx, keys...).Result() + if err != nil { + return nil, err + } + + results := make(map[int64]float64, len(accountIDs)) + for i, val := range vals { + if val == nil { + continue // 缓存未命中 + } + // 尝试解析为 float64 + switch v := val.(type) { + case string: + if cost, err := strconv.ParseFloat(v, 64); err == nil { + results[accountIDs[i]] = cost + } + case float64: + results[accountIDs[i]] = v + } + } + + return results, nil +} diff --git a/backend/internal/repository/simple_mode_default_groups.go b/backend/internal/repository/simple_mode_default_groups.go new file mode 100644 index 00000000..56309184 --- /dev/null +++ b/backend/internal/repository/simple_mode_default_groups.go @@ -0,0 +1,82 @@ +package repository + +import ( + "context" + "fmt" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/group" + "github.com/Wei-Shaw/sub2api/internal/service" +) + +func ensureSimpleModeDefaultGroups(ctx context.Context, client *dbent.Client) error { + if client == nil { + return fmt.Errorf("nil ent client") + } + + requiredByPlatform := map[string]int{ + service.PlatformAnthropic: 1, + service.PlatformOpenAI: 1, + service.PlatformGemini: 1, + service.PlatformAntigravity: 2, + } + + for platform, minCount := range requiredByPlatform { + count, err := client.Group.Query(). + Where(group.PlatformEQ(platform), group.DeletedAtIsNil()). + Count(ctx) + if err != nil { + return fmt.Errorf("count groups for platform %s: %w", platform, err) + } + + if platform == service.PlatformAntigravity { + if count < minCount { + for i := count; i < minCount; i++ { + name := fmt.Sprintf("%s-default-%d", platform, i+1) + if err := createGroupIfNotExists(ctx, client, name, platform); err != nil { + return err + } + } + } + continue + } + + // Non-antigravity platforms: ensure -default exists. + name := platform + "-default" + if err := createGroupIfNotExists(ctx, client, name, platform); err != nil { + return err + } + } + + return nil +} + +func createGroupIfNotExists(ctx context.Context, client *dbent.Client, name, platform string) error { + exists, err := client.Group.Query(). + Where(group.NameEQ(name), group.DeletedAtIsNil()). + Exist(ctx) + if err != nil { + return fmt.Errorf("check group exists %s: %w", name, err) + } + if exists { + return nil + } + + _, err = client.Group.Create(). + SetName(name). + SetDescription("Auto-created default group"). + SetPlatform(platform). + SetStatus(service.StatusActive). + SetSubscriptionType(service.SubscriptionTypeStandard). + SetRateMultiplier(1.0). + SetIsExclusive(false). + Save(ctx) + if err != nil { + if dbent.IsConstraintError(err) { + // Concurrent server startups may race on creation; treat as success. + return nil + } + return fmt.Errorf("create default group %s: %w", name, err) + } + return nil +} diff --git a/backend/internal/repository/simple_mode_default_groups_integration_test.go b/backend/internal/repository/simple_mode_default_groups_integration_test.go new file mode 100644 index 00000000..3327257b --- /dev/null +++ b/backend/internal/repository/simple_mode_default_groups_integration_test.go @@ -0,0 +1,84 @@ +//go:build integration + +package repository + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/ent/group" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +func TestEnsureSimpleModeDefaultGroups_CreatesMissingDefaults(t *testing.T) { + ctx := context.Background() + tx := testEntTx(t) + client := tx.Client() + + seedCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + require.NoError(t, ensureSimpleModeDefaultGroups(seedCtx, client)) + + assertGroupExists := func(name string) { + exists, err := client.Group.Query().Where(group.NameEQ(name), group.DeletedAtIsNil()).Exist(seedCtx) + require.NoError(t, err) + require.True(t, exists, "expected group %s to exist", name) + } + + assertGroupExists(service.PlatformAnthropic + "-default") + assertGroupExists(service.PlatformOpenAI + "-default") + assertGroupExists(service.PlatformGemini + "-default") + assertGroupExists(service.PlatformAntigravity + "-default-1") + assertGroupExists(service.PlatformAntigravity + "-default-2") +} + +func TestEnsureSimpleModeDefaultGroups_IgnoresSoftDeletedGroups(t *testing.T) { + ctx := context.Background() + tx := testEntTx(t) + client := tx.Client() + + seedCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + // Create and then soft-delete an anthropic default group. + g, err := client.Group.Create(). + SetName(service.PlatformAnthropic + "-default"). + SetPlatform(service.PlatformAnthropic). + SetStatus(service.StatusActive). + SetSubscriptionType(service.SubscriptionTypeStandard). + SetRateMultiplier(1.0). + SetIsExclusive(false). + Save(seedCtx) + require.NoError(t, err) + + _, err = client.Group.Delete().Where(group.IDEQ(g.ID)).Exec(seedCtx) + require.NoError(t, err) + + require.NoError(t, ensureSimpleModeDefaultGroups(seedCtx, client)) + + // New active one should exist. + count, err := client.Group.Query().Where(group.NameEQ(service.PlatformAnthropic+"-default"), group.DeletedAtIsNil()).Count(seedCtx) + require.NoError(t, err) + require.Equal(t, 1, count) +} + +func TestEnsureSimpleModeDefaultGroups_AntigravityNeedsTwoGroupsOnlyByCount(t *testing.T) { + ctx := context.Background() + tx := testEntTx(t) + client := tx.Client() + + seedCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + mustCreateGroup(t, client, &service.Group{Name: "ag-custom-1-" + time.Now().Format(time.RFC3339Nano), Platform: service.PlatformAntigravity}) + mustCreateGroup(t, client, &service.Group{Name: "ag-custom-2-" + time.Now().Format(time.RFC3339Nano), Platform: service.PlatformAntigravity}) + + require.NoError(t, ensureSimpleModeDefaultGroups(seedCtx, client)) + + count, err := client.Group.Query().Where(group.PlatformEQ(service.PlatformAntigravity), group.DeletedAtIsNil()).Count(seedCtx) + require.NoError(t, err) + require.GreaterOrEqual(t, count, 2) +} diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index 9dc91eca..7a8d85f4 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -37,6 +37,16 @@ func ProvidePricingRemoteClient(cfg *config.Config) service.PricingRemoteClient return NewPricingRemoteClient(cfg.Update.ProxyURL) } +// ProvideSessionLimitCache 创建会话限制缓存 +// 用于 Anthropic OAuth/SetupToken 账号的并发会话数量控制 +func ProvideSessionLimitCache(rdb *redis.Client, cfg *config.Config) service.SessionLimitCache { + defaultIdleTimeoutMinutes := 5 // 默认 5 分钟空闲超时 + if cfg != nil && cfg.Gateway.SessionIdleTimeoutMinutes > 0 { + defaultIdleTimeoutMinutes = cfg.Gateway.SessionIdleTimeoutMinutes + } + return NewSessionLimitCache(rdb, defaultIdleTimeoutMinutes) +} + // ProviderSet is the Wire provider set for all repositories var ProviderSet = wire.NewSet( NewUserRepository, @@ -62,6 +72,7 @@ var ProviderSet = wire.NewSet( NewTempUnschedCache, NewTimeoutCounterCache, ProvideConcurrencyCache, + ProvideSessionLimitCache, NewDashboardCache, NewEmailCache, NewIdentityCache, diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 7076f8c5..0a549b19 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -441,7 +441,7 @@ func newContractDeps(t *testing.T) *contractDeps { apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil) - adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil) + adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) jwtAuth := func(c *gin.Context) { c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{ diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index 0d7a9cf9..36ba0bcc 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -557,3 +557,141 @@ func (a *Account) IsMixedSchedulingEnabled() bool { } return false } + +// WindowCostSchedulability 窗口费用调度状态 +type WindowCostSchedulability int + +const ( + // WindowCostSchedulable 可正常调度 + WindowCostSchedulable WindowCostSchedulability = iota + // WindowCostStickyOnly 仅允许粘性会话 + WindowCostStickyOnly + // WindowCostNotSchedulable 完全不可调度 + WindowCostNotSchedulable +) + +// IsAnthropicOAuthOrSetupToken 判断是否为 Anthropic OAuth 或 SetupToken 类型账号 +// 仅这两类账号支持 5h 窗口额度控制和会话数量控制 +func (a *Account) IsAnthropicOAuthOrSetupToken() bool { + return a.Platform == PlatformAnthropic && (a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken) +} + +// GetWindowCostLimit 获取 5h 窗口费用阈值(美元) +// 返回 0 表示未启用 +func (a *Account) GetWindowCostLimit() float64 { + if a.Extra == nil { + return 0 + } + if v, ok := a.Extra["window_cost_limit"]; ok { + return parseExtraFloat64(v) + } + return 0 +} + +// GetWindowCostStickyReserve 获取粘性会话预留额度(美元) +// 默认值为 10 +func (a *Account) GetWindowCostStickyReserve() float64 { + if a.Extra == nil { + return 10.0 + } + if v, ok := a.Extra["window_cost_sticky_reserve"]; ok { + val := parseExtraFloat64(v) + if val > 0 { + return val + } + } + return 10.0 +} + +// GetMaxSessions 获取最大并发会话数 +// 返回 0 表示未启用 +func (a *Account) GetMaxSessions() int { + if a.Extra == nil { + return 0 + } + if v, ok := a.Extra["max_sessions"]; ok { + return parseExtraInt(v) + } + return 0 +} + +// GetSessionIdleTimeoutMinutes 获取会话空闲超时分钟数 +// 默认值为 5 分钟 +func (a *Account) GetSessionIdleTimeoutMinutes() int { + if a.Extra == nil { + return 5 + } + if v, ok := a.Extra["session_idle_timeout_minutes"]; ok { + val := parseExtraInt(v) + if val > 0 { + return val + } + } + return 5 +} + +// CheckWindowCostSchedulability 根据当前窗口费用检查调度状态 +// - 费用 < 阈值: WindowCostSchedulable(可正常调度) +// - 费用 >= 阈值 且 < 阈值+预留: WindowCostStickyOnly(仅粘性会话) +// - 费用 >= 阈值+预留: WindowCostNotSchedulable(不可调度) +func (a *Account) CheckWindowCostSchedulability(currentWindowCost float64) WindowCostSchedulability { + limit := a.GetWindowCostLimit() + if limit <= 0 { + return WindowCostSchedulable + } + + if currentWindowCost < limit { + return WindowCostSchedulable + } + + stickyReserve := a.GetWindowCostStickyReserve() + if currentWindowCost < limit+stickyReserve { + return WindowCostStickyOnly + } + + return WindowCostNotSchedulable +} + +// parseExtraFloat64 从 extra 字段解析 float64 值 +func parseExtraFloat64(value any) float64 { + switch v := value.(type) { + case float64: + return v + case float32: + return float64(v) + case int: + return float64(v) + case int64: + return float64(v) + case json.Number: + if f, err := v.Float64(); err == nil { + return f + } + case string: + if f, err := strconv.ParseFloat(strings.TrimSpace(v), 64); err == nil { + return f + } + } + return 0 +} + +// parseExtraInt 从 extra 字段解析 int 值 +func parseExtraInt(value any) int { + switch v := value.(type) { + case int: + return v + case int64: + return int(v) + case float64: + return int(v) + case json.Number: + if i, err := v.Int64(); err == nil { + return int(i) + } + case string: + if i, err := strconv.Atoi(strings.TrimSpace(v)); err == nil { + return i + } + } + return 0 +} diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go index f1c07d5e..e3c0974e 100644 --- a/backend/internal/service/account_usage_service.go +++ b/backend/internal/service/account_usage_service.go @@ -575,3 +575,9 @@ func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64 }, } } + +// GetAccountWindowStats 获取账号在指定时间窗口内的使用统计 +// 用于账号列表页面显示当前窗口费用 +func (s *AccountUsageService) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) { + return s.usageLogRepo.GetAccountWindowStats(ctx, accountID, startTime) +} diff --git a/backend/internal/service/antigravity_model_mapping_test.go b/backend/internal/service/antigravity_model_mapping_test.go index 39000e4f..179a3520 100644 --- a/backend/internal/service/antigravity_model_mapping_test.go +++ b/backend/internal/service/antigravity_model_mapping_test.go @@ -30,7 +30,7 @@ func TestIsAntigravityModelSupported(t *testing.T) { {"可映射 - claude-3-haiku-20240307", "claude-3-haiku-20240307", true}, // Gemini 前缀透传 - {"Gemini前缀 - gemini-1.5-pro", "gemini-1.5-pro", true}, + {"Gemini前缀 - gemini-2.5-pro", "gemini-2.5-pro", true}, {"Gemini前缀 - gemini-unknown-model", "gemini-unknown-model", true}, {"Gemini前缀 - gemini-future-version", "gemini-future-version", true}, @@ -142,10 +142,10 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) { expected: "gemini-2.5-flash", }, { - name: "Gemini透传 - gemini-1.5-pro", - requestedModel: "gemini-1.5-pro", + name: "Gemini透传 - gemini-2.5-pro", + requestedModel: "gemini-2.5-pro", accountMapping: nil, - expected: "gemini-1.5-pro", + expected: "gemini-2.5-pro", }, { name: "Gemini透传 - gemini-future-model", diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index 76d73286..f543ef1a 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -1052,7 +1052,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: nil, // No concurrency service } - result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil) + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "") require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Account) @@ -1105,7 +1105,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: nil, // legacy path } - result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-b", nil) + result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-b", nil, "") require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Account) @@ -1137,7 +1137,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: nil, } - result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil) + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "") require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Account) @@ -1169,7 +1169,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { } excludedIDs := map[int64]struct{}{1: {}} - result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs) + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs, "") require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Account) @@ -1203,7 +1203,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: NewConcurrencyService(concurrencyCache), } - result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil) + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "") require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Account) @@ -1239,7 +1239,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: NewConcurrencyService(concurrencyCache), } - result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil) + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "") require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Account) @@ -1266,7 +1266,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: nil, } - result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil) + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "") require.Error(t, err) require.Nil(t, result) require.Contains(t, err.Error(), "no available accounts") @@ -1298,7 +1298,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: nil, } - result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil) + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "") require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Account) @@ -1331,7 +1331,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: nil, } - result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil) + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "") require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Account) diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 1e3221d3..5068767c 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -176,6 +176,7 @@ type GatewayService struct { deferredService *DeferredService concurrencyService *ConcurrencyService claudeTokenProvider *ClaudeTokenProvider + sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken) } // NewGatewayService creates a new GatewayService @@ -196,6 +197,7 @@ func NewGatewayService( httpUpstream HTTPUpstream, deferredService *DeferredService, claudeTokenProvider *ClaudeTokenProvider, + sessionLimitCache SessionLimitCache, ) *GatewayService { return &GatewayService{ accountRepo: accountRepo, @@ -214,6 +216,7 @@ func NewGatewayService( httpUpstream: httpUpstream, deferredService: deferredService, claudeTokenProvider: claudeTokenProvider, + sessionLimitCache: sessionLimitCache, } } @@ -407,8 +410,12 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context } // SelectAccountWithLoadAwareness selects account with load-awareness and wait plan. -func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) { +// metadataUserID: 原始 metadata.user_id 字段(用于提取会话 UUID 进行会话数量限制) +func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string) (*AccountSelectionResult, error) { cfg := s.schedulingConfig() + // 提取会话 UUID(用于会话数量限制) + sessionUUID := extractSessionUUID(metadataUserID) + var stickyAccountID int64 if sessionHash != "" && s.cache != nil { if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash); err == nil { @@ -527,7 +534,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if len(routingAccountIDs) > 0 && s.concurrencyService != nil { // 1. 过滤出路由列表中可调度的账号 var routingCandidates []*Account - var filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping int + var filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping, filteredWindowCost int for _, routingAccountID := range routingAccountIDs { if isExcluded(routingAccountID) { filteredExcluded++ @@ -554,13 +561,18 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro filteredModelMapping++ continue } + // 窗口费用检查(非粘性会话路径) + if !s.isAccountSchedulableForWindowCost(ctx, account, false) { + filteredWindowCost++ + continue + } routingCandidates = append(routingCandidates, account) } if s.debugModelRoutingEnabled() { - log.Printf("[ModelRoutingDebug] routed candidates: group_id=%v model=%s routed=%d candidates=%d filtered(excluded=%d missing=%d unsched=%d platform=%d model_scope=%d model_mapping=%d)", + log.Printf("[ModelRoutingDebug] routed candidates: group_id=%v model=%s routed=%d candidates=%d filtered(excluded=%d missing=%d unsched=%d platform=%d model_scope=%d model_mapping=%d window_cost=%d)", derefGroupID(groupID), requestedModel, len(routingAccountIDs), len(routingCandidates), - filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping) + filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping, filteredWindowCost) } if len(routingCandidates) > 0 { @@ -573,18 +585,25 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if stickyAccount.IsSchedulable() && s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) && stickyAccount.IsSchedulableForModel(requestedModel) && - (requestedModel == "" || s.isModelSupportedByAccount(stickyAccount, requestedModel)) { + (requestedModel == "" || s.isModelSupportedByAccount(stickyAccount, requestedModel)) && + s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) { // 粘性会话窗口费用检查 result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency) if err == nil && result.Acquired { - _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL) - if s.debugModelRoutingEnabled() { - log.Printf("[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID) + // 会话数量限制检查 + if !s.checkAndRegisterSession(ctx, stickyAccount, sessionUUID) { + result.ReleaseFunc() // 释放槽位 + // 继续到负载感知选择 + } else { + _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL) + if s.debugModelRoutingEnabled() { + log.Printf("[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID) + } + return &AccountSelectionResult{ + Account: stickyAccount, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil } - return &AccountSelectionResult{ - Account: stickyAccount, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil } waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID) @@ -657,6 +676,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro for _, item := range routingAvailable { result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency) if err == nil && result.Acquired { + // 会话数量限制检查 + if !s.checkAndRegisterSession(ctx, item.account, sessionUUID) { + result.ReleaseFunc() // 释放槽位,继续尝试下一个账号 + continue + } if sessionHash != "" && s.cache != nil { _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL) } @@ -699,15 +723,21 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if ok && s.isAccountInGroup(account, groupID) && s.isAccountAllowedForPlatform(account, platform, useMixed) && account.IsSchedulableForModel(requestedModel) && - (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { + (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) && + s.isAccountSchedulableForWindowCost(ctx, account, true) { // 粘性会话窗口费用检查 result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) if err == nil && result.Acquired { - _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL) - return &AccountSelectionResult{ - Account: account, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil + // 会话数量限制检查 + if !s.checkAndRegisterSession(ctx, account, sessionUUID) { + result.ReleaseFunc() // 释放槽位,继续到 Layer 2 + } else { + _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL) + return &AccountSelectionResult{ + Account: account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil + } } waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID) @@ -748,6 +778,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) { continue } + // 窗口费用检查(非粘性会话路径) + if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { + continue + } candidates = append(candidates, acc) } @@ -765,7 +799,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads) if err != nil { - if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); ok { + if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth, sessionUUID); ok { return result, nil } } else { @@ -814,6 +848,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro for _, item := range available { result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency) if err == nil && result.Acquired { + // 会话数量限制检查 + if !s.checkAndRegisterSession(ctx, item.account, sessionUUID) { + result.ReleaseFunc() // 释放槽位,继续尝试下一个账号 + continue + } if sessionHash != "" && s.cache != nil { _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL) } @@ -843,13 +882,18 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro return nil, errors.New("no available accounts") } -func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) { +func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool, sessionUUID string) (*AccountSelectionResult, bool) { ordered := append([]*Account(nil), candidates...) sortAccountsByPriorityAndLastUsed(ordered, preferOAuth) for _, acc := range ordered { result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency) if err == nil && result.Acquired { + // 会话数量限制检查 + if !s.checkAndRegisterSession(ctx, acc, sessionUUID) { + result.ReleaseFunc() // 释放槽位,继续尝试下一个账号 + continue + } if sessionHash != "" && s.cache != nil { _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, acc.ID, stickySessionTTL) } @@ -1081,6 +1125,107 @@ func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID in return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency) } +// isAccountSchedulableForWindowCost 检查账号是否可根据窗口费用进行调度 +// 仅适用于 Anthropic OAuth/SetupToken 账号 +// 返回 true 表示可调度,false 表示不可调度 +func (s *GatewayService) isAccountSchedulableForWindowCost(ctx context.Context, account *Account, isSticky bool) bool { + // 只检查 Anthropic OAuth/SetupToken 账号 + if !account.IsAnthropicOAuthOrSetupToken() { + return true + } + + limit := account.GetWindowCostLimit() + if limit <= 0 { + return true // 未启用窗口费用限制 + } + + // 尝试从缓存获取窗口费用 + var currentCost float64 + if s.sessionLimitCache != nil { + if cost, hit, err := s.sessionLimitCache.GetWindowCost(ctx, account.ID); err == nil && hit { + currentCost = cost + goto checkSchedulability + } + } + + // 缓存未命中,从数据库查询 + { + var startTime time.Time + if account.SessionWindowStart != nil { + startTime = *account.SessionWindowStart + } else { + startTime = time.Now().Add(-5 * time.Hour) + } + + stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime) + if err != nil { + // 失败开放:查询失败时允许调度 + return true + } + + // 使用标准费用(不含账号倍率) + currentCost = stats.StandardCost + + // 设置缓存(忽略错误) + if s.sessionLimitCache != nil { + _ = s.sessionLimitCache.SetWindowCost(ctx, account.ID, currentCost) + } + } + +checkSchedulability: + schedulability := account.CheckWindowCostSchedulability(currentCost) + + switch schedulability { + case WindowCostSchedulable: + return true + case WindowCostStickyOnly: + return isSticky + case WindowCostNotSchedulable: + return false + } + return true +} + +// checkAndRegisterSession 检查并注册会话,用于会话数量限制 +// 仅适用于 Anthropic OAuth/SetupToken 账号 +// 返回 true 表示允许(在限制内或会话已存在),false 表示拒绝(超出限制且是新会话) +func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *Account, sessionUUID string) bool { + // 只检查 Anthropic OAuth/SetupToken 账号 + if !account.IsAnthropicOAuthOrSetupToken() { + return true + } + + maxSessions := account.GetMaxSessions() + if maxSessions <= 0 || sessionUUID == "" { + return true // 未启用会话限制或无会话ID + } + + if s.sessionLimitCache == nil { + return true // 缓存不可用时允许通过 + } + + idleTimeout := time.Duration(account.GetSessionIdleTimeoutMinutes()) * time.Minute + + allowed, err := s.sessionLimitCache.RegisterSession(ctx, account.ID, sessionUUID, maxSessions, idleTimeout) + if err != nil { + // 失败开放:缓存错误时允许通过 + return true + } + return allowed +} + +// extractSessionUUID 从 metadata.user_id 中提取会话 UUID +// 格式: user_{64位hex}_account__session_{uuid} +func extractSessionUUID(metadataUserID string) string { + if metadataUserID == "" { + return "" + } + if match := sessionIDRegex.FindStringSubmatch(metadataUserID); len(match) > 1 { + return match[1] + } + return "" +} + func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) { if s.schedulerSnapshot != nil { return s.schedulerSnapshot.GetAccount(ctx, accountID) diff --git a/backend/internal/service/gemini_multiplatform_test.go b/backend/internal/service/gemini_multiplatform_test.go index 03f5d757..f2ea5859 100644 --- a/backend/internal/service/gemini_multiplatform_test.go +++ b/backend/internal/service/gemini_multiplatform_test.go @@ -599,7 +599,7 @@ func TestGeminiMessagesCompatService_isModelSupportedByAccount(t *testing.T) { name: "Gemini平台-有映射配置-只支持配置的模型", account: &Account{ Platform: PlatformGemini, - Credentials: map[string]any{"model_mapping": map[string]any{"gemini-1.5-pro": "x"}}, + Credentials: map[string]any{"model_mapping": map[string]any{"gemini-2.5-pro": "x"}}, }, model: "gemini-2.5-flash", expected: false, diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go index 264bdf95..48c72593 100644 --- a/backend/internal/service/openai_codex_transform.go +++ b/backend/internal/service/openai_codex_transform.go @@ -394,19 +394,35 @@ func normalizeCodexTools(reqBody map[string]any) bool { } modified := false - for idx, tool := range tools { + validTools := make([]any, 0, len(tools)) + + for _, tool := range tools { toolMap, ok := tool.(map[string]any) if !ok { + // Keep unknown structure as-is to avoid breaking upstream behavior. + validTools = append(validTools, tool) continue } toolType, _ := toolMap["type"].(string) - if strings.TrimSpace(toolType) != "function" { + toolType = strings.TrimSpace(toolType) + if toolType != "function" { + validTools = append(validTools, toolMap) continue } - function, ok := toolMap["function"].(map[string]any) - if !ok { + // OpenAI Responses-style tools use top-level name/parameters. + if name, ok := toolMap["name"].(string); ok && strings.TrimSpace(name) != "" { + validTools = append(validTools, toolMap) + continue + } + + // ChatCompletions-style tools use {type:"function", function:{...}}. + functionValue, hasFunction := toolMap["function"] + function, ok := functionValue.(map[string]any) + if !hasFunction || functionValue == nil || !ok || function == nil { + // Drop invalid function tools. + modified = true continue } @@ -435,11 +451,11 @@ func normalizeCodexTools(reqBody map[string]any) bool { } } - tools[idx] = toolMap + validTools = append(validTools, toolMap) } if modified { - reqBody["tools"] = tools + reqBody["tools"] = validTools } return modified diff --git a/backend/internal/service/openai_codex_transform_test.go b/backend/internal/service/openai_codex_transform_test.go index 0ff9485a..4cd72ab6 100644 --- a/backend/internal/service/openai_codex_transform_test.go +++ b/backend/internal/service/openai_codex_transform_test.go @@ -129,6 +129,37 @@ func TestFilterCodexInput_RemovesItemReferenceWhenNotPreserved(t *testing.T) { require.False(t, hasID) } +func TestApplyCodexOAuthTransform_NormalizeCodexTools_PreservesResponsesFunctionTools(t *testing.T) { + setupCodexCache(t) + + reqBody := map[string]any{ + "model": "gpt-5.1", + "tools": []any{ + map[string]any{ + "type": "function", + "name": "bash", + "description": "desc", + "parameters": map[string]any{"type": "object"}, + }, + map[string]any{ + "type": "function", + "function": nil, + }, + }, + } + + applyCodexOAuthTransform(reqBody) + + tools, ok := reqBody["tools"].([]any) + require.True(t, ok) + require.Len(t, tools, 1) + + first, ok := tools[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "function", first["type"]) + require.Equal(t, "bash", first["name"]) +} + func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) { // 空 input 应保持为空且不触发异常。 setupCodexCache(t) diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index c7d94882..a3c4a239 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -133,12 +133,30 @@ func NewOpenAIGatewayService( } } -// GenerateSessionHash generates session hash from header (OpenAI uses session_id header) -func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context) string { - sessionID := c.GetHeader("session_id") +// GenerateSessionHash generates a sticky-session hash for OpenAI requests. +// +// Priority: +// 1. Header: session_id +// 2. Header: conversation_id +// 3. Body: prompt_cache_key (opencode) +func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, reqBody map[string]any) string { + if c == nil { + return "" + } + + sessionID := strings.TrimSpace(c.GetHeader("session_id")) + if sessionID == "" { + sessionID = strings.TrimSpace(c.GetHeader("conversation_id")) + } + if sessionID == "" && reqBody != nil { + if v, ok := reqBody["prompt_cache_key"].(string); ok { + sessionID = strings.TrimSpace(v) + } + } if sessionID == "" { return "" } + hash := sha256.Sum256([]byte(sessionID)) return hex.EncodeToString(hash[:]) } diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go index 42b88b7d..a34b8045 100644 --- a/backend/internal/service/openai_gateway_service_test.go +++ b/backend/internal/service/openai_gateway_service_test.go @@ -49,6 +49,49 @@ func (c stubConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts return out, nil } +func TestOpenAIGatewayService_GenerateSessionHash_Priority(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + + svc := &OpenAIGatewayService{} + + // 1) session_id header wins + c.Request.Header.Set("session_id", "sess-123") + c.Request.Header.Set("conversation_id", "conv-456") + h1 := svc.GenerateSessionHash(c, map[string]any{"prompt_cache_key": "ses_aaa"}) + if h1 == "" { + t.Fatalf("expected non-empty hash") + } + + // 2) conversation_id used when session_id absent + c.Request.Header.Del("session_id") + h2 := svc.GenerateSessionHash(c, map[string]any{"prompt_cache_key": "ses_aaa"}) + if h2 == "" { + t.Fatalf("expected non-empty hash") + } + if h1 == h2 { + t.Fatalf("expected different hashes for different keys") + } + + // 3) prompt_cache_key used when both headers absent + c.Request.Header.Del("conversation_id") + h3 := svc.GenerateSessionHash(c, map[string]any{"prompt_cache_key": "ses_aaa"}) + if h3 == "" { + t.Fatalf("expected non-empty hash") + } + if h2 == h3 { + t.Fatalf("expected different hashes for different keys") + } + + // 4) empty when no signals + h4 := svc.GenerateSessionHash(c, map[string]any{}) + if h4 != "" { + t.Fatalf("expected empty hash when no signals") + } +} + func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable(t *testing.T) { now := time.Now() resetAt := now.Add(10 * time.Minute) diff --git a/backend/internal/service/openai_tool_corrector.go b/backend/internal/service/openai_tool_corrector.go index 9c9eab84..f4719275 100644 --- a/backend/internal/service/openai_tool_corrector.go +++ b/backend/internal/service/openai_tool_corrector.go @@ -27,6 +27,11 @@ var codexToolNameMapping = map[string]string{ "executeBash": "bash", "exec_bash": "bash", "execBash": "bash", + + // Some clients output generic fetch names. + "fetch": "webfetch", + "web_fetch": "webfetch", + "webFetch": "webfetch", } // ToolCorrectionStats 记录工具修正的统计信息(导出用于 JSON 序列化) @@ -208,27 +213,67 @@ func (c *CodexToolCorrector) correctToolParameters(toolName string, functionCall // 根据工具名称应用特定的参数修正规则 switch toolName { case "bash": - // 移除 workdir 参数(OpenCode 不支持) - if _, exists := argsMap["workdir"]; exists { - delete(argsMap, "workdir") - corrected = true - log.Printf("[CodexToolCorrector] Removed 'workdir' parameter from bash tool") - } - if _, exists := argsMap["work_dir"]; exists { - delete(argsMap, "work_dir") - corrected = true - log.Printf("[CodexToolCorrector] Removed 'work_dir' parameter from bash tool") + // OpenCode bash 支持 workdir;有些来源会输出 work_dir。 + if _, hasWorkdir := argsMap["workdir"]; !hasWorkdir { + if workDir, exists := argsMap["work_dir"]; exists { + argsMap["workdir"] = workDir + delete(argsMap, "work_dir") + corrected = true + log.Printf("[CodexToolCorrector] Renamed 'work_dir' to 'workdir' in bash tool") + } + } else { + if _, exists := argsMap["work_dir"]; exists { + delete(argsMap, "work_dir") + corrected = true + log.Printf("[CodexToolCorrector] Removed duplicate 'work_dir' parameter from bash tool") + } } case "edit": - // OpenCode edit 使用 old_string/new_string,Codex 可能使用其他名称 - // 这里可以添加参数名称的映射逻辑 - if _, exists := argsMap["file_path"]; !exists { - if path, exists := argsMap["path"]; exists { - argsMap["file_path"] = path + // OpenCode edit 参数为 filePath/oldString/newString(camelCase)。 + if _, exists := argsMap["filePath"]; !exists { + if filePath, exists := argsMap["file_path"]; exists { + argsMap["filePath"] = filePath + delete(argsMap, "file_path") + corrected = true + log.Printf("[CodexToolCorrector] Renamed 'file_path' to 'filePath' in edit tool") + } else if filePath, exists := argsMap["path"]; exists { + argsMap["filePath"] = filePath delete(argsMap, "path") corrected = true - log.Printf("[CodexToolCorrector] Renamed 'path' to 'file_path' in edit tool") + log.Printf("[CodexToolCorrector] Renamed 'path' to 'filePath' in edit tool") + } else if filePath, exists := argsMap["file"]; exists { + argsMap["filePath"] = filePath + delete(argsMap, "file") + corrected = true + log.Printf("[CodexToolCorrector] Renamed 'file' to 'filePath' in edit tool") + } + } + + if _, exists := argsMap["oldString"]; !exists { + if oldString, exists := argsMap["old_string"]; exists { + argsMap["oldString"] = oldString + delete(argsMap, "old_string") + corrected = true + log.Printf("[CodexToolCorrector] Renamed 'old_string' to 'oldString' in edit tool") + } + } + + if _, exists := argsMap["newString"]; !exists { + if newString, exists := argsMap["new_string"]; exists { + argsMap["newString"] = newString + delete(argsMap, "new_string") + corrected = true + log.Printf("[CodexToolCorrector] Renamed 'new_string' to 'newString' in edit tool") + } + } + + if _, exists := argsMap["replaceAll"]; !exists { + if replaceAll, exists := argsMap["replace_all"]; exists { + argsMap["replaceAll"] = replaceAll + delete(argsMap, "replace_all") + corrected = true + log.Printf("[CodexToolCorrector] Renamed 'replace_all' to 'replaceAll' in edit tool") } } } diff --git a/backend/internal/service/openai_tool_corrector_test.go b/backend/internal/service/openai_tool_corrector_test.go index 3e885b4b..ff518ea6 100644 --- a/backend/internal/service/openai_tool_corrector_test.go +++ b/backend/internal/service/openai_tool_corrector_test.go @@ -416,22 +416,23 @@ func TestCorrectToolParameters(t *testing.T) { expected map[string]bool // key: 期待存在的参数, value: true表示应该存在 }{ { - name: "remove workdir from bash tool", + name: "rename work_dir to workdir in bash tool", input: `{ "tool_calls": [{ "function": { "name": "bash", - "arguments": "{\"command\":\"ls\",\"workdir\":\"/tmp\"}" + "arguments": "{\"command\":\"ls\",\"work_dir\":\"/tmp\"}" } }] }`, expected: map[string]bool{ - "command": true, - "workdir": false, + "command": true, + "workdir": true, + "work_dir": false, }, }, { - name: "rename path to file_path in edit tool", + name: "rename snake_case edit params to camelCase", input: `{ "tool_calls": [{ "function": { @@ -441,10 +442,12 @@ func TestCorrectToolParameters(t *testing.T) { }] }`, expected: map[string]bool{ - "file_path": true, + "filePath": true, "path": false, - "old_string": true, - "new_string": true, + "oldString": true, + "old_string": false, + "newString": true, + "new_string": false, }, }, } diff --git a/backend/internal/service/ops_retry.go b/backend/internal/service/ops_retry.go index 25c10af6..8d98e43f 100644 --- a/backend/internal/service/ops_retry.go +++ b/backend/internal/service/ops_retry.go @@ -514,7 +514,7 @@ func (s *OpsService) selectAccountForRetry(ctx context.Context, reqType opsRetry if s.gatewayService == nil { return nil, fmt.Errorf("gateway service not available") } - return s.gatewayService.SelectAccountWithLoadAwareness(ctx, groupID, "", model, excludedIDs) + return s.gatewayService.SelectAccountWithLoadAwareness(ctx, groupID, "", model, excludedIDs, "") // 重试不使用会话限制 default: return nil, fmt.Errorf("unsupported retry type: %s", reqType) } diff --git a/backend/internal/service/pricing_service.go b/backend/internal/service/pricing_service.go index 392fb65c..0ade72cd 100644 --- a/backend/internal/service/pricing_service.go +++ b/backend/internal/service/pricing_service.go @@ -531,8 +531,8 @@ func (s *PricingService) buildModelLookupCandidates(modelLower string) []string func normalizeModelNameForPricing(model string) string { // Common Gemini/VertexAI forms: // - models/gemini-2.0-flash-exp - // - publishers/google/models/gemini-1.5-pro - // - projects/.../locations/.../publishers/google/models/gemini-1.5-pro + // - publishers/google/models/gemini-2.5-pro + // - projects/.../locations/.../publishers/google/models/gemini-2.5-pro model = strings.TrimSpace(model) model = strings.TrimLeft(model, "/") model = strings.TrimPrefix(model, "models/") diff --git a/backend/internal/service/session_limit_cache.go b/backend/internal/service/session_limit_cache.go new file mode 100644 index 00000000..f6f0c26a --- /dev/null +++ b/backend/internal/service/session_limit_cache.go @@ -0,0 +1,63 @@ +package service + +import ( + "context" + "time" +) + +// SessionLimitCache 管理账号级别的活跃会话跟踪 +// 用于 Anthropic OAuth/SetupToken 账号的会话数量限制 +// +// Key 格式: session_limit:account:{accountID} +// 数据结构: Sorted Set (member=sessionUUID, score=timestamp) +// +// 会话在空闲超时后自动过期,无需手动清理 +type SessionLimitCache interface { + // RegisterSession 注册会话活动 + // - 如果会话已存在,刷新其时间戳并返回 true + // - 如果会话不存在且活跃会话数 < maxSessions,添加新会话并返回 true + // - 如果会话不存在且活跃会话数 >= maxSessions,返回 false(拒绝) + // + // 参数: + // accountID: 账号 ID + // sessionUUID: 从 metadata.user_id 中提取的会话 UUID + // maxSessions: 最大并发会话数限制 + // idleTimeout: 会话空闲超时时间 + // + // 返回: + // allowed: true 表示允许(在限制内或会话已存在),false 表示拒绝(超出限制且是新会话) + // error: 操作错误 + RegisterSession(ctx context.Context, accountID int64, sessionUUID string, maxSessions int, idleTimeout time.Duration) (allowed bool, err error) + + // RefreshSession 刷新现有会话的时间戳 + // 用于活跃会话保持活动状态 + RefreshSession(ctx context.Context, accountID int64, sessionUUID string, idleTimeout time.Duration) error + + // GetActiveSessionCount 获取当前活跃会话数 + // 返回未过期的会话数量 + GetActiveSessionCount(ctx context.Context, accountID int64) (int, error) + + // GetActiveSessionCountBatch 批量获取多个账号的活跃会话数 + // 返回 map[accountID]count,查询失败的账号不在 map 中 + GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) + + // IsSessionActive 检查特定会话是否活跃(未过期) + IsSessionActive(ctx context.Context, accountID int64, sessionUUID string) (bool, error) + + // ========== 5h窗口费用缓存 ========== + // Key 格式: window_cost:account:{accountID} + // 用于缓存账号在当前5h窗口内的标准费用,减少数据库聚合查询压力 + + // GetWindowCost 获取缓存的窗口费用 + // 返回 (cost, true, nil) 如果缓存命中 + // 返回 (0, false, nil) 如果缓存未命中 + // 返回 (0, false, err) 如果发生错误 + GetWindowCost(ctx context.Context, accountID int64) (cost float64, hit bool, err error) + + // SetWindowCost 设置窗口费用缓存 + SetWindowCost(ctx context.Context, accountID int64, cost float64) error + + // GetWindowCostBatch 批量获取窗口费用缓存 + // 返回 map[accountID]cost,缓存未命中的账号不在 map 中 + GetWindowCostBatch(ctx context.Context, accountIDs []int64) (map[int64]float64, error) +} diff --git a/frontend/src/components/account/AccountCapacityCell.vue b/frontend/src/components/account/AccountCapacityCell.vue new file mode 100644 index 00000000..ae338aca --- /dev/null +++ b/frontend/src/components/account/AccountCapacityCell.vue @@ -0,0 +1,199 @@ + + + diff --git a/frontend/src/components/account/AccountTestModal.vue b/frontend/src/components/account/AccountTestModal.vue index 42f3c1b9..dfa1503e 100644 --- a/frontend/src/components/account/AccountTestModal.vue +++ b/frontend/src/components/account/AccountTestModal.vue @@ -292,8 +292,11 @@ const loadAvailableModels = async () => { if (availableModels.value.length > 0) { if (props.account.platform === 'gemini') { const preferred = + availableModels.value.find((m) => m.id === 'gemini-2.0-flash') || + availableModels.value.find((m) => m.id === 'gemini-2.5-flash') || availableModels.value.find((m) => m.id === 'gemini-2.5-pro') || - availableModels.value.find((m) => m.id === 'gemini-3-pro') + availableModels.value.find((m) => m.id === 'gemini-3-flash-preview') || + availableModels.value.find((m) => m.id === 'gemini-3-pro-preview') selectedModelId.value = preferred?.id || availableModels.value[0].id } else { // Try to select Sonnet as default, otherwise use first model diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index 00cd9b24..d27364f1 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -604,6 +604,136 @@ + +
+
+

{{ t('admin.accounts.quotaControl.title') }}

+

+ {{ t('admin.accounts.quotaControl.hint') }} +

+
+ + +
+
+
+ +

+ {{ t('admin.accounts.quotaControl.windowCost.hint') }} +

+
+ +
+ +
+
+ +
+ $ + +
+

{{ t('admin.accounts.quotaControl.windowCost.limitHint') }}

+
+
+ +
+ $ + +
+

{{ t('admin.accounts.quotaControl.windowCost.stickyReserveHint') }}

+
+
+
+ + +
+
+
+ +

+ {{ t('admin.accounts.quotaControl.sessionLimit.hint') }} +

+
+ +
+ +
+
+ + +

{{ t('admin.accounts.quotaControl.sessionLimit.maxSessionsHint') }}

+
+
+ +
+ + {{ t('common.minutes') }} +
+

{{ t('admin.accounts.quotaControl.sessionLimit.idleTimeoutHint') }}

+
+
+
+
+
@@ -767,6 +897,14 @@ const mixedScheduling = ref(false) // For antigravity accounts: enable mixed sch const tempUnschedEnabled = ref(false) const tempUnschedRules = ref([]) +// Quota control state (Anthropic OAuth/SetupToken only) +const windowCostEnabled = ref(false) +const windowCostLimit = ref(null) +const windowCostStickyReserve = ref(null) +const sessionLimitEnabled = ref(false) +const maxSessions = ref(null) +const sessionIdleTimeout = ref(null) + // Computed: current preset mappings based on platform const presetMappings = computed(() => getPresetMappingsByPlatform(props.account?.platform || 'anthropic')) const tempUnschedPresets = computed(() => [ @@ -854,6 +992,9 @@ watch( const extra = newAccount.extra as Record | undefined mixedScheduling.value = extra?.mixed_scheduling === true + // Load quota control settings (Anthropic OAuth/SetupToken only) + loadQuotaControlSettings(newAccount) + loadTempUnschedRules(credentials) // Initialize API Key fields for apikey type @@ -1087,6 +1228,35 @@ function loadTempUnschedRules(credentials?: Record) { }) } +// Load quota control settings from account (Anthropic OAuth/SetupToken only) +function loadQuotaControlSettings(account: Account) { + // Reset all quota control state first + windowCostEnabled.value = false + windowCostLimit.value = null + windowCostStickyReserve.value = null + sessionLimitEnabled.value = false + maxSessions.value = null + sessionIdleTimeout.value = null + + // Only applies to Anthropic OAuth/SetupToken accounts + if (account.platform !== 'anthropic' || (account.type !== 'oauth' && account.type !== 'setup-token')) { + return + } + + // Load from extra field (via backend DTO fields) + if (account.window_cost_limit != null && account.window_cost_limit > 0) { + windowCostEnabled.value = true + windowCostLimit.value = account.window_cost_limit + windowCostStickyReserve.value = account.window_cost_sticky_reserve ?? 10 + } + + if (account.max_sessions != null && account.max_sessions > 0) { + sessionLimitEnabled.value = true + maxSessions.value = account.max_sessions + sessionIdleTimeout.value = account.session_idle_timeout_minutes ?? 5 + } +} + function formatTempUnschedKeywords(value: unknown) { if (Array.isArray(value)) { return value @@ -1214,6 +1384,32 @@ const handleSubmit = async () => { updatePayload.extra = newExtra } + // For Anthropic OAuth/SetupToken accounts, handle quota control settings in extra + if (props.account.platform === 'anthropic' && (props.account.type === 'oauth' || props.account.type === 'setup-token')) { + const currentExtra = (props.account.extra as Record) || {} + const newExtra: Record = { ...currentExtra } + + // Window cost limit settings + if (windowCostEnabled.value && windowCostLimit.value != null && windowCostLimit.value > 0) { + newExtra.window_cost_limit = windowCostLimit.value + newExtra.window_cost_sticky_reserve = windowCostStickyReserve.value ?? 10 + } else { + delete newExtra.window_cost_limit + delete newExtra.window_cost_sticky_reserve + } + + // Session limit settings + if (sessionLimitEnabled.value && maxSessions.value != null && maxSessions.value > 0) { + newExtra.max_sessions = maxSessions.value + newExtra.session_idle_timeout_minutes = sessionIdleTimeout.value ?? 5 + } else { + delete newExtra.max_sessions + delete newExtra.session_idle_timeout_minutes + } + + updatePayload.extra = newExtra + } + await adminAPI.accounts.update(props.account.id, updatePayload) appStore.showSuccess(t('admin.accounts.accountUpdated')) emit('updated') diff --git a/frontend/src/components/admin/account/AccountTableActions.vue b/frontend/src/components/admin/account/AccountTableActions.vue index 96fceaa0..91ebd239 100644 --- a/frontend/src/components/admin/account/AccountTableActions.vue +++ b/frontend/src/components/admin/account/AccountTableActions.vue @@ -1,5 +1,6 @@