package service import ( "context" "errors" "fmt" "net/http" "net/url" "strings" "sync" "sync/atomic" "time" coderws "github.com/coder/websocket" "github.com/coder/websocket/wsjson" ) const openAIWSMessageReadLimitBytes int64 = 16 * 1024 * 1024 const ( openAIWSProxyTransportMaxIdleConns = 128 openAIWSProxyTransportMaxIdleConnsPerHost = 64 openAIWSProxyTransportIdleConnTimeout = 90 * time.Second openAIWSProxyClientCacheMaxEntries = 256 openAIWSProxyClientCacheIdleTTL = 15 * time.Minute ) type OpenAIWSTransportMetricsSnapshot struct { ProxyClientCacheHits int64 `json:"proxy_client_cache_hits"` ProxyClientCacheMisses int64 `json:"proxy_client_cache_misses"` TransportReuseRatio float64 `json:"transport_reuse_ratio"` } // openAIWSClientConn 抽象 WS 客户端连接,便于替换底层实现。 type openAIWSClientConn interface { WriteJSON(ctx context.Context, value any) error ReadMessage(ctx context.Context) ([]byte, error) Ping(ctx context.Context) error Close() error } // openAIWSClientDialer 抽象 WS 建连器。 type openAIWSClientDialer interface { Dial(ctx context.Context, wsURL string, headers http.Header, proxyURL string) (openAIWSClientConn, int, http.Header, error) } type openAIWSTransportMetricsDialer interface { SnapshotTransportMetrics() OpenAIWSTransportMetricsSnapshot } func newDefaultOpenAIWSClientDialer() openAIWSClientDialer { return &coderOpenAIWSClientDialer{ proxyClients: make(map[string]*openAIWSProxyClientEntry), } } type coderOpenAIWSClientDialer struct { proxyMu sync.Mutex proxyClients map[string]*openAIWSProxyClientEntry proxyHits atomic.Int64 proxyMisses atomic.Int64 } type openAIWSProxyClientEntry struct { client *http.Client lastUsedUnixNano int64 } func (d *coderOpenAIWSClientDialer) Dial( ctx context.Context, wsURL string, headers http.Header, proxyURL string, ) (openAIWSClientConn, int, http.Header, error) { targetURL := strings.TrimSpace(wsURL) if targetURL == "" { return nil, 0, nil, errors.New("ws url is empty") } opts := &coderws.DialOptions{ HTTPHeader: cloneHeader(headers), CompressionMode: coderws.CompressionContextTakeover, } if proxy := strings.TrimSpace(proxyURL); proxy != "" { proxyClient, err := d.proxyHTTPClient(proxy) if err != nil { return nil, 0, nil, err } opts.HTTPClient = proxyClient } conn, resp, err := coderws.Dial(ctx, targetURL, opts) if err != nil { status := 0 respHeaders := http.Header(nil) if resp != nil { status = resp.StatusCode respHeaders = cloneHeader(resp.Header) } return nil, status, respHeaders, err } // coder/websocket 默认单消息读取上限为 32KB,Codex WS 事件(如 rate_limits/大 delta) // 可能超过该阈值,需显式提高上限,避免本地 read_fail(message too big)。 conn.SetReadLimit(openAIWSMessageReadLimitBytes) respHeaders := http.Header(nil) if resp != nil { respHeaders = cloneHeader(resp.Header) } return &coderOpenAIWSClientConn{conn: conn}, 0, respHeaders, nil } func (d *coderOpenAIWSClientDialer) proxyHTTPClient(proxy string) (*http.Client, error) { if d == nil { return nil, errors.New("openai ws dialer is nil") } normalizedProxy := strings.TrimSpace(proxy) if normalizedProxy == "" { return nil, errors.New("proxy url is empty") } parsedProxyURL, err := url.Parse(normalizedProxy) if err != nil { return nil, fmt.Errorf("invalid proxy url: %w", err) } now := time.Now().UnixNano() d.proxyMu.Lock() defer d.proxyMu.Unlock() if entry, ok := d.proxyClients[normalizedProxy]; ok && entry != nil && entry.client != nil { entry.lastUsedUnixNano = now d.proxyHits.Add(1) return entry.client, nil } d.cleanupProxyClientsLocked(now) transport := &http.Transport{ Proxy: http.ProxyURL(parsedProxyURL), MaxIdleConns: openAIWSProxyTransportMaxIdleConns, MaxIdleConnsPerHost: openAIWSProxyTransportMaxIdleConnsPerHost, IdleConnTimeout: openAIWSProxyTransportIdleConnTimeout, TLSHandshakeTimeout: 10 * time.Second, ForceAttemptHTTP2: true, } client := &http.Client{Transport: transport} d.proxyClients[normalizedProxy] = &openAIWSProxyClientEntry{ client: client, lastUsedUnixNano: now, } d.ensureProxyClientCapacityLocked() d.proxyMisses.Add(1) return client, nil } func (d *coderOpenAIWSClientDialer) cleanupProxyClientsLocked(nowUnixNano int64) { if d == nil || len(d.proxyClients) == 0 { return } idleTTL := openAIWSProxyClientCacheIdleTTL if idleTTL <= 0 { return } now := time.Unix(0, nowUnixNano) for key, entry := range d.proxyClients { if entry == nil || entry.client == nil { delete(d.proxyClients, key) continue } lastUsed := time.Unix(0, entry.lastUsedUnixNano) if now.Sub(lastUsed) > idleTTL { closeOpenAIWSProxyClient(entry.client) delete(d.proxyClients, key) } } } func (d *coderOpenAIWSClientDialer) ensureProxyClientCapacityLocked() { if d == nil { return } maxEntries := openAIWSProxyClientCacheMaxEntries if maxEntries <= 0 { return } for len(d.proxyClients) > maxEntries { var oldestKey string var oldestLastUsed int64 hasOldest := false for key, entry := range d.proxyClients { lastUsed := int64(0) if entry != nil { lastUsed = entry.lastUsedUnixNano } if !hasOldest || lastUsed < oldestLastUsed { hasOldest = true oldestKey = key oldestLastUsed = lastUsed } } if !hasOldest { return } if entry := d.proxyClients[oldestKey]; entry != nil { closeOpenAIWSProxyClient(entry.client) } delete(d.proxyClients, oldestKey) } } func closeOpenAIWSProxyClient(client *http.Client) { if client == nil || client.Transport == nil { return } if transport, ok := client.Transport.(*http.Transport); ok && transport != nil { transport.CloseIdleConnections() } } func (d *coderOpenAIWSClientDialer) SnapshotTransportMetrics() OpenAIWSTransportMetricsSnapshot { if d == nil { return OpenAIWSTransportMetricsSnapshot{} } hits := d.proxyHits.Load() misses := d.proxyMisses.Load() total := hits + misses reuseRatio := 0.0 if total > 0 { reuseRatio = float64(hits) / float64(total) } return OpenAIWSTransportMetricsSnapshot{ ProxyClientCacheHits: hits, ProxyClientCacheMisses: misses, TransportReuseRatio: reuseRatio, } } type coderOpenAIWSClientConn struct { conn *coderws.Conn } func (c *coderOpenAIWSClientConn) WriteJSON(ctx context.Context, value any) error { if c == nil || c.conn == nil { return errOpenAIWSConnClosed } if ctx == nil { ctx = context.Background() } return wsjson.Write(ctx, c.conn, value) } func (c *coderOpenAIWSClientConn) ReadMessage(ctx context.Context) ([]byte, error) { if c == nil || c.conn == nil { return nil, errOpenAIWSConnClosed } if ctx == nil { ctx = context.Background() } msgType, payload, err := c.conn.Read(ctx) if err != nil { return nil, err } switch msgType { case coderws.MessageText, coderws.MessageBinary: return payload, nil default: return nil, errOpenAIWSConnClosed } } func (c *coderOpenAIWSClientConn) Ping(ctx context.Context) error { if c == nil || c.conn == nil { return errOpenAIWSConnClosed } if ctx == nil { ctx = context.Background() } return c.conn.Ping(ctx) } func (c *coderOpenAIWSClientConn) Close() error { if c == nil || c.conn == nil { return nil } // Close 为幂等,忽略重复关闭错误。 _ = c.conn.Close(coderws.StatusNormalClosure, "") _ = c.conn.CloseNow() return nil }