//go:build unit package service import ( "context" "errors" "testing" "github.com/stretchr/testify/require" ) // stubConcurrencyCacheForTest 用于并发服务单元测试的缓存桩 type stubConcurrencyCacheForTest struct { acquireResult bool acquireErr error releaseErr error concurrency int concurrencyErr error waitAllowed bool waitErr error waitCount int waitCountErr error loadBatch map[int64]*AccountLoadInfo loadBatchErr error usersLoadBatch map[int64]*UserLoadInfo usersLoadErr error cleanupErr error // 记录调用 releasedAccountIDs []int64 releasedRequestIDs []string } var _ ConcurrencyCache = (*stubConcurrencyCacheForTest)(nil) func (c *stubConcurrencyCacheForTest) AcquireAccountSlot(_ context.Context, _ int64, _ int, _ string) (bool, error) { return c.acquireResult, c.acquireErr } func (c *stubConcurrencyCacheForTest) ReleaseAccountSlot(_ context.Context, accountID int64, requestID string) error { c.releasedAccountIDs = append(c.releasedAccountIDs, accountID) c.releasedRequestIDs = append(c.releasedRequestIDs, requestID) return c.releaseErr } func (c *stubConcurrencyCacheForTest) GetAccountConcurrency(_ context.Context, _ int64) (int, error) { return c.concurrency, c.concurrencyErr } func (c *stubConcurrencyCacheForTest) IncrementAccountWaitCount(_ context.Context, _ int64, _ int) (bool, error) { return c.waitAllowed, c.waitErr } func (c *stubConcurrencyCacheForTest) DecrementAccountWaitCount(_ context.Context, _ int64) error { return nil } func (c *stubConcurrencyCacheForTest) GetAccountWaitingCount(_ context.Context, _ int64) (int, error) { return c.waitCount, c.waitCountErr } func (c *stubConcurrencyCacheForTest) AcquireUserSlot(_ context.Context, _ int64, _ int, _ string) (bool, error) { return c.acquireResult, c.acquireErr } func (c *stubConcurrencyCacheForTest) ReleaseUserSlot(_ context.Context, _ int64, _ string) error { return c.releaseErr } func (c *stubConcurrencyCacheForTest) GetUserConcurrency(_ context.Context, _ int64) (int, error) { return c.concurrency, c.concurrencyErr } func (c *stubConcurrencyCacheForTest) IncrementWaitCount(_ context.Context, _ int64, _ int) (bool, error) { return c.waitAllowed, c.waitErr } func (c *stubConcurrencyCacheForTest) DecrementWaitCount(_ context.Context, _ int64) error { return nil } func (c *stubConcurrencyCacheForTest) GetAccountsLoadBatch(_ context.Context, _ []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) { return c.loadBatch, c.loadBatchErr } func (c *stubConcurrencyCacheForTest) GetUsersLoadBatch(_ context.Context, _ []UserWithConcurrency) (map[int64]*UserLoadInfo, error) { return c.usersLoadBatch, c.usersLoadErr } func (c *stubConcurrencyCacheForTest) CleanupExpiredAccountSlots(_ context.Context, _ int64) error { return c.cleanupErr } func TestAcquireAccountSlot_Success(t *testing.T) { cache := &stubConcurrencyCacheForTest{acquireResult: true} svc := NewConcurrencyService(cache) result, err := svc.AcquireAccountSlot(context.Background(), 1, 5) require.NoError(t, err) require.True(t, result.Acquired) require.NotNil(t, result.ReleaseFunc) } func TestAcquireAccountSlot_Failure(t *testing.T) { cache := &stubConcurrencyCacheForTest{acquireResult: false} svc := NewConcurrencyService(cache) result, err := svc.AcquireAccountSlot(context.Background(), 1, 5) require.NoError(t, err) require.False(t, result.Acquired) require.Nil(t, result.ReleaseFunc) } func TestAcquireAccountSlot_UnlimitedConcurrency(t *testing.T) { svc := NewConcurrencyService(&stubConcurrencyCacheForTest{}) for _, maxConcurrency := range []int{0, -1} { result, err := svc.AcquireAccountSlot(context.Background(), 1, maxConcurrency) require.NoError(t, err) require.True(t, result.Acquired, "maxConcurrency=%d 应无限制通过", maxConcurrency) require.NotNil(t, result.ReleaseFunc, "ReleaseFunc 应为 no-op 函数") } } func TestAcquireAccountSlot_CacheError(t *testing.T) { cache := &stubConcurrencyCacheForTest{acquireErr: errors.New("redis down")} svc := NewConcurrencyService(cache) result, err := svc.AcquireAccountSlot(context.Background(), 1, 5) require.Error(t, err) require.Nil(t, result) } func TestAcquireAccountSlot_ReleaseDecrements(t *testing.T) { cache := &stubConcurrencyCacheForTest{acquireResult: true} svc := NewConcurrencyService(cache) result, err := svc.AcquireAccountSlot(context.Background(), 42, 5) require.NoError(t, err) require.True(t, result.Acquired) // 调用 ReleaseFunc 应释放槽位 result.ReleaseFunc() require.Len(t, cache.releasedAccountIDs, 1) require.Equal(t, int64(42), cache.releasedAccountIDs[0]) require.Len(t, cache.releasedRequestIDs, 1) require.NotEmpty(t, cache.releasedRequestIDs[0], "requestID 不应为空") } func TestAcquireUserSlot_IndependentFromAccount(t *testing.T) { cache := &stubConcurrencyCacheForTest{acquireResult: true} svc := NewConcurrencyService(cache) // 用户槽位获取应独立于账户槽位 result, err := svc.AcquireUserSlot(context.Background(), 100, 3) require.NoError(t, err) require.True(t, result.Acquired) require.NotNil(t, result.ReleaseFunc) } func TestAcquireUserSlot_UnlimitedConcurrency(t *testing.T) { svc := NewConcurrencyService(&stubConcurrencyCacheForTest{}) result, err := svc.AcquireUserSlot(context.Background(), 1, 0) require.NoError(t, err) require.True(t, result.Acquired) } func TestGetAccountsLoadBatch_ReturnsCorrectData(t *testing.T) { expected := map[int64]*AccountLoadInfo{ 1: {AccountID: 1, CurrentConcurrency: 3, WaitingCount: 0, LoadRate: 60}, 2: {AccountID: 2, CurrentConcurrency: 5, WaitingCount: 2, LoadRate: 100}, } cache := &stubConcurrencyCacheForTest{loadBatch: expected} svc := NewConcurrencyService(cache) accounts := []AccountWithConcurrency{ {ID: 1, MaxConcurrency: 5}, {ID: 2, MaxConcurrency: 5}, } result, err := svc.GetAccountsLoadBatch(context.Background(), accounts) require.NoError(t, err) require.Equal(t, expected, result) } func TestGetAccountsLoadBatch_NilCache(t *testing.T) { svc := &ConcurrencyService{cache: nil} result, err := svc.GetAccountsLoadBatch(context.Background(), nil) require.NoError(t, err) require.Empty(t, result) } func TestIncrementWaitCount_Success(t *testing.T) { cache := &stubConcurrencyCacheForTest{waitAllowed: true} svc := NewConcurrencyService(cache) allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25) require.NoError(t, err) require.True(t, allowed) } func TestIncrementWaitCount_QueueFull(t *testing.T) { cache := &stubConcurrencyCacheForTest{waitAllowed: false} svc := NewConcurrencyService(cache) allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25) require.NoError(t, err) require.False(t, allowed) } func TestIncrementWaitCount_FailOpen(t *testing.T) { // Redis 错误时应 fail-open(允许请求通过) cache := &stubConcurrencyCacheForTest{waitErr: errors.New("redis timeout")} svc := NewConcurrencyService(cache) allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25) require.NoError(t, err, "Redis 错误不应传播") require.True(t, allowed, "Redis 错误时应 fail-open") } func TestIncrementWaitCount_NilCache(t *testing.T) { svc := &ConcurrencyService{cache: nil} allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25) require.NoError(t, err) require.True(t, allowed, "nil cache 应 fail-open") } func TestCalculateMaxWait(t *testing.T) { tests := []struct { concurrency int expected int }{ {5, 25}, // 5 + 20 {1, 21}, // 1 + 20 {0, 21}, // min(1) + 20 {-1, 21}, // min(1) + 20 {10, 30}, // 10 + 20 } for _, tt := range tests { result := CalculateMaxWait(tt.concurrency) require.Equal(t, tt.expected, result, "CalculateMaxWait(%d)", tt.concurrency) } } func TestGetAccountWaitingCount(t *testing.T) { cache := &stubConcurrencyCacheForTest{waitCount: 5} svc := NewConcurrencyService(cache) count, err := svc.GetAccountWaitingCount(context.Background(), 1) require.NoError(t, err) require.Equal(t, 5, count) } func TestGetAccountWaitingCount_NilCache(t *testing.T) { svc := &ConcurrencyService{cache: nil} count, err := svc.GetAccountWaitingCount(context.Background(), 1) require.NoError(t, err) require.Equal(t, 0, count) } func TestGetAccountConcurrencyBatch(t *testing.T) { cache := &stubConcurrencyCacheForTest{concurrency: 3} svc := NewConcurrencyService(cache) result, err := svc.GetAccountConcurrencyBatch(context.Background(), []int64{1, 2, 3}) require.NoError(t, err) require.Len(t, result, 3) for _, id := range []int64{1, 2, 3} { require.Equal(t, 3, result[id]) } } func TestIncrementAccountWaitCount_FailOpen(t *testing.T) { cache := &stubConcurrencyCacheForTest{waitErr: errors.New("redis error")} svc := NewConcurrencyService(cache) allowed, err := svc.IncrementAccountWaitCount(context.Background(), 1, 10) require.NoError(t, err, "Redis 错误不应传播") require.True(t, allowed, "Redis 错误时应 fail-open") } func TestIncrementAccountWaitCount_NilCache(t *testing.T) { svc := &ConcurrencyService{cache: nil} allowed, err := svc.IncrementAccountWaitCount(context.Background(), 1, 10) require.NoError(t, err) require.True(t, allowed) }