perf(后端): 完成性能优化与连接池配置
新增 DB/Redis 连接池配置与校验,并补充单测 网关请求体大小限制与 413 处理 HTTP/req 客户端池化并调整上游连接池默认值 并发槽位改为 ZSET+Lua 与指数退避 用量统计改 SQL 聚合并新增索引迁移 计费缓存写入改工作池并补测试/基准 测试: 在 backend/ 下运行 go test ./...
This commit is contained in:
@@ -67,6 +67,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 读取请求体
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||
return
|
||||
}
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
||||
return
|
||||
}
|
||||
@@ -76,15 +80,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 解析请求获取模型名和stream
|
||||
var req struct {
|
||||
Model string `json:"model"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
parsedReq, err := service.ParseGatewayRequest(body)
|
||||
if err != nil {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
}
|
||||
reqModel := parsedReq.Model
|
||||
reqStream := parsedReq.Stream
|
||||
|
||||
// Track if we've started streaming (for error handling)
|
||||
streamStarted := false
|
||||
@@ -106,7 +108,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
||||
|
||||
// 1. 首先获取用户并发槽位
|
||||
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, req.Stream, &streamStarted)
|
||||
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("User concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "user", streamStarted)
|
||||
@@ -124,7 +126,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 计算粘性会话hash
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(body)
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
|
||||
|
||||
// 获取平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context),否则使用分组平台
|
||||
platform := ""
|
||||
@@ -141,7 +143,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
lastFailoverStatus := 0
|
||||
|
||||
for {
|
||||
account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model, failedAccountIDs)
|
||||
account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
|
||||
if err != nil {
|
||||
if len(failedAccountIDs) == 0 {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
@@ -153,16 +155,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
|
||||
// 检查预热请求拦截(在账号选择后、转发前检查)
|
||||
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
||||
if req.Stream {
|
||||
sendMockWarmupStream(c, req.Model)
|
||||
if reqStream {
|
||||
sendMockWarmupStream(c, reqModel)
|
||||
} else {
|
||||
sendMockWarmupResponse(c, req.Model)
|
||||
sendMockWarmupResponse(c, reqModel)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 获取账号并发槽位
|
||||
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted)
|
||||
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
@@ -172,7 +174,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 转发请求 - 根据账号平台分流
|
||||
var result *service.ForwardResult
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, req.Model, "generateContent", req.Stream, body)
|
||||
result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, reqModel, "generateContent", reqStream, body)
|
||||
} else {
|
||||
result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, body)
|
||||
}
|
||||
@@ -223,7 +225,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
|
||||
for {
|
||||
// 选择支持该模型的账号
|
||||
account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model, failedAccountIDs)
|
||||
account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
|
||||
if err != nil {
|
||||
if len(failedAccountIDs) == 0 {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
@@ -235,16 +237,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
|
||||
// 检查预热请求拦截(在账号选择后、转发前检查)
|
||||
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
||||
if req.Stream {
|
||||
sendMockWarmupStream(c, req.Model)
|
||||
if reqStream {
|
||||
sendMockWarmupStream(c, reqModel)
|
||||
} else {
|
||||
sendMockWarmupResponse(c, req.Model)
|
||||
sendMockWarmupResponse(c, reqModel)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 获取账号并发槽位
|
||||
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted)
|
||||
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
@@ -256,7 +258,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
result, err = h.antigravityGatewayService.Forward(c.Request.Context(), c, account, body)
|
||||
} else {
|
||||
result, err = h.gatewayService.Forward(c.Request.Context(), c, account, body)
|
||||
result, err = h.gatewayService.Forward(c.Request.Context(), c, account, parsedReq)
|
||||
}
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
@@ -496,6 +498,10 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
// 读取请求体
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||
return
|
||||
}
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
||||
return
|
||||
}
|
||||
@@ -505,11 +511,8 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 解析请求获取模型名
|
||||
var req struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
parsedReq, err := service.ParseGatewayRequest(body)
|
||||
if err != nil {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
}
|
||||
@@ -525,17 +528,17 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 计算粘性会话 hash
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(body)
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
|
||||
|
||||
// 选择支持该模型的账号
|
||||
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model)
|
||||
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, parsedReq.Model)
|
||||
if err != nil {
|
||||
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 转发请求(不记录使用量)
|
||||
if err := h.gatewayService.ForwardCountTokens(c.Request.Context(), c, account, body); err != nil {
|
||||
if err := h.gatewayService.ForwardCountTokens(c.Request.Context(), c, account, parsedReq); err != nil {
|
||||
log.Printf("Forward count_tokens request failed: %v", err)
|
||||
// 错误响应已在 ForwardCountTokens 中处理
|
||||
return
|
||||
|
||||
@@ -3,6 +3,7 @@ package handler
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
@@ -11,11 +12,28 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// 并发槽位等待相关常量
|
||||
//
|
||||
// 性能优化说明:
|
||||
// 原实现使用固定间隔(100ms)轮询并发槽位,存在以下问题:
|
||||
// 1. 高并发时频繁轮询增加 Redis 压力
|
||||
// 2. 固定间隔可能导致多个请求同时重试(惊群效应)
|
||||
//
|
||||
// 新实现使用指数退避 + 抖动算法:
|
||||
// 1. 初始退避 100ms,每次乘以 1.5,最大 2s
|
||||
// 2. 添加 ±20% 的随机抖动,分散重试时间点
|
||||
// 3. 减少 Redis 压力,避免惊群效应
|
||||
const (
|
||||
// maxConcurrencyWait is the maximum time to wait for a concurrency slot
|
||||
// maxConcurrencyWait 等待并发槽位的最大时间
|
||||
maxConcurrencyWait = 30 * time.Second
|
||||
// pingInterval is the interval for sending ping events during slot wait
|
||||
// pingInterval 流式响应等待时发送 ping 的间隔
|
||||
pingInterval = 15 * time.Second
|
||||
// initialBackoff 初始退避时间
|
||||
initialBackoff = 100 * time.Millisecond
|
||||
// backoffMultiplier 退避时间乘数(指数退避)
|
||||
backoffMultiplier = 1.5
|
||||
// maxBackoff 最大退避时间
|
||||
maxBackoff = 2 * time.Second
|
||||
)
|
||||
|
||||
// SSEPingFormat defines the format of SSE ping events for different platforms
|
||||
@@ -131,8 +149,10 @@ func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string,
|
||||
pingCh = pingTicker.C
|
||||
}
|
||||
|
||||
pollTicker := time.NewTicker(100 * time.Millisecond)
|
||||
defer pollTicker.Stop()
|
||||
backoff := initialBackoff
|
||||
timer := time.NewTimer(backoff)
|
||||
defer timer.Stop()
|
||||
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
|
||||
for {
|
||||
select {
|
||||
@@ -156,7 +176,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string,
|
||||
}
|
||||
flusher.Flush()
|
||||
|
||||
case <-pollTicker.C:
|
||||
case <-timer.C:
|
||||
// Try to acquire slot
|
||||
var result *service.AcquireResult
|
||||
var err error
|
||||
@@ -174,6 +194,35 @@ func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string,
|
||||
if result.Acquired {
|
||||
return result.ReleaseFunc, nil
|
||||
}
|
||||
backoff = nextBackoff(backoff, rng)
|
||||
timer.Reset(backoff)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// nextBackoff 计算下一次退避时间
|
||||
// 性能优化:使用指数退避 + 随机抖动,避免惊群效应
|
||||
// current: 当前退避时间
|
||||
// rng: 随机数生成器(可为 nil,此时不添加抖动)
|
||||
// 返回值:下一次退避时间(100ms ~ 2s 之间)
|
||||
func nextBackoff(current time.Duration, rng *rand.Rand) time.Duration {
|
||||
// 指数退避:当前时间 * 1.5
|
||||
next := time.Duration(float64(current) * backoffMultiplier)
|
||||
if next > maxBackoff {
|
||||
next = maxBackoff
|
||||
}
|
||||
if rng == nil {
|
||||
return next
|
||||
}
|
||||
// 添加 ±20% 的随机抖动(jitter 范围 0.8 ~ 1.2)
|
||||
// 抖动可以分散多个请求的重试时间点,避免同时冲击 Redis
|
||||
jitter := 0.8 + rng.Float64()*0.4
|
||||
jittered := time.Duration(float64(next) * jitter)
|
||||
if jittered < initialBackoff {
|
||||
return initialBackoff
|
||||
}
|
||||
if jittered > maxBackoff {
|
||||
return maxBackoff
|
||||
}
|
||||
return jittered
|
||||
}
|
||||
|
||||
@@ -148,6 +148,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
googleError(c, http.StatusRequestEntityTooLarge, buildBodyTooLargeMessage(maxErr.Limit))
|
||||
return
|
||||
}
|
||||
googleError(c, http.StatusBadRequest, "Failed to read request body")
|
||||
return
|
||||
}
|
||||
@@ -191,7 +195,8 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 3) select account (sticky session based on request body)
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(body)
|
||||
parsedReq, _ := service.ParseGatewayRequest(body)
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
|
||||
const maxAccountSwitches = 3
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
|
||||
@@ -56,6 +56,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
// Read request body
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||
return
|
||||
}
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
||||
return
|
||||
}
|
||||
|
||||
27
backend/internal/handler/request_body_limit.go
Normal file
27
backend/internal/handler/request_body_limit.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func extractMaxBytesError(err error) (*http.MaxBytesError, bool) {
|
||||
var maxErr *http.MaxBytesError
|
||||
if errors.As(err, &maxErr) {
|
||||
return maxErr, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func formatBodyLimit(limit int64) string {
|
||||
const mb = 1024 * 1024
|
||||
if limit >= mb {
|
||||
return fmt.Sprintf("%dMB", limit/mb)
|
||||
}
|
||||
return fmt.Sprintf("%dB", limit)
|
||||
}
|
||||
|
||||
func buildBodyTooLargeMessage(limit int64) string {
|
||||
return fmt.Sprintf("Request body too large, limit is %s", formatBodyLimit(limit))
|
||||
}
|
||||
45
backend/internal/handler/request_body_limit_test.go
Normal file
45
backend/internal/handler/request_body_limit_test.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRequestBodyLimitTooLarge(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
limit := int64(16)
|
||||
router := gin.New()
|
||||
router.Use(middleware.RequestBodyLimit(limit))
|
||||
router.POST("/test", func(c *gin.Context) {
|
||||
_, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
c.JSON(http.StatusRequestEntityTooLarge, gin.H{
|
||||
"error": buildBodyTooLargeMessage(maxErr.Limit),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "read_failed",
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
payload := bytes.Repeat([]byte("a"), int(limit+1))
|
||||
req := httptest.NewRequest(http.MethodPost, "/test", bytes.NewReader(payload))
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
require.Equal(t, http.StatusRequestEntityTooLarge, recorder.Code)
|
||||
require.Contains(t, recorder.Body.String(), buildBodyTooLargeMessage(limit))
|
||||
}
|
||||
Reference in New Issue
Block a user