fix(gateway): 修复粘性会话预取分组错配并优化并发等待热路径

This commit is contained in:
yangjianbo
2026-02-22 16:43:33 +08:00
parent a89477ddf5
commit 2ee6c26676
7 changed files with 121 additions and 32 deletions

View File

@@ -244,7 +244,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if sessionKey != "" { if sessionKey != "" {
sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey) sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey)
if sessionBoundAccountID > 0 { if sessionBoundAccountID > 0 {
prefetchedGroupID := int64(0)
if apiKey.GroupID != nil {
prefetchedGroupID = *apiKey.GroupID
}
ctx := context.WithValue(c.Request.Context(), ctxkey.PrefetchedStickyAccountID, sessionBoundAccountID) ctx := context.WithValue(c.Request.Context(), ctxkey.PrefetchedStickyAccountID, sessionBoundAccountID)
ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, prefetchedGroupID)
c.Request = c.Request.WithContext(ctx) c.Request = c.Request.WithContext(ctx)
} }
} }

View File

@@ -230,14 +230,31 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests. // waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests.
// streamStarted pointer is updated when streaming begins (for proper error handling by caller). // streamStarted pointer is updated when streaming begins (for proper error handling by caller).
func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) { func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
return h.waitForSlotWithPingTimeout(c, slotType, id, maxConcurrency, maxConcurrencyWait, isStream, streamStarted) return h.waitForSlotWithPingTimeout(c, slotType, id, maxConcurrency, maxConcurrencyWait, isStream, streamStarted, false)
} }
// waitForSlotWithPingTimeout waits for a concurrency slot with a custom timeout. // waitForSlotWithPingTimeout waits for a concurrency slot with a custom timeout.
func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType string, id int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) { func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType string, id int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool, tryImmediate bool) (func(), error) {
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout) ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
defer cancel() defer cancel()
acquireSlot := func() (*service.AcquireResult, error) {
if slotType == "user" {
return h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency)
}
return h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency)
}
if tryImmediate {
result, err := acquireSlot()
if err != nil {
return nil, err
}
if result.Acquired {
return result.ReleaseFunc, nil
}
}
// Determine if ping is needed (streaming + ping format defined) // Determine if ping is needed (streaming + ping format defined)
needPing := isStream && h.pingFormat != "" needPing := isStream && h.pingFormat != ""
@@ -286,15 +303,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
case <-timer.C: case <-timer.C:
// Try to acquire slot // Try to acquire slot
var result *service.AcquireResult result, err := acquireSlot()
var err error
if slotType == "user" {
result, err = h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency)
} else {
result, err = h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency)
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -310,7 +319,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
// AcquireAccountSlotWithWaitTimeout acquires an account slot with a custom timeout (keeps SSE ping). // AcquireAccountSlotWithWaitTimeout acquires an account slot with a custom timeout (keeps SSE ping).
func (h *ConcurrencyHelper) AcquireAccountSlotWithWaitTimeout(c *gin.Context, accountID int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) { func (h *ConcurrencyHelper) AcquireAccountSlotWithWaitTimeout(c *gin.Context, accountID int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) {
return h.waitForSlotWithPingTimeout(c, "account", accountID, maxConcurrency, timeout, isStream, streamStarted) return h.waitForSlotWithPingTimeout(c, "account", accountID, maxConcurrency, timeout, isStream, streamStarted, true)
} }
// nextBackoff 计算下一次退避时间 // nextBackoff 计算下一次退避时间

View File

