286 lines
7.4 KiB
Go
286 lines
7.4 KiB
Go
package service
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"fmt"
|
||
"net/http"
|
||
"net/url"
|
||
"strings"
|
||
"sync"
|
||
"sync/atomic"
|
||
"time"
|
||
|
||
coderws "github.com/coder/websocket"
|
||
"github.com/coder/websocket/wsjson"
|
||
)
|
||
|
||
const openAIWSMessageReadLimitBytes int64 = 16 * 1024 * 1024
|
||
const (
|
||
openAIWSProxyTransportMaxIdleConns = 128
|
||
openAIWSProxyTransportMaxIdleConnsPerHost = 64
|
||
openAIWSProxyTransportIdleConnTimeout = 90 * time.Second
|
||
openAIWSProxyClientCacheMaxEntries = 256
|
||
openAIWSProxyClientCacheIdleTTL = 15 * time.Minute
|
||
)
|
||
|
||
type OpenAIWSTransportMetricsSnapshot struct {
|
||
ProxyClientCacheHits int64 `json:"proxy_client_cache_hits"`
|
||
ProxyClientCacheMisses int64 `json:"proxy_client_cache_misses"`
|
||
TransportReuseRatio float64 `json:"transport_reuse_ratio"`
|
||
}
|
||
|
||
// openAIWSClientConn 抽象 WS 客户端连接,便于替换底层实现。
|
||
type openAIWSClientConn interface {
|
||
WriteJSON(ctx context.Context, value any) error
|
||
ReadMessage(ctx context.Context) ([]byte, error)
|
||
Ping(ctx context.Context) error
|
||
Close() error
|
||
}
|
||
|
||
// openAIWSClientDialer 抽象 WS 建连器。
|
||
type openAIWSClientDialer interface {
|
||
Dial(ctx context.Context, wsURL string, headers http.Header, proxyURL string) (openAIWSClientConn, int, http.Header, error)
|
||
}
|
||
|
||
type openAIWSTransportMetricsDialer interface {
|
||
SnapshotTransportMetrics() OpenAIWSTransportMetricsSnapshot
|
||
}
|
||
|
||
func newDefaultOpenAIWSClientDialer() openAIWSClientDialer {
|
||
return &coderOpenAIWSClientDialer{
|
||
proxyClients: make(map[string]*openAIWSProxyClientEntry),
|
||
}
|
||
}
|
||
|
||
type coderOpenAIWSClientDialer struct {
|
||
proxyMu sync.Mutex
|
||
proxyClients map[string]*openAIWSProxyClientEntry
|
||
proxyHits atomic.Int64
|
||
proxyMisses atomic.Int64
|
||
}
|
||
|
||
type openAIWSProxyClientEntry struct {
|
||
client *http.Client
|
||
lastUsedUnixNano int64
|
||
}
|
||
|
||
func (d *coderOpenAIWSClientDialer) Dial(
|
||
ctx context.Context,
|
||
wsURL string,
|
||
headers http.Header,
|
||
proxyURL string,
|
||
) (openAIWSClientConn, int, http.Header, error) {
|
||
targetURL := strings.TrimSpace(wsURL)
|
||
if targetURL == "" {
|
||
return nil, 0, nil, errors.New("ws url is empty")
|
||
}
|
||
|
||
opts := &coderws.DialOptions{
|
||
HTTPHeader: cloneHeader(headers),
|
||
CompressionMode: coderws.CompressionContextTakeover,
|
||
}
|
||
if proxy := strings.TrimSpace(proxyURL); proxy != "" {
|
||
proxyClient, err := d.proxyHTTPClient(proxy)
|
||
if err != nil {
|
||
return nil, 0, nil, err
|
||
}
|
||
opts.HTTPClient = proxyClient
|
||
}
|
||
|
||
conn, resp, err := coderws.Dial(ctx, targetURL, opts)
|
||
if err != nil {
|
||
status := 0
|
||
respHeaders := http.Header(nil)
|
||
if resp != nil {
|
||
status = resp.StatusCode
|
||
respHeaders = cloneHeader(resp.Header)
|
||
}
|
||
return nil, status, respHeaders, err
|
||
}
|
||
// coder/websocket 默认单消息读取上限为 32KB,Codex WS 事件(如 rate_limits/大 delta)
|
||
// 可能超过该阈值,需显式提高上限,避免本地 read_fail(message too big)。
|
||
conn.SetReadLimit(openAIWSMessageReadLimitBytes)
|
||
respHeaders := http.Header(nil)
|
||
if resp != nil {
|
||
respHeaders = cloneHeader(resp.Header)
|
||
}
|
||
return &coderOpenAIWSClientConn{conn: conn}, 0, respHeaders, nil
|
||
}
|
||
|
||
func (d *coderOpenAIWSClientDialer) proxyHTTPClient(proxy string) (*http.Client, error) {
|
||
if d == nil {
|
||
return nil, errors.New("openai ws dialer is nil")
|
||
}
|
||
normalizedProxy := strings.TrimSpace(proxy)
|
||
if normalizedProxy == "" {
|
||
return nil, errors.New("proxy url is empty")
|
||
}
|
||
parsedProxyURL, err := url.Parse(normalizedProxy)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("invalid proxy url: %w", err)
|
||
}
|
||
now := time.Now().UnixNano()
|
||
|
||
d.proxyMu.Lock()
|
||
defer d.proxyMu.Unlock()
|
||
if entry, ok := d.proxyClients[normalizedProxy]; ok && entry != nil && entry.client != nil {
|
||
entry.lastUsedUnixNano = now
|
||
d.proxyHits.Add(1)
|
||
return entry.client, nil
|
||
}
|
||
d.cleanupProxyClientsLocked(now)
|
||
transport := &http.Transport{
|
||
Proxy: http.ProxyURL(parsedProxyURL),
|
||
MaxIdleConns: openAIWSProxyTransportMaxIdleConns,
|
||
MaxIdleConnsPerHost: openAIWSProxyTransportMaxIdleConnsPerHost,
|
||
IdleConnTimeout: openAIWSProxyTransportIdleConnTimeout,
|
||
TLSHandshakeTimeout: 10 * time.Second,
|
||
ForceAttemptHTTP2: true,
|
||
}
|
||
client := &http.Client{Transport: transport}
|
||
d.proxyClients[normalizedProxy] = &openAIWSProxyClientEntry{
|
||
client: client,
|
||
lastUsedUnixNano: now,
|
||
}
|
||
d.ensureProxyClientCapacityLocked()
|
||
d.proxyMisses.Add(1)
|
||
return client, nil
|
||
}
|
||
|
||
func (d *coderOpenAIWSClientDialer) cleanupProxyClientsLocked(nowUnixNano int64) {
|
||
if d == nil || len(d.proxyClients) == 0 {
|
||
return
|
||
}
|
||
idleTTL := openAIWSProxyClientCacheIdleTTL
|
||
if idleTTL <= 0 {
|
||
return
|
||
}
|
||
now := time.Unix(0, nowUnixNano)
|
||
for key, entry := range d.proxyClients {
|
||
if entry == nil || entry.client == nil {
|
||
delete(d.proxyClients, key)
|
||
continue
|
||
}
|
||
lastUsed := time.Unix(0, entry.lastUsedUnixNano)
|
||
if now.Sub(lastUsed) > idleTTL {
|
||
closeOpenAIWSProxyClient(entry.client)
|
||
delete(d.proxyClients, key)
|
||
}
|
||
}
|
||
}
|
||
|
||
func (d *coderOpenAIWSClientDialer) ensureProxyClientCapacityLocked() {
|
||
if d == nil {
|
||
return
|
||
}
|
||
maxEntries := openAIWSProxyClientCacheMaxEntries
|
||
if maxEntries <= 0 {
|
||
return
|
||
}
|
||
for len(d.proxyClients) > maxEntries {
|
||
var oldestKey string
|
||
var oldestLastUsed int64
|
||
hasOldest := false
|
||
for key, entry := range d.proxyClients {
|
||
lastUsed := int64(0)
|
||
if entry != nil {
|
||
lastUsed = entry.lastUsedUnixNano
|
||
}
|
||
if !hasOldest || lastUsed < oldestLastUsed {
|
||
hasOldest = true
|
||
oldestKey = key
|
||
oldestLastUsed = lastUsed
|
||
}
|
||
}
|
||
if !hasOldest {
|
||
return
|
||
}
|
||
if entry := d.proxyClients[oldestKey]; entry != nil {
|
||
closeOpenAIWSProxyClient(entry.client)
|
||
}
|
||
delete(d.proxyClients, oldestKey)
|
||
}
|
||
}
|
||
|
||
func closeOpenAIWSProxyClient(client *http.Client) {
|
||
if client == nil || client.Transport == nil {
|
||
return
|
||
}
|
||
if transport, ok := client.Transport.(*http.Transport); ok && transport != nil {
|
||
transport.CloseIdleConnections()
|
||
}
|
||
}
|
||
|
||
func (d *coderOpenAIWSClientDialer) SnapshotTransportMetrics() OpenAIWSTransportMetricsSnapshot {
|
||
if d == nil {
|
||
return OpenAIWSTransportMetricsSnapshot{}
|
||
}
|
||
hits := d.proxyHits.Load()
|
||
misses := d.proxyMisses.Load()
|
||
total := hits + misses
|
||
reuseRatio := 0.0
|
||
if total > 0 {
|
||
reuseRatio = float64(hits) / float64(total)
|
||
}
|
||
return OpenAIWSTransportMetricsSnapshot{
|
||
ProxyClientCacheHits: hits,
|
||
ProxyClientCacheMisses: misses,
|
||
TransportReuseRatio: reuseRatio,
|
||
}
|
||
}
|
||
|
||
type coderOpenAIWSClientConn struct {
|
||
conn *coderws.Conn
|
||
}
|
||
|
||
func (c *coderOpenAIWSClientConn) WriteJSON(ctx context.Context, value any) error {
|
||
if c == nil || c.conn == nil {
|
||
return errOpenAIWSConnClosed
|
||
}
|
||
if ctx == nil {
|
||
ctx = context.Background()
|
||
}
|
||
return wsjson.Write(ctx, c.conn, value)
|
||
}
|
||
|
||
func (c *coderOpenAIWSClientConn) ReadMessage(ctx context.Context) ([]byte, error) {
|
||
if c == nil || c.conn == nil {
|
||
return nil, errOpenAIWSConnClosed
|
||
}
|
||
if ctx == nil {
|
||
ctx = context.Background()
|
||
}
|
||
|
||
msgType, payload, err := c.conn.Read(ctx)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
switch msgType {
|
||
case coderws.MessageText, coderws.MessageBinary:
|
||
return payload, nil
|
||
default:
|
||
return nil, errOpenAIWSConnClosed
|
||
}
|
||
}
|
||
|
||
func (c *coderOpenAIWSClientConn) Ping(ctx context.Context) error {
|
||
if c == nil || c.conn == nil {
|
||
return errOpenAIWSConnClosed
|
||
}
|
||
if ctx == nil {
|
||
ctx = context.Background()
|
||
}
|
||
return c.conn.Ping(ctx)
|
||
}
|
||
|
||
func (c *coderOpenAIWSClientConn) Close() error {
|
||
if c == nil || c.conn == nil {
|
||
return nil
|
||
}
|
||
// Close 为幂等,忽略重复关闭错误。
|
||
_ = c.conn.Close(coderws.StatusNormalClosure, "")
|
||
_ = c.conn.CloseNow()
|
||
return nil
|
||
}
|