package service import ( "context" "errors" "fmt" "math" "net/http" "sort" "strconv" "strings" "sync" "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/config" "golang.org/x/sync/errgroup" ) const ( openAIWSConnMaxAge = 60 * time.Minute openAIWSConnHealthCheckIdle = 90 * time.Second openAIWSConnHealthCheckTO = 2 * time.Second openAIWSConnPrewarmExtraDelay = 2 * time.Second openAIWSAcquireCleanupInterval = 3 * time.Second openAIWSBackgroundPingInterval = 30 * time.Second openAIWSBackgroundSweepTicker = 30 * time.Second openAIWSPrewarmFailureWindow = 30 * time.Second openAIWSPrewarmFailureSuppress = 2 ) var ( errOpenAIWSConnClosed = errors.New("openai ws connection closed") errOpenAIWSConnQueueFull = errors.New("openai ws connection queue full") errOpenAIWSPreferredConnUnavailable = errors.New("openai ws preferred connection unavailable") ) type openAIWSDialError struct { StatusCode int ResponseHeaders http.Header Err error } func (e *openAIWSDialError) Error() string { if e == nil { return "" } if e.StatusCode > 0 { return fmt.Sprintf("openai ws dial failed: status=%d err=%v", e.StatusCode, e.Err) } return fmt.Sprintf("openai ws dial failed: %v", e.Err) } func (e *openAIWSDialError) Unwrap() error { if e == nil { return nil } return e.Err } type openAIWSAcquireRequest struct { Account *Account WSURL string Headers http.Header ProxyURL string PreferredConnID string // ForceNewConn: 强制本次获取新连接(避免复用导致连接内续链状态互相污染)。 ForceNewConn bool // ForcePreferredConn: 强制本次只使用 PreferredConnID,禁止漂移到其它连接。 ForcePreferredConn bool } type openAIWSConnLease struct { pool *openAIWSConnPool accountID int64 conn *openAIWSConn queueWait time.Duration connPick time.Duration reused bool released atomic.Bool } func (l *openAIWSConnLease) activeConn() (*openAIWSConn, error) { if l == nil || l.conn == nil { return nil, errOpenAIWSConnClosed } if l.released.Load() { return nil, errOpenAIWSConnClosed } return l.conn, nil } func (l *openAIWSConnLease) ConnID() string { if l == nil || l.conn == nil { return "" } return l.conn.id } func (l *openAIWSConnLease) QueueWaitDuration() time.Duration { if l == nil { return 0 } return l.queueWait } func (l *openAIWSConnLease) ConnPickDuration() time.Duration { if l == nil { return 0 } return l.connPick } func (l *openAIWSConnLease) Reused() bool { if l == nil { return false } return l.reused } func (l *openAIWSConnLease) HandshakeHeader(name string) string { if l == nil || l.conn == nil { return "" } return l.conn.handshakeHeader(name) } func (l *openAIWSConnLease) IsPrewarmed() bool { if l == nil || l.conn == nil { return false } return l.conn.isPrewarmed() } func (l *openAIWSConnLease) MarkPrewarmed() { if l == nil || l.conn == nil { return } l.conn.markPrewarmed() } func (l *openAIWSConnLease) WriteJSON(value any, timeout time.Duration) error { conn, err := l.activeConn() if err != nil { return err } return conn.writeJSONWithTimeout(context.Background(), value, timeout) } func (l *openAIWSConnLease) WriteJSONWithContextTimeout(ctx context.Context, value any, timeout time.Duration) error { conn, err := l.activeConn() if err != nil { return err } return conn.writeJSONWithTimeout(ctx, value, timeout) } func (l *openAIWSConnLease) WriteJSONContext(ctx context.Context, value any) error { conn, err := l.activeConn() if err != nil { return err } return conn.writeJSON(value, ctx) } func (l *openAIWSConnLease) ReadMessage(timeout time.Duration) ([]byte, error) { conn, err := l.activeConn() if err != nil { return nil, err } return conn.readMessageWithTimeout(timeout) } func (l *openAIWSConnLease) ReadMessageContext(ctx context.Context) ([]byte, error) { conn, err := l.activeConn() if err != nil { return nil, err } return conn.readMessage(ctx) } func (l *openAIWSConnLease) ReadMessageWithContextTimeout(ctx context.Context, timeout time.Duration) ([]byte, error) { conn, err := l.activeConn() if err != nil { return nil, err } return conn.readMessageWithContextTimeout(ctx, timeout) } func (l *openAIWSConnLease) PingWithTimeout(timeout time.Duration) error { conn, err := l.activeConn() if err != nil { return err } return conn.pingWithTimeout(timeout) } func (l *openAIWSConnLease) MarkBroken() { if l == nil || l.pool == nil || l.conn == nil || l.released.Load() { return } l.pool.evictConn(l.accountID, l.conn.id) } func (l *openAIWSConnLease) Release() { if l == nil || l.conn == nil { return } if !l.released.CompareAndSwap(false, true) { return } l.conn.release() } type openAIWSConn struct { id string ws openAIWSClientConn handshakeHeaders http.Header leaseCh chan struct{} closedCh chan struct{} closeOnce sync.Once readMu sync.Mutex writeMu sync.Mutex waiters atomic.Int32 createdAtNano atomic.Int64 lastUsedNano atomic.Int64 prewarmed atomic.Bool } func newOpenAIWSConn(id string, _ int64, ws openAIWSClientConn, handshakeHeaders http.Header) *openAIWSConn { now := time.Now() conn := &openAIWSConn{ id: id, ws: ws, handshakeHeaders: cloneHeader(handshakeHeaders), leaseCh: make(chan struct{}, 1), closedCh: make(chan struct{}), } conn.leaseCh <- struct{}{} conn.createdAtNano.Store(now.UnixNano()) conn.lastUsedNano.Store(now.UnixNano()) return conn } func (c *openAIWSConn) tryAcquire() bool { if c == nil { return false } select { case <-c.closedCh: return false default: } select { case <-c.leaseCh: select { case <-c.closedCh: c.release() return false default: } return true default: return false } } func (c *openAIWSConn) acquire(ctx context.Context) error { if c == nil { return errOpenAIWSConnClosed } for { select { case <-ctx.Done(): return ctx.Err() case <-c.closedCh: return errOpenAIWSConnClosed case <-c.leaseCh: select { case <-c.closedCh: c.release() return errOpenAIWSConnClosed default: } return nil } } } func (c *openAIWSConn) release() { if c == nil { return } select { case c.leaseCh <- struct{}{}: default: } c.touch() } func (c *openAIWSConn) close() { if c == nil { return } c.closeOnce.Do(func() { close(c.closedCh) if c.ws != nil { _ = c.ws.Close() } select { case c.leaseCh <- struct{}{}: default: } }) } func (c *openAIWSConn) writeJSONWithTimeout(parent context.Context, value any, timeout time.Duration) error { if c == nil { return errOpenAIWSConnClosed } select { case <-c.closedCh: return errOpenAIWSConnClosed default: } writeCtx := parent if writeCtx == nil { writeCtx = context.Background() } if timeout <= 0 { return c.writeJSON(value, writeCtx) } var cancel context.CancelFunc writeCtx, cancel = context.WithTimeout(writeCtx, timeout) defer cancel() return c.writeJSON(value, writeCtx) } func (c *openAIWSConn) writeJSON(value any, writeCtx context.Context) error { c.writeMu.Lock() defer c.writeMu.Unlock() if c.ws == nil { return errOpenAIWSConnClosed } if writeCtx == nil { writeCtx = context.Background() } if err := c.ws.WriteJSON(writeCtx, value); err != nil { return err } c.touch() return nil } func (c *openAIWSConn) readMessageWithTimeout(timeout time.Duration) ([]byte, error) { return c.readMessageWithContextTimeout(context.Background(), timeout) } func (c *openAIWSConn) readMessageWithContextTimeout(parent context.Context, timeout time.Duration) ([]byte, error) { if c == nil { return nil, errOpenAIWSConnClosed } select { case <-c.closedCh: return nil, errOpenAIWSConnClosed default: } if parent == nil { parent = context.Background() } if timeout <= 0 { return c.readMessage(parent) } readCtx, cancel := context.WithTimeout(parent, timeout) defer cancel() return c.readMessage(readCtx) } func (c *openAIWSConn) readMessage(readCtx context.Context) ([]byte, error) { c.readMu.Lock() defer c.readMu.Unlock() if c.ws == nil { return nil, errOpenAIWSConnClosed } if readCtx == nil { readCtx = context.Background() } payload, err := c.ws.ReadMessage(readCtx) if err != nil { return nil, err } c.touch() return payload, nil } func (c *openAIWSConn) pingWithTimeout(timeout time.Duration) error { if c == nil { return errOpenAIWSConnClosed } select { case <-c.closedCh: return errOpenAIWSConnClosed default: } c.writeMu.Lock() defer c.writeMu.Unlock() if c.ws == nil { return errOpenAIWSConnClosed } if timeout <= 0 { timeout = openAIWSConnHealthCheckTO } pingCtx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() if err := c.ws.Ping(pingCtx); err != nil { return err } return nil } func (c *openAIWSConn) touch() { if c == nil { return } c.lastUsedNano.Store(time.Now().UnixNano()) } func (c *openAIWSConn) createdAt() time.Time { if c == nil { return time.Time{} } nano := c.createdAtNano.Load() if nano <= 0 { return time.Time{} } return time.Unix(0, nano) } func (c *openAIWSConn) lastUsedAt() time.Time { if c == nil { return time.Time{} } nano := c.lastUsedNano.Load() if nano <= 0 { return time.Time{} } return time.Unix(0, nano) } func (c *openAIWSConn) idleDuration(now time.Time) time.Duration { if c == nil { return 0 } last := c.lastUsedAt() if last.IsZero() { return 0 } return now.Sub(last) } func (c *openAIWSConn) age(now time.Time) time.Duration { if c == nil { return 0 } created := c.createdAt() if created.IsZero() { return 0 } return now.Sub(created) } func (c *openAIWSConn) isLeased() bool { if c == nil { return false } return len(c.leaseCh) == 0 } func (c *openAIWSConn) handshakeHeader(name string) string { if c == nil || c.handshakeHeaders == nil { return "" } return strings.TrimSpace(c.handshakeHeaders.Get(strings.TrimSpace(name))) } func (c *openAIWSConn) isPrewarmed() bool { if c == nil { return false } return c.prewarmed.Load() } func (c *openAIWSConn) markPrewarmed() { if c == nil { return } c.prewarmed.Store(true) } type openAIWSAccountPool struct { mu sync.Mutex conns map[string]*openAIWSConn pinnedConns map[string]int creating int lastCleanupAt time.Time lastAcquire *openAIWSAcquireRequest prewarmActive bool prewarmUntil time.Time prewarmFails int prewarmFailAt time.Time } type OpenAIWSPoolMetricsSnapshot struct { AcquireTotal int64 AcquireReuseTotal int64 AcquireCreateTotal int64 AcquireQueueWaitTotal int64 AcquireQueueWaitMsTotal int64 ConnPickTotal int64 ConnPickMsTotal int64 ScaleUpTotal int64 ScaleDownTotal int64 } type openAIWSPoolMetrics struct { acquireTotal atomic.Int64 acquireReuseTotal atomic.Int64 acquireCreateTotal atomic.Int64 acquireQueueWaitTotal atomic.Int64 acquireQueueWaitMs atomic.Int64 connPickTotal atomic.Int64 connPickMs atomic.Int64 scaleUpTotal atomic.Int64 scaleDownTotal atomic.Int64 } type openAIWSConnPool struct { cfg *config.Config // 通过接口解耦底层 WS 客户端实现,默认使用 coder/websocket。 clientDialer openAIWSClientDialer accounts sync.Map // key: int64(accountID), value: *openAIWSAccountPool seq atomic.Uint64 metrics openAIWSPoolMetrics workerStopCh chan struct{} workerWg sync.WaitGroup closeOnce sync.Once } func newOpenAIWSConnPool(cfg *config.Config) *openAIWSConnPool { pool := &openAIWSConnPool{ cfg: cfg, clientDialer: newDefaultOpenAIWSClientDialer(), workerStopCh: make(chan struct{}), } pool.startBackgroundWorkers() return pool } func (p *openAIWSConnPool) SnapshotMetrics() OpenAIWSPoolMetricsSnapshot { if p == nil { return OpenAIWSPoolMetricsSnapshot{} } return OpenAIWSPoolMetricsSnapshot{ AcquireTotal: p.metrics.acquireTotal.Load(), AcquireReuseTotal: p.metrics.acquireReuseTotal.Load(), AcquireCreateTotal: p.metrics.acquireCreateTotal.Load(), AcquireQueueWaitTotal: p.metrics.acquireQueueWaitTotal.Load(), AcquireQueueWaitMsTotal: p.metrics.acquireQueueWaitMs.Load(), ConnPickTotal: p.metrics.connPickTotal.Load(), ConnPickMsTotal: p.metrics.connPickMs.Load(), ScaleUpTotal: p.metrics.scaleUpTotal.Load(), ScaleDownTotal: p.metrics.scaleDownTotal.Load(), } } func (p *openAIWSConnPool) SnapshotTransportMetrics() OpenAIWSTransportMetricsSnapshot { if p == nil { return OpenAIWSTransportMetricsSnapshot{} } if dialer, ok := p.clientDialer.(openAIWSTransportMetricsDialer); ok { return dialer.SnapshotTransportMetrics() } return OpenAIWSTransportMetricsSnapshot{} } func (p *openAIWSConnPool) setClientDialerForTest(dialer openAIWSClientDialer) { if p == nil || dialer == nil { return } p.clientDialer = dialer } // Close 停止后台 worker 并关闭所有空闲连接,应在优雅关闭时调用。 func (p *openAIWSConnPool) Close() { if p == nil { return } p.closeOnce.Do(func() { if p.workerStopCh != nil { close(p.workerStopCh) } p.workerWg.Wait() // 遍历所有账户池,关闭全部空闲连接。 p.accounts.Range(func(key, value any) bool { ap, ok := value.(*openAIWSAccountPool) if !ok || ap == nil { return true } ap.mu.Lock() for _, conn := range ap.conns { if conn != nil && !conn.isLeased() { conn.close() } } ap.mu.Unlock() return true }) }) } func (p *openAIWSConnPool) startBackgroundWorkers() { if p == nil || p.workerStopCh == nil { return } p.workerWg.Add(2) go func() { defer p.workerWg.Done() p.runBackgroundPingWorker() }() go func() { defer p.workerWg.Done() p.runBackgroundCleanupWorker() }() } type openAIWSIdlePingCandidate struct { accountID int64 conn *openAIWSConn } func (p *openAIWSConnPool) runBackgroundPingWorker() { if p == nil { return } ticker := time.NewTicker(openAIWSBackgroundPingInterval) defer ticker.Stop() for { select { case <-ticker.C: p.runBackgroundPingSweep() case <-p.workerStopCh: return } } } func (p *openAIWSConnPool) runBackgroundPingSweep() { if p == nil { return } candidates := p.snapshotIdleConnsForPing() var g errgroup.Group g.SetLimit(10) for _, item := range candidates { item := item if item.conn == nil || item.conn.isLeased() || item.conn.waiters.Load() > 0 { continue } g.Go(func() error { if err := item.conn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil { p.evictConn(item.accountID, item.conn.id) } return nil }) } _ = g.Wait() } func (p *openAIWSConnPool) snapshotIdleConnsForPing() []openAIWSIdlePingCandidate { if p == nil { return nil } candidates := make([]openAIWSIdlePingCandidate, 0) p.accounts.Range(func(key, value any) bool { accountID, ok := key.(int64) if !ok || accountID <= 0 { return true } ap, ok := value.(*openAIWSAccountPool) if !ok || ap == nil { return true } ap.mu.Lock() for _, conn := range ap.conns { if conn == nil || conn.isLeased() || conn.waiters.Load() > 0 { continue } candidates = append(candidates, openAIWSIdlePingCandidate{ accountID: accountID, conn: conn, }) } ap.mu.Unlock() return true }) return candidates } func (p *openAIWSConnPool) runBackgroundCleanupWorker() { if p == nil { return } ticker := time.NewTicker(openAIWSBackgroundSweepTicker) defer ticker.Stop() for { select { case <-ticker.C: p.runBackgroundCleanupSweep(time.Now()) case <-p.workerStopCh: return } } } func (p *openAIWSConnPool) runBackgroundCleanupSweep(now time.Time) { if p == nil { return } type cleanupResult struct { evicted []*openAIWSConn } results := make([]cleanupResult, 0) p.accounts.Range(func(_ any, value any) bool { ap, ok := value.(*openAIWSAccountPool) if !ok || ap == nil { return true } maxConns := p.maxConnsHardCap() ap.mu.Lock() if ap.lastAcquire != nil && ap.lastAcquire.Account != nil { maxConns = p.effectiveMaxConnsByAccount(ap.lastAcquire.Account) } evicted := p.cleanupAccountLocked(ap, now, maxConns) ap.lastCleanupAt = now ap.mu.Unlock() if len(evicted) > 0 { results = append(results, cleanupResult{evicted: evicted}) } return true }) for _, result := range results { closeOpenAIWSConns(result.evicted) } } func (p *openAIWSConnPool) Acquire(ctx context.Context, req openAIWSAcquireRequest) (*openAIWSConnLease, error) { if p != nil { p.metrics.acquireTotal.Add(1) } return p.acquire(ctx, cloneOpenAIWSAcquireRequest(req), 0) } func (p *openAIWSConnPool) acquire(ctx context.Context, req openAIWSAcquireRequest, retry int) (*openAIWSConnLease, error) { if p == nil || req.Account == nil || req.Account.ID <= 0 { return nil, errors.New("invalid ws acquire request") } if stringsTrim(req.WSURL) == "" { return nil, errors.New("ws url is empty") } accountID := req.Account.ID effectiveMaxConns := p.effectiveMaxConnsByAccount(req.Account) if effectiveMaxConns <= 0 { return nil, errOpenAIWSConnQueueFull } var evicted []*openAIWSConn ap := p.getOrCreateAccountPool(accountID) ap.mu.Lock() ap.lastAcquire = cloneOpenAIWSAcquireRequestPtr(&req) now := time.Now() if ap.lastCleanupAt.IsZero() || now.Sub(ap.lastCleanupAt) >= openAIWSAcquireCleanupInterval { evicted = p.cleanupAccountLocked(ap, now, effectiveMaxConns) ap.lastCleanupAt = now } pickStartedAt := time.Now() allowReuse := !req.ForceNewConn preferredConnID := stringsTrim(req.PreferredConnID) forcePreferredConn := allowReuse && req.ForcePreferredConn if allowReuse { if forcePreferredConn { if preferredConnID == "" { p.recordConnPickDuration(time.Since(pickStartedAt)) ap.mu.Unlock() closeOpenAIWSConns(evicted) return nil, errOpenAIWSPreferredConnUnavailable } preferredConn, ok := ap.conns[preferredConnID] if !ok || preferredConn == nil { p.recordConnPickDuration(time.Since(pickStartedAt)) ap.mu.Unlock() closeOpenAIWSConns(evicted) return nil, errOpenAIWSPreferredConnUnavailable } if preferredConn.tryAcquire() { connPick := time.Since(pickStartedAt) p.recordConnPickDuration(connPick) ap.mu.Unlock() closeOpenAIWSConns(evicted) if p.shouldHealthCheckConn(preferredConn) { if err := preferredConn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil { preferredConn.close() p.evictConn(accountID, preferredConn.id) if retry < 1 { return p.acquire(ctx, req, retry+1) } return nil, err } } lease := &openAIWSConnLease{ pool: p, accountID: accountID, conn: preferredConn, connPick: connPick, reused: true, } p.metrics.acquireReuseTotal.Add(1) p.ensureTargetIdleAsync(accountID) return lease, nil } connPick := time.Since(pickStartedAt) p.recordConnPickDuration(connPick) if int(preferredConn.waiters.Load()) >= p.queueLimitPerConn() { ap.mu.Unlock() closeOpenAIWSConns(evicted) return nil, errOpenAIWSConnQueueFull } preferredConn.waiters.Add(1) ap.mu.Unlock() closeOpenAIWSConns(evicted) defer preferredConn.waiters.Add(-1) waitStart := time.Now() p.metrics.acquireQueueWaitTotal.Add(1) if err := preferredConn.acquire(ctx); err != nil { if errors.Is(err, errOpenAIWSConnClosed) && retry < 1 { return p.acquire(ctx, req, retry+1) } return nil, err } if p.shouldHealthCheckConn(preferredConn) { if err := preferredConn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil { preferredConn.release() preferredConn.close() p.evictConn(accountID, preferredConn.id) if retry < 1 { return p.acquire(ctx, req, retry+1) } return nil, err } } queueWait := time.Since(waitStart) p.metrics.acquireQueueWaitMs.Add(queueWait.Milliseconds()) lease := &openAIWSConnLease{ pool: p, accountID: accountID, conn: preferredConn, queueWait: queueWait, connPick: connPick, reused: true, } p.metrics.acquireReuseTotal.Add(1) p.ensureTargetIdleAsync(accountID) return lease, nil } if preferredConnID != "" { if conn, ok := ap.conns[preferredConnID]; ok && conn.tryAcquire() { connPick := time.Since(pickStartedAt) p.recordConnPickDuration(connPick) ap.mu.Unlock() closeOpenAIWSConns(evicted) if p.shouldHealthCheckConn(conn) { if err := conn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil { conn.close() p.evictConn(accountID, conn.id) if retry < 1 { return p.acquire(ctx, req, retry+1) } return nil, err } } lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: conn, connPick: connPick, reused: true} p.metrics.acquireReuseTotal.Add(1) p.ensureTargetIdleAsync(accountID) return lease, nil } } best := p.pickLeastBusyConnLocked(ap, "") if best != nil && best.tryAcquire() { connPick := time.Since(pickStartedAt) p.recordConnPickDuration(connPick) ap.mu.Unlock() closeOpenAIWSConns(evicted) if p.shouldHealthCheckConn(best) { if err := best.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil { best.close() p.evictConn(accountID, best.id) if retry < 1 { return p.acquire(ctx, req, retry+1) } return nil, err } } lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: best, connPick: connPick, reused: true} p.metrics.acquireReuseTotal.Add(1) p.ensureTargetIdleAsync(accountID) return lease, nil } for _, conn := range ap.conns { if conn == nil || conn == best { continue } if conn.tryAcquire() { connPick := time.Since(pickStartedAt) p.recordConnPickDuration(connPick) ap.mu.Unlock() closeOpenAIWSConns(evicted) if p.shouldHealthCheckConn(conn) { if err := conn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil { conn.close() p.evictConn(accountID, conn.id) if retry < 1 { return p.acquire(ctx, req, retry+1) } return nil, err } } lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: conn, connPick: connPick, reused: true} p.metrics.acquireReuseTotal.Add(1) p.ensureTargetIdleAsync(accountID) return lease, nil } } } if req.ForceNewConn && len(ap.conns)+ap.creating >= effectiveMaxConns { if idle := p.pickOldestIdleConnLocked(ap); idle != nil { delete(ap.conns, idle.id) evicted = append(evicted, idle) p.metrics.scaleDownTotal.Add(1) } } if len(ap.conns)+ap.creating < effectiveMaxConns { connPick := time.Since(pickStartedAt) p.recordConnPickDuration(connPick) ap.creating++ ap.mu.Unlock() closeOpenAIWSConns(evicted) conn, dialErr := p.dialConn(ctx, req) ap = p.getOrCreateAccountPool(accountID) ap.mu.Lock() ap.creating-- if dialErr != nil { ap.prewarmFails++ ap.prewarmFailAt = time.Now() ap.mu.Unlock() return nil, dialErr } ap.conns[conn.id] = conn ap.prewarmFails = 0 ap.prewarmFailAt = time.Time{} ap.mu.Unlock() p.metrics.acquireCreateTotal.Add(1) if !conn.tryAcquire() { if err := conn.acquire(ctx); err != nil { conn.close() p.evictConn(accountID, conn.id) return nil, err } } lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: conn, connPick: connPick} p.ensureTargetIdleAsync(accountID) return lease, nil } if req.ForceNewConn { p.recordConnPickDuration(time.Since(pickStartedAt)) ap.mu.Unlock() closeOpenAIWSConns(evicted) return nil, errOpenAIWSConnQueueFull } target := p.pickLeastBusyConnLocked(ap, req.PreferredConnID) connPick := time.Since(pickStartedAt) p.recordConnPickDuration(connPick) if target == nil { ap.mu.Unlock() closeOpenAIWSConns(evicted) return nil, errOpenAIWSConnClosed } if int(target.waiters.Load()) >= p.queueLimitPerConn() { ap.mu.Unlock() closeOpenAIWSConns(evicted) return nil, errOpenAIWSConnQueueFull } target.waiters.Add(1) ap.mu.Unlock() closeOpenAIWSConns(evicted) defer target.waiters.Add(-1) waitStart := time.Now() p.metrics.acquireQueueWaitTotal.Add(1) if err := target.acquire(ctx); err != nil { if errors.Is(err, errOpenAIWSConnClosed) && retry < 1 { return p.acquire(ctx, req, retry+1) } return nil, err } if p.shouldHealthCheckConn(target) { if err := target.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil { target.release() target.close() p.evictConn(accountID, target.id) if retry < 1 { return p.acquire(ctx, req, retry+1) } return nil, err } } queueWait := time.Since(waitStart) p.metrics.acquireQueueWaitMs.Add(queueWait.Milliseconds()) lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: target, queueWait: queueWait, connPick: connPick, reused: true} p.metrics.acquireReuseTotal.Add(1) p.ensureTargetIdleAsync(accountID) return lease, nil } func (p *openAIWSConnPool) recordConnPickDuration(duration time.Duration) { if p == nil { return } if duration < 0 { duration = 0 } p.metrics.connPickTotal.Add(1) p.metrics.connPickMs.Add(duration.Milliseconds()) } func (p *openAIWSConnPool) pickOldestIdleConnLocked(ap *openAIWSAccountPool) *openAIWSConn { if ap == nil || len(ap.conns) == 0 { return nil } var oldest *openAIWSConn for _, conn := range ap.conns { if conn == nil || conn.isLeased() || conn.waiters.Load() > 0 || p.isConnPinnedLocked(ap, conn.id) { continue } if oldest == nil || conn.lastUsedAt().Before(oldest.lastUsedAt()) { oldest = conn } } return oldest } func (p *openAIWSConnPool) getOrCreateAccountPool(accountID int64) *openAIWSAccountPool { if p == nil || accountID <= 0 { return nil } if existing, ok := p.accounts.Load(accountID); ok { if ap, typed := existing.(*openAIWSAccountPool); typed && ap != nil { return ap } } ap := &openAIWSAccountPool{ conns: make(map[string]*openAIWSConn), pinnedConns: make(map[string]int), } actual, _ := p.accounts.LoadOrStore(accountID, ap) if typed, ok := actual.(*openAIWSAccountPool); ok && typed != nil { return typed } return ap } // ensureAccountPoolLocked 兼容旧调用。 func (p *openAIWSConnPool) ensureAccountPoolLocked(accountID int64) *openAIWSAccountPool { return p.getOrCreateAccountPool(accountID) } func (p *openAIWSConnPool) getAccountPool(accountID int64) (*openAIWSAccountPool, bool) { if p == nil || accountID <= 0 { return nil, false } value, ok := p.accounts.Load(accountID) if !ok || value == nil { return nil, false } ap, typed := value.(*openAIWSAccountPool) return ap, typed && ap != nil } func (p *openAIWSConnPool) isConnPinnedLocked(ap *openAIWSAccountPool, connID string) bool { if ap == nil || connID == "" || len(ap.pinnedConns) == 0 { return false } return ap.pinnedConns[connID] > 0 } func (p *openAIWSConnPool) cleanupAccountLocked(ap *openAIWSAccountPool, now time.Time, maxConns int) []*openAIWSConn { if ap == nil { return nil } maxAge := p.maxConnAge() evicted := make([]*openAIWSConn, 0) for id, conn := range ap.conns { if conn == nil { delete(ap.conns, id) if len(ap.pinnedConns) > 0 { delete(ap.pinnedConns, id) } continue } select { case <-conn.closedCh: delete(ap.conns, id) if len(ap.pinnedConns) > 0 { delete(ap.pinnedConns, id) } evicted = append(evicted, conn) continue default: } if p.isConnPinnedLocked(ap, id) { continue } if maxAge > 0 && !conn.isLeased() && conn.age(now) > maxAge { delete(ap.conns, id) if len(ap.pinnedConns) > 0 { delete(ap.pinnedConns, id) } evicted = append(evicted, conn) } } if maxConns <= 0 { maxConns = p.maxConnsHardCap() } maxIdle := p.maxIdlePerAccount() if maxIdle < 0 || maxIdle > maxConns { maxIdle = maxConns } if maxIdle >= 0 && len(ap.conns) > maxIdle { idleConns := make([]*openAIWSConn, 0, len(ap.conns)) for id, conn := range ap.conns { if conn == nil { delete(ap.conns, id) if len(ap.pinnedConns) > 0 { delete(ap.pinnedConns, id) } continue } // 有等待者的连接不能在清理阶段被淘汰,否则等待中的 acquire 会收到 closed 错误。 if conn.isLeased() || conn.waiters.Load() > 0 || p.isConnPinnedLocked(ap, conn.id) { continue } idleConns = append(idleConns, conn) } sort.SliceStable(idleConns, func(i, j int) bool { return idleConns[i].lastUsedAt().Before(idleConns[j].lastUsedAt()) }) redundant := len(ap.conns) - maxIdle if redundant > len(idleConns) { redundant = len(idleConns) } for i := 0; i < redundant; i++ { conn := idleConns[i] delete(ap.conns, conn.id) if len(ap.pinnedConns) > 0 { delete(ap.pinnedConns, conn.id) } evicted = append(evicted, conn) } if redundant > 0 { p.metrics.scaleDownTotal.Add(int64(redundant)) } } return evicted } func (p *openAIWSConnPool) pickLeastBusyConnLocked(ap *openAIWSAccountPool, preferredConnID string) *openAIWSConn { if ap == nil || len(ap.conns) == 0 { return nil } preferredConnID = stringsTrim(preferredConnID) if preferredConnID != "" { if conn, ok := ap.conns[preferredConnID]; ok { return conn } } var best *openAIWSConn var bestWaiters int32 var bestLastUsed time.Time for _, conn := range ap.conns { if conn == nil { continue } waiters := conn.waiters.Load() lastUsed := conn.lastUsedAt() if best == nil || waiters < bestWaiters || (waiters == bestWaiters && lastUsed.Before(bestLastUsed)) { best = conn bestWaiters = waiters bestLastUsed = lastUsed } } return best } func accountPoolLoadLocked(ap *openAIWSAccountPool) (inflight int, waiters int) { if ap == nil { return 0, 0 } for _, conn := range ap.conns { if conn == nil { continue } if conn.isLeased() { inflight++ } waiters += int(conn.waiters.Load()) } return inflight, waiters } // AccountPoolLoad 返回指定账号连接池的并发与排队快照。 func (p *openAIWSConnPool) AccountPoolLoad(accountID int64) (inflight int, waiters int, conns int) { if p == nil || accountID <= 0 { return 0, 0, 0 } ap, ok := p.getAccountPool(accountID) if !ok || ap == nil { return 0, 0, 0 } ap.mu.Lock() defer ap.mu.Unlock() inflight, waiters = accountPoolLoadLocked(ap) return inflight, waiters, len(ap.conns) } func (p *openAIWSConnPool) ensureTargetIdleAsync(accountID int64) { if p == nil || accountID <= 0 { return } var req openAIWSAcquireRequest need := 0 ap, ok := p.getAccountPool(accountID) if !ok || ap == nil { return } ap.mu.Lock() defer ap.mu.Unlock() if ap.lastAcquire == nil { return } if ap.prewarmActive { return } now := time.Now() if !ap.prewarmUntil.IsZero() && now.Before(ap.prewarmUntil) { return } if p.shouldSuppressPrewarmLocked(ap, now) { return } effectiveMaxConns := p.maxConnsHardCap() if ap.lastAcquire != nil && ap.lastAcquire.Account != nil { effectiveMaxConns = p.effectiveMaxConnsByAccount(ap.lastAcquire.Account) } target := p.targetConnCountLocked(ap, effectiveMaxConns) current := len(ap.conns) + ap.creating if current >= target { return } need = target - current if need <= 0 { return } req = cloneOpenAIWSAcquireRequest(*ap.lastAcquire) ap.prewarmActive = true if cooldown := p.prewarmCooldown(); cooldown > 0 { ap.prewarmUntil = now.Add(cooldown) } ap.creating += need p.metrics.scaleUpTotal.Add(int64(need)) go p.prewarmConns(accountID, req, need) } func (p *openAIWSConnPool) targetConnCountLocked(ap *openAIWSAccountPool, maxConns int) int { if ap == nil { return 0 } if maxConns <= 0 { return 0 } minIdle := p.minIdlePerAccount() if minIdle < 0 { minIdle = 0 } if minIdle > maxConns { minIdle = maxConns } inflight, waiters := accountPoolLoadLocked(ap) utilization := p.targetUtilization() demand := inflight + waiters if demand <= 0 { return minIdle } target := 1 if demand > 1 { target = int(math.Ceil(float64(demand) / utilization)) } if waiters > 0 && target < len(ap.conns)+1 { target = len(ap.conns) + 1 } if target < minIdle { target = minIdle } if target > maxConns { target = maxConns } return target } func (p *openAIWSConnPool) prewarmConns(accountID int64, req openAIWSAcquireRequest, total int) { defer func() { if ap, ok := p.getAccountPool(accountID); ok && ap != nil { ap.mu.Lock() ap.prewarmActive = false ap.mu.Unlock() } }() for i := 0; i < total; i++ { ctx, cancel := context.WithTimeout(context.Background(), p.dialTimeout()+openAIWSConnPrewarmExtraDelay) conn, err := p.dialConn(ctx, req) cancel() ap, ok := p.getAccountPool(accountID) if !ok || ap == nil { if conn != nil { conn.close() } return } ap.mu.Lock() if ap.creating > 0 { ap.creating-- } if err != nil { ap.prewarmFails++ ap.prewarmFailAt = time.Now() ap.mu.Unlock() continue } if len(ap.conns) >= p.effectiveMaxConnsByAccount(req.Account) { ap.mu.Unlock() conn.close() continue } ap.conns[conn.id] = conn ap.prewarmFails = 0 ap.prewarmFailAt = time.Time{} ap.mu.Unlock() } } func (p *openAIWSConnPool) evictConn(accountID int64, connID string) { if p == nil || accountID <= 0 || stringsTrim(connID) == "" { return } var conn *openAIWSConn ap, ok := p.getAccountPool(accountID) if ok && ap != nil { ap.mu.Lock() if c, exists := ap.conns[connID]; exists { conn = c delete(ap.conns, connID) if len(ap.pinnedConns) > 0 { delete(ap.pinnedConns, connID) } } ap.mu.Unlock() } if conn != nil { conn.close() } } func (p *openAIWSConnPool) PinConn(accountID int64, connID string) bool { if p == nil || accountID <= 0 { return false } connID = stringsTrim(connID) if connID == "" { return false } ap, ok := p.getAccountPool(accountID) if !ok || ap == nil { return false } ap.mu.Lock() defer ap.mu.Unlock() if _, exists := ap.conns[connID]; !exists { return false } if ap.pinnedConns == nil { ap.pinnedConns = make(map[string]int) } ap.pinnedConns[connID]++ return true } func (p *openAIWSConnPool) UnpinConn(accountID int64, connID string) { if p == nil || accountID <= 0 { return } connID = stringsTrim(connID) if connID == "" { return } ap, ok := p.getAccountPool(accountID) if !ok || ap == nil { return } ap.mu.Lock() defer ap.mu.Unlock() if len(ap.pinnedConns) == 0 { return } count := ap.pinnedConns[connID] if count <= 1 { delete(ap.pinnedConns, connID) return } ap.pinnedConns[connID] = count - 1 } func (p *openAIWSConnPool) dialConn(ctx context.Context, req openAIWSAcquireRequest) (*openAIWSConn, error) { if p == nil || p.clientDialer == nil { return nil, errors.New("openai ws client dialer is nil") } conn, status, handshakeHeaders, err := p.clientDialer.Dial(ctx, req.WSURL, req.Headers, req.ProxyURL) if err != nil { return nil, &openAIWSDialError{ StatusCode: status, ResponseHeaders: cloneHeader(handshakeHeaders), Err: err, } } if conn == nil { return nil, &openAIWSDialError{ StatusCode: status, ResponseHeaders: cloneHeader(handshakeHeaders), Err: errors.New("openai ws dialer returned nil connection"), } } id := p.nextConnID(req.Account.ID) return newOpenAIWSConn(id, req.Account.ID, conn, handshakeHeaders), nil } func (p *openAIWSConnPool) nextConnID(accountID int64) string { seq := p.seq.Add(1) buf := make([]byte, 0, 32) buf = append(buf, "oa_ws_"...) buf = strconv.AppendInt(buf, accountID, 10) buf = append(buf, '_') buf = strconv.AppendUint(buf, seq, 10) return string(buf) } func (p *openAIWSConnPool) shouldHealthCheckConn(conn *openAIWSConn) bool { if conn == nil { return false } return conn.idleDuration(time.Now()) >= openAIWSConnHealthCheckIdle } func (p *openAIWSConnPool) maxConnsHardCap() int { if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.MaxConnsPerAccount > 0 { return p.cfg.Gateway.OpenAIWS.MaxConnsPerAccount } return 8 } func (p *openAIWSConnPool) dynamicMaxConnsEnabled() bool { if p != nil && p.cfg != nil { return p.cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled } return false } func (p *openAIWSConnPool) modeRouterV2Enabled() bool { if p != nil && p.cfg != nil { return p.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled } return false } func (p *openAIWSConnPool) maxConnsFactorByAccount(account *Account) float64 { if p == nil || p.cfg == nil || account == nil { return 1.0 } switch account.Type { case AccountTypeOAuth: if p.cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor > 0 { return p.cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor } case AccountTypeAPIKey: if p.cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor > 0 { return p.cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor } } return 1.0 } func (p *openAIWSConnPool) effectiveMaxConnsByAccount(account *Account) int { hardCap := p.maxConnsHardCap() if hardCap <= 0 { return 0 } if p.modeRouterV2Enabled() { if account == nil { return hardCap } if account.Concurrency <= 0 { return 0 } return account.Concurrency } if account == nil || !p.dynamicMaxConnsEnabled() { return hardCap } if account.Concurrency <= 0 { // 0/-1 等“无限制”并发场景下,仍由全局硬上限兜底。 return hardCap } factor := p.maxConnsFactorByAccount(account) if factor <= 0 { factor = 1.0 } effective := int(math.Ceil(float64(account.Concurrency) * factor)) if effective < 1 { effective = 1 } if effective > hardCap { effective = hardCap } return effective } func (p *openAIWSConnPool) minIdlePerAccount() int { if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.MinIdlePerAccount >= 0 { return p.cfg.Gateway.OpenAIWS.MinIdlePerAccount } return 0 } func (p *openAIWSConnPool) maxIdlePerAccount() int { if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.MaxIdlePerAccount >= 0 { return p.cfg.Gateway.OpenAIWS.MaxIdlePerAccount } return 4 } func (p *openAIWSConnPool) maxConnAge() time.Duration { return openAIWSConnMaxAge } func (p *openAIWSConnPool) queueLimitPerConn() int { if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.QueueLimitPerConn > 0 { return p.cfg.Gateway.OpenAIWS.QueueLimitPerConn } return 256 } func (p *openAIWSConnPool) targetUtilization() float64 { if p != nil && p.cfg != nil { ratio := p.cfg.Gateway.OpenAIWS.PoolTargetUtilization if ratio > 0 && ratio <= 1 { return ratio } } return 0.7 } func (p *openAIWSConnPool) prewarmCooldown() time.Duration { if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.PrewarmCooldownMS > 0 { return time.Duration(p.cfg.Gateway.OpenAIWS.PrewarmCooldownMS) * time.Millisecond } return 0 } func (p *openAIWSConnPool) shouldSuppressPrewarmLocked(ap *openAIWSAccountPool, now time.Time) bool { if ap == nil { return true } if ap.prewarmFails <= 0 { return false } if ap.prewarmFailAt.IsZero() { ap.prewarmFails = 0 return false } if now.Sub(ap.prewarmFailAt) > openAIWSPrewarmFailureWindow { ap.prewarmFails = 0 ap.prewarmFailAt = time.Time{} return false } return ap.prewarmFails >= openAIWSPrewarmFailureSuppress } func (p *openAIWSConnPool) dialTimeout() time.Duration { if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.DialTimeoutSeconds > 0 { return time.Duration(p.cfg.Gateway.OpenAIWS.DialTimeoutSeconds) * time.Second } return 10 * time.Second } func cloneOpenAIWSAcquireRequest(req openAIWSAcquireRequest) openAIWSAcquireRequest { copied := req copied.Headers = cloneHeader(req.Headers) copied.WSURL = stringsTrim(req.WSURL) copied.ProxyURL = stringsTrim(req.ProxyURL) copied.PreferredConnID = stringsTrim(req.PreferredConnID) return copied } func cloneOpenAIWSAcquireRequestPtr(req *openAIWSAcquireRequest) *openAIWSAcquireRequest { if req == nil { return nil } copied := cloneOpenAIWSAcquireRequest(*req) return &copied } func cloneHeader(src http.Header) http.Header { if src == nil { return nil } dst := make(http.Header, len(src)) for k, vals := range src { if len(vals) == 0 { dst[k] = nil continue } copied := make([]string, len(vals)) copy(copied, vals) dst[k] = copied } return dst } func closeOpenAIWSConns(conns []*openAIWSConn) { if len(conns) == 0 { return } for _, conn := range conns { if conn == nil { continue } conn.close() } } func stringsTrim(value string) string { return strings.TrimSpace(value) }