fix: address deep code review issues for RPM limiting
- Move IncrementRPM after Forward success to prevent phantom RPM consumption during account switch retries - Add base_rpm input sanitization (clamp to 0-10000) in Create/Update - Add WindowCost scheduling checks to legacy path sticky sessions (4 check sites + 4 prefetch sites), fixing pre-existing gap - Clean up rpm_strategy/rpm_sticky_buffer when disabling RPM in BulkEditModal (JSONB merge cannot delete keys, use empty values) - Add json.Number test cases to TestGetBaseRPM/TestGetRPMStickyBuffer - Document TOCTOU race as accepted soft-limit design trade-off
This commit is contained in:
@@ -1,6 +1,9 @@
|
||||
package service
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetBaseRPM(t *testing.T) {
|
||||
tests := []struct {
|
||||
@@ -16,6 +19,7 @@ func TestGetBaseRPM(t *testing.T) {
|
||||
{"string value", map[string]any{"base_rpm": "15"}, 15},
|
||||
{"negative value", map[string]any{"base_rpm": -5}, 0},
|
||||
{"int64 value", map[string]any{"base_rpm": int64(20)}, 20},
|
||||
{"json.Number value", map[string]any{"base_rpm": json.Number("25")}, 25},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
@@ -103,6 +107,7 @@ func TestGetRPMStickyBuffer(t *testing.T) {
|
||||
{"custom buffer=0 fallback to default", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 0}, 2},
|
||||
{"custom buffer negative fallback", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": -1}, 2},
|
||||
{"custom buffer with float", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": float64(7)}, 7},
|
||||
{"json.Number base_rpm", map[string]any{"base_rpm": json.Number("10")}, 2},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
||||
@@ -2242,6 +2242,9 @@ func (s *GatewayService) isAccountSchedulableForRPM(ctx context.Context, account
|
||||
}
|
||||
|
||||
// IncrementAccountRPM increments the RPM counter for the given account.
|
||||
// 已知 TOCTOU 竞态:调度时读取 RPM 计数与此处递增之间存在时间窗口,
|
||||
// 高并发下可能短暂超出 RPM 限制。这是与 WindowCost 一致的 soft-limit
|
||||
// 设计权衡——可接受的少量超额优于加锁带来的延迟和复杂度。
|
||||
func (s *GatewayService) IncrementAccountRPM(ctx context.Context, accountID int64) error {
|
||||
if s.rpmCache == nil {
|
||||
return nil
|
||||
@@ -2444,7 +2447,7 @@ func sameAccountWithLoadGroup(a, b accountWithLoad) bool {
|
||||
// shuffleWithinPriorityAndLastUsed 对排序后的 []*Account 切片,按 (Priority, LastUsedAt) 分组后组内随机打乱。
|
||||
//
|
||||
// 注意:当 preferOAuth=true 时,需要保证 OAuth 账号在同组内仍然优先,否则会把排序时的偏好打散掉。
|
||||
// 因此这里采用“组内分区 + 分区内 shuffle”的方式:
|
||||
// 因此这里采用"组内分区 + 分区内 shuffle"的方式:
|
||||
// - 先把同组账号按 (OAuth / 非 OAuth) 拆成两段,保持 OAuth 段在前;
|
||||
// - 再分别在各段内随机打散,避免热点。
|
||||
func shuffleWithinPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
|
||||
@@ -2584,7 +2587,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
if clearSticky {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
}
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForRPM(ctx, account, true) {
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
|
||||
if s.debugModelRoutingEnabled() {
|
||||
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
|
||||
}
|
||||
@@ -2607,7 +2610,8 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
}
|
||||
accountsLoaded = true
|
||||
|
||||
// 提前预取 RPM 计数,确保 routing 段内的 isAccountSchedulableForRPM 调用能命中缓存
|
||||
// 提前预取窗口费用+RPM 计数,确保 routing 段内的调度检查调用能命中缓存
|
||||
ctx = s.withWindowCostPrefetch(ctx, accounts)
|
||||
ctx = s.withRPMPrefetch(ctx, accounts)
|
||||
|
||||
routingSet := make(map[int64]struct{}, len(routingAccountIDs))
|
||||
@@ -2690,7 +2694,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
if clearSticky {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
}
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForRPM(ctx, account, true) {
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
@@ -2711,7 +2715,8 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
}
|
||||
}
|
||||
|
||||
// 批量预取 RPM 计数,避免逐个账号查询(N+1)
|
||||
// 批量预取窗口费用+RPM 计数,避免逐个账号查询(N+1)
|
||||
ctx = s.withWindowCostPrefetch(ctx, accounts)
|
||||
ctx = s.withRPMPrefetch(ctx, accounts)
|
||||
|
||||
// 3. 按优先级+最久未用选择(考虑模型支持)
|
||||
@@ -2804,7 +2809,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
if clearSticky {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
}
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForRPM(ctx, account, true) {
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
|
||||
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
|
||||
if s.debugModelRoutingEnabled() {
|
||||
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
|
||||
@@ -2825,7 +2830,8 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
}
|
||||
accountsLoaded = true
|
||||
|
||||
// 提前预取 RPM 计数,确保 routing 段内的 isAccountSchedulableForRPM 调用能命中缓存
|
||||
// 提前预取窗口费用+RPM 计数,确保 routing 段内的调度检查调用能命中缓存
|
||||
ctx = s.withWindowCostPrefetch(ctx, accounts)
|
||||
ctx = s.withRPMPrefetch(ctx, accounts)
|
||||
|
||||
routingSet := make(map[int64]struct{}, len(routingAccountIDs))
|
||||
@@ -2912,7 +2918,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
if clearSticky {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
}
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForRPM(ctx, account, true) {
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
|
||||
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
|
||||
return account, nil
|
||||
}
|
||||
@@ -2931,7 +2937,8 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
}
|
||||
}
|
||||
|
||||
// 批量预取 RPM 计数,避免逐个账号查询(N+1)
|
||||
// 批量预取窗口费用+RPM 计数,避免逐个账号查询(N+1)
|
||||
ctx = s.withWindowCostPrefetch(ctx, accounts)
|
||||
ctx = s.withRPMPrefetch(ctx, accounts)
|
||||
|
||||
// 3. 按优先级+最久未用选择(考虑模型支持和混合调度)
|
||||
@@ -5304,7 +5311,7 @@ func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool {
|
||||
}
|
||||
|
||||
func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool {
|
||||
// 只对“可能是兼容性差异导致”的 400 允许切换,避免无意义重试。
|
||||
// 只对"可能是兼容性差异导致"的 400 允许切换,避免无意义重试。
|
||||
// 默认保守:无法识别则不切换。
|
||||
msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody)))
|
||||
if msg == "" {
|
||||
|
||||
Reference in New Issue
Block a user