Merge PR #229: perf(网关): 粘性会话命中复用候选账号
This commit is contained in:
@@ -24,9 +24,11 @@ type mockAccountRepoForPlatform struct {
|
|||||||
accounts []Account
|
accounts []Account
|
||||||
accountsByID map[int64]*Account
|
accountsByID map[int64]*Account
|
||||||
listPlatformFunc func(ctx context.Context, platform string) ([]Account, error)
|
listPlatformFunc func(ctx context.Context, platform string) ([]Account, error)
|
||||||
|
getByIDCalls int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockAccountRepoForPlatform) GetByID(ctx context.Context, id int64) (*Account, error) {
|
func (m *mockAccountRepoForPlatform) GetByID(ctx context.Context, id int64) (*Account, error) {
|
||||||
|
m.getByIDCalls++
|
||||||
if acc, ok := m.accountsByID[id]; ok {
|
if acc, ok := m.accountsByID[id]; ok {
|
||||||
return acc, nil
|
return acc, nil
|
||||||
}
|
}
|
||||||
@@ -948,6 +950,74 @@ func (m *mockConcurrencyService) GetAccountWaitingCount(ctx context.Context, acc
|
|||||||
return m.accountWaitCounts[accountID], nil
|
return m.accountWaitCounts[accountID], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type mockConcurrencyCache struct {
|
||||||
|
acquireAccountCalls int
|
||||||
|
loadBatchCalls int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||||
|
m.acquireAccountCalls++
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockConcurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockConcurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockConcurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockConcurrencyCache) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockConcurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockConcurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockConcurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockConcurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockConcurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
|
||||||
|
m.loadBatchCalls++
|
||||||
|
result := make(map[int64]*AccountLoadInfo, len(accounts))
|
||||||
|
for _, acc := range accounts {
|
||||||
|
result[acc.ID] = &AccountLoadInfo{
|
||||||
|
AccountID: acc.ID,
|
||||||
|
CurrentConcurrency: 0,
|
||||||
|
WaitingCount: 0,
|
||||||
|
LoadRate: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockConcurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// TestGatewayService_SelectAccountWithLoadAwareness tests load-aware account selection
|
// TestGatewayService_SelectAccountWithLoadAwareness tests load-aware account selection
|
||||||
func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
@@ -1046,6 +1116,78 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
|||||||
require.Equal(t, int64(2), result.Account.ID, "不应选择被排除的账号")
|
require.Equal(t, int64(2), result.Account.ID, "不应选择被排除的账号")
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("粘性命中-不调用GetByID", func(t *testing.T) {
|
||||||
|
repo := &mockAccountRepoForPlatform{
|
||||||
|
accounts: []Account{
|
||||||
|
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
|
||||||
|
},
|
||||||
|
accountsByID: map[int64]*Account{},
|
||||||
|
}
|
||||||
|
for i := range repo.accounts {
|
||||||
|
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
cache := &mockGatewayCacheForPlatform{
|
||||||
|
sessionBindings: map[string]int64{"sticky": 1},
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := testConfig()
|
||||||
|
cfg.Gateway.Scheduling.LoadBatchEnabled = true
|
||||||
|
|
||||||
|
concurrencyCache := &mockConcurrencyCache{}
|
||||||
|
|
||||||
|
svc := &GatewayService{
|
||||||
|
accountRepo: repo,
|
||||||
|
cache: cache,
|
||||||
|
cfg: cfg,
|
||||||
|
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.NotNil(t, result.Account)
|
||||||
|
require.Equal(t, int64(1), result.Account.ID)
|
||||||
|
require.Equal(t, 0, repo.getByIDCalls, "粘性命中不应调用GetByID")
|
||||||
|
require.Equal(t, 0, concurrencyCache.loadBatchCalls, "粘性命中应在负载批量查询前返回")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("粘性账号不在候选集-回退负载感知选择", func(t *testing.T) {
|
||||||
|
repo := &mockAccountRepoForPlatform{
|
||||||
|
accounts: []Account{
|
||||||
|
{ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
|
||||||
|
},
|
||||||
|
accountsByID: map[int64]*Account{},
|
||||||
|
}
|
||||||
|
for i := range repo.accounts {
|
||||||
|
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
cache := &mockGatewayCacheForPlatform{
|
||||||
|
sessionBindings: map[string]int64{"sticky": 1},
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := testConfig()
|
||||||
|
cfg.Gateway.Scheduling.LoadBatchEnabled = true
|
||||||
|
|
||||||
|
concurrencyCache := &mockConcurrencyCache{}
|
||||||
|
|
||||||
|
svc := &GatewayService{
|
||||||
|
accountRepo: repo,
|
||||||
|
cache: cache,
|
||||||
|
cfg: cfg,
|
||||||
|
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.NotNil(t, result.Account)
|
||||||
|
require.Equal(t, int64(2), result.Account.ID, "粘性账号不在候选集时应回退到可用账号")
|
||||||
|
require.Equal(t, 0, repo.getByIDCalls, "粘性账号缺失不应回退到GetByID")
|
||||||
|
require.Equal(t, 1, concurrencyCache.loadBatchCalls, "应继续进行负载批量查询")
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("无可用账号-返回错误", func(t *testing.T) {
|
t.Run("无可用账号-返回错误", func(t *testing.T) {
|
||||||
repo := &mockAccountRepoForPlatform{
|
repo := &mockAccountRepoForPlatform{
|
||||||
accounts: []Account{},
|
accounts: []Account{},
|
||||||
|
|||||||
@@ -465,8 +465,13 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
if sessionHash != "" && s.cache != nil {
|
if sessionHash != "" && s.cache != nil {
|
||||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||||
if err == nil && accountID > 0 && !isExcluded(accountID) {
|
if err == nil && accountID > 0 && !isExcluded(accountID) {
|
||||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
// 粘性命中仅在当前可调度候选集中生效。
|
||||||
if err == nil && s.isAccountInGroup(account, groupID) &&
|
accountByID := make(map[int64]*Account, len(accounts))
|
||||||
|
for i := range accounts {
|
||||||
|
accountByID[accounts[i].ID] = &accounts[i]
|
||||||
|
}
|
||||||
|
account, ok := accountByID[accountID]
|
||||||
|
if ok && s.isAccountInGroup(account, groupID) &&
|
||||||
s.isAccountAllowedForPlatform(account, platform, useMixed) &&
|
s.isAccountAllowedForPlatform(account, platform, useMixed) &&
|
||||||
account.IsSchedulableForModel(requestedModel) &&
|
account.IsSchedulableForModel(requestedModel) &&
|
||||||
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||||
|
|||||||
Reference in New Issue
Block a user