feat(gateway): 引入使用量记录有界 worker 池与自动扩缩容

- 新增 UsageRecordWorkerPool,支持有界队列、溢出降级策略与自动扩缩容
- 将 Gateway/OpenAI/Sora/Gemini 使用量记录改为提交到统一任务池执行
- 增加 usage_record 配置默认值与校验规则,并补充配置与任务提交相关测试
- 注入并托管 worker 池生命周期,服务退出时统一 StopAndWait

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
yangjianbo
2026-02-22 12:56:57 +08:00
parent 50b9897182
commit 33db7a0fb6
15 changed files with 1575 additions and 70 deletions

View File

@@ -37,6 +37,7 @@ type GatewayHandler struct {
billingCacheService *service.BillingCacheService
usageService *service.UsageService
apiKeyService *service.APIKeyService
usageRecordWorkerPool *service.UsageRecordWorkerPool
errorPassthroughService *service.ErrorPassthroughService
concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int
@@ -54,6 +55,7 @@ func NewGatewayHandler(
billingCacheService *service.BillingCacheService,
usageService *service.UsageService,
apiKeyService *service.APIKeyService,
usageRecordWorkerPool *service.UsageRecordWorkerPool,
errorPassthroughService *service.ErrorPassthroughService,
cfg *config.Config,
) *GatewayHandler {
@@ -77,6 +79,7 @@ func NewGatewayHandler(
billingCacheService: billingCacheService,
usageService: usageService,
apiKeyService: apiKeyService,
usageRecordWorkerPool: usageRecordWorkerPool,
errorPassthroughService: errorPassthroughService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
maxAccountSwitches: maxAccountSwitches,
@@ -431,19 +434,17 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
// 异步记录使用量subscription已在函数开头获取
go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string, fcb bool) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Account: account,
Subscription: subscription,
UserAgent: ua,
UserAgent: userAgent,
IPAddress: clientIP,
ForceCacheBilling: fcb,
ForceCacheBilling: forceCacheBilling,
APIKeyService: h.apiKeyService,
}); err != nil {
logger.L().With(
@@ -452,10 +453,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
zap.String("model", reqModel),
zap.Int64("account_id", usedAccount.ID),
zap.Int64("account_id", account.ID),
).Error("gateway.record_usage_failed", zap.Error(err))
}
}(result, account, userAgent, clientIP, forceCacheBilling)
})
return
}
}
@@ -700,19 +701,17 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
// 异步记录使用量subscription已在函数开头获取
go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string, fcb bool) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
APIKey: currentAPIKey,
User: currentAPIKey.User,
Account: usedAccount,
Account: account,
Subscription: currentSubscription,
UserAgent: ua,
UserAgent: userAgent,
IPAddress: clientIP,
ForceCacheBilling: fcb,
ForceCacheBilling: forceCacheBilling,
APIKeyService: h.apiKeyService,
}); err != nil {
logger.L().With(
@@ -721,10 +720,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
zap.Int64("api_key_id", currentAPIKey.ID),
zap.Any("group_id", currentAPIKey.GroupID),
zap.String("model", reqModel),
zap.Int64("account_id", usedAccount.ID),
zap.Int64("account_id", account.ID),
).Error("gateway.record_usage_failed", zap.Error(err))
}
}(result, account, userAgent, clientIP, forceCacheBilling)
})
reqLog.Debug("gateway.request_completed",
zap.Int64("account_id", account.ID),
zap.Int("switch_count", switchCount),
@@ -1508,3 +1507,17 @@ func billingErrorDetails(err error) (status int, code, message string) {
}
return http.StatusForbidden, "billing_error", msg
}
func (h *GatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
if task == nil {
return
}
if h.usageRecordWorkerPool != nil {
h.usageRecordWorkerPool.Submit(task)
return
}
// 回退路径worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
task(ctx)
}

View File

