主要改进: - 优化负载感知调度的准确性和响应速度 - 将 AccountUsageService 的包级缓存改为依赖注入 - 修复 SSE/JSON 转义和 nil 安全问题 - 恢复 Google One 功能兼容性
396 lines
12 KiB
Go
396 lines
12 KiB
Go
package repository
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"fmt"
|
||
"strconv"
|
||
|
||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||
"github.com/redis/go-redis/v9"
|
||
)
|
||
|
||
// 并发控制缓存常量定义
|
||
//
|
||
// 性能优化说明:
|
||
// 原实现使用 SCAN 命令遍历独立的槽位键(concurrency:account:{id}:{requestID}),
|
||
// 在高并发场景下 SCAN 需要多次往返,且遍历大量键时性能下降明显。
|
||
//
|
||
// 新实现改用 Redis 有序集合(Sorted Set):
|
||
// 1. 每个账号/用户只有一个键,成员为 requestID,分数为时间戳
|
||
// 2. 使用 ZCARD 原子获取并发数,时间复杂度 O(1)
|
||
// 3. 使用 ZREMRANGEBYSCORE 清理过期槽位,避免手动管理 TTL
|
||
// 4. 单次 Redis 调用完成计数,减少网络往返
|
||
const (
|
||
// 并发槽位键前缀(有序集合)
|
||
// 格式: concurrency:account:{accountID}
|
||
accountSlotKeyPrefix = "concurrency:account:"
|
||
// 格式: concurrency:user:{userID}
|
||
userSlotKeyPrefix = "concurrency:user:"
|
||
// 等待队列计数器格式: concurrency:wait:{userID}
|
||
waitQueueKeyPrefix = "concurrency:wait:"
|
||
// 账号级等待队列计数器格式: wait:account:{accountID}
|
||
accountWaitKeyPrefix = "wait:account:"
|
||
|
||
// 默认槽位过期时间(分钟),可通过配置覆盖
|
||
defaultSlotTTLMinutes = 15
|
||
)
|
||
|
||
var (
|
||
// acquireScript 使用有序集合计数并在未达上限时添加槽位
|
||
// 使用 Redis TIME 命令获取服务器时间,避免多实例时钟不同步问题
|
||
// KEYS[1] = 有序集合键 (concurrency:account:{id} / concurrency:user:{id})
|
||
// ARGV[1] = maxConcurrency
|
||
// ARGV[2] = TTL(秒)
|
||
// ARGV[3] = requestID
|
||
acquireScript = redis.NewScript(`
|
||
local key = KEYS[1]
|
||
local maxConcurrency = tonumber(ARGV[1])
|
||
local ttl = tonumber(ARGV[2])
|
||
local requestID = ARGV[3]
|
||
|
||
-- 使用 Redis 服务器时间,确保多实例时钟一致
|
||
local timeResult = redis.call('TIME')
|
||
local now = tonumber(timeResult[1])
|
||
local expireBefore = now - ttl
|
||
|
||
-- 清理过期槽位
|
||
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
|
||
|
||
-- 检查是否已存在(支持重试场景刷新时间戳)
|
||
local exists = redis.call('ZSCORE', key, requestID)
|
||
if exists ~= false then
|
||
redis.call('ZADD', key, now, requestID)
|
||
redis.call('EXPIRE', key, ttl)
|
||
return 1
|
||
end
|
||
|
||
-- 检查是否达到并发上限
|
||
local count = redis.call('ZCARD', key)
|
||
if count < maxConcurrency then
|
||
redis.call('ZADD', key, now, requestID)
|
||
redis.call('EXPIRE', key, ttl)
|
||
return 1
|
||
end
|
||
|
||
return 0
|
||
`)
|
||
|
||
// getCountScript 统计有序集合中的槽位数量并清理过期条目
|
||
// 使用 Redis TIME 命令获取服务器时间
|
||
// KEYS[1] = 有序集合键
|
||
// ARGV[1] = TTL(秒)
|
||
getCountScript = redis.NewScript(`
|
||
local key = KEYS[1]
|
||
local ttl = tonumber(ARGV[1])
|
||
|
||
-- 使用 Redis 服务器时间
|
||
local timeResult = redis.call('TIME')
|
||
local now = tonumber(timeResult[1])
|
||
local expireBefore = now - ttl
|
||
|
||
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
|
||
return redis.call('ZCARD', key)
|
||
`)
|
||
|
||
// incrementWaitScript - only sets TTL on first creation to avoid refreshing
|
||
// KEYS[1] = wait queue key
|
||
// ARGV[1] = maxWait
|
||
// ARGV[2] = TTL in seconds
|
||
incrementWaitScript = redis.NewScript(`
|
||
local current = redis.call('GET', KEYS[1])
|
||
if current == false then
|
||
current = 0
|
||
else
|
||
current = tonumber(current)
|
||
end
|
||
|
||
if current >= tonumber(ARGV[1]) then
|
||
return 0
|
||
end
|
||
|
||
local newVal = redis.call('INCR', KEYS[1])
|
||
|
||
-- Only set TTL on first creation to avoid refreshing zombie data
|
||
if newVal == 1 then
|
||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||
end
|
||
|
||
return 1
|
||
`)
|
||
|
||
// incrementAccountWaitScript - account-level wait queue count
|
||
incrementAccountWaitScript = redis.NewScript(`
|
||
local current = redis.call('GET', KEYS[1])
|
||
if current == false then
|
||
current = 0
|
||
else
|
||
current = tonumber(current)
|
||
end
|
||
|
||
if current >= tonumber(ARGV[1]) then
|
||
return 0
|
||
end
|
||
|
||
local newVal = redis.call('INCR', KEYS[1])
|
||
|
||
-- Only set TTL on first creation to avoid refreshing zombie data
|
||
if newVal == 1 then
|
||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||
end
|
||
|
||
return 1
|
||
`)
|
||
|
||
// decrementWaitScript - same as before
|
||
decrementWaitScript = redis.NewScript(`
|
||
local current = redis.call('GET', KEYS[1])
|
||
if current ~= false and tonumber(current) > 0 then
|
||
redis.call('DECR', KEYS[1])
|
||
end
|
||
return 1
|
||
`)
|
||
|
||
// getAccountsLoadBatchScript - batch load query with expired slot cleanup
|
||
// ARGV[1] = slot TTL (seconds)
|
||
// ARGV[2..n] = accountID1, maxConcurrency1, accountID2, maxConcurrency2, ...
|
||
getAccountsLoadBatchScript = redis.NewScript(`
|
||
local result = {}
|
||
local slotTTL = tonumber(ARGV[1])
|
||
|
||
-- Get current server time
|
||
local timeResult = redis.call('TIME')
|
||
local nowSeconds = tonumber(timeResult[1])
|
||
local cutoffTime = nowSeconds - slotTTL
|
||
|
||
local i = 2
|
||
while i <= #ARGV do
|
||
local accountID = ARGV[i]
|
||
local maxConcurrency = tonumber(ARGV[i + 1])
|
||
|
||
local slotKey = 'concurrency:account:' .. accountID
|
||
|
||
-- Clean up expired slots before counting
|
||
redis.call('ZREMRANGEBYSCORE', slotKey, '-inf', cutoffTime)
|
||
local currentConcurrency = redis.call('ZCARD', slotKey)
|
||
|
||
local waitKey = 'wait:account:' .. accountID
|
||
local waitingCount = redis.call('GET', waitKey)
|
||
if waitingCount == false then
|
||
waitingCount = 0
|
||
else
|
||
waitingCount = tonumber(waitingCount)
|
||
end
|
||
|
||
local loadRate = 0
|
||
if maxConcurrency > 0 then
|
||
loadRate = math.floor((currentConcurrency + waitingCount) * 100 / maxConcurrency)
|
||
end
|
||
|
||
table.insert(result, accountID)
|
||
table.insert(result, currentConcurrency)
|
||
table.insert(result, waitingCount)
|
||
table.insert(result, loadRate)
|
||
|
||
i = i + 2
|
||
end
|
||
|
||
return result
|
||
`)
|
||
|
||
// cleanupExpiredSlotsScript - remove expired slots
|
||
// KEYS[1] = concurrency:account:{accountID}
|
||
// ARGV[1] = TTL (seconds)
|
||
cleanupExpiredSlotsScript = redis.NewScript(`
|
||
local key = KEYS[1]
|
||
local ttl = tonumber(ARGV[1])
|
||
local timeResult = redis.call('TIME')
|
||
local now = tonumber(timeResult[1])
|
||
local expireBefore = now - ttl
|
||
return redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
|
||
`)
|
||
)
|
||
|
||
type concurrencyCache struct {
|
||
rdb *redis.Client
|
||
slotTTLSeconds int // 槽位过期时间(秒)
|
||
waitQueueTTLSeconds int // 等待队列过期时间(秒)
|
||
}
|
||
|
||
// NewConcurrencyCache 创建并发控制缓存
|
||
// slotTTLMinutes: 槽位过期时间(分钟),0 或负数使用默认值 15 分钟
|
||
// waitQueueTTLSeconds: 等待队列过期时间(秒),0 或负数使用 slot TTL
|
||
func NewConcurrencyCache(rdb *redis.Client, slotTTLMinutes int, waitQueueTTLSeconds int) service.ConcurrencyCache {
|
||
if slotTTLMinutes <= 0 {
|
||
slotTTLMinutes = defaultSlotTTLMinutes
|
||
}
|
||
if waitQueueTTLSeconds <= 0 {
|
||
waitQueueTTLSeconds = slotTTLMinutes * 60
|
||
}
|
||
return &concurrencyCache{
|
||
rdb: rdb,
|
||
slotTTLSeconds: slotTTLMinutes * 60,
|
||
waitQueueTTLSeconds: waitQueueTTLSeconds,
|
||
}
|
||
}
|
||
|
||
// Helper functions for key generation
|
||
func accountSlotKey(accountID int64) string {
|
||
return fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
|
||
}
|
||
|
||
func userSlotKey(userID int64) string {
|
||
return fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
|
||
}
|
||
|
||
func waitQueueKey(userID int64) string {
|
||
return fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||
}
|
||
|
||
func accountWaitKey(accountID int64) string {
|
||
return fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
|
||
}
|
||
|
||
// Account slot operations
|
||
|
||
func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||
key := accountSlotKey(accountID)
|
||
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取,确保多实例时钟一致
|
||
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, c.slotTTLSeconds, requestID).Int()
|
||
if err != nil {
|
||
return false, err
|
||
}
|
||
return result == 1, nil
|
||
}
|
||
|
||
func (c *concurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
|
||
key := accountSlotKey(accountID)
|
||
return c.rdb.ZRem(ctx, key, requestID).Err()
|
||
}
|
||
|
||
func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
|
||
key := accountSlotKey(accountID)
|
||
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取
|
||
result, err := getCountScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Int()
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
return result, nil
|
||
}
|
||
|
||
// User slot operations
|
||
|
||
func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
||
key := userSlotKey(userID)
|
||
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取,确保多实例时钟一致
|
||
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, c.slotTTLSeconds, requestID).Int()
|
||
if err != nil {
|
||
return false, err
|
||
}
|
||
return result == 1, nil
|
||
}
|
||
|
||
func (c *concurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error {
|
||
key := userSlotKey(userID)
|
||
return c.rdb.ZRem(ctx, key, requestID).Err()
|
||
}
|
||
|
||
func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
|
||
key := userSlotKey(userID)
|
||
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取
|
||
result, err := getCountScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Int()
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
return result, nil
|
||
}
|
||
|
||
// Wait queue operations
|
||
|
||
func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
||
key := waitQueueKey(userID)
|
||
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.waitQueueTTLSeconds).Int()
|
||
if err != nil {
|
||
return false, err
|
||
}
|
||
return result == 1, nil
|
||
}
|
||
|
||
func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error {
|
||
key := waitQueueKey(userID)
|
||
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
|
||
return err
|
||
}
|
||
|
||
// Account wait queue operations
|
||
|
||
func (c *concurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
||
key := accountWaitKey(accountID)
|
||
result, err := incrementAccountWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.waitQueueTTLSeconds).Int()
|
||
if err != nil {
|
||
return false, err
|
||
}
|
||
return result == 1, nil
|
||
}
|
||
|
||
func (c *concurrencyCache) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
|
||
key := accountWaitKey(accountID)
|
||
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
|
||
return err
|
||
}
|
||
|
||
func (c *concurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
|
||
key := accountWaitKey(accountID)
|
||
val, err := c.rdb.Get(ctx, key).Int()
|
||
if err != nil && !errors.Is(err, redis.Nil) {
|
||
return 0, err
|
||
}
|
||
if errors.Is(err, redis.Nil) {
|
||
return 0, nil
|
||
}
|
||
return val, nil
|
||
}
|
||
|
||
func (c *concurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) {
|
||
if len(accounts) == 0 {
|
||
return map[int64]*service.AccountLoadInfo{}, nil
|
||
}
|
||
|
||
args := []any{c.slotTTLSeconds}
|
||
for _, acc := range accounts {
|
||
args = append(args, acc.ID, acc.MaxConcurrency)
|
||
}
|
||
|
||
result, err := getAccountsLoadBatchScript.Run(ctx, c.rdb, []string{}, args...).Slice()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
loadMap := make(map[int64]*service.AccountLoadInfo)
|
||
for i := 0; i < len(result); i += 4 {
|
||
if i+3 >= len(result) {
|
||
break
|
||
}
|
||
|
||
accountID, _ := strconv.ParseInt(fmt.Sprintf("%v", result[i]), 10, 64)
|
||
currentConcurrency, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+1]))
|
||
waitingCount, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+2]))
|
||
loadRate, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+3]))
|
||
|
||
loadMap[accountID] = &service.AccountLoadInfo{
|
||
AccountID: accountID,
|
||
CurrentConcurrency: currentConcurrency,
|
||
WaitingCount: waitingCount,
|
||
LoadRate: loadRate,
|
||
}
|
||
}
|
||
|
||
return loadMap, nil
|
||
}
|
||
|
||
func (c *concurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
|
||
key := accountSlotKey(accountID)
|
||
_, err := cleanupExpiredSlotsScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Result()
|
||
return err
|
||
}
|