@@ -176,7 +176,7 @@ func TestWaitForSlotWithPingTimeout_AccountAndUserAcquire(t *testing.T) {
t.Run("account_slot_acquired_after_retry", func(t *testing.T) { t.Run("account_slot_acquired_after_retry", func(t *testing.T) {
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
streamStarted := false streamStarted := false
release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, time.Second, false, &streamStarted) release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, time.Second, false, &streamStarted, true)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, release) require.NotNil(t, release)
require.False(t, streamStarted) require.False(t, streamStarted)
@@ -188,7 +188,7 @@ func TestWaitForSlotWithPingTimeout_AccountAndUserAcquire(t *testing.T) {
t.Run("user_slot_acquired_after_retry", func(t *testing.T) { t.Run("user_slot_acquired_after_retry", func(t *testing.T) {
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
streamStarted := false streamStarted := false
release, err := helper.waitForSlotWithPingTimeout(c, "user", 202, 3, time.Second, false, &streamStarted) release, err := helper.waitForSlotWithPingTimeout(c, "user", 202, 3, time.Second, false, &streamStarted, true)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, release) require.NotNil(t, release)
release() release()
@@ -207,7 +207,7 @@ func TestWaitForSlotWithPingTimeout_TimeoutAndStreamPing(t *testing.T) {
helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond) helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond)
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
streamStarted := false streamStarted := false
release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, 130*time.Millisecond, false, &streamStarted) release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, 130*time.Millisecond, false, &streamStarted, true)
require.Nil(t, release) require.Nil(t, release)
var cErr *ConcurrencyError var cErr *ConcurrencyError
require.ErrorAs(t, err, &cErr) require.ErrorAs(t, err, &cErr)
@@ -218,7 +218,7 @@ func TestWaitForSlotWithPingTimeout_TimeoutAndStreamPing(t *testing.T) {
helper := NewConcurrencyHelper(concurrency, SSEPingFormatComment, 10*time.Millisecond) helper := NewConcurrencyHelper(concurrency, SSEPingFormatComment, 10*time.Millisecond)
c, rec := newHelperTestContext(http.MethodPost, "/v1/messages") c, rec := newHelperTestContext(http.MethodPost, "/v1/messages")
streamStarted := false streamStarted := false
release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, 70*time.Millisecond, true, &streamStarted) release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, 70*time.Millisecond, true, &streamStarted, true)
require.Nil(t, release) require.Nil(t, release)
var cErr *ConcurrencyError var cErr *ConcurrencyError
require.ErrorAs(t, err, &cErr) require.ErrorAs(t, err, &cErr)
@@ -236,12 +236,29 @@ func TestWaitForSlotWithPingTimeout_AcquireError(t *testing.T) {
helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond) helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond)
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
streamStarted := false streamStarted := false
release, err := helper.waitForSlotWithPingTimeout(c, "account", 1, 1, 200*time.Millisecond, false, &streamStarted) release, err := helper.waitForSlotWithPingTimeout(c, "account", 1, 1, 200*time.Millisecond, false, &streamStarted, true)
require.Nil(t, release) require.Nil(t, release)
require.Error(t, err) require.Error(t, err)
require.Contains(t, err.Error(), "redis unavailable") require.Contains(t, err.Error(), "redis unavailable")
} }
func TestAcquireAccountSlotWithWaitTimeout_ImmediateAttemptBeforeBackoff(t *testing.T) {
cache := &helperConcurrencyCacheStub{
accountSeq: []bool{false},
}
concurrency := service.NewConcurrencyService(cache)
helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond)
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
streamStarted := false
release, err := helper.AcquireAccountSlotWithWaitTimeout(c, 301, 1, 30*time.Millisecond, false, &streamStarted)
require.Nil(t, release)
var cErr *ConcurrencyError
require.ErrorAs(t, err, &cErr)
require.True(t, cErr.IsTimeout)
require.GreaterOrEqual(t, cache.accountAcquireCalls, 1)
}
type helperConcurrencyCacheStubWithError struct { type helperConcurrencyCacheStubWithError struct {
helperConcurrencyCacheStub helperConcurrencyCacheStub
err error err error

View File

@@ -264,7 +264,12 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
if sessionKey != "" { if sessionKey != "" {
sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey) sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey)
if sessionBoundAccountID > 0 { if sessionBoundAccountID > 0 {
prefetchedGroupID := int64(0)
if apiKey.GroupID != nil {
prefetchedGroupID = *apiKey.GroupID
}
ctx := context.WithValue(c.Request.Context(), ctxkey.PrefetchedStickyAccountID, sessionBoundAccountID) ctx := context.WithValue(c.Request.Context(), ctxkey.PrefetchedStickyAccountID, sessionBoundAccountID)
ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, prefetchedGroupID)
c.Request = c.Request.WithContext(ctx) c.Request = c.Request.WithContext(ctx)
} }
} }

