feat(sync): full code sync from release
This commit is contained in:
285
backend/internal/service/openai_ws_client.go
Normal file
285
backend/internal/service/openai_ws_client.go
Normal file
@@ -0,0 +1,285 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
coderws "github.com/coder/websocket"
|
||||
"github.com/coder/websocket/wsjson"
|
||||
)
|
||||
|
||||
const openAIWSMessageReadLimitBytes int64 = 16 * 1024 * 1024
|
||||
const (
|
||||
openAIWSProxyTransportMaxIdleConns = 128
|
||||
openAIWSProxyTransportMaxIdleConnsPerHost = 64
|
||||
openAIWSProxyTransportIdleConnTimeout = 90 * time.Second
|
||||
openAIWSProxyClientCacheMaxEntries = 256
|
||||
openAIWSProxyClientCacheIdleTTL = 15 * time.Minute
|
||||
)
|
||||
|
||||
type OpenAIWSTransportMetricsSnapshot struct {
|
||||
ProxyClientCacheHits int64 `json:"proxy_client_cache_hits"`
|
||||
ProxyClientCacheMisses int64 `json:"proxy_client_cache_misses"`
|
||||
TransportReuseRatio float64 `json:"transport_reuse_ratio"`
|
||||
}
|
||||
|
||||
// openAIWSClientConn 抽象 WS 客户端连接,便于替换底层实现。
|
||||
type openAIWSClientConn interface {
|
||||
WriteJSON(ctx context.Context, value any) error
|
||||
ReadMessage(ctx context.Context) ([]byte, error)
|
||||
Ping(ctx context.Context) error
|
||||
Close() error
|
||||
}
|
||||
|
||||
// openAIWSClientDialer 抽象 WS 建连器。
|
||||
type openAIWSClientDialer interface {
|
||||
Dial(ctx context.Context, wsURL string, headers http.Header, proxyURL string) (openAIWSClientConn, int, http.Header, error)
|
||||
}
|
||||
|
||||
type openAIWSTransportMetricsDialer interface {
|
||||
SnapshotTransportMetrics() OpenAIWSTransportMetricsSnapshot
|
||||
}
|
||||
|
||||
func newDefaultOpenAIWSClientDialer() openAIWSClientDialer {
|
||||
return &coderOpenAIWSClientDialer{
|
||||
proxyClients: make(map[string]*openAIWSProxyClientEntry),
|
||||
}
|
||||
}
|
||||
|
||||
type coderOpenAIWSClientDialer struct {
|
||||
proxyMu sync.Mutex
|
||||
proxyClients map[string]*openAIWSProxyClientEntry
|
||||
proxyHits atomic.Int64
|
||||
proxyMisses atomic.Int64
|
||||
}
|
||||
|
||||
type openAIWSProxyClientEntry struct {
|
||||
client *http.Client
|
||||
lastUsedUnixNano int64
|
||||
}
|
||||
|
||||
func (d *coderOpenAIWSClientDialer) Dial(
|
||||
ctx context.Context,
|
||||
wsURL string,
|
||||
headers http.Header,
|
||||
proxyURL string,
|
||||
) (openAIWSClientConn, int, http.Header, error) {
|
||||
targetURL := strings.TrimSpace(wsURL)
|
||||
if targetURL == "" {
|
||||
return nil, 0, nil, errors.New("ws url is empty")
|
||||
}
|
||||
|
||||
opts := &coderws.DialOptions{
|
||||
HTTPHeader: cloneHeader(headers),
|
||||
CompressionMode: coderws.CompressionContextTakeover,
|
||||
}
|
||||
if proxy := strings.TrimSpace(proxyURL); proxy != "" {
|
||||
proxyClient, err := d.proxyHTTPClient(proxy)
|
||||
if err != nil {
|
||||
return nil, 0, nil, err
|
||||
}
|
||||
opts.HTTPClient = proxyClient
|
||||
}
|
||||
|
||||
conn, resp, err := coderws.Dial(ctx, targetURL, opts)
|
||||
if err != nil {
|
||||
status := 0
|
||||
respHeaders := http.Header(nil)
|
||||
if resp != nil {
|
||||
status = resp.StatusCode
|
||||
respHeaders = cloneHeader(resp.Header)
|
||||
}
|
||||
return nil, status, respHeaders, err
|
||||
}
|
||||
// coder/websocket 默认单消息读取上限为 32KB,Codex WS 事件(如 rate_limits/大 delta)
|
||||
// 可能超过该阈值,需显式提高上限,避免本地 read_fail(message too big)。
|
||||
conn.SetReadLimit(openAIWSMessageReadLimitBytes)
|
||||
respHeaders := http.Header(nil)
|
||||
if resp != nil {
|
||||
respHeaders = cloneHeader(resp.Header)
|
||||
}
|
||||
return &coderOpenAIWSClientConn{conn: conn}, 0, respHeaders, nil
|
||||
}
|
||||
|
||||
func (d *coderOpenAIWSClientDialer) proxyHTTPClient(proxy string) (*http.Client, error) {
|
||||
if d == nil {
|
||||
return nil, errors.New("openai ws dialer is nil")
|
||||
}
|
||||
normalizedProxy := strings.TrimSpace(proxy)
|
||||
if normalizedProxy == "" {
|
||||
return nil, errors.New("proxy url is empty")
|
||||
}
|
||||
parsedProxyURL, err := url.Parse(normalizedProxy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid proxy url: %w", err)
|
||||
}
|
||||
now := time.Now().UnixNano()
|
||||
|
||||
d.proxyMu.Lock()
|
||||
defer d.proxyMu.Unlock()
|
||||
if entry, ok := d.proxyClients[normalizedProxy]; ok && entry != nil && entry.client != nil {
|
||||
entry.lastUsedUnixNano = now
|
||||
d.proxyHits.Add(1)
|
||||
return entry.client, nil
|
||||
}
|
||||
d.cleanupProxyClientsLocked(now)
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyURL(parsedProxyURL),
|
||||
MaxIdleConns: openAIWSProxyTransportMaxIdleConns,
|
||||
MaxIdleConnsPerHost: openAIWSProxyTransportMaxIdleConnsPerHost,
|
||||
IdleConnTimeout: openAIWSProxyTransportIdleConnTimeout,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ForceAttemptHTTP2: true,
|
||||
}
|
||||
client := &http.Client{Transport: transport}
|
||||
d.proxyClients[normalizedProxy] = &openAIWSProxyClientEntry{
|
||||
client: client,
|
||||
lastUsedUnixNano: now,
|
||||
}
|
||||
d.ensureProxyClientCapacityLocked()
|
||||
d.proxyMisses.Add(1)
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (d *coderOpenAIWSClientDialer) cleanupProxyClientsLocked(nowUnixNano int64) {
|
||||
if d == nil || len(d.proxyClients) == 0 {
|
||||
return
|
||||
}
|
||||
idleTTL := openAIWSProxyClientCacheIdleTTL
|
||||
if idleTTL <= 0 {
|
||||
return
|
||||
}
|
||||
now := time.Unix(0, nowUnixNano)
|
||||
for key, entry := range d.proxyClients {
|
||||
if entry == nil || entry.client == nil {
|
||||
delete(d.proxyClients, key)
|
||||
continue
|
||||
}
|
||||
lastUsed := time.Unix(0, entry.lastUsedUnixNano)
|
||||
if now.Sub(lastUsed) > idleTTL {
|
||||
closeOpenAIWSProxyClient(entry.client)
|
||||
delete(d.proxyClients, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *coderOpenAIWSClientDialer) ensureProxyClientCapacityLocked() {
|
||||
if d == nil {
|
||||
return
|
||||
}
|
||||
maxEntries := openAIWSProxyClientCacheMaxEntries
|
||||
if maxEntries <= 0 {
|
||||
return
|
||||
}
|
||||
for len(d.proxyClients) > maxEntries {
|
||||
var oldestKey string
|
||||
var oldestLastUsed int64
|
||||
hasOldest := false
|
||||
for key, entry := range d.proxyClients {
|
||||
lastUsed := int64(0)
|
||||
if entry != nil {
|
||||
lastUsed = entry.lastUsedUnixNano
|
||||
}
|
||||
if !hasOldest || lastUsed < oldestLastUsed {
|
||||
hasOldest = true
|
||||
oldestKey = key
|
||||
oldestLastUsed = lastUsed
|
||||
}
|
||||
}
|
||||
if !hasOldest {
|
||||
return
|
||||
}
|
||||
if entry := d.proxyClients[oldestKey]; entry != nil {
|
||||
closeOpenAIWSProxyClient(entry.client)
|
||||
}
|
||||
delete(d.proxyClients, oldestKey)
|
||||
}
|
||||
}
|
||||
|
||||
func closeOpenAIWSProxyClient(client *http.Client) {
|
||||
if client == nil || client.Transport == nil {
|
||||
return
|
||||
}
|
||||
if transport, ok := client.Transport.(*http.Transport); ok && transport != nil {
|
||||
transport.CloseIdleConnections()
|
||||
}
|
||||
}
|
||||
|
||||
func (d *coderOpenAIWSClientDialer) SnapshotTransportMetrics() OpenAIWSTransportMetricsSnapshot {
|
||||
if d == nil {
|
||||
return OpenAIWSTransportMetricsSnapshot{}
|
||||
}
|
||||
hits := d.proxyHits.Load()
|
||||
misses := d.proxyMisses.Load()
|
||||
total := hits + misses
|
||||
reuseRatio := 0.0
|
||||
if total > 0 {
|
||||
reuseRatio = float64(hits) / float64(total)
|
||||
}
|
||||
return OpenAIWSTransportMetricsSnapshot{
|
||||
ProxyClientCacheHits: hits,
|
||||
ProxyClientCacheMisses: misses,
|
||||
TransportReuseRatio: reuseRatio,
|
||||
}
|
||||
}
|
||||
|
||||
type coderOpenAIWSClientConn struct {
|
||||
conn *coderws.Conn
|
||||
}
|
||||
|
||||
func (c *coderOpenAIWSClientConn) WriteJSON(ctx context.Context, value any) error {
|
||||
if c == nil || c.conn == nil {
|
||||
return errOpenAIWSConnClosed
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
return wsjson.Write(ctx, c.conn, value)
|
||||
}
|
||||
|
||||
func (c *coderOpenAIWSClientConn) ReadMessage(ctx context.Context) ([]byte, error) {
|
||||
if c == nil || c.conn == nil {
|
||||
return nil, errOpenAIWSConnClosed
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
msgType, payload, err := c.conn.Read(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch msgType {
|
||||
case coderws.MessageText, coderws.MessageBinary:
|
||||
return payload, nil
|
||||
default:
|
||||
return nil, errOpenAIWSConnClosed
|
||||
}
|
||||
}
|
||||
|
||||
func (c *coderOpenAIWSClientConn) Ping(ctx context.Context) error {
|
||||
if c == nil || c.conn == nil {
|
||||
return errOpenAIWSConnClosed
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
return c.conn.Ping(ctx)
|
||||
}
|
||||
|
||||
func (c *coderOpenAIWSClientConn) Close() error {
|
||||
if c == nil || c.conn == nil {
|
||||
return nil
|
||||
}
|
||||
// Close 为幂等,忽略重复关闭错误。
|
||||
_ = c.conn.Close(coderws.StatusNormalClosure, "")
|
||||
_ = c.conn.CloseNow()
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user