fix(concurrency): 重构并发管理使用独立Key+原生TTL

问题:旧方案使用计数器模式,每次acquire都刷新TTL,导致僵尸数据永不过期

解决方案:
- 每个槽位使用独立Redis Key: concurrency:account:{id}:{requestID}
- 利用Redis原生TTL,每个槽位独立5分钟过期
- 服务崩溃后僵尸数据自动清理,无需手动干预
- 兼容多实例K8s部署

技术改动:
- 新增SCAN脚本统计活跃槽位数量
- 移除冗余的releaseScript,直接使用DEL命令
- Wait队列TTL只在首次创建时设置,避免刷新
This commit is contained in:
shaw
2025-12-24 21:00:29 +08:00
parent aaadd6ed04
commit e65e9587b4
3 changed files with 155 additions and 59 deletions

View File

@@ -11,54 +11,89 @@ import (
) )
const ( const (
accountConcurrencyKeyPrefix = "concurrency:account:" // Key prefixes for independent slot keys
userConcurrencyKeyPrefix = "concurrency:user:" // Format: concurrency:account:{accountID}:{requestID}
waitQueueKeyPrefix = "concurrency:wait:" accountSlotKeyPrefix = "concurrency:account:"
concurrencyTTL = 5 * time.Minute // Format: concurrency:user:{userID}:{requestID}
userSlotKeyPrefix = "concurrency:user:"
// Wait queue keeps counter format: concurrency:wait:{userID}
waitQueueKeyPrefix = "concurrency:wait:"
// Slot TTL - each slot expires independently
slotTTL = 5 * time.Minute
) )
var ( var (
// acquireScript uses SCAN to count existing slots and creates new slot if under limit
// KEYS[1] = pattern for SCAN (e.g., "concurrency:account:2:*")
// KEYS[2] = full slot key (e.g., "concurrency:account:2:req_xxx")
// ARGV[1] = maxConcurrency
// ARGV[2] = TTL in seconds
acquireScript = redis.NewScript(` acquireScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1]) local pattern = KEYS[1]
if current == false then local slotKey = KEYS[2]
current = 0 local maxConcurrency = tonumber(ARGV[1])
else local ttl = tonumber(ARGV[2])
current = tonumber(current)
end -- Count existing slots using SCAN
if current < tonumber(ARGV[1]) then local cursor = "0"
redis.call('INCR', KEYS[1]) local count = 0
redis.call('EXPIRE', KEYS[1], ARGV[2]) repeat
local result = redis.call('SCAN', cursor, 'MATCH', pattern, 'COUNT', 100)
cursor = result[1]
count = count + #result[2]
until cursor == "0"
-- Check if we can acquire a slot
if count < maxConcurrency then
redis.call('SET', slotKey, '1', 'EX', ttl)
return 1 return 1
end end
return 0 return 0
`) `)
releaseScript = redis.NewScript(` // getCountScript counts slots using SCAN
local current = redis.call('GET', KEYS[1]) // KEYS[1] = pattern for SCAN
if current ~= false and tonumber(current) > 0 then getCountScript = redis.NewScript(`
redis.call('DECR', KEYS[1]) local pattern = KEYS[1]
end local cursor = "0"
return 1 local count = 0
repeat
local result = redis.call('SCAN', cursor, 'MATCH', pattern, 'COUNT', 100)
cursor = result[1]
count = count + #result[2]
until cursor == "0"
return count
`) `)
// incrementWaitScript - only sets TTL on first creation to avoid refreshing
// KEYS[1] = wait queue key
// ARGV[1] = maxWait
// ARGV[2] = TTL in seconds
incrementWaitScript = redis.NewScript(` incrementWaitScript = redis.NewScript(`
local waitKey = KEYS[1] local current = redis.call('GET', KEYS[1])
local maxWait = tonumber(ARGV[1])
local ttl = tonumber(ARGV[2])
local current = redis.call('GET', waitKey)
if current == false then if current == false then
current = 0 current = 0
else else
current = tonumber(current) current = tonumber(current)
end end
if current >= maxWait then
if current >= tonumber(ARGV[1]) then
return 0 return 0
end end
redis.call('INCR', waitKey)
redis.call('EXPIRE', waitKey, ttl) local newVal = redis.call('INCR', KEYS[1])
-- Only set TTL on first creation to avoid refreshing zombie data
if newVal == 1 then
redis.call('EXPIRE', KEYS[1], ARGV[2])
end
return 1 return 1
`) `)
// decrementWaitScript - same as before
decrementWaitScript = redis.NewScript(` decrementWaitScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1]) local current = redis.call('GET', KEYS[1])
if current ~= false and tonumber(current) > 0 then if current ~= false and tonumber(current) > 0 then
@@ -76,49 +111,86 @@ func NewConcurrencyCache(rdb *redis.Client) ports.ConcurrencyCache {
return &concurrencyCache{rdb: rdb} return &concurrencyCache{rdb: rdb}
} }
func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (bool, error) { // Helper functions for key generation
key := fmt.Sprintf("%s%d", accountConcurrencyKeyPrefix, accountID) func accountSlotKey(accountID int64, requestID string) string {
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, int(concurrencyTTL.Seconds())).Int() return fmt.Sprintf("%s%d:%s", accountSlotKeyPrefix, accountID, requestID)
}
func accountSlotPattern(accountID int64) string {
return fmt.Sprintf("%s%d:*", accountSlotKeyPrefix, accountID)
}
func userSlotKey(userID int64, requestID string) string {
return fmt.Sprintf("%s%d:%s", userSlotKeyPrefix, userID, requestID)
}
func userSlotPattern(userID int64) string {
return fmt.Sprintf("%s%d:*", userSlotKeyPrefix, userID)
}
func waitQueueKey(userID int64) string {
return fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
}
// Account slot operations
func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
pattern := accountSlotPattern(accountID)
slotKey := accountSlotKey(accountID, requestID)
result, err := acquireScript.Run(ctx, c.rdb, []string{pattern, slotKey}, maxConcurrency, int(slotTTL.Seconds())).Int()
if err != nil { if err != nil {
return false, err return false, err
} }
return result == 1, nil return result == 1, nil
} }
func (c *concurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64) error { func (c *concurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
key := fmt.Sprintf("%s%d", accountConcurrencyKeyPrefix, accountID) slotKey := accountSlotKey(accountID, requestID)
_, err := releaseScript.Run(ctx, c.rdb, []string{key}).Result() return c.rdb.Del(ctx, slotKey).Err()
return err
} }
func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) { func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
key := fmt.Sprintf("%s%d", accountConcurrencyKeyPrefix, accountID) pattern := accountSlotPattern(accountID)
return c.rdb.Get(ctx, key).Int() result, err := getCountScript.Run(ctx, c.rdb, []string{pattern}).Int()
if err != nil {
return 0, err
}
return result, nil
} }
func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (bool, error) { // User slot operations
key := fmt.Sprintf("%s%d", userConcurrencyKeyPrefix, userID)
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, int(concurrencyTTL.Seconds())).Int() func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
pattern := userSlotPattern(userID)
slotKey := userSlotKey(userID, requestID)
result, err := acquireScript.Run(ctx, c.rdb, []string{pattern, slotKey}, maxConcurrency, int(slotTTL.Seconds())).Int()
if err != nil { if err != nil {
return false, err return false, err
} }
return result == 1, nil return result == 1, nil
} }
func (c *concurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64) error { func (c *concurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error {
key := fmt.Sprintf("%s%d", userConcurrencyKeyPrefix, userID) slotKey := userSlotKey(userID, requestID)
_, err := releaseScript.Run(ctx, c.rdb, []string{key}).Result() return c.rdb.Del(ctx, slotKey).Err()
return err
} }
func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) { func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
key := fmt.Sprintf("%s%d", userConcurrencyKeyPrefix, userID) pattern := userSlotPattern(userID)
return c.rdb.Get(ctx, key).Int() result, err := getCountScript.Run(ctx, c.rdb, []string{pattern}).Int()
if err != nil {
return 0, err
}
return result, nil
} }
// Wait queue operations
func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) { func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
key := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID) key := waitQueueKey(userID)
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, int(concurrencyTTL.Seconds())).Int() result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, int(slotTTL.Seconds())).Int()
if err != nil { if err != nil {
return false, err return false, err
} }
@@ -126,7 +198,7 @@ func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64,
} }
func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error { func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID) key := waitQueueKey(userID)
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result() _, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
return err return err
} }