@@ -11,7 +11,6 @@ import (
"net/http"
"regexp"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
@@ -519,22 +518,19 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
}
}
// 6) record usage async (Gemini 使用长上下文双倍计费)
go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string, fcb bool) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Account: account,
Subscription: subscription,
UserAgent: ua,
IPAddress: ip,
UserAgent: userAgent,
IPAddress: clientIP,
LongContextThreshold: 200000, // Gemini 200K 阈值
LongContextMultiplier: 2.0, // 超出部分双倍计费
ForceCacheBilling: fcb,
ForceCacheBilling: forceCacheBilling,
APIKeyService: h.apiKeyService,
}); err != nil {
logger.L().With(
@@ -543,10 +539,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
zap.String("model", modelName),
zap.Int64("account_id", usedAccount.ID),
zap.Int64("account_id", account.ID),
).Error("gemini.record_usage_failed", zap.Error(err))
}
}(result, account, userAgent, clientIP, forceCacheBilling)
})
reqLog.Debug("gemini.request_completed",
zap.Int64("account_id", account.ID),
zap.Int("switch_count", switchCount),

View File

@@ -26,6 +26,7 @@ type OpenAIGatewayHandler struct {
gatewayService *service.OpenAIGatewayService
billingCacheService *service.BillingCacheService
apiKeyService *service.APIKeyService
usageRecordWorkerPool *service.UsageRecordWorkerPool
errorPassthroughService *service.ErrorPassthroughService
concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int
@@ -37,6 +38,7 @@ func NewOpenAIGatewayHandler(
concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService,
apiKeyService *service.APIKeyService,
usageRecordWorkerPool *service.UsageRecordWorkerPool,
errorPassthroughService *service.ErrorPassthroughService,
cfg *config.Config,
) *OpenAIGatewayHandler {
@@ -52,6 +54,7 @@ func NewOpenAIGatewayHandler(
gatewayService: gatewayService,
billingCacheService: billingCacheService,
apiKeyService: apiKeyService,
usageRecordWorkerPool: usageRecordWorkerPool,
errorPassthroughService: errorPassthroughService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
maxAccountSwitches: maxAccountSwitches,
@@ -378,18 +381,16 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
// Async record usage
go func(result *service.OpenAIForwardResult, usedAccount *service.Account, ua, ip string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Account: account,
Subscription: subscription,
UserAgent: ua,
IPAddress: ip,
UserAgent: userAgent,
IPAddress: clientIP,
APIKeyService: h.apiKeyService,
}); err != nil {
logger.L().With(
@@ -398,10 +399,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
zap.String("model", reqModel),
zap.Int64("account_id", usedAccount.ID),
zap.Int64("account_id", account.ID),
).Error("openai.record_usage_failed", zap.Error(err))
}
}(result, account, userAgent, clientIP)
})
reqLog.Debug("openai.request_completed",
zap.Int64("account_id", account.ID),
zap.Int("switch_count", switchCount),
@@ -432,6 +433,20 @@ func getContextInt64(c *gin.Context, key string) (int64, bool) {
}
}
func (h *OpenAIGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
if task == nil {
return
}
if h.usageRecordWorkerPool != nil {
h.usageRecordWorkerPool.Submit(task)
return
}
// 回退路径worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
task(ctx)
}
// handleConcurrencyError handles concurrency-related errors with proper 429 response
func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",

View File

@@ -31,15 +31,16 @@ import (
// SoraGatewayHandler handles Sora chat completions requests
type SoraGatewayHandler struct {
gatewayService *service.GatewayService
soraGatewayService *service.SoraGatewayService
billingCacheService *service.BillingCacheService
concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int
streamMode string
soraTLSEnabled bool
soraMediaSigningKey string
soraMediaRoot string
gatewayService *service.GatewayService
soraGatewayService *service.SoraGatewayService
billingCacheService *service.BillingCacheService
usageRecordWorkerPool *service.UsageRecordWorkerPool
concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int
streamMode string
soraTLSEnabled bool
soraMediaSigningKey string
soraMediaRoot string
}
// NewSoraGatewayHandler creates a new SoraGatewayHandler
@@ -48,6 +49,7 @@ func NewSoraGatewayHandler(
soraGatewayService *service.SoraGatewayService,
concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService,
usageRecordWorkerPool *service.UsageRecordWorkerPool,
cfg *config.Config,
) *SoraGatewayHandler {
pingInterval := time.Duration(0)
@@ -71,15 +73,16 @@ func NewSoraGatewayHandler(
}
}
return &SoraGatewayHandler{
gatewayService: gatewayService,
soraGatewayService: soraGatewayService,
billingCacheService: billingCacheService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
maxAccountSwitches: maxAccountSwitches,
streamMode: strings.ToLower(streamMode),
soraTLSEnabled: soraTLSEnabled,
soraMediaSigningKey: signKey,
soraMediaRoot: mediaRoot,
gatewayService: gatewayService,
soraGatewayService: soraGatewayService,
billingCacheService: billingCacheService,
usageRecordWorkerPool: usageRecordWorkerPool,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
maxAccountSwitches: maxAccountSwitches,
streamMode: strings.ToLower(streamMode),
soraTLSEnabled: soraTLSEnabled,
soraMediaSigningKey: signKey,
soraMediaRoot: mediaRoot,
}
}
@@ -397,17 +400,16 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Account: account,
Subscription: subscription,
UserAgent: ua,
IPAddress: ip,
UserAgent: userAgent,
IPAddress: clientIP,
}); err != nil {
logger.L().With(
zap.String("component", "handler.sora_gateway.chat_completions"),
@@ -415,10 +417,10 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
zap.String("model", reqModel),
zap.Int64("account_id", usedAccount.ID),
zap.Int64("account_id", account.ID),
).Error("sora.record_usage_failed", zap.Error(err))
}
}(result, account, userAgent, clientIP)
})
reqLog.Debug("sora.request_completed",
zap.Int64("account_id", account.ID),
zap.Int64("proxy_id", proxyID),
@@ -448,6 +450,20 @@ func generateOpenAISessionHash(c *gin.Context, body []byte) string {
return hex.EncodeToString(hash[:])
}
func (h *SoraGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
if task == nil {
return
}
if h.usageRecordWorkerPool != nil {
h.usageRecordWorkerPool.Submit(task)
return
}
// 回退路径worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
task(ctx)
}
func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)

