Files
sub2api/backend/internal/service/billing_cache_service_singleflight_test.go
2026-02-28 15:01:20 +08:00

116 lines
2.7 KiB
Go

//go:build unit
package service
import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type billingCacheMissStub struct {
setBalanceCalls atomic.Int64
}
func (s *billingCacheMissStub) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
return 0, errors.New("cache miss")
}
func (s *billingCacheMissStub) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
s.setBalanceCalls.Add(1)
return nil
}
func (s *billingCacheMissStub) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
return nil
}
func (s *billingCacheMissStub) InvalidateUserBalance(ctx context.Context, userID int64) error {
return nil
}
func (s *billingCacheMissStub) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*SubscriptionCacheData, error) {
return nil, errors.New("cache miss")
}
func (s *billingCacheMissStub) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error {
return nil
}
func (s *billingCacheMissStub) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
return nil
}
func (s *billingCacheMissStub) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error {
return nil
}
type balanceLoadUserRepoStub struct {
mockUserRepo
calls atomic.Int64
delay time.Duration
balance float64
}
func (s *balanceLoadUserRepoStub) GetByID(ctx context.Context, id int64) (*User, error) {
s.calls.Add(1)
if s.delay > 0 {
select {
case <-time.After(s.delay):
case <-ctx.Done():
return nil, ctx.Err()
}
}
return &User{ID: id, Balance: s.balance}, nil
}
func TestBillingCacheServiceGetUserBalance_Singleflight(t *testing.T) {
cache := &billingCacheMissStub{}
userRepo := &balanceLoadUserRepoStub{
delay: 80 * time.Millisecond,
balance: 12.34,
}
svc := NewBillingCacheService(cache, userRepo, nil, &config.Config{})
t.Cleanup(svc.Stop)
const goroutines = 16
start := make(chan struct{})
var wg sync.WaitGroup
errCh := make(chan error, goroutines)
balCh := make(chan float64, goroutines)
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
<-start
bal, err := svc.GetUserBalance(context.Background(), 99)
errCh <- err
balCh <- bal
}()
}
close(start)
wg.Wait()
close(errCh)
close(balCh)
for err := range errCh {
require.NoError(t, err)
}
for bal := range balCh {
require.Equal(t, 12.34, bal)
}
require.Equal(t, int64(1), userRepo.calls.Load(), "并发穿透应被 singleflight 合并")
require.Eventually(t, func() bool {
return cache.setBalanceCalls.Load() >= 1
}, time.Second, 10*time.Millisecond)
}