package service import ( "context" "crypto/rand" "encoding/binary" "os" "strconv" "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" ) // ConcurrencyCache 定义并发控制的缓存接口 // 使用有序集合存储槽位,按时间戳清理过期条目 type ConcurrencyCache interface { // 账号槽位管理 // 键格式: concurrency:account:{accountID}(有序集合,成员为 requestID) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) // 账号等待队列(账号级) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) DecrementAccountWaitCount(ctx context.Context, accountID int64) error GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) // 用户槽位管理 // 键格式: concurrency:user:{userID}(有序集合,成员为 requestID) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error GetUserConcurrency(ctx context.Context, userID int64) (int, error) // 等待队列计数(只在首次创建时设置 TTL) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) DecrementWaitCount(ctx context.Context, userID int64) error // 批量负载查询(只读) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) GetUsersLoadBatch(ctx context.Context, users []UserWithConcurrency) (map[int64]*UserLoadInfo, error) // 清理过期槽位(后台任务) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error } var ( requestIDPrefix = initRequestIDPrefix() requestIDCounter atomic.Uint64 ) func initRequestIDPrefix() string { b := make([]byte, 8) if _, err := rand.Read(b); err == nil { return "r" + strconv.FormatUint(binary.BigEndian.Uint64(b), 36) } fallback := uint64(time.Now().UnixNano()) ^ (uint64(os.Getpid()) << 16) return "r" + strconv.FormatUint(fallback, 36) } // generateRequestID generates a unique request ID for concurrency slot tracking. // Format: {process_random_prefix}-{base36_counter} func generateRequestID() string { seq := requestIDCounter.Add(1) return requestIDPrefix + "-" + strconv.FormatUint(seq, 36) } const ( // Default extra wait slots beyond concurrency limit defaultExtraWaitSlots = 20 ) // ConcurrencyService manages concurrent request limiting for accounts and users type ConcurrencyService struct { cache ConcurrencyCache } // NewConcurrencyService creates a new ConcurrencyService func NewConcurrencyService(cache ConcurrencyCache) *ConcurrencyService { return &ConcurrencyService{cache: cache} } // AcquireResult represents the result of acquiring a concurrency slot type AcquireResult struct { Acquired bool ReleaseFunc func() // Must be called when done (typically via defer) } type AccountWithConcurrency struct { ID int64 MaxConcurrency int } type UserWithConcurrency struct { ID int64 MaxConcurrency int } type AccountLoadInfo struct { AccountID int64 CurrentConcurrency int WaitingCount int LoadRate int // 0-100+ (percent) } type UserLoadInfo struct { UserID int64 CurrentConcurrency int WaitingCount int LoadRate int // 0-100+ (percent) } // AcquireAccountSlot attempts to acquire a concurrency slot for an account. // If the account is at max concurrency, it waits until a slot is available or timeout. // Returns a release function that MUST be called when the request completes. func (s *ConcurrencyService) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) { // If maxConcurrency is 0 or negative, no limit if maxConcurrency <= 0 { return &AcquireResult{ Acquired: true, ReleaseFunc: func() {}, // no-op }, nil } // Generate unique request ID for this slot requestID := generateRequestID() acquired, err := s.cache.AcquireAccountSlot(ctx, accountID, maxConcurrency, requestID) if err != nil { return nil, err } if acquired { return &AcquireResult{ Acquired: true, ReleaseFunc: func() { bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := s.cache.ReleaseAccountSlot(bgCtx, accountID, requestID); err != nil { logger.LegacyPrintf("service.concurrency", "Warning: failed to release account slot for %d (req=%s): %v", accountID, requestID, err) } }, }, nil } return &AcquireResult{ Acquired: false, ReleaseFunc: nil, }, nil } // AcquireUserSlot attempts to acquire a concurrency slot for a user. // If the user is at max concurrency, it waits until a slot is available or timeout. // Returns a release function that MUST be called when the request completes. func (s *ConcurrencyService) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (*AcquireResult, error) { // If maxConcurrency is 0 or negative, no limit if maxConcurrency <= 0 { return &AcquireResult{ Acquired: true, ReleaseFunc: func() {}, // no-op }, nil } // Generate unique request ID for this slot requestID := generateRequestID() acquired, err := s.cache.AcquireUserSlot(ctx, userID, maxConcurrency, requestID) if err != nil { return nil, err } if acquired { return &AcquireResult{ Acquired: true, ReleaseFunc: func() { bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := s.cache.ReleaseUserSlot(bgCtx, userID, requestID); err != nil { logger.LegacyPrintf("service.concurrency", "Warning: failed to release user slot for %d (req=%s): %v", userID, requestID, err) } }, }, nil } return &AcquireResult{ Acquired: false, ReleaseFunc: nil, }, nil } // ============================================ // Wait Queue Count Methods // ============================================ // IncrementWaitCount attempts to increment the wait queue counter for a user. // Returns true if successful, false if the wait queue is full. // maxWait should be user.Concurrency + defaultExtraWaitSlots func (s *ConcurrencyService) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) { if s.cache == nil { // Redis not available, allow request return true, nil } result, err := s.cache.IncrementWaitCount(ctx, userID, maxWait) if err != nil { // On error, allow the request to proceed (fail open) logger.LegacyPrintf("service.concurrency", "Warning: increment wait count failed for user %d: %v", userID, err) return true, nil } return result, nil } // DecrementWaitCount decrements the wait queue counter for a user. // Should be called when a request completes or exits the wait queue. func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int64) { if s.cache == nil { return } // Use background context to ensure decrement even if original context is cancelled bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := s.cache.DecrementWaitCount(bgCtx, userID); err != nil { logger.LegacyPrintf("service.concurrency", "Warning: decrement wait count failed for user %d: %v", userID, err) } } // IncrementAccountWaitCount increments the wait queue counter for an account. func (s *ConcurrencyService) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) { if s.cache == nil { return true, nil } result, err := s.cache.IncrementAccountWaitCount(ctx, accountID, maxWait) if err != nil { logger.LegacyPrintf("service.concurrency", "Warning: increment wait count failed for account %d: %v", accountID, err) return true, nil } return result, nil } // DecrementAccountWaitCount decrements the wait queue counter for an account. func (s *ConcurrencyService) DecrementAccountWaitCount(ctx context.Context, accountID int64) { if s.cache == nil { return } bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := s.cache.DecrementAccountWaitCount(bgCtx, accountID); err != nil { logger.LegacyPrintf("service.concurrency", "Warning: decrement wait count failed for account %d: %v", accountID, err) } } // GetAccountWaitingCount gets current wait queue count for an account. func (s *ConcurrencyService) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) { if s.cache == nil { return 0, nil } return s.cache.GetAccountWaitingCount(ctx, accountID) } // CalculateMaxWait calculates the maximum wait queue size for a user // maxWait = userConcurrency + defaultExtraWaitSlots func CalculateMaxWait(userConcurrency int) int { if userConcurrency <= 0 { userConcurrency = 1 } return userConcurrency + defaultExtraWaitSlots } // GetAccountsLoadBatch returns load info for multiple accounts. func (s *ConcurrencyService) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) { if s.cache == nil { return map[int64]*AccountLoadInfo{}, nil } return s.cache.GetAccountsLoadBatch(ctx, accounts) } // GetUsersLoadBatch returns load info for multiple users. func (s *ConcurrencyService) GetUsersLoadBatch(ctx context.Context, users []UserWithConcurrency) (map[int64]*UserLoadInfo, error) { if s.cache == nil { return map[int64]*UserLoadInfo{}, nil } return s.cache.GetUsersLoadBatch(ctx, users) } // CleanupExpiredAccountSlots removes expired slots for one account (background task). func (s *ConcurrencyService) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error { if s.cache == nil { return nil } return s.cache.CleanupExpiredAccountSlots(ctx, accountID) } // StartSlotCleanupWorker starts a background cleanup worker for expired account slots. func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepository, interval time.Duration) { if s == nil || s.cache == nil || accountRepo == nil || interval <= 0 { return } runCleanup := func() { listCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) accounts, err := accountRepo.ListSchedulable(listCtx) cancel() if err != nil { logger.LegacyPrintf("service.concurrency", "Warning: list schedulable accounts failed: %v", err) return } for _, account := range accounts { accountCtx, accountCancel := context.WithTimeout(context.Background(), 2*time.Second) err := s.cache.CleanupExpiredAccountSlots(accountCtx, account.ID) accountCancel() if err != nil { logger.LegacyPrintf("service.concurrency", "Warning: cleanup expired slots failed for account %d: %v", account.ID, err) } } } go func() { ticker := time.NewTicker(interval) defer ticker.Stop() runCleanup() for range ticker.C { runCleanup() } }() } // GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts // Returns a map of accountID -> current concurrency count func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) { if len(accountIDs) == 0 { return map[int64]int{}, nil } if s.cache == nil { result := make(map[int64]int, len(accountIDs)) for _, accountID := range accountIDs { result[accountID] = 0 } return result, nil } return s.cache.GetAccountConcurrencyBatch(ctx, accountIDs) }