fix(gateway): 修复粘性会话预取分组错配并优化并发等待热路径
This commit is contained in:
@@ -244,7 +244,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
if sessionKey != "" {
|
||||
sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey)
|
||||
if sessionBoundAccountID > 0 {
|
||||
prefetchedGroupID := int64(0)
|
||||
if apiKey.GroupID != nil {
|
||||
prefetchedGroupID = *apiKey.GroupID
|
||||
}
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.PrefetchedStickyAccountID, sessionBoundAccountID)
|
||||
ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, prefetchedGroupID)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -230,14 +230,31 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID
|
||||
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests.
|
||||
// streamStarted pointer is updated when streaming begins (for proper error handling by caller).
|
||||
func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
|
||||
return h.waitForSlotWithPingTimeout(c, slotType, id, maxConcurrency, maxConcurrencyWait, isStream, streamStarted)
|
||||
return h.waitForSlotWithPingTimeout(c, slotType, id, maxConcurrency, maxConcurrencyWait, isStream, streamStarted, false)
|
||||
}
|
||||
|
||||
// waitForSlotWithPingTimeout waits for a concurrency slot with a custom timeout.
|
||||
func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType string, id int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) {
|
||||
func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType string, id int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool, tryImmediate bool) (func(), error) {
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
|
||||
defer cancel()
|
||||
|
||||
acquireSlot := func() (*service.AcquireResult, error) {
|
||||
if slotType == "user" {
|
||||
return h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency)
|
||||
}
|
||||
return h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency)
|
||||
}
|
||||
|
||||
if tryImmediate {
|
||||
result, err := acquireSlot()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if result.Acquired {
|
||||
return result.ReleaseFunc, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Determine if ping is needed (streaming + ping format defined)
|
||||
needPing := isStream && h.pingFormat != ""
|
||||
|
||||
@@ -286,15 +303,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
|
||||
|
||||
case <-timer.C:
|
||||
// Try to acquire slot
|
||||
var result *service.AcquireResult
|
||||
var err error
|
||||
|
||||
if slotType == "user" {
|
||||
result, err = h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency)
|
||||
} else {
|
||||
result, err = h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency)
|
||||
}
|
||||
|
||||
result, err := acquireSlot()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -310,7 +319,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
|
||||
|
||||
// AcquireAccountSlotWithWaitTimeout acquires an account slot with a custom timeout (keeps SSE ping).
|
||||
func (h *ConcurrencyHelper) AcquireAccountSlotWithWaitTimeout(c *gin.Context, accountID int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) {
|
||||
return h.waitForSlotWithPingTimeout(c, "account", accountID, maxConcurrency, timeout, isStream, streamStarted)
|
||||
return h.waitForSlotWithPingTimeout(c, "account", accountID, maxConcurrency, timeout, isStream, streamStarted, true)
|
||||
}
|
||||
|
||||
// nextBackoff 计算下一次退避时间
|
||||
|
||||
@@ -176,7 +176,7 @@ func TestWaitForSlotWithPingTimeout_AccountAndUserAcquire(t *testing.T) {
|
||||
t.Run("account_slot_acquired_after_retry", func(t *testing.T) {
|
||||
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||
streamStarted := false
|
||||
release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, time.Second, false, &streamStarted)
|
||||
release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, time.Second, false, &streamStarted, true)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, release)
|
||||
require.False(t, streamStarted)
|
||||
@@ -188,7 +188,7 @@ func TestWaitForSlotWithPingTimeout_AccountAndUserAcquire(t *testing.T) {
|
||||
t.Run("user_slot_acquired_after_retry", func(t *testing.T) {
|
||||
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||
streamStarted := false
|
||||
release, err := helper.waitForSlotWithPingTimeout(c, "user", 202, 3, time.Second, false, &streamStarted)
|
||||
release, err := helper.waitForSlotWithPingTimeout(c, "user", 202, 3, time.Second, false, &streamStarted, true)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, release)
|
||||
release()
|
||||
@@ -207,7 +207,7 @@ func TestWaitForSlotWithPingTimeout_TimeoutAndStreamPing(t *testing.T) {
|
||||
helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond)
|
||||
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||
streamStarted := false
|
||||
release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, 130*time.Millisecond, false, &streamStarted)
|
||||
release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, 130*time.Millisecond, false, &streamStarted, true)
|
||||
require.Nil(t, release)
|
||||
var cErr *ConcurrencyError
|
||||
require.ErrorAs(t, err, &cErr)
|
||||
@@ -218,7 +218,7 @@ func TestWaitForSlotWithPingTimeout_TimeoutAndStreamPing(t *testing.T) {
|
||||
helper := NewConcurrencyHelper(concurrency, SSEPingFormatComment, 10*time.Millisecond)
|
||||
c, rec := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||
streamStarted := false
|
||||
release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, 70*time.Millisecond, true, &streamStarted)
|
||||
release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, 70*time.Millisecond, true, &streamStarted, true)
|
||||
require.Nil(t, release)
|
||||
var cErr *ConcurrencyError
|
||||
require.ErrorAs(t, err, &cErr)
|
||||
@@ -236,12 +236,29 @@ func TestWaitForSlotWithPingTimeout_AcquireError(t *testing.T) {
|
||||
helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond)
|
||||
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||
streamStarted := false
|
||||
release, err := helper.waitForSlotWithPingTimeout(c, "account", 1, 1, 200*time.Millisecond, false, &streamStarted)
|
||||
release, err := helper.waitForSlotWithPingTimeout(c, "account", 1, 1, 200*time.Millisecond, false, &streamStarted, true)
|
||||
require.Nil(t, release)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "redis unavailable")
|
||||
}
|
||||
|
||||
func TestAcquireAccountSlotWithWaitTimeout_ImmediateAttemptBeforeBackoff(t *testing.T) {
|
||||
cache := &helperConcurrencyCacheStub{
|
||||
accountSeq: []bool{false},
|
||||
}
|
||||
concurrency := service.NewConcurrencyService(cache)
|
||||
helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond)
|
||||
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||
streamStarted := false
|
||||
|
||||
release, err := helper.AcquireAccountSlotWithWaitTimeout(c, 301, 1, 30*time.Millisecond, false, &streamStarted)
|
||||
require.Nil(t, release)
|
||||
var cErr *ConcurrencyError
|
||||
require.ErrorAs(t, err, &cErr)
|
||||
require.True(t, cErr.IsTimeout)
|
||||
require.GreaterOrEqual(t, cache.accountAcquireCalls, 1)
|
||||
}
|
||||
|
||||
type helperConcurrencyCacheStubWithError struct {
|
||||
helperConcurrencyCacheStub
|
||||
err error
|
||||
|
||||
@@ -264,7 +264,12 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
if sessionKey != "" {
|
||||
sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey)
|
||||
if sessionBoundAccountID > 0 {
|
||||
prefetchedGroupID := int64(0)
|
||||
if apiKey.GroupID != nil {
|
||||
prefetchedGroupID = *apiKey.GroupID
|
||||
}
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.PrefetchedStickyAccountID, sessionBoundAccountID)
|
||||
ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, prefetchedGroupID)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user