Files
sub2api/backend/internal/service/openai_ws_pool.go
2026-02-28 15:01:20 +08:00

1707 lines
40 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"context"
"errors"
"fmt"
"math"
"net/http"
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"golang.org/x/sync/errgroup"
)
const (
openAIWSConnMaxAge = 60 * time.Minute
openAIWSConnHealthCheckIdle = 90 * time.Second
openAIWSConnHealthCheckTO = 2 * time.Second
openAIWSConnPrewarmExtraDelay = 2 * time.Second
openAIWSAcquireCleanupInterval = 3 * time.Second
openAIWSBackgroundPingInterval = 30 * time.Second
openAIWSBackgroundSweepTicker = 30 * time.Second
openAIWSPrewarmFailureWindow = 30 * time.Second
openAIWSPrewarmFailureSuppress = 2
)
var (
errOpenAIWSConnClosed = errors.New("openai ws connection closed")
errOpenAIWSConnQueueFull = errors.New("openai ws connection queue full")
errOpenAIWSPreferredConnUnavailable = errors.New("openai ws preferred connection unavailable")
)
type openAIWSDialError struct {
StatusCode int
ResponseHeaders http.Header
Err error
}
func (e *openAIWSDialError) Error() string {
if e == nil {
return ""
}
if e.StatusCode > 0 {
return fmt.Sprintf("openai ws dial failed: status=%d err=%v", e.StatusCode, e.Err)
}
return fmt.Sprintf("openai ws dial failed: %v", e.Err)
}
func (e *openAIWSDialError) Unwrap() error {
if e == nil {
return nil
}
return e.Err
}
type openAIWSAcquireRequest struct {
Account *Account
WSURL string
Headers http.Header
ProxyURL string
PreferredConnID string
// ForceNewConn: 强制本次获取新连接(避免复用导致连接内续链状态互相污染)。
ForceNewConn bool
// ForcePreferredConn: 强制本次只使用 PreferredConnID禁止漂移到其它连接。
ForcePreferredConn bool
}
type openAIWSConnLease struct {
pool *openAIWSConnPool
accountID int64
conn *openAIWSConn
queueWait time.Duration
connPick time.Duration
reused bool
released atomic.Bool
}
func (l *openAIWSConnLease) activeConn() (*openAIWSConn, error) {
if l == nil || l.conn == nil {
return nil, errOpenAIWSConnClosed
}
if l.released.Load() {
return nil, errOpenAIWSConnClosed
}
return l.conn, nil
}
func (l *openAIWSConnLease) ConnID() string {
if l == nil || l.conn == nil {
return ""
}
return l.conn.id
}
func (l *openAIWSConnLease) QueueWaitDuration() time.Duration {
if l == nil {
return 0
}
return l.queueWait
}
func (l *openAIWSConnLease) ConnPickDuration() time.Duration {
if l == nil {
return 0
}
return l.connPick
}
func (l *openAIWSConnLease) Reused() bool {
if l == nil {
return false
}
return l.reused
}
func (l *openAIWSConnLease) HandshakeHeader(name string) string {
if l == nil || l.conn == nil {
return ""
}
return l.conn.handshakeHeader(name)
}
func (l *openAIWSConnLease) IsPrewarmed() bool {
if l == nil || l.conn == nil {
return false
}
return l.conn.isPrewarmed()
}
func (l *openAIWSConnLease) MarkPrewarmed() {
if l == nil || l.conn == nil {
return
}
l.conn.markPrewarmed()
}
func (l *openAIWSConnLease) WriteJSON(value any, timeout time.Duration) error {
conn, err := l.activeConn()
if err != nil {
return err
}
return conn.writeJSONWithTimeout(context.Background(), value, timeout)
}
func (l *openAIWSConnLease) WriteJSONWithContextTimeout(ctx context.Context, value any, timeout time.Duration) error {
conn, err := l.activeConn()
if err != nil {
return err
}
return conn.writeJSONWithTimeout(ctx, value, timeout)
}
func (l *openAIWSConnLease) WriteJSONContext(ctx context.Context, value any) error {
conn, err := l.activeConn()
if err != nil {
return err
}
return conn.writeJSON(value, ctx)
}
func (l *openAIWSConnLease) ReadMessage(timeout time.Duration) ([]byte, error) {
conn, err := l.activeConn()
if err != nil {
return nil, err
}
return conn.readMessageWithTimeout(timeout)
}
func (l *openAIWSConnLease) ReadMessageContext(ctx context.Context) ([]byte, error) {
conn, err := l.activeConn()
if err != nil {
return nil, err
}
return conn.readMessage(ctx)
}
func (l *openAIWSConnLease) ReadMessageWithContextTimeout(ctx context.Context, timeout time.Duration) ([]byte, error) {
conn, err := l.activeConn()
if err != nil {
return nil, err
}
return conn.readMessageWithContextTimeout(ctx, timeout)
}
func (l *openAIWSConnLease) PingWithTimeout(timeout time.Duration) error {
conn, err := l.activeConn()
if err != nil {
return err
}
return conn.pingWithTimeout(timeout)
}
func (l *openAIWSConnLease) MarkBroken() {
if l == nil || l.pool == nil || l.conn == nil || l.released.Load() {
return
}
l.pool.evictConn(l.accountID, l.conn.id)
}
func (l *openAIWSConnLease) Release() {
if l == nil || l.conn == nil {
return
}
if !l.released.CompareAndSwap(false, true) {
return
}
l.conn.release()
}
type openAIWSConn struct {
id string
ws openAIWSClientConn
handshakeHeaders http.Header
leaseCh chan struct{}
closedCh chan struct{}
closeOnce sync.Once
readMu sync.Mutex
writeMu sync.Mutex
waiters atomic.Int32
createdAtNano atomic.Int64
lastUsedNano atomic.Int64
prewarmed atomic.Bool
}
func newOpenAIWSConn(id string, _ int64, ws openAIWSClientConn, handshakeHeaders http.Header) *openAIWSConn {
now := time.Now()
conn := &openAIWSConn{
id: id,
ws: ws,
handshakeHeaders: cloneHeader(handshakeHeaders),
leaseCh: make(chan struct{}, 1),
closedCh: make(chan struct{}),
}
conn.leaseCh <- struct{}{}
conn.createdAtNano.Store(now.UnixNano())
conn.lastUsedNano.Store(now.UnixNano())
return conn
}
func (c *openAIWSConn) tryAcquire() bool {
if c == nil {
return false
}
select {
case <-c.closedCh:
return false
default:
}
select {
case <-c.leaseCh:
select {
case <-c.closedCh:
c.release()
return false
default:
}
return true
default:
return false
}
}
func (c *openAIWSConn) acquire(ctx context.Context) error {
if c == nil {
return errOpenAIWSConnClosed
}
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-c.closedCh:
return errOpenAIWSConnClosed
case <-c.leaseCh:
select {
case <-c.closedCh:
c.release()
return errOpenAIWSConnClosed
default:
}
return nil
}
}
}
func (c *openAIWSConn) release() {
if c == nil {
return
}
select {
case c.leaseCh <- struct{}{}:
default:
}
c.touch()
}
func (c *openAIWSConn) close() {
if c == nil {
return
}
c.closeOnce.Do(func() {
close(c.closedCh)
if c.ws != nil {
_ = c.ws.Close()
}
select {
case c.leaseCh <- struct{}{}:
default:
}
})
}
func (c *openAIWSConn) writeJSONWithTimeout(parent context.Context, value any, timeout time.Duration) error {
if c == nil {
return errOpenAIWSConnClosed
}
select {
case <-c.closedCh:
return errOpenAIWSConnClosed
default:
}
writeCtx := parent
if writeCtx == nil {
writeCtx = context.Background()
}
if timeout <= 0 {
return c.writeJSON(value, writeCtx)
}
var cancel context.CancelFunc
writeCtx, cancel = context.WithTimeout(writeCtx, timeout)
defer cancel()
return c.writeJSON(value, writeCtx)
}
func (c *openAIWSConn) writeJSON(value any, writeCtx context.Context) error {
c.writeMu.Lock()
defer c.writeMu.Unlock()
if c.ws == nil {
return errOpenAIWSConnClosed
}
if writeCtx == nil {
writeCtx = context.Background()
}
if err := c.ws.WriteJSON(writeCtx, value); err != nil {
return err
}
c.touch()
return nil
}
func (c *openAIWSConn) readMessageWithTimeout(timeout time.Duration) ([]byte, error) {
return c.readMessageWithContextTimeout(context.Background(), timeout)
}
func (c *openAIWSConn) readMessageWithContextTimeout(parent context.Context, timeout time.Duration) ([]byte, error) {
if c == nil {
return nil, errOpenAIWSConnClosed
}
select {
case <-c.closedCh:
return nil, errOpenAIWSConnClosed
default:
}
if parent == nil {
parent = context.Background()
}
if timeout <= 0 {
return c.readMessage(parent)
}
readCtx, cancel := context.WithTimeout(parent, timeout)
defer cancel()
return c.readMessage(readCtx)
}
func (c *openAIWSConn) readMessage(readCtx context.Context) ([]byte, error) {
c.readMu.Lock()
defer c.readMu.Unlock()
if c.ws == nil {
return nil, errOpenAIWSConnClosed
}
if readCtx == nil {
readCtx = context.Background()
}
payload, err := c.ws.ReadMessage(readCtx)
if err != nil {
return nil, err
}
c.touch()
return payload, nil
}
func (c *openAIWSConn) pingWithTimeout(timeout time.Duration) error {
if c == nil {
return errOpenAIWSConnClosed
}
select {
case <-c.closedCh:
return errOpenAIWSConnClosed
default:
}
c.writeMu.Lock()
defer c.writeMu.Unlock()
if c.ws == nil {
return errOpenAIWSConnClosed
}
if timeout <= 0 {
timeout = openAIWSConnHealthCheckTO
}
pingCtx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
if err := c.ws.Ping(pingCtx); err != nil {
return err
}
return nil
}
func (c *openAIWSConn) touch() {
if c == nil {
return
}
c.lastUsedNano.Store(time.Now().UnixNano())
}
func (c *openAIWSConn) createdAt() time.Time {
if c == nil {
return time.Time{}
}
nano := c.createdAtNano.Load()
if nano <= 0 {
return time.Time{}
}
return time.Unix(0, nano)
}
func (c *openAIWSConn) lastUsedAt() time.Time {
if c == nil {
return time.Time{}
}
nano := c.lastUsedNano.Load()
if nano <= 0 {
return time.Time{}
}
return time.Unix(0, nano)
}
func (c *openAIWSConn) idleDuration(now time.Time) time.Duration {
if c == nil {
return 0
}
last := c.lastUsedAt()
if last.IsZero() {
return 0
}
return now.Sub(last)
}
func (c *openAIWSConn) age(now time.Time) time.Duration {
if c == nil {
return 0
}
created := c.createdAt()
if created.IsZero() {
return 0
}
return now.Sub(created)
}
func (c *openAIWSConn) isLeased() bool {
if c == nil {
return false
}
return len(c.leaseCh) == 0
}
func (c *openAIWSConn) handshakeHeader(name string) string {
if c == nil || c.handshakeHeaders == nil {
return ""
}
return strings.TrimSpace(c.handshakeHeaders.Get(strings.TrimSpace(name)))
}
func (c *openAIWSConn) isPrewarmed() bool {
if c == nil {
return false
}
return c.prewarmed.Load()
}
func (c *openAIWSConn) markPrewarmed() {
if c == nil {
return
}
c.prewarmed.Store(true)
}
type openAIWSAccountPool struct {
mu sync.Mutex
conns map[string]*openAIWSConn
pinnedConns map[string]int
creating int
lastCleanupAt time.Time
lastAcquire *openAIWSAcquireRequest
prewarmActive bool
prewarmUntil time.Time
prewarmFails int
prewarmFailAt time.Time
}
type OpenAIWSPoolMetricsSnapshot struct {
AcquireTotal int64
AcquireReuseTotal int64
AcquireCreateTotal int64
AcquireQueueWaitTotal int64
AcquireQueueWaitMsTotal int64
ConnPickTotal int64
ConnPickMsTotal int64
ScaleUpTotal int64
ScaleDownTotal int64
}
type openAIWSPoolMetrics struct {
acquireTotal atomic.Int64
acquireReuseTotal atomic.Int64
acquireCreateTotal atomic.Int64
acquireQueueWaitTotal atomic.Int64
acquireQueueWaitMs atomic.Int64
connPickTotal atomic.Int64
connPickMs atomic.Int64
scaleUpTotal atomic.Int64
scaleDownTotal atomic.Int64
}
type openAIWSConnPool struct {
cfg *config.Config
// 通过接口解耦底层 WS 客户端实现,默认使用 coder/websocket。
clientDialer openAIWSClientDialer
accounts sync.Map // key: int64(accountID), value: *openAIWSAccountPool
seq atomic.Uint64
metrics openAIWSPoolMetrics
workerStopCh chan struct{}
workerWg sync.WaitGroup
closeOnce sync.Once
}
func newOpenAIWSConnPool(cfg *config.Config) *openAIWSConnPool {
pool := &openAIWSConnPool{
cfg: cfg,
clientDialer: newDefaultOpenAIWSClientDialer(),
workerStopCh: make(chan struct{}),
}
pool.startBackgroundWorkers()
return pool
}
func (p *openAIWSConnPool) SnapshotMetrics() OpenAIWSPoolMetricsSnapshot {
if p == nil {
return OpenAIWSPoolMetricsSnapshot{}
}
return OpenAIWSPoolMetricsSnapshot{
AcquireTotal: p.metrics.acquireTotal.Load(),
AcquireReuseTotal: p.metrics.acquireReuseTotal.Load(),
AcquireCreateTotal: p.metrics.acquireCreateTotal.Load(),
AcquireQueueWaitTotal: p.metrics.acquireQueueWaitTotal.Load(),
AcquireQueueWaitMsTotal: p.metrics.acquireQueueWaitMs.Load(),
ConnPickTotal: p.metrics.connPickTotal.Load(),
ConnPickMsTotal: p.metrics.connPickMs.Load(),
ScaleUpTotal: p.metrics.scaleUpTotal.Load(),
ScaleDownTotal: p.metrics.scaleDownTotal.Load(),
}
}
func (p *openAIWSConnPool) SnapshotTransportMetrics() OpenAIWSTransportMetricsSnapshot {
if p == nil {
return OpenAIWSTransportMetricsSnapshot{}
}
if dialer, ok := p.clientDialer.(openAIWSTransportMetricsDialer); ok {
return dialer.SnapshotTransportMetrics()
}
return OpenAIWSTransportMetricsSnapshot{}
}
func (p *openAIWSConnPool) setClientDialerForTest(dialer openAIWSClientDialer) {
if p == nil || dialer == nil {
return
}
p.clientDialer = dialer
}
// Close 停止后台 worker 并关闭所有空闲连接,应在优雅关闭时调用。
func (p *openAIWSConnPool) Close() {
if p == nil {
return
}
p.closeOnce.Do(func() {
if p.workerStopCh != nil {
close(p.workerStopCh)
}
p.workerWg.Wait()
// 遍历所有账户池,关闭全部空闲连接。
p.accounts.Range(func(key, value any) bool {
ap, ok := value.(*openAIWSAccountPool)
if !ok || ap == nil {
return true
}
ap.mu.Lock()
for _, conn := range ap.conns {
if conn != nil && !conn.isLeased() {
conn.close()
}
}
ap.mu.Unlock()
return true
})
})
}
func (p *openAIWSConnPool) startBackgroundWorkers() {
if p == nil || p.workerStopCh == nil {
return
}
p.workerWg.Add(2)
go func() {
defer p.workerWg.Done()
p.runBackgroundPingWorker()
}()
go func() {
defer p.workerWg.Done()
p.runBackgroundCleanupWorker()
}()
}
type openAIWSIdlePingCandidate struct {
accountID int64
conn *openAIWSConn
}
func (p *openAIWSConnPool) runBackgroundPingWorker() {
if p == nil {
return
}
ticker := time.NewTicker(openAIWSBackgroundPingInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
p.runBackgroundPingSweep()
case <-p.workerStopCh:
return
}
}
}
func (p *openAIWSConnPool) runBackgroundPingSweep() {
if p == nil {
return
}
candidates := p.snapshotIdleConnsForPing()
var g errgroup.Group
g.SetLimit(10)
for _, item := range candidates {
item := item
if item.conn == nil || item.conn.isLeased() || item.conn.waiters.Load() > 0 {
continue
}
g.Go(func() error {
if err := item.conn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil {
p.evictConn(item.accountID, item.conn.id)
}
return nil
})
}
_ = g.Wait()
}
func (p *openAIWSConnPool) snapshotIdleConnsForPing() []openAIWSIdlePingCandidate {
if p == nil {
return nil
}
candidates := make([]openAIWSIdlePingCandidate, 0)
p.accounts.Range(func(key, value any) bool {
accountID, ok := key.(int64)
if !ok || accountID <= 0 {
return true
}
ap, ok := value.(*openAIWSAccountPool)
if !ok || ap == nil {
return true
}
ap.mu.Lock()
for _, conn := range ap.conns {
if conn == nil || conn.isLeased() || conn.waiters.Load() > 0 {
continue
}
candidates = append(candidates, openAIWSIdlePingCandidate{
accountID: accountID,
conn: conn,
})
}
ap.mu.Unlock()
return true
})
return candidates
}
func (p *openAIWSConnPool) runBackgroundCleanupWorker() {
if p == nil {
return
}
ticker := time.NewTicker(openAIWSBackgroundSweepTicker)
defer ticker.Stop()
for {
select {
case <-ticker.C:
p.runBackgroundCleanupSweep(time.Now())
case <-p.workerStopCh:
return
}
}
}
func (p *openAIWSConnPool) runBackgroundCleanupSweep(now time.Time) {
if p == nil {
return
}
type cleanupResult struct {
evicted []*openAIWSConn
}
results := make([]cleanupResult, 0)
p.accounts.Range(func(_ any, value any) bool {
ap, ok := value.(*openAIWSAccountPool)
if !ok || ap == nil {
return true
}
maxConns := p.maxConnsHardCap()
ap.mu.Lock()
if ap.lastAcquire != nil && ap.lastAcquire.Account != nil {
maxConns = p.effectiveMaxConnsByAccount(ap.lastAcquire.Account)
}
evicted := p.cleanupAccountLocked(ap, now, maxConns)
ap.lastCleanupAt = now
ap.mu.Unlock()
if len(evicted) > 0 {
results = append(results, cleanupResult{evicted: evicted})
}
return true
})
for _, result := range results {
closeOpenAIWSConns(result.evicted)
}
}
func (p *openAIWSConnPool) Acquire(ctx context.Context, req openAIWSAcquireRequest) (*openAIWSConnLease, error) {
if p != nil {
p.metrics.acquireTotal.Add(1)
}
return p.acquire(ctx, cloneOpenAIWSAcquireRequest(req), 0)
}
func (p *openAIWSConnPool) acquire(ctx context.Context, req openAIWSAcquireRequest, retry int) (*openAIWSConnLease, error) {
if p == nil || req.Account == nil || req.Account.ID <= 0 {
return nil, errors.New("invalid ws acquire request")
}
if stringsTrim(req.WSURL) == "" {
return nil, errors.New("ws url is empty")
}
accountID := req.Account.ID
effectiveMaxConns := p.effectiveMaxConnsByAccount(req.Account)
if effectiveMaxConns <= 0 {
return nil, errOpenAIWSConnQueueFull
}
var evicted []*openAIWSConn
ap := p.getOrCreateAccountPool(accountID)
ap.mu.Lock()
ap.lastAcquire = cloneOpenAIWSAcquireRequestPtr(&req)
now := time.Now()
if ap.lastCleanupAt.IsZero() || now.Sub(ap.lastCleanupAt) >= openAIWSAcquireCleanupInterval {
evicted = p.cleanupAccountLocked(ap, now, effectiveMaxConns)
ap.lastCleanupAt = now
}
pickStartedAt := time.Now()
allowReuse := !req.ForceNewConn
preferredConnID := stringsTrim(req.PreferredConnID)
forcePreferredConn := allowReuse && req.ForcePreferredConn
if allowReuse {
if forcePreferredConn {
if preferredConnID == "" {
p.recordConnPickDuration(time.Since(pickStartedAt))
ap.mu.Unlock()
closeOpenAIWSConns(evicted)
return nil, errOpenAIWSPreferredConnUnavailable
}
preferredConn, ok := ap.conns[preferredConnID]
if !ok || preferredConn == nil {
p.recordConnPickDuration(time.Since(pickStartedAt))
ap.mu.Unlock()
closeOpenAIWSConns(evicted)
return nil, errOpenAIWSPreferredConnUnavailable
}
if preferredConn.tryAcquire() {
connPick := time.Since(pickStartedAt)
p.recordConnPickDuration(connPick)
ap.mu.Unlock()
closeOpenAIWSConns(evicted)
if p.shouldHealthCheckConn(preferredConn) {
if err := preferredConn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil {
preferredConn.close()
p.evictConn(accountID, preferredConn.id)
if retry < 1 {
return p.acquire(ctx, req, retry+1)
}
return nil, err
}
}
lease := &openAIWSConnLease{
pool: p,
accountID: accountID,
conn: preferredConn,
connPick: connPick,
reused: true,
}
p.metrics.acquireReuseTotal.Add(1)
p.ensureTargetIdleAsync(accountID)
return lease, nil
}
connPick := time.Since(pickStartedAt)
p.recordConnPickDuration(connPick)
if int(preferredConn.waiters.Load()) >= p.queueLimitPerConn() {
ap.mu.Unlock()
closeOpenAIWSConns(evicted)
return nil, errOpenAIWSConnQueueFull
}
preferredConn.waiters.Add(1)
ap.mu.Unlock()
closeOpenAIWSConns(evicted)
defer preferredConn.waiters.Add(-1)
waitStart := time.Now()
p.metrics.acquireQueueWaitTotal.Add(1)
if err := preferredConn.acquire(ctx); err != nil {
if errors.Is(err, errOpenAIWSConnClosed) && retry < 1 {
return p.acquire(ctx, req, retry+1)
}
return nil, err
}
if p.shouldHealthCheckConn(preferredConn) {
if err := preferredConn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil {
preferredConn.release()
preferredConn.close()
p.evictConn(accountID, preferredConn.id)
if retry < 1 {
return p.acquire(ctx, req, retry+1)
}
return nil, err
}
}
queueWait := time.Since(waitStart)
p.metrics.acquireQueueWaitMs.Add(queueWait.Milliseconds())
lease := &openAIWSConnLease{
pool: p,
accountID: accountID,
conn: preferredConn,
queueWait: queueWait,
connPick: connPick,
reused: true,
}
p.metrics.acquireReuseTotal.Add(1)
p.ensureTargetIdleAsync(accountID)
return lease, nil
}
if preferredConnID != "" {
if conn, ok := ap.conns[preferredConnID]; ok && conn.tryAcquire() {
connPick := time.Since(pickStartedAt)
p.recordConnPickDuration(connPick)
ap.mu.Unlock()
closeOpenAIWSConns(evicted)
if p.shouldHealthCheckConn(conn) {
if err := conn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil {
conn.close()
p.evictConn(accountID, conn.id)
if retry < 1 {
return p.acquire(ctx, req, retry+1)
}
return nil, err
}
}
lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: conn, connPick: connPick, reused: true}
p.metrics.acquireReuseTotal.Add(1)
p.ensureTargetIdleAsync(accountID)
return lease, nil
}
}
best := p.pickLeastBusyConnLocked(ap, "")
if best != nil && best.tryAcquire() {
connPick := time.Since(pickStartedAt)
p.recordConnPickDuration(connPick)
ap.mu.Unlock()
closeOpenAIWSConns(evicted)
if p.shouldHealthCheckConn(best) {
if err := best.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil {
best.close()
p.evictConn(accountID, best.id)
if retry < 1 {
return p.acquire(ctx, req, retry+1)
}
return nil, err
}
}
lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: best, connPick: connPick, reused: true}
p.metrics.acquireReuseTotal.Add(1)
p.ensureTargetIdleAsync(accountID)
return lease, nil
}
for _, conn := range ap.conns {
if conn == nil || conn == best {
continue
}
if conn.tryAcquire() {
connPick := time.Since(pickStartedAt)
p.recordConnPickDuration(connPick)
ap.mu.Unlock()
closeOpenAIWSConns(evicted)
if p.shouldHealthCheckConn(conn) {
if err := conn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil {
conn.close()
p.evictConn(accountID, conn.id)
if retry < 1 {
return p.acquire(ctx, req, retry+1)
}
return nil, err
}
}
lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: conn, connPick: connPick, reused: true}
p.metrics.acquireReuseTotal.Add(1)
p.ensureTargetIdleAsync(accountID)
return lease, nil
}
}
}
if req.ForceNewConn && len(ap.conns)+ap.creating >= effectiveMaxConns {
if idle := p.pickOldestIdleConnLocked(ap); idle != nil {
delete(ap.conns, idle.id)
evicted = append(evicted, idle)
p.metrics.scaleDownTotal.Add(1)
}
}
if len(ap.conns)+ap.creating < effectiveMaxConns {
connPick := time.Since(pickStartedAt)
p.recordConnPickDuration(connPick)
ap.creating++
ap.mu.Unlock()
closeOpenAIWSConns(evicted)
conn, dialErr := p.dialConn(ctx, req)
ap = p.getOrCreateAccountPool(accountID)
ap.mu.Lock()
ap.creating--
if dialErr != nil {
ap.prewarmFails++
ap.prewarmFailAt = time.Now()
ap.mu.Unlock()
return nil, dialErr
}
ap.conns[conn.id] = conn
ap.prewarmFails = 0
ap.prewarmFailAt = time.Time{}
ap.mu.Unlock()
p.metrics.acquireCreateTotal.Add(1)
if !conn.tryAcquire() {
if err := conn.acquire(ctx); err != nil {
conn.close()
p.evictConn(accountID, conn.id)
return nil, err
}
}
lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: conn, connPick: connPick}
p.ensureTargetIdleAsync(accountID)
return lease, nil
}
if req.ForceNewConn {
p.recordConnPickDuration(time.Since(pickStartedAt))
ap.mu.Unlock()
closeOpenAIWSConns(evicted)
return nil, errOpenAIWSConnQueueFull
}
target := p.pickLeastBusyConnLocked(ap, req.PreferredConnID)
connPick := time.Since(pickStartedAt)
p.recordConnPickDuration(connPick)
if target == nil {
ap.mu.Unlock()
closeOpenAIWSConns(evicted)
return nil, errOpenAIWSConnClosed
}
if int(target.waiters.Load()) >= p.queueLimitPerConn() {
ap.mu.Unlock()
closeOpenAIWSConns(evicted)
return nil, errOpenAIWSConnQueueFull
}
target.waiters.Add(1)
ap.mu.Unlock()
closeOpenAIWSConns(evicted)
defer target.waiters.Add(-1)
waitStart := time.Now()
p.metrics.acquireQueueWaitTotal.Add(1)
if err := target.acquire(ctx); err != nil {
if errors.Is(err, errOpenAIWSConnClosed) && retry < 1 {
return p.acquire(ctx, req, retry+1)
}
return nil, err
}
if p.shouldHealthCheckConn(target) {
if err := target.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil {
target.release()
target.close()
p.evictConn(accountID, target.id)
if retry < 1 {
return p.acquire(ctx, req, retry+1)
}
return nil, err
}
}
queueWait := time.Since(waitStart)
p.metrics.acquireQueueWaitMs.Add(queueWait.Milliseconds())
lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: target, queueWait: queueWait, connPick: connPick, reused: true}
p.metrics.acquireReuseTotal.Add(1)
p.ensureTargetIdleAsync(accountID)
return lease, nil
}
func (p *openAIWSConnPool) recordConnPickDuration(duration time.Duration) {
if p == nil {
return
}
if duration < 0 {
duration = 0
}
p.metrics.connPickTotal.Add(1)
p.metrics.connPickMs.Add(duration.Milliseconds())
}
func (p *openAIWSConnPool) pickOldestIdleConnLocked(ap *openAIWSAccountPool) *openAIWSConn {
if ap == nil || len(ap.conns) == 0 {
return nil
}
var oldest *openAIWSConn
for _, conn := range ap.conns {
if conn == nil || conn.isLeased() || conn.waiters.Load() > 0 || p.isConnPinnedLocked(ap, conn.id) {
continue
}
if oldest == nil || conn.lastUsedAt().Before(oldest.lastUsedAt()) {
oldest = conn
}
}
return oldest
}
func (p *openAIWSConnPool) getOrCreateAccountPool(accountID int64) *openAIWSAccountPool {
if p == nil || accountID <= 0 {
return nil
}
if existing, ok := p.accounts.Load(accountID); ok {
if ap, typed := existing.(*openAIWSAccountPool); typed && ap != nil {
return ap
}
}
ap := &openAIWSAccountPool{
conns: make(map[string]*openAIWSConn),
pinnedConns: make(map[string]int),
}
actual, _ := p.accounts.LoadOrStore(accountID, ap)
if typed, ok := actual.(*openAIWSAccountPool); ok && typed != nil {
return typed
}
return ap
}
// ensureAccountPoolLocked 兼容旧调用。
func (p *openAIWSConnPool) ensureAccountPoolLocked(accountID int64) *openAIWSAccountPool {
return p.getOrCreateAccountPool(accountID)
}
func (p *openAIWSConnPool) getAccountPool(accountID int64) (*openAIWSAccountPool, bool) {
if p == nil || accountID <= 0 {
return nil, false
}
value, ok := p.accounts.Load(accountID)
if !ok || value == nil {
return nil, false
}
ap, typed := value.(*openAIWSAccountPool)
return ap, typed && ap != nil
}
func (p *openAIWSConnPool) isConnPinnedLocked(ap *openAIWSAccountPool, connID string) bool {
if ap == nil || connID == "" || len(ap.pinnedConns) == 0 {
return false
}
return ap.pinnedConns[connID] > 0
}
func (p *openAIWSConnPool) cleanupAccountLocked(ap *openAIWSAccountPool, now time.Time, maxConns int) []*openAIWSConn {
if ap == nil {
return nil
}
maxAge := p.maxConnAge()
evicted := make([]*openAIWSConn, 0)
for id, conn := range ap.conns {
if conn == nil {
delete(ap.conns, id)
if len(ap.pinnedConns) > 0 {
delete(ap.pinnedConns, id)
}
continue
}
select {
case <-conn.closedCh:
delete(ap.conns, id)
if len(ap.pinnedConns) > 0 {
delete(ap.pinnedConns, id)
}
evicted = append(evicted, conn)
continue
default:
}
if p.isConnPinnedLocked(ap, id) {
continue
}
if maxAge > 0 && !conn.isLeased() && conn.age(now) > maxAge {
delete(ap.conns, id)
if len(ap.pinnedConns) > 0 {
delete(ap.pinnedConns, id)
}
evicted = append(evicted, conn)
}
}
if maxConns <= 0 {
maxConns = p.maxConnsHardCap()
}
maxIdle := p.maxIdlePerAccount()
if maxIdle < 0 || maxIdle > maxConns {
maxIdle = maxConns
}
if maxIdle >= 0 && len(ap.conns) > maxIdle {
idleConns := make([]*openAIWSConn, 0, len(ap.conns))
for id, conn := range ap.conns {
if conn == nil {
delete(ap.conns, id)
if len(ap.pinnedConns) > 0 {
delete(ap.pinnedConns, id)
}
continue
}
// 有等待者的连接不能在清理阶段被淘汰,否则等待中的 acquire 会收到 closed 错误。
if conn.isLeased() || conn.waiters.Load() > 0 || p.isConnPinnedLocked(ap, conn.id) {
continue
}
idleConns = append(idleConns, conn)
}
sort.SliceStable(idleConns, func(i, j int) bool {
return idleConns[i].lastUsedAt().Before(idleConns[j].lastUsedAt())
})
redundant := len(ap.conns) - maxIdle
if redundant > len(idleConns) {
redundant = len(idleConns)
}
for i := 0; i < redundant; i++ {
conn := idleConns[i]
delete(ap.conns, conn.id)
if len(ap.pinnedConns) > 0 {
delete(ap.pinnedConns, conn.id)
}
evicted = append(evicted, conn)
}
if redundant > 0 {
p.metrics.scaleDownTotal.Add(int64(redundant))
}
}
return evicted
}
func (p *openAIWSConnPool) pickLeastBusyConnLocked(ap *openAIWSAccountPool, preferredConnID string) *openAIWSConn {
if ap == nil || len(ap.conns) == 0 {
return nil
}
preferredConnID = stringsTrim(preferredConnID)
if preferredConnID != "" {
if conn, ok := ap.conns[preferredConnID]; ok {
return conn
}
}
var best *openAIWSConn
var bestWaiters int32
var bestLastUsed time.Time
for _, conn := range ap.conns {
if conn == nil {
continue
}
waiters := conn.waiters.Load()
lastUsed := conn.lastUsedAt()
if best == nil ||
waiters < bestWaiters ||
(waiters == bestWaiters && lastUsed.Before(bestLastUsed)) {
best = conn
bestWaiters = waiters
bestLastUsed = lastUsed
}
}
return best
}
func accountPoolLoadLocked(ap *openAIWSAccountPool) (inflight int, waiters int) {
if ap == nil {
return 0, 0
}
for _, conn := range ap.conns {
if conn == nil {
continue
}
if conn.isLeased() {
inflight++
}
waiters += int(conn.waiters.Load())
}
return inflight, waiters
}
// AccountPoolLoad 返回指定账号连接池的并发与排队快照。
func (p *openAIWSConnPool) AccountPoolLoad(accountID int64) (inflight int, waiters int, conns int) {
if p == nil || accountID <= 0 {
return 0, 0, 0
}
ap, ok := p.getAccountPool(accountID)
if !ok || ap == nil {
return 0, 0, 0
}
ap.mu.Lock()
defer ap.mu.Unlock()
inflight, waiters = accountPoolLoadLocked(ap)
return inflight, waiters, len(ap.conns)
}
func (p *openAIWSConnPool) ensureTargetIdleAsync(accountID int64) {
if p == nil || accountID <= 0 {
return
}
var req openAIWSAcquireRequest
need := 0
ap, ok := p.getAccountPool(accountID)
if !ok || ap == nil {
return
}
ap.mu.Lock()
defer ap.mu.Unlock()
if ap.lastAcquire == nil {
return
}
if ap.prewarmActive {
return
}
now := time.Now()
if !ap.prewarmUntil.IsZero() && now.Before(ap.prewarmUntil) {
return
}
if p.shouldSuppressPrewarmLocked(ap, now) {
return
}
effectiveMaxConns := p.maxConnsHardCap()
if ap.lastAcquire != nil && ap.lastAcquire.Account != nil {
effectiveMaxConns = p.effectiveMaxConnsByAccount(ap.lastAcquire.Account)
}
target := p.targetConnCountLocked(ap, effectiveMaxConns)
current := len(ap.conns) + ap.creating
if current >= target {
return
}
need = target - current
if need <= 0 {
return
}
req = cloneOpenAIWSAcquireRequest(*ap.lastAcquire)
ap.prewarmActive = true
if cooldown := p.prewarmCooldown(); cooldown > 0 {
ap.prewarmUntil = now.Add(cooldown)
}
ap.creating += need
p.metrics.scaleUpTotal.Add(int64(need))
go p.prewarmConns(accountID, req, need)
}
func (p *openAIWSConnPool) targetConnCountLocked(ap *openAIWSAccountPool, maxConns int) int {
if ap == nil {
return 0
}
if maxConns <= 0 {
return 0
}
minIdle := p.minIdlePerAccount()
if minIdle < 0 {
minIdle = 0
}
if minIdle > maxConns {
minIdle = maxConns
}
inflight, waiters := accountPoolLoadLocked(ap)
utilization := p.targetUtilization()
demand := inflight + waiters
if demand <= 0 {
return minIdle
}
target := 1
if demand > 1 {
target = int(math.Ceil(float64(demand) / utilization))
}
if waiters > 0 && target < len(ap.conns)+1 {
target = len(ap.conns) + 1
}
if target < minIdle {
target = minIdle
}
if target > maxConns {
target = maxConns
}
return target
}
func (p *openAIWSConnPool) prewarmConns(accountID int64, req openAIWSAcquireRequest, total int) {
defer func() {
if ap, ok := p.getAccountPool(accountID); ok && ap != nil {
ap.mu.Lock()
ap.prewarmActive = false
ap.mu.Unlock()
}
}()
for i := 0; i < total; i++ {
ctx, cancel := context.WithTimeout(context.Background(), p.dialTimeout()+openAIWSConnPrewarmExtraDelay)
conn, err := p.dialConn(ctx, req)
cancel()
ap, ok := p.getAccountPool(accountID)
if !ok || ap == nil {
if conn != nil {
conn.close()
}
return
}
ap.mu.Lock()
if ap.creating > 0 {
ap.creating--
}
if err != nil {
ap.prewarmFails++
ap.prewarmFailAt = time.Now()
ap.mu.Unlock()
continue
}
if len(ap.conns) >= p.effectiveMaxConnsByAccount(req.Account) {
ap.mu.Unlock()
conn.close()
continue
}
ap.conns[conn.id] = conn
ap.prewarmFails = 0
ap.prewarmFailAt = time.Time{}
ap.mu.Unlock()
}
}
func (p *openAIWSConnPool) evictConn(accountID int64, connID string) {
if p == nil || accountID <= 0 || stringsTrim(connID) == "" {
return
}
var conn *openAIWSConn
ap, ok := p.getAccountPool(accountID)
if ok && ap != nil {
ap.mu.Lock()
if c, exists := ap.conns[connID]; exists {
conn = c
delete(ap.conns, connID)
if len(ap.pinnedConns) > 0 {
delete(ap.pinnedConns, connID)
}
}
ap.mu.Unlock()
}
if conn != nil {
conn.close()
}
}
func (p *openAIWSConnPool) PinConn(accountID int64, connID string) bool {
if p == nil || accountID <= 0 {
return false
}
connID = stringsTrim(connID)
if connID == "" {
return false
}
ap, ok := p.getAccountPool(accountID)
if !ok || ap == nil {
return false
}
ap.mu.Lock()
defer ap.mu.Unlock()
if _, exists := ap.conns[connID]; !exists {
return false
}
if ap.pinnedConns == nil {
ap.pinnedConns = make(map[string]int)
}
ap.pinnedConns[connID]++
return true
}
func (p *openAIWSConnPool) UnpinConn(accountID int64, connID string) {
if p == nil || accountID <= 0 {
return
}
connID = stringsTrim(connID)
if connID == "" {
return
}
ap, ok := p.getAccountPool(accountID)
if !ok || ap == nil {
return
}
ap.mu.Lock()
defer ap.mu.Unlock()
if len(ap.pinnedConns) == 0 {
return
}
count := ap.pinnedConns[connID]
if count <= 1 {
delete(ap.pinnedConns, connID)
return
}
ap.pinnedConns[connID] = count - 1
}
func (p *openAIWSConnPool) dialConn(ctx context.Context, req openAIWSAcquireRequest) (*openAIWSConn, error) {
if p == nil || p.clientDialer == nil {
return nil, errors.New("openai ws client dialer is nil")
}
conn, status, handshakeHeaders, err := p.clientDialer.Dial(ctx, req.WSURL, req.Headers, req.ProxyURL)
if err != nil {
return nil, &openAIWSDialError{
StatusCode: status,
ResponseHeaders: cloneHeader(handshakeHeaders),
Err: err,
}
}
if conn == nil {
return nil, &openAIWSDialError{
StatusCode: status,
ResponseHeaders: cloneHeader(handshakeHeaders),
Err: errors.New("openai ws dialer returned nil connection"),
}
}
id := p.nextConnID(req.Account.ID)
return newOpenAIWSConn(id, req.Account.ID, conn, handshakeHeaders), nil
}
func (p *openAIWSConnPool) nextConnID(accountID int64) string {
seq := p.seq.Add(1)
buf := make([]byte, 0, 32)
buf = append(buf, "oa_ws_"...)
buf = strconv.AppendInt(buf, accountID, 10)
buf = append(buf, '_')
buf = strconv.AppendUint(buf, seq, 10)
return string(buf)
}
func (p *openAIWSConnPool) shouldHealthCheckConn(conn *openAIWSConn) bool {
if conn == nil {
return false
}
return conn.idleDuration(time.Now()) >= openAIWSConnHealthCheckIdle
}
func (p *openAIWSConnPool) maxConnsHardCap() int {
if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.MaxConnsPerAccount > 0 {
return p.cfg.Gateway.OpenAIWS.MaxConnsPerAccount
}
return 8
}
func (p *openAIWSConnPool) dynamicMaxConnsEnabled() bool {
if p != nil && p.cfg != nil {
return p.cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled
}
return false
}
func (p *openAIWSConnPool) modeRouterV2Enabled() bool {
if p != nil && p.cfg != nil {
return p.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled
}
return false
}
func (p *openAIWSConnPool) maxConnsFactorByAccount(account *Account) float64 {
if p == nil || p.cfg == nil || account == nil {
return 1.0
}
switch account.Type {
case AccountTypeOAuth:
if p.cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor > 0 {
return p.cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor
}
case AccountTypeAPIKey:
if p.cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor > 0 {
return p.cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor
}
}
return 1.0
}
func (p *openAIWSConnPool) effectiveMaxConnsByAccount(account *Account) int {
hardCap := p.maxConnsHardCap()
if hardCap <= 0 {
return 0
}
if p.modeRouterV2Enabled() {
if account == nil {
return hardCap
}
if account.Concurrency <= 0 {
return 0
}
return account.Concurrency
}
if account == nil || !p.dynamicMaxConnsEnabled() {
return hardCap
}
if account.Concurrency <= 0 {
// 0/-1 等“无限制”并发场景下,仍由全局硬上限兜底。
return hardCap
}
factor := p.maxConnsFactorByAccount(account)
if factor <= 0 {
factor = 1.0
}
effective := int(math.Ceil(float64(account.Concurrency) * factor))
if effective < 1 {
effective = 1
}
if effective > hardCap {
effective = hardCap
}
return effective
}
func (p *openAIWSConnPool) minIdlePerAccount() int {
if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.MinIdlePerAccount >= 0 {
return p.cfg.Gateway.OpenAIWS.MinIdlePerAccount
}
return 0
}
func (p *openAIWSConnPool) maxIdlePerAccount() int {
if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.MaxIdlePerAccount >= 0 {
return p.cfg.Gateway.OpenAIWS.MaxIdlePerAccount
}
return 4
}
func (p *openAIWSConnPool) maxConnAge() time.Duration {
return openAIWSConnMaxAge
}
func (p *openAIWSConnPool) queueLimitPerConn() int {
if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.QueueLimitPerConn > 0 {
return p.cfg.Gateway.OpenAIWS.QueueLimitPerConn
}
return 256
}
func (p *openAIWSConnPool) targetUtilization() float64 {
if p != nil && p.cfg != nil {
ratio := p.cfg.Gateway.OpenAIWS.PoolTargetUtilization
if ratio > 0 && ratio <= 1 {
return ratio
}
}
return 0.7
}
func (p *openAIWSConnPool) prewarmCooldown() time.Duration {
if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.PrewarmCooldownMS > 0 {
return time.Duration(p.cfg.Gateway.OpenAIWS.PrewarmCooldownMS) * time.Millisecond
}
return 0
}
func (p *openAIWSConnPool) shouldSuppressPrewarmLocked(ap *openAIWSAccountPool, now time.Time) bool {
if ap == nil {
return true
}
if ap.prewarmFails <= 0 {
return false
}
if ap.prewarmFailAt.IsZero() {
ap.prewarmFails = 0
return false
}
if now.Sub(ap.prewarmFailAt) > openAIWSPrewarmFailureWindow {
ap.prewarmFails = 0
ap.prewarmFailAt = time.Time{}
return false
}
return ap.prewarmFails >= openAIWSPrewarmFailureSuppress
}
func (p *openAIWSConnPool) dialTimeout() time.Duration {
if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.DialTimeoutSeconds > 0 {
return time.Duration(p.cfg.Gateway.OpenAIWS.DialTimeoutSeconds) * time.Second
}
return 10 * time.Second
}
func cloneOpenAIWSAcquireRequest(req openAIWSAcquireRequest) openAIWSAcquireRequest {
copied := req
copied.Headers = cloneHeader(req.Headers)
copied.WSURL = stringsTrim(req.WSURL)
copied.ProxyURL = stringsTrim(req.ProxyURL)
copied.PreferredConnID = stringsTrim(req.PreferredConnID)
return copied
}
func cloneOpenAIWSAcquireRequestPtr(req *openAIWSAcquireRequest) *openAIWSAcquireRequest {
if req == nil {
return nil
}
copied := cloneOpenAIWSAcquireRequest(*req)
return &copied
}
func cloneHeader(src http.Header) http.Header {
if src == nil {
return nil
}
dst := make(http.Header, len(src))
for k, vals := range src {
if len(vals) == 0 {
dst[k] = nil
continue
}
copied := make([]string, len(vals))
copy(copied, vals)
dst[k] = copied
}
return dst
}
func closeOpenAIWSConns(conns []*openAIWSConn) {
if len(conns) == 0 {
return
}
for _, conn := range conns {
if conn == nil {
continue
}
conn.close()
}
}
func stringsTrim(value string) string {
return strings.TrimSpace(value)
}