feat(sync): full code sync from release
This commit is contained in:
@@ -0,0 +1,115 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type billingCacheMissStub struct {
|
||||
setBalanceCalls atomic.Int64
|
||||
}
|
||||
|
||||
func (s *billingCacheMissStub) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
|
||||
return 0, errors.New("cache miss")
|
||||
}
|
||||
|
||||
func (s *billingCacheMissStub) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
|
||||
s.setBalanceCalls.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *billingCacheMissStub) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *billingCacheMissStub) InvalidateUserBalance(ctx context.Context, userID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *billingCacheMissStub) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*SubscriptionCacheData, error) {
|
||||
return nil, errors.New("cache miss")
|
||||
}
|
||||
|
||||
func (s *billingCacheMissStub) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *billingCacheMissStub) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *billingCacheMissStub) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type balanceLoadUserRepoStub struct {
|
||||
mockUserRepo
|
||||
calls atomic.Int64
|
||||
delay time.Duration
|
||||
balance float64
|
||||
}
|
||||
|
||||
func (s *balanceLoadUserRepoStub) GetByID(ctx context.Context, id int64) (*User, error) {
|
||||
s.calls.Add(1)
|
||||
if s.delay > 0 {
|
||||
select {
|
||||
case <-time.After(s.delay):
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
return &User{ID: id, Balance: s.balance}, nil
|
||||
}
|
||||
|
||||
func TestBillingCacheServiceGetUserBalance_Singleflight(t *testing.T) {
|
||||
cache := &billingCacheMissStub{}
|
||||
userRepo := &balanceLoadUserRepoStub{
|
||||
delay: 80 * time.Millisecond,
|
||||
balance: 12.34,
|
||||
}
|
||||
svc := NewBillingCacheService(cache, userRepo, nil, &config.Config{})
|
||||
t.Cleanup(svc.Stop)
|
||||
|
||||
const goroutines = 16
|
||||
start := make(chan struct{})
|
||||
var wg sync.WaitGroup
|
||||
errCh := make(chan error, goroutines)
|
||||
balCh := make(chan float64, goroutines)
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-start
|
||||
bal, err := svc.GetUserBalance(context.Background(), 99)
|
||||
errCh <- err
|
||||
balCh <- bal
|
||||
}()
|
||||
}
|
||||
|
||||
close(start)
|
||||
wg.Wait()
|
||||
close(errCh)
|
||||
close(balCh)
|
||||
|
||||
for err := range errCh {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
for bal := range balCh {
|
||||
require.Equal(t, 12.34, bal)
|
||||
}
|
||||
|
||||
require.Equal(t, int64(1), userRepo.calls.Load(), "并发穿透应被 singleflight 合并")
|
||||
require.Eventually(t, func() bool {
|
||||
return cache.setBalanceCalls.Load() >= 1
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
}
|
||||
Reference in New Issue
Block a user