perf(gateway): 优化负载感知调度

主要改进:
- 优化负载感知调度的准确性和响应速度
- 将 AccountUsageService 的包级缓存改为依赖注入
- 修复 SSE/JSON 转义和 nil 安全问题
- 恢复 Google One 功能兼容性
This commit is contained in:
ianshaw
2026-01-03 06:32:51 -08:00
parent 26106eb0ac
commit acb718d355
7 changed files with 369 additions and 310 deletions

View File

@@ -27,14 +27,8 @@ const (
accountSlotKeyPrefix = "concurrency:account:"
// 格式: concurrency:user:{userID}
userSlotKeyPrefix = "concurrency:user:"
// Wait queue keys (global structures)
// - total: integer total queue depth across all users
// - updated: sorted set of userID -> lastUpdateUnixSec (for TTL cleanup)
// - counts: hash of userID -> current wait count
waitQueueTotalKey = "concurrency:wait:total"
waitQueueUpdatedKey = "concurrency:wait:updated"
waitQueueCountsKey = "concurrency:wait:counts"
// 等待队列计数器格式: concurrency:wait:{userID}
waitQueueKeyPrefix = "concurrency:wait:"
// 账号级等待队列计数器格式: wait:account:{accountID}
accountWaitKeyPrefix = "wait:account:"
@@ -100,55 +94,27 @@ var (
`)
// incrementWaitScript - only sets TTL on first creation to avoid refreshing
// KEYS[1] = total key
// KEYS[2] = updated zset key
// KEYS[3] = counts hash key
// ARGV[1] = userID
// ARGV[2] = maxWait
// ARGV[3] = TTL in seconds
// ARGV[4] = cleanup limit
// KEYS[1] = wait queue key
// ARGV[1] = maxWait
// ARGV[2] = TTL in seconds
incrementWaitScript = redis.NewScript(`
local totalKey = KEYS[1]
local updatedKey = KEYS[2]
local countsKey = KEYS[3]
local userID = ARGV[1]
local maxWait = tonumber(ARGV[2])
local ttl = tonumber(ARGV[3])
local cleanupLimit = tonumber(ARGV[4])
redis.call('SETNX', totalKey, 0)
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - ttl
-- Cleanup expired users (bounded)
local expired = redis.call('ZRANGEBYSCORE', updatedKey, '-inf', expireBefore, 'LIMIT', 0, cleanupLimit)
for _, uid in ipairs(expired) do
local c = tonumber(redis.call('HGET', countsKey, uid) or '0')
if c > 0 then
redis.call('DECRBY', totalKey, c)
end
redis.call('HDEL', countsKey, uid)
redis.call('ZREM', updatedKey, uid)
local current = redis.call('GET', KEYS[1])
if current == false then
current = 0
else
current = tonumber(current)
end
local current = tonumber(redis.call('HGET', countsKey, userID) or '0')
if current >= maxWait then
if current >= tonumber(ARGV[1]) then
return 0
end
local newVal = current + 1
redis.call('HSET', countsKey, userID, newVal)
redis.call('ZADD', updatedKey, now, userID)
redis.call('INCR', totalKey)
local newVal = redis.call('INCR', KEYS[1])
-- Keep global structures from living forever in totally idle deployments.
local ttlKeep = ttl * 2
redis.call('EXPIRE', totalKey, ttlKeep)
redis.call('EXPIRE', updatedKey, ttlKeep)
redis.call('EXPIRE', countsKey, ttlKeep)
-- Only set TTL on first creation to avoid refreshing zombie data
if newVal == 1 then
redis.call('EXPIRE', KEYS[1], ARGV[2])
end
return 1
`)
@@ -178,111 +144,6 @@ var (
// decrementWaitScript - same as before
decrementWaitScript = redis.NewScript(`
local totalKey = KEYS[1]
local updatedKey = KEYS[2]
local countsKey = KEYS[3]
local userID = ARGV[1]
local ttl = tonumber(ARGV[2])
local cleanupLimit = tonumber(ARGV[3])
redis.call('SETNX', totalKey, 0)
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - ttl
-- Cleanup expired users (bounded)
local expired = redis.call('ZRANGEBYSCORE', updatedKey, '-inf', expireBefore, 'LIMIT', 0, cleanupLimit)
for _, uid in ipairs(expired) do
local c = tonumber(redis.call('HGET', countsKey, uid) or '0')
if c > 0 then
redis.call('DECRBY', totalKey, c)
end
redis.call('HDEL', countsKey, uid)
redis.call('ZREM', updatedKey, uid)
end
local current = tonumber(redis.call('HGET', countsKey, userID) or '0')
if current <= 0 then
return 1
end
local newVal = current - 1
if newVal <= 0 then
redis.call('HDEL', countsKey, userID)
redis.call('ZREM', updatedKey, userID)
else
redis.call('HSET', countsKey, userID, newVal)
redis.call('ZADD', updatedKey, now, userID)
end
redis.call('DECR', totalKey)
local ttlKeep = ttl * 2
redis.call('EXPIRE', totalKey, ttlKeep)
redis.call('EXPIRE', updatedKey, ttlKeep)
redis.call('EXPIRE', countsKey, ttlKeep)
return 1
`)
// getTotalWaitScript returns the global wait depth with TTL cleanup.
// KEYS[1] = total key
// KEYS[2] = updated zset key
// KEYS[3] = counts hash key
// ARGV[1] = TTL in seconds
// ARGV[2] = cleanup limit
getTotalWaitScript = redis.NewScript(`
local totalKey = KEYS[1]
local updatedKey = KEYS[2]
local countsKey = KEYS[3]
local ttl = tonumber(ARGV[1])
local cleanupLimit = tonumber(ARGV[2])
redis.call('SETNX', totalKey, 0)
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - ttl
-- Cleanup expired users (bounded)
local expired = redis.call('ZRANGEBYSCORE', updatedKey, '-inf', expireBefore, 'LIMIT', 0, cleanupLimit)
for _, uid in ipairs(expired) do
local c = tonumber(redis.call('HGET', countsKey, uid) or '0')
if c > 0 then
redis.call('DECRBY', totalKey, c)
end
redis.call('HDEL', countsKey, uid)
redis.call('ZREM', updatedKey, uid)
end
-- If totalKey got lost but counts exist (e.g. Redis restart), recompute once.
local total = redis.call('GET', totalKey)
if total == false then
total = 0
local vals = redis.call('HVALS', countsKey)
for _, v in ipairs(vals) do
total = total + tonumber(v)
end
redis.call('SET', totalKey, total)
end
local ttlKeep = ttl * 2
redis.call('EXPIRE', totalKey, ttlKeep)
redis.call('EXPIRE', updatedKey, ttlKeep)
redis.call('EXPIRE', countsKey, ttlKeep)
local result = tonumber(redis.call('GET', totalKey) or '0')
if result < 0 then
result = 0
redis.call('SET', totalKey, 0)
end
return result
`)
// decrementAccountWaitScript - account-level wait queue decrement
decrementAccountWaitScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current ~= false and tonumber(current) > 0 then
redis.call('DECR', KEYS[1])
@@ -383,9 +244,7 @@ func userSlotKey(userID int64) string {
}
func waitQueueKey(userID int64) string {
// Historical: per-user string keys were used.
// Now we use global structures keyed by userID string.
return strconv.FormatInt(userID, 10)
return fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
}
func accountWaitKey(accountID int64) string {
@@ -449,16 +308,8 @@ func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64)
// Wait queue operations
func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
userKey := waitQueueKey(userID)
result, err := incrementWaitScript.Run(
ctx,
c.rdb,
[]string{waitQueueTotalKey, waitQueueUpdatedKey, waitQueueCountsKey},
userKey,
maxWait,
c.waitQueueTTLSeconds,
200, // cleanup limit per call
).Int()
key := waitQueueKey(userID)
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.waitQueueTTLSeconds).Int()
if err != nil {
return false, err
}
@@ -466,35 +317,11 @@ func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64,
}
func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error {
userKey := waitQueueKey(userID)
_, err := decrementWaitScript.Run(
ctx,
c.rdb,
[]string{waitQueueTotalKey, waitQueueUpdatedKey, waitQueueCountsKey},
userKey,
c.waitQueueTTLSeconds,
200, // cleanup limit per call
).Result()
key := waitQueueKey(userID)
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
return err
}
func (c *concurrencyCache) GetTotalWaitCount(ctx context.Context) (int, error) {
if c.rdb == nil {
return 0, nil
}
total, err := getTotalWaitScript.Run(
ctx,
c.rdb,
[]string{waitQueueTotalKey, waitQueueUpdatedKey, waitQueueCountsKey},
c.waitQueueTTLSeconds,
500, // cleanup limit per query (rare)
).Int64()
if err != nil {
return 0, err
}
return int(total), nil
}
// Account wait queue operations
func (c *concurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
@@ -508,7 +335,7 @@ func (c *concurrencyCache) IncrementAccountWaitCount(ctx context.Context, accoun
func (c *concurrencyCache) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
key := accountWaitKey(accountID)
_, err := decrementAccountWaitScript.Run(ctx, c.rdb, []string{key}).Result()
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
return err
}

View File

@@ -158,7 +158,7 @@ func (s *ConcurrencyCacheSuite) TestUserSlot_TTL() {
func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() {
userID := int64(20)
userKey := waitQueueKey(userID)
waitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
ok, err := s.cache.IncrementWaitCount(s.ctx, userID, 2)
require.NoError(s.T(), err, "IncrementWaitCount 1")
@@ -172,31 +172,31 @@ func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() {
require.NoError(s.T(), err, "IncrementWaitCount 3")
require.False(s.T(), ok, "expected wait increment over max to fail")
ttl, err := s.rdb.TTL(s.ctx, waitQueueTotalKey).Result()
require.NoError(s.T(), err, "TTL wait total key")
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL*2)
ttl, err := s.rdb.TTL(s.ctx, waitKey).Result()
require.NoError(s.T(), err, "TTL waitKey")
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount")
val, err := s.rdb.HGet(s.ctx, waitQueueCountsKey, userKey).Int()
require.NoError(s.T(), err, "HGET wait queue count")
val, err := s.rdb.Get(s.ctx, waitKey).Int()
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey")
}
require.Equal(s.T(), 1, val, "expected wait count 1")
total, err := s.rdb.Get(s.ctx, waitQueueTotalKey).Int()
require.NoError(s.T(), err, "GET wait queue total")
require.Equal(s.T(), 1, total, "expected total wait count 1")
}
func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() {
userID := int64(300)
userKey := waitQueueKey(userID)
waitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
// Test decrement on non-existent key - should not error and should not create negative value
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on non-existent key")
// Verify count remains zero / absent.
val, err := s.rdb.HGet(s.ctx, waitQueueCountsKey, userKey).Int()
require.True(s.T(), errors.Is(err, redis.Nil))
// Verify no key was created or it's not negative
val, err := s.rdb.Get(s.ctx, waitKey).Int()
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey")
}
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count after decrement on empty")
// Set count to 1, then decrement twice
@@ -210,15 +210,12 @@ func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() {
// Decrement again on 0 - should not go negative
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on zero")
// Verify per-user count is absent and total is non-negative.
_, err = s.rdb.HGet(s.ctx, waitQueueCountsKey, userKey).Result()
require.True(s.T(), errors.Is(err, redis.Nil), "expected count field removed on zero")
total, err := s.rdb.Get(s.ctx, waitQueueTotalKey).Int()
// Verify count is 0, not negative
val, err = s.rdb.Get(s.ctx, waitKey).Int()
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err)
require.NoError(s.T(), err, "Get waitKey after double decrement")
}
require.GreaterOrEqual(s.T(), total, 0, "expected non-negative total wait count")
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count")
}
func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_IncrementAndDecrement() {