fix(gateway): 修复粘性会话预取分组错配并优化并发等待热路径
This commit is contained in:
@@ -604,17 +604,25 @@ func TestGatewayHotpathHelpers_CacheTTLAndStickyContext(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("prefetched_sticky_account_id_from_context", func(t *testing.T) {
|
||||
require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(context.TODO()))
|
||||
require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(context.Background()))
|
||||
require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(context.TODO(), nil))
|
||||
require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(context.Background(), nil))
|
||||
|
||||
ctx := context.WithValue(context.Background(), ctxkey.PrefetchedStickyAccountID, int64(123))
|
||||
require.Equal(t, int64(123), prefetchedStickyAccountIDFromContext(ctx))
|
||||
ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, int64(0))
|
||||
require.Equal(t, int64(123), prefetchedStickyAccountIDFromContext(ctx, nil))
|
||||
|
||||
groupID := int64(9)
|
||||
ctx2 := context.WithValue(context.Background(), ctxkey.PrefetchedStickyAccountID, 456)
|
||||
require.Equal(t, int64(456), prefetchedStickyAccountIDFromContext(ctx2))
|
||||
ctx2 = context.WithValue(ctx2, ctxkey.PrefetchedStickyGroupID, groupID)
|
||||
require.Equal(t, int64(456), prefetchedStickyAccountIDFromContext(ctx2, &groupID))
|
||||
|
||||
ctx3 := context.WithValue(context.Background(), ctxkey.PrefetchedStickyAccountID, "invalid")
|
||||
require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(ctx3))
|
||||
ctx3 = context.WithValue(ctx3, ctxkey.PrefetchedStickyGroupID, groupID)
|
||||
require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(ctx3, &groupID))
|
||||
|
||||
ctx4 := context.WithValue(context.Background(), ctxkey.PrefetchedStickyAccountID, int64(789))
|
||||
ctx4 = context.WithValue(ctx4, ctxkey.PrefetchedStickyGroupID, int64(10))
|
||||
require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(ctx4, &groupID))
|
||||
})
|
||||
|
||||
t.Run("window_cost_from_prefetch_context", func(t *testing.T) {
|
||||
@@ -745,6 +753,7 @@ func TestSelectAccountWithLoadAwareness_StickyReadReuse(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx := context.WithValue(baseCtx, ctxkey.PrefetchedStickyAccountID, account.ID)
|
||||
ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, int64(0))
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sess-hash", "", nil, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
@@ -752,4 +761,26 @@ func TestSelectAccountWithLoadAwareness_StickyReadReuse(t *testing.T) {
|
||||
require.Equal(t, account.ID, result.Account.ID)
|
||||
require.Equal(t, int64(0), cache.getCalls.Load())
|
||||
})
|
||||
|
||||
t.Run("with_prefetch_group_mismatch_reads_cache", func(t *testing.T) {
|
||||
cache := &stickyGatewayCacheHotpathStub{stickyID: account.ID}
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
concurrencyService: concurrency,
|
||||
userGroupRateCache: gocache.New(time.Minute, time.Minute),
|
||||
modelsListCache: gocache.New(time.Minute, time.Minute),
|
||||
modelsListCacheTTL: time.Minute,
|
||||
}
|
||||
|
||||
ctx := context.WithValue(baseCtx, ctxkey.PrefetchedStickyAccountID, int64(999))
|
||||
ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, int64(77))
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sess-hash", "", nil, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
require.Equal(t, account.ID, result.Account.ID)
|
||||
require.Equal(t, int64(1), cache.getCalls.Load())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -373,8 +373,26 @@ func modelsListCacheKey(groupID *int64, platform string) string {
|
||||
return fmt.Sprintf("%d|%s", derefGroupID(groupID), strings.TrimSpace(platform))
|
||||
}
|
||||
|
||||
func prefetchedStickyAccountIDFromContext(ctx context.Context) int64 {
|
||||
func prefetchedStickyGroupIDFromContext(ctx context.Context) (int64, bool) {
|
||||
if ctx == nil {
|
||||
return 0, false
|
||||
}
|
||||
v := ctx.Value(ctxkey.PrefetchedStickyGroupID)
|
||||
switch t := v.(type) {
|
||||
case int64:
|
||||
return t, true
|
||||
case int:
|
||||
return int64(t), true
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func prefetchedStickyAccountIDFromContext(ctx context.Context, groupID *int64) int64 {
|
||||
if ctx == nil {
|
||||
return 0
|
||||
}
|
||||
prefetchedGroupID, ok := prefetchedStickyGroupIDFromContext(ctx)
|
||||
if !ok || prefetchedGroupID != derefGroupID(groupID) {
|
||||
return 0
|
||||
}
|
||||
v := ctx.Value(ctxkey.PrefetchedStickyAccountID)
|
||||
@@ -1035,15 +1053,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
|
||||
cfg := s.schedulingConfig()
|
||||
|
||||
var stickyAccountID int64
|
||||
if prefetch := prefetchedStickyAccountIDFromContext(ctx); prefetch > 0 {
|
||||
stickyAccountID = prefetch
|
||||
} else if sessionHash != "" && s.cache != nil {
|
||||
if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash); err == nil {
|
||||
stickyAccountID = accountID
|
||||
}
|
||||
}
|
||||
|
||||
// 检查 Claude Code 客户端限制(可能会替换 groupID 为降级分组)
|
||||
group, groupID, err := s.checkClaudeCodeRestriction(ctx, groupID)
|
||||
if err != nil {
|
||||
@@ -1051,6 +1060,15 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
}
|
||||
ctx = s.withGroupContext(ctx, group)
|
||||
|
||||
var stickyAccountID int64
|
||||
if prefetch := prefetchedStickyAccountIDFromContext(ctx, groupID); prefetch > 0 {
|
||||
stickyAccountID = prefetch
|
||||
} else if sessionHash != "" && s.cache != nil {
|
||||
if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash); err == nil {
|
||||
stickyAccountID = accountID
|
||||
}
|
||||
}
|
||||
|
||||
if s.debugModelRoutingEnabled() && requestedModel != "" {
|
||||
groupPlatform := ""
|
||||
if group != nil {
|
||||
|
||||
Reference in New Issue
Block a user