diff --git a/backend/internal/repository/concurrency_cache.go b/backend/internal/repository/concurrency_cache.go index b5f500a3..072dfe69 100644 --- a/backend/internal/repository/concurrency_cache.go +++ b/backend/internal/repository/concurrency_cache.go @@ -11,54 +11,89 @@ import ( ) const ( - accountConcurrencyKeyPrefix = "concurrency:account:" - userConcurrencyKeyPrefix = "concurrency:user:" - waitQueueKeyPrefix = "concurrency:wait:" - concurrencyTTL = 5 * time.Minute + // Key prefixes for independent slot keys + // Format: concurrency:account:{accountID}:{requestID} + accountSlotKeyPrefix = "concurrency:account:" + // Format: concurrency:user:{userID}:{requestID} + userSlotKeyPrefix = "concurrency:user:" + // Wait queue keeps counter format: concurrency:wait:{userID} + waitQueueKeyPrefix = "concurrency:wait:" + + // Slot TTL - each slot expires independently + slotTTL = 5 * time.Minute ) var ( + // acquireScript uses SCAN to count existing slots and creates new slot if under limit + // KEYS[1] = pattern for SCAN (e.g., "concurrency:account:2:*") + // KEYS[2] = full slot key (e.g., "concurrency:account:2:req_xxx") + // ARGV[1] = maxConcurrency + // ARGV[2] = TTL in seconds acquireScript = redis.NewScript(` - local current = redis.call('GET', KEYS[1]) - if current == false then - current = 0 - else - current = tonumber(current) - end - if current < tonumber(ARGV[1]) then - redis.call('INCR', KEYS[1]) - redis.call('EXPIRE', KEYS[1], ARGV[2]) + local pattern = KEYS[1] + local slotKey = KEYS[2] + local maxConcurrency = tonumber(ARGV[1]) + local ttl = tonumber(ARGV[2]) + + -- Count existing slots using SCAN + local cursor = "0" + local count = 0 + repeat + local result = redis.call('SCAN', cursor, 'MATCH', pattern, 'COUNT', 100) + cursor = result[1] + count = count + #result[2] + until cursor == "0" + + -- Check if we can acquire a slot + if count < maxConcurrency then + redis.call('SET', slotKey, '1', 'EX', ttl) return 1 end + return 0 `) - releaseScript = redis.NewScript(` - local current = redis.call('GET', KEYS[1]) - if current ~= false and tonumber(current) > 0 then - redis.call('DECR', KEYS[1]) - end - return 1 + // getCountScript counts slots using SCAN + // KEYS[1] = pattern for SCAN + getCountScript = redis.NewScript(` + local pattern = KEYS[1] + local cursor = "0" + local count = 0 + repeat + local result = redis.call('SCAN', cursor, 'MATCH', pattern, 'COUNT', 100) + cursor = result[1] + count = count + #result[2] + until cursor == "0" + return count `) + // incrementWaitScript - only sets TTL on first creation to avoid refreshing + // KEYS[1] = wait queue key + // ARGV[1] = maxWait + // ARGV[2] = TTL in seconds incrementWaitScript = redis.NewScript(` - local waitKey = KEYS[1] - local maxWait = tonumber(ARGV[1]) - local ttl = tonumber(ARGV[2]) - local current = redis.call('GET', waitKey) + local current = redis.call('GET', KEYS[1]) if current == false then current = 0 else current = tonumber(current) end - if current >= maxWait then + + if current >= tonumber(ARGV[1]) then return 0 end - redis.call('INCR', waitKey) - redis.call('EXPIRE', waitKey, ttl) + + local newVal = redis.call('INCR', KEYS[1]) + + -- Only set TTL on first creation to avoid refreshing zombie data + if newVal == 1 then + redis.call('EXPIRE', KEYS[1], ARGV[2]) + end + return 1 `) + // decrementWaitScript - same as before decrementWaitScript = redis.NewScript(` local current = redis.call('GET', KEYS[1]) if current ~= false and tonumber(current) > 0 then @@ -76,49 +111,86 @@ func NewConcurrencyCache(rdb *redis.Client) ports.ConcurrencyCache { return &concurrencyCache{rdb: rdb} } -func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (bool, error) { - key := fmt.Sprintf("%s%d", accountConcurrencyKeyPrefix, accountID) - result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, int(concurrencyTTL.Seconds())).Int() +// Helper functions for key generation +func accountSlotKey(accountID int64, requestID string) string { + return fmt.Sprintf("%s%d:%s", accountSlotKeyPrefix, accountID, requestID) +} + +func accountSlotPattern(accountID int64) string { + return fmt.Sprintf("%s%d:*", accountSlotKeyPrefix, accountID) +} + +func userSlotKey(userID int64, requestID string) string { + return fmt.Sprintf("%s%d:%s", userSlotKeyPrefix, userID, requestID) +} + +func userSlotPattern(userID int64) string { + return fmt.Sprintf("%s%d:*", userSlotKeyPrefix, userID) +} + +func waitQueueKey(userID int64) string { + return fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID) +} + +// Account slot operations + +func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { + pattern := accountSlotPattern(accountID) + slotKey := accountSlotKey(accountID, requestID) + + result, err := acquireScript.Run(ctx, c.rdb, []string{pattern, slotKey}, maxConcurrency, int(slotTTL.Seconds())).Int() if err != nil { return false, err } return result == 1, nil } -func (c *concurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64) error { - key := fmt.Sprintf("%s%d", accountConcurrencyKeyPrefix, accountID) - _, err := releaseScript.Run(ctx, c.rdb, []string{key}).Result() - return err +func (c *concurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error { + slotKey := accountSlotKey(accountID, requestID) + return c.rdb.Del(ctx, slotKey).Err() } func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) { - key := fmt.Sprintf("%s%d", accountConcurrencyKeyPrefix, accountID) - return c.rdb.Get(ctx, key).Int() + pattern := accountSlotPattern(accountID) + result, err := getCountScript.Run(ctx, c.rdb, []string{pattern}).Int() + if err != nil { + return 0, err + } + return result, nil } -func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (bool, error) { - key := fmt.Sprintf("%s%d", userConcurrencyKeyPrefix, userID) - result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, int(concurrencyTTL.Seconds())).Int() +// User slot operations + +func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) { + pattern := userSlotPattern(userID) + slotKey := userSlotKey(userID, requestID) + + result, err := acquireScript.Run(ctx, c.rdb, []string{pattern, slotKey}, maxConcurrency, int(slotTTL.Seconds())).Int() if err != nil { return false, err } return result == 1, nil } -func (c *concurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64) error { - key := fmt.Sprintf("%s%d", userConcurrencyKeyPrefix, userID) - _, err := releaseScript.Run(ctx, c.rdb, []string{key}).Result() - return err +func (c *concurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error { + slotKey := userSlotKey(userID, requestID) + return c.rdb.Del(ctx, slotKey).Err() } func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) { - key := fmt.Sprintf("%s%d", userConcurrencyKeyPrefix, userID) - return c.rdb.Get(ctx, key).Int() + pattern := userSlotPattern(userID) + result, err := getCountScript.Run(ctx, c.rdb, []string{pattern}).Int() + if err != nil { + return 0, err + } + return result, nil } +// Wait queue operations + func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) { - key := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID) - result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, int(concurrencyTTL.Seconds())).Int() + key := waitQueueKey(userID) + result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, int(slotTTL.Seconds())).Int() if err != nil { return false, err } @@ -126,7 +198,7 @@ func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, } func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error { - key := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID) + key := waitQueueKey(userID) _, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result() return err } diff --git a/backend/internal/service/concurrency_service.go b/backend/internal/service/concurrency_service.go index c554624c..cec2aab8 100644 --- a/backend/internal/service/concurrency_service.go +++ b/backend/internal/service/concurrency_service.go @@ -2,12 +2,26 @@ package service import ( "context" + "crypto/rand" + "encoding/hex" + "fmt" "log" "time" "sub2api/internal/service/ports" ) +// generateRequestID generates a unique request ID for concurrency slot tracking +// Uses 8 random bytes (16 hex chars) for uniqueness +func generateRequestID() string { + b := make([]byte, 8) + if _, err := rand.Read(b); err != nil { + // Fallback to nanosecond timestamp (extremely rare case) + return fmt.Sprintf("%x", time.Now().UnixNano()) + } + return hex.EncodeToString(b) +} + const ( // Default extra wait slots beyond concurrency limit defaultExtraWaitSlots = 20 @@ -41,7 +55,10 @@ func (s *ConcurrencyService) AcquireAccountSlot(ctx context.Context, accountID i }, nil } - acquired, err := s.cache.AcquireAccountSlot(ctx, accountID, maxConcurrency) + // Generate unique request ID for this slot + requestID := generateRequestID() + + acquired, err := s.cache.AcquireAccountSlot(ctx, accountID, maxConcurrency, requestID) if err != nil { return nil, err } @@ -52,8 +69,8 @@ func (s *ConcurrencyService) AcquireAccountSlot(ctx context.Context, accountID i ReleaseFunc: func() { bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - if err := s.cache.ReleaseAccountSlot(bgCtx, accountID); err != nil { - log.Printf("Warning: failed to release account slot for %d: %v", accountID, err) + if err := s.cache.ReleaseAccountSlot(bgCtx, accountID, requestID); err != nil { + log.Printf("Warning: failed to release account slot for %d (req=%s): %v", accountID, requestID, err) } }, }, nil @@ -77,7 +94,10 @@ func (s *ConcurrencyService) AcquireUserSlot(ctx context.Context, userID int64, }, nil } - acquired, err := s.cache.AcquireUserSlot(ctx, userID, maxConcurrency) + // Generate unique request ID for this slot + requestID := generateRequestID() + + acquired, err := s.cache.AcquireUserSlot(ctx, userID, maxConcurrency, requestID) if err != nil { return nil, err } @@ -88,8 +108,8 @@ func (s *ConcurrencyService) AcquireUserSlot(ctx context.Context, userID int64, ReleaseFunc: func() { bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - if err := s.cache.ReleaseUserSlot(bgCtx, userID); err != nil { - log.Printf("Warning: failed to release user slot for %d: %v", userID, err) + if err := s.cache.ReleaseUserSlot(bgCtx, userID, requestID); err != nil { + log.Printf("Warning: failed to release user slot for %d (req=%s): %v", userID, requestID, err) } }, }, nil diff --git a/backend/internal/service/ports/concurrency_cache.go b/backend/internal/service/ports/concurrency_cache.go index 313737b7..2344fe62 100644 --- a/backend/internal/service/ports/concurrency_cache.go +++ b/backend/internal/service/ports/concurrency_cache.go @@ -3,17 +3,21 @@ package ports import "context" // ConcurrencyCache defines cache operations for concurrency service +// Uses independent keys per request slot with native Redis TTL for automatic cleanup type ConcurrencyCache interface { - // Slot management - AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (bool, error) - ReleaseAccountSlot(ctx context.Context, accountID int64) error + // Account slot management - each slot is a separate key with independent TTL + // Key format: concurrency:account:{accountID}:{requestID} + AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) + ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) - AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (bool, error) - ReleaseUserSlot(ctx context.Context, userID int64) error + // User slot management - each slot is a separate key with independent TTL + // Key format: concurrency:user:{userID}:{requestID} + AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) + ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error GetUserConcurrency(ctx context.Context, userID int64) (int, error) - // Wait queue + // Wait queue - uses counter with TTL set only on creation IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) DecrementWaitCount(ctx context.Context, userID int64) error }