View File

@@ -48,4 +48,8 @@ const (
// PrefetchedStickyAccountID 标识上游(通常 handler预取到的 sticky session 账号 ID。 // PrefetchedStickyAccountID 标识上游(通常 handler预取到的 sticky session 账号 ID。
// Service 层可复用该值,避免同请求链路重复读取 Redis。 // Service 层可复用该值,避免同请求链路重复读取 Redis。
PrefetchedStickyAccountID Key = "ctx_prefetched_sticky_account_id" PrefetchedStickyAccountID Key = "ctx_prefetched_sticky_account_id"
// PrefetchedStickyGroupID 标识上游预取 sticky session 时所使用的分组 ID。
// Service 层仅在分组匹配时复用 PrefetchedStickyAccountID避免分组切换重试误用旧 sticky。
PrefetchedStickyGroupID Key = "ctx_prefetched_sticky_group_id"
) )

View File

@@ -604,17 +604,25 @@ func TestGatewayHotpathHelpers_CacheTTLAndStickyContext(t *testing.T) {
}) })
t.Run("prefetched_sticky_account_id_from_context", func(t *testing.T) { t.Run("prefetched_sticky_account_id_from_context", func(t *testing.T) {
require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(context.TODO())) require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(context.TODO(), nil))
require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(context.Background())) require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(context.Background(), nil))
ctx := context.WithValue(context.Background(), ctxkey.PrefetchedStickyAccountID, int64(123)) ctx := context.WithValue(context.Background(), ctxkey.PrefetchedStickyAccountID, int64(123))
require.Equal(t, int64(123), prefetchedStickyAccountIDFromContext(ctx)) ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, int64(0))
require.Equal(t, int64(123), prefetchedStickyAccountIDFromContext(ctx, nil))
groupID := int64(9)
ctx2 := context.WithValue(context.Background(), ctxkey.PrefetchedStickyAccountID, 456) ctx2 := context.WithValue(context.Background(), ctxkey.PrefetchedStickyAccountID, 456)
require.Equal(t, int64(456), prefetchedStickyAccountIDFromContext(ctx2)) ctx2 = context.WithValue(ctx2, ctxkey.PrefetchedStickyGroupID, groupID)
require.Equal(t, int64(456), prefetchedStickyAccountIDFromContext(ctx2, &groupID))
ctx3 := context.WithValue(context.Background(), ctxkey.PrefetchedStickyAccountID, "invalid") ctx3 := context.WithValue(context.Background(), ctxkey.PrefetchedStickyAccountID, "invalid")
require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(ctx3)) ctx3 = context.WithValue(ctx3, ctxkey.PrefetchedStickyGroupID, groupID)
require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(ctx3, &groupID))
ctx4 := context.WithValue(context.Background(), ctxkey.PrefetchedStickyAccountID, int64(789))
ctx4 = context.WithValue(ctx4, ctxkey.PrefetchedStickyGroupID, int64(10))
require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(ctx4, &groupID))
}) })
t.Run("window_cost_from_prefetch_context", func(t *testing.T) { t.Run("window_cost_from_prefetch_context", func(t *testing.T) {
@@ -745,6 +753,7 @@ func TestSelectAccountWithLoadAwareness_StickyReadReuse(t *testing.T) {
} }
ctx := context.WithValue(baseCtx, ctxkey.PrefetchedStickyAccountID, account.ID) ctx := context.WithValue(baseCtx, ctxkey.PrefetchedStickyAccountID, account.ID)
ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, int64(0))
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sess-hash", "", nil, "") result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sess-hash", "", nil, "")
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
@@ -752,4 +761,26 @@ func TestSelectAccountWithLoadAwareness_StickyReadReuse(t *testing.T) {
require.Equal(t, account.ID, result.Account.ID) require.Equal(t, account.ID, result.Account.ID)
require.Equal(t, int64(0), cache.getCalls.Load()) require.Equal(t, int64(0), cache.getCalls.Load())
}) })
t.Run("with_prefetch_group_mismatch_reads_cache", func(t *testing.T) {
cache := &stickyGatewayCacheHotpathStub{stickyID: account.ID}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: concurrency,
userGroupRateCache: gocache.New(time.Minute, time.Minute),
modelsListCache: gocache.New(time.Minute, time.Minute),
modelsListCacheTTL: time.Minute,
}
ctx := context.WithValue(baseCtx, ctxkey.PrefetchedStickyAccountID, int64(999))
ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, int64(77))
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sess-hash", "", nil, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, account.ID, result.Account.ID)
require.Equal(t, int64(1), cache.getCalls.Load())
})
} }

