Files
sub2api/backend/internal/service/openai_ws_pool_test.go
2026-02-28 15:01:20 +08:00

1710 lines
54 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"context"
"errors"
"net/http"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func TestOpenAIWSConnPool_CleanupStaleAndTrimIdle(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
pool := newOpenAIWSConnPool(cfg)
accountID := int64(10)
ap := pool.getOrCreateAccountPool(accountID)
stale := newOpenAIWSConn("stale", accountID, nil, nil)
stale.createdAtNano.Store(time.Now().Add(-2 * time.Hour).UnixNano())
stale.lastUsedNano.Store(time.Now().Add(-2 * time.Hour).UnixNano())
idleOld := newOpenAIWSConn("idle_old", accountID, nil, nil)
idleOld.lastUsedNano.Store(time.Now().Add(-10 * time.Minute).UnixNano())
idleNew := newOpenAIWSConn("idle_new", accountID, nil, nil)
idleNew.lastUsedNano.Store(time.Now().Add(-1 * time.Minute).UnixNano())
ap.conns[stale.id] = stale
ap.conns[idleOld.id] = idleOld
ap.conns[idleNew.id] = idleNew
evicted := pool.cleanupAccountLocked(ap, time.Now(), pool.maxConnsHardCap())
closeOpenAIWSConns(evicted)
require.Nil(t, ap.conns["stale"], "stale connection should be rotated")
require.Nil(t, ap.conns["idle_old"], "old idle should be trimmed by max_idle")
require.NotNil(t, ap.conns["idle_new"], "newer idle should be kept")
}
func TestOpenAIWSConnPool_NextConnIDFormat(t *testing.T) {
pool := newOpenAIWSConnPool(&config.Config{})
id1 := pool.nextConnID(42)
id2 := pool.nextConnID(42)
require.True(t, strings.HasPrefix(id1, "oa_ws_42_"))
require.True(t, strings.HasPrefix(id2, "oa_ws_42_"))
require.NotEqual(t, id1, id2)
require.Equal(t, "oa_ws_42_1", id1)
require.Equal(t, "oa_ws_42_2", id2)
}
func TestOpenAIWSConnPool_AcquireCleanupInterval(t *testing.T) {
require.Equal(t, 3*time.Second, openAIWSAcquireCleanupInterval)
require.Less(t, openAIWSAcquireCleanupInterval, openAIWSBackgroundSweepTicker)
}
func TestOpenAIWSConnLease_WriteJSONAndGuards(t *testing.T) {
conn := newOpenAIWSConn("lease_write", 1, &openAIWSFakeConn{}, nil)
lease := &openAIWSConnLease{conn: conn}
require.NoError(t, lease.WriteJSON(map[string]any{"type": "response.create"}, 0))
var nilLease *openAIWSConnLease
err := nilLease.WriteJSONWithContextTimeout(context.Background(), map[string]any{"type": "response.create"}, time.Second)
require.ErrorIs(t, err, errOpenAIWSConnClosed)
err = (&openAIWSConnLease{}).WriteJSONWithContextTimeout(context.Background(), map[string]any{"type": "response.create"}, time.Second)
require.ErrorIs(t, err, errOpenAIWSConnClosed)
}
func TestOpenAIWSConn_WriteJSONWithTimeout_NilParentContextUsesBackground(t *testing.T) {
probe := &openAIWSContextProbeConn{}
conn := newOpenAIWSConn("ctx_probe", 1, probe, nil)
require.NoError(t, conn.writeJSONWithTimeout(context.Background(), map[string]any{"type": "response.create"}, 0))
require.NotNil(t, probe.lastWriteCtx)
}
func TestOpenAIWSConnPool_TargetConnCountAdaptive(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 6
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 1
cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.5
pool := newOpenAIWSConnPool(cfg)
ap := pool.getOrCreateAccountPool(88)
conn1 := newOpenAIWSConn("c1", 88, nil, nil)
conn2 := newOpenAIWSConn("c2", 88, nil, nil)
require.True(t, conn1.tryAcquire())
require.True(t, conn2.tryAcquire())
conn1.waiters.Store(1)
conn2.waiters.Store(1)
ap.conns[conn1.id] = conn1
ap.conns[conn2.id] = conn2
target := pool.targetConnCountLocked(ap, pool.maxConnsHardCap())
require.Equal(t, 6, target, "应按 inflight+waiters 与 target_utilization 自适应扩容到上限")
conn1.release()
conn2.release()
conn1.waiters.Store(0)
conn2.waiters.Store(0)
target = pool.targetConnCountLocked(ap, pool.maxConnsHardCap())
require.Equal(t, 1, target, "低负载时应缩回到最小空闲连接")
}
func TestOpenAIWSConnPool_TargetConnCountMinIdleZero(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 4
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.8
pool := newOpenAIWSConnPool(cfg)
ap := pool.getOrCreateAccountPool(66)
target := pool.targetConnCountLocked(ap, pool.maxConnsHardCap())
require.Equal(t, 0, target, "min_idle=0 且无负载时应允许缩容到 0")
}
func TestOpenAIWSConnPool_EnsureTargetIdleAsync(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 4
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 2
cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.8
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1
pool := newOpenAIWSConnPool(cfg)
pool.setClientDialerForTest(&openAIWSFakeDialer{})
accountID := int64(77)
account := &Account{ID: accountID, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
ap := pool.getOrCreateAccountPool(accountID)
ap.mu.Lock()
ap.lastAcquire = &openAIWSAcquireRequest{
Account: account,
WSURL: "wss://example.com/v1/responses",
}
ap.mu.Unlock()
pool.ensureTargetIdleAsync(accountID)
require.Eventually(t, func() bool {
ap, ok := pool.getAccountPool(accountID)
if !ok || ap == nil {
return false
}
ap.mu.Lock()
defer ap.mu.Unlock()
return len(ap.conns) >= 2
}, 2*time.Second, 20*time.Millisecond)
metrics := pool.SnapshotMetrics()
require.GreaterOrEqual(t, metrics.ScaleUpTotal, int64(2))
}
func TestOpenAIWSConnPool_EnsureTargetIdleAsyncCooldown(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 4
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 2
cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.8
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1
cfg.Gateway.OpenAIWS.PrewarmCooldownMS = 500
pool := newOpenAIWSConnPool(cfg)
dialer := &openAIWSCountingDialer{}
pool.setClientDialerForTest(dialer)
accountID := int64(178)
account := &Account{ID: accountID, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
ap := pool.getOrCreateAccountPool(accountID)
ap.mu.Lock()
ap.lastAcquire = &openAIWSAcquireRequest{
Account: account,
WSURL: "wss://example.com/v1/responses",
}
ap.mu.Unlock()
pool.ensureTargetIdleAsync(accountID)
require.Eventually(t, func() bool {
ap, ok := pool.getAccountPool(accountID)
if !ok || ap == nil {
return false
}
ap.mu.Lock()
defer ap.mu.Unlock()
return len(ap.conns) >= 2 && !ap.prewarmActive
}, 2*time.Second, 20*time.Millisecond)
firstDialCount := dialer.DialCount()
require.GreaterOrEqual(t, firstDialCount, 2)
// 人工制造缺口触发新一轮预热需求。
ap, ok := pool.getAccountPool(accountID)
require.True(t, ok)
require.NotNil(t, ap)
ap.mu.Lock()
for id := range ap.conns {
delete(ap.conns, id)
break
}
ap.mu.Unlock()
pool.ensureTargetIdleAsync(accountID)
time.Sleep(120 * time.Millisecond)
require.Equal(t, firstDialCount, dialer.DialCount(), "cooldown 窗口内不应再次触发预热")
time.Sleep(450 * time.Millisecond)
pool.ensureTargetIdleAsync(accountID)
require.Eventually(t, func() bool {
return dialer.DialCount() > firstDialCount
}, 2*time.Second, 20*time.Millisecond)
}
func TestOpenAIWSConnPool_EnsureTargetIdleAsyncFailureSuppress(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 1
cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.8
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1
cfg.Gateway.OpenAIWS.PrewarmCooldownMS = 0
pool := newOpenAIWSConnPool(cfg)
dialer := &openAIWSAlwaysFailDialer{}
pool.setClientDialerForTest(dialer)
accountID := int64(279)
account := &Account{ID: accountID, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
ap := pool.getOrCreateAccountPool(accountID)
ap.mu.Lock()
ap.lastAcquire = &openAIWSAcquireRequest{
Account: account,
WSURL: "wss://example.com/v1/responses",
}
ap.mu.Unlock()
pool.ensureTargetIdleAsync(accountID)
require.Eventually(t, func() bool {
ap, ok := pool.getAccountPool(accountID)
if !ok || ap == nil {
return false
}
ap.mu.Lock()
defer ap.mu.Unlock()
return !ap.prewarmActive
}, 2*time.Second, 20*time.Millisecond)
pool.ensureTargetIdleAsync(accountID)
require.Eventually(t, func() bool {
ap, ok := pool.getAccountPool(accountID)
if !ok || ap == nil {
return false
}
ap.mu.Lock()
defer ap.mu.Unlock()
return !ap.prewarmActive
}, 2*time.Second, 20*time.Millisecond)
require.Equal(t, 2, dialer.DialCount())
// 连续失败达到阈值后,新的预热触发应被抑制,不再继续拨号。
pool.ensureTargetIdleAsync(accountID)
time.Sleep(120 * time.Millisecond)
require.Equal(t, 2, dialer.DialCount())
}
func TestOpenAIWSConnPool_AcquireQueueWaitMetrics(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 4
pool := newOpenAIWSConnPool(cfg)
accountID := int64(99)
account := &Account{ID: accountID, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
conn := newOpenAIWSConn("busy", accountID, &openAIWSFakeConn{}, nil)
require.True(t, conn.tryAcquire()) // 占用连接,触发后续排队
ap := pool.ensureAccountPoolLocked(accountID)
ap.mu.Lock()
ap.conns[conn.id] = conn
ap.lastAcquire = &openAIWSAcquireRequest{
Account: account,
WSURL: "wss://example.com/v1/responses",
}
ap.mu.Unlock()
go func() {
time.Sleep(60 * time.Millisecond)
conn.release()
}()
lease, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{
Account: account,
WSURL: "wss://example.com/v1/responses",
})
require.NoError(t, err)
require.NotNil(t, lease)
require.True(t, lease.Reused())
require.GreaterOrEqual(t, lease.QueueWaitDuration(), 50*time.Millisecond)
lease.Release()
metrics := pool.SnapshotMetrics()
require.GreaterOrEqual(t, metrics.AcquireQueueWaitTotal, int64(1))
require.Greater(t, metrics.AcquireQueueWaitMsTotal, int64(0))
require.GreaterOrEqual(t, metrics.ConnPickTotal, int64(1))
}
func TestOpenAIWSConnPool_ForceNewConnSkipsReuse(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2
pool := newOpenAIWSConnPool(cfg)
dialer := &openAIWSCountingDialer{}
pool.setClientDialerForTest(dialer)
account := &Account{ID: 123, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
lease1, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{
Account: account,
WSURL: "wss://example.com/v1/responses",
})
require.NoError(t, err)
require.NotNil(t, lease1)
lease1.Release()
lease2, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{
Account: account,
WSURL: "wss://example.com/v1/responses",
ForceNewConn: true,
})
require.NoError(t, err)
require.NotNil(t, lease2)
lease2.Release()
require.Equal(t, 2, dialer.DialCount(), "ForceNewConn=true 时应跳过空闲连接复用并新建连接")
}
func TestOpenAIWSConnPool_AcquireForcePreferredConnUnavailable(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2
pool := newOpenAIWSConnPool(cfg)
account := &Account{ID: 124, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
ap := pool.getOrCreateAccountPool(account.ID)
otherConn := newOpenAIWSConn("other_conn", account.ID, &openAIWSFakeConn{}, nil)
ap.mu.Lock()
ap.conns[otherConn.id] = otherConn
ap.mu.Unlock()
_, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{
Account: account,
WSURL: "wss://example.com/v1/responses",
ForcePreferredConn: true,
})
require.ErrorIs(t, err, errOpenAIWSPreferredConnUnavailable)
_, err = pool.Acquire(context.Background(), openAIWSAcquireRequest{
Account: account,
WSURL: "wss://example.com/v1/responses",
PreferredConnID: "missing_conn",
ForcePreferredConn: true,
})
require.ErrorIs(t, err, errOpenAIWSPreferredConnUnavailable)
}
func TestOpenAIWSConnPool_AcquireForcePreferredConnQueuesOnPreferredOnly(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 4
pool := newOpenAIWSConnPool(cfg)
account := &Account{ID: 125, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
ap := pool.getOrCreateAccountPool(account.ID)
preferredConn := newOpenAIWSConn("preferred_conn", account.ID, &openAIWSFakeConn{}, nil)
otherConn := newOpenAIWSConn("other_conn_idle", account.ID, &openAIWSFakeConn{}, nil)
require.True(t, preferredConn.tryAcquire(), "先占用 preferred 连接,触发排队获取")
ap.mu.Lock()
ap.conns[preferredConn.id] = preferredConn
ap.conns[otherConn.id] = otherConn
ap.lastCleanupAt = time.Now()
ap.mu.Unlock()
go func() {
time.Sleep(60 * time.Millisecond)
preferredConn.release()
}()
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
lease, err := pool.Acquire(ctx, openAIWSAcquireRequest{
Account: account,
WSURL: "wss://example.com/v1/responses",
PreferredConnID: preferredConn.id,
ForcePreferredConn: true,
})
require.NoError(t, err)
require.NotNil(t, lease)
require.Equal(t, preferredConn.id, lease.ConnID(), "严格模式应只等待并复用 preferred 连接,不可漂移")
require.GreaterOrEqual(t, lease.QueueWaitDuration(), 40*time.Millisecond)
lease.Release()
require.True(t, otherConn.tryAcquire(), "other 连接不应被严格模式抢占")
otherConn.release()
}
func TestOpenAIWSConnPool_AcquireForcePreferredConnDirectAndQueueFull(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 1
pool := newOpenAIWSConnPool(cfg)
account := &Account{ID: 127, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
ap := pool.getOrCreateAccountPool(account.ID)
preferredConn := newOpenAIWSConn("preferred_conn_direct", account.ID, &openAIWSFakeConn{}, nil)
otherConn := newOpenAIWSConn("other_conn_direct", account.ID, &openAIWSFakeConn{}, nil)
ap.mu.Lock()
ap.conns[preferredConn.id] = preferredConn
ap.conns[otherConn.id] = otherConn
ap.lastCleanupAt = time.Now()
ap.mu.Unlock()
lease, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{
Account: account,
WSURL: "wss://example.com/v1/responses",
PreferredConnID: preferredConn.id,
ForcePreferredConn: true,
})
require.NoError(t, err)
require.Equal(t, preferredConn.id, lease.ConnID(), "preferred 空闲时应直接命中")
lease.Release()
require.True(t, preferredConn.tryAcquire())
preferredConn.waiters.Store(1)
_, err = pool.Acquire(context.Background(), openAIWSAcquireRequest{
Account: account,
WSURL: "wss://example.com/v1/responses",
PreferredConnID: preferredConn.id,
ForcePreferredConn: true,
})
require.ErrorIs(t, err, errOpenAIWSConnQueueFull, "严格模式下队列满应直接失败,不得漂移")
preferredConn.waiters.Store(0)
preferredConn.release()
}
func TestOpenAIWSConnPool_CleanupSkipsPinnedConn(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 0
pool := newOpenAIWSConnPool(cfg)
accountID := int64(126)
ap := pool.getOrCreateAccountPool(accountID)
pinnedConn := newOpenAIWSConn("pinned_conn", accountID, &openAIWSFakeConn{}, nil)
idleConn := newOpenAIWSConn("idle_conn", accountID, &openAIWSFakeConn{}, nil)
ap.mu.Lock()
ap.conns[pinnedConn.id] = pinnedConn
ap.conns[idleConn.id] = idleConn
ap.mu.Unlock()
require.True(t, pool.PinConn(accountID, pinnedConn.id))
evicted := pool.cleanupAccountLocked(ap, time.Now(), pool.maxConnsHardCap())
closeOpenAIWSConns(evicted)
ap.mu.Lock()
_, pinnedExists := ap.conns[pinnedConn.id]
_, idleExists := ap.conns[idleConn.id]
ap.mu.Unlock()
require.True(t, pinnedExists, "被 active ingress 绑定的连接不应被 cleanup 回收")
require.False(t, idleExists, "非绑定的空闲连接应被回收")
pool.UnpinConn(accountID, pinnedConn.id)
evicted = pool.cleanupAccountLocked(ap, time.Now(), pool.maxConnsHardCap())
closeOpenAIWSConns(evicted)
ap.mu.Lock()
_, pinnedExists = ap.conns[pinnedConn.id]
ap.mu.Unlock()
require.False(t, pinnedExists, "解绑后连接应可被正常回收")
}
func TestOpenAIWSConnPool_PinUnpinConnBranches(t *testing.T) {
var nilPool *openAIWSConnPool
require.False(t, nilPool.PinConn(1, "x"))
nilPool.UnpinConn(1, "x")
cfg := &config.Config{}
pool := newOpenAIWSConnPool(cfg)
accountID := int64(128)
ap := &openAIWSAccountPool{
conns: map[string]*openAIWSConn{},
}
pool.accounts.Store(accountID, ap)
require.False(t, pool.PinConn(0, "x"))
require.False(t, pool.PinConn(999, "x"))
require.False(t, pool.PinConn(accountID, ""))
require.False(t, pool.PinConn(accountID, "missing"))
conn := newOpenAIWSConn("pin_refcount", accountID, &openAIWSFakeConn{}, nil)
ap.mu.Lock()
ap.conns[conn.id] = conn
ap.mu.Unlock()
require.True(t, pool.PinConn(accountID, conn.id))
require.True(t, pool.PinConn(accountID, conn.id))
ap.mu.Lock()
require.Equal(t, 2, ap.pinnedConns[conn.id])
ap.mu.Unlock()
pool.UnpinConn(accountID, conn.id)
ap.mu.Lock()
require.Equal(t, 1, ap.pinnedConns[conn.id])
ap.mu.Unlock()
pool.UnpinConn(accountID, conn.id)
ap.mu.Lock()
_, exists := ap.pinnedConns[conn.id]
ap.mu.Unlock()
require.False(t, exists)
pool.UnpinConn(accountID, conn.id)
pool.UnpinConn(accountID, "")
pool.UnpinConn(0, conn.id)
pool.UnpinConn(999, conn.id)
}
func TestOpenAIWSConnPool_EffectiveMaxConnsByAccount(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8
cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled = true
cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor = 1.0
cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor = 0.6
pool := newOpenAIWSConnPool(cfg)
oauthHigh := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 10}
require.Equal(t, 8, pool.effectiveMaxConnsByAccount(oauthHigh), "应受全局硬上限约束")
oauthLow := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 3}
require.Equal(t, 3, pool.effectiveMaxConnsByAccount(oauthLow))
apiKeyHigh := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Concurrency: 10}
require.Equal(t, 6, pool.effectiveMaxConnsByAccount(apiKeyHigh), "API Key 应按系数缩放")
apiKeyLow := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Concurrency: 1}
require.Equal(t, 1, pool.effectiveMaxConnsByAccount(apiKeyLow), "最小值应保持为 1")
unlimited := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 0}
require.Equal(t, 8, pool.effectiveMaxConnsByAccount(unlimited), "无限并发应回退到全局硬上限")
require.Equal(t, 8, pool.effectiveMaxConnsByAccount(nil), "缺少账号上下文应回退到全局硬上限")
}
func TestOpenAIWSConnPool_EffectiveMaxConnsDisabledFallbackHardCap(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8
cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled = false
cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor = 1.0
cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor = 1.0
pool := newOpenAIWSConnPool(cfg)
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 2}
require.Equal(t, 8, pool.effectiveMaxConnsByAccount(account), "关闭动态模式后应保持旧行为")
}
func TestOpenAIWSConnPool_EffectiveMaxConnsByAccount_ModeRouterV2UsesAccountConcurrency(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8
cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled = true
cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor = 0.3
cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor = 0.6
pool := newOpenAIWSConnPool(cfg)
high := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 20}
require.Equal(t, 20, pool.effectiveMaxConnsByAccount(high), "v2 路径应直接使用账号并发数作为池上限")
nonPositive := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Concurrency: 0}
require.Equal(t, 0, pool.effectiveMaxConnsByAccount(nonPositive), "并发数<=0 时应不可调度")
}
func TestOpenAIWSConnPool_AcquireRejectsWhenEffectiveMaxConnsIsZero(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8
pool := newOpenAIWSConnPool(cfg)
account := &Account{ID: 901, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 0}
_, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{
Account: account,
WSURL: "wss://example.com/v1/responses",
})
require.ErrorIs(t, err, errOpenAIWSConnQueueFull)
}
func TestOpenAIWSConnLease_ReadMessageWithContextTimeout_PerRead(t *testing.T) {
conn := newOpenAIWSConn("timeout", 1, &openAIWSBlockingConn{readDelay: 80 * time.Millisecond}, nil)
lease := &openAIWSConnLease{conn: conn}
_, err := lease.ReadMessageWithContextTimeout(context.Background(), 20*time.Millisecond)
require.Error(t, err)
require.ErrorIs(t, err, context.DeadlineExceeded)
payload, err := lease.ReadMessageWithContextTimeout(context.Background(), 150*time.Millisecond)
require.NoError(t, err)
require.Contains(t, string(payload), "response.completed")
parentCtx, cancel := context.WithCancel(context.Background())
cancel()
_, err = lease.ReadMessageWithContextTimeout(parentCtx, 150*time.Millisecond)
require.Error(t, err)
require.ErrorIs(t, err, context.Canceled)
}
func TestOpenAIWSConnLease_WriteJSONWithContextTimeout_RespectsParentContext(t *testing.T) {
conn := newOpenAIWSConn("write_timeout_ctx", 1, &openAIWSWriteBlockingConn{}, nil)
lease := &openAIWSConnLease{conn: conn}
parentCtx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(20 * time.Millisecond)
cancel()
}()
start := time.Now()
err := lease.WriteJSONWithContextTimeout(parentCtx, map[string]any{"type": "response.create"}, 2*time.Minute)
elapsed := time.Since(start)
require.Error(t, err)
require.ErrorIs(t, err, context.Canceled)
require.Less(t, elapsed, 200*time.Millisecond)
}
func TestOpenAIWSConnLease_PingWithTimeout(t *testing.T) {
conn := newOpenAIWSConn("ping_ok", 1, &openAIWSFakeConn{}, nil)
lease := &openAIWSConnLease{conn: conn}
require.NoError(t, lease.PingWithTimeout(50*time.Millisecond))
var nilLease *openAIWSConnLease
err := nilLease.PingWithTimeout(50 * time.Millisecond)
require.ErrorIs(t, err, errOpenAIWSConnClosed)
}
func TestOpenAIWSConn_ReadAndWriteCanProceedConcurrently(t *testing.T) {
conn := newOpenAIWSConn("full_duplex", 1, &openAIWSBlockingConn{readDelay: 120 * time.Millisecond}, nil)
readDone := make(chan error, 1)
go func() {
_, err := conn.readMessageWithContextTimeout(context.Background(), 200*time.Millisecond)
readDone <- err
}()
// 让读取先占用 readMu。
time.Sleep(20 * time.Millisecond)
start := time.Now()
err := conn.pingWithTimeout(50 * time.Millisecond)
elapsed := time.Since(start)
require.NoError(t, err)
require.Less(t, elapsed, 80*time.Millisecond, "写路径不应被读锁长期阻塞")
require.NoError(t, <-readDone)
}
func TestOpenAIWSConnPool_BackgroundPingSweep_EvictsDeadIdleConn(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
pool := newOpenAIWSConnPool(cfg)
accountID := int64(301)
ap := pool.getOrCreateAccountPool(accountID)
conn := newOpenAIWSConn("dead_idle", accountID, &openAIWSPingFailConn{}, nil)
ap.mu.Lock()
ap.conns[conn.id] = conn
ap.mu.Unlock()
pool.runBackgroundPingSweep()
ap.mu.Lock()
_, exists := ap.conns[conn.id]
ap.mu.Unlock()
require.False(t, exists, "后台 ping 失败的空闲连接应被回收")
}
func TestOpenAIWSConnPool_BackgroundCleanupSweep_WithoutAcquire(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2
pool := newOpenAIWSConnPool(cfg)
accountID := int64(302)
ap := pool.getOrCreateAccountPool(accountID)
stale := newOpenAIWSConn("stale_bg", accountID, &openAIWSFakeConn{}, nil)
stale.createdAtNano.Store(time.Now().Add(-2 * time.Hour).UnixNano())
stale.lastUsedNano.Store(time.Now().Add(-2 * time.Hour).UnixNano())
ap.mu.Lock()
ap.conns[stale.id] = stale
ap.mu.Unlock()
pool.runBackgroundCleanupSweep(time.Now())
ap.mu.Lock()
_, exists := ap.conns[stale.id]
ap.mu.Unlock()
require.False(t, exists, "后台清理应在无新 acquire 时也回收过期连接")
}
func TestOpenAIWSConnPool_BackgroundWorkerGuardBranches(t *testing.T) {
var nilPool *openAIWSConnPool
require.NotPanics(t, func() {
nilPool.startBackgroundWorkers()
nilPool.runBackgroundPingWorker()
nilPool.runBackgroundPingSweep()
_ = nilPool.snapshotIdleConnsForPing()
nilPool.runBackgroundCleanupWorker()
nilPool.runBackgroundCleanupSweep(time.Now())
})
poolNoStop := &openAIWSConnPool{}
require.NotPanics(t, func() {
poolNoStop.startBackgroundWorkers()
})
poolStopPing := &openAIWSConnPool{workerStopCh: make(chan struct{})}
pingDone := make(chan struct{})
go func() {
poolStopPing.runBackgroundPingWorker()
close(pingDone)
}()
close(poolStopPing.workerStopCh)
select {
case <-pingDone:
case <-time.After(500 * time.Millisecond):
t.Fatal("runBackgroundPingWorker 未在 stop 信号后退出")
}
poolStopCleanup := &openAIWSConnPool{workerStopCh: make(chan struct{})}
cleanupDone := make(chan struct{})
go func() {
poolStopCleanup.runBackgroundCleanupWorker()
close(cleanupDone)
}()
close(poolStopCleanup.workerStopCh)
select {
case <-cleanupDone:
case <-time.After(500 * time.Millisecond):
t.Fatal("runBackgroundCleanupWorker 未在 stop 信号后退出")
}
}
func TestOpenAIWSConnPool_SnapshotIdleConnsForPing_SkipsInvalidEntries(t *testing.T) {
pool := &openAIWSConnPool{}
pool.accounts.Store("invalid-key", &openAIWSAccountPool{})
pool.accounts.Store(int64(123), "invalid-value")
accountID := int64(123)
ap := &openAIWSAccountPool{
conns: make(map[string]*openAIWSConn),
}
ap.conns["nil_conn"] = nil
leased := newOpenAIWSConn("leased", accountID, &openAIWSFakeConn{}, nil)
require.True(t, leased.tryAcquire())
ap.conns[leased.id] = leased
waiting := newOpenAIWSConn("waiting", accountID, &openAIWSFakeConn{}, nil)
waiting.waiters.Store(1)
ap.conns[waiting.id] = waiting
idle := newOpenAIWSConn("idle", accountID, &openAIWSFakeConn{}, nil)
ap.conns[idle.id] = idle
pool.accounts.Store(accountID, ap)
candidates := pool.snapshotIdleConnsForPing()
require.Len(t, candidates, 1)
require.Equal(t, idle.id, candidates[0].conn.id)
}
func TestOpenAIWSConnPool_RunBackgroundCleanupSweep_SkipsInvalidAndUsesAccountCap(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 4
cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled = true
pool := &openAIWSConnPool{cfg: cfg}
pool.accounts.Store("bad-key", "bad-value")
accountID := int64(2026)
ap := &openAIWSAccountPool{
conns: make(map[string]*openAIWSConn),
}
ap.conns["nil_conn"] = nil
stale := newOpenAIWSConn("stale_bg_cleanup", accountID, &openAIWSFakeConn{}, nil)
stale.createdAtNano.Store(time.Now().Add(-2 * time.Hour).UnixNano())
stale.lastUsedNano.Store(time.Now().Add(-2 * time.Hour).UnixNano())
ap.conns[stale.id] = stale
ap.lastAcquire = &openAIWSAcquireRequest{
Account: &Account{
ID: accountID,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
},
}
pool.accounts.Store(accountID, ap)
now := time.Now()
require.NotPanics(t, func() {
pool.runBackgroundCleanupSweep(now)
})
ap.mu.Lock()
_, nilConnExists := ap.conns["nil_conn"]
_, exists := ap.conns[stale.id]
lastCleanupAt := ap.lastCleanupAt
ap.mu.Unlock()
require.False(t, nilConnExists, "后台清理应移除无效 nil 连接条目")
require.False(t, exists, "后台清理应清理过期连接")
require.Equal(t, now, lastCleanupAt)
}
func TestOpenAIWSConnPool_QueueLimitPerConn_DefaultAndConfigured(t *testing.T) {
var nilPool *openAIWSConnPool
require.Equal(t, 256, nilPool.queueLimitPerConn())
pool := &openAIWSConnPool{cfg: &config.Config{}}
require.Equal(t, 256, pool.queueLimitPerConn())
pool.cfg.Gateway.OpenAIWS.QueueLimitPerConn = 9
require.Equal(t, 9, pool.queueLimitPerConn())
}
func TestOpenAIWSConnPool_Close(t *testing.T) {
cfg := &config.Config{}
pool := newOpenAIWSConnPool(cfg)
// Close 应该可以安全调用
pool.Close()
// workerStopCh 应已关闭
select {
case <-pool.workerStopCh:
// 预期channel 已关闭
default:
t.Fatal("Close 后 workerStopCh 应已关闭")
}
// 多次调用 Close 不应 panic
pool.Close()
// nil pool 调用 Close 不应 panic
var nilPool *openAIWSConnPool
nilPool.Close()
}
func TestOpenAIWSDialError_ErrorAndUnwrap(t *testing.T) {
baseErr := errors.New("boom")
dialErr := &openAIWSDialError{StatusCode: 502, Err: baseErr}
require.Contains(t, dialErr.Error(), "status=502")
require.ErrorIs(t, dialErr.Unwrap(), baseErr)
noStatus := &openAIWSDialError{Err: baseErr}
require.Contains(t, noStatus.Error(), "boom")
var nilDialErr *openAIWSDialError
require.Equal(t, "", nilDialErr.Error())
require.NoError(t, nilDialErr.Unwrap())
}
func TestOpenAIWSConnLease_ReadWriteHelpersAndConnStats(t *testing.T) {
conn := newOpenAIWSConn("helper_conn", 1, &openAIWSFakeConn{}, http.Header{
"X-Test": []string{" value "},
})
lease := &openAIWSConnLease{conn: conn}
require.NoError(t, lease.WriteJSONContext(context.Background(), map[string]any{"type": "response.create"}))
payload, err := lease.ReadMessage(100 * time.Millisecond)
require.NoError(t, err)
require.Contains(t, string(payload), "response.completed")
payload, err = lease.ReadMessageContext(context.Background())
require.NoError(t, err)
require.Contains(t, string(payload), "response.completed")
payload, err = conn.readMessageWithTimeout(100 * time.Millisecond)
require.NoError(t, err)
require.Contains(t, string(payload), "response.completed")
require.Equal(t, "value", conn.handshakeHeader(" X-Test "))
require.NotZero(t, conn.createdAt())
require.NotZero(t, conn.lastUsedAt())
require.GreaterOrEqual(t, conn.age(time.Now()), time.Duration(0))
require.GreaterOrEqual(t, conn.idleDuration(time.Now()), time.Duration(0))
require.False(t, conn.isLeased())
// 覆盖空上下文路径
_, err = conn.readMessage(context.Background())
require.NoError(t, err)
// 覆盖 nil 保护分支
var nilConn *openAIWSConn
require.ErrorIs(t, nilConn.writeJSONWithTimeout(context.Background(), map[string]any{}, time.Second), errOpenAIWSConnClosed)
_, err = nilConn.readMessageWithTimeout(10 * time.Millisecond)
require.ErrorIs(t, err, errOpenAIWSConnClosed)
_, err = nilConn.readMessageWithContextTimeout(context.Background(), 10*time.Millisecond)
require.ErrorIs(t, err, errOpenAIWSConnClosed)
}
func TestOpenAIWSConnPool_PickOldestIdleAndAccountPoolLoad(t *testing.T) {
pool := &openAIWSConnPool{}
accountID := int64(404)
ap := &openAIWSAccountPool{conns: map[string]*openAIWSConn{}}
idleOld := newOpenAIWSConn("idle_old", accountID, &openAIWSFakeConn{}, nil)
idleOld.lastUsedNano.Store(time.Now().Add(-10 * time.Minute).UnixNano())
idleNew := newOpenAIWSConn("idle_new", accountID, &openAIWSFakeConn{}, nil)
idleNew.lastUsedNano.Store(time.Now().Add(-1 * time.Minute).UnixNano())
leased := newOpenAIWSConn("leased", accountID, &openAIWSFakeConn{}, nil)
require.True(t, leased.tryAcquire())
leased.waiters.Store(2)
ap.conns[idleOld.id] = idleOld
ap.conns[idleNew.id] = idleNew
ap.conns[leased.id] = leased
oldest := pool.pickOldestIdleConnLocked(ap)
require.NotNil(t, oldest)
require.Equal(t, idleOld.id, oldest.id)
inflight, waiters := accountPoolLoadLocked(ap)
require.Equal(t, 1, inflight)
require.Equal(t, 2, waiters)
pool.accounts.Store(accountID, ap)
loadInflight, loadWaiters, conns := pool.AccountPoolLoad(accountID)
require.Equal(t, 1, loadInflight)
require.Equal(t, 2, loadWaiters)
require.Equal(t, 3, conns)
zeroInflight, zeroWaiters, zeroConns := pool.AccountPoolLoad(0)
require.Equal(t, 0, zeroInflight)
require.Equal(t, 0, zeroWaiters)
require.Equal(t, 0, zeroConns)
}
func TestOpenAIWSConnPool_Close_WaitsWorkerGroupAndNilStopChannel(t *testing.T) {
pool := &openAIWSConnPool{}
release := make(chan struct{})
pool.workerWg.Add(1)
go func() {
defer pool.workerWg.Done()
<-release
}()
closed := make(chan struct{})
go func() {
pool.Close()
close(closed)
}()
select {
case <-closed:
t.Fatal("Close 不应在 WaitGroup 未完成时提前返回")
case <-time.After(30 * time.Millisecond):
}
close(release)
select {
case <-closed:
case <-time.After(time.Second):
t.Fatal("Close 未等待 workerWg 完成")
}
}
func TestOpenAIWSConnPool_Close_ClosesOnlyIdleConnections(t *testing.T) {
pool := &openAIWSConnPool{
workerStopCh: make(chan struct{}),
}
accountID := int64(606)
ap := &openAIWSAccountPool{
conns: map[string]*openAIWSConn{},
}
idle := newOpenAIWSConn("idle_conn", accountID, &openAIWSFakeConn{}, nil)
leased := newOpenAIWSConn("leased_conn", accountID, &openAIWSFakeConn{}, nil)
require.True(t, leased.tryAcquire())
ap.conns[idle.id] = idle
ap.conns[leased.id] = leased
pool.accounts.Store(accountID, ap)
pool.accounts.Store("invalid-key", "invalid-value")
pool.Close()
select {
case <-idle.closedCh:
// idle should be closed
default:
t.Fatal("空闲连接应在 Close 时被关闭")
}
select {
case <-leased.closedCh:
t.Fatal("已租赁连接不应在 Close 时被关闭")
default:
}
leased.release()
pool.Close()
}
func TestOpenAIWSConnPool_RunBackgroundPingSweep_ConcurrencyLimit(t *testing.T) {
cfg := &config.Config{}
pool := newOpenAIWSConnPool(cfg)
accountID := int64(505)
ap := pool.getOrCreateAccountPool(accountID)
var current atomic.Int32
var maxConcurrent atomic.Int32
release := make(chan struct{})
for i := 0; i < 25; i++ {
conn := newOpenAIWSConn(pool.nextConnID(accountID), accountID, &openAIWSPingBlockingConn{
current: &current,
maxConcurrent: &maxConcurrent,
release: release,
}, nil)
ap.mu.Lock()
ap.conns[conn.id] = conn
ap.mu.Unlock()
}
done := make(chan struct{})
go func() {
pool.runBackgroundPingSweep()
close(done)
}()
require.Eventually(t, func() bool {
return maxConcurrent.Load() >= 10
}, time.Second, 10*time.Millisecond)
close(release)
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("runBackgroundPingSweep 未在释放后完成")
}
require.LessOrEqual(t, maxConcurrent.Load(), int32(10))
}
func TestOpenAIWSConnLease_BasicGetterBranches(t *testing.T) {
var nilLease *openAIWSConnLease
require.Equal(t, "", nilLease.ConnID())
require.Equal(t, time.Duration(0), nilLease.QueueWaitDuration())
require.Equal(t, time.Duration(0), nilLease.ConnPickDuration())
require.False(t, nilLease.Reused())
require.Equal(t, "", nilLease.HandshakeHeader("x-test"))
require.False(t, nilLease.IsPrewarmed())
nilLease.MarkPrewarmed()
nilLease.Release()
conn := newOpenAIWSConn("getter_conn", 1, &openAIWSFakeConn{}, http.Header{"X-Test": []string{"ok"}})
lease := &openAIWSConnLease{
conn: conn,
queueWait: 3 * time.Millisecond,
connPick: 4 * time.Millisecond,
reused: true,
}
require.Equal(t, "getter_conn", lease.ConnID())
require.Equal(t, 3*time.Millisecond, lease.QueueWaitDuration())
require.Equal(t, 4*time.Millisecond, lease.ConnPickDuration())
require.True(t, lease.Reused())
require.Equal(t, "ok", lease.HandshakeHeader("x-test"))
require.False(t, lease.IsPrewarmed())
lease.MarkPrewarmed()
require.True(t, lease.IsPrewarmed())
lease.Release()
}
func TestOpenAIWSConnPool_UtilityBranches(t *testing.T) {
var nilPool *openAIWSConnPool
require.Equal(t, OpenAIWSPoolMetricsSnapshot{}, nilPool.SnapshotMetrics())
require.Equal(t, OpenAIWSTransportMetricsSnapshot{}, nilPool.SnapshotTransportMetrics())
pool := &openAIWSConnPool{cfg: &config.Config{}}
pool.metrics.acquireTotal.Store(7)
pool.metrics.acquireReuseTotal.Store(3)
metrics := pool.SnapshotMetrics()
require.Equal(t, int64(7), metrics.AcquireTotal)
require.Equal(t, int64(3), metrics.AcquireReuseTotal)
// 非 transport metrics dialer 路径
pool.clientDialer = &openAIWSFakeDialer{}
require.Equal(t, OpenAIWSTransportMetricsSnapshot{}, pool.SnapshotTransportMetrics())
pool.setClientDialerForTest(nil)
require.NotNil(t, pool.clientDialer)
require.Equal(t, 8, nilPool.maxConnsHardCap())
require.False(t, nilPool.dynamicMaxConnsEnabled())
require.Equal(t, 1.0, nilPool.maxConnsFactorByAccount(nil))
require.Equal(t, 0, nilPool.minIdlePerAccount())
require.Equal(t, 4, nilPool.maxIdlePerAccount())
require.Equal(t, 256, nilPool.queueLimitPerConn())
require.Equal(t, 0.7, nilPool.targetUtilization())
require.Equal(t, time.Duration(0), nilPool.prewarmCooldown())
require.Equal(t, 10*time.Second, nilPool.dialTimeout())
// shouldSuppressPrewarmLocked 覆盖 3 条分支
now := time.Now()
apNilFail := &openAIWSAccountPool{prewarmFails: 1}
require.False(t, pool.shouldSuppressPrewarmLocked(apNilFail, now))
apZeroTime := &openAIWSAccountPool{prewarmFails: 2}
require.False(t, pool.shouldSuppressPrewarmLocked(apZeroTime, now))
require.Equal(t, 0, apZeroTime.prewarmFails)
apOldFail := &openAIWSAccountPool{prewarmFails: 2, prewarmFailAt: now.Add(-openAIWSPrewarmFailureWindow - time.Second)}
require.False(t, pool.shouldSuppressPrewarmLocked(apOldFail, now))
apRecentFail := &openAIWSAccountPool{prewarmFails: openAIWSPrewarmFailureSuppress, prewarmFailAt: now}
require.True(t, pool.shouldSuppressPrewarmLocked(apRecentFail, now))
// recordConnPickDuration 的保护分支
nilPool.recordConnPickDuration(10 * time.Millisecond)
pool.recordConnPickDuration(-10 * time.Millisecond)
require.Equal(t, int64(1), pool.metrics.connPickTotal.Load())
// account pool 读写分支
require.Nil(t, nilPool.getOrCreateAccountPool(1))
require.Nil(t, pool.getOrCreateAccountPool(0))
pool.accounts.Store(int64(7), "invalid")
ap := pool.getOrCreateAccountPool(7)
require.NotNil(t, ap)
_, ok := pool.getAccountPool(0)
require.False(t, ok)
_, ok = pool.getAccountPool(12345)
require.False(t, ok)
pool.accounts.Store(int64(8), "bad-type")
_, ok = pool.getAccountPool(8)
require.False(t, ok)
// health check 条件
require.False(t, pool.shouldHealthCheckConn(nil))
conn := newOpenAIWSConn("health", 1, &openAIWSFakeConn{}, nil)
conn.lastUsedNano.Store(time.Now().Add(-openAIWSConnHealthCheckIdle - time.Second).UnixNano())
require.True(t, pool.shouldHealthCheckConn(conn))
}
func TestOpenAIWSConn_LeaseAndTimeHelpers_NilAndClosedBranches(t *testing.T) {
var nilConn *openAIWSConn
nilConn.touch()
require.Equal(t, time.Time{}, nilConn.createdAt())
require.Equal(t, time.Time{}, nilConn.lastUsedAt())
require.Equal(t, time.Duration(0), nilConn.idleDuration(time.Now()))
require.Equal(t, time.Duration(0), nilConn.age(time.Now()))
require.False(t, nilConn.isLeased())
require.False(t, nilConn.isPrewarmed())
nilConn.markPrewarmed()
conn := newOpenAIWSConn("lease_state", 1, &openAIWSFakeConn{}, nil)
require.True(t, conn.tryAcquire())
require.True(t, conn.isLeased())
conn.release()
require.False(t, conn.isLeased())
conn.close()
require.False(t, conn.tryAcquire())
ctx, cancel := context.WithCancel(context.Background())
cancel()
err := conn.acquire(ctx)
require.Error(t, err)
}
func TestOpenAIWSConnLease_ReadWriteNilConnBranches(t *testing.T) {
lease := &openAIWSConnLease{}
require.ErrorIs(t, lease.WriteJSON(map[string]any{"k": "v"}, time.Second), errOpenAIWSConnClosed)
require.ErrorIs(t, lease.WriteJSONContext(context.Background(), map[string]any{"k": "v"}), errOpenAIWSConnClosed)
_, err := lease.ReadMessage(10 * time.Millisecond)
require.ErrorIs(t, err, errOpenAIWSConnClosed)
_, err = lease.ReadMessageContext(context.Background())
require.ErrorIs(t, err, errOpenAIWSConnClosed)
_, err = lease.ReadMessageWithContextTimeout(context.Background(), 10*time.Millisecond)
require.ErrorIs(t, err, errOpenAIWSConnClosed)
}
func TestOpenAIWSConnLease_ReleasedLeaseGuards(t *testing.T) {
conn := newOpenAIWSConn("released_guard", 1, &openAIWSFakeConn{}, nil)
lease := &openAIWSConnLease{conn: conn}
require.NoError(t, lease.PingWithTimeout(50*time.Millisecond))
lease.Release()
lease.Release() // idempotent
require.ErrorIs(t, lease.WriteJSON(map[string]any{"k": "v"}, time.Second), errOpenAIWSConnClosed)
require.ErrorIs(t, lease.WriteJSONContext(context.Background(), map[string]any{"k": "v"}), errOpenAIWSConnClosed)
require.ErrorIs(t, lease.WriteJSONWithContextTimeout(context.Background(), map[string]any{"k": "v"}, time.Second), errOpenAIWSConnClosed)
_, err := lease.ReadMessage(10 * time.Millisecond)
require.ErrorIs(t, err, errOpenAIWSConnClosed)
_, err = lease.ReadMessageContext(context.Background())
require.ErrorIs(t, err, errOpenAIWSConnClosed)
_, err = lease.ReadMessageWithContextTimeout(context.Background(), 10*time.Millisecond)
require.ErrorIs(t, err, errOpenAIWSConnClosed)
require.ErrorIs(t, lease.PingWithTimeout(50*time.Millisecond), errOpenAIWSConnClosed)
}
func TestOpenAIWSConnLease_MarkBrokenAfterRelease_NoEviction(t *testing.T) {
conn := newOpenAIWSConn("released_markbroken", 7, &openAIWSFakeConn{}, nil)
ap := &openAIWSAccountPool{
conns: map[string]*openAIWSConn{
conn.id: conn,
},
}
pool := &openAIWSConnPool{}
pool.accounts.Store(int64(7), ap)
lease := &openAIWSConnLease{
pool: pool,
accountID: 7,
conn: conn,
}
lease.Release()
lease.MarkBroken()
ap.mu.Lock()
_, exists := ap.conns[conn.id]
ap.mu.Unlock()
require.True(t, exists, "released lease should not evict active pool connection")
}
func TestOpenAIWSConn_AdditionalGuardBranches(t *testing.T) {
var nilConn *openAIWSConn
require.False(t, nilConn.tryAcquire())
require.ErrorIs(t, nilConn.acquire(context.Background()), errOpenAIWSConnClosed)
nilConn.release()
nilConn.close()
require.Equal(t, "", nilConn.handshakeHeader("x-test"))
connBusy := newOpenAIWSConn("busy_ctx", 1, &openAIWSFakeConn{}, nil)
require.True(t, connBusy.tryAcquire())
ctx, cancel := context.WithCancel(context.Background())
cancel()
require.ErrorIs(t, connBusy.acquire(ctx), context.Canceled)
connBusy.release()
connClosed := newOpenAIWSConn("closed_guard", 1, &openAIWSFakeConn{}, nil)
connClosed.close()
require.ErrorIs(
t,
connClosed.writeJSONWithTimeout(context.Background(), map[string]any{"k": "v"}, time.Second),
errOpenAIWSConnClosed,
)
_, err := connClosed.readMessageWithContextTimeout(context.Background(), time.Second)
require.ErrorIs(t, err, errOpenAIWSConnClosed)
require.ErrorIs(t, connClosed.pingWithTimeout(time.Second), errOpenAIWSConnClosed)
connNoWS := newOpenAIWSConn("no_ws", 1, nil, nil)
require.ErrorIs(t, connNoWS.writeJSON(map[string]any{"k": "v"}, context.Background()), errOpenAIWSConnClosed)
_, err = connNoWS.readMessage(context.Background())
require.ErrorIs(t, err, errOpenAIWSConnClosed)
require.ErrorIs(t, connNoWS.pingWithTimeout(time.Second), errOpenAIWSConnClosed)
require.Equal(t, "", connNoWS.handshakeHeader("x-test"))
connOK := newOpenAIWSConn("ok", 1, &openAIWSFakeConn{}, nil)
require.NoError(t, connOK.writeJSON(map[string]any{"k": "v"}, nil))
_, err = connOK.readMessageWithContextTimeout(context.Background(), 0)
require.NoError(t, err)
require.NoError(t, connOK.pingWithTimeout(0))
connZero := newOpenAIWSConn("zero_ts", 1, &openAIWSFakeConn{}, nil)
connZero.createdAtNano.Store(0)
connZero.lastUsedNano.Store(0)
require.True(t, connZero.createdAt().IsZero())
require.True(t, connZero.lastUsedAt().IsZero())
require.Equal(t, time.Duration(0), connZero.idleDuration(time.Now()))
require.Equal(t, time.Duration(0), connZero.age(time.Now()))
require.Nil(t, cloneOpenAIWSAcquireRequestPtr(nil))
copied := cloneHeader(http.Header{
"X-Empty": []string{},
"X-Test": []string{"v1"},
})
require.Contains(t, copied, "X-Empty")
require.Nil(t, copied["X-Empty"])
require.Equal(t, "v1", copied.Get("X-Test"))
closeOpenAIWSConns([]*openAIWSConn{nil, connOK})
}
func TestOpenAIWSConnLease_MarkBrokenEvictsConn(t *testing.T) {
pool := newOpenAIWSConnPool(&config.Config{})
accountID := int64(5001)
conn := newOpenAIWSConn("broken_me", accountID, &openAIWSFakeConn{}, nil)
ap := pool.getOrCreateAccountPool(accountID)
ap.mu.Lock()
ap.conns[conn.id] = conn
ap.mu.Unlock()
lease := &openAIWSConnLease{
pool: pool,
accountID: accountID,
conn: conn,
}
lease.MarkBroken()
ap.mu.Lock()
_, exists := ap.conns[conn.id]
ap.mu.Unlock()
require.False(t, exists)
require.False(t, conn.tryAcquire(), "被标记为 broken 的连接应被关闭")
}
func TestOpenAIWSConnPool_TargetConnCountAndPrewarmBranches(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
pool := newOpenAIWSConnPool(cfg)
require.Equal(t, 0, pool.targetConnCountLocked(nil, 1))
ap := &openAIWSAccountPool{conns: map[string]*openAIWSConn{}}
require.Equal(t, 0, pool.targetConnCountLocked(ap, 0))
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 3
require.Equal(t, 1, pool.targetConnCountLocked(ap, 1), "minIdle 应被 maxConns 截断")
// 覆盖 waiters>0 且 target 需要至少 len(conns)+1 的分支
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.9
busy := newOpenAIWSConn("busy_target", 2, &openAIWSFakeConn{}, nil)
require.True(t, busy.tryAcquire())
busy.waiters.Store(1)
ap.conns[busy.id] = busy
target := pool.targetConnCountLocked(ap, 4)
require.GreaterOrEqual(t, target, len(ap.conns)+1)
// prewarm: account pool 缺失时,拨号后的连接应被关闭并提前返回
req := openAIWSAcquireRequest{
Account: &Account{ID: 999, Platform: PlatformOpenAI, Type: AccountTypeAPIKey},
WSURL: "wss://example.com/v1/responses",
}
pool.prewarmConns(999, req, 1)
// prewarm: 拨号失败分支prewarmFails 累加)
accountID := int64(1000)
failPool := newOpenAIWSConnPool(cfg)
failPool.setClientDialerForTest(&openAIWSAlwaysFailDialer{})
apFail := failPool.getOrCreateAccountPool(accountID)
apFail.mu.Lock()
apFail.creating = 1
apFail.mu.Unlock()
req.Account.ID = accountID
failPool.prewarmConns(accountID, req, 1)
apFail.mu.Lock()
require.GreaterOrEqual(t, apFail.prewarmFails, 1)
apFail.mu.Unlock()
}
func TestOpenAIWSConnPool_Acquire_ErrorBranches(t *testing.T) {
var nilPool *openAIWSConnPool
_, err := nilPool.Acquire(context.Background(), openAIWSAcquireRequest{})
require.Error(t, err)
pool := newOpenAIWSConnPool(&config.Config{})
_, err = pool.Acquire(context.Background(), openAIWSAcquireRequest{
Account: &Account{ID: 1},
WSURL: " ",
})
require.Error(t, err)
require.Contains(t, err.Error(), "ws url is empty")
// target=nil 分支:池满且仅有 nil 连接
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 1
fullPool := newOpenAIWSConnPool(cfg)
account := &Account{ID: 2001, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
ap := fullPool.getOrCreateAccountPool(account.ID)
ap.mu.Lock()
ap.conns["nil"] = nil
ap.lastCleanupAt = time.Now()
ap.mu.Unlock()
_, err = fullPool.Acquire(context.Background(), openAIWSAcquireRequest{
Account: account,
WSURL: "wss://example.com/v1/responses",
})
require.ErrorIs(t, err, errOpenAIWSConnClosed)
// queue full 分支waiters 达上限
account2 := &Account{ID: 2002, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
ap2 := fullPool.getOrCreateAccountPool(account2.ID)
conn := newOpenAIWSConn("queue_full", account2.ID, &openAIWSFakeConn{}, nil)
require.True(t, conn.tryAcquire())
conn.waiters.Store(1)
ap2.mu.Lock()
ap2.conns[conn.id] = conn
ap2.lastCleanupAt = time.Now()
ap2.mu.Unlock()
_, err = fullPool.Acquire(context.Background(), openAIWSAcquireRequest{
Account: account2,
WSURL: "wss://example.com/v1/responses",
})
require.ErrorIs(t, err, errOpenAIWSConnQueueFull)
}
type openAIWSFakeDialer struct{}
func (d *openAIWSFakeDialer) Dial(
ctx context.Context,
wsURL string,
headers http.Header,
proxyURL string,
) (openAIWSClientConn, int, http.Header, error) {
_ = ctx
_ = wsURL
_ = headers
_ = proxyURL
return &openAIWSFakeConn{}, 0, nil, nil
}
type openAIWSCountingDialer struct {
mu sync.Mutex
dialCount int
}
type openAIWSAlwaysFailDialer struct {
mu sync.Mutex
dialCount int
}
type openAIWSPingBlockingConn struct {
current *atomic.Int32
maxConcurrent *atomic.Int32
release <-chan struct{}
}
func (c *openAIWSPingBlockingConn) WriteJSON(context.Context, any) error {
return nil
}
func (c *openAIWSPingBlockingConn) ReadMessage(context.Context) ([]byte, error) {
return []byte(`{"type":"response.completed","response":{"id":"resp_blocking_ping"}}`), nil
}
func (c *openAIWSPingBlockingConn) Ping(ctx context.Context) error {
if c.current == nil || c.maxConcurrent == nil {
return nil
}
now := c.current.Add(1)
for {
prev := c.maxConcurrent.Load()
if now <= prev || c.maxConcurrent.CompareAndSwap(prev, now) {
break
}
}
defer c.current.Add(-1)
select {
case <-ctx.Done():
return ctx.Err()
case <-c.release:
return nil
}
}
func (c *openAIWSPingBlockingConn) Close() error {
return nil
}
func (d *openAIWSCountingDialer) Dial(
ctx context.Context,
wsURL string,
headers http.Header,
proxyURL string,
) (openAIWSClientConn, int, http.Header, error) {
_ = ctx
_ = wsURL
_ = headers
_ = proxyURL
d.mu.Lock()
d.dialCount++
d.mu.Unlock()
return &openAIWSFakeConn{}, 0, nil, nil
}
func (d *openAIWSCountingDialer) DialCount() int {
d.mu.Lock()
defer d.mu.Unlock()
return d.dialCount
}
func (d *openAIWSAlwaysFailDialer) Dial(
ctx context.Context,
wsURL string,
headers http.Header,
proxyURL string,
) (openAIWSClientConn, int, http.Header, error) {
_ = ctx
_ = wsURL
_ = headers
_ = proxyURL
d.mu.Lock()
d.dialCount++
d.mu.Unlock()
return nil, 503, nil, errors.New("dial failed")
}
func (d *openAIWSAlwaysFailDialer) DialCount() int {
d.mu.Lock()
defer d.mu.Unlock()
return d.dialCount
}
type openAIWSFakeConn struct {
mu sync.Mutex
closed bool
payload [][]byte
}
func (c *openAIWSFakeConn) WriteJSON(ctx context.Context, value any) error {
_ = ctx
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return errors.New("closed")
}
c.payload = append(c.payload, []byte("ok"))
_ = value
return nil
}
func (c *openAIWSFakeConn) ReadMessage(ctx context.Context) ([]byte, error) {
_ = ctx
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return nil, errors.New("closed")
}
return []byte(`{"type":"response.completed","response":{"id":"resp_fake"}}`), nil
}
func (c *openAIWSFakeConn) Ping(ctx context.Context) error {
_ = ctx
return nil
}
func (c *openAIWSFakeConn) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
c.closed = true
return nil
}
type openAIWSBlockingConn struct {
readDelay time.Duration
}
func (c *openAIWSBlockingConn) WriteJSON(ctx context.Context, value any) error {
_ = ctx
_ = value
return nil
}
func (c *openAIWSBlockingConn) ReadMessage(ctx context.Context) ([]byte, error) {
delay := c.readDelay
if delay <= 0 {
delay = 10 * time.Millisecond
}
timer := time.NewTimer(delay)
defer timer.Stop()
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-timer.C:
return []byte(`{"type":"response.completed","response":{"id":"resp_blocking"}}`), nil
}
}
func (c *openAIWSBlockingConn) Ping(ctx context.Context) error {
_ = ctx
return nil
}
func (c *openAIWSBlockingConn) Close() error {
return nil
}
type openAIWSWriteBlockingConn struct{}
func (c *openAIWSWriteBlockingConn) WriteJSON(ctx context.Context, _ any) error {
<-ctx.Done()
return ctx.Err()
}
func (c *openAIWSWriteBlockingConn) ReadMessage(context.Context) ([]byte, error) {
return []byte(`{"type":"response.completed","response":{"id":"resp_write_block"}}`), nil
}
func (c *openAIWSWriteBlockingConn) Ping(context.Context) error {
return nil
}
func (c *openAIWSWriteBlockingConn) Close() error {
return nil
}
type openAIWSPingFailConn struct{}
func (c *openAIWSPingFailConn) WriteJSON(context.Context, any) error {
return nil
}
func (c *openAIWSPingFailConn) ReadMessage(context.Context) ([]byte, error) {
return []byte(`{"type":"response.completed","response":{"id":"resp_ping_fail"}}`), nil
}
func (c *openAIWSPingFailConn) Ping(context.Context) error {
return errors.New("ping failed")
}
func (c *openAIWSPingFailConn) Close() error {
return nil
}
type openAIWSContextProbeConn struct {
lastWriteCtx context.Context
}
func (c *openAIWSContextProbeConn) WriteJSON(ctx context.Context, _ any) error {
c.lastWriteCtx = ctx
return nil
}
func (c *openAIWSContextProbeConn) ReadMessage(context.Context) ([]byte, error) {
return []byte(`{"type":"response.completed","response":{"id":"resp_ctx_probe"}}`), nil
}
func (c *openAIWSContextProbeConn) Ping(context.Context) error {
return nil
}
func (c *openAIWSContextProbeConn) Close() error {
return nil
}
type openAIWSNilConnDialer struct{}
func (d *openAIWSNilConnDialer) Dial(
ctx context.Context,
wsURL string,
headers http.Header,
proxyURL string,
) (openAIWSClientConn, int, http.Header, error) {
_ = ctx
_ = wsURL
_ = headers
_ = proxyURL
return nil, 200, nil, nil
}
func TestOpenAIWSConnPool_DialConnNilConnection(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1
pool := newOpenAIWSConnPool(cfg)
pool.setClientDialerForTest(&openAIWSNilConnDialer{})
account := &Account{ID: 91, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
_, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{
Account: account,
WSURL: "wss://example.com/v1/responses",
})
require.Error(t, err)
require.Contains(t, err.Error(), "nil connection")
}
func TestOpenAIWSConnPool_SnapshotTransportMetrics(t *testing.T) {
cfg := &config.Config{}
pool := newOpenAIWSConnPool(cfg)
dialer, ok := pool.clientDialer.(*coderOpenAIWSClientDialer)
require.True(t, ok)
_, err := dialer.proxyHTTPClient("http://127.0.0.1:28080")
require.NoError(t, err)
_, err = dialer.proxyHTTPClient("http://127.0.0.1:28080")
require.NoError(t, err)
_, err = dialer.proxyHTTPClient("http://127.0.0.1:28081")
require.NoError(t, err)
snapshot := pool.SnapshotTransportMetrics()
require.Equal(t, int64(1), snapshot.ProxyClientCacheHits)
require.Equal(t, int64(2), snapshot.ProxyClientCacheMisses)
require.InDelta(t, 1.0/3.0, snapshot.TransportReuseRatio, 0.0001)
}