//go:build unit package service import ( "context" "errors" "sync" "sync/atomic" "testing" "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/stretchr/testify/require" ) type billingCacheMissStub struct { setBalanceCalls atomic.Int64 } func (s *billingCacheMissStub) GetUserBalance(ctx context.Context, userID int64) (float64, error) { return 0, errors.New("cache miss") } func (s *billingCacheMissStub) SetUserBalance(ctx context.Context, userID int64, balance float64) error { s.setBalanceCalls.Add(1) return nil } func (s *billingCacheMissStub) DeductUserBalance(ctx context.Context, userID int64, amount float64) error { return nil } func (s *billingCacheMissStub) InvalidateUserBalance(ctx context.Context, userID int64) error { return nil } func (s *billingCacheMissStub) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*SubscriptionCacheData, error) { return nil, errors.New("cache miss") } func (s *billingCacheMissStub) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error { return nil } func (s *billingCacheMissStub) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error { return nil } func (s *billingCacheMissStub) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error { return nil } type balanceLoadUserRepoStub struct { mockUserRepo calls atomic.Int64 delay time.Duration balance float64 } func (s *balanceLoadUserRepoStub) GetByID(ctx context.Context, id int64) (*User, error) { s.calls.Add(1) if s.delay > 0 { select { case <-time.After(s.delay): case <-ctx.Done(): return nil, ctx.Err() } } return &User{ID: id, Balance: s.balance}, nil } func TestBillingCacheServiceGetUserBalance_Singleflight(t *testing.T) { cache := &billingCacheMissStub{} userRepo := &balanceLoadUserRepoStub{ delay: 80 * time.Millisecond, balance: 12.34, } svc := NewBillingCacheService(cache, userRepo, nil, &config.Config{}) t.Cleanup(svc.Stop) const goroutines = 16 start := make(chan struct{}) var wg sync.WaitGroup errCh := make(chan error, goroutines) balCh := make(chan float64, goroutines) for i := 0; i < goroutines; i++ { wg.Add(1) go func() { defer wg.Done() <-start bal, err := svc.GetUserBalance(context.Background(), 99) errCh <- err balCh <- bal }() } close(start) wg.Wait() close(errCh) close(balCh) for err := range errCh { require.NoError(t, err) } for bal := range balCh { require.Equal(t, 12.34, bal) } require.Equal(t, int64(1), userRepo.calls.Load(), "并发穿透应被 singleflight 合并") require.Eventually(t, func() bool { return cache.setBalanceCalls.Load() >= 1 }, time.Second, 10*time.Millisecond) }