View File

@@ -373,8 +373,26 @@ func modelsListCacheKey(groupID *int64, platform string) string {
return fmt.Sprintf("%d|%s", derefGroupID(groupID), strings.TrimSpace(platform)) return fmt.Sprintf("%d|%s", derefGroupID(groupID), strings.TrimSpace(platform))
} }
func prefetchedStickyAccountIDFromContext(ctx context.Context) int64 { func prefetchedStickyGroupIDFromContext(ctx context.Context) (int64, bool) {
if ctx == nil { if ctx == nil {
return 0, false
}
v := ctx.Value(ctxkey.PrefetchedStickyGroupID)
switch t := v.(type) {
case int64:
return t, true
case int:
return int64(t), true
}
return 0, false
}
func prefetchedStickyAccountIDFromContext(ctx context.Context, groupID *int64) int64 {
if ctx == nil {
return 0
}
prefetchedGroupID, ok := prefetchedStickyGroupIDFromContext(ctx)
if !ok || prefetchedGroupID != derefGroupID(groupID) {
return 0 return 0
} }
v := ctx.Value(ctxkey.PrefetchedStickyAccountID) v := ctx.Value(ctxkey.PrefetchedStickyAccountID)
@@ -1035,15 +1053,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
cfg := s.schedulingConfig() cfg := s.schedulingConfig()
var stickyAccountID int64
if prefetch := prefetchedStickyAccountIDFromContext(ctx); prefetch > 0 {
stickyAccountID = prefetch
} else if sessionHash != "" && s.cache != nil {
if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash); err == nil {
stickyAccountID = accountID
}
}
// 检查 Claude Code 客户端限制(可能会替换 groupID 为降级分组) // 检查 Claude Code 客户端限制(可能会替换 groupID 为降级分组)
group, groupID, err := s.checkClaudeCodeRestriction(ctx, groupID) group, groupID, err := s.checkClaudeCodeRestriction(ctx, groupID)
if err != nil { if err != nil {
@@ -1051,6 +1060,15 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
} }
ctx = s.withGroupContext(ctx, group) ctx = s.withGroupContext(ctx, group)
var stickyAccountID int64
if prefetch := prefetchedStickyAccountIDFromContext(ctx, groupID); prefetch > 0 {
stickyAccountID = prefetch
} else if sessionHash != "" && s.cache != nil {
if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash); err == nil {
stickyAccountID = accountID
}
}
if s.debugModelRoutingEnabled() && requestedModel != "" { if s.debugModelRoutingEnabled() && requestedModel != "" {
groupPlatform := "" groupPlatform := ""
if group != nil { if group != nil {