feat(gateway): 双模式用户消息队列 — 串行队列 + 软性限速
新增 UMQ (User Message Queue) 双模式支持: - serialize: 账号级分布式串行锁 + RPM 自适应延迟(严格限流) - throttle: 仅 RPM 自适应前置延迟,不阻塞并发(软性限速) 后端: - config: 新增 Mode 字段,保留 Enabled 向后兼容 - service: 新增 UserMessageQueueService(Lua 锁/延迟算法/清理 worker) - repository: 新增 UserMsgQueueCache(Redis Lua acquire/release/force-release) - handler: 新增 UserMsgQueueHelper(SSE ping + 等待循环 + throttle) - gateway: 按 mode 分支集成 serialize/throttle 逻辑 - lint: 修复 gofmt rewrite rules、errcheck 类型断言、staticcheck QF1012 前端: - 三态选择器 UI(关闭/软性限速/串行队列)替代 toggle 开关 - BulkEdit 支持 null 语义(不修改) - i18n 中英文文案 通过 6 轮专家评审(42 次 review)、golangci-lint、单元测试、集成测试。
This commit is contained in:
@@ -196,7 +196,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
adminAPIKeyHandler := admin.NewAdminAPIKeyHandler(adminService)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler)
|
||||
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
|
||||
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig, settingService)
|
||||
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
|
||||
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
|
||||
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, userMessageQueueService, configConfig, settingService)
|
||||
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig)
|
||||
soraSDKClient := service.ProvideSoraSDKClient(configConfig, httpUpstream, openAITokenProvider, accountRepository, soraAccountRepository)
|
||||
soraMediaStorage := service.ProvideSoraMediaStorage(configConfig)
|
||||
|
||||
@@ -30,6 +30,14 @@ const (
|
||||
// __CSP_NONCE__ will be replaced with actual nonce at request time by the SecurityHeaders middleware
|
||||
const DefaultCSPPolicy = "default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
|
||||
|
||||
// UMQ(用户消息队列)模式常量
|
||||
const (
|
||||
// UMQModeSerialize: 账号级串行锁 + RPM 自适应延迟
|
||||
UMQModeSerialize = "serialize"
|
||||
// UMQModeThrottle: 仅 RPM 自适应前置延迟,不阻塞并发
|
||||
UMQModeThrottle = "throttle"
|
||||
)
|
||||
|
||||
// 连接池隔离策略常量
|
||||
// 用于控制上游 HTTP 连接池的隔离粒度,影响连接复用和资源消耗
|
||||
const (
|
||||
@@ -455,6 +463,52 @@ type GatewayConfig struct {
|
||||
UserGroupRateCacheTTLSeconds int `mapstructure:"user_group_rate_cache_ttl_seconds"`
|
||||
// ModelsListCacheTTLSeconds: /v1/models 模型列表短缓存 TTL(秒)
|
||||
ModelsListCacheTTLSeconds int `mapstructure:"models_list_cache_ttl_seconds"`
|
||||
|
||||
// UserMessageQueue: 用户消息串行队列配置
|
||||
// 对 role:"user" 的真实用户消息实施账号级串行化 + RPM 自适应延迟
|
||||
UserMessageQueue UserMessageQueueConfig `mapstructure:"user_message_queue"`
|
||||
}
|
||||
|
||||
// UserMessageQueueConfig 用户消息串行队列配置
|
||||
// 用于 Anthropic OAuth/SetupToken 账号的用户消息串行化发送
|
||||
type UserMessageQueueConfig struct {
|
||||
// Mode: 模式选择
|
||||
// "serialize" = 账号级串行锁 + RPM 自适应延迟
|
||||
// "throttle" = 仅 RPM 自适应前置延迟,不阻塞并发
|
||||
// "" = 禁用(默认)
|
||||
Mode string `mapstructure:"mode"`
|
||||
// Enabled: 已废弃,仅向后兼容(等同于 mode: "serialize")
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
// LockTTLMs: 串行锁 TTL(毫秒),应大于最长请求时间
|
||||
LockTTLMs int `mapstructure:"lock_ttl_ms"`
|
||||
// WaitTimeoutMs: 等待获取锁的超时时间(毫秒)
|
||||
WaitTimeoutMs int `mapstructure:"wait_timeout_ms"`
|
||||
// MinDelayMs: RPM 自适应延迟下限(毫秒)
|
||||
MinDelayMs int `mapstructure:"min_delay_ms"`
|
||||
// MaxDelayMs: RPM 自适应延迟上限(毫秒)
|
||||
MaxDelayMs int `mapstructure:"max_delay_ms"`
|
||||
// CleanupIntervalSeconds: 孤儿锁清理间隔(秒),0 表示禁用
|
||||
CleanupIntervalSeconds int `mapstructure:"cleanup_interval_seconds"`
|
||||
}
|
||||
|
||||
// WaitTimeout 返回等待超时的 time.Duration
|
||||
func (c *UserMessageQueueConfig) WaitTimeout() time.Duration {
|
||||
if c.WaitTimeoutMs <= 0 {
|
||||
return 30 * time.Second
|
||||
}
|
||||
return time.Duration(c.WaitTimeoutMs) * time.Millisecond
|
||||
}
|
||||
|
||||
// GetEffectiveMode 返回生效的模式
|
||||
// 注意:Mode 字段已在 load() 中做过白名单校验和规范化,此处无需重复验证
|
||||
func (c *UserMessageQueueConfig) GetEffectiveMode() string {
|
||||
if c.Mode == UMQModeSerialize || c.Mode == UMQModeThrottle {
|
||||
return c.Mode
|
||||
}
|
||||
if c.Enabled {
|
||||
return UMQModeSerialize // 向后兼容
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GatewayOpenAIWSConfig OpenAI Responses WebSocket 配置。
|
||||
@@ -994,6 +1048,14 @@ func load(allowMissingJWTSecret bool) (*Config, error) {
|
||||
cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = cfg.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds
|
||||
}
|
||||
|
||||
// Normalize UMQ mode: 白名单校验,非法值在加载时一次性 warn 并清空
|
||||
if m := cfg.Gateway.UserMessageQueue.Mode; m != "" && m != UMQModeSerialize && m != UMQModeThrottle {
|
||||
slog.Warn("invalid user_message_queue mode, disabling",
|
||||
"mode", m,
|
||||
"valid_modes", []string{UMQModeSerialize, UMQModeThrottle})
|
||||
cfg.Gateway.UserMessageQueue.Mode = ""
|
||||
}
|
||||
|
||||
// Auto-generate TOTP encryption key if not set (32 bytes = 64 hex chars for AES-256)
|
||||
cfg.Totp.EncryptionKey = strings.TrimSpace(cfg.Totp.EncryptionKey)
|
||||
if cfg.Totp.EncryptionKey == "" {
|
||||
@@ -1372,6 +1434,14 @@ func setDefaults() {
|
||||
viper.SetDefault("gateway.user_group_rate_cache_ttl_seconds", 30)
|
||||
viper.SetDefault("gateway.models_list_cache_ttl_seconds", 15)
|
||||
// TLS指纹伪装配置(默认关闭,需要账号级别单独启用)
|
||||
// 用户消息串行队列默认值
|
||||
viper.SetDefault("gateway.user_message_queue.enabled", false)
|
||||
viper.SetDefault("gateway.user_message_queue.lock_ttl_ms", 120000)
|
||||
viper.SetDefault("gateway.user_message_queue.wait_timeout_ms", 30000)
|
||||
viper.SetDefault("gateway.user_message_queue.min_delay_ms", 200)
|
||||
viper.SetDefault("gateway.user_message_queue.max_delay_ms", 2000)
|
||||
viper.SetDefault("gateway.user_message_queue.cleanup_interval_seconds", 60)
|
||||
|
||||
viper.SetDefault("gateway.tls_fingerprint.enabled", true)
|
||||
viper.SetDefault("concurrency.ping_interval", 10)
|
||||
|
||||
|
||||
@@ -216,6 +216,10 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
buffer := a.GetRPMStickyBuffer()
|
||||
out.RPMStickyBuffer = &buffer
|
||||
}
|
||||
// 用户消息队列模式
|
||||
if mode := a.GetUserMsgQueueMode(); mode != "" {
|
||||
out.UserMsgQueueMode = &mode
|
||||
}
|
||||
// TLS指纹伪装开关
|
||||
if a.IsTLSFingerprintEnabled() {
|
||||
enabled := true
|
||||
|
||||
@@ -155,9 +155,10 @@ type Account struct {
|
||||
|
||||
// RPM 限制(仅 Anthropic OAuth/SetupToken 账号有效)
|
||||
// 从 extra 字段提取,方便前端显示和编辑
|
||||
BaseRPM *int `json:"base_rpm,omitempty"`
|
||||
RPMStrategy *string `json:"rpm_strategy,omitempty"`
|
||||
RPMStickyBuffer *int `json:"rpm_sticky_buffer,omitempty"`
|
||||
BaseRPM *int `json:"base_rpm,omitempty"`
|
||||
RPMStrategy *string `json:"rpm_strategy,omitempty"`
|
||||
RPMStickyBuffer *int `json:"rpm_sticky_buffer,omitempty"`
|
||||
UserMsgQueueMode *string `json:"user_msg_queue_mode,omitempty"`
|
||||
|
||||
// TLS指纹伪装(仅 Anthropic OAuth/SetupToken 账号有效)
|
||||
// 从 extra 字段提取,方便前端显示和编辑
|
||||
|
||||
@@ -45,6 +45,7 @@ type GatewayHandler struct {
|
||||
usageRecordWorkerPool *service.UsageRecordWorkerPool
|
||||
errorPassthroughService *service.ErrorPassthroughService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
userMsgQueueHelper *UserMsgQueueHelper
|
||||
maxAccountSwitches int
|
||||
maxAccountSwitchesGemini int
|
||||
cfg *config.Config
|
||||
@@ -63,6 +64,7 @@ func NewGatewayHandler(
|
||||
apiKeyService *service.APIKeyService,
|
||||
usageRecordWorkerPool *service.UsageRecordWorkerPool,
|
||||
errorPassthroughService *service.ErrorPassthroughService,
|
||||
userMsgQueueService *service.UserMessageQueueService,
|
||||
cfg *config.Config,
|
||||
settingService *service.SettingService,
|
||||
) *GatewayHandler {
|
||||
@@ -78,6 +80,13 @@ func NewGatewayHandler(
|
||||
maxAccountSwitchesGemini = cfg.Gateway.MaxAccountSwitchesGemini
|
||||
}
|
||||
}
|
||||
|
||||
// 初始化用户消息串行队列 helper
|
||||
var umqHelper *UserMsgQueueHelper
|
||||
if userMsgQueueService != nil && cfg != nil {
|
||||
umqHelper = NewUserMsgQueueHelper(userMsgQueueService, SSEPingFormatClaude, pingInterval)
|
||||
}
|
||||
|
||||
return &GatewayHandler{
|
||||
gatewayService: gatewayService,
|
||||
geminiCompatService: geminiCompatService,
|
||||
@@ -89,6 +98,7 @@ func NewGatewayHandler(
|
||||
usageRecordWorkerPool: usageRecordWorkerPool,
|
||||
errorPassthroughService: errorPassthroughService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
|
||||
userMsgQueueHelper: umqHelper,
|
||||
maxAccountSwitches: maxAccountSwitches,
|
||||
maxAccountSwitchesGemini: maxAccountSwitchesGemini,
|
||||
cfg: cfg,
|
||||
@@ -566,6 +576,58 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 账号槽位/等待计数需要在超时或断开时安全回收
|
||||
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
|
||||
|
||||
// ===== 用户消息串行队列 START =====
|
||||
var queueRelease func()
|
||||
umqMode := h.getUserMsgQueueMode(account, parsedReq)
|
||||
|
||||
switch umqMode {
|
||||
case config.UMQModeSerialize:
|
||||
// 串行模式:获取锁 + RPM 延迟 + 释放(当前行为不变)
|
||||
baseRPM := account.GetBaseRPM()
|
||||
release, qErr := h.userMsgQueueHelper.AcquireWithWait(
|
||||
c, account.ID, baseRPM, reqStream, &streamStarted,
|
||||
h.cfg.Gateway.UserMessageQueue.WaitTimeout(),
|
||||
reqLog,
|
||||
)
|
||||
if qErr != nil {
|
||||
// fail-open: 记录 warn,不阻止请求
|
||||
reqLog.Warn("gateway.umq_acquire_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Error(qErr),
|
||||
)
|
||||
} else {
|
||||
queueRelease = release
|
||||
}
|
||||
|
||||
case config.UMQModeThrottle:
|
||||
// 软性限速:仅施加 RPM 自适应延迟,不阻塞并发
|
||||
baseRPM := account.GetBaseRPM()
|
||||
if tErr := h.userMsgQueueHelper.ThrottleWithPing(
|
||||
c, account.ID, baseRPM, reqStream, &streamStarted,
|
||||
h.cfg.Gateway.UserMessageQueue.WaitTimeout(),
|
||||
reqLog,
|
||||
); tErr != nil {
|
||||
reqLog.Warn("gateway.umq_throttle_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Error(tErr),
|
||||
)
|
||||
}
|
||||
|
||||
default:
|
||||
if umqMode != "" {
|
||||
reqLog.Warn("gateway.umq_unknown_mode",
|
||||
zap.String("mode", umqMode),
|
||||
zap.Int64("account_id", account.ID),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// 用 wrapReleaseOnDone 确保 context 取消时自动释放(仅 serialize 模式有 queueRelease)
|
||||
queueRelease = wrapReleaseOnDone(c.Request.Context(), queueRelease)
|
||||
// 注入回调到 ParsedRequest:使用外层 wrapper 以便提前清理 AfterFunc
|
||||
parsedReq.OnUpstreamAccepted = queueRelease
|
||||
// ===== 用户消息串行队列 END =====
|
||||
|
||||
// 转发请求 - 根据账号平台分流
|
||||
var result *service.ForwardResult
|
||||
requestCtx := c.Request.Context()
|
||||
@@ -577,6 +639,14 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
} else {
|
||||
result, err = h.gatewayService.Forward(requestCtx, c, account, parsedReq)
|
||||
}
|
||||
|
||||
// 兜底释放串行锁(正常情况已通过回调提前释放)
|
||||
if queueRelease != nil {
|
||||
queueRelease()
|
||||
}
|
||||
// 清理回调引用,防止 failover 重试时旧回调被错误调用
|
||||
parsedReq.OnUpstreamAccepted = nil
|
||||
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
@@ -1431,3 +1501,24 @@ func (h *GatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
|
||||
}()
|
||||
task(ctx)
|
||||
}
|
||||
|
||||
// getUserMsgQueueMode 获取当前请求的 UMQ 模式
|
||||
// 返回 "serialize" | "throttle" | ""
|
||||
func (h *GatewayHandler) getUserMsgQueueMode(account *service.Account, parsed *service.ParsedRequest) string {
|
||||
if h.userMsgQueueHelper == nil {
|
||||
return ""
|
||||
}
|
||||
// 仅适用于 Anthropic OAuth/SetupToken 账号
|
||||
if !account.IsAnthropicOAuthOrSetupToken() {
|
||||
return ""
|
||||
}
|
||||
if !service.IsRealUserMessage(parsed) {
|
||||
return ""
|
||||
}
|
||||
// 账号级模式优先,fallback 到全局配置
|
||||
mode := account.GetUserMsgQueueMode()
|
||||
if mode == "" {
|
||||
mode = h.cfg.Gateway.UserMessageQueue.GetEffectiveMode()
|
||||
}
|
||||
return mode
|
||||
}
|
||||
|
||||
237
backend/internal/handler/user_msg_queue_helper.go
Normal file
237
backend/internal/handler/user_msg_queue_helper.go
Normal file
@@ -0,0 +1,237 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// UserMsgQueueHelper 用户消息串行队列 Handler 层辅助
|
||||
// 复用 ConcurrencyHelper 的退避 + SSE ping 模式
|
||||
type UserMsgQueueHelper struct {
|
||||
queueService *service.UserMessageQueueService
|
||||
pingFormat SSEPingFormat
|
||||
pingInterval time.Duration
|
||||
}
|
||||
|
||||
// NewUserMsgQueueHelper 创建用户消息串行队列辅助
|
||||
func NewUserMsgQueueHelper(
|
||||
queueService *service.UserMessageQueueService,
|
||||
pingFormat SSEPingFormat,
|
||||
pingInterval time.Duration,
|
||||
) *UserMsgQueueHelper {
|
||||
if pingInterval <= 0 {
|
||||
pingInterval = defaultPingInterval
|
||||
}
|
||||
return &UserMsgQueueHelper{
|
||||
queueService: queueService,
|
||||
pingFormat: pingFormat,
|
||||
pingInterval: pingInterval,
|
||||
}
|
||||
}
|
||||
|
||||
// AcquireWithWait 等待获取串行锁,流式请求期间发送 SSE ping
|
||||
// 返回的 releaseFunc 内部使用 sync.Once,确保只执行一次释放
|
||||
func (h *UserMsgQueueHelper) AcquireWithWait(
|
||||
c *gin.Context,
|
||||
accountID int64,
|
||||
baseRPM int,
|
||||
isStream bool,
|
||||
streamStarted *bool,
|
||||
timeout time.Duration,
|
||||
reqLog *zap.Logger,
|
||||
) (releaseFunc func(), err error) {
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
|
||||
defer cancel()
|
||||
|
||||
// 先尝试立即获取
|
||||
result, err := h.queueService.TryAcquire(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err // fail-open 已在 service 层处理
|
||||
}
|
||||
|
||||
if result.Acquired {
|
||||
// 获取成功,执行 RPM 自适应延迟
|
||||
if err := h.queueService.EnforceDelay(ctx, accountID, baseRPM); err != nil {
|
||||
if ctx.Err() != nil {
|
||||
// 延迟期间 context 取消,释放锁
|
||||
bgCtx, bgCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
_ = h.queueService.Release(bgCtx, accountID, result.RequestID)
|
||||
bgCancel()
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
reqLog.Debug("gateway.umq_lock_acquired", zap.Int64("account_id", accountID))
|
||||
return h.makeReleaseFunc(accountID, result.RequestID, reqLog), nil
|
||||
}
|
||||
|
||||
// 需要等待:指数退避轮询
|
||||
return h.waitForLockWithPing(c, ctx, accountID, baseRPM, isStream, streamStarted, reqLog)
|
||||
}
|
||||
|
||||
// waitForLockWithPing 等待获取锁,流式请求期间发送 SSE ping
|
||||
func (h *UserMsgQueueHelper) waitForLockWithPing(
|
||||
c *gin.Context,
|
||||
ctx context.Context,
|
||||
accountID int64,
|
||||
baseRPM int,
|
||||
isStream bool,
|
||||
streamStarted *bool,
|
||||
reqLog *zap.Logger,
|
||||
) (func(), error) {
|
||||
needPing := isStream && h.pingFormat != ""
|
||||
|
||||
var flusher http.Flusher
|
||||
if needPing {
|
||||
var ok bool
|
||||
flusher, ok = c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
needPing = false
|
||||
}
|
||||
}
|
||||
|
||||
var pingCh <-chan time.Time
|
||||
if needPing {
|
||||
pingTicker := time.NewTicker(h.pingInterval)
|
||||
defer pingTicker.Stop()
|
||||
pingCh = pingTicker.C
|
||||
}
|
||||
|
||||
backoff := initialBackoff
|
||||
timer := time.NewTimer(backoff)
|
||||
defer timer.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, fmt.Errorf("umq wait timeout for account %d", accountID)
|
||||
|
||||
case <-pingCh:
|
||||
if !*streamStarted {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
*streamStarted = true
|
||||
}
|
||||
if _, err := fmt.Fprint(c.Writer, string(h.pingFormat)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
flusher.Flush()
|
||||
|
||||
case <-timer.C:
|
||||
result, err := h.queueService.TryAcquire(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if result.Acquired {
|
||||
// 获取成功,执行 RPM 自适应延迟
|
||||
if delayErr := h.queueService.EnforceDelay(ctx, accountID, baseRPM); delayErr != nil {
|
||||
if ctx.Err() != nil {
|
||||
bgCtx, bgCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
_ = h.queueService.Release(bgCtx, accountID, result.RequestID)
|
||||
bgCancel()
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
reqLog.Debug("gateway.umq_lock_acquired", zap.Int64("account_id", accountID))
|
||||
return h.makeReleaseFunc(accountID, result.RequestID, reqLog), nil
|
||||
}
|
||||
backoff = nextBackoff(backoff)
|
||||
timer.Reset(backoff)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// makeReleaseFunc 创建锁释放函数(使用 sync.Once 确保只执行一次)
|
||||
func (h *UserMsgQueueHelper) makeReleaseFunc(accountID int64, requestID string, reqLog *zap.Logger) func() {
|
||||
var once sync.Once
|
||||
return func() {
|
||||
once.Do(func() {
|
||||
bgCtx, bgCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer bgCancel()
|
||||
if err := h.queueService.Release(bgCtx, accountID, requestID); err != nil {
|
||||
reqLog.Warn("gateway.umq_release_failed",
|
||||
zap.Int64("account_id", accountID),
|
||||
zap.Error(err),
|
||||
)
|
||||
} else {
|
||||
reqLog.Debug("gateway.umq_lock_released", zap.Int64("account_id", accountID))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ThrottleWithPing 软性限速模式:施加 RPM 自适应延迟,流式期间发送 SSE ping
|
||||
// 不获取串行锁,不阻塞并发。返回后即可转发请求。
|
||||
func (h *UserMsgQueueHelper) ThrottleWithPing(
|
||||
c *gin.Context,
|
||||
accountID int64,
|
||||
baseRPM int,
|
||||
isStream bool,
|
||||
streamStarted *bool,
|
||||
timeout time.Duration,
|
||||
reqLog *zap.Logger,
|
||||
) error {
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
|
||||
defer cancel()
|
||||
|
||||
delay := h.queueService.CalculateRPMAwareDelay(ctx, accountID, baseRPM)
|
||||
if delay <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
reqLog.Debug("gateway.umq_throttle_delay",
|
||||
zap.Int64("account_id", accountID),
|
||||
zap.Duration("delay", delay),
|
||||
)
|
||||
|
||||
// 延迟期间发送 SSE ping(复用 waitForLockWithPing 的 ping 逻辑)
|
||||
needPing := isStream && h.pingFormat != ""
|
||||
var flusher http.Flusher
|
||||
if needPing {
|
||||
flusher, _ = c.Writer.(http.Flusher)
|
||||
if flusher == nil {
|
||||
needPing = false
|
||||
}
|
||||
}
|
||||
|
||||
var pingCh <-chan time.Time
|
||||
if needPing {
|
||||
pingTicker := time.NewTicker(h.pingInterval)
|
||||
defer pingTicker.Stop()
|
||||
pingCh = pingTicker.C
|
||||
}
|
||||
|
||||
timer := time.NewTimer(delay)
|
||||
defer timer.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-pingCh:
|
||||
// SSE ping 逻辑(与 waitForLockWithPing 一致)
|
||||
if !*streamStarted {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
*streamStarted = true
|
||||
}
|
||||
if _, err := fmt.Fprint(c.Writer, string(h.pingFormat)); err != nil {
|
||||
return err
|
||||
}
|
||||
flusher.Flush()
|
||||
case <-timer.C:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -35,12 +35,12 @@ func latencyHistogramRangeCaseExpr(column string) string {
|
||||
if b.upperMs <= 0 {
|
||||
continue
|
||||
}
|
||||
_, _ = sb.WriteString(fmt.Sprintf("\tWHEN %s < %d THEN '%s'\n", column, b.upperMs, b.label))
|
||||
fmt.Fprintf(&sb, "\tWHEN %s < %d THEN '%s'\n", column, b.upperMs, b.label)
|
||||
}
|
||||
|
||||
// Default bucket.
|
||||
last := latencyHistogramBuckets[len(latencyHistogramBuckets)-1]
|
||||
_, _ = sb.WriteString(fmt.Sprintf("\tELSE '%s'\n", last.label))
|
||||
fmt.Fprintf(&sb, "\tELSE '%s'\n", last.label)
|
||||
_, _ = sb.WriteString("END")
|
||||
return sb.String()
|
||||
}
|
||||
@@ -54,11 +54,11 @@ func latencyHistogramRangeOrderCaseExpr(column string) string {
|
||||
if b.upperMs <= 0 {
|
||||
continue
|
||||
}
|
||||
_, _ = sb.WriteString(fmt.Sprintf("\tWHEN %s < %d THEN %d\n", column, b.upperMs, order))
|
||||
fmt.Fprintf(&sb, "\tWHEN %s < %d THEN %d\n", column, b.upperMs, order)
|
||||
order++
|
||||
}
|
||||
|
||||
_, _ = sb.WriteString(fmt.Sprintf("\tELSE %d\n", order))
|
||||
fmt.Fprintf(&sb, "\tELSE %d\n", order)
|
||||
_, _ = sb.WriteString("END")
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
186
backend/internal/repository/user_msg_queue_cache.go
Normal file
186
backend/internal/repository/user_msg_queue_cache.go
Normal file
@@ -0,0 +1,186 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// Redis Key 模式(使用 hash tag 确保 Redis Cluster 下同一 accountID 的 key 落入同一 slot)
|
||||
// 格式: umq:{accountID}:lock / umq:{accountID}:last
|
||||
const (
|
||||
umqKeyPrefix = "umq:"
|
||||
umqLockSuffix = ":lock" // STRING (requestID), PX lockTtlMs
|
||||
umqLastSuffix = ":last" // STRING (毫秒时间戳), EX 60s
|
||||
)
|
||||
|
||||
// Lua 脚本:原子获取串行锁(SET NX PX + 重入安全)
|
||||
var acquireLockScript = redis.NewScript(`
|
||||
local cur = redis.call('GET', KEYS[1])
|
||||
if cur == ARGV[1] then
|
||||
redis.call('PEXPIRE', KEYS[1], tonumber(ARGV[2]))
|
||||
return 1
|
||||
end
|
||||
if cur ~= false then return 0 end
|
||||
redis.call('SET', KEYS[1], ARGV[1], 'PX', tonumber(ARGV[2]))
|
||||
return 1
|
||||
`)
|
||||
|
||||
// Lua 脚本:原子释放锁 + 记录完成时间(使用 Redis TIME 避免时钟偏差)
|
||||
var releaseLockScript = redis.NewScript(`
|
||||
local cur = redis.call('GET', KEYS[1])
|
||||
if cur == ARGV[1] then
|
||||
redis.call('DEL', KEYS[1])
|
||||
local t = redis.call('TIME')
|
||||
local ms = tonumber(t[1])*1000 + math.floor(tonumber(t[2])/1000)
|
||||
redis.call('SET', KEYS[2], ms, 'EX', 60)
|
||||
return 1
|
||||
end
|
||||
return 0
|
||||
`)
|
||||
|
||||
// Lua 脚本:原子清理孤儿锁(仅在 PTTL == -1 时删除,避免 TOCTOU 竞态误删合法锁)
|
||||
var forceReleaseLockScript = redis.NewScript(`
|
||||
local pttl = redis.call('PTTL', KEYS[1])
|
||||
if pttl == -1 then
|
||||
redis.call('DEL', KEYS[1])
|
||||
return 1
|
||||
end
|
||||
return 0
|
||||
`)
|
||||
|
||||
type userMsgQueueCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
// NewUserMsgQueueCache 创建用户消息队列缓存
|
||||
func NewUserMsgQueueCache(rdb *redis.Client) service.UserMsgQueueCache {
|
||||
return &userMsgQueueCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func umqLockKey(accountID int64) string {
|
||||
// 格式: umq:{123}:lock — 花括号确保 Redis Cluster hash tag 生效
|
||||
return umqKeyPrefix + "{" + strconv.FormatInt(accountID, 10) + "}" + umqLockSuffix
|
||||
}
|
||||
|
||||
func umqLastKey(accountID int64) string {
|
||||
// 格式: umq:{123}:last — 与 lockKey 同一 hash slot
|
||||
return umqKeyPrefix + "{" + strconv.FormatInt(accountID, 10) + "}" + umqLastSuffix
|
||||
}
|
||||
|
||||
// umqScanPattern 用于 SCAN 扫描锁 key
|
||||
func umqScanPattern() string {
|
||||
return umqKeyPrefix + "{*}" + umqLockSuffix
|
||||
}
|
||||
|
||||
// AcquireLock 尝试获取账号级串行锁
|
||||
func (c *userMsgQueueCache) AcquireLock(ctx context.Context, accountID int64, requestID string, lockTtlMs int) (bool, error) {
|
||||
key := umqLockKey(accountID)
|
||||
result, err := acquireLockScript.Run(ctx, c.rdb, []string{key}, requestID, lockTtlMs).Int()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("umq acquire lock: %w", err)
|
||||
}
|
||||
return result == 1, nil
|
||||
}
|
||||
|
||||
// ReleaseLock 释放锁并记录完成时间
|
||||
func (c *userMsgQueueCache) ReleaseLock(ctx context.Context, accountID int64, requestID string) (bool, error) {
|
||||
lockKey := umqLockKey(accountID)
|
||||
lastKey := umqLastKey(accountID)
|
||||
result, err := releaseLockScript.Run(ctx, c.rdb, []string{lockKey, lastKey}, requestID).Int()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("umq release lock: %w", err)
|
||||
}
|
||||
return result == 1, nil
|
||||
}
|
||||
|
||||
// GetLastCompletedMs 获取上次完成时间(毫秒时间戳)
|
||||
func (c *userMsgQueueCache) GetLastCompletedMs(ctx context.Context, accountID int64) (int64, error) {
|
||||
key := umqLastKey(accountID)
|
||||
val, err := c.rdb.Get(ctx, key).Result()
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return 0, nil
|
||||
}
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("umq get last completed: %w", err)
|
||||
}
|
||||
ms, err := strconv.ParseInt(val, 10, 64)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("umq parse last completed: %w", err)
|
||||
}
|
||||
return ms, nil
|
||||
}
|
||||
|
||||
// ForceReleaseLock 原子清理孤儿锁(仅在 PTTL == -1 时删除,防止 TOCTOU 竞态误删合法锁)
|
||||
func (c *userMsgQueueCache) ForceReleaseLock(ctx context.Context, accountID int64) error {
|
||||
key := umqLockKey(accountID)
|
||||
_, err := forceReleaseLockScript.Run(ctx, c.rdb, []string{key}).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
return fmt.Errorf("umq force release lock: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ScanLockKeys 扫描所有锁 key,仅返回 PTTL == -1(无过期时间)的孤儿锁 accountID 列表
|
||||
// 正常的锁都有 PX 过期时间,PTTL == -1 表示异常状态(如 Redis 故障恢复后丢失 TTL)
|
||||
func (c *userMsgQueueCache) ScanLockKeys(ctx context.Context, maxCount int) ([]int64, error) {
|
||||
var accountIDs []int64
|
||||
var cursor uint64
|
||||
pattern := umqScanPattern()
|
||||
|
||||
for {
|
||||
keys, nextCursor, err := c.rdb.Scan(ctx, cursor, pattern, 100).Result()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("umq scan lock keys: %w", err)
|
||||
}
|
||||
for _, key := range keys {
|
||||
// 检查 PTTL:只清理 PTTL == -1(无过期时间)的异常锁
|
||||
pttl, err := c.rdb.PTTL(ctx, key).Result()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
// PTTL 返回值:-2 = key 不存在,-1 = 无过期时间,>0 = 剩余毫秒
|
||||
// go-redis 对哨兵值 -1/-2 不乘精度系数,直接返回 time.Duration(-1)/-2
|
||||
// 只删除 -1(无过期时间的异常锁),跳过正常持有的锁
|
||||
if pttl != time.Duration(-1) {
|
||||
continue
|
||||
}
|
||||
|
||||
// 从 key 中提取 accountID: umq:{123}:lock → 提取 {} 内的数字
|
||||
openBrace := strings.IndexByte(key, '{')
|
||||
closeBrace := strings.IndexByte(key, '}')
|
||||
if openBrace < 0 || closeBrace <= openBrace+1 {
|
||||
continue
|
||||
}
|
||||
idStr := key[openBrace+1 : closeBrace]
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
accountIDs = append(accountIDs, id)
|
||||
if len(accountIDs) >= maxCount {
|
||||
return accountIDs, nil
|
||||
}
|
||||
}
|
||||
cursor = nextCursor
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
return accountIDs, nil
|
||||
}
|
||||
|
||||
// GetCurrentTimeMs 通过 Redis TIME 命令获取当前服务器时间(毫秒),确保与锁记录的时间源一致
|
||||
func (c *userMsgQueueCache) GetCurrentTimeMs(ctx context.Context) (int64, error) {
|
||||
t, err := c.rdb.Time(ctx).Result()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("umq get redis time: %w", err)
|
||||
}
|
||||
return t.UnixMilli(), nil
|
||||
}
|
||||
@@ -80,6 +80,7 @@ var ProviderSet = wire.NewSet(
|
||||
ProvideConcurrencyCache,
|
||||
ProvideSessionLimitCache,
|
||||
NewRPMCache,
|
||||
NewUserMsgQueueCache,
|
||||
NewDashboardCache,
|
||||
NewEmailCache,
|
||||
NewIdentityCache,
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||
)
|
||||
|
||||
@@ -1032,6 +1033,26 @@ func (a *Account) IsTLSFingerprintEnabled() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// GetUserMsgQueueMode 获取用户消息队列模式
|
||||
// "serialize" = 串行队列, "throttle" = 软性限速, "" = 未设置(使用全局配置)
|
||||
func (a *Account) GetUserMsgQueueMode() string {
|
||||
if a.Extra == nil {
|
||||
return ""
|
||||
}
|
||||
// 优先读取新字段 user_msg_queue_mode(白名单校验,非法值视为未设置)
|
||||
if mode, ok := a.Extra["user_msg_queue_mode"].(string); ok && mode != "" {
|
||||
if mode == config.UMQModeSerialize || mode == config.UMQModeThrottle {
|
||||
return mode
|
||||
}
|
||||
return "" // 非法值 fallback 到全局配置
|
||||
}
|
||||
// 向后兼容: user_msg_queue_enabled: true → "serialize"
|
||||
if enabled, ok := a.Extra["user_msg_queue_enabled"].(bool); ok && enabled {
|
||||
return config.UMQModeSerialize
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// IsSessionIDMaskingEnabled 检查是否启用会话ID伪装
|
||||
// 仅适用于 Anthropic OAuth/SetupToken 类型账号
|
||||
// 启用后将在一段时间内(15分钟)固定 metadata.user_id 中的 session ID,
|
||||
|
||||
@@ -61,6 +61,10 @@ type ParsedRequest struct {
|
||||
ThinkingEnabled bool // 是否开启 thinking(部分平台会影响最终模型名)
|
||||
MaxTokens int // max_tokens 值(用于探测请求拦截)
|
||||
SessionContext *SessionContext // 可选:请求上下文区分因子(nil 时行为不变)
|
||||
|
||||
// OnUpstreamAccepted 上游接受请求后立即调用(用于提前释放串行锁)
|
||||
// 流式请求在收到 2xx 响应头后调用,避免持锁等流完成
|
||||
OnUpstreamAccepted func()
|
||||
}
|
||||
|
||||
// ParseGatewayRequest 解析网关请求体并返回结构化结果。
|
||||
|
||||
@@ -4305,6 +4305,12 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
}
|
||||
|
||||
// 处理正常响应
|
||||
|
||||
// 触发上游接受回调(提前释放串行锁,不等流完成)
|
||||
if parsed.OnUpstreamAccepted != nil {
|
||||
parsed.OnUpstreamAccepted()
|
||||
}
|
||||
|
||||
var usage *ClaudeUsage
|
||||
var firstTokenMs *int
|
||||
var clientDisconnect bool
|
||||
|
||||
@@ -994,7 +994,7 @@ func (s *SettingService) GetMinClaudeCodeVersion(ctx context.Context) string {
|
||||
}
|
||||
}
|
||||
// singleflight: 同一时刻只有一个 goroutine 查询 DB,其余复用结果
|
||||
result, _, _ := minVersionSF.Do("min_version", func() (any, error) {
|
||||
result, err, _ := minVersionSF.Do("min_version", func() (any, error) {
|
||||
// 二次检查,避免排队的 goroutine 重复查询
|
||||
if cached, ok := minVersionCache.Load().(*cachedMinVersion); ok {
|
||||
if time.Now().UnixNano() < cached.expiresAt {
|
||||
@@ -1020,10 +1020,14 @@ func (s *SettingService) GetMinClaudeCodeVersion(ctx context.Context) string {
|
||||
})
|
||||
return value, nil
|
||||
})
|
||||
if s, ok := result.(string); ok {
|
||||
return s
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return ""
|
||||
ver, ok := result.(string)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return ver
|
||||
}
|
||||
|
||||
// SetStreamTimeoutSettings 设置流超时处理配置
|
||||
|
||||
318
backend/internal/service/user_msg_queue_service.go
Normal file
318
backend/internal/service/user_msg_queue_service.go
Normal file
@@ -0,0 +1,318 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
cryptorand "crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"math"
|
||||
"math/rand/v2"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
)
|
||||
|
||||
// UserMsgQueueCache 用户消息串行队列 Redis 缓存接口
|
||||
type UserMsgQueueCache interface {
|
||||
// AcquireLock 尝试获取账号级串行锁
|
||||
AcquireLock(ctx context.Context, accountID int64, requestID string, lockTtlMs int) (acquired bool, err error)
|
||||
// ReleaseLock 释放锁并记录完成时间
|
||||
ReleaseLock(ctx context.Context, accountID int64, requestID string) (released bool, err error)
|
||||
// GetLastCompletedMs 获取上次完成时间(毫秒时间戳,Redis TIME 源)
|
||||
GetLastCompletedMs(ctx context.Context, accountID int64) (int64, error)
|
||||
// GetCurrentTimeMs 获取 Redis 服务器当前时间(毫秒),与 ReleaseLock 记录的时间源一致
|
||||
GetCurrentTimeMs(ctx context.Context) (int64, error)
|
||||
// ForceReleaseLock 强制释放锁(孤儿锁清理)
|
||||
ForceReleaseLock(ctx context.Context, accountID int64) error
|
||||
// ScanLockKeys 扫描 PTTL == -1 的孤儿锁 key,返回 accountID 列表
|
||||
ScanLockKeys(ctx context.Context, maxCount int) ([]int64, error)
|
||||
}
|
||||
|
||||
// QueueLockResult 锁获取结果
|
||||
type QueueLockResult struct {
|
||||
Acquired bool
|
||||
RequestID string
|
||||
}
|
||||
|
||||
// UserMessageQueueService 用户消息串行队列服务
|
||||
// 对真实用户消息实施账号级串行化 + RPM 自适应延迟
|
||||
type UserMessageQueueService struct {
|
||||
cache UserMsgQueueCache
|
||||
rpmCache RPMCache
|
||||
cfg *config.UserMessageQueueConfig
|
||||
stopCh chan struct{} // graceful shutdown
|
||||
stopOnce sync.Once // 确保 Stop() 并发安全
|
||||
}
|
||||
|
||||
// NewUserMessageQueueService 创建用户消息串行队列服务
|
||||
func NewUserMessageQueueService(cache UserMsgQueueCache, rpmCache RPMCache, cfg *config.UserMessageQueueConfig) *UserMessageQueueService {
|
||||
return &UserMessageQueueService{
|
||||
cache: cache,
|
||||
rpmCache: rpmCache,
|
||||
cfg: cfg,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// IsRealUserMessage 检测是否为真实用户消息(非 tool_result)
|
||||
// 与 claude-relay-service 的检测逻辑一致:
|
||||
// 1. messages 非空
|
||||
// 2. 最后一条消息 role == "user"
|
||||
// 3. 最后一条消息 content(如果是数组)中不含 type:"tool_result" / "tool_use_result"
|
||||
func IsRealUserMessage(parsed *ParsedRequest) bool {
|
||||
if parsed == nil || len(parsed.Messages) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
lastMsg := parsed.Messages[len(parsed.Messages)-1]
|
||||
msgMap, ok := lastMsg.(map[string]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
role, _ := msgMap["role"].(string)
|
||||
if role != "user" {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查 content 是否包含 tool_result 类型
|
||||
content, ok := msgMap["content"]
|
||||
if !ok {
|
||||
return true // 没有 content 字段,视为普通用户消息
|
||||
}
|
||||
|
||||
contentArr, ok := content.([]any)
|
||||
if !ok {
|
||||
return true // content 不是数组(可能是 string),视为普通用户消息
|
||||
}
|
||||
|
||||
for _, item := range contentArr {
|
||||
itemMap, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
itemType, _ := itemMap["type"].(string)
|
||||
if itemType == "tool_result" || itemType == "tool_use_result" {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// TryAcquire 尝试立即获取串行锁
|
||||
func (s *UserMessageQueueService) TryAcquire(ctx context.Context, accountID int64) (*QueueLockResult, error) {
|
||||
if s.cache == nil {
|
||||
return &QueueLockResult{Acquired: true}, nil // fail-open
|
||||
}
|
||||
|
||||
requestID := generateUMQRequestID()
|
||||
lockTTL := s.cfg.LockTTLMs
|
||||
if lockTTL <= 0 {
|
||||
lockTTL = 120000
|
||||
}
|
||||
|
||||
acquired, err := s.cache.AcquireLock(ctx, accountID, requestID, lockTTL)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.umq", "AcquireLock failed for account %d: %v", accountID, err)
|
||||
return &QueueLockResult{Acquired: true}, nil // fail-open
|
||||
}
|
||||
|
||||
return &QueueLockResult{
|
||||
Acquired: acquired,
|
||||
RequestID: requestID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Release 释放串行锁
|
||||
func (s *UserMessageQueueService) Release(ctx context.Context, accountID int64, requestID string) error {
|
||||
if s.cache == nil || requestID == "" {
|
||||
return nil
|
||||
}
|
||||
released, err := s.cache.ReleaseLock(ctx, accountID, requestID)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.umq", "ReleaseLock failed for account %d: %v", accountID, err)
|
||||
return err
|
||||
}
|
||||
if !released {
|
||||
logger.LegacyPrintf("service.umq", "ReleaseLock no-op for account %d (requestID mismatch or expired)", accountID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// EnforceDelay 根据 RPM 负载执行自适应延迟
|
||||
// 使用 Redis TIME 确保与 releaseLockScript 记录的时间源一致
|
||||
func (s *UserMessageQueueService) EnforceDelay(ctx context.Context, accountID int64, baseRPM int) error {
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 先检查历史记录:没有历史则无需延迟,避免不必要的 RPM 查询
|
||||
lastMs, err := s.cache.GetLastCompletedMs(ctx, accountID)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.umq", "GetLastCompletedMs failed for account %d: %v", accountID, err)
|
||||
return nil // fail-open
|
||||
}
|
||||
if lastMs == 0 {
|
||||
return nil // 没有历史记录,无需延迟
|
||||
}
|
||||
|
||||
delay := s.CalculateRPMAwareDelay(ctx, accountID, baseRPM)
|
||||
if delay <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 获取 Redis 当前时间(与 lastMs 同源,避免时钟偏差)
|
||||
nowMs, err := s.cache.GetCurrentTimeMs(ctx)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.umq", "GetCurrentTimeMs failed: %v", err)
|
||||
return nil // fail-open
|
||||
}
|
||||
|
||||
elapsed := time.Duration(nowMs-lastMs) * time.Millisecond
|
||||
if elapsed < 0 {
|
||||
// 时钟异常(Redis 故障转移等),fail-open
|
||||
return nil
|
||||
}
|
||||
remaining := delay - elapsed
|
||||
if remaining <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 执行延迟
|
||||
timer := time.NewTimer(remaining)
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-timer.C:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// CalculateRPMAwareDelay 根据当前 RPM 负载计算自适应延迟
|
||||
// ratio = currentRPM / baseRPM
|
||||
// ratio < 0.5 → MinDelay
|
||||
// 0.5 ≤ ratio < 0.8 → 线性插值 MinDelay..MaxDelay
|
||||
// ratio ≥ 0.8 → MaxDelay
|
||||
// 返回值包含 ±15% 随机抖动(anti-detection + 避免惊群效应)
|
||||
func (s *UserMessageQueueService) CalculateRPMAwareDelay(ctx context.Context, accountID int64, baseRPM int) time.Duration {
|
||||
minDelay := time.Duration(s.cfg.MinDelayMs) * time.Millisecond
|
||||
maxDelay := time.Duration(s.cfg.MaxDelayMs) * time.Millisecond
|
||||
|
||||
if minDelay <= 0 {
|
||||
minDelay = 200 * time.Millisecond
|
||||
}
|
||||
if maxDelay <= 0 {
|
||||
maxDelay = 2000 * time.Millisecond
|
||||
}
|
||||
// 防止配置错误:minDelay > maxDelay 时交换
|
||||
if minDelay > maxDelay {
|
||||
minDelay, maxDelay = maxDelay, minDelay
|
||||
}
|
||||
|
||||
var baseDelay time.Duration
|
||||
|
||||
if baseRPM <= 0 || s.rpmCache == nil {
|
||||
baseDelay = minDelay
|
||||
} else {
|
||||
currentRPM, err := s.rpmCache.GetRPM(ctx, accountID)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.umq", "GetRPM failed for account %d: %v", accountID, err)
|
||||
baseDelay = minDelay // fail-open
|
||||
} else {
|
||||
ratio := float64(currentRPM) / float64(baseRPM)
|
||||
if ratio < 0.5 {
|
||||
baseDelay = minDelay
|
||||
} else if ratio >= 0.8 {
|
||||
baseDelay = maxDelay
|
||||
} else {
|
||||
// 线性插值: 0.5 → minDelay, 0.8 → maxDelay
|
||||
t := (ratio - 0.5) / 0.3
|
||||
interpolated := float64(minDelay) + t*(float64(maxDelay)-float64(minDelay))
|
||||
baseDelay = time.Duration(math.Round(interpolated))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ±15% 随机抖动
|
||||
return applyJitter(baseDelay, 0.15)
|
||||
}
|
||||
|
||||
// StartCleanupWorker 启动孤儿锁清理 worker
|
||||
// 定期 SCAN umq:*:lock 并清理 PTTL == -1 的异常锁(PTTL 检查在 cache.ScanLockKeys 内完成)
|
||||
func (s *UserMessageQueueService) StartCleanupWorker(interval time.Duration) {
|
||||
if s == nil || s.cache == nil || interval <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
runCleanup := func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
accountIDs, err := s.cache.ScanLockKeys(ctx, 1000)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.umq", "Cleanup scan failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
cleaned := 0
|
||||
for _, accountID := range accountIDs {
|
||||
cleanCtx, cleanCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
if err := s.cache.ForceReleaseLock(cleanCtx, accountID); err != nil {
|
||||
logger.LegacyPrintf("service.umq", "Cleanup force release failed for account %d: %v", accountID, err)
|
||||
} else {
|
||||
cleaned++
|
||||
}
|
||||
cleanCancel()
|
||||
}
|
||||
|
||||
if cleaned > 0 {
|
||||
logger.LegacyPrintf("service.umq", "Cleanup completed: released %d orphaned locks", cleaned)
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-s.stopCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
runCleanup()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Stop 停止后台 cleanup worker
|
||||
func (s *UserMessageQueueService) Stop() {
|
||||
if s != nil && s.stopCh != nil {
|
||||
s.stopOnce.Do(func() {
|
||||
close(s.stopCh)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// applyJitter 对延迟值施加 ±jitterPct 的随机抖动
|
||||
// 使用 math/rand/v2(Go 1.22+ 自动使用 crypto/rand 种子),与 nextBackoff 一致
|
||||
// 例如 applyJitter(200ms, 0.15) 返回 170ms ~ 230ms
|
||||
func applyJitter(d time.Duration, jitterPct float64) time.Duration {
|
||||
if d <= 0 || jitterPct <= 0 {
|
||||
return d
|
||||
}
|
||||
// [-jitterPct, +jitterPct]
|
||||
jitter := (rand.Float64()*2 - 1) * jitterPct
|
||||
return time.Duration(float64(d) * (1 + jitter))
|
||||
}
|
||||
|
||||
// generateUMQRequestID 生成唯一请求 ID(与 generateRequestID 一致的 fallback 模式)
|
||||
func generateUMQRequestID() string {
|
||||
b := make([]byte, 16)
|
||||
if _, err := cryptorand.Read(b); err != nil {
|
||||
return fmt.Sprintf("%x", time.Now().UnixNano())
|
||||
}
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
@@ -110,6 +110,15 @@ func ProvideConcurrencyService(cache ConcurrencyCache, accountRepo AccountReposi
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProvideUserMessageQueueService 创建用户消息串行队列服务并启动清理 worker
|
||||
func ProvideUserMessageQueueService(cache UserMsgQueueCache, rpmCache RPMCache, cfg *config.Config) *UserMessageQueueService {
|
||||
svc := NewUserMessageQueueService(cache, rpmCache, &cfg.Gateway.UserMessageQueue)
|
||||
if cfg.Gateway.UserMessageQueue.CleanupIntervalSeconds > 0 {
|
||||
svc.StartCleanupWorker(time.Duration(cfg.Gateway.UserMessageQueue.CleanupIntervalSeconds) * time.Second)
|
||||
}
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProvideSchedulerSnapshotService creates and starts SchedulerSnapshotService.
|
||||
func ProvideSchedulerSnapshotService(
|
||||
cache SchedulerCache,
|
||||
@@ -348,6 +357,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewSubscriptionService,
|
||||
wire.Bind(new(DefaultSubscriptionAssigner), new(*SubscriptionService)),
|
||||
ProvideConcurrencyService,
|
||||
ProvideUserMessageQueueService,
|
||||
NewUsageRecordWorkerPool,
|
||||
ProvideSchedulerSnapshotService,
|
||||
NewIdentityService,
|
||||
|
||||
Reference in New Issue
Block a user