Files
xinghuoapi/backend/internal/service/openai_ws_v2/passthrough_relay.go
yangjianbo 1d0872e7ca feat(openai-ws): 合并 WS v2 透传模式与前端 ws mode
新增 OpenAI WebSocket v2 passthrough relay 数据面与服务适配层,
支持按账号 ws mode 在 ctx_pool 与 passthrough 间路由。

同步调整前端 OpenAI ws mode 选项为 off/ctx_pool/passthrough,
并补充 i18n 文案与对应单测。

新增 Caddyfile.dmit 与 docker-compose-aicodex.yml 部署配置,
用于宿主机场景下的反向代理与服务编排。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-05 11:50:58 +08:00

808 lines
22 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package openai_ws_v2
import (
"context"
"errors"
"io"
"net"
"strconv"
"strings"
"sync/atomic"
"time"
coderws "github.com/coder/websocket"
"github.com/tidwall/gjson"
)
type FrameConn interface {
ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error)
WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error
Close() error
}
type Usage struct {
InputTokens int
OutputTokens int
CacheCreationInputTokens int
CacheReadInputTokens int
}
type RelayResult struct {
RequestModel string
Usage Usage
RequestID string
TerminalEventType string
FirstTokenMs *int
Duration time.Duration
ClientToUpstreamFrames int64
UpstreamToClientFrames int64
DroppedDownstreamFrames int64
}
type RelayTurnResult struct {
RequestModel string
Usage Usage
RequestID string
TerminalEventType string
Duration time.Duration
FirstTokenMs *int
}
type RelayExit struct {
Stage string
Err error
WroteDownstream bool
}
type RelayOptions struct {
WriteTimeout time.Duration
IdleTimeout time.Duration
UpstreamDrainTimeout time.Duration
FirstMessageType coderws.MessageType
OnUsageParseFailure func(eventType string, usageRaw string)
OnTurnComplete func(turn RelayTurnResult)
OnTrace func(event RelayTraceEvent)
Now func() time.Time
}
type RelayTraceEvent struct {
Stage string
Direction string
MessageType string
PayloadBytes int
Graceful bool
WroteDownstream bool
Error string
}
type relayState struct {
usage Usage
requestModel string
lastResponseID string
terminalEventType string
firstTokenMs *int
turnTimingByID map[string]*relayTurnTiming
}
type relayExitSignal struct {
stage string
err error
graceful bool
wroteDownstream bool
}
type observedUpstreamEvent struct {
terminal bool
eventType string
responseID string
usage Usage
duration time.Duration
firstToken *int
}
type relayTurnTiming struct {
startAt time.Time
firstTokenMs *int
}
func Relay(
ctx context.Context,
clientConn FrameConn,
upstreamConn FrameConn,
firstClientMessage []byte,
options RelayOptions,
) (RelayResult, *RelayExit) {
result := RelayResult{RequestModel: strings.TrimSpace(gjson.GetBytes(firstClientMessage, "model").String())}
if clientConn == nil || upstreamConn == nil {
return result, &RelayExit{Stage: "relay_init", Err: errors.New("relay connection is nil")}
}
if ctx == nil {
ctx = context.Background()
}
nowFn := options.Now
if nowFn == nil {
nowFn = time.Now
}
writeTimeout := options.WriteTimeout
if writeTimeout <= 0 {
writeTimeout = 2 * time.Minute
}
drainTimeout := options.UpstreamDrainTimeout
if drainTimeout <= 0 {
drainTimeout = 1200 * time.Millisecond
}
firstMessageType := options.FirstMessageType
if firstMessageType != coderws.MessageBinary {
firstMessageType = coderws.MessageText
}
startAt := nowFn()
state := &relayState{requestModel: result.RequestModel}
onTrace := options.OnTrace
relayCtx, relayCancel := context.WithCancel(ctx)
defer relayCancel()
lastActivity := atomic.Int64{}
lastActivity.Store(nowFn().UnixNano())
markActivity := func() {
lastActivity.Store(nowFn().UnixNano())
}
writeUpstream := func(msgType coderws.MessageType, payload []byte) error {
writeCtx, cancel := context.WithTimeout(relayCtx, writeTimeout)
defer cancel()
return upstreamConn.WriteFrame(writeCtx, msgType, payload)
}
writeClient := func(msgType coderws.MessageType, payload []byte) error {
writeCtx, cancel := context.WithTimeout(relayCtx, writeTimeout)
defer cancel()
return clientConn.WriteFrame(writeCtx, msgType, payload)
}
clientToUpstreamFrames := &atomic.Int64{}
upstreamToClientFrames := &atomic.Int64{}
droppedDownstreamFrames := &atomic.Int64{}
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "relay_start",
PayloadBytes: len(firstClientMessage),
MessageType: relayMessageTypeString(firstMessageType),
})
if err := writeUpstream(firstMessageType, firstClientMessage); err != nil {
result.Duration = nowFn().Sub(startAt)
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "write_first_message_failed",
Direction: "client_to_upstream",
MessageType: relayMessageTypeString(firstMessageType),
PayloadBytes: len(firstClientMessage),
Error: err.Error(),
})
return result, &RelayExit{Stage: "write_upstream", Err: err}
}
clientToUpstreamFrames.Add(1)
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "write_first_message_ok",
Direction: "client_to_upstream",
MessageType: relayMessageTypeString(firstMessageType),
PayloadBytes: len(firstClientMessage),
})
markActivity()
exitCh := make(chan relayExitSignal, 3)
dropDownstreamWrites := atomic.Bool{}
go runClientToUpstream(relayCtx, clientConn, writeUpstream, markActivity, clientToUpstreamFrames, onTrace, exitCh)
go runUpstreamToClient(
relayCtx,
upstreamConn,
writeClient,
startAt,
nowFn,
state,
options.OnUsageParseFailure,
options.OnTurnComplete,
&dropDownstreamWrites,
upstreamToClientFrames,
droppedDownstreamFrames,
markActivity,
onTrace,
exitCh,
)
go runIdleWatchdog(relayCtx, nowFn, options.IdleTimeout, &lastActivity, onTrace, exitCh)
firstExit := <-exitCh
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "first_exit",
Direction: relayDirectionFromStage(firstExit.stage),
Graceful: firstExit.graceful,
WroteDownstream: firstExit.wroteDownstream,
Error: relayErrorString(firstExit.err),
})
combinedWroteDownstream := firstExit.wroteDownstream
secondExit := relayExitSignal{graceful: true}
hasSecondExit := false
// 客户端断开后尽力继续读取上游短窗口,捕获延迟 usage/terminal 事件用于计费。
if firstExit.stage == "read_client" && firstExit.graceful {
dropDownstreamWrites.Store(true)
secondExit, hasSecondExit = waitRelayExit(exitCh, drainTimeout)
} else {
relayCancel()
_ = upstreamConn.Close()
secondExit, hasSecondExit = waitRelayExit(exitCh, 200*time.Millisecond)
}
if hasSecondExit {
combinedWroteDownstream = combinedWroteDownstream || secondExit.wroteDownstream
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "second_exit",
Direction: relayDirectionFromStage(secondExit.stage),
Graceful: secondExit.graceful,
WroteDownstream: secondExit.wroteDownstream,
Error: relayErrorString(secondExit.err),
})
}
relayCancel()
_ = upstreamConn.Close()
enrichResult(&result, state, nowFn().Sub(startAt))
result.ClientToUpstreamFrames = clientToUpstreamFrames.Load()
result.UpstreamToClientFrames = upstreamToClientFrames.Load()
result.DroppedDownstreamFrames = droppedDownstreamFrames.Load()
if firstExit.stage == "read_client" && firstExit.graceful {
stage := "client_disconnected"
exitErr := firstExit.err
if hasSecondExit && !secondExit.graceful {
stage = secondExit.stage
exitErr = secondExit.err
}
if exitErr == nil {
exitErr = io.EOF
}
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "relay_exit",
Direction: relayDirectionFromStage(stage),
Graceful: false,
WroteDownstream: combinedWroteDownstream,
Error: relayErrorString(exitErr),
})
return result, &RelayExit{
Stage: stage,
Err: exitErr,
WroteDownstream: combinedWroteDownstream,
}
}
if firstExit.graceful && (!hasSecondExit || secondExit.graceful) {
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "relay_complete",
Graceful: true,
WroteDownstream: combinedWroteDownstream,
})
_ = clientConn.Close()
return result, nil
}
if !firstExit.graceful {
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "relay_exit",
Direction: relayDirectionFromStage(firstExit.stage),
Graceful: false,
WroteDownstream: combinedWroteDownstream,
Error: relayErrorString(firstExit.err),
})
return result, &RelayExit{
Stage: firstExit.stage,
Err: firstExit.err,
WroteDownstream: combinedWroteDownstream,
}
}
if hasSecondExit && !secondExit.graceful {
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "relay_exit",
Direction: relayDirectionFromStage(secondExit.stage),
Graceful: false,
WroteDownstream: combinedWroteDownstream,
Error: relayErrorString(secondExit.err),
})
return result, &RelayExit{
Stage: secondExit.stage,
Err: secondExit.err,
WroteDownstream: combinedWroteDownstream,
}
}
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "relay_complete",
Graceful: true,
WroteDownstream: combinedWroteDownstream,
})
_ = clientConn.Close()
return result, nil
}
func runClientToUpstream(
ctx context.Context,
clientConn FrameConn,
writeUpstream func(msgType coderws.MessageType, payload []byte) error,
markActivity func(),
forwardedFrames *atomic.Int64,
onTrace func(event RelayTraceEvent),
exitCh chan<- relayExitSignal,
) {
for {
msgType, payload, err := clientConn.ReadFrame(ctx)
if err != nil {
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "read_client_failed",
Direction: "client_to_upstream",
Error: err.Error(),
Graceful: isDisconnectError(err),
})
exitCh <- relayExitSignal{stage: "read_client", err: err, graceful: isDisconnectError(err)}
return
}
markActivity()
if err := writeUpstream(msgType, payload); err != nil {
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "write_upstream_failed",
Direction: "client_to_upstream",
MessageType: relayMessageTypeString(msgType),
PayloadBytes: len(payload),
Error: err.Error(),
})
exitCh <- relayExitSignal{stage: "write_upstream", err: err}
return
}
if forwardedFrames != nil {
forwardedFrames.Add(1)
}
markActivity()
}
}
func runUpstreamToClient(
ctx context.Context,
upstreamConn FrameConn,
writeClient func(msgType coderws.MessageType, payload []byte) error,
startAt time.Time,
nowFn func() time.Time,
state *relayState,
onUsageParseFailure func(eventType string, usageRaw string),
onTurnComplete func(turn RelayTurnResult),
dropDownstreamWrites *atomic.Bool,
forwardedFrames *atomic.Int64,
droppedFrames *atomic.Int64,
markActivity func(),
onTrace func(event RelayTraceEvent),
exitCh chan<- relayExitSignal,
) {
wroteDownstream := false
for {
msgType, payload, err := upstreamConn.ReadFrame(ctx)
if err != nil {
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "read_upstream_failed",
Direction: "upstream_to_client",
Error: err.Error(),
Graceful: isDisconnectError(err),
WroteDownstream: wroteDownstream,
})
exitCh <- relayExitSignal{
stage: "read_upstream",
err: err,
graceful: isDisconnectError(err),
wroteDownstream: wroteDownstream,
}
return
}
markActivity()
observedEvent := observedUpstreamEvent{}
switch msgType {
case coderws.MessageText:
observedEvent = observeUpstreamMessage(state, payload, startAt, nowFn, onUsageParseFailure)
case coderws.MessageBinary:
// binary frame 直接透传,不进入 JSON 观测路径(避免无效解析开销)。
}
emitTurnComplete(onTurnComplete, state, observedEvent)
if dropDownstreamWrites != nil && dropDownstreamWrites.Load() {
if droppedFrames != nil {
droppedFrames.Add(1)
}
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "drop_downstream_frame",
Direction: "upstream_to_client",
MessageType: relayMessageTypeString(msgType),
PayloadBytes: len(payload),
WroteDownstream: wroteDownstream,
})
if observedEvent.terminal {
exitCh <- relayExitSignal{
stage: "drain_terminal",
graceful: true,
wroteDownstream: wroteDownstream,
}
return
}
markActivity()
continue
}
if err := writeClient(msgType, payload); err != nil {
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "write_client_failed",
Direction: "upstream_to_client",
MessageType: relayMessageTypeString(msgType),
PayloadBytes: len(payload),
WroteDownstream: wroteDownstream,
Error: err.Error(),
})
exitCh <- relayExitSignal{stage: "write_client", err: err, wroteDownstream: wroteDownstream}
return
}
wroteDownstream = true
if forwardedFrames != nil {
forwardedFrames.Add(1)
}
markActivity()
}
}
func runIdleWatchdog(
ctx context.Context,
nowFn func() time.Time,
idleTimeout time.Duration,
lastActivity *atomic.Int64,
onTrace func(event RelayTraceEvent),
exitCh chan<- relayExitSignal,
) {
if idleTimeout <= 0 {
return
}
checkInterval := minDuration(idleTimeout/4, 5*time.Second)
if checkInterval < time.Second {
checkInterval = time.Second
}
ticker := time.NewTicker(checkInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
last := time.Unix(0, lastActivity.Load())
if nowFn().Sub(last) < idleTimeout {
continue
}
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "idle_timeout_triggered",
Direction: "watchdog",
Error: context.DeadlineExceeded.Error(),
})
exitCh <- relayExitSignal{stage: "idle_timeout", err: context.DeadlineExceeded}
return
}
}
}
func emitRelayTrace(onTrace func(event RelayTraceEvent), event RelayTraceEvent) {
if onTrace == nil {
return
}
onTrace(event)
}
func relayMessageTypeString(msgType coderws.MessageType) string {
switch msgType {
case coderws.MessageText:
return "text"
case coderws.MessageBinary:
return "binary"
default:
return "unknown(" + strconv.Itoa(int(msgType)) + ")"
}
}
func relayDirectionFromStage(stage string) string {
switch stage {
case "read_client", "write_upstream":
return "client_to_upstream"
case "read_upstream", "write_client", "drain_terminal":
return "upstream_to_client"
case "idle_timeout":
return "watchdog"
default:
return ""
}
}
func relayErrorString(err error) string {
if err == nil {
return ""
}
return err.Error()
}
func observeUpstreamMessage(
state *relayState,
message []byte,
startAt time.Time,
nowFn func() time.Time,
onUsageParseFailure func(eventType string, usageRaw string),
) observedUpstreamEvent {
if state == nil || len(message) == 0 {
return observedUpstreamEvent{}
}
values := gjson.GetManyBytes(message, "type", "response.id", "response_id", "id")
eventType := strings.TrimSpace(values[0].String())
if eventType == "" {
return observedUpstreamEvent{}
}
responseID := strings.TrimSpace(values[1].String())
if responseID == "" {
responseID = strings.TrimSpace(values[2].String())
}
// 仅 terminal 事件兜底读取顶层 id避免把 event_id 当成 response_id 关联到 turn。
if responseID == "" && isTerminalEvent(eventType) {
responseID = strings.TrimSpace(values[3].String())
}
now := nowFn()
if state.firstTokenMs == nil && isTokenEvent(eventType) {
ms := int(now.Sub(startAt).Milliseconds())
if ms >= 0 {
state.firstTokenMs = &ms
}
}
parsedUsage := parseUsageAndAccumulate(state, message, eventType, onUsageParseFailure)
observed := observedUpstreamEvent{
eventType: eventType,
responseID: responseID,
usage: parsedUsage,
}
if responseID != "" {
turnTiming := openAIWSRelayGetOrInitTurnTiming(state, responseID, now)
if turnTiming != nil && turnTiming.firstTokenMs == nil && isTokenEvent(eventType) {
ms := int(now.Sub(turnTiming.startAt).Milliseconds())
if ms >= 0 {
turnTiming.firstTokenMs = &ms
}
}
}
if !isTerminalEvent(eventType) {
return observed
}
observed.terminal = true
state.terminalEventType = eventType
if responseID != "" {
state.lastResponseID = responseID
if turnTiming, ok := openAIWSRelayDeleteTurnTiming(state, responseID); ok {
duration := now.Sub(turnTiming.startAt)
if duration < 0 {
duration = 0
}
observed.duration = duration
observed.firstToken = openAIWSRelayCloneIntPtr(turnTiming.firstTokenMs)
}
}
return observed
}
func emitTurnComplete(
onTurnComplete func(turn RelayTurnResult),
state *relayState,
observed observedUpstreamEvent,
) {
if onTurnComplete == nil || !observed.terminal {
return
}
responseID := strings.TrimSpace(observed.responseID)
if responseID == "" {
return
}
requestModel := ""
if state != nil {
requestModel = state.requestModel
}
onTurnComplete(RelayTurnResult{
RequestModel: requestModel,
Usage: observed.usage,
RequestID: responseID,
TerminalEventType: observed.eventType,
Duration: observed.duration,
FirstTokenMs: openAIWSRelayCloneIntPtr(observed.firstToken),
})
}
func openAIWSRelayGetOrInitTurnTiming(state *relayState, responseID string, now time.Time) *relayTurnTiming {
if state == nil {
return nil
}
if state.turnTimingByID == nil {
state.turnTimingByID = make(map[string]*relayTurnTiming, 8)
}
timing, ok := state.turnTimingByID[responseID]
if !ok || timing == nil || timing.startAt.IsZero() {
timing = &relayTurnTiming{startAt: now}
state.turnTimingByID[responseID] = timing
return timing
}
return timing
}
func openAIWSRelayDeleteTurnTiming(state *relayState, responseID string) (relayTurnTiming, bool) {
if state == nil || state.turnTimingByID == nil {
return relayTurnTiming{}, false
}
timing, ok := state.turnTimingByID[responseID]
if !ok || timing == nil {
return relayTurnTiming{}, false
}
delete(state.turnTimingByID, responseID)
return *timing, true
}
func openAIWSRelayCloneIntPtr(v *int) *int {
if v == nil {
return nil
}
cloned := *v
return &cloned
}
func parseUsageAndAccumulate(
state *relayState,
message []byte,
eventType string,
onParseFailure func(eventType string, usageRaw string),
) Usage {
if state == nil || len(message) == 0 || !shouldParseUsage(eventType) {
return Usage{}
}
usageResult := gjson.GetBytes(message, "response.usage")
if !usageResult.Exists() {
return Usage{}
}
usageRaw := strings.TrimSpace(usageResult.Raw)
if usageRaw == "" || !strings.HasPrefix(usageRaw, "{") {
recordUsageParseFailure()
if onParseFailure != nil {
onParseFailure(eventType, usageRaw)
}
return Usage{}
}
inputResult := gjson.GetBytes(message, "response.usage.input_tokens")
outputResult := gjson.GetBytes(message, "response.usage.output_tokens")
cachedResult := gjson.GetBytes(message, "response.usage.input_tokens_details.cached_tokens")
inputTokens, inputOK := parseUsageIntField(inputResult, true)
outputTokens, outputOK := parseUsageIntField(outputResult, true)
cachedTokens, cachedOK := parseUsageIntField(cachedResult, false)
if !inputOK || !outputOK || !cachedOK {
recordUsageParseFailure()
if onParseFailure != nil {
onParseFailure(eventType, usageRaw)
}
// 解析失败时不做部分字段累加,避免计费 usage 出现“半有效”状态。
return Usage{}
}
parsedUsage := Usage{
InputTokens: inputTokens,
OutputTokens: outputTokens,
CacheReadInputTokens: cachedTokens,
}
state.usage.InputTokens += parsedUsage.InputTokens
state.usage.OutputTokens += parsedUsage.OutputTokens
state.usage.CacheReadInputTokens += parsedUsage.CacheReadInputTokens
return parsedUsage
}
func parseUsageIntField(value gjson.Result, required bool) (int, bool) {
if !value.Exists() {
return 0, !required
}
if value.Type != gjson.Number {
return 0, false
}
return int(value.Int()), true
}
func enrichResult(result *RelayResult, state *relayState, duration time.Duration) {
if result == nil {
return
}
result.Duration = duration
if state == nil {
return
}
result.RequestModel = state.requestModel
result.Usage = state.usage
result.RequestID = state.lastResponseID
result.TerminalEventType = state.terminalEventType
result.FirstTokenMs = state.firstTokenMs
}
func isDisconnectError(err error) bool {
if err == nil {
return false
}
if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) {
return true
}
switch coderws.CloseStatus(err) {
case coderws.StatusNormalClosure, coderws.StatusGoingAway, coderws.StatusNoStatusRcvd, coderws.StatusAbnormalClosure:
return true
}
message := strings.ToLower(strings.TrimSpace(err.Error()))
if message == "" {
return false
}
return strings.Contains(message, "failed to read frame header: eof") ||
strings.Contains(message, "unexpected eof") ||
strings.Contains(message, "use of closed network connection") ||
strings.Contains(message, "connection reset by peer") ||
strings.Contains(message, "broken pipe")
}
func isTerminalEvent(eventType string) bool {
switch eventType {
case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled":
return true
default:
return false
}
}
func shouldParseUsage(eventType string) bool {
switch eventType {
case "response.completed", "response.done", "response.failed":
return true
default:
return false
}
}
func isTokenEvent(eventType string) bool {
if eventType == "" {
return false
}
switch eventType {
case "response.created", "response.in_progress", "response.output_item.added", "response.output_item.done":
return false
}
if strings.Contains(eventType, ".delta") {
return true
}
if strings.HasPrefix(eventType, "response.output_text") {
return true
}
if strings.HasPrefix(eventType, "response.output") {
return true
}
return eventType == "response.completed" || eventType == "response.done"
}
func minDuration(a, b time.Duration) time.Duration {
if a <= 0 {
return b
}
if b <= 0 {
return a
}
if a < b {
return a
}
return b
}
func waitRelayExit(exitCh <-chan relayExitSignal, timeout time.Duration) (relayExitSignal, bool) {
if timeout <= 0 {
timeout = 200 * time.Millisecond
}
select {
case sig := <-exitCh:
return sig, true
case <-time.After(timeout):
return relayExitSignal{}, false
}
}