From 2ee6c26676efd19b0f30b151b2ccc7aedc7bb9b5 Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Sun, 22 Feb 2026 16:43:33 +0800 Subject: [PATCH] =?UTF-8?q?fix(gateway):=20=E4=BF=AE=E5=A4=8D=E7=B2=98?= =?UTF-8?q?=E6=80=A7=E4=BC=9A=E8=AF=9D=E9=A2=84=E5=8F=96=E5=88=86=E7=BB=84?= =?UTF-8?q?=E9=94=99=E9=85=8D=E5=B9=B6=E4=BC=98=E5=8C=96=E5=B9=B6=E5=8F=91?= =?UTF-8?q?=E7=AD=89=E5=BE=85=E7=83=AD=E8=B7=AF=E5=BE=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/internal/handler/gateway_handler.go | 5 +++ backend/internal/handler/gateway_helper.go | 33 +++++++++------ .../handler/gateway_helper_hotpath_test.go | 27 +++++++++--- .../internal/handler/gemini_v1beta_handler.go | 5 +++ backend/internal/pkg/ctxkey/ctxkey.go | 4 ++ .../gateway_hotpath_optimization_test.go | 41 ++++++++++++++++--- backend/internal/service/gateway_service.go | 38 ++++++++++++----- 7 files changed, 121 insertions(+), 32 deletions(-) diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 9bf0fcd2..4b32969f 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -244,7 +244,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) { if sessionKey != "" { sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey) if sessionBoundAccountID > 0 { + prefetchedGroupID := int64(0) + if apiKey.GroupID != nil { + prefetchedGroupID = *apiKey.GroupID + } ctx := context.WithValue(c.Request.Context(), ctxkey.PrefetchedStickyAccountID, sessionBoundAccountID) + ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, prefetchedGroupID) c.Request = c.Request.WithContext(ctx) } } diff --git a/backend/internal/handler/gateway_helper.go b/backend/internal/handler/gateway_helper.go index 6127dda7..efff7997 100644 --- a/backend/internal/handler/gateway_helper.go +++ b/backend/internal/handler/gateway_helper.go @@ -230,14 +230,31 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID // waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests. // streamStarted pointer is updated when streaming begins (for proper error handling by caller). func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) { - return h.waitForSlotWithPingTimeout(c, slotType, id, maxConcurrency, maxConcurrencyWait, isStream, streamStarted) + return h.waitForSlotWithPingTimeout(c, slotType, id, maxConcurrency, maxConcurrencyWait, isStream, streamStarted, false) } // waitForSlotWithPingTimeout waits for a concurrency slot with a custom timeout. -func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType string, id int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) { +func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType string, id int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool, tryImmediate bool) (func(), error) { ctx, cancel := context.WithTimeout(c.Request.Context(), timeout) defer cancel() + acquireSlot := func() (*service.AcquireResult, error) { + if slotType == "user" { + return h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency) + } + return h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency) + } + + if tryImmediate { + result, err := acquireSlot() + if err != nil { + return nil, err + } + if result.Acquired { + return result.ReleaseFunc, nil + } + } + // Determine if ping is needed (streaming + ping format defined) needPing := isStream && h.pingFormat != "" @@ -286,15 +303,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType case <-timer.C: // Try to acquire slot - var result *service.AcquireResult - var err error - - if slotType == "user" { - result, err = h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency) - } else { - result, err = h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency) - } - + result, err := acquireSlot() if err != nil { return nil, err } @@ -310,7 +319,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType // AcquireAccountSlotWithWaitTimeout acquires an account slot with a custom timeout (keeps SSE ping). func (h *ConcurrencyHelper) AcquireAccountSlotWithWaitTimeout(c *gin.Context, accountID int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) { - return h.waitForSlotWithPingTimeout(c, "account", accountID, maxConcurrency, timeout, isStream, streamStarted) + return h.waitForSlotWithPingTimeout(c, "account", accountID, maxConcurrency, timeout, isStream, streamStarted, true) } // nextBackoff 计算下一次退避时间 diff --git a/backend/internal/handler/gateway_helper_hotpath_test.go b/backend/internal/handler/gateway_helper_hotpath_test.go index 2149c130..3fdf1bfc 100644 --- a/backend/internal/handler/gateway_helper_hotpath_test.go +++ b/backend/internal/handler/gateway_helper_hotpath_test.go @@ -176,7 +176,7 @@ func TestWaitForSlotWithPingTimeout_AccountAndUserAcquire(t *testing.T) { t.Run("account_slot_acquired_after_retry", func(t *testing.T) { c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") streamStarted := false - release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, time.Second, false, &streamStarted) + release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, time.Second, false, &streamStarted, true) require.NoError(t, err) require.NotNil(t, release) require.False(t, streamStarted) @@ -188,7 +188,7 @@ func TestWaitForSlotWithPingTimeout_AccountAndUserAcquire(t *testing.T) { t.Run("user_slot_acquired_after_retry", func(t *testing.T) { c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") streamStarted := false - release, err := helper.waitForSlotWithPingTimeout(c, "user", 202, 3, time.Second, false, &streamStarted) + release, err := helper.waitForSlotWithPingTimeout(c, "user", 202, 3, time.Second, false, &streamStarted, true) require.NoError(t, err) require.NotNil(t, release) release() @@ -207,7 +207,7 @@ func TestWaitForSlotWithPingTimeout_TimeoutAndStreamPing(t *testing.T) { helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond) c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") streamStarted := false - release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, 130*time.Millisecond, false, &streamStarted) + release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, 130*time.Millisecond, false, &streamStarted, true) require.Nil(t, release) var cErr *ConcurrencyError require.ErrorAs(t, err, &cErr) @@ -218,7 +218,7 @@ func TestWaitForSlotWithPingTimeout_TimeoutAndStreamPing(t *testing.T) { helper := NewConcurrencyHelper(concurrency, SSEPingFormatComment, 10*time.Millisecond) c, rec := newHelperTestContext(http.MethodPost, "/v1/messages") streamStarted := false - release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, 70*time.Millisecond, true, &streamStarted) + release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, 70*time.Millisecond, true, &streamStarted, true) require.Nil(t, release) var cErr *ConcurrencyError require.ErrorAs(t, err, &cErr) @@ -236,12 +236,29 @@ func TestWaitForSlotWithPingTimeout_AcquireError(t *testing.T) { helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond) c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") streamStarted := false - release, err := helper.waitForSlotWithPingTimeout(c, "account", 1, 1, 200*time.Millisecond, false, &streamStarted) + release, err := helper.waitForSlotWithPingTimeout(c, "account", 1, 1, 200*time.Millisecond, false, &streamStarted, true) require.Nil(t, release) require.Error(t, err) require.Contains(t, err.Error(), "redis unavailable") } +func TestAcquireAccountSlotWithWaitTimeout_ImmediateAttemptBeforeBackoff(t *testing.T) { + cache := &helperConcurrencyCacheStub{ + accountSeq: []bool{false}, + } + concurrency := service.NewConcurrencyService(cache) + helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond) + c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") + streamStarted := false + + release, err := helper.AcquireAccountSlotWithWaitTimeout(c, 301, 1, 30*time.Millisecond, false, &streamStarted) + require.Nil(t, release) + var cErr *ConcurrencyError + require.ErrorAs(t, err, &cErr) + require.True(t, cErr.IsTimeout) + require.GreaterOrEqual(t, cache.accountAcquireCalls, 1) +} + type helperConcurrencyCacheStubWithError struct { helperConcurrencyCacheStub err error diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index c96484a6..ea212088 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -264,7 +264,12 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { if sessionKey != "" { sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey) if sessionBoundAccountID > 0 { + prefetchedGroupID := int64(0) + if apiKey.GroupID != nil { + prefetchedGroupID = *apiKey.GroupID + } ctx := context.WithValue(c.Request.Context(), ctxkey.PrefetchedStickyAccountID, sessionBoundAccountID) + ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, prefetchedGroupID) c.Request = c.Request.WithContext(ctx) } } diff --git a/backend/internal/pkg/ctxkey/ctxkey.go b/backend/internal/pkg/ctxkey/ctxkey.go index a320ee8c..b13d66cb 100644 --- a/backend/internal/pkg/ctxkey/ctxkey.go +++ b/backend/internal/pkg/ctxkey/ctxkey.go @@ -48,4 +48,8 @@ const ( // PrefetchedStickyAccountID 标识上游(通常 handler)预取到的 sticky session 账号 ID。 // Service 层可复用该值,避免同请求链路重复读取 Redis。 PrefetchedStickyAccountID Key = "ctx_prefetched_sticky_account_id" + + // PrefetchedStickyGroupID 标识上游预取 sticky session 时所使用的分组 ID。 + // Service 层仅在分组匹配时复用 PrefetchedStickyAccountID,避免分组切换重试误用旧 sticky。 + PrefetchedStickyGroupID Key = "ctx_prefetched_sticky_group_id" ) diff --git a/backend/internal/service/gateway_hotpath_optimization_test.go b/backend/internal/service/gateway_hotpath_optimization_test.go index 81824cb3..161c4ba4 100644 --- a/backend/internal/service/gateway_hotpath_optimization_test.go +++ b/backend/internal/service/gateway_hotpath_optimization_test.go @@ -604,17 +604,25 @@ func TestGatewayHotpathHelpers_CacheTTLAndStickyContext(t *testing.T) { }) t.Run("prefetched_sticky_account_id_from_context", func(t *testing.T) { - require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(context.TODO())) - require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(context.Background())) + require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(context.TODO(), nil)) + require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(context.Background(), nil)) ctx := context.WithValue(context.Background(), ctxkey.PrefetchedStickyAccountID, int64(123)) - require.Equal(t, int64(123), prefetchedStickyAccountIDFromContext(ctx)) + ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, int64(0)) + require.Equal(t, int64(123), prefetchedStickyAccountIDFromContext(ctx, nil)) + groupID := int64(9) ctx2 := context.WithValue(context.Background(), ctxkey.PrefetchedStickyAccountID, 456) - require.Equal(t, int64(456), prefetchedStickyAccountIDFromContext(ctx2)) + ctx2 = context.WithValue(ctx2, ctxkey.PrefetchedStickyGroupID, groupID) + require.Equal(t, int64(456), prefetchedStickyAccountIDFromContext(ctx2, &groupID)) ctx3 := context.WithValue(context.Background(), ctxkey.PrefetchedStickyAccountID, "invalid") - require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(ctx3)) + ctx3 = context.WithValue(ctx3, ctxkey.PrefetchedStickyGroupID, groupID) + require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(ctx3, &groupID)) + + ctx4 := context.WithValue(context.Background(), ctxkey.PrefetchedStickyAccountID, int64(789)) + ctx4 = context.WithValue(ctx4, ctxkey.PrefetchedStickyGroupID, int64(10)) + require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(ctx4, &groupID)) }) t.Run("window_cost_from_prefetch_context", func(t *testing.T) { @@ -745,6 +753,7 @@ func TestSelectAccountWithLoadAwareness_StickyReadReuse(t *testing.T) { } ctx := context.WithValue(baseCtx, ctxkey.PrefetchedStickyAccountID, account.ID) + ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, int64(0)) result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sess-hash", "", nil, "") require.NoError(t, err) require.NotNil(t, result) @@ -752,4 +761,26 @@ func TestSelectAccountWithLoadAwareness_StickyReadReuse(t *testing.T) { require.Equal(t, account.ID, result.Account.ID) require.Equal(t, int64(0), cache.getCalls.Load()) }) + + t.Run("with_prefetch_group_mismatch_reads_cache", func(t *testing.T) { + cache := &stickyGatewayCacheHotpathStub{stickyID: account.ID} + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: cfg, + concurrencyService: concurrency, + userGroupRateCache: gocache.New(time.Minute, time.Minute), + modelsListCache: gocache.New(time.Minute, time.Minute), + modelsListCacheTTL: time.Minute, + } + + ctx := context.WithValue(baseCtx, ctxkey.PrefetchedStickyAccountID, int64(999)) + ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, int64(77)) + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sess-hash", "", nil, "") + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.Account) + require.Equal(t, account.ID, result.Account.ID) + require.Equal(t, int64(1), cache.getCalls.Load()) + }) } diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index ae637ee3..e55940ee 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -373,8 +373,26 @@ func modelsListCacheKey(groupID *int64, platform string) string { return fmt.Sprintf("%d|%s", derefGroupID(groupID), strings.TrimSpace(platform)) } -func prefetchedStickyAccountIDFromContext(ctx context.Context) int64 { +func prefetchedStickyGroupIDFromContext(ctx context.Context) (int64, bool) { if ctx == nil { + return 0, false + } + v := ctx.Value(ctxkey.PrefetchedStickyGroupID) + switch t := v.(type) { + case int64: + return t, true + case int: + return int64(t), true + } + return 0, false +} + +func prefetchedStickyAccountIDFromContext(ctx context.Context, groupID *int64) int64 { + if ctx == nil { + return 0 + } + prefetchedGroupID, ok := prefetchedStickyGroupIDFromContext(ctx) + if !ok || prefetchedGroupID != derefGroupID(groupID) { return 0 } v := ctx.Value(ctxkey.PrefetchedStickyAccountID) @@ -1035,15 +1053,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro cfg := s.schedulingConfig() - var stickyAccountID int64 - if prefetch := prefetchedStickyAccountIDFromContext(ctx); prefetch > 0 { - stickyAccountID = prefetch - } else if sessionHash != "" && s.cache != nil { - if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash); err == nil { - stickyAccountID = accountID - } - } - // 检查 Claude Code 客户端限制(可能会替换 groupID 为降级分组) group, groupID, err := s.checkClaudeCodeRestriction(ctx, groupID) if err != nil { @@ -1051,6 +1060,15 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } ctx = s.withGroupContext(ctx, group) + var stickyAccountID int64 + if prefetch := prefetchedStickyAccountIDFromContext(ctx, groupID); prefetch > 0 { + stickyAccountID = prefetch + } else if sessionHash != "" && s.cache != nil { + if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash); err == nil { + stickyAccountID = accountID + } + } + if s.debugModelRoutingEnabled() && requestedModel != "" { groupPlatform := "" if group != nil {