Files
sub2api/backend/internal/repository/concurrency_cache.go
yangjianbo 7efa8b54c4 perf(后端): 完成性能优化与连接池配置
新增 DB/Redis 连接池配置与校验,并补充单测

网关请求体大小限制与 413 处理

HTTP/req 客户端池化并调整上游连接池默认值

并发槽位改为 ZSET+Lua 与指数退避

用量统计改 SQL 聚合并新增索引迁移

计费缓存写入改工作池并补测试/基准

测试: 在 backend/ 下运行 go test ./...
2025-12-31 08:50:12 +08:00

228 lines
6.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package repository
import (
"context"
"fmt"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
// 并发控制缓存常量定义
//
// 性能优化说明:
// 原实现使用 SCAN 命令遍历独立的槽位键concurrency:account:{id}:{requestID}
// 在高并发场景下 SCAN 需要多次往返,且遍历大量键时性能下降明显。
//
// 新实现改用 Redis 有序集合Sorted Set
// 1. 每个账号/用户只有一个键,成员为 requestID分数为时间戳
// 2. 使用 ZCARD 原子获取并发数,时间复杂度 O(1)
// 3. 使用 ZREMRANGEBYSCORE 清理过期槽位,避免手动管理 TTL
// 4. 单次 Redis 调用完成计数,减少网络往返
const (
// 并发槽位键前缀(有序集合)
// 格式: concurrency:account:{accountID}
accountSlotKeyPrefix = "concurrency:account:"
// 格式: concurrency:user:{userID}
userSlotKeyPrefix = "concurrency:user:"
// 等待队列计数器格式: concurrency:wait:{userID}
waitQueueKeyPrefix = "concurrency:wait:"
// 默认槽位过期时间(分钟),可通过配置覆盖
defaultSlotTTLMinutes = 15
)
var (
// acquireScript 使用有序集合计数并在未达上限时添加槽位
// 使用 Redis TIME 命令获取服务器时间,避免多实例时钟不同步问题
// KEYS[1] = 有序集合键 (concurrency:account:{id} / concurrency:user:{id})
// ARGV[1] = maxConcurrency
// ARGV[2] = TTL
// ARGV[3] = requestID
acquireScript = redis.NewScript(`
local key = KEYS[1]
local maxConcurrency = tonumber(ARGV[1])
local ttl = tonumber(ARGV[2])
local requestID = ARGV[3]
-- 使用 Redis 服务器时间,确保多实例时钟一致
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - ttl
-- 清理过期槽位
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
-- 检查是否已存在(支持重试场景刷新时间戳)
local exists = redis.call('ZSCORE', key, requestID)
if exists ~= false then
redis.call('ZADD', key, now, requestID)
redis.call('EXPIRE', key, ttl)
return 1
end
-- 检查是否达到并发上限
local count = redis.call('ZCARD', key)
if count < maxConcurrency then
redis.call('ZADD', key, now, requestID)
redis.call('EXPIRE', key, ttl)
return 1
end
return 0
`)
// getCountScript 统计有序集合中的槽位数量并清理过期条目
// 使用 Redis TIME 命令获取服务器时间
// KEYS[1] = 有序集合键
// ARGV[1] = TTL
getCountScript = redis.NewScript(`
local key = KEYS[1]
local ttl = tonumber(ARGV[1])
-- 使用 Redis 服务器时间
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - ttl
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
return redis.call('ZCARD', key)
`)
// incrementWaitScript - only sets TTL on first creation to avoid refreshing
// KEYS[1] = wait queue key
// ARGV[1] = maxWait
// ARGV[2] = TTL in seconds
incrementWaitScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current == false then
current = 0
else
current = tonumber(current)
end
if current >= tonumber(ARGV[1]) then
return 0
end
local newVal = redis.call('INCR', KEYS[1])
-- Only set TTL on first creation to avoid refreshing zombie data
if newVal == 1 then
redis.call('EXPIRE', KEYS[1], ARGV[2])
end
return 1
`)
// decrementWaitScript - same as before
decrementWaitScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current ~= false and tonumber(current) > 0 then
redis.call('DECR', KEYS[1])
end
return 1
`)
)
type concurrencyCache struct {
rdb *redis.Client
slotTTLSeconds int // 槽位过期时间(秒)
}
// NewConcurrencyCache 创建并发控制缓存
// slotTTLMinutes: 槽位过期时间分钟0 或负数使用默认值 15 分钟
func NewConcurrencyCache(rdb *redis.Client, slotTTLMinutes int) service.ConcurrencyCache {
if slotTTLMinutes <= 0 {
slotTTLMinutes = defaultSlotTTLMinutes
}
return &concurrencyCache{
rdb: rdb,
slotTTLSeconds: slotTTLMinutes * 60,
}
}
// Helper functions for key generation
func accountSlotKey(accountID int64) string {
return fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
}
func userSlotKey(userID int64) string {
return fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
}
func waitQueueKey(userID int64) string {
return fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
}
// Account slot operations
func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
key := accountSlotKey(accountID)
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取,确保多实例时钟一致
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, c.slotTTLSeconds, requestID).Int()
if err != nil {
return false, err
}
return result == 1, nil
}
func (c *concurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
key := accountSlotKey(accountID)
return c.rdb.ZRem(ctx, key, requestID).Err()
}
func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
key := accountSlotKey(accountID)
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取
result, err := getCountScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Int()
if err != nil {
return 0, err
}
return result, nil
}
// User slot operations
func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
key := userSlotKey(userID)
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取,确保多实例时钟一致
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, c.slotTTLSeconds, requestID).Int()
if err != nil {
return false, err
}
return result == 1, nil
}
func (c *concurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error {
key := userSlotKey(userID)
return c.rdb.ZRem(ctx, key, requestID).Err()
}
func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
key := userSlotKey(userID)
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取
result, err := getCountScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Int()
if err != nil {
return 0, err
}
return result, nil
}
// Wait queue operations
func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
key := waitQueueKey(userID)
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.slotTTLSeconds).Int()
if err != nil {
return false, err
}
return result == 1, nil
}
func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error {
key := waitQueueKey(userID)
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
return err
}