Files
sub2api/backend/internal/service/openai_ws_client.go
2026-02-28 15:01:20 +08:00

286 lines
7.4 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"context"
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"sync"
"sync/atomic"
"time"
coderws "github.com/coder/websocket"
"github.com/coder/websocket/wsjson"
)
const openAIWSMessageReadLimitBytes int64 = 16 * 1024 * 1024
const (
openAIWSProxyTransportMaxIdleConns = 128
openAIWSProxyTransportMaxIdleConnsPerHost = 64
openAIWSProxyTransportIdleConnTimeout = 90 * time.Second
openAIWSProxyClientCacheMaxEntries = 256
openAIWSProxyClientCacheIdleTTL = 15 * time.Minute
)
type OpenAIWSTransportMetricsSnapshot struct {
ProxyClientCacheHits int64 `json:"proxy_client_cache_hits"`
ProxyClientCacheMisses int64 `json:"proxy_client_cache_misses"`
TransportReuseRatio float64 `json:"transport_reuse_ratio"`
}
// openAIWSClientConn 抽象 WS 客户端连接,便于替换底层实现。
type openAIWSClientConn interface {
WriteJSON(ctx context.Context, value any) error
ReadMessage(ctx context.Context) ([]byte, error)
Ping(ctx context.Context) error
Close() error
}
// openAIWSClientDialer 抽象 WS 建连器。
type openAIWSClientDialer interface {
Dial(ctx context.Context, wsURL string, headers http.Header, proxyURL string) (openAIWSClientConn, int, http.Header, error)
}
type openAIWSTransportMetricsDialer interface {
SnapshotTransportMetrics() OpenAIWSTransportMetricsSnapshot
}
func newDefaultOpenAIWSClientDialer() openAIWSClientDialer {
return &coderOpenAIWSClientDialer{
proxyClients: make(map[string]*openAIWSProxyClientEntry),
}
}
type coderOpenAIWSClientDialer struct {
proxyMu sync.Mutex
proxyClients map[string]*openAIWSProxyClientEntry
proxyHits atomic.Int64
proxyMisses atomic.Int64
}
type openAIWSProxyClientEntry struct {
client *http.Client
lastUsedUnixNano int64
}
func (d *coderOpenAIWSClientDialer) Dial(
ctx context.Context,
wsURL string,
headers http.Header,
proxyURL string,
) (openAIWSClientConn, int, http.Header, error) {
targetURL := strings.TrimSpace(wsURL)
if targetURL == "" {
return nil, 0, nil, errors.New("ws url is empty")
}
opts := &coderws.DialOptions{
HTTPHeader: cloneHeader(headers),
CompressionMode: coderws.CompressionContextTakeover,
}
if proxy := strings.TrimSpace(proxyURL); proxy != "" {
proxyClient, err := d.proxyHTTPClient(proxy)
if err != nil {
return nil, 0, nil, err
}
opts.HTTPClient = proxyClient
}
conn, resp, err := coderws.Dial(ctx, targetURL, opts)
if err != nil {
status := 0
respHeaders := http.Header(nil)
if resp != nil {
status = resp.StatusCode
respHeaders = cloneHeader(resp.Header)
}
return nil, status, respHeaders, err
}
// coder/websocket 默认单消息读取上限为 32KBCodex WS 事件(如 rate_limits/大 delta
// 可能超过该阈值,需显式提高上限,避免本地 read_fail(message too big)。
conn.SetReadLimit(openAIWSMessageReadLimitBytes)
respHeaders := http.Header(nil)
if resp != nil {
respHeaders = cloneHeader(resp.Header)
}
return &coderOpenAIWSClientConn{conn: conn}, 0, respHeaders, nil
}
func (d *coderOpenAIWSClientDialer) proxyHTTPClient(proxy string) (*http.Client, error) {
if d == nil {
return nil, errors.New("openai ws dialer is nil")
}
normalizedProxy := strings.TrimSpace(proxy)
if normalizedProxy == "" {
return nil, errors.New("proxy url is empty")
}
parsedProxyURL, err := url.Parse(normalizedProxy)
if err != nil {
return nil, fmt.Errorf("invalid proxy url: %w", err)
}
now := time.Now().UnixNano()
d.proxyMu.Lock()
defer d.proxyMu.Unlock()
if entry, ok := d.proxyClients[normalizedProxy]; ok && entry != nil && entry.client != nil {
entry.lastUsedUnixNano = now
d.proxyHits.Add(1)
return entry.client, nil
}
d.cleanupProxyClientsLocked(now)
transport := &http.Transport{
Proxy: http.ProxyURL(parsedProxyURL),
MaxIdleConns: openAIWSProxyTransportMaxIdleConns,
MaxIdleConnsPerHost: openAIWSProxyTransportMaxIdleConnsPerHost,
IdleConnTimeout: openAIWSProxyTransportIdleConnTimeout,
TLSHandshakeTimeout: 10 * time.Second,
ForceAttemptHTTP2: true,
}
client := &http.Client{Transport: transport}
d.proxyClients[normalizedProxy] = &openAIWSProxyClientEntry{
client: client,
lastUsedUnixNano: now,
}
d.ensureProxyClientCapacityLocked()
d.proxyMisses.Add(1)
return client, nil
}
func (d *coderOpenAIWSClientDialer) cleanupProxyClientsLocked(nowUnixNano int64) {
if d == nil || len(d.proxyClients) == 0 {
return
}
idleTTL := openAIWSProxyClientCacheIdleTTL
if idleTTL <= 0 {
return
}
now := time.Unix(0, nowUnixNano)
for key, entry := range d.proxyClients {
if entry == nil || entry.client == nil {
delete(d.proxyClients, key)
continue
}
lastUsed := time.Unix(0, entry.lastUsedUnixNano)
if now.Sub(lastUsed) > idleTTL {
closeOpenAIWSProxyClient(entry.client)
delete(d.proxyClients, key)
}
}
}
func (d *coderOpenAIWSClientDialer) ensureProxyClientCapacityLocked() {
if d == nil {
return
}
maxEntries := openAIWSProxyClientCacheMaxEntries
if maxEntries <= 0 {
return
}
for len(d.proxyClients) > maxEntries {
var oldestKey string
var oldestLastUsed int64
hasOldest := false
for key, entry := range d.proxyClients {
lastUsed := int64(0)
if entry != nil {
lastUsed = entry.lastUsedUnixNano
}
if !hasOldest || lastUsed < oldestLastUsed {
hasOldest = true
oldestKey = key
oldestLastUsed = lastUsed
}
}
if !hasOldest {
return
}
if entry := d.proxyClients[oldestKey]; entry != nil {
closeOpenAIWSProxyClient(entry.client)
}
delete(d.proxyClients, oldestKey)
}
}
func closeOpenAIWSProxyClient(client *http.Client) {
if client == nil || client.Transport == nil {
return
}
if transport, ok := client.Transport.(*http.Transport); ok && transport != nil {
transport.CloseIdleConnections()
}
}
func (d *coderOpenAIWSClientDialer) SnapshotTransportMetrics() OpenAIWSTransportMetricsSnapshot {
if d == nil {
return OpenAIWSTransportMetricsSnapshot{}
}
hits := d.proxyHits.Load()
misses := d.proxyMisses.Load()
total := hits + misses
reuseRatio := 0.0
if total > 0 {
reuseRatio = float64(hits) / float64(total)
}
return OpenAIWSTransportMetricsSnapshot{
ProxyClientCacheHits: hits,
ProxyClientCacheMisses: misses,
TransportReuseRatio: reuseRatio,
}
}
type coderOpenAIWSClientConn struct {
conn *coderws.Conn
}
func (c *coderOpenAIWSClientConn) WriteJSON(ctx context.Context, value any) error {
if c == nil || c.conn == nil {
return errOpenAIWSConnClosed
}
if ctx == nil {
ctx = context.Background()
}
return wsjson.Write(ctx, c.conn, value)
}
func (c *coderOpenAIWSClientConn) ReadMessage(ctx context.Context) ([]byte, error) {
if c == nil || c.conn == nil {
return nil, errOpenAIWSConnClosed
}
if ctx == nil {
ctx = context.Background()
}
msgType, payload, err := c.conn.Read(ctx)
if err != nil {
return nil, err
}
switch msgType {
case coderws.MessageText, coderws.MessageBinary:
return payload, nil
default:
return nil, errOpenAIWSConnClosed
}
}
func (c *coderOpenAIWSClientConn) Ping(ctx context.Context) error {
if c == nil || c.conn == nil {
return errOpenAIWSConnClosed
}
if ctx == nil {
ctx = context.Background()
}
return c.conn.Ping(ctx)
}
func (c *coderOpenAIWSClientConn) Close() error {
if c == nil || c.conn == nil {
return nil
}
// Close 为幂等,忽略重复关闭错误。
_ = c.conn.Close(coderws.StatusNormalClosure, "")
_ = c.conn.CloseNow()
return nil
}