perf(gateway): 优化负载感知调度
主要改进: - 优化负载感知调度的准确性和响应速度 - 将 AccountUsageService 的包级缓存改为依赖注入 - 修复 SSE/JSON 转义和 nil 安全问题 - 恢复 Google One 功能兼容性
This commit is contained in:
@@ -27,14 +27,8 @@ const (
|
||||
accountSlotKeyPrefix = "concurrency:account:"
|
||||
// 格式: concurrency:user:{userID}
|
||||
userSlotKeyPrefix = "concurrency:user:"
|
||||
|
||||
// Wait queue keys (global structures)
|
||||
// - total: integer total queue depth across all users
|
||||
// - updated: sorted set of userID -> lastUpdateUnixSec (for TTL cleanup)
|
||||
// - counts: hash of userID -> current wait count
|
||||
waitQueueTotalKey = "concurrency:wait:total"
|
||||
waitQueueUpdatedKey = "concurrency:wait:updated"
|
||||
waitQueueCountsKey = "concurrency:wait:counts"
|
||||
// 等待队列计数器格式: concurrency:wait:{userID}
|
||||
waitQueueKeyPrefix = "concurrency:wait:"
|
||||
// 账号级等待队列计数器格式: wait:account:{accountID}
|
||||
accountWaitKeyPrefix = "wait:account:"
|
||||
|
||||
@@ -100,55 +94,27 @@ var (
|
||||
`)
|
||||
|
||||
// incrementWaitScript - only sets TTL on first creation to avoid refreshing
|
||||
// KEYS[1] = total key
|
||||
// KEYS[2] = updated zset key
|
||||
// KEYS[3] = counts hash key
|
||||
// ARGV[1] = userID
|
||||
// ARGV[2] = maxWait
|
||||
// ARGV[3] = TTL in seconds
|
||||
// ARGV[4] = cleanup limit
|
||||
// KEYS[1] = wait queue key
|
||||
// ARGV[1] = maxWait
|
||||
// ARGV[2] = TTL in seconds
|
||||
incrementWaitScript = redis.NewScript(`
|
||||
local totalKey = KEYS[1]
|
||||
local updatedKey = KEYS[2]
|
||||
local countsKey = KEYS[3]
|
||||
|
||||
local userID = ARGV[1]
|
||||
local maxWait = tonumber(ARGV[2])
|
||||
local ttl = tonumber(ARGV[3])
|
||||
local cleanupLimit = tonumber(ARGV[4])
|
||||
|
||||
redis.call('SETNX', totalKey, 0)
|
||||
|
||||
local timeResult = redis.call('TIME')
|
||||
local now = tonumber(timeResult[1])
|
||||
local expireBefore = now - ttl
|
||||
|
||||
-- Cleanup expired users (bounded)
|
||||
local expired = redis.call('ZRANGEBYSCORE', updatedKey, '-inf', expireBefore, 'LIMIT', 0, cleanupLimit)
|
||||
for _, uid in ipairs(expired) do
|
||||
local c = tonumber(redis.call('HGET', countsKey, uid) or '0')
|
||||
if c > 0 then
|
||||
redis.call('DECRBY', totalKey, c)
|
||||
end
|
||||
redis.call('HDEL', countsKey, uid)
|
||||
redis.call('ZREM', updatedKey, uid)
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current == false then
|
||||
current = 0
|
||||
else
|
||||
current = tonumber(current)
|
||||
end
|
||||
|
||||
local current = tonumber(redis.call('HGET', countsKey, userID) or '0')
|
||||
if current >= maxWait then
|
||||
if current >= tonumber(ARGV[1]) then
|
||||
return 0
|
||||
end
|
||||
|
||||
local newVal = current + 1
|
||||
redis.call('HSET', countsKey, userID, newVal)
|
||||
redis.call('ZADD', updatedKey, now, userID)
|
||||
redis.call('INCR', totalKey)
|
||||
local newVal = redis.call('INCR', KEYS[1])
|
||||
|
||||
-- Keep global structures from living forever in totally idle deployments.
|
||||
local ttlKeep = ttl * 2
|
||||
redis.call('EXPIRE', totalKey, ttlKeep)
|
||||
redis.call('EXPIRE', updatedKey, ttlKeep)
|
||||
redis.call('EXPIRE', countsKey, ttlKeep)
|
||||
-- Only set TTL on first creation to avoid refreshing zombie data
|
||||
if newVal == 1 then
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
end
|
||||
|
||||
return 1
|
||||
`)
|
||||
@@ -178,111 +144,6 @@ var (
|
||||
|
||||
// decrementWaitScript - same as before
|
||||
decrementWaitScript = redis.NewScript(`
|
||||
local totalKey = KEYS[1]
|
||||
local updatedKey = KEYS[2]
|
||||
local countsKey = KEYS[3]
|
||||
|
||||
local userID = ARGV[1]
|
||||
local ttl = tonumber(ARGV[2])
|
||||
local cleanupLimit = tonumber(ARGV[3])
|
||||
|
||||
redis.call('SETNX', totalKey, 0)
|
||||
|
||||
local timeResult = redis.call('TIME')
|
||||
local now = tonumber(timeResult[1])
|
||||
local expireBefore = now - ttl
|
||||
|
||||
-- Cleanup expired users (bounded)
|
||||
local expired = redis.call('ZRANGEBYSCORE', updatedKey, '-inf', expireBefore, 'LIMIT', 0, cleanupLimit)
|
||||
for _, uid in ipairs(expired) do
|
||||
local c = tonumber(redis.call('HGET', countsKey, uid) or '0')
|
||||
if c > 0 then
|
||||
redis.call('DECRBY', totalKey, c)
|
||||
end
|
||||
redis.call('HDEL', countsKey, uid)
|
||||
redis.call('ZREM', updatedKey, uid)
|
||||
end
|
||||
|
||||
local current = tonumber(redis.call('HGET', countsKey, userID) or '0')
|
||||
if current <= 0 then
|
||||
return 1
|
||||
end
|
||||
|
||||
local newVal = current - 1
|
||||
if newVal <= 0 then
|
||||
redis.call('HDEL', countsKey, userID)
|
||||
redis.call('ZREM', updatedKey, userID)
|
||||
else
|
||||
redis.call('HSET', countsKey, userID, newVal)
|
||||
redis.call('ZADD', updatedKey, now, userID)
|
||||
end
|
||||
redis.call('DECR', totalKey)
|
||||
|
||||
local ttlKeep = ttl * 2
|
||||
redis.call('EXPIRE', totalKey, ttlKeep)
|
||||
redis.call('EXPIRE', updatedKey, ttlKeep)
|
||||
redis.call('EXPIRE', countsKey, ttlKeep)
|
||||
|
||||
return 1
|
||||
`)
|
||||
|
||||
// getTotalWaitScript returns the global wait depth with TTL cleanup.
|
||||
// KEYS[1] = total key
|
||||
// KEYS[2] = updated zset key
|
||||
// KEYS[3] = counts hash key
|
||||
// ARGV[1] = TTL in seconds
|
||||
// ARGV[2] = cleanup limit
|
||||
getTotalWaitScript = redis.NewScript(`
|
||||
local totalKey = KEYS[1]
|
||||
local updatedKey = KEYS[2]
|
||||
local countsKey = KEYS[3]
|
||||
|
||||
local ttl = tonumber(ARGV[1])
|
||||
local cleanupLimit = tonumber(ARGV[2])
|
||||
|
||||
redis.call('SETNX', totalKey, 0)
|
||||
|
||||
local timeResult = redis.call('TIME')
|
||||
local now = tonumber(timeResult[1])
|
||||
local expireBefore = now - ttl
|
||||
|
||||
-- Cleanup expired users (bounded)
|
||||
local expired = redis.call('ZRANGEBYSCORE', updatedKey, '-inf', expireBefore, 'LIMIT', 0, cleanupLimit)
|
||||
for _, uid in ipairs(expired) do
|
||||
local c = tonumber(redis.call('HGET', countsKey, uid) or '0')
|
||||
if c > 0 then
|
||||
redis.call('DECRBY', totalKey, c)
|
||||
end
|
||||
redis.call('HDEL', countsKey, uid)
|
||||
redis.call('ZREM', updatedKey, uid)
|
||||
end
|
||||
|
||||
-- If totalKey got lost but counts exist (e.g. Redis restart), recompute once.
|
||||
local total = redis.call('GET', totalKey)
|
||||
if total == false then
|
||||
total = 0
|
||||
local vals = redis.call('HVALS', countsKey)
|
||||
for _, v in ipairs(vals) do
|
||||
total = total + tonumber(v)
|
||||
end
|
||||
redis.call('SET', totalKey, total)
|
||||
end
|
||||
|
||||
local ttlKeep = ttl * 2
|
||||
redis.call('EXPIRE', totalKey, ttlKeep)
|
||||
redis.call('EXPIRE', updatedKey, ttlKeep)
|
||||
redis.call('EXPIRE', countsKey, ttlKeep)
|
||||
|
||||
local result = tonumber(redis.call('GET', totalKey) or '0')
|
||||
if result < 0 then
|
||||
result = 0
|
||||
redis.call('SET', totalKey, 0)
|
||||
end
|
||||
return result
|
||||
`)
|
||||
|
||||
// decrementAccountWaitScript - account-level wait queue decrement
|
||||
decrementAccountWaitScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current ~= false and tonumber(current) > 0 then
|
||||
redis.call('DECR', KEYS[1])
|
||||
@@ -383,9 +244,7 @@ func userSlotKey(userID int64) string {
|
||||
}
|
||||
|
||||
func waitQueueKey(userID int64) string {
|
||||
// Historical: per-user string keys were used.
|
||||
// Now we use global structures keyed by userID string.
|
||||
return strconv.FormatInt(userID, 10)
|
||||
return fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
}
|
||||
|
||||
func accountWaitKey(accountID int64) string {
|
||||
@@ -449,16 +308,8 @@ func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64)
|
||||
// Wait queue operations
|
||||
|
||||
func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
||||
userKey := waitQueueKey(userID)
|
||||
result, err := incrementWaitScript.Run(
|
||||
ctx,
|
||||
c.rdb,
|
||||
[]string{waitQueueTotalKey, waitQueueUpdatedKey, waitQueueCountsKey},
|
||||
userKey,
|
||||
maxWait,
|
||||
c.waitQueueTTLSeconds,
|
||||
200, // cleanup limit per call
|
||||
).Int()
|
||||
key := waitQueueKey(userID)
|
||||
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.waitQueueTTLSeconds).Int()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -466,35 +317,11 @@ func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64,
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error {
|
||||
userKey := waitQueueKey(userID)
|
||||
_, err := decrementWaitScript.Run(
|
||||
ctx,
|
||||
c.rdb,
|
||||
[]string{waitQueueTotalKey, waitQueueUpdatedKey, waitQueueCountsKey},
|
||||
userKey,
|
||||
c.waitQueueTTLSeconds,
|
||||
200, // cleanup limit per call
|
||||
).Result()
|
||||
key := waitQueueKey(userID)
|
||||
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) GetTotalWaitCount(ctx context.Context) (int, error) {
|
||||
if c.rdb == nil {
|
||||
return 0, nil
|
||||
}
|
||||
total, err := getTotalWaitScript.Run(
|
||||
ctx,
|
||||
c.rdb,
|
||||
[]string{waitQueueTotalKey, waitQueueUpdatedKey, waitQueueCountsKey},
|
||||
c.waitQueueTTLSeconds,
|
||||
500, // cleanup limit per query (rare)
|
||||
).Int64()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int(total), nil
|
||||
}
|
||||
|
||||
// Account wait queue operations
|
||||
|
||||
func (c *concurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
||||
@@ -508,7 +335,7 @@ func (c *concurrencyCache) IncrementAccountWaitCount(ctx context.Context, accoun
|
||||
|
||||
func (c *concurrencyCache) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
|
||||
key := accountWaitKey(accountID)
|
||||
_, err := decrementAccountWaitScript.Run(ctx, c.rdb, []string{key}).Result()
|
||||
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@@ -158,7 +158,7 @@ func (s *ConcurrencyCacheSuite) TestUserSlot_TTL() {
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() {
|
||||
userID := int64(20)
|
||||
userKey := waitQueueKey(userID)
|
||||
waitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
|
||||
ok, err := s.cache.IncrementWaitCount(s.ctx, userID, 2)
|
||||
require.NoError(s.T(), err, "IncrementWaitCount 1")
|
||||
@@ -172,31 +172,31 @@ func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() {
|
||||
require.NoError(s.T(), err, "IncrementWaitCount 3")
|
||||
require.False(s.T(), ok, "expected wait increment over max to fail")
|
||||
|
||||
ttl, err := s.rdb.TTL(s.ctx, waitQueueTotalKey).Result()
|
||||
require.NoError(s.T(), err, "TTL wait total key")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL*2)
|
||||
ttl, err := s.rdb.TTL(s.ctx, waitKey).Result()
|
||||
require.NoError(s.T(), err, "TTL waitKey")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
|
||||
|
||||
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount")
|
||||
|
||||
val, err := s.rdb.HGet(s.ctx, waitQueueCountsKey, userKey).Int()
|
||||
require.NoError(s.T(), err, "HGET wait queue count")
|
||||
val, err := s.rdb.Get(s.ctx, waitKey).Int()
|
||||
if !errors.Is(err, redis.Nil) {
|
||||
require.NoError(s.T(), err, "Get waitKey")
|
||||
}
|
||||
require.Equal(s.T(), 1, val, "expected wait count 1")
|
||||
|
||||
total, err := s.rdb.Get(s.ctx, waitQueueTotalKey).Int()
|
||||
require.NoError(s.T(), err, "GET wait queue total")
|
||||
require.Equal(s.T(), 1, total, "expected total wait count 1")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() {
|
||||
userID := int64(300)
|
||||
userKey := waitQueueKey(userID)
|
||||
waitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
|
||||
// Test decrement on non-existent key - should not error and should not create negative value
|
||||
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on non-existent key")
|
||||
|
||||
// Verify count remains zero / absent.
|
||||
val, err := s.rdb.HGet(s.ctx, waitQueueCountsKey, userKey).Int()
|
||||
require.True(s.T(), errors.Is(err, redis.Nil))
|
||||
// Verify no key was created or it's not negative
|
||||
val, err := s.rdb.Get(s.ctx, waitKey).Int()
|
||||
if !errors.Is(err, redis.Nil) {
|
||||
require.NoError(s.T(), err, "Get waitKey")
|
||||
}
|
||||
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count after decrement on empty")
|
||||
|
||||
// Set count to 1, then decrement twice
|
||||
@@ -210,15 +210,12 @@ func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() {
|
||||
// Decrement again on 0 - should not go negative
|
||||
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on zero")
|
||||
|
||||
// Verify per-user count is absent and total is non-negative.
|
||||
_, err = s.rdb.HGet(s.ctx, waitQueueCountsKey, userKey).Result()
|
||||
require.True(s.T(), errors.Is(err, redis.Nil), "expected count field removed on zero")
|
||||
|
||||
total, err := s.rdb.Get(s.ctx, waitQueueTotalKey).Int()
|
||||
// Verify count is 0, not negative
|
||||
val, err = s.rdb.Get(s.ctx, waitKey).Int()
|
||||
if !errors.Is(err, redis.Nil) {
|
||||
require.NoError(s.T(), err)
|
||||
require.NoError(s.T(), err, "Get waitKey after double decrement")
|
||||
}
|
||||
require.GreaterOrEqual(s.T(), total, 0, "expected non-negative total wait count")
|
||||
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_IncrementAndDecrement() {
|
||||
|
||||
Reference in New Issue
Block a user