View File

@@ -432,7 +432,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
soraClient := &stubSoraClient{imageURLs: []string{"https://example.com/a.png"}}
soraGatewayService := service.NewSoraGatewayService(soraClient, nil, nil, cfg)
handler := NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, cfg)
handler := NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, nil, cfg)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)

View File

@@ -0,0 +1,136 @@
package handler
import (
"context"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func newUsageRecordTestPool(t *testing.T) *service.UsageRecordWorkerPool {
t.Helper()
pool := service.NewUsageRecordWorkerPoolWithOptions(service.UsageRecordWorkerPoolOptions{
WorkerCount: 1,
QueueSize: 8,
TaskTimeout: time.Second,
OverflowPolicy: "drop",
OverflowSamplePercent: 0,
AutoScaleEnabled: false,
})
t.Cleanup(pool.Stop)
return pool
}
func TestGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) {
pool := newUsageRecordTestPool(t)
h := &GatewayHandler{usageRecordWorkerPool: pool}
done := make(chan struct{})
h.submitUsageRecordTask(func(ctx context.Context) {
close(done)
})
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("task not executed")
}
}
func TestGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.T) {
h := &GatewayHandler{}
var called atomic.Bool
h.submitUsageRecordTask(func(ctx context.Context) {
if _, ok := ctx.Deadline(); !ok {
t.Fatal("expected deadline in fallback context")
}
called.Store(true)
})
require.True(t, called.Load())
}
func TestGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
h := &GatewayHandler{}
require.NotPanics(t, func() {
h.submitUsageRecordTask(nil)
})
}
func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) {
pool := newUsageRecordTestPool(t)
h := &OpenAIGatewayHandler{usageRecordWorkerPool: pool}
done := make(chan struct{})
h.submitUsageRecordTask(func(ctx context.Context) {
close(done)
})
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("task not executed")
}
}
func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.T) {
h := &OpenAIGatewayHandler{}
var called atomic.Bool
h.submitUsageRecordTask(func(ctx context.Context) {
if _, ok := ctx.Deadline(); !ok {
t.Fatal("expected deadline in fallback context")
}
called.Store(true)
})
require.True(t, called.Load())
}
func TestOpenAIGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
h := &OpenAIGatewayHandler{}
require.NotPanics(t, func() {
h.submitUsageRecordTask(nil)
})
}
func TestSoraGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) {
pool := newUsageRecordTestPool(t)
h := &SoraGatewayHandler{usageRecordWorkerPool: pool}
done := make(chan struct{})
h.submitUsageRecordTask(func(ctx context.Context) {
close(done)
})
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("task not executed")
}
}
func TestSoraGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.T) {
h := &SoraGatewayHandler{}
var called atomic.Bool
h.submitUsageRecordTask(func(ctx context.Context) {
if _, ok := ctx.Deadline(); !ok {
t.Fatal("expected deadline in fallback context")
}
called.Store(true)
})
require.True(t, called.Load())
}
func TestSoraGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
h := &SoraGatewayHandler{}
require.NotPanics(t, func() {
h.submitUsageRecordTask(nil)
})
}