From e63c83955adf1c8e4972de181e0d2d86cb133243 Mon Sep 17 00:00:00 2001 From: QTom Date: Sat, 28 Feb 2026 10:35:33 +0800 Subject: [PATCH] fix: address deep code review issues for RPM limiting - Move IncrementRPM after Forward success to prevent phantom RPM consumption during account switch retries - Add base_rpm input sanitization (clamp to 0-10000) in Create/Update - Add WindowCost scheduling checks to legacy path sticky sessions (4 check sites + 4 prefetch sites), fixing pre-existing gap - Clean up rpm_strategy/rpm_sticky_buffer when disabling RPM in BulkEditModal (JSONB merge cannot delete keys, use empty values) - Add json.Number test cases to TestGetBaseRPM/TestGetRPMStickyBuffer - Document TOCTOU race as accepted soft-limit design trade-off --- .../internal/handler/admin/account_handler.go | 45 +++++++++++++++++++ backend/internal/handler/gateway_handler.go | 34 +++++++------- backend/internal/service/account_rpm_test.go | 7 ++- backend/internal/service/gateway_service.go | 27 ++++++----- .../account/BulkEditAccountModal.vue | 6 ++- 5 files changed, 92 insertions(+), 27 deletions(-) diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index c41f37c1..382d62c1 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -483,6 +483,8 @@ func (h *AccountHandler) Create(c *gin.Context) { response.BadRequest(c, "rate_multiplier must be >= 0") return } + // base_rpm 输入校验:负值归零,超过 10000 截断 + sanitizeExtraBaseRPM(req.Extra) // 确定是否跳过混合渠道检查 skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk @@ -552,6 +554,8 @@ func (h *AccountHandler) Update(c *gin.Context) { response.BadRequest(c, "rate_multiplier must be >= 0") return } + // base_rpm 输入校验:负值归零,超过 10000 截断 + sanitizeExtraBaseRPM(req.Extra) // 确定是否跳过混合渠道检查 skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk @@ -1736,3 +1740,44 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) { func (h *AccountHandler) GetAntigravityDefaultModelMapping(c *gin.Context) { response.Success(c, domain.DefaultAntigravityModelMapping) } + +// sanitizeExtraBaseRPM 对 extra map 中的 base_rpm 值进行范围校验和归一化。 +// 负值归零,超过 10000 截断为 10000。extra 为 nil 或不含 base_rpm 时无操作。 +func sanitizeExtraBaseRPM(extra map[string]any) { + if extra == nil { + return + } + raw, ok := extra["base_rpm"] + if !ok { + return + } + v := parseExtraIntForValidation(raw) + if v < 0 { + v = 0 + } else if v > 10000 { + v = 10000 + } + extra["base_rpm"] = v +} + +// parseExtraIntForValidation 从 extra 字段的 any 值解析为 int,用于输入校验。 +// 支持 int, int64, float64, json.Number, string 类型。 +func parseExtraIntForValidation(value any) int { + switch v := value.(type) { + case int: + return v + case int64: + return int(v) + case float64: + return int(v) + case json.Number: + if i, err := v.Int64(); err == nil { + return int(i) + } + case string: + if i, err := strconv.Atoi(strings.TrimSpace(v)); err == nil { + return i + } + } + return 0 +} diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index b68a46fa..3cc52839 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -366,13 +366,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 账号槽位/等待计数需要在超时或断开时安全回收 accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) - // RPM 计数递增(调度成功后、Forward 前) - if account.IsAnthropicOAuthOrSetupToken() && account.GetBaseRPM() > 0 { - if err := h.gatewayService.IncrementAccountRPM(c.Request.Context(), account.ID); err != nil { - reqLog.Warn("gateway.rpm_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err)) - } - } - // 转发请求 - 根据账号平台分流 var result *service.ForwardResult requestCtx := c.Request.Context() @@ -410,6 +403,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) { return } + // RPM 计数递增(Forward 成功后) + // 注意:TOCTOU 竞态是已知且可接受的设计权衡,与 WindowCost 一致的 soft-limit 模式。 + // 在高并发下可能短暂超出 RPM 限制,但不会导致请求失败。 + if account.IsAnthropicOAuthOrSetupToken() && account.GetBaseRPM() > 0 { + if err := h.gatewayService.IncrementAccountRPM(c.Request.Context(), account.ID); err != nil { + reqLog.Warn("gateway.rpm_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + } + } + // 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context) userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) @@ -556,13 +558,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 账号槽位/等待计数需要在超时或断开时安全回收 accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) - // RPM 计数递增(调度成功后、Forward 前) - if account.IsAnthropicOAuthOrSetupToken() && account.GetBaseRPM() > 0 { - if err := h.gatewayService.IncrementAccountRPM(c.Request.Context(), account.ID); err != nil { - reqLog.Warn("gateway.rpm_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err)) - } - } - // 转发请求 - 根据账号平台分流 var result *service.ForwardResult requestCtx := c.Request.Context() @@ -609,7 +604,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { h.handleStreamingAwareError(c, status, code, message, streamStarted) return } - // 兜底重试按“直接请求兜底分组”处理:清除强制平台,允许按分组平台调度 + // 兜底重试按"直接请求兜底分组"处理:清除强制平台,允许按分组平台调度 ctx := context.WithValue(c.Request.Context(), ctxkey.ForcePlatform, "") c.Request = c.Request.WithContext(ctx) currentAPIKey = fallbackAPIKey @@ -643,6 +638,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) { return } + // RPM 计数递增(Forward 成功后) + // 注意:TOCTOU 竞态是已知且可接受的设计权衡,与 WindowCost 一致的 soft-limit 模式。 + // 在高并发下可能短暂超出 RPM 限制,但不会导致请求失败。 + if account.IsAnthropicOAuthOrSetupToken() && account.GetBaseRPM() > 0 { + if err := h.gatewayService.IncrementAccountRPM(c.Request.Context(), account.ID); err != nil { + reqLog.Warn("gateway.rpm_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + } + } + // 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context) userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) diff --git a/backend/internal/service/account_rpm_test.go b/backend/internal/service/account_rpm_test.go index b08b54a2..9d91f3e0 100644 --- a/backend/internal/service/account_rpm_test.go +++ b/backend/internal/service/account_rpm_test.go @@ -1,6 +1,9 @@ package service -import "testing" +import ( + "encoding/json" + "testing" +) func TestGetBaseRPM(t *testing.T) { tests := []struct { @@ -16,6 +19,7 @@ func TestGetBaseRPM(t *testing.T) { {"string value", map[string]any{"base_rpm": "15"}, 15}, {"negative value", map[string]any{"base_rpm": -5}, 0}, {"int64 value", map[string]any{"base_rpm": int64(20)}, 20}, + {"json.Number value", map[string]any{"base_rpm": json.Number("25")}, 25}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -103,6 +107,7 @@ func TestGetRPMStickyBuffer(t *testing.T) { {"custom buffer=0 fallback to default", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 0}, 2}, {"custom buffer negative fallback", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": -1}, 2}, {"custom buffer with float", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": float64(7)}, 7}, + {"json.Number base_rpm", map[string]any{"base_rpm": json.Number("10")}, 2}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 54c3a4d1..04e37f68 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -2242,6 +2242,9 @@ func (s *GatewayService) isAccountSchedulableForRPM(ctx context.Context, account } // IncrementAccountRPM increments the RPM counter for the given account. +// 已知 TOCTOU 竞态:调度时读取 RPM 计数与此处递增之间存在时间窗口, +// 高并发下可能短暂超出 RPM 限制。这是与 WindowCost 一致的 soft-limit +// 设计权衡——可接受的少量超额优于加锁带来的延迟和复杂度。 func (s *GatewayService) IncrementAccountRPM(ctx context.Context, accountID int64) error { if s.rpmCache == nil { return nil @@ -2444,7 +2447,7 @@ func sameAccountWithLoadGroup(a, b accountWithLoad) bool { // shuffleWithinPriorityAndLastUsed 对排序后的 []*Account 切片,按 (Priority, LastUsedAt) 分组后组内随机打乱。 // // 注意:当 preferOAuth=true 时,需要保证 OAuth 账号在同组内仍然优先,否则会把排序时的偏好打散掉。 -// 因此这里采用“组内分区 + 分区内 shuffle”的方式: +// 因此这里采用"组内分区 + 分区内 shuffle"的方式: // - 先把同组账号按 (OAuth / 非 OAuth) 拆成两段,保持 OAuth 段在前; // - 再分别在各段内随机打散,避免热点。 func shuffleWithinPriorityAndLastUsed(accounts []*Account, preferOAuth bool) { @@ -2584,7 +2587,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } - if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForRPM(ctx, account, true) { + if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { if s.debugModelRoutingEnabled() { logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) } @@ -2607,7 +2610,8 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, } accountsLoaded = true - // 提前预取 RPM 计数,确保 routing 段内的 isAccountSchedulableForRPM 调用能命中缓存 + // 提前预取窗口费用+RPM 计数,确保 routing 段内的调度检查调用能命中缓存 + ctx = s.withWindowCostPrefetch(ctx, accounts) ctx = s.withRPMPrefetch(ctx, accounts) routingSet := make(map[int64]struct{}, len(routingAccountIDs)) @@ -2690,7 +2694,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } - if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForRPM(ctx, account, true) { + if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { return account, nil } } @@ -2711,7 +2715,8 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, } } - // 批量预取 RPM 计数,避免逐个账号查询(N+1) + // 批量预取窗口费用+RPM 计数,避免逐个账号查询(N+1) + ctx = s.withWindowCostPrefetch(ctx, accounts) ctx = s.withRPMPrefetch(ctx, accounts) // 3. 按优先级+最久未用选择(考虑模型支持) @@ -2804,7 +2809,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } - if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForRPM(ctx, account, true) { + if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { if s.debugModelRoutingEnabled() { logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) @@ -2825,7 +2830,8 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g } accountsLoaded = true - // 提前预取 RPM 计数,确保 routing 段内的 isAccountSchedulableForRPM 调用能命中缓存 + // 提前预取窗口费用+RPM 计数,确保 routing 段内的调度检查调用能命中缓存 + ctx = s.withWindowCostPrefetch(ctx, accounts) ctx = s.withRPMPrefetch(ctx, accounts) routingSet := make(map[int64]struct{}, len(routingAccountIDs)) @@ -2912,7 +2918,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } - if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForRPM(ctx, account, true) { + if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { return account, nil } @@ -2931,7 +2937,8 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g } } - // 批量预取 RPM 计数,避免逐个账号查询(N+1) + // 批量预取窗口费用+RPM 计数,避免逐个账号查询(N+1) + ctx = s.withWindowCostPrefetch(ctx, accounts) ctx = s.withRPMPrefetch(ctx, accounts) // 3. 按优先级+最久未用选择(考虑模型支持和混合调度) @@ -5304,7 +5311,7 @@ func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool { } func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool { - // 只对“可能是兼容性差异导致”的 400 允许切换,避免无意义重试。 + // 只对"可能是兼容性差异导致"的 400 允许切换,避免无意义重试。 // 默认保守:无法识别则不切换。 msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody))) if msg == "" { diff --git a/frontend/src/components/account/BulkEditAccountModal.vue b/frontend/src/components/account/BulkEditAccountModal.vue index e583b981..ae16ff1a 100644 --- a/frontend/src/components/account/BulkEditAccountModal.vue +++ b/frontend/src/components/account/BulkEditAccountModal.vue @@ -1224,8 +1224,12 @@ const buildUpdatePayload = (): Record | null => { extra.rpm_sticky_buffer = bulkRpmStickyBuffer.value } } else { - // 关闭 RPM 限制 - 设置 base_rpm 为 0 + // 关闭 RPM 限制 - 设置 base_rpm 为 0,并用空值覆盖关联字段 + // 后端使用 JSONB || merge 语义,不会删除已有 key, + // 所以必须显式发送空值来重置(后端读取时会 fallback 到默认值) extra.base_rpm = 0 + extra.rpm_strategy = '' + extra.rpm_sticky_buffer = 0 } updates.extra = extra }