fix(gateway): 修复粘性会话预取分组错配并优化并发等待热路径

This commit is contained in:
yangjianbo
2026-02-22 16:43:33 +08:00
parent a89477ddf5
commit 2ee6c26676
7 changed files with 121 additions and 32 deletions

View File

@@ -604,17 +604,25 @@ func TestGatewayHotpathHelpers_CacheTTLAndStickyContext(t *testing.T) {
})
t.Run("prefetched_sticky_account_id_from_context", func(t *testing.T) {
require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(context.TODO()))
require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(context.Background()))
require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(context.TODO(), nil))
require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(context.Background(), nil))
ctx := context.WithValue(context.Background(), ctxkey.PrefetchedStickyAccountID, int64(123))
require.Equal(t, int64(123), prefetchedStickyAccountIDFromContext(ctx))
ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, int64(0))
require.Equal(t, int64(123), prefetchedStickyAccountIDFromContext(ctx, nil))
groupID := int64(9)
ctx2 := context.WithValue(context.Background(), ctxkey.PrefetchedStickyAccountID, 456)
require.Equal(t, int64(456), prefetchedStickyAccountIDFromContext(ctx2))
ctx2 = context.WithValue(ctx2, ctxkey.PrefetchedStickyGroupID, groupID)
require.Equal(t, int64(456), prefetchedStickyAccountIDFromContext(ctx2, &groupID))
ctx3 := context.WithValue(context.Background(), ctxkey.PrefetchedStickyAccountID, "invalid")
require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(ctx3))
ctx3 = context.WithValue(ctx3, ctxkey.PrefetchedStickyGroupID, groupID)
require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(ctx3, &groupID))
ctx4 := context.WithValue(context.Background(), ctxkey.PrefetchedStickyAccountID, int64(789))
ctx4 = context.WithValue(ctx4, ctxkey.PrefetchedStickyGroupID, int64(10))
require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(ctx4, &groupID))
})
t.Run("window_cost_from_prefetch_context", func(t *testing.T) {
@@ -745,6 +753,7 @@ func TestSelectAccountWithLoadAwareness_StickyReadReuse(t *testing.T) {
}
ctx := context.WithValue(baseCtx, ctxkey.PrefetchedStickyAccountID, account.ID)
ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, int64(0))
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sess-hash", "", nil, "")
require.NoError(t, err)
require.NotNil(t, result)
@@ -752,4 +761,26 @@ func TestSelectAccountWithLoadAwareness_StickyReadReuse(t *testing.T) {
require.Equal(t, account.ID, result.Account.ID)
require.Equal(t, int64(0), cache.getCalls.Load())
})
t.Run("with_prefetch_group_mismatch_reads_cache", func(t *testing.T) {
cache := &stickyGatewayCacheHotpathStub{stickyID: account.ID}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: concurrency,
userGroupRateCache: gocache.New(time.Minute, time.Minute),
modelsListCache: gocache.New(time.Minute, time.Minute),
modelsListCacheTTL: time.Minute,
}
ctx := context.WithValue(baseCtx, ctxkey.PrefetchedStickyAccountID, int64(999))
ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, int64(77))
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sess-hash", "", nil, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, account.ID, result.Account.ID)
require.Equal(t, int64(1), cache.getCalls.Load())
})
}

View File

@@ -373,8 +373,26 @@ func modelsListCacheKey(groupID *int64, platform string) string {
return fmt.Sprintf("%d|%s", derefGroupID(groupID), strings.TrimSpace(platform))
}
func prefetchedStickyAccountIDFromContext(ctx context.Context) int64 {
func prefetchedStickyGroupIDFromContext(ctx context.Context) (int64, bool) {
if ctx == nil {
return 0, false
}
v := ctx.Value(ctxkey.PrefetchedStickyGroupID)
switch t := v.(type) {
case int64:
return t, true
case int:
return int64(t), true
}
return 0, false
}
func prefetchedStickyAccountIDFromContext(ctx context.Context, groupID *int64) int64 {
if ctx == nil {
return 0
}
prefetchedGroupID, ok := prefetchedStickyGroupIDFromContext(ctx)
if !ok || prefetchedGroupID != derefGroupID(groupID) {
return 0
}
v := ctx.Value(ctxkey.PrefetchedStickyAccountID)
@@ -1035,15 +1053,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
cfg := s.schedulingConfig()
var stickyAccountID int64
if prefetch := prefetchedStickyAccountIDFromContext(ctx); prefetch > 0 {
stickyAccountID = prefetch
} else if sessionHash != "" && s.cache != nil {
if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash); err == nil {
stickyAccountID = accountID
}
}
// 检查 Claude Code 客户端限制(可能会替换 groupID 为降级分组)
group, groupID, err := s.checkClaudeCodeRestriction(ctx, groupID)
if err != nil {
@@ -1051,6 +1060,15 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
ctx = s.withGroupContext(ctx, group)
var stickyAccountID int64
if prefetch := prefetchedStickyAccountIDFromContext(ctx, groupID); prefetch > 0 {
stickyAccountID = prefetch
} else if sessionHash != "" && s.cache != nil {
if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash); err == nil {
stickyAccountID = accountID
}
}
if s.debugModelRoutingEnabled() && requestedModel != "" {
groupPlatform := ""
if group != nil {