refactor(backend): 添加 service 缓存端口
This commit is contained in:
@@ -2,22 +2,13 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"sub2api/internal/service/ports"
|
||||
)
|
||||
|
||||
const (
|
||||
// Redis key prefixes
|
||||
accountConcurrencyKey = "concurrency:account:"
|
||||
userConcurrencyKey = "concurrency:user:"
|
||||
userWaitCountKey = "concurrency:wait:"
|
||||
|
||||
// TTL for concurrency keys (auto-release safety net)
|
||||
concurrencyKeyTTL = 10 * time.Minute
|
||||
|
||||
// Wait polling interval
|
||||
waitPollInterval = 100 * time.Millisecond
|
||||
|
||||
@@ -28,70 +19,14 @@ const (
|
||||
defaultExtraWaitSlots = 20
|
||||
)
|
||||
|
||||
// Pre-compiled Lua scripts for better performance
|
||||
var (
|
||||
// acquireScript: increment counter if below max, return 1 if successful
|
||||
acquireScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current == false then
|
||||
current = 0
|
||||
else
|
||||
current = tonumber(current)
|
||||
end
|
||||
if current < tonumber(ARGV[1]) then
|
||||
redis.call('INCR', KEYS[1])
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
return 1
|
||||
end
|
||||
return 0
|
||||
`)
|
||||
|
||||
// releaseScript: decrement counter, but don't go below 0
|
||||
releaseScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current ~= false and tonumber(current) > 0 then
|
||||
redis.call('DECR', KEYS[1])
|
||||
end
|
||||
return 1
|
||||
`)
|
||||
|
||||
// incrementWaitScript: increment wait counter if below max, return 1 if successful
|
||||
incrementWaitScript = redis.NewScript(`
|
||||
local waitKey = KEYS[1]
|
||||
local maxWait = tonumber(ARGV[1])
|
||||
local ttl = tonumber(ARGV[2])
|
||||
local current = redis.call('GET', waitKey)
|
||||
if current == false then
|
||||
current = 0
|
||||
else
|
||||
current = tonumber(current)
|
||||
end
|
||||
if current >= maxWait then
|
||||
return 0
|
||||
end
|
||||
redis.call('INCR', waitKey)
|
||||
redis.call('EXPIRE', waitKey, ttl)
|
||||
return 1
|
||||
`)
|
||||
|
||||
// decrementWaitScript: decrement wait counter, but don't go below 0
|
||||
decrementWaitScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current ~= false and tonumber(current) > 0 then
|
||||
redis.call('DECR', KEYS[1])
|
||||
end
|
||||
return 1
|
||||
`)
|
||||
)
|
||||
|
||||
// ConcurrencyService manages concurrent request limiting for accounts and users
|
||||
type ConcurrencyService struct {
|
||||
rdb *redis.Client
|
||||
cache ports.ConcurrencyCache
|
||||
}
|
||||
|
||||
// NewConcurrencyService creates a new ConcurrencyService
|
||||
func NewConcurrencyService(rdb *redis.Client) *ConcurrencyService {
|
||||
return &ConcurrencyService{rdb: rdb}
|
||||
func NewConcurrencyService(cache ports.ConcurrencyCache) *ConcurrencyService {
|
||||
return &ConcurrencyService{cache: cache}
|
||||
}
|
||||
|
||||
// AcquireResult represents the result of acquiring a concurrency slot
|
||||
@@ -104,20 +39,6 @@ type AcquireResult struct {
|
||||
// If the account is at max concurrency, it waits until a slot is available or timeout.
|
||||
// Returns a release function that MUST be called when the request completes.
|
||||
func (s *ConcurrencyService) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) {
|
||||
key := fmt.Sprintf("%s%d", accountConcurrencyKey, accountID)
|
||||
return s.acquireSlot(ctx, key, maxConcurrency)
|
||||
}
|
||||
|
||||
// AcquireUserSlot attempts to acquire a concurrency slot for a user.
|
||||
// If the user is at max concurrency, it waits until a slot is available or timeout.
|
||||
// Returns a release function that MUST be called when the request completes.
|
||||
func (s *ConcurrencyService) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (*AcquireResult, error) {
|
||||
key := fmt.Sprintf("%s%d", userConcurrencyKey, userID)
|
||||
return s.acquireSlot(ctx, key, maxConcurrency)
|
||||
}
|
||||
|
||||
// acquireSlot is the core implementation for acquiring a concurrency slot
|
||||
func (s *ConcurrencyService) acquireSlot(ctx context.Context, key string, maxConcurrency int) (*AcquireResult, error) {
|
||||
// If maxConcurrency is 0 or negative, no limit
|
||||
if maxConcurrency <= 0 {
|
||||
return &AcquireResult{
|
||||
@@ -126,8 +47,7 @@ func (s *ConcurrencyService) acquireSlot(ctx context.Context, key string, maxCon
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Try to acquire immediately
|
||||
acquired, err := s.tryAcquire(ctx, key, maxConcurrency)
|
||||
acquired, err := s.cache.AcquireAccountSlot(ctx, accountID, maxConcurrency)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -135,64 +55,56 @@ func (s *ConcurrencyService) acquireSlot(ctx context.Context, key string, maxCon
|
||||
if acquired {
|
||||
return &AcquireResult{
|
||||
Acquired: true,
|
||||
ReleaseFunc: s.makeReleaseFunc(key),
|
||||
ReleaseFunc: func() {
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := s.cache.ReleaseAccountSlot(bgCtx, accountID); err != nil {
|
||||
log.Printf("Warning: failed to release account slot for %d: %v", accountID, err)
|
||||
}
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Not acquired, return with Acquired=false
|
||||
// The caller (gateway handler) will handle waiting with ping support
|
||||
return &AcquireResult{
|
||||
Acquired: false,
|
||||
ReleaseFunc: nil,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// tryAcquire attempts to increment the counter if below max
|
||||
// Uses pre-compiled Lua script for atomicity and performance
|
||||
func (s *ConcurrencyService) tryAcquire(ctx context.Context, key string, maxConcurrency int) (bool, error) {
|
||||
result, err := acquireScript.Run(ctx, s.rdb, []string{key}, maxConcurrency, int(concurrencyKeyTTL.Seconds())).Int()
|
||||
// AcquireUserSlot attempts to acquire a concurrency slot for a user.
|
||||
// If the user is at max concurrency, it waits until a slot is available or timeout.
|
||||
// Returns a release function that MUST be called when the request completes.
|
||||
func (s *ConcurrencyService) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (*AcquireResult, error) {
|
||||
// If maxConcurrency is 0 or negative, no limit
|
||||
if maxConcurrency <= 0 {
|
||||
return &AcquireResult{
|
||||
Acquired: true,
|
||||
ReleaseFunc: func() {}, // no-op
|
||||
}, nil
|
||||
}
|
||||
|
||||
acquired, err := s.cache.AcquireUserSlot(ctx, userID, maxConcurrency)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("acquire slot failed: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
return result == 1, nil
|
||||
}
|
||||
|
||||
// makeReleaseFunc creates a function to release a concurrency slot
|
||||
func (s *ConcurrencyService) makeReleaseFunc(key string) func() {
|
||||
return func() {
|
||||
// Use background context to ensure release even if original context is cancelled
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := releaseScript.Run(ctx, s.rdb, []string{key}).Err(); err != nil {
|
||||
// Log error but don't panic - TTL will eventually clean up
|
||||
log.Printf("Warning: failed to release concurrency slot for %s: %v", key, err)
|
||||
}
|
||||
if acquired {
|
||||
return &AcquireResult{
|
||||
Acquired: true,
|
||||
ReleaseFunc: func() {
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := s.cache.ReleaseUserSlot(bgCtx, userID); err != nil {
|
||||
log.Printf("Warning: failed to release user slot for %d: %v", userID, err)
|
||||
}
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// GetCurrentCount returns the current concurrency count for debugging/monitoring
|
||||
func (s *ConcurrencyService) GetCurrentCount(ctx context.Context, key string) (int, error) {
|
||||
val, err := s.rdb.Get(ctx, key).Int()
|
||||
if err == redis.Nil {
|
||||
return 0, nil
|
||||
}
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// GetAccountCurrentCount returns current concurrency count for an account
|
||||
func (s *ConcurrencyService) GetAccountCurrentCount(ctx context.Context, accountID int64) (int, error) {
|
||||
key := fmt.Sprintf("%s%d", accountConcurrencyKey, accountID)
|
||||
return s.GetCurrentCount(ctx, key)
|
||||
}
|
||||
|
||||
// GetUserCurrentCount returns current concurrency count for a user
|
||||
func (s *ConcurrencyService) GetUserCurrentCount(ctx context.Context, userID int64) (int, error) {
|
||||
key := fmt.Sprintf("%s%d", userConcurrencyKey, userID)
|
||||
return s.GetCurrentCount(ctx, key)
|
||||
return &AcquireResult{
|
||||
Acquired: false,
|
||||
ReleaseFunc: nil,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ============================================
|
||||
@@ -203,44 +115,36 @@ func (s *ConcurrencyService) GetUserCurrentCount(ctx context.Context, userID int
|
||||
// Returns true if successful, false if the wait queue is full.
|
||||
// maxWait should be user.Concurrency + defaultExtraWaitSlots
|
||||
func (s *ConcurrencyService) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
||||
if s.rdb == nil {
|
||||
if s.cache == nil {
|
||||
// Redis not available, allow request
|
||||
return true, nil
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d", userWaitCountKey, userID)
|
||||
result, err := incrementWaitScript.Run(ctx, s.rdb, []string{key}, maxWait, int(concurrencyKeyTTL.Seconds())).Int()
|
||||
result, err := s.cache.IncrementWaitCount(ctx, userID, maxWait)
|
||||
if err != nil {
|
||||
// On error, allow the request to proceed (fail open)
|
||||
log.Printf("Warning: increment wait count failed for user %d: %v", userID, err)
|
||||
return true, nil
|
||||
}
|
||||
return result == 1, nil
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// DecrementWaitCount decrements the wait queue counter for a user.
|
||||
// Should be called when a request completes or exits the wait queue.
|
||||
func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int64) {
|
||||
if s.rdb == nil {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d", userWaitCountKey, userID)
|
||||
// Use background context to ensure decrement even if original context is cancelled
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := decrementWaitScript.Run(bgCtx, s.rdb, []string{key}).Err(); err != nil {
|
||||
if err := s.cache.DecrementWaitCount(bgCtx, userID); err != nil {
|
||||
log.Printf("Warning: decrement wait count failed for user %d: %v", userID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// GetUserWaitCount returns current wait queue count for a user
|
||||
func (s *ConcurrencyService) GetUserWaitCount(ctx context.Context, userID int64) (int, error) {
|
||||
key := fmt.Sprintf("%s%d", userWaitCountKey, userID)
|
||||
return s.GetCurrentCount(ctx, key)
|
||||
}
|
||||
|
||||
// CalculateMaxWait calculates the maximum wait queue size for a user
|
||||
// maxWait = userConcurrency + defaultExtraWaitSlots
|
||||
func CalculateMaxWait(userConcurrency int) int {
|
||||
|
||||
Reference in New Issue
Block a user