View File

@@ -2,12 +2,26 @@ package service
import ( import (
"context" "context"
"crypto/rand"
"encoding/hex"
"fmt"
"log" "log"
"time" "time"
"sub2api/internal/service/ports" "sub2api/internal/service/ports"
) )
// generateRequestID generates a unique request ID for concurrency slot tracking
// Uses 8 random bytes (16 hex chars) for uniqueness
func generateRequestID() string {
b := make([]byte, 8)
if _, err := rand.Read(b); err != nil {
// Fallback to nanosecond timestamp (extremely rare case)
return fmt.Sprintf("%x", time.Now().UnixNano())
}
return hex.EncodeToString(b)
}
const ( const (
// Default extra wait slots beyond concurrency limit // Default extra wait slots beyond concurrency limit
defaultExtraWaitSlots = 20 defaultExtraWaitSlots = 20
@@ -41,7 +55,10 @@ func (s *ConcurrencyService) AcquireAccountSlot(ctx context.Context, accountID i
}, nil }, nil
} }
acquired, err := s.cache.AcquireAccountSlot(ctx, accountID, maxConcurrency) // Generate unique request ID for this slot
requestID := generateRequestID()
acquired, err := s.cache.AcquireAccountSlot(ctx, accountID, maxConcurrency, requestID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -52,8 +69,8 @@ func (s *ConcurrencyService) AcquireAccountSlot(ctx context.Context, accountID i
ReleaseFunc: func() { ReleaseFunc: func() {
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
if err := s.cache.ReleaseAccountSlot(bgCtx, accountID); err != nil { if err := s.cache.ReleaseAccountSlot(bgCtx, accountID, requestID); err != nil {
log.Printf("Warning: failed to release account slot for %d: %v", accountID, err) log.Printf("Warning: failed to release account slot for %d (req=%s): %v", accountID, requestID, err)
} }
}, },
}, nil }, nil
@@ -77,7 +94,10 @@ func (s *ConcurrencyService) AcquireUserSlot(ctx context.Context, userID int64,
}, nil }, nil
} }
acquired, err := s.cache.AcquireUserSlot(ctx, userID, maxConcurrency) // Generate unique request ID for this slot
requestID := generateRequestID()
acquired, err := s.cache.AcquireUserSlot(ctx, userID, maxConcurrency, requestID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -88,8 +108,8 @@ func (s *ConcurrencyService) AcquireUserSlot(ctx context.Context, userID int64,
ReleaseFunc: func() { ReleaseFunc: func() {
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
if err := s.cache.ReleaseUserSlot(bgCtx, userID); err != nil { if err := s.cache.ReleaseUserSlot(bgCtx, userID, requestID); err != nil {
log.Printf("Warning: failed to release user slot for %d: %v", userID, err) log.Printf("Warning: failed to release user slot for %d (req=%s): %v", userID, requestID, err)
} }
}, },
}, nil }, nil

View File

@@ -3,17 +3,21 @@ package ports
import "context" import "context"
// ConcurrencyCache defines cache operations for concurrency service // ConcurrencyCache defines cache operations for concurrency service
// Uses independent keys per request slot with native Redis TTL for automatic cleanup
type ConcurrencyCache interface { type ConcurrencyCache interface {
// Slot management // Account slot management - each slot is a separate key with independent TTL
AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (bool, error) // Key format: concurrency:account:{accountID}:{requestID}
ReleaseAccountSlot(ctx context.Context, accountID int64) error AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error)
ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error
GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error)
AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (bool, error) // User slot management - each slot is a separate key with independent TTL
ReleaseUserSlot(ctx context.Context, userID int64) error // Key format: concurrency:user:{userID}:{requestID}
AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error)
ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error
GetUserConcurrency(ctx context.Context, userID int64) (int, error) GetUserConcurrency(ctx context.Context, userID int64) (int, error)
// Wait queue // Wait queue - uses counter with TTL set only on creation
IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error)
DecrementWaitCount(ctx context.Context, userID int64) error DecrementWaitCount(ctx context.Context, userID int64) error
} }