fix(网关): 防止连接池缓存失控

超限且无可淘汰条目时拒绝新建

规范化代理地址并更新失败时的访问时间

补充连接池上限与代理规范化测试
This commit is contained in:
yangjianbo
2025-12-31 12:01:31 +08:00
parent d1c9889609
commit 820bb16ca7
2 changed files with 109 additions and 33 deletions

View File

@@ -1,8 +1,10 @@
package repository
import (
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
@@ -40,6 +42,8 @@ const (
defaultClientIdleTTLSeconds = 900
)
var errUpstreamClientLimitReached = errors.New("upstream client cache limit reached")
// poolSettings 连接池配置参数
// 封装 Transport 所需的各项连接池参数
type poolSettings struct {
@@ -116,13 +120,17 @@ func NewHTTPUpstream(cfg *config.Config) service.HTTPUpstream {
// - inFlight > 0 的客户端不会被淘汰,确保活跃请求不被中断
func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
// 获取或创建对应的客户端,并标记请求占用
entry := s.acquireClient(proxyURL, accountID, accountConcurrency)
entry, err := s.acquireClient(proxyURL, accountID, accountConcurrency)
if err != nil {
return nil, err
}
// 执行请求
resp, err := entry.client.Do(req)
if err != nil {
// 请求失败,立即减少计数
atomic.AddInt64(&entry.inFlight, -1)
atomic.StoreInt64(&entry.lastUsed, time.Now().UnixNano())
return nil, err
}
@@ -138,8 +146,8 @@ func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID i
// acquireClient 获取或创建客户端,并标记为进行中请求
// 用于请求路径,避免在获取后被淘汰
func (s *httpUpstreamService) acquireClient(proxyURL string, accountID int64, accountConcurrency int) *upstreamClientEntry {
return s.getClientEntry(proxyURL, accountID, accountConcurrency, true)
func (s *httpUpstreamService) acquireClient(proxyURL string, accountID int64, accountConcurrency int) (*upstreamClientEntry, error) {
return s.getClientEntry(proxyURL, accountID, accountConcurrency, true, true)
}
// getOrCreateClient 获取或创建客户端
@@ -158,12 +166,14 @@ func (s *httpUpstreamService) acquireClient(proxyURL string, accountID int64, ac
// - account: 按账户隔离,同一账户共享客户端(代理变更时重建)
// - account_proxy: 按账户+代理组合隔离,最细粒度
func (s *httpUpstreamService) getOrCreateClient(proxyURL string, accountID int64, accountConcurrency int) *upstreamClientEntry {
return s.getClientEntry(proxyURL, accountID, accountConcurrency, false)
entry, _ := s.getClientEntry(proxyURL, accountID, accountConcurrency, false, false)
return entry
}
// getClientEntry 获取或创建客户端条目
// markInFlight=true 时会标记进行中请求,用于请求路径防止被淘汰
func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, accountConcurrency int, markInFlight bool) *upstreamClientEntry {
// enforceLimit=true 时会限制客户端数量,超限且无法淘汰时返回错误
func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, accountConcurrency int, markInFlight bool, enforceLimit bool) (*upstreamClientEntry, error) {
// 获取隔离模式
isolation := s.getIsolationMode()
// 标准化代理 URL 并解析
@@ -184,7 +194,7 @@ func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, a
atomic.AddInt64(&entry.inFlight, 1)
}
s.mu.RUnlock()
return entry
return entry, nil
}
s.mu.RUnlock()
@@ -197,11 +207,22 @@ func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, a
atomic.AddInt64(&entry.inFlight, 1)
}
s.mu.Unlock()
return entry
return entry, nil
}
s.removeClientLocked(cacheKey, entry)
}
// 超出缓存上限时尝试淘汰,无法淘汰则拒绝新建
if enforceLimit && s.maxUpstreamClients() > 0 {
s.evictIdleLocked(now)
if len(s.clients) >= s.maxUpstreamClients() {
if !s.evictOldestIdleLocked() {
s.mu.Unlock()
return nil, errUpstreamClientLimitReached
}
}
}
// 缓存未命中或需要重建,创建新客户端
settings := s.resolvePoolSettings(isolation, accountConcurrency)
client := &http.Client{Transport: buildUpstreamTransport(settings, parsedProxy)}
@@ -220,7 +241,7 @@ func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, a
s.evictIdleLocked(now)
s.evictOverLimitLocked()
s.mu.Unlock()
return entry
return entry, nil
}
// shouldReuseEntry 判断缓存条目是否可复用
@@ -277,39 +298,50 @@ func (s *httpUpstreamService) evictIdleLocked(now time.Time) {
}
}
// evictOldestIdleLocked 淘汰最久未使用且无活跃请求的客户端(需持有锁)
func (s *httpUpstreamService) evictOldestIdleLocked() bool {
var (
oldestKey string
oldestEntry *upstreamClientEntry
oldestTime int64
)
// 查找最久未使用且无活跃请求的客户端
for key, entry := range s.clients {
// 跳过有活跃请求的客户端
if atomic.LoadInt64(&entry.inFlight) != 0 {
continue
}
lastUsed := atomic.LoadInt64(&entry.lastUsed)
if oldestEntry == nil || lastUsed < oldestTime {
oldestKey = key
oldestEntry = entry
oldestTime = lastUsed
}
}
// 所有客户端都有活跃请求,无法淘汰
if oldestEntry == nil {
return false
}
s.removeClientLocked(oldestKey, oldestEntry)
return true
}
// evictOverLimitLocked 淘汰超出数量限制的客户端(需持有锁)
// 使用 LRU 策略,优先淘汰最久未使用且无活跃请求的客户端
func (s *httpUpstreamService) evictOverLimitLocked() {
func (s *httpUpstreamService) evictOverLimitLocked() bool {
maxClients := s.maxUpstreamClients()
if maxClients <= 0 {
return
return false
}
evicted := false
// 循环淘汰直到满足数量限制
for len(s.clients) > maxClients {
var (
oldestKey string
oldestEntry *upstreamClientEntry
oldestTime int64
)
// 查找最久未使用且无活跃请求的客户端
for key, entry := range s.clients {
// 跳过有活跃请求的客户端
if atomic.LoadInt64(&entry.inFlight) != 0 {
continue
}
lastUsed := atomic.LoadInt64(&entry.lastUsed)
if oldestEntry == nil || lastUsed < oldestTime {
oldestKey = key
oldestEntry = entry
oldestTime = lastUsed
}
if !s.evictOldestIdleLocked() {
return evicted
}
// 所有客户端都有活跃请求,无法淘汰
if oldestEntry == nil {
return
}
s.removeClientLocked(oldestKey, oldestEntry)
evicted = true
}
return evicted
}
// getIsolationMode 获取连接池隔离模式
@@ -443,7 +475,26 @@ func normalizeProxyURL(raw string) (string, *url.URL) {
if err != nil {
return directProxyKey, nil
}
return proxyURL, parsed
parsed.Scheme = strings.ToLower(parsed.Scheme)
parsed.Host = strings.ToLower(parsed.Host)
parsed.Path = ""
parsed.RawPath = ""
parsed.RawQuery = ""
parsed.Fragment = ""
parsed.ForceQuery = false
if hostname := parsed.Hostname(); hostname != "" {
port := parsed.Port()
if (parsed.Scheme == "http" && port == "80") || (parsed.Scheme == "https" && port == "443") {
port = ""
}
hostname = strings.ToLower(hostname)
if port != "" {
parsed.Host = net.JoinHostPort(hostname, port)
} else {
parsed.Host = hostname
}
}
return parsed.String(), parsed
}
// defaultPoolSettings 获取默认连接池配置

View File

@@ -64,6 +64,31 @@ func (s *HTTPUpstreamSuite) TestGetOrCreateClient_InvalidURLFallsBackToDirect()
require.Equal(s.T(), directProxyKey, entry.proxyKey, "expected direct proxy fallback")
}
// TestNormalizeProxyURL_Canonicalizes 测试代理 URL 规范化
// 验证等价地址能够映射到同一缓存键
func (s *HTTPUpstreamSuite) TestNormalizeProxyURL_Canonicalizes() {
key1, _ := normalizeProxyURL("http://proxy.local:8080")
key2, _ := normalizeProxyURL("http://proxy.local:8080/")
require.Equal(s.T(), key1, key2, "expected normalized proxy keys to match")
}
// TestAcquireClient_OverLimitReturnsError 测试连接池缓存上限保护
// 验证超限且无可淘汰条目时返回错误
func (s *HTTPUpstreamSuite) TestAcquireClient_OverLimitReturnsError() {
s.cfg.Gateway = config.GatewayConfig{
ConnectionPoolIsolation: config.ConnectionPoolIsolationAccountProxy,
MaxUpstreamClients: 1,
}
svc := s.newService()
entry1, err := svc.acquireClient("http://proxy-a:8080", 1, 1)
require.NoError(s.T(), err, "expected first acquire to succeed")
require.NotNil(s.T(), entry1, "expected entry")
entry2, err := svc.acquireClient("http://proxy-b:8080", 2, 1)
require.Error(s.T(), err, "expected error when cache limit reached")
require.Nil(s.T(), entry2, "expected nil entry when cache limit reached")
}
// TestDo_WithoutProxy_GoesDirect 测试无代理时直连
// 验证空代理 URL 时请求直接发送到目标服务器
func (s *HTTPUpstreamSuite) TestDo_WithoutProxy_GoesDirect() {