1710 lines
54 KiB
Go
1710 lines
54 KiB
Go
package service
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"net/http"
|
||
"strings"
|
||
"sync"
|
||
"sync/atomic"
|
||
"testing"
|
||
"time"
|
||
|
||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||
"github.com/stretchr/testify/require"
|
||
)
|
||
|
||
func TestOpenAIWSConnPool_CleanupStaleAndTrimIdle(t *testing.T) {
|
||
cfg := &config.Config{}
|
||
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
|
||
pool := newOpenAIWSConnPool(cfg)
|
||
|
||
accountID := int64(10)
|
||
ap := pool.getOrCreateAccountPool(accountID)
|
||
|
||
stale := newOpenAIWSConn("stale", accountID, nil, nil)
|
||
stale.createdAtNano.Store(time.Now().Add(-2 * time.Hour).UnixNano())
|
||
stale.lastUsedNano.Store(time.Now().Add(-2 * time.Hour).UnixNano())
|
||
|
||
idleOld := newOpenAIWSConn("idle_old", accountID, nil, nil)
|
||
idleOld.lastUsedNano.Store(time.Now().Add(-10 * time.Minute).UnixNano())
|
||
|
||
idleNew := newOpenAIWSConn("idle_new", accountID, nil, nil)
|
||
idleNew.lastUsedNano.Store(time.Now().Add(-1 * time.Minute).UnixNano())
|
||
|
||
ap.conns[stale.id] = stale
|
||
ap.conns[idleOld.id] = idleOld
|
||
ap.conns[idleNew.id] = idleNew
|
||
|
||
evicted := pool.cleanupAccountLocked(ap, time.Now(), pool.maxConnsHardCap())
|
||
closeOpenAIWSConns(evicted)
|
||
|
||
require.Nil(t, ap.conns["stale"], "stale connection should be rotated")
|
||
require.Nil(t, ap.conns["idle_old"], "old idle should be trimmed by max_idle")
|
||
require.NotNil(t, ap.conns["idle_new"], "newer idle should be kept")
|
||
}
|
||
|
||
func TestOpenAIWSConnPool_NextConnIDFormat(t *testing.T) {
|
||
pool := newOpenAIWSConnPool(&config.Config{})
|
||
id1 := pool.nextConnID(42)
|
||
id2 := pool.nextConnID(42)
|
||
|
||
require.True(t, strings.HasPrefix(id1, "oa_ws_42_"))
|
||
require.True(t, strings.HasPrefix(id2, "oa_ws_42_"))
|
||
require.NotEqual(t, id1, id2)
|
||
require.Equal(t, "oa_ws_42_1", id1)
|
||
require.Equal(t, "oa_ws_42_2", id2)
|
||
}
|
||
|
||
func TestOpenAIWSConnPool_AcquireCleanupInterval(t *testing.T) {
|
||
require.Equal(t, 3*time.Second, openAIWSAcquireCleanupInterval)
|
||
require.Less(t, openAIWSAcquireCleanupInterval, openAIWSBackgroundSweepTicker)
|
||
}
|
||
|
||
func TestOpenAIWSConnLease_WriteJSONAndGuards(t *testing.T) {
|
||
conn := newOpenAIWSConn("lease_write", 1, &openAIWSFakeConn{}, nil)
|
||
lease := &openAIWSConnLease{conn: conn}
|
||
require.NoError(t, lease.WriteJSON(map[string]any{"type": "response.create"}, 0))
|
||
|
||
var nilLease *openAIWSConnLease
|
||
err := nilLease.WriteJSONWithContextTimeout(context.Background(), map[string]any{"type": "response.create"}, time.Second)
|
||
require.ErrorIs(t, err, errOpenAIWSConnClosed)
|
||
|
||
err = (&openAIWSConnLease{}).WriteJSONWithContextTimeout(context.Background(), map[string]any{"type": "response.create"}, time.Second)
|
||
require.ErrorIs(t, err, errOpenAIWSConnClosed)
|
||
}
|
||
|
||
func TestOpenAIWSConn_WriteJSONWithTimeout_NilParentContextUsesBackground(t *testing.T) {
|
||
probe := &openAIWSContextProbeConn{}
|
||
conn := newOpenAIWSConn("ctx_probe", 1, probe, nil)
|
||
require.NoError(t, conn.writeJSONWithTimeout(context.Background(), map[string]any{"type": "response.create"}, 0))
|
||
require.NotNil(t, probe.lastWriteCtx)
|
||
}
|
||
|
||
func TestOpenAIWSConnPool_TargetConnCountAdaptive(t *testing.T) {
|
||
cfg := &config.Config{}
|
||
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 6
|
||
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 1
|
||
cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.5
|
||
|
||
pool := newOpenAIWSConnPool(cfg)
|
||
ap := pool.getOrCreateAccountPool(88)
|
||
|
||
conn1 := newOpenAIWSConn("c1", 88, nil, nil)
|
||
conn2 := newOpenAIWSConn("c2", 88, nil, nil)
|
||
require.True(t, conn1.tryAcquire())
|
||
require.True(t, conn2.tryAcquire())
|
||
conn1.waiters.Store(1)
|
||
conn2.waiters.Store(1)
|
||
|
||
ap.conns[conn1.id] = conn1
|
||
ap.conns[conn2.id] = conn2
|
||
|
||
target := pool.targetConnCountLocked(ap, pool.maxConnsHardCap())
|
||
require.Equal(t, 6, target, "应按 inflight+waiters 与 target_utilization 自适应扩容到上限")
|
||
|
||
conn1.release()
|
||
conn2.release()
|
||
conn1.waiters.Store(0)
|
||
conn2.waiters.Store(0)
|
||
target = pool.targetConnCountLocked(ap, pool.maxConnsHardCap())
|
||
require.Equal(t, 1, target, "低负载时应缩回到最小空闲连接")
|
||
}
|
||
|
||
func TestOpenAIWSConnPool_TargetConnCountMinIdleZero(t *testing.T) {
|
||
cfg := &config.Config{}
|
||
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 4
|
||
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
|
||
cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.8
|
||
|
||
pool := newOpenAIWSConnPool(cfg)
|
||
ap := pool.getOrCreateAccountPool(66)
|
||
|
||
target := pool.targetConnCountLocked(ap, pool.maxConnsHardCap())
|
||
require.Equal(t, 0, target, "min_idle=0 且无负载时应允许缩容到 0")
|
||
}
|
||
|
||
func TestOpenAIWSConnPool_EnsureTargetIdleAsync(t *testing.T) {
|
||
cfg := &config.Config{}
|
||
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 4
|
||
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 2
|
||
cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.8
|
||
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1
|
||
|
||
pool := newOpenAIWSConnPool(cfg)
|
||
pool.setClientDialerForTest(&openAIWSFakeDialer{})
|
||
|
||
accountID := int64(77)
|
||
account := &Account{ID: accountID, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||
ap := pool.getOrCreateAccountPool(accountID)
|
||
ap.mu.Lock()
|
||
ap.lastAcquire = &openAIWSAcquireRequest{
|
||
Account: account,
|
||
WSURL: "wss://example.com/v1/responses",
|
||
}
|
||
ap.mu.Unlock()
|
||
|
||
pool.ensureTargetIdleAsync(accountID)
|
||
|
||
require.Eventually(t, func() bool {
|
||
ap, ok := pool.getAccountPool(accountID)
|
||
if !ok || ap == nil {
|
||
return false
|
||
}
|
||
ap.mu.Lock()
|
||
defer ap.mu.Unlock()
|
||
return len(ap.conns) >= 2
|
||
}, 2*time.Second, 20*time.Millisecond)
|
||
|
||
metrics := pool.SnapshotMetrics()
|
||
require.GreaterOrEqual(t, metrics.ScaleUpTotal, int64(2))
|
||
}
|
||
|
||
func TestOpenAIWSConnPool_EnsureTargetIdleAsyncCooldown(t *testing.T) {
|
||
cfg := &config.Config{}
|
||
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 4
|
||
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 2
|
||
cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.8
|
||
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1
|
||
cfg.Gateway.OpenAIWS.PrewarmCooldownMS = 500
|
||
|
||
pool := newOpenAIWSConnPool(cfg)
|
||
dialer := &openAIWSCountingDialer{}
|
||
pool.setClientDialerForTest(dialer)
|
||
|
||
accountID := int64(178)
|
||
account := &Account{ID: accountID, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||
ap := pool.getOrCreateAccountPool(accountID)
|
||
ap.mu.Lock()
|
||
ap.lastAcquire = &openAIWSAcquireRequest{
|
||
Account: account,
|
||
WSURL: "wss://example.com/v1/responses",
|
||
}
|
||
ap.mu.Unlock()
|
||
|
||
pool.ensureTargetIdleAsync(accountID)
|
||
require.Eventually(t, func() bool {
|
||
ap, ok := pool.getAccountPool(accountID)
|
||
if !ok || ap == nil {
|
||
return false
|
||
}
|
||
ap.mu.Lock()
|
||
defer ap.mu.Unlock()
|
||
return len(ap.conns) >= 2 && !ap.prewarmActive
|
||
}, 2*time.Second, 20*time.Millisecond)
|
||
firstDialCount := dialer.DialCount()
|
||
require.GreaterOrEqual(t, firstDialCount, 2)
|
||
|
||
// 人工制造缺口触发新一轮预热需求。
|
||
ap, ok := pool.getAccountPool(accountID)
|
||
require.True(t, ok)
|
||
require.NotNil(t, ap)
|
||
ap.mu.Lock()
|
||
for id := range ap.conns {
|
||
delete(ap.conns, id)
|
||
break
|
||
}
|
||
ap.mu.Unlock()
|
||
|
||
pool.ensureTargetIdleAsync(accountID)
|
||
time.Sleep(120 * time.Millisecond)
|
||
require.Equal(t, firstDialCount, dialer.DialCount(), "cooldown 窗口内不应再次触发预热")
|
||
|
||
time.Sleep(450 * time.Millisecond)
|
||
pool.ensureTargetIdleAsync(accountID)
|
||
require.Eventually(t, func() bool {
|
||
return dialer.DialCount() > firstDialCount
|
||
}, 2*time.Second, 20*time.Millisecond)
|
||
}
|
||
|
||
func TestOpenAIWSConnPool_EnsureTargetIdleAsyncFailureSuppress(t *testing.T) {
|
||
cfg := &config.Config{}
|
||
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
|
||
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 1
|
||
cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.8
|
||
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1
|
||
cfg.Gateway.OpenAIWS.PrewarmCooldownMS = 0
|
||
|
||
pool := newOpenAIWSConnPool(cfg)
|
||
dialer := &openAIWSAlwaysFailDialer{}
|
||
pool.setClientDialerForTest(dialer)
|
||
|
||
accountID := int64(279)
|
||
account := &Account{ID: accountID, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||
ap := pool.getOrCreateAccountPool(accountID)
|
||
ap.mu.Lock()
|
||
ap.lastAcquire = &openAIWSAcquireRequest{
|
||
Account: account,
|
||
WSURL: "wss://example.com/v1/responses",
|
||
}
|
||
ap.mu.Unlock()
|
||
|
||
pool.ensureTargetIdleAsync(accountID)
|
||
require.Eventually(t, func() bool {
|
||
ap, ok := pool.getAccountPool(accountID)
|
||
if !ok || ap == nil {
|
||
return false
|
||
}
|
||
ap.mu.Lock()
|
||
defer ap.mu.Unlock()
|
||
return !ap.prewarmActive
|
||
}, 2*time.Second, 20*time.Millisecond)
|
||
|
||
pool.ensureTargetIdleAsync(accountID)
|
||
require.Eventually(t, func() bool {
|
||
ap, ok := pool.getAccountPool(accountID)
|
||
if !ok || ap == nil {
|
||
return false
|
||
}
|
||
ap.mu.Lock()
|
||
defer ap.mu.Unlock()
|
||
return !ap.prewarmActive
|
||
}, 2*time.Second, 20*time.Millisecond)
|
||
require.Equal(t, 2, dialer.DialCount())
|
||
|
||
// 连续失败达到阈值后,新的预热触发应被抑制,不再继续拨号。
|
||
pool.ensureTargetIdleAsync(accountID)
|
||
time.Sleep(120 * time.Millisecond)
|
||
require.Equal(t, 2, dialer.DialCount())
|
||
}
|
||
|
||
func TestOpenAIWSConnPool_AcquireQueueWaitMetrics(t *testing.T) {
|
||
cfg := &config.Config{}
|
||
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
|
||
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
|
||
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 4
|
||
|
||
pool := newOpenAIWSConnPool(cfg)
|
||
accountID := int64(99)
|
||
account := &Account{ID: accountID, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||
conn := newOpenAIWSConn("busy", accountID, &openAIWSFakeConn{}, nil)
|
||
require.True(t, conn.tryAcquire()) // 占用连接,触发后续排队
|
||
|
||
ap := pool.ensureAccountPoolLocked(accountID)
|
||
ap.mu.Lock()
|
||
ap.conns[conn.id] = conn
|
||
ap.lastAcquire = &openAIWSAcquireRequest{
|
||
Account: account,
|
||
WSURL: "wss://example.com/v1/responses",
|
||
}
|
||
ap.mu.Unlock()
|
||
|
||
go func() {
|
||
time.Sleep(60 * time.Millisecond)
|
||
conn.release()
|
||
}()
|
||
|
||
lease, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{
|
||
Account: account,
|
||
WSURL: "wss://example.com/v1/responses",
|
||
})
|
||
require.NoError(t, err)
|
||
require.NotNil(t, lease)
|
||
require.True(t, lease.Reused())
|
||
require.GreaterOrEqual(t, lease.QueueWaitDuration(), 50*time.Millisecond)
|
||
lease.Release()
|
||
|
||
metrics := pool.SnapshotMetrics()
|
||
require.GreaterOrEqual(t, metrics.AcquireQueueWaitTotal, int64(1))
|
||
require.Greater(t, metrics.AcquireQueueWaitMsTotal, int64(0))
|
||
require.GreaterOrEqual(t, metrics.ConnPickTotal, int64(1))
|
||
}
|
||
|
||
func TestOpenAIWSConnPool_ForceNewConnSkipsReuse(t *testing.T) {
|
||
cfg := &config.Config{}
|
||
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
|
||
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
|
||
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2
|
||
|
||
pool := newOpenAIWSConnPool(cfg)
|
||
dialer := &openAIWSCountingDialer{}
|
||
pool.setClientDialerForTest(dialer)
|
||
|
||
account := &Account{ID: 123, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||
|
||
lease1, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{
|
||
Account: account,
|
||
WSURL: "wss://example.com/v1/responses",
|
||
})
|
||
require.NoError(t, err)
|
||
require.NotNil(t, lease1)
|
||
lease1.Release()
|
||
|
||
lease2, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{
|
||
Account: account,
|
||
WSURL: "wss://example.com/v1/responses",
|
||
ForceNewConn: true,
|
||
})
|
||
require.NoError(t, err)
|
||
require.NotNil(t, lease2)
|
||
lease2.Release()
|
||
|
||
require.Equal(t, 2, dialer.DialCount(), "ForceNewConn=true 时应跳过空闲连接复用并新建连接")
|
||
}
|
||
|
||
func TestOpenAIWSConnPool_AcquireForcePreferredConnUnavailable(t *testing.T) {
|
||
cfg := &config.Config{}
|
||
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
|
||
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
|
||
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2
|
||
|
||
pool := newOpenAIWSConnPool(cfg)
|
||
account := &Account{ID: 124, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||
ap := pool.getOrCreateAccountPool(account.ID)
|
||
otherConn := newOpenAIWSConn("other_conn", account.ID, &openAIWSFakeConn{}, nil)
|
||
ap.mu.Lock()
|
||
ap.conns[otherConn.id] = otherConn
|
||
ap.mu.Unlock()
|
||
|
||
_, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{
|
||
Account: account,
|
||
WSURL: "wss://example.com/v1/responses",
|
||
ForcePreferredConn: true,
|
||
})
|
||
require.ErrorIs(t, err, errOpenAIWSPreferredConnUnavailable)
|
||
|
||
_, err = pool.Acquire(context.Background(), openAIWSAcquireRequest{
|
||
Account: account,
|
||
WSURL: "wss://example.com/v1/responses",
|
||
PreferredConnID: "missing_conn",
|
||
ForcePreferredConn: true,
|
||
})
|
||
require.ErrorIs(t, err, errOpenAIWSPreferredConnUnavailable)
|
||
}
|
||
|
||
func TestOpenAIWSConnPool_AcquireForcePreferredConnQueuesOnPreferredOnly(t *testing.T) {
|
||
cfg := &config.Config{}
|
||
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
|
||
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
|
||
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2
|
||
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 4
|
||
|
||
pool := newOpenAIWSConnPool(cfg)
|
||
account := &Account{ID: 125, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||
ap := pool.getOrCreateAccountPool(account.ID)
|
||
preferredConn := newOpenAIWSConn("preferred_conn", account.ID, &openAIWSFakeConn{}, nil)
|
||
otherConn := newOpenAIWSConn("other_conn_idle", account.ID, &openAIWSFakeConn{}, nil)
|
||
require.True(t, preferredConn.tryAcquire(), "先占用 preferred 连接,触发排队获取")
|
||
ap.mu.Lock()
|
||
ap.conns[preferredConn.id] = preferredConn
|
||
ap.conns[otherConn.id] = otherConn
|
||
ap.lastCleanupAt = time.Now()
|
||
ap.mu.Unlock()
|
||
|
||
go func() {
|
||
time.Sleep(60 * time.Millisecond)
|
||
preferredConn.release()
|
||
}()
|
||
|
||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||
defer cancel()
|
||
lease, err := pool.Acquire(ctx, openAIWSAcquireRequest{
|
||
Account: account,
|
||
WSURL: "wss://example.com/v1/responses",
|
||
PreferredConnID: preferredConn.id,
|
||
ForcePreferredConn: true,
|
||
})
|
||
require.NoError(t, err)
|
||
require.NotNil(t, lease)
|
||
require.Equal(t, preferredConn.id, lease.ConnID(), "严格模式应只等待并复用 preferred 连接,不可漂移")
|
||
require.GreaterOrEqual(t, lease.QueueWaitDuration(), 40*time.Millisecond)
|
||
lease.Release()
|
||
require.True(t, otherConn.tryAcquire(), "other 连接不应被严格模式抢占")
|
||
otherConn.release()
|
||
}
|
||
|
||
func TestOpenAIWSConnPool_AcquireForcePreferredConnDirectAndQueueFull(t *testing.T) {
|
||
cfg := &config.Config{}
|
||
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
|
||
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
|
||
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2
|
||
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 1
|
||
|
||
pool := newOpenAIWSConnPool(cfg)
|
||
account := &Account{ID: 127, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||
ap := pool.getOrCreateAccountPool(account.ID)
|
||
preferredConn := newOpenAIWSConn("preferred_conn_direct", account.ID, &openAIWSFakeConn{}, nil)
|
||
otherConn := newOpenAIWSConn("other_conn_direct", account.ID, &openAIWSFakeConn{}, nil)
|
||
ap.mu.Lock()
|
||
ap.conns[preferredConn.id] = preferredConn
|
||
ap.conns[otherConn.id] = otherConn
|
||
ap.lastCleanupAt = time.Now()
|
||
ap.mu.Unlock()
|
||
|
||
lease, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{
|
||
Account: account,
|
||
WSURL: "wss://example.com/v1/responses",
|
||
PreferredConnID: preferredConn.id,
|
||
ForcePreferredConn: true,
|
||
})
|
||
require.NoError(t, err)
|
||
require.Equal(t, preferredConn.id, lease.ConnID(), "preferred 空闲时应直接命中")
|
||
lease.Release()
|
||
|
||
require.True(t, preferredConn.tryAcquire())
|
||
preferredConn.waiters.Store(1)
|
||
_, err = pool.Acquire(context.Background(), openAIWSAcquireRequest{
|
||
Account: account,
|
||
WSURL: "wss://example.com/v1/responses",
|
||
PreferredConnID: preferredConn.id,
|
||
ForcePreferredConn: true,
|
||
})
|
||
require.ErrorIs(t, err, errOpenAIWSConnQueueFull, "严格模式下队列满应直接失败,不得漂移")
|
||
preferredConn.waiters.Store(0)
|
||
preferredConn.release()
|
||
}
|
||
|
||
func TestOpenAIWSConnPool_CleanupSkipsPinnedConn(t *testing.T) {
|
||
cfg := &config.Config{}
|
||
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
|
||
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 0
|
||
|
||
pool := newOpenAIWSConnPool(cfg)
|
||
accountID := int64(126)
|
||
ap := pool.getOrCreateAccountPool(accountID)
|
||
pinnedConn := newOpenAIWSConn("pinned_conn", accountID, &openAIWSFakeConn{}, nil)
|
||
idleConn := newOpenAIWSConn("idle_conn", accountID, &openAIWSFakeConn{}, nil)
|
||
ap.mu.Lock()
|
||
ap.conns[pinnedConn.id] = pinnedConn
|
||
ap.conns[idleConn.id] = idleConn
|
||
ap.mu.Unlock()
|
||
|
||
require.True(t, pool.PinConn(accountID, pinnedConn.id))
|
||
evicted := pool.cleanupAccountLocked(ap, time.Now(), pool.maxConnsHardCap())
|
||
closeOpenAIWSConns(evicted)
|
||
|
||
ap.mu.Lock()
|
||
_, pinnedExists := ap.conns[pinnedConn.id]
|
||
_, idleExists := ap.conns[idleConn.id]
|
||
ap.mu.Unlock()
|
||
require.True(t, pinnedExists, "被 active ingress 绑定的连接不应被 cleanup 回收")
|
||
require.False(t, idleExists, "非绑定的空闲连接应被回收")
|
||
|
||
pool.UnpinConn(accountID, pinnedConn.id)
|
||
evicted = pool.cleanupAccountLocked(ap, time.Now(), pool.maxConnsHardCap())
|
||
closeOpenAIWSConns(evicted)
|
||
ap.mu.Lock()
|
||
_, pinnedExists = ap.conns[pinnedConn.id]
|
||
ap.mu.Unlock()
|
||
require.False(t, pinnedExists, "解绑后连接应可被正常回收")
|
||
}
|
||
|
||
func TestOpenAIWSConnPool_PinUnpinConnBranches(t *testing.T) {
|
||
var nilPool *openAIWSConnPool
|
||
require.False(t, nilPool.PinConn(1, "x"))
|
||
nilPool.UnpinConn(1, "x")
|
||
|
||
cfg := &config.Config{}
|
||
pool := newOpenAIWSConnPool(cfg)
|
||
accountID := int64(128)
|
||
ap := &openAIWSAccountPool{
|
||
conns: map[string]*openAIWSConn{},
|
||
}
|
||
pool.accounts.Store(accountID, ap)
|
||
|
||
require.False(t, pool.PinConn(0, "x"))
|
||
require.False(t, pool.PinConn(999, "x"))
|
||
require.False(t, pool.PinConn(accountID, ""))
|
||
require.False(t, pool.PinConn(accountID, "missing"))
|
||
|
||
conn := newOpenAIWSConn("pin_refcount", accountID, &openAIWSFakeConn{}, nil)
|
||
ap.mu.Lock()
|
||
ap.conns[conn.id] = conn
|
||
ap.mu.Unlock()
|
||
require.True(t, pool.PinConn(accountID, conn.id))
|
||
require.True(t, pool.PinConn(accountID, conn.id))
|
||
|
||
ap.mu.Lock()
|
||
require.Equal(t, 2, ap.pinnedConns[conn.id])
|
||
ap.mu.Unlock()
|
||
|
||
pool.UnpinConn(accountID, conn.id)
|
||
ap.mu.Lock()
|
||
require.Equal(t, 1, ap.pinnedConns[conn.id])
|
||
ap.mu.Unlock()
|
||
|
||
pool.UnpinConn(accountID, conn.id)
|
||
ap.mu.Lock()
|
||
_, exists := ap.pinnedConns[conn.id]
|
||
ap.mu.Unlock()
|
||
require.False(t, exists)
|
||
|
||
pool.UnpinConn(accountID, conn.id)
|
||
pool.UnpinConn(accountID, "")
|
||
pool.UnpinConn(0, conn.id)
|
||
pool.UnpinConn(999, conn.id)
|
||
}
|
||
|
||
func TestOpenAIWSConnPool_EffectiveMaxConnsByAccount(t *testing.T) {
|
||
cfg := &config.Config{}
|
||
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8
|
||
cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled = true
|
||
cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor = 1.0
|
||
cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor = 0.6
|
||
|
||
pool := newOpenAIWSConnPool(cfg)
|
||
|
||
oauthHigh := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 10}
|
||
require.Equal(t, 8, pool.effectiveMaxConnsByAccount(oauthHigh), "应受全局硬上限约束")
|
||
|
||
oauthLow := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 3}
|
||
require.Equal(t, 3, pool.effectiveMaxConnsByAccount(oauthLow))
|
||
|
||
apiKeyHigh := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Concurrency: 10}
|
||
require.Equal(t, 6, pool.effectiveMaxConnsByAccount(apiKeyHigh), "API Key 应按系数缩放")
|
||
|
||
apiKeyLow := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Concurrency: 1}
|
||
require.Equal(t, 1, pool.effectiveMaxConnsByAccount(apiKeyLow), "最小值应保持为 1")
|
||
|
||
unlimited := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 0}
|
||
require.Equal(t, 8, pool.effectiveMaxConnsByAccount(unlimited), "无限并发应回退到全局硬上限")
|
||
|
||
require.Equal(t, 8, pool.effectiveMaxConnsByAccount(nil), "缺少账号上下文应回退到全局硬上限")
|
||
}
|
||
|
||
func TestOpenAIWSConnPool_EffectiveMaxConnsDisabledFallbackHardCap(t *testing.T) {
|
||
cfg := &config.Config{}
|
||
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8
|
||
cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled = false
|
||
cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor = 1.0
|
||
cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor = 1.0
|
||
|
||
pool := newOpenAIWSConnPool(cfg)
|
||
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 2}
|
||
require.Equal(t, 8, pool.effectiveMaxConnsByAccount(account), "关闭动态模式后应保持旧行为")
|
||
}
|
||
|
||
func TestOpenAIWSConnPool_EffectiveMaxConnsByAccount_ModeRouterV2UsesAccountConcurrency(t *testing.T) {
|
||
cfg := &config.Config{}
|
||
cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
|
||
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8
|
||
cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled = true
|
||
cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor = 0.3
|
||
cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor = 0.6
|
||
|
||
pool := newOpenAIWSConnPool(cfg)
|
||
|
||
high := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 20}
|
||
require.Equal(t, 20, pool.effectiveMaxConnsByAccount(high), "v2 路径应直接使用账号并发数作为池上限")
|
||
|
||
nonPositive := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Concurrency: 0}
|
||
require.Equal(t, 0, pool.effectiveMaxConnsByAccount(nonPositive), "并发数<=0 时应不可调度")
|
||
}
|
||
|
||
func TestOpenAIWSConnPool_AcquireRejectsWhenEffectiveMaxConnsIsZero(t *testing.T) {
|
||
cfg := &config.Config{}
|
||
cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
|
||
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8
|
||
pool := newOpenAIWSConnPool(cfg)
|
||
|
||
account := &Account{ID: 901, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 0}
|
||
_, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{
|
||
Account: account,
|
||
WSURL: "wss://example.com/v1/responses",
|
||
})
|
||
require.ErrorIs(t, err, errOpenAIWSConnQueueFull)
|
||
}
|
||
|
||
func TestOpenAIWSConnLease_ReadMessageWithContextTimeout_PerRead(t *testing.T) {
|
||
conn := newOpenAIWSConn("timeout", 1, &openAIWSBlockingConn{readDelay: 80 * time.Millisecond}, nil)
|
||
lease := &openAIWSConnLease{conn: conn}
|
||
|
||
_, err := lease.ReadMessageWithContextTimeout(context.Background(), 20*time.Millisecond)
|
||
require.Error(t, err)
|
||
require.ErrorIs(t, err, context.DeadlineExceeded)
|
||
|
||
payload, err := lease.ReadMessageWithContextTimeout(context.Background(), 150*time.Millisecond)
|
||
require.NoError(t, err)
|
||
require.Contains(t, string(payload), "response.completed")
|
||
|
||
parentCtx, cancel := context.WithCancel(context.Background())
|
||
cancel()
|
||
_, err = lease.ReadMessageWithContextTimeout(parentCtx, 150*time.Millisecond)
|
||
require.Error(t, err)
|
||
require.ErrorIs(t, err, context.Canceled)
|
||
}
|
||
|
||
func TestOpenAIWSConnLease_WriteJSONWithContextTimeout_RespectsParentContext(t *testing.T) {
|
||
conn := newOpenAIWSConn("write_timeout_ctx", 1, &openAIWSWriteBlockingConn{}, nil)
|
||
lease := &openAIWSConnLease{conn: conn}
|
||
|
||
parentCtx, cancel := context.WithCancel(context.Background())
|
||
go func() {
|
||
time.Sleep(20 * time.Millisecond)
|
||
cancel()
|
||
}()
|
||
|
||
start := time.Now()
|
||
err := lease.WriteJSONWithContextTimeout(parentCtx, map[string]any{"type": "response.create"}, 2*time.Minute)
|
||
elapsed := time.Since(start)
|
||
|
||
require.Error(t, err)
|
||
require.ErrorIs(t, err, context.Canceled)
|
||
require.Less(t, elapsed, 200*time.Millisecond)
|
||
}
|
||
|
||
func TestOpenAIWSConnLease_PingWithTimeout(t *testing.T) {
|
||
conn := newOpenAIWSConn("ping_ok", 1, &openAIWSFakeConn{}, nil)
|
||
lease := &openAIWSConnLease{conn: conn}
|
||
require.NoError(t, lease.PingWithTimeout(50*time.Millisecond))
|
||
|
||
var nilLease *openAIWSConnLease
|
||
err := nilLease.PingWithTimeout(50 * time.Millisecond)
|
||
require.ErrorIs(t, err, errOpenAIWSConnClosed)
|
||
}
|
||
|
||
func TestOpenAIWSConn_ReadAndWriteCanProceedConcurrently(t *testing.T) {
|
||
conn := newOpenAIWSConn("full_duplex", 1, &openAIWSBlockingConn{readDelay: 120 * time.Millisecond}, nil)
|
||
|
||
readDone := make(chan error, 1)
|
||
go func() {
|
||
_, err := conn.readMessageWithContextTimeout(context.Background(), 200*time.Millisecond)
|
||
readDone <- err
|
||
}()
|
||
|
||
// 让读取先占用 readMu。
|
||
time.Sleep(20 * time.Millisecond)
|
||
|
||
start := time.Now()
|
||
err := conn.pingWithTimeout(50 * time.Millisecond)
|
||
elapsed := time.Since(start)
|
||
|
||
require.NoError(t, err)
|
||
require.Less(t, elapsed, 80*time.Millisecond, "写路径不应被读锁长期阻塞")
|
||
require.NoError(t, <-readDone)
|
||
}
|
||
|
||
func TestOpenAIWSConnPool_BackgroundPingSweep_EvictsDeadIdleConn(t *testing.T) {
|
||
cfg := &config.Config{}
|
||
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
|
||
pool := newOpenAIWSConnPool(cfg)
|
||
|
||
accountID := int64(301)
|
||
ap := pool.getOrCreateAccountPool(accountID)
|
||
conn := newOpenAIWSConn("dead_idle", accountID, &openAIWSPingFailConn{}, nil)
|
||
ap.mu.Lock()
|
||
ap.conns[conn.id] = conn
|
||
ap.mu.Unlock()
|
||
|
||
pool.runBackgroundPingSweep()
|
||
|
||
ap.mu.Lock()
|
||
_, exists := ap.conns[conn.id]
|
||
ap.mu.Unlock()
|
||
require.False(t, exists, "后台 ping 失败的空闲连接应被回收")
|
||
}
|
||
|
||
func TestOpenAIWSConnPool_BackgroundCleanupSweep_WithoutAcquire(t *testing.T) {
|
||
cfg := &config.Config{}
|
||
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
|
||
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2
|
||
pool := newOpenAIWSConnPool(cfg)
|
||
|
||
accountID := int64(302)
|
||
ap := pool.getOrCreateAccountPool(accountID)
|
||
stale := newOpenAIWSConn("stale_bg", accountID, &openAIWSFakeConn{}, nil)
|
||
stale.createdAtNano.Store(time.Now().Add(-2 * time.Hour).UnixNano())
|
||
stale.lastUsedNano.Store(time.Now().Add(-2 * time.Hour).UnixNano())
|
||
ap.mu.Lock()
|
||
ap.conns[stale.id] = stale
|
||
ap.mu.Unlock()
|
||
|
||
pool.runBackgroundCleanupSweep(time.Now())
|
||
|
||
ap.mu.Lock()
|
||
_, exists := ap.conns[stale.id]
|
||
ap.mu.Unlock()
|
||
require.False(t, exists, "后台清理应在无新 acquire 时也回收过期连接")
|
||
}
|
||
|
||
func TestOpenAIWSConnPool_BackgroundWorkerGuardBranches(t *testing.T) {
|
||
var nilPool *openAIWSConnPool
|
||
require.NotPanics(t, func() {
|
||
nilPool.startBackgroundWorkers()
|
||
nilPool.runBackgroundPingWorker()
|
||
nilPool.runBackgroundPingSweep()
|
||
_ = nilPool.snapshotIdleConnsForPing()
|
||
nilPool.runBackgroundCleanupWorker()
|
||
nilPool.runBackgroundCleanupSweep(time.Now())
|
||
})
|
||
|
||
poolNoStop := &openAIWSConnPool{}
|
||
require.NotPanics(t, func() {
|
||
poolNoStop.startBackgroundWorkers()
|
||
})
|
||
|
||
poolStopPing := &openAIWSConnPool{workerStopCh: make(chan struct{})}
|
||
pingDone := make(chan struct{})
|
||
go func() {
|
||
poolStopPing.runBackgroundPingWorker()
|
||
close(pingDone)
|
||
}()
|
||
close(poolStopPing.workerStopCh)
|
||
select {
|
||
case <-pingDone:
|
||
case <-time.After(500 * time.Millisecond):
|
||
t.Fatal("runBackgroundPingWorker 未在 stop 信号后退出")
|
||
}
|
||
|
||
poolStopCleanup := &openAIWSConnPool{workerStopCh: make(chan struct{})}
|
||
cleanupDone := make(chan struct{})
|
||
go func() {
|
||
poolStopCleanup.runBackgroundCleanupWorker()
|
||
close(cleanupDone)
|
||
}()
|
||
close(poolStopCleanup.workerStopCh)
|
||
select {
|
||
case <-cleanupDone:
|
||
case <-time.After(500 * time.Millisecond):
|
||
t.Fatal("runBackgroundCleanupWorker 未在 stop 信号后退出")
|
||
}
|
||
}
|
||
|
||
func TestOpenAIWSConnPool_SnapshotIdleConnsForPing_SkipsInvalidEntries(t *testing.T) {
|
||
pool := &openAIWSConnPool{}
|
||
pool.accounts.Store("invalid-key", &openAIWSAccountPool{})
|
||
pool.accounts.Store(int64(123), "invalid-value")
|
||
|
||
accountID := int64(123)
|
||
ap := &openAIWSAccountPool{
|
||
conns: make(map[string]*openAIWSConn),
|
||
}
|
||
ap.conns["nil_conn"] = nil
|
||
|
||
leased := newOpenAIWSConn("leased", accountID, &openAIWSFakeConn{}, nil)
|
||
require.True(t, leased.tryAcquire())
|
||
ap.conns[leased.id] = leased
|
||
|
||
waiting := newOpenAIWSConn("waiting", accountID, &openAIWSFakeConn{}, nil)
|
||
waiting.waiters.Store(1)
|
||
ap.conns[waiting.id] = waiting
|
||
|
||
idle := newOpenAIWSConn("idle", accountID, &openAIWSFakeConn{}, nil)
|
||
ap.conns[idle.id] = idle
|
||
|
||
pool.accounts.Store(accountID, ap)
|
||
candidates := pool.snapshotIdleConnsForPing()
|
||
require.Len(t, candidates, 1)
|
||
require.Equal(t, idle.id, candidates[0].conn.id)
|
||
}
|
||
|
||
func TestOpenAIWSConnPool_RunBackgroundCleanupSweep_SkipsInvalidAndUsesAccountCap(t *testing.T) {
|
||
cfg := &config.Config{}
|
||
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 4
|
||
cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled = true
|
||
|
||
pool := &openAIWSConnPool{cfg: cfg}
|
||
pool.accounts.Store("bad-key", "bad-value")
|
||
|
||
accountID := int64(2026)
|
||
ap := &openAIWSAccountPool{
|
||
conns: make(map[string]*openAIWSConn),
|
||
}
|
||
ap.conns["nil_conn"] = nil
|
||
stale := newOpenAIWSConn("stale_bg_cleanup", accountID, &openAIWSFakeConn{}, nil)
|
||
stale.createdAtNano.Store(time.Now().Add(-2 * time.Hour).UnixNano())
|
||
stale.lastUsedNano.Store(time.Now().Add(-2 * time.Hour).UnixNano())
|
||
ap.conns[stale.id] = stale
|
||
ap.lastAcquire = &openAIWSAcquireRequest{
|
||
Account: &Account{
|
||
ID: accountID,
|
||
Platform: PlatformOpenAI,
|
||
Type: AccountTypeAPIKey,
|
||
Concurrency: 1,
|
||
},
|
||
}
|
||
pool.accounts.Store(accountID, ap)
|
||
|
||
now := time.Now()
|
||
require.NotPanics(t, func() {
|
||
pool.runBackgroundCleanupSweep(now)
|
||
})
|
||
|
||
ap.mu.Lock()
|
||
_, nilConnExists := ap.conns["nil_conn"]
|
||
_, exists := ap.conns[stale.id]
|
||
lastCleanupAt := ap.lastCleanupAt
|
||
ap.mu.Unlock()
|
||
|
||
require.False(t, nilConnExists, "后台清理应移除无效 nil 连接条目")
|
||
require.False(t, exists, "后台清理应清理过期连接")
|
||
require.Equal(t, now, lastCleanupAt)
|
||
}
|
||
|
||
func TestOpenAIWSConnPool_QueueLimitPerConn_DefaultAndConfigured(t *testing.T) {
|
||
var nilPool *openAIWSConnPool
|
||
require.Equal(t, 256, nilPool.queueLimitPerConn())
|
||
|
||
pool := &openAIWSConnPool{cfg: &config.Config{}}
|
||
require.Equal(t, 256, pool.queueLimitPerConn())
|
||
|
||
pool.cfg.Gateway.OpenAIWS.QueueLimitPerConn = 9
|
||
require.Equal(t, 9, pool.queueLimitPerConn())
|
||
}
|
||
|
||
func TestOpenAIWSConnPool_Close(t *testing.T) {
|
||
cfg := &config.Config{}
|
||
pool := newOpenAIWSConnPool(cfg)
|
||
|
||
// Close 应该可以安全调用
|
||
pool.Close()
|
||
|
||
// workerStopCh 应已关闭
|
||
select {
|
||
case <-pool.workerStopCh:
|
||
// 预期:channel 已关闭
|
||
default:
|
||
t.Fatal("Close 后 workerStopCh 应已关闭")
|
||
}
|
||
|
||
// 多次调用 Close 不应 panic
|
||
pool.Close()
|
||
|
||
// nil pool 调用 Close 不应 panic
|
||
var nilPool *openAIWSConnPool
|
||
nilPool.Close()
|
||
}
|
||
|
||
func TestOpenAIWSDialError_ErrorAndUnwrap(t *testing.T) {
|
||
baseErr := errors.New("boom")
|
||
dialErr := &openAIWSDialError{StatusCode: 502, Err: baseErr}
|
||
require.Contains(t, dialErr.Error(), "status=502")
|
||
require.ErrorIs(t, dialErr.Unwrap(), baseErr)
|
||
|
||
noStatus := &openAIWSDialError{Err: baseErr}
|
||
require.Contains(t, noStatus.Error(), "boom")
|
||
|
||
var nilDialErr *openAIWSDialError
|
||
require.Equal(t, "", nilDialErr.Error())
|
||
require.NoError(t, nilDialErr.Unwrap())
|
||
}
|
||
|
||
func TestOpenAIWSConnLease_ReadWriteHelpersAndConnStats(t *testing.T) {
|
||
conn := newOpenAIWSConn("helper_conn", 1, &openAIWSFakeConn{}, http.Header{
|
||
"X-Test": []string{" value "},
|
||
})
|
||
lease := &openAIWSConnLease{conn: conn}
|
||
|
||
require.NoError(t, lease.WriteJSONContext(context.Background(), map[string]any{"type": "response.create"}))
|
||
payload, err := lease.ReadMessage(100 * time.Millisecond)
|
||
require.NoError(t, err)
|
||
require.Contains(t, string(payload), "response.completed")
|
||
|
||
payload, err = lease.ReadMessageContext(context.Background())
|
||
require.NoError(t, err)
|
||
require.Contains(t, string(payload), "response.completed")
|
||
|
||
payload, err = conn.readMessageWithTimeout(100 * time.Millisecond)
|
||
require.NoError(t, err)
|
||
require.Contains(t, string(payload), "response.completed")
|
||
|
||
require.Equal(t, "value", conn.handshakeHeader(" X-Test "))
|
||
require.NotZero(t, conn.createdAt())
|
||
require.NotZero(t, conn.lastUsedAt())
|
||
require.GreaterOrEqual(t, conn.age(time.Now()), time.Duration(0))
|
||
require.GreaterOrEqual(t, conn.idleDuration(time.Now()), time.Duration(0))
|
||
require.False(t, conn.isLeased())
|
||
|
||
// 覆盖空上下文路径
|
||
_, err = conn.readMessage(context.Background())
|
||
require.NoError(t, err)
|
||
|
||
// 覆盖 nil 保护分支
|
||
var nilConn *openAIWSConn
|
||
require.ErrorIs(t, nilConn.writeJSONWithTimeout(context.Background(), map[string]any{}, time.Second), errOpenAIWSConnClosed)
|
||
_, err = nilConn.readMessageWithTimeout(10 * time.Millisecond)
|
||
require.ErrorIs(t, err, errOpenAIWSConnClosed)
|
||
_, err = nilConn.readMessageWithContextTimeout(context.Background(), 10*time.Millisecond)
|
||
require.ErrorIs(t, err, errOpenAIWSConnClosed)
|
||
}
|
||
|
||
func TestOpenAIWSConnPool_PickOldestIdleAndAccountPoolLoad(t *testing.T) {
|
||
pool := &openAIWSConnPool{}
|
||
accountID := int64(404)
|
||
ap := &openAIWSAccountPool{conns: map[string]*openAIWSConn{}}
|
||
|
||
idleOld := newOpenAIWSConn("idle_old", accountID, &openAIWSFakeConn{}, nil)
|
||
idleOld.lastUsedNano.Store(time.Now().Add(-10 * time.Minute).UnixNano())
|
||
idleNew := newOpenAIWSConn("idle_new", accountID, &openAIWSFakeConn{}, nil)
|
||
idleNew.lastUsedNano.Store(time.Now().Add(-1 * time.Minute).UnixNano())
|
||
leased := newOpenAIWSConn("leased", accountID, &openAIWSFakeConn{}, nil)
|
||
require.True(t, leased.tryAcquire())
|
||
leased.waiters.Store(2)
|
||
|
||
ap.conns[idleOld.id] = idleOld
|
||
ap.conns[idleNew.id] = idleNew
|
||
ap.conns[leased.id] = leased
|
||
|
||
oldest := pool.pickOldestIdleConnLocked(ap)
|
||
require.NotNil(t, oldest)
|
||
require.Equal(t, idleOld.id, oldest.id)
|
||
|
||
inflight, waiters := accountPoolLoadLocked(ap)
|
||
require.Equal(t, 1, inflight)
|
||
require.Equal(t, 2, waiters)
|
||
|
||
pool.accounts.Store(accountID, ap)
|
||
loadInflight, loadWaiters, conns := pool.AccountPoolLoad(accountID)
|
||
require.Equal(t, 1, loadInflight)
|
||
require.Equal(t, 2, loadWaiters)
|
||
require.Equal(t, 3, conns)
|
||
|
||
zeroInflight, zeroWaiters, zeroConns := pool.AccountPoolLoad(0)
|
||
require.Equal(t, 0, zeroInflight)
|
||
require.Equal(t, 0, zeroWaiters)
|
||
require.Equal(t, 0, zeroConns)
|
||
}
|
||
|
||
func TestOpenAIWSConnPool_Close_WaitsWorkerGroupAndNilStopChannel(t *testing.T) {
|
||
pool := &openAIWSConnPool{}
|
||
release := make(chan struct{})
|
||
pool.workerWg.Add(1)
|
||
go func() {
|
||
defer pool.workerWg.Done()
|
||
<-release
|
||
}()
|
||
|
||
closed := make(chan struct{})
|
||
go func() {
|
||
pool.Close()
|
||
close(closed)
|
||
}()
|
||
|
||
select {
|
||
case <-closed:
|
||
t.Fatal("Close 不应在 WaitGroup 未完成时提前返回")
|
||
case <-time.After(30 * time.Millisecond):
|
||
}
|
||
|
||
close(release)
|
||
select {
|
||
case <-closed:
|
||
case <-time.After(time.Second):
|
||
t.Fatal("Close 未等待 workerWg 完成")
|
||
}
|
||
}
|
||
|
||
func TestOpenAIWSConnPool_Close_ClosesOnlyIdleConnections(t *testing.T) {
|
||
pool := &openAIWSConnPool{
|
||
workerStopCh: make(chan struct{}),
|
||
}
|
||
|
||
accountID := int64(606)
|
||
ap := &openAIWSAccountPool{
|
||
conns: map[string]*openAIWSConn{},
|
||
}
|
||
idle := newOpenAIWSConn("idle_conn", accountID, &openAIWSFakeConn{}, nil)
|
||
leased := newOpenAIWSConn("leased_conn", accountID, &openAIWSFakeConn{}, nil)
|
||
require.True(t, leased.tryAcquire())
|
||
|
||
ap.conns[idle.id] = idle
|
||
ap.conns[leased.id] = leased
|
||
pool.accounts.Store(accountID, ap)
|
||
pool.accounts.Store("invalid-key", "invalid-value")
|
||
|
||
pool.Close()
|
||
|
||
select {
|
||
case <-idle.closedCh:
|
||
// idle should be closed
|
||
default:
|
||
t.Fatal("空闲连接应在 Close 时被关闭")
|
||
}
|
||
|
||
select {
|
||
case <-leased.closedCh:
|
||
t.Fatal("已租赁连接不应在 Close 时被关闭")
|
||
default:
|
||
}
|
||
|
||
leased.release()
|
||
pool.Close()
|
||
}
|
||
|
||
func TestOpenAIWSConnPool_RunBackgroundPingSweep_ConcurrencyLimit(t *testing.T) {
|
||
cfg := &config.Config{}
|
||
pool := newOpenAIWSConnPool(cfg)
|
||
accountID := int64(505)
|
||
ap := pool.getOrCreateAccountPool(accountID)
|
||
|
||
var current atomic.Int32
|
||
var maxConcurrent atomic.Int32
|
||
release := make(chan struct{})
|
||
for i := 0; i < 25; i++ {
|
||
conn := newOpenAIWSConn(pool.nextConnID(accountID), accountID, &openAIWSPingBlockingConn{
|
||
current: ¤t,
|
||
maxConcurrent: &maxConcurrent,
|
||
release: release,
|
||
}, nil)
|
||
ap.mu.Lock()
|
||
ap.conns[conn.id] = conn
|
||
ap.mu.Unlock()
|
||
}
|
||
|
||
done := make(chan struct{})
|
||
go func() {
|
||
pool.runBackgroundPingSweep()
|
||
close(done)
|
||
}()
|
||
|
||
require.Eventually(t, func() bool {
|
||
return maxConcurrent.Load() >= 10
|
||
}, time.Second, 10*time.Millisecond)
|
||
|
||
close(release)
|
||
select {
|
||
case <-done:
|
||
case <-time.After(2 * time.Second):
|
||
t.Fatal("runBackgroundPingSweep 未在释放后完成")
|
||
}
|
||
|
||
require.LessOrEqual(t, maxConcurrent.Load(), int32(10))
|
||
}
|
||
|
||
func TestOpenAIWSConnLease_BasicGetterBranches(t *testing.T) {
|
||
var nilLease *openAIWSConnLease
|
||
require.Equal(t, "", nilLease.ConnID())
|
||
require.Equal(t, time.Duration(0), nilLease.QueueWaitDuration())
|
||
require.Equal(t, time.Duration(0), nilLease.ConnPickDuration())
|
||
require.False(t, nilLease.Reused())
|
||
require.Equal(t, "", nilLease.HandshakeHeader("x-test"))
|
||
require.False(t, nilLease.IsPrewarmed())
|
||
nilLease.MarkPrewarmed()
|
||
nilLease.Release()
|
||
|
||
conn := newOpenAIWSConn("getter_conn", 1, &openAIWSFakeConn{}, http.Header{"X-Test": []string{"ok"}})
|
||
lease := &openAIWSConnLease{
|
||
conn: conn,
|
||
queueWait: 3 * time.Millisecond,
|
||
connPick: 4 * time.Millisecond,
|
||
reused: true,
|
||
}
|
||
require.Equal(t, "getter_conn", lease.ConnID())
|
||
require.Equal(t, 3*time.Millisecond, lease.QueueWaitDuration())
|
||
require.Equal(t, 4*time.Millisecond, lease.ConnPickDuration())
|
||
require.True(t, lease.Reused())
|
||
require.Equal(t, "ok", lease.HandshakeHeader("x-test"))
|
||
require.False(t, lease.IsPrewarmed())
|
||
lease.MarkPrewarmed()
|
||
require.True(t, lease.IsPrewarmed())
|
||
lease.Release()
|
||
}
|
||
|
||
func TestOpenAIWSConnPool_UtilityBranches(t *testing.T) {
|
||
var nilPool *openAIWSConnPool
|
||
require.Equal(t, OpenAIWSPoolMetricsSnapshot{}, nilPool.SnapshotMetrics())
|
||
require.Equal(t, OpenAIWSTransportMetricsSnapshot{}, nilPool.SnapshotTransportMetrics())
|
||
|
||
pool := &openAIWSConnPool{cfg: &config.Config{}}
|
||
pool.metrics.acquireTotal.Store(7)
|
||
pool.metrics.acquireReuseTotal.Store(3)
|
||
metrics := pool.SnapshotMetrics()
|
||
require.Equal(t, int64(7), metrics.AcquireTotal)
|
||
require.Equal(t, int64(3), metrics.AcquireReuseTotal)
|
||
|
||
// 非 transport metrics dialer 路径
|
||
pool.clientDialer = &openAIWSFakeDialer{}
|
||
require.Equal(t, OpenAIWSTransportMetricsSnapshot{}, pool.SnapshotTransportMetrics())
|
||
pool.setClientDialerForTest(nil)
|
||
require.NotNil(t, pool.clientDialer)
|
||
|
||
require.Equal(t, 8, nilPool.maxConnsHardCap())
|
||
require.False(t, nilPool.dynamicMaxConnsEnabled())
|
||
require.Equal(t, 1.0, nilPool.maxConnsFactorByAccount(nil))
|
||
require.Equal(t, 0, nilPool.minIdlePerAccount())
|
||
require.Equal(t, 4, nilPool.maxIdlePerAccount())
|
||
require.Equal(t, 256, nilPool.queueLimitPerConn())
|
||
require.Equal(t, 0.7, nilPool.targetUtilization())
|
||
require.Equal(t, time.Duration(0), nilPool.prewarmCooldown())
|
||
require.Equal(t, 10*time.Second, nilPool.dialTimeout())
|
||
|
||
// shouldSuppressPrewarmLocked 覆盖 3 条分支
|
||
now := time.Now()
|
||
apNilFail := &openAIWSAccountPool{prewarmFails: 1}
|
||
require.False(t, pool.shouldSuppressPrewarmLocked(apNilFail, now))
|
||
apZeroTime := &openAIWSAccountPool{prewarmFails: 2}
|
||
require.False(t, pool.shouldSuppressPrewarmLocked(apZeroTime, now))
|
||
require.Equal(t, 0, apZeroTime.prewarmFails)
|
||
apOldFail := &openAIWSAccountPool{prewarmFails: 2, prewarmFailAt: now.Add(-openAIWSPrewarmFailureWindow - time.Second)}
|
||
require.False(t, pool.shouldSuppressPrewarmLocked(apOldFail, now))
|
||
apRecentFail := &openAIWSAccountPool{prewarmFails: openAIWSPrewarmFailureSuppress, prewarmFailAt: now}
|
||
require.True(t, pool.shouldSuppressPrewarmLocked(apRecentFail, now))
|
||
|
||
// recordConnPickDuration 的保护分支
|
||
nilPool.recordConnPickDuration(10 * time.Millisecond)
|
||
pool.recordConnPickDuration(-10 * time.Millisecond)
|
||
require.Equal(t, int64(1), pool.metrics.connPickTotal.Load())
|
||
|
||
// account pool 读写分支
|
||
require.Nil(t, nilPool.getOrCreateAccountPool(1))
|
||
require.Nil(t, pool.getOrCreateAccountPool(0))
|
||
pool.accounts.Store(int64(7), "invalid")
|
||
ap := pool.getOrCreateAccountPool(7)
|
||
require.NotNil(t, ap)
|
||
_, ok := pool.getAccountPool(0)
|
||
require.False(t, ok)
|
||
_, ok = pool.getAccountPool(12345)
|
||
require.False(t, ok)
|
||
pool.accounts.Store(int64(8), "bad-type")
|
||
_, ok = pool.getAccountPool(8)
|
||
require.False(t, ok)
|
||
|
||
// health check 条件
|
||
require.False(t, pool.shouldHealthCheckConn(nil))
|
||
conn := newOpenAIWSConn("health", 1, &openAIWSFakeConn{}, nil)
|
||
conn.lastUsedNano.Store(time.Now().Add(-openAIWSConnHealthCheckIdle - time.Second).UnixNano())
|
||
require.True(t, pool.shouldHealthCheckConn(conn))
|
||
}
|
||
|
||
func TestOpenAIWSConn_LeaseAndTimeHelpers_NilAndClosedBranches(t *testing.T) {
|
||
var nilConn *openAIWSConn
|
||
nilConn.touch()
|
||
require.Equal(t, time.Time{}, nilConn.createdAt())
|
||
require.Equal(t, time.Time{}, nilConn.lastUsedAt())
|
||
require.Equal(t, time.Duration(0), nilConn.idleDuration(time.Now()))
|
||
require.Equal(t, time.Duration(0), nilConn.age(time.Now()))
|
||
require.False(t, nilConn.isLeased())
|
||
require.False(t, nilConn.isPrewarmed())
|
||
nilConn.markPrewarmed()
|
||
|
||
conn := newOpenAIWSConn("lease_state", 1, &openAIWSFakeConn{}, nil)
|
||
require.True(t, conn.tryAcquire())
|
||
require.True(t, conn.isLeased())
|
||
conn.release()
|
||
require.False(t, conn.isLeased())
|
||
conn.close()
|
||
require.False(t, conn.tryAcquire())
|
||
|
||
ctx, cancel := context.WithCancel(context.Background())
|
||
cancel()
|
||
err := conn.acquire(ctx)
|
||
require.Error(t, err)
|
||
}
|
||
|
||
func TestOpenAIWSConnLease_ReadWriteNilConnBranches(t *testing.T) {
|
||
lease := &openAIWSConnLease{}
|
||
require.ErrorIs(t, lease.WriteJSON(map[string]any{"k": "v"}, time.Second), errOpenAIWSConnClosed)
|
||
require.ErrorIs(t, lease.WriteJSONContext(context.Background(), map[string]any{"k": "v"}), errOpenAIWSConnClosed)
|
||
_, err := lease.ReadMessage(10 * time.Millisecond)
|
||
require.ErrorIs(t, err, errOpenAIWSConnClosed)
|
||
_, err = lease.ReadMessageContext(context.Background())
|
||
require.ErrorIs(t, err, errOpenAIWSConnClosed)
|
||
_, err = lease.ReadMessageWithContextTimeout(context.Background(), 10*time.Millisecond)
|
||
require.ErrorIs(t, err, errOpenAIWSConnClosed)
|
||
}
|
||
|
||
func TestOpenAIWSConnLease_ReleasedLeaseGuards(t *testing.T) {
|
||
conn := newOpenAIWSConn("released_guard", 1, &openAIWSFakeConn{}, nil)
|
||
lease := &openAIWSConnLease{conn: conn}
|
||
|
||
require.NoError(t, lease.PingWithTimeout(50*time.Millisecond))
|
||
|
||
lease.Release()
|
||
lease.Release() // idempotent
|
||
|
||
require.ErrorIs(t, lease.WriteJSON(map[string]any{"k": "v"}, time.Second), errOpenAIWSConnClosed)
|
||
require.ErrorIs(t, lease.WriteJSONContext(context.Background(), map[string]any{"k": "v"}), errOpenAIWSConnClosed)
|
||
require.ErrorIs(t, lease.WriteJSONWithContextTimeout(context.Background(), map[string]any{"k": "v"}, time.Second), errOpenAIWSConnClosed)
|
||
|
||
_, err := lease.ReadMessage(10 * time.Millisecond)
|
||
require.ErrorIs(t, err, errOpenAIWSConnClosed)
|
||
_, err = lease.ReadMessageContext(context.Background())
|
||
require.ErrorIs(t, err, errOpenAIWSConnClosed)
|
||
_, err = lease.ReadMessageWithContextTimeout(context.Background(), 10*time.Millisecond)
|
||
require.ErrorIs(t, err, errOpenAIWSConnClosed)
|
||
|
||
require.ErrorIs(t, lease.PingWithTimeout(50*time.Millisecond), errOpenAIWSConnClosed)
|
||
}
|
||
|
||
func TestOpenAIWSConnLease_MarkBrokenAfterRelease_NoEviction(t *testing.T) {
|
||
conn := newOpenAIWSConn("released_markbroken", 7, &openAIWSFakeConn{}, nil)
|
||
ap := &openAIWSAccountPool{
|
||
conns: map[string]*openAIWSConn{
|
||
conn.id: conn,
|
||
},
|
||
}
|
||
pool := &openAIWSConnPool{}
|
||
pool.accounts.Store(int64(7), ap)
|
||
|
||
lease := &openAIWSConnLease{
|
||
pool: pool,
|
||
accountID: 7,
|
||
conn: conn,
|
||
}
|
||
|
||
lease.Release()
|
||
lease.MarkBroken()
|
||
|
||
ap.mu.Lock()
|
||
_, exists := ap.conns[conn.id]
|
||
ap.mu.Unlock()
|
||
require.True(t, exists, "released lease should not evict active pool connection")
|
||
}
|
||
|
||
func TestOpenAIWSConn_AdditionalGuardBranches(t *testing.T) {
|
||
var nilConn *openAIWSConn
|
||
require.False(t, nilConn.tryAcquire())
|
||
require.ErrorIs(t, nilConn.acquire(context.Background()), errOpenAIWSConnClosed)
|
||
nilConn.release()
|
||
nilConn.close()
|
||
require.Equal(t, "", nilConn.handshakeHeader("x-test"))
|
||
|
||
connBusy := newOpenAIWSConn("busy_ctx", 1, &openAIWSFakeConn{}, nil)
|
||
require.True(t, connBusy.tryAcquire())
|
||
ctx, cancel := context.WithCancel(context.Background())
|
||
cancel()
|
||
require.ErrorIs(t, connBusy.acquire(ctx), context.Canceled)
|
||
connBusy.release()
|
||
|
||
connClosed := newOpenAIWSConn("closed_guard", 1, &openAIWSFakeConn{}, nil)
|
||
connClosed.close()
|
||
require.ErrorIs(
|
||
t,
|
||
connClosed.writeJSONWithTimeout(context.Background(), map[string]any{"k": "v"}, time.Second),
|
||
errOpenAIWSConnClosed,
|
||
)
|
||
_, err := connClosed.readMessageWithContextTimeout(context.Background(), time.Second)
|
||
require.ErrorIs(t, err, errOpenAIWSConnClosed)
|
||
require.ErrorIs(t, connClosed.pingWithTimeout(time.Second), errOpenAIWSConnClosed)
|
||
|
||
connNoWS := newOpenAIWSConn("no_ws", 1, nil, nil)
|
||
require.ErrorIs(t, connNoWS.writeJSON(map[string]any{"k": "v"}, context.Background()), errOpenAIWSConnClosed)
|
||
_, err = connNoWS.readMessage(context.Background())
|
||
require.ErrorIs(t, err, errOpenAIWSConnClosed)
|
||
require.ErrorIs(t, connNoWS.pingWithTimeout(time.Second), errOpenAIWSConnClosed)
|
||
require.Equal(t, "", connNoWS.handshakeHeader("x-test"))
|
||
|
||
connOK := newOpenAIWSConn("ok", 1, &openAIWSFakeConn{}, nil)
|
||
require.NoError(t, connOK.writeJSON(map[string]any{"k": "v"}, nil))
|
||
_, err = connOK.readMessageWithContextTimeout(context.Background(), 0)
|
||
require.NoError(t, err)
|
||
require.NoError(t, connOK.pingWithTimeout(0))
|
||
|
||
connZero := newOpenAIWSConn("zero_ts", 1, &openAIWSFakeConn{}, nil)
|
||
connZero.createdAtNano.Store(0)
|
||
connZero.lastUsedNano.Store(0)
|
||
require.True(t, connZero.createdAt().IsZero())
|
||
require.True(t, connZero.lastUsedAt().IsZero())
|
||
require.Equal(t, time.Duration(0), connZero.idleDuration(time.Now()))
|
||
require.Equal(t, time.Duration(0), connZero.age(time.Now()))
|
||
|
||
require.Nil(t, cloneOpenAIWSAcquireRequestPtr(nil))
|
||
copied := cloneHeader(http.Header{
|
||
"X-Empty": []string{},
|
||
"X-Test": []string{"v1"},
|
||
})
|
||
require.Contains(t, copied, "X-Empty")
|
||
require.Nil(t, copied["X-Empty"])
|
||
require.Equal(t, "v1", copied.Get("X-Test"))
|
||
|
||
closeOpenAIWSConns([]*openAIWSConn{nil, connOK})
|
||
}
|
||
|
||
func TestOpenAIWSConnLease_MarkBrokenEvictsConn(t *testing.T) {
|
||
pool := newOpenAIWSConnPool(&config.Config{})
|
||
accountID := int64(5001)
|
||
conn := newOpenAIWSConn("broken_me", accountID, &openAIWSFakeConn{}, nil)
|
||
ap := pool.getOrCreateAccountPool(accountID)
|
||
ap.mu.Lock()
|
||
ap.conns[conn.id] = conn
|
||
ap.mu.Unlock()
|
||
|
||
lease := &openAIWSConnLease{
|
||
pool: pool,
|
||
accountID: accountID,
|
||
conn: conn,
|
||
}
|
||
lease.MarkBroken()
|
||
|
||
ap.mu.Lock()
|
||
_, exists := ap.conns[conn.id]
|
||
ap.mu.Unlock()
|
||
require.False(t, exists)
|
||
require.False(t, conn.tryAcquire(), "被标记为 broken 的连接应被关闭")
|
||
}
|
||
|
||
func TestOpenAIWSConnPool_TargetConnCountAndPrewarmBranches(t *testing.T) {
|
||
cfg := &config.Config{}
|
||
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
|
||
pool := newOpenAIWSConnPool(cfg)
|
||
|
||
require.Equal(t, 0, pool.targetConnCountLocked(nil, 1))
|
||
ap := &openAIWSAccountPool{conns: map[string]*openAIWSConn{}}
|
||
require.Equal(t, 0, pool.targetConnCountLocked(ap, 0))
|
||
|
||
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 3
|
||
require.Equal(t, 1, pool.targetConnCountLocked(ap, 1), "minIdle 应被 maxConns 截断")
|
||
|
||
// 覆盖 waiters>0 且 target 需要至少 len(conns)+1 的分支
|
||
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
|
||
cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.9
|
||
busy := newOpenAIWSConn("busy_target", 2, &openAIWSFakeConn{}, nil)
|
||
require.True(t, busy.tryAcquire())
|
||
busy.waiters.Store(1)
|
||
ap.conns[busy.id] = busy
|
||
target := pool.targetConnCountLocked(ap, 4)
|
||
require.GreaterOrEqual(t, target, len(ap.conns)+1)
|
||
|
||
// prewarm: account pool 缺失时,拨号后的连接应被关闭并提前返回
|
||
req := openAIWSAcquireRequest{
|
||
Account: &Account{ID: 999, Platform: PlatformOpenAI, Type: AccountTypeAPIKey},
|
||
WSURL: "wss://example.com/v1/responses",
|
||
}
|
||
pool.prewarmConns(999, req, 1)
|
||
|
||
// prewarm: 拨号失败分支(prewarmFails 累加)
|
||
accountID := int64(1000)
|
||
failPool := newOpenAIWSConnPool(cfg)
|
||
failPool.setClientDialerForTest(&openAIWSAlwaysFailDialer{})
|
||
apFail := failPool.getOrCreateAccountPool(accountID)
|
||
apFail.mu.Lock()
|
||
apFail.creating = 1
|
||
apFail.mu.Unlock()
|
||
req.Account.ID = accountID
|
||
failPool.prewarmConns(accountID, req, 1)
|
||
apFail.mu.Lock()
|
||
require.GreaterOrEqual(t, apFail.prewarmFails, 1)
|
||
apFail.mu.Unlock()
|
||
}
|
||
|
||
func TestOpenAIWSConnPool_Acquire_ErrorBranches(t *testing.T) {
|
||
var nilPool *openAIWSConnPool
|
||
_, err := nilPool.Acquire(context.Background(), openAIWSAcquireRequest{})
|
||
require.Error(t, err)
|
||
|
||
pool := newOpenAIWSConnPool(&config.Config{})
|
||
_, err = pool.Acquire(context.Background(), openAIWSAcquireRequest{
|
||
Account: &Account{ID: 1},
|
||
WSURL: " ",
|
||
})
|
||
require.Error(t, err)
|
||
require.Contains(t, err.Error(), "ws url is empty")
|
||
|
||
// target=nil 分支:池满且仅有 nil 连接
|
||
cfg := &config.Config{}
|
||
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
|
||
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 1
|
||
fullPool := newOpenAIWSConnPool(cfg)
|
||
account := &Account{ID: 2001, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||
ap := fullPool.getOrCreateAccountPool(account.ID)
|
||
ap.mu.Lock()
|
||
ap.conns["nil"] = nil
|
||
ap.lastCleanupAt = time.Now()
|
||
ap.mu.Unlock()
|
||
_, err = fullPool.Acquire(context.Background(), openAIWSAcquireRequest{
|
||
Account: account,
|
||
WSURL: "wss://example.com/v1/responses",
|
||
})
|
||
require.ErrorIs(t, err, errOpenAIWSConnClosed)
|
||
|
||
// queue full 分支:waiters 达上限
|
||
account2 := &Account{ID: 2002, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||
ap2 := fullPool.getOrCreateAccountPool(account2.ID)
|
||
conn := newOpenAIWSConn("queue_full", account2.ID, &openAIWSFakeConn{}, nil)
|
||
require.True(t, conn.tryAcquire())
|
||
conn.waiters.Store(1)
|
||
ap2.mu.Lock()
|
||
ap2.conns[conn.id] = conn
|
||
ap2.lastCleanupAt = time.Now()
|
||
ap2.mu.Unlock()
|
||
_, err = fullPool.Acquire(context.Background(), openAIWSAcquireRequest{
|
||
Account: account2,
|
||
WSURL: "wss://example.com/v1/responses",
|
||
})
|
||
require.ErrorIs(t, err, errOpenAIWSConnQueueFull)
|
||
}
|
||
|
||
type openAIWSFakeDialer struct{}
|
||
|
||
func (d *openAIWSFakeDialer) Dial(
|
||
ctx context.Context,
|
||
wsURL string,
|
||
headers http.Header,
|
||
proxyURL string,
|
||
) (openAIWSClientConn, int, http.Header, error) {
|
||
_ = ctx
|
||
_ = wsURL
|
||
_ = headers
|
||
_ = proxyURL
|
||
return &openAIWSFakeConn{}, 0, nil, nil
|
||
}
|
||
|
||
type openAIWSCountingDialer struct {
|
||
mu sync.Mutex
|
||
dialCount int
|
||
}
|
||
|
||
type openAIWSAlwaysFailDialer struct {
|
||
mu sync.Mutex
|
||
dialCount int
|
||
}
|
||
|
||
type openAIWSPingBlockingConn struct {
|
||
current *atomic.Int32
|
||
maxConcurrent *atomic.Int32
|
||
release <-chan struct{}
|
||
}
|
||
|
||
func (c *openAIWSPingBlockingConn) WriteJSON(context.Context, any) error {
|
||
return nil
|
||
}
|
||
|
||
func (c *openAIWSPingBlockingConn) ReadMessage(context.Context) ([]byte, error) {
|
||
return []byte(`{"type":"response.completed","response":{"id":"resp_blocking_ping"}}`), nil
|
||
}
|
||
|
||
func (c *openAIWSPingBlockingConn) Ping(ctx context.Context) error {
|
||
if c.current == nil || c.maxConcurrent == nil {
|
||
return nil
|
||
}
|
||
|
||
now := c.current.Add(1)
|
||
for {
|
||
prev := c.maxConcurrent.Load()
|
||
if now <= prev || c.maxConcurrent.CompareAndSwap(prev, now) {
|
||
break
|
||
}
|
||
}
|
||
defer c.current.Add(-1)
|
||
|
||
select {
|
||
case <-ctx.Done():
|
||
return ctx.Err()
|
||
case <-c.release:
|
||
return nil
|
||
}
|
||
}
|
||
|
||
func (c *openAIWSPingBlockingConn) Close() error {
|
||
return nil
|
||
}
|
||
|
||
func (d *openAIWSCountingDialer) Dial(
|
||
ctx context.Context,
|
||
wsURL string,
|
||
headers http.Header,
|
||
proxyURL string,
|
||
) (openAIWSClientConn, int, http.Header, error) {
|
||
_ = ctx
|
||
_ = wsURL
|
||
_ = headers
|
||
_ = proxyURL
|
||
d.mu.Lock()
|
||
d.dialCount++
|
||
d.mu.Unlock()
|
||
return &openAIWSFakeConn{}, 0, nil, nil
|
||
}
|
||
|
||
func (d *openAIWSCountingDialer) DialCount() int {
|
||
d.mu.Lock()
|
||
defer d.mu.Unlock()
|
||
return d.dialCount
|
||
}
|
||
|
||
func (d *openAIWSAlwaysFailDialer) Dial(
|
||
ctx context.Context,
|
||
wsURL string,
|
||
headers http.Header,
|
||
proxyURL string,
|
||
) (openAIWSClientConn, int, http.Header, error) {
|
||
_ = ctx
|
||
_ = wsURL
|
||
_ = headers
|
||
_ = proxyURL
|
||
d.mu.Lock()
|
||
d.dialCount++
|
||
d.mu.Unlock()
|
||
return nil, 503, nil, errors.New("dial failed")
|
||
}
|
||
|
||
func (d *openAIWSAlwaysFailDialer) DialCount() int {
|
||
d.mu.Lock()
|
||
defer d.mu.Unlock()
|
||
return d.dialCount
|
||
}
|
||
|
||
type openAIWSFakeConn struct {
|
||
mu sync.Mutex
|
||
closed bool
|
||
payload [][]byte
|
||
}
|
||
|
||
func (c *openAIWSFakeConn) WriteJSON(ctx context.Context, value any) error {
|
||
_ = ctx
|
||
c.mu.Lock()
|
||
defer c.mu.Unlock()
|
||
if c.closed {
|
||
return errors.New("closed")
|
||
}
|
||
c.payload = append(c.payload, []byte("ok"))
|
||
_ = value
|
||
return nil
|
||
}
|
||
|
||
func (c *openAIWSFakeConn) ReadMessage(ctx context.Context) ([]byte, error) {
|
||
_ = ctx
|
||
c.mu.Lock()
|
||
defer c.mu.Unlock()
|
||
if c.closed {
|
||
return nil, errors.New("closed")
|
||
}
|
||
return []byte(`{"type":"response.completed","response":{"id":"resp_fake"}}`), nil
|
||
}
|
||
|
||
func (c *openAIWSFakeConn) Ping(ctx context.Context) error {
|
||
_ = ctx
|
||
return nil
|
||
}
|
||
|
||
func (c *openAIWSFakeConn) Close() error {
|
||
c.mu.Lock()
|
||
defer c.mu.Unlock()
|
||
c.closed = true
|
||
return nil
|
||
}
|
||
|
||
type openAIWSBlockingConn struct {
|
||
readDelay time.Duration
|
||
}
|
||
|
||
func (c *openAIWSBlockingConn) WriteJSON(ctx context.Context, value any) error {
|
||
_ = ctx
|
||
_ = value
|
||
return nil
|
||
}
|
||
|
||
func (c *openAIWSBlockingConn) ReadMessage(ctx context.Context) ([]byte, error) {
|
||
delay := c.readDelay
|
||
if delay <= 0 {
|
||
delay = 10 * time.Millisecond
|
||
}
|
||
timer := time.NewTimer(delay)
|
||
defer timer.Stop()
|
||
|
||
select {
|
||
case <-ctx.Done():
|
||
return nil, ctx.Err()
|
||
case <-timer.C:
|
||
return []byte(`{"type":"response.completed","response":{"id":"resp_blocking"}}`), nil
|
||
}
|
||
}
|
||
|
||
func (c *openAIWSBlockingConn) Ping(ctx context.Context) error {
|
||
_ = ctx
|
||
return nil
|
||
}
|
||
|
||
func (c *openAIWSBlockingConn) Close() error {
|
||
return nil
|
||
}
|
||
|
||
type openAIWSWriteBlockingConn struct{}
|
||
|
||
func (c *openAIWSWriteBlockingConn) WriteJSON(ctx context.Context, _ any) error {
|
||
<-ctx.Done()
|
||
return ctx.Err()
|
||
}
|
||
|
||
func (c *openAIWSWriteBlockingConn) ReadMessage(context.Context) ([]byte, error) {
|
||
return []byte(`{"type":"response.completed","response":{"id":"resp_write_block"}}`), nil
|
||
}
|
||
|
||
func (c *openAIWSWriteBlockingConn) Ping(context.Context) error {
|
||
return nil
|
||
}
|
||
|
||
func (c *openAIWSWriteBlockingConn) Close() error {
|
||
return nil
|
||
}
|
||
|
||
type openAIWSPingFailConn struct{}
|
||
|
||
func (c *openAIWSPingFailConn) WriteJSON(context.Context, any) error {
|
||
return nil
|
||
}
|
||
|
||
func (c *openAIWSPingFailConn) ReadMessage(context.Context) ([]byte, error) {
|
||
return []byte(`{"type":"response.completed","response":{"id":"resp_ping_fail"}}`), nil
|
||
}
|
||
|
||
func (c *openAIWSPingFailConn) Ping(context.Context) error {
|
||
return errors.New("ping failed")
|
||
}
|
||
|
||
func (c *openAIWSPingFailConn) Close() error {
|
||
return nil
|
||
}
|
||
|
||
type openAIWSContextProbeConn struct {
|
||
lastWriteCtx context.Context
|
||
}
|
||
|
||
func (c *openAIWSContextProbeConn) WriteJSON(ctx context.Context, _ any) error {
|
||
c.lastWriteCtx = ctx
|
||
return nil
|
||
}
|
||
|
||
func (c *openAIWSContextProbeConn) ReadMessage(context.Context) ([]byte, error) {
|
||
return []byte(`{"type":"response.completed","response":{"id":"resp_ctx_probe"}}`), nil
|
||
}
|
||
|
||
func (c *openAIWSContextProbeConn) Ping(context.Context) error {
|
||
return nil
|
||
}
|
||
|
||
func (c *openAIWSContextProbeConn) Close() error {
|
||
return nil
|
||
}
|
||
|
||
type openAIWSNilConnDialer struct{}
|
||
|
||
func (d *openAIWSNilConnDialer) Dial(
|
||
ctx context.Context,
|
||
wsURL string,
|
||
headers http.Header,
|
||
proxyURL string,
|
||
) (openAIWSClientConn, int, http.Header, error) {
|
||
_ = ctx
|
||
_ = wsURL
|
||
_ = headers
|
||
_ = proxyURL
|
||
return nil, 200, nil, nil
|
||
}
|
||
|
||
func TestOpenAIWSConnPool_DialConnNilConnection(t *testing.T) {
|
||
cfg := &config.Config{}
|
||
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
|
||
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1
|
||
|
||
pool := newOpenAIWSConnPool(cfg)
|
||
pool.setClientDialerForTest(&openAIWSNilConnDialer{})
|
||
account := &Account{ID: 91, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||
|
||
_, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{
|
||
Account: account,
|
||
WSURL: "wss://example.com/v1/responses",
|
||
})
|
||
require.Error(t, err)
|
||
require.Contains(t, err.Error(), "nil connection")
|
||
}
|
||
|
||
func TestOpenAIWSConnPool_SnapshotTransportMetrics(t *testing.T) {
|
||
cfg := &config.Config{}
|
||
pool := newOpenAIWSConnPool(cfg)
|
||
|
||
dialer, ok := pool.clientDialer.(*coderOpenAIWSClientDialer)
|
||
require.True(t, ok)
|
||
|
||
_, err := dialer.proxyHTTPClient("http://127.0.0.1:28080")
|
||
require.NoError(t, err)
|
||
_, err = dialer.proxyHTTPClient("http://127.0.0.1:28080")
|
||
require.NoError(t, err)
|
||
_, err = dialer.proxyHTTPClient("http://127.0.0.1:28081")
|
||
require.NoError(t, err)
|
||
|
||
snapshot := pool.SnapshotTransportMetrics()
|
||
require.Equal(t, int64(1), snapshot.ProxyClientCacheHits)
|
||
require.Equal(t, int64(2), snapshot.ProxyClientCacheMisses)
|
||
require.InDelta(t, 1.0/3.0, snapshot.TransportReuseRatio, 0.0001)
|
||
}
|