feat(openai-ws): 合并 WS v2 透传模式与前端 ws mode

新增 OpenAI WebSocket v2 passthrough relay 数据面与服务适配层,
支持按账号 ws mode 在 ctx_pool 与 passthrough 间路由。

同步调整前端 OpenAI ws mode 选项为 off/ctx_pool/passthrough,
并补充 i18n 文案与对应单测。

新增 Caddyfile.dmit 与 docker-compose-aicodex.yml 部署配置,
用于宿主机场景下的反向代理与服务编排。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
yangjianbo
2026-03-05 11:50:58 +08:00
parent 078fefed03
commit 1d0872e7ca
27 changed files with 3322 additions and 81 deletions

View File

@@ -0,0 +1,24 @@
package openai_ws_v2
import (
"context"
)
// runCaddyStyleRelay 采用 Caddy reverseproxy 的双向隧道思想:
// 连接建立后并发复制两个方向,任一方向退出触发收敛关闭。
//
// Reference:
// - Project: caddyserver/caddy (Apache-2.0)
// - Commit: f283062d37c50627d53ca682ebae2ce219b35515
// - Files:
// - modules/caddyhttp/reverseproxy/streaming.go
// - modules/caddyhttp/reverseproxy/reverseproxy.go
func runCaddyStyleRelay(
ctx context.Context,
clientConn FrameConn,
upstreamConn FrameConn,
firstClientMessage []byte,
options RelayOptions,
) (RelayResult, *RelayExit) {
return Relay(ctx, clientConn, upstreamConn, firstClientMessage, options)
}

View File

@@ -0,0 +1,23 @@
package openai_ws_v2
import "context"
// EntryInput 是 passthrough v2 数据面的入口参数。
type EntryInput struct {
Ctx context.Context
ClientConn FrameConn
UpstreamConn FrameConn
FirstClientMessage []byte
Options RelayOptions
}
// RunEntry 是 openai_ws_v2 包对外的统一入口。
func RunEntry(input EntryInput) (RelayResult, *RelayExit) {
return runCaddyStyleRelay(
input.Ctx,
input.ClientConn,
input.UpstreamConn,
input.FirstClientMessage,
input.Options,
)
}

View File

@@ -0,0 +1,29 @@
package openai_ws_v2
import (
"sync/atomic"
)
// MetricsSnapshot 是 OpenAI WS v2 passthrough 路径的轻量运行时指标快照。
type MetricsSnapshot struct {
SemanticMutationTotal int64 `json:"semantic_mutation_total"`
UsageParseFailureTotal int64 `json:"usage_parse_failure_total"`
}
var (
// passthrough 路径默认不会做语义改写,该计数通常应保持为 0保留用于未来防御性校验
passthroughSemanticMutationTotal atomic.Int64
passthroughUsageParseFailureTotal atomic.Int64
)
func recordUsageParseFailure() {
passthroughUsageParseFailureTotal.Add(1)
}
// SnapshotMetrics 返回当前 passthrough 指标快照。
func SnapshotMetrics() MetricsSnapshot {
return MetricsSnapshot{
SemanticMutationTotal: passthroughSemanticMutationTotal.Load(),
UsageParseFailureTotal: passthroughUsageParseFailureTotal.Load(),
}
}

View File

@@ -0,0 +1,807 @@
package openai_ws_v2
import (
"context"
"errors"
"io"
"net"
"strconv"
"strings"
"sync/atomic"
"time"
coderws "github.com/coder/websocket"
"github.com/tidwall/gjson"
)
type FrameConn interface {
ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error)
WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error
Close() error
}
type Usage struct {
InputTokens int
OutputTokens int
CacheCreationInputTokens int
CacheReadInputTokens int
}
type RelayResult struct {
RequestModel string
Usage Usage
RequestID string
TerminalEventType string
FirstTokenMs *int
Duration time.Duration
ClientToUpstreamFrames int64
UpstreamToClientFrames int64
DroppedDownstreamFrames int64
}
type RelayTurnResult struct {
RequestModel string
Usage Usage
RequestID string
TerminalEventType string
Duration time.Duration
FirstTokenMs *int
}
type RelayExit struct {
Stage string
Err error
WroteDownstream bool
}
type RelayOptions struct {
WriteTimeout time.Duration
IdleTimeout time.Duration
UpstreamDrainTimeout time.Duration
FirstMessageType coderws.MessageType
OnUsageParseFailure func(eventType string, usageRaw string)
OnTurnComplete func(turn RelayTurnResult)
OnTrace func(event RelayTraceEvent)
Now func() time.Time
}
type RelayTraceEvent struct {
Stage string
Direction string
MessageType string
PayloadBytes int
Graceful bool
WroteDownstream bool
Error string
}
type relayState struct {
usage Usage
requestModel string
lastResponseID string
terminalEventType string
firstTokenMs *int
turnTimingByID map[string]*relayTurnTiming
}
type relayExitSignal struct {
stage string
err error
graceful bool
wroteDownstream bool
}
type observedUpstreamEvent struct {
terminal bool
eventType string
responseID string
usage Usage
duration time.Duration
firstToken *int
}
type relayTurnTiming struct {
startAt time.Time
firstTokenMs *int
}
func Relay(
ctx context.Context,
clientConn FrameConn,
upstreamConn FrameConn,
firstClientMessage []byte,
options RelayOptions,
) (RelayResult, *RelayExit) {
result := RelayResult{RequestModel: strings.TrimSpace(gjson.GetBytes(firstClientMessage, "model").String())}
if clientConn == nil || upstreamConn == nil {
return result, &RelayExit{Stage: "relay_init", Err: errors.New("relay connection is nil")}
}
if ctx == nil {
ctx = context.Background()
}
nowFn := options.Now
if nowFn == nil {
nowFn = time.Now
}
writeTimeout := options.WriteTimeout
if writeTimeout <= 0 {
writeTimeout = 2 * time.Minute
}
drainTimeout := options.UpstreamDrainTimeout
if drainTimeout <= 0 {
drainTimeout = 1200 * time.Millisecond
}
firstMessageType := options.FirstMessageType
if firstMessageType != coderws.MessageBinary {
firstMessageType = coderws.MessageText
}
startAt := nowFn()
state := &relayState{requestModel: result.RequestModel}
onTrace := options.OnTrace
relayCtx, relayCancel := context.WithCancel(ctx)
defer relayCancel()
lastActivity := atomic.Int64{}
lastActivity.Store(nowFn().UnixNano())
markActivity := func() {
lastActivity.Store(nowFn().UnixNano())
}
writeUpstream := func(msgType coderws.MessageType, payload []byte) error {
writeCtx, cancel := context.WithTimeout(relayCtx, writeTimeout)
defer cancel()
return upstreamConn.WriteFrame(writeCtx, msgType, payload)
}
writeClient := func(msgType coderws.MessageType, payload []byte) error {
writeCtx, cancel := context.WithTimeout(relayCtx, writeTimeout)
defer cancel()
return clientConn.WriteFrame(writeCtx, msgType, payload)
}
clientToUpstreamFrames := &atomic.Int64{}
upstreamToClientFrames := &atomic.Int64{}
droppedDownstreamFrames := &atomic.Int64{}
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "relay_start",
PayloadBytes: len(firstClientMessage),
MessageType: relayMessageTypeString(firstMessageType),
})
if err := writeUpstream(firstMessageType, firstClientMessage); err != nil {
result.Duration = nowFn().Sub(startAt)
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "write_first_message_failed",
Direction: "client_to_upstream",
MessageType: relayMessageTypeString(firstMessageType),
PayloadBytes: len(firstClientMessage),
Error: err.Error(),
})
return result, &RelayExit{Stage: "write_upstream", Err: err}
}
clientToUpstreamFrames.Add(1)
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "write_first_message_ok",
Direction: "client_to_upstream",
MessageType: relayMessageTypeString(firstMessageType),
PayloadBytes: len(firstClientMessage),
})
markActivity()
exitCh := make(chan relayExitSignal, 3)
dropDownstreamWrites := atomic.Bool{}
go runClientToUpstream(relayCtx, clientConn, writeUpstream, markActivity, clientToUpstreamFrames, onTrace, exitCh)
go runUpstreamToClient(
relayCtx,
upstreamConn,
writeClient,
startAt,
nowFn,
state,
options.OnUsageParseFailure,
options.OnTurnComplete,
&dropDownstreamWrites,
upstreamToClientFrames,
droppedDownstreamFrames,
markActivity,
onTrace,
exitCh,
)
go runIdleWatchdog(relayCtx, nowFn, options.IdleTimeout, &lastActivity, onTrace, exitCh)
firstExit := <-exitCh
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "first_exit",
Direction: relayDirectionFromStage(firstExit.stage),
Graceful: firstExit.graceful,
WroteDownstream: firstExit.wroteDownstream,
Error: relayErrorString(firstExit.err),
})
combinedWroteDownstream := firstExit.wroteDownstream
secondExit := relayExitSignal{graceful: true}
hasSecondExit := false
// 客户端断开后尽力继续读取上游短窗口,捕获延迟 usage/terminal 事件用于计费。
if firstExit.stage == "read_client" && firstExit.graceful {
dropDownstreamWrites.Store(true)
secondExit, hasSecondExit = waitRelayExit(exitCh, drainTimeout)
} else {
relayCancel()
_ = upstreamConn.Close()
secondExit, hasSecondExit = waitRelayExit(exitCh, 200*time.Millisecond)
}
if hasSecondExit {
combinedWroteDownstream = combinedWroteDownstream || secondExit.wroteDownstream
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "second_exit",
Direction: relayDirectionFromStage(secondExit.stage),
Graceful: secondExit.graceful,
WroteDownstream: secondExit.wroteDownstream,
Error: relayErrorString(secondExit.err),
})
}
relayCancel()
_ = upstreamConn.Close()
enrichResult(&result, state, nowFn().Sub(startAt))
result.ClientToUpstreamFrames = clientToUpstreamFrames.Load()
result.UpstreamToClientFrames = upstreamToClientFrames.Load()
result.DroppedDownstreamFrames = droppedDownstreamFrames.Load()
if firstExit.stage == "read_client" && firstExit.graceful {
stage := "client_disconnected"
exitErr := firstExit.err
if hasSecondExit && !secondExit.graceful {
stage = secondExit.stage
exitErr = secondExit.err
}
if exitErr == nil {
exitErr = io.EOF
}
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "relay_exit",
Direction: relayDirectionFromStage(stage),
Graceful: false,
WroteDownstream: combinedWroteDownstream,
Error: relayErrorString(exitErr),
})
return result, &RelayExit{
Stage: stage,
Err: exitErr,
WroteDownstream: combinedWroteDownstream,
}
}
if firstExit.graceful && (!hasSecondExit || secondExit.graceful) {
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "relay_complete",
Graceful: true,
WroteDownstream: combinedWroteDownstream,
})
_ = clientConn.Close()
return result, nil
}
if !firstExit.graceful {
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "relay_exit",
Direction: relayDirectionFromStage(firstExit.stage),
Graceful: false,
WroteDownstream: combinedWroteDownstream,
Error: relayErrorString(firstExit.err),
})
return result, &RelayExit{
Stage: firstExit.stage,
Err: firstExit.err,
WroteDownstream: combinedWroteDownstream,
}
}
if hasSecondExit && !secondExit.graceful {
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "relay_exit",
Direction: relayDirectionFromStage(secondExit.stage),
Graceful: false,
WroteDownstream: combinedWroteDownstream,
Error: relayErrorString(secondExit.err),
})
return result, &RelayExit{
Stage: secondExit.stage,
Err: secondExit.err,
WroteDownstream: combinedWroteDownstream,
}
}
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "relay_complete",
Graceful: true,
WroteDownstream: combinedWroteDownstream,
})
_ = clientConn.Close()
return result, nil
}
func runClientToUpstream(
ctx context.Context,
clientConn FrameConn,
writeUpstream func(msgType coderws.MessageType, payload []byte) error,
markActivity func(),
forwardedFrames *atomic.Int64,
onTrace func(event RelayTraceEvent),
exitCh chan<- relayExitSignal,
) {
for {
msgType, payload, err := clientConn.ReadFrame(ctx)
if err != nil {
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "read_client_failed",
Direction: "client_to_upstream",
Error: err.Error(),
Graceful: isDisconnectError(err),
})
exitCh <- relayExitSignal{stage: "read_client", err: err, graceful: isDisconnectError(err)}
return
}
markActivity()
if err := writeUpstream(msgType, payload); err != nil {
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "write_upstream_failed",
Direction: "client_to_upstream",
MessageType: relayMessageTypeString(msgType),
PayloadBytes: len(payload),
Error: err.Error(),
})
exitCh <- relayExitSignal{stage: "write_upstream", err: err}
return
}
if forwardedFrames != nil {
forwardedFrames.Add(1)
}
markActivity()
}
}
func runUpstreamToClient(
ctx context.Context,
upstreamConn FrameConn,
writeClient func(msgType coderws.MessageType, payload []byte) error,
startAt time.Time,
nowFn func() time.Time,
state *relayState,
onUsageParseFailure func(eventType string, usageRaw string),
onTurnComplete func(turn RelayTurnResult),
dropDownstreamWrites *atomic.Bool,
forwardedFrames *atomic.Int64,
droppedFrames *atomic.Int64,
markActivity func(),
onTrace func(event RelayTraceEvent),
exitCh chan<- relayExitSignal,
) {
wroteDownstream := false
for {
msgType, payload, err := upstreamConn.ReadFrame(ctx)
if err != nil {
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "read_upstream_failed",
Direction: "upstream_to_client",
Error: err.Error(),
Graceful: isDisconnectError(err),
WroteDownstream: wroteDownstream,
})
exitCh <- relayExitSignal{
stage: "read_upstream",
err: err,
graceful: isDisconnectError(err),
wroteDownstream: wroteDownstream,
}
return
}
markActivity()
observedEvent := observedUpstreamEvent{}
switch msgType {
case coderws.MessageText:
observedEvent = observeUpstreamMessage(state, payload, startAt, nowFn, onUsageParseFailure)
case coderws.MessageBinary:
// binary frame 直接透传,不进入 JSON 观测路径(避免无效解析开销)。
}
emitTurnComplete(onTurnComplete, state, observedEvent)
if dropDownstreamWrites != nil && dropDownstreamWrites.Load() {
if droppedFrames != nil {
droppedFrames.Add(1)
}
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "drop_downstream_frame",
Direction: "upstream_to_client",
MessageType: relayMessageTypeString(msgType),
PayloadBytes: len(payload),
WroteDownstream: wroteDownstream,
})
if observedEvent.terminal {
exitCh <- relayExitSignal{
stage: "drain_terminal",
graceful: true,
wroteDownstream: wroteDownstream,
}
return
}
markActivity()
continue
}
if err := writeClient(msgType, payload); err != nil {
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "write_client_failed",
Direction: "upstream_to_client",
MessageType: relayMessageTypeString(msgType),
PayloadBytes: len(payload),
WroteDownstream: wroteDownstream,
Error: err.Error(),
})
exitCh <- relayExitSignal{stage: "write_client", err: err, wroteDownstream: wroteDownstream}
return
}
wroteDownstream = true
if forwardedFrames != nil {
forwardedFrames.Add(1)
}
markActivity()
}
}
func runIdleWatchdog(
ctx context.Context,
nowFn func() time.Time,
idleTimeout time.Duration,
lastActivity *atomic.Int64,
onTrace func(event RelayTraceEvent),
exitCh chan<- relayExitSignal,
) {
if idleTimeout <= 0 {
return
}
checkInterval := minDuration(idleTimeout/4, 5*time.Second)
if checkInterval < time.Second {
checkInterval = time.Second
}
ticker := time.NewTicker(checkInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
last := time.Unix(0, lastActivity.Load())
if nowFn().Sub(last) < idleTimeout {
continue
}
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "idle_timeout_triggered",
Direction: "watchdog",
Error: context.DeadlineExceeded.Error(),
})
exitCh <- relayExitSignal{stage: "idle_timeout", err: context.DeadlineExceeded}
return
}
}
}
func emitRelayTrace(onTrace func(event RelayTraceEvent), event RelayTraceEvent) {
if onTrace == nil {
return
}
onTrace(event)
}
func relayMessageTypeString(msgType coderws.MessageType) string {
switch msgType {
case coderws.MessageText:
return "text"
case coderws.MessageBinary:
return "binary"
default:
return "unknown(" + strconv.Itoa(int(msgType)) + ")"
}
}
func relayDirectionFromStage(stage string) string {
switch stage {
case "read_client", "write_upstream":
return "client_to_upstream"
case "read_upstream", "write_client", "drain_terminal":
return "upstream_to_client"
case "idle_timeout":
return "watchdog"
default:
return ""
}
}
func relayErrorString(err error) string {
if err == nil {
return ""
}
return err.Error()
}
func observeUpstreamMessage(
state *relayState,
message []byte,
startAt time.Time,
nowFn func() time.Time,
onUsageParseFailure func(eventType string, usageRaw string),
) observedUpstreamEvent {
if state == nil || len(message) == 0 {
return observedUpstreamEvent{}
}
values := gjson.GetManyBytes(message, "type", "response.id", "response_id", "id")
eventType := strings.TrimSpace(values[0].String())
if eventType == "" {
return observedUpstreamEvent{}
}
responseID := strings.TrimSpace(values[1].String())
if responseID == "" {
responseID = strings.TrimSpace(values[2].String())
}
// 仅 terminal 事件兜底读取顶层 id避免把 event_id 当成 response_id 关联到 turn。
if responseID == "" && isTerminalEvent(eventType) {
responseID = strings.TrimSpace(values[3].String())
}
now := nowFn()
if state.firstTokenMs == nil && isTokenEvent(eventType) {
ms := int(now.Sub(startAt).Milliseconds())
if ms >= 0 {
state.firstTokenMs = &ms
}
}
parsedUsage := parseUsageAndAccumulate(state, message, eventType, onUsageParseFailure)
observed := observedUpstreamEvent{
eventType: eventType,
responseID: responseID,
usage: parsedUsage,
}
if responseID != "" {
turnTiming := openAIWSRelayGetOrInitTurnTiming(state, responseID, now)
if turnTiming != nil && turnTiming.firstTokenMs == nil && isTokenEvent(eventType) {
ms := int(now.Sub(turnTiming.startAt).Milliseconds())
if ms >= 0 {
turnTiming.firstTokenMs = &ms
}
}
}
if !isTerminalEvent(eventType) {
return observed
}
observed.terminal = true
state.terminalEventType = eventType
if responseID != "" {
state.lastResponseID = responseID
if turnTiming, ok := openAIWSRelayDeleteTurnTiming(state, responseID); ok {
duration := now.Sub(turnTiming.startAt)
if duration < 0 {
duration = 0
}
observed.duration = duration
observed.firstToken = openAIWSRelayCloneIntPtr(turnTiming.firstTokenMs)
}
}
return observed
}
func emitTurnComplete(
onTurnComplete func(turn RelayTurnResult),
state *relayState,
observed observedUpstreamEvent,
) {
if onTurnComplete == nil || !observed.terminal {
return
}
responseID := strings.TrimSpace(observed.responseID)
if responseID == "" {
return
}
requestModel := ""
if state != nil {
requestModel = state.requestModel
}
onTurnComplete(RelayTurnResult{
RequestModel: requestModel,
Usage: observed.usage,
RequestID: responseID,
TerminalEventType: observed.eventType,
Duration: observed.duration,
FirstTokenMs: openAIWSRelayCloneIntPtr(observed.firstToken),
})
}
func openAIWSRelayGetOrInitTurnTiming(state *relayState, responseID string, now time.Time) *relayTurnTiming {
if state == nil {
return nil
}
if state.turnTimingByID == nil {
state.turnTimingByID = make(map[string]*relayTurnTiming, 8)
}
timing, ok := state.turnTimingByID[responseID]
if !ok || timing == nil || timing.startAt.IsZero() {
timing = &relayTurnTiming{startAt: now}
state.turnTimingByID[responseID] = timing
return timing
}
return timing
}
func openAIWSRelayDeleteTurnTiming(state *relayState, responseID string) (relayTurnTiming, bool) {
if state == nil || state.turnTimingByID == nil {
return relayTurnTiming{}, false
}
timing, ok := state.turnTimingByID[responseID]
if !ok || timing == nil {
return relayTurnTiming{}, false
}
delete(state.turnTimingByID, responseID)
return *timing, true
}
func openAIWSRelayCloneIntPtr(v *int) *int {
if v == nil {
return nil
}
cloned := *v
return &cloned
}
func parseUsageAndAccumulate(
state *relayState,
message []byte,
eventType string,
onParseFailure func(eventType string, usageRaw string),
) Usage {
if state == nil || len(message) == 0 || !shouldParseUsage(eventType) {
return Usage{}
}
usageResult := gjson.GetBytes(message, "response.usage")
if !usageResult.Exists() {
return Usage{}
}
usageRaw := strings.TrimSpace(usageResult.Raw)
if usageRaw == "" || !strings.HasPrefix(usageRaw, "{") {
recordUsageParseFailure()
if onParseFailure != nil {
onParseFailure(eventType, usageRaw)
}
return Usage{}
}
inputResult := gjson.GetBytes(message, "response.usage.input_tokens")
outputResult := gjson.GetBytes(message, "response.usage.output_tokens")
cachedResult := gjson.GetBytes(message, "response.usage.input_tokens_details.cached_tokens")
inputTokens, inputOK := parseUsageIntField(inputResult, true)
outputTokens, outputOK := parseUsageIntField(outputResult, true)
cachedTokens, cachedOK := parseUsageIntField(cachedResult, false)
if !inputOK || !outputOK || !cachedOK {
recordUsageParseFailure()
if onParseFailure != nil {
onParseFailure(eventType, usageRaw)
}
// 解析失败时不做部分字段累加,避免计费 usage 出现“半有效”状态。
return Usage{}
}
parsedUsage := Usage{
InputTokens: inputTokens,
OutputTokens: outputTokens,
CacheReadInputTokens: cachedTokens,
}
state.usage.InputTokens += parsedUsage.InputTokens
state.usage.OutputTokens += parsedUsage.OutputTokens
state.usage.CacheReadInputTokens += parsedUsage.CacheReadInputTokens
return parsedUsage
}
func parseUsageIntField(value gjson.Result, required bool) (int, bool) {
if !value.Exists() {
return 0, !required
}
if value.Type != gjson.Number {
return 0, false
}
return int(value.Int()), true
}
func enrichResult(result *RelayResult, state *relayState, duration time.Duration) {
if result == nil {
return
}
result.Duration = duration
if state == nil {
return
}
result.RequestModel = state.requestModel
result.Usage = state.usage
result.RequestID = state.lastResponseID
result.TerminalEventType = state.terminalEventType
result.FirstTokenMs = state.firstTokenMs
}
func isDisconnectError(err error) bool {
if err == nil {
return false
}
if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) {
return true
}
switch coderws.CloseStatus(err) {
case coderws.StatusNormalClosure, coderws.StatusGoingAway, coderws.StatusNoStatusRcvd, coderws.StatusAbnormalClosure:
return true
}
message := strings.ToLower(strings.TrimSpace(err.Error()))
if message == "" {
return false
}
return strings.Contains(message, "failed to read frame header: eof") ||
strings.Contains(message, "unexpected eof") ||
strings.Contains(message, "use of closed network connection") ||
strings.Contains(message, "connection reset by peer") ||
strings.Contains(message, "broken pipe")
}
func isTerminalEvent(eventType string) bool {
switch eventType {
case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled":
return true
default:
return false
}
}
func shouldParseUsage(eventType string) bool {
switch eventType {
case "response.completed", "response.done", "response.failed":
return true
default:
return false
}
}
func isTokenEvent(eventType string) bool {
if eventType == "" {
return false
}
switch eventType {
case "response.created", "response.in_progress", "response.output_item.added", "response.output_item.done":
return false
}
if strings.Contains(eventType, ".delta") {
return true
}
if strings.HasPrefix(eventType, "response.output_text") {
return true
}
if strings.HasPrefix(eventType, "response.output") {
return true
}
return eventType == "response.completed" || eventType == "response.done"
}
func minDuration(a, b time.Duration) time.Duration {
if a <= 0 {
return b
}
if b <= 0 {
return a
}
if a < b {
return a
}
return b
}
func waitRelayExit(exitCh <-chan relayExitSignal, timeout time.Duration) (relayExitSignal, bool) {
if timeout <= 0 {
timeout = 200 * time.Millisecond
}
select {
case sig := <-exitCh:
return sig, true
case <-time.After(timeout):
return relayExitSignal{}, false
}
}

View File

@@ -0,0 +1,432 @@
package openai_ws_v2
import (
"context"
"errors"
"io"
"net"
"sync/atomic"
"testing"
"time"
coderws "github.com/coder/websocket"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestRunEntry_DelegatesRelay(t *testing.T) {
t.Parallel()
clientConn := newPassthroughTestFrameConn(nil, false)
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
{
msgType: coderws.MessageText,
payload: []byte(`{"type":"response.completed","response":{"id":"resp_entry","usage":{"input_tokens":1,"output_tokens":1}}}`),
},
}, true)
result, relayExit := RunEntry(EntryInput{
Ctx: context.Background(),
ClientConn: clientConn,
UpstreamConn: upstreamConn,
FirstClientMessage: []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`),
})
require.Nil(t, relayExit)
require.Equal(t, "resp_entry", result.RequestID)
}
func TestRunClientToUpstream_ErrorPaths(t *testing.T) {
t.Parallel()
t.Run("read client eof", func(t *testing.T) {
t.Parallel()
exitCh := make(chan relayExitSignal, 1)
runClientToUpstream(
context.Background(),
newPassthroughTestFrameConn(nil, true),
func(_ coderws.MessageType, _ []byte) error { return nil },
func() {},
nil,
nil,
exitCh,
)
sig := <-exitCh
require.Equal(t, "read_client", sig.stage)
require.True(t, sig.graceful)
})
t.Run("write upstream failed", func(t *testing.T) {
t.Parallel()
exitCh := make(chan relayExitSignal, 1)
runClientToUpstream(
context.Background(),
newPassthroughTestFrameConn([]passthroughTestFrame{
{msgType: coderws.MessageText, payload: []byte(`{"x":1}`)},
}, true),
func(_ coderws.MessageType, _ []byte) error { return errors.New("boom") },
func() {},
nil,
nil,
exitCh,
)
sig := <-exitCh
require.Equal(t, "write_upstream", sig.stage)
require.False(t, sig.graceful)
})
t.Run("forwarded counter and trace callback", func(t *testing.T) {
t.Parallel()
exitCh := make(chan relayExitSignal, 1)
forwarded := &atomic.Int64{}
traces := make([]RelayTraceEvent, 0, 2)
runClientToUpstream(
context.Background(),
newPassthroughTestFrameConn([]passthroughTestFrame{
{msgType: coderws.MessageText, payload: []byte(`{"x":1}`)},
}, true),
func(_ coderws.MessageType, _ []byte) error { return nil },
func() {},
forwarded,
func(event RelayTraceEvent) {
traces = append(traces, event)
},
exitCh,
)
sig := <-exitCh
require.Equal(t, "read_client", sig.stage)
require.Equal(t, int64(1), forwarded.Load())
require.NotEmpty(t, traces)
})
}
func TestRunUpstreamToClient_ErrorAndDropPaths(t *testing.T) {
t.Parallel()
t.Run("read upstream eof", func(t *testing.T) {
t.Parallel()
exitCh := make(chan relayExitSignal, 1)
drop := &atomic.Bool{}
drop.Store(false)
runUpstreamToClient(
context.Background(),
newPassthroughTestFrameConn(nil, true),
func(_ coderws.MessageType, _ []byte) error { return nil },
time.Now(),
time.Now,
&relayState{},
nil,
nil,
drop,
nil,
nil,
func() {},
nil,
exitCh,
)
sig := <-exitCh
require.Equal(t, "read_upstream", sig.stage)
require.True(t, sig.graceful)
})
t.Run("write client failed", func(t *testing.T) {
t.Parallel()
exitCh := make(chan relayExitSignal, 1)
drop := &atomic.Bool{}
drop.Store(false)
runUpstreamToClient(
context.Background(),
newPassthroughTestFrameConn([]passthroughTestFrame{
{msgType: coderws.MessageText, payload: []byte(`{"type":"response.output_text.delta","delta":"x"}`)},
}, true),
func(_ coderws.MessageType, _ []byte) error { return errors.New("write failed") },
time.Now(),
time.Now,
&relayState{},
nil,
nil,
drop,
nil,
nil,
func() {},
nil,
exitCh,
)
sig := <-exitCh
require.Equal(t, "write_client", sig.stage)
})
t.Run("drop downstream and stop on terminal", func(t *testing.T) {
t.Parallel()
exitCh := make(chan relayExitSignal, 1)
drop := &atomic.Bool{}
drop.Store(true)
dropped := &atomic.Int64{}
runUpstreamToClient(
context.Background(),
newPassthroughTestFrameConn([]passthroughTestFrame{
{
msgType: coderws.MessageText,
payload: []byte(`{"type":"response.completed","response":{"id":"resp_drop","usage":{"input_tokens":1,"output_tokens":1}}}`),
},
}, true),
func(_ coderws.MessageType, _ []byte) error { return nil },
time.Now(),
time.Now,
&relayState{},
nil,
nil,
drop,
nil,
dropped,
func() {},
nil,
exitCh,
)
sig := <-exitCh
require.Equal(t, "drain_terminal", sig.stage)
require.True(t, sig.graceful)
require.Equal(t, int64(1), dropped.Load())
})
}
func TestRunIdleWatchdog_NoTimeoutWhenDisabled(t *testing.T) {
t.Parallel()
exitCh := make(chan relayExitSignal, 1)
lastActivity := &atomic.Int64{}
lastActivity.Store(time.Now().UnixNano())
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go runIdleWatchdog(ctx, time.Now, 0, lastActivity, nil, exitCh)
select {
case <-exitCh:
t.Fatal("unexpected idle timeout signal")
case <-time.After(200 * time.Millisecond):
}
}
func TestHelperFunctionsCoverage(t *testing.T) {
t.Parallel()
require.Equal(t, "text", relayMessageTypeString(coderws.MessageText))
require.Equal(t, "binary", relayMessageTypeString(coderws.MessageBinary))
require.Contains(t, relayMessageTypeString(coderws.MessageType(99)), "unknown(")
require.Equal(t, "", relayErrorString(nil))
require.Equal(t, "x", relayErrorString(errors.New("x")))
require.True(t, isDisconnectError(io.EOF))
require.True(t, isDisconnectError(net.ErrClosed))
require.True(t, isDisconnectError(context.Canceled))
require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusGoingAway}))
require.True(t, isDisconnectError(errors.New("broken pipe")))
require.False(t, isDisconnectError(errors.New("unrelated")))
require.True(t, isTokenEvent("response.output_text.delta"))
require.True(t, isTokenEvent("response.output_audio.delta"))
require.True(t, isTokenEvent("response.completed"))
require.False(t, isTokenEvent(""))
require.False(t, isTokenEvent("response.created"))
require.Equal(t, 2*time.Second, minDuration(2*time.Second, 5*time.Second))
require.Equal(t, 2*time.Second, minDuration(5*time.Second, 2*time.Second))
require.Equal(t, 5*time.Second, minDuration(0, 5*time.Second))
require.Equal(t, 2*time.Second, minDuration(2*time.Second, 0))
ch := make(chan relayExitSignal, 1)
ch <- relayExitSignal{stage: "ok"}
sig, ok := waitRelayExit(ch, 10*time.Millisecond)
require.True(t, ok)
require.Equal(t, "ok", sig.stage)
ch <- relayExitSignal{stage: "ok2"}
sig, ok = waitRelayExit(ch, 0)
require.True(t, ok)
require.Equal(t, "ok2", sig.stage)
_, ok = waitRelayExit(ch, 10*time.Millisecond)
require.False(t, ok)
n, ok := parseUsageIntField(gjson.Get(`{"n":3}`, "n"), true)
require.True(t, ok)
require.Equal(t, 3, n)
_, ok = parseUsageIntField(gjson.Get(`{"n":"x"}`, "n"), true)
require.False(t, ok)
n, ok = parseUsageIntField(gjson.Result{}, false)
require.True(t, ok)
require.Equal(t, 0, n)
_, ok = parseUsageIntField(gjson.Result{}, true)
require.False(t, ok)
}
func TestParseUsageAndEnrichCoverage(t *testing.T) {
t.Parallel()
state := &relayState{}
parseUsageAndAccumulate(state, []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":"bad"}}}`), "response.completed", nil)
require.Equal(t, 0, state.usage.InputTokens)
parseUsageAndAccumulate(
state,
[]byte(`{"type":"response.completed","response":{"usage":{"input_tokens":9,"output_tokens":"bad","input_tokens_details":{"cached_tokens":2}}}}`),
"response.completed",
nil,
)
require.Equal(t, 0, state.usage.InputTokens, "部分字段解析失败时不应累加 usage")
require.Equal(t, 0, state.usage.OutputTokens)
require.Equal(t, 0, state.usage.CacheReadInputTokens)
parseUsageAndAccumulate(
state,
[]byte(`{"type":"response.completed","response":{"usage":{"input_tokens_details":{"cached_tokens":2}}}}`),
"response.completed",
nil,
)
require.Equal(t, 0, state.usage.InputTokens, "必填 usage 字段缺失时不应累加 usage")
require.Equal(t, 0, state.usage.OutputTokens)
require.Equal(t, 0, state.usage.CacheReadInputTokens)
parseUsageAndAccumulate(state, []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":2,"output_tokens":1,"input_tokens_details":{"cached_tokens":1}}}}`), "response.completed", nil)
require.Equal(t, 2, state.usage.InputTokens)
require.Equal(t, 1, state.usage.OutputTokens)
require.Equal(t, 1, state.usage.CacheReadInputTokens)
result := &RelayResult{}
enrichResult(result, state, 5*time.Millisecond)
require.Equal(t, state.usage.InputTokens, result.Usage.InputTokens)
require.Equal(t, 5*time.Millisecond, result.Duration)
parseUsageAndAccumulate(state, []byte(`{"type":"response.in_progress","response":{"usage":{"input_tokens":9}}}`), "response.in_progress", nil)
require.Equal(t, 2, state.usage.InputTokens)
enrichResult(nil, state, 0)
}
func TestEmitTurnCompleteCoverage(t *testing.T) {
t.Parallel()
// 非 terminal 事件不应触发。
called := 0
emitTurnComplete(func(turn RelayTurnResult) {
called++
}, &relayState{requestModel: "gpt-5"}, observedUpstreamEvent{
terminal: false,
eventType: "response.output_text.delta",
responseID: "resp_ignored",
usage: Usage{InputTokens: 1},
})
require.Equal(t, 0, called)
// 缺少 response_id 时不应触发。
emitTurnComplete(func(turn RelayTurnResult) {
called++
}, &relayState{requestModel: "gpt-5"}, observedUpstreamEvent{
terminal: true,
eventType: "response.completed",
})
require.Equal(t, 0, called)
// terminal 且 response_id 存在应该触发state=nil 时 model 为空串。
var got RelayTurnResult
emitTurnComplete(func(turn RelayTurnResult) {
called++
got = turn
}, nil, observedUpstreamEvent{
terminal: true,
eventType: "response.completed",
responseID: "resp_emit",
usage: Usage{InputTokens: 2, OutputTokens: 3},
})
require.Equal(t, 1, called)
require.Equal(t, "resp_emit", got.RequestID)
require.Equal(t, "response.completed", got.TerminalEventType)
require.Equal(t, 2, got.Usage.InputTokens)
require.Equal(t, 3, got.Usage.OutputTokens)
require.Equal(t, "", got.RequestModel)
}
func TestIsDisconnectErrorCoverage_CloseStatusesAndMessageBranches(t *testing.T) {
t.Parallel()
require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusNormalClosure}))
require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusNoStatusRcvd}))
require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusAbnormalClosure}))
require.True(t, isDisconnectError(errors.New("connection reset by peer")))
require.False(t, isDisconnectError(errors.New(" ")))
}
func TestIsTokenEventCoverageBranches(t *testing.T) {
t.Parallel()
require.False(t, isTokenEvent("response.in_progress"))
require.False(t, isTokenEvent("response.output_item.added"))
require.True(t, isTokenEvent("response.output_audio.delta"))
require.True(t, isTokenEvent("response.output"))
require.True(t, isTokenEvent("response.done"))
}
func TestRelayTurnTimingHelpersCoverage(t *testing.T) {
t.Parallel()
now := time.Unix(100, 0)
// nil state
require.Nil(t, openAIWSRelayGetOrInitTurnTiming(nil, "resp_nil", now))
_, ok := openAIWSRelayDeleteTurnTiming(nil, "resp_nil")
require.False(t, ok)
state := &relayState{}
timing := openAIWSRelayGetOrInitTurnTiming(state, "resp_a", now)
require.NotNil(t, timing)
require.Equal(t, now, timing.startAt)
// 再次获取返回同一条 timing
timing2 := openAIWSRelayGetOrInitTurnTiming(state, "resp_a", now.Add(5*time.Second))
require.NotNil(t, timing2)
require.Equal(t, now, timing2.startAt)
// 删除存在键
deleted, ok := openAIWSRelayDeleteTurnTiming(state, "resp_a")
require.True(t, ok)
require.Equal(t, now, deleted.startAt)
// 删除不存在键
_, ok = openAIWSRelayDeleteTurnTiming(state, "resp_a")
require.False(t, ok)
}
func TestObserveUpstreamMessage_ResponseIDFallbackPolicy(t *testing.T) {
t.Parallel()
state := &relayState{requestModel: "gpt-5"}
startAt := time.Unix(0, 0)
now := startAt
nowFn := func() time.Time {
now = now.Add(5 * time.Millisecond)
return now
}
// 非 terminal仅有顶层 id不应把 event id 当成 response_id。
observed := observeUpstreamMessage(
state,
[]byte(`{"type":"response.output_text.delta","id":"evt_123","delta":"hi"}`),
startAt,
nowFn,
nil,
)
require.False(t, observed.terminal)
require.Equal(t, "", observed.responseID)
// terminal允许兜底用顶层 id用于兼容少数字段变体
observed = observeUpstreamMessage(
state,
[]byte(`{"type":"response.completed","id":"resp_fallback","response":{"usage":{"input_tokens":1,"output_tokens":1}}}`),
startAt,
nowFn,
nil,
)
require.True(t, observed.terminal)
require.Equal(t, "resp_fallback", observed.responseID)
}

View File

@@ -0,0 +1,752 @@
package openai_ws_v2
import (
"context"
"errors"
"io"
"sync"
"sync/atomic"
"testing"
"time"
coderws "github.com/coder/websocket"
"github.com/stretchr/testify/require"
)
type passthroughTestFrame struct {
msgType coderws.MessageType
payload []byte
}
type passthroughTestFrameConn struct {
mu sync.Mutex
writes []passthroughTestFrame
readCh chan passthroughTestFrame
once sync.Once
}
type delayedReadFrameConn struct {
base FrameConn
firstDelay time.Duration
once sync.Once
}
type closeSpyFrameConn struct {
closeCalls atomic.Int32
}
func newPassthroughTestFrameConn(frames []passthroughTestFrame, autoClose bool) *passthroughTestFrameConn {
c := &passthroughTestFrameConn{
readCh: make(chan passthroughTestFrame, len(frames)+1),
}
for _, frame := range frames {
copied := passthroughTestFrame{msgType: frame.msgType, payload: append([]byte(nil), frame.payload...)}
c.readCh <- copied
}
if autoClose {
close(c.readCh)
}
return c
}
func (c *passthroughTestFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
if ctx == nil {
ctx = context.Background()
}
select {
case <-ctx.Done():
return coderws.MessageText, nil, ctx.Err()
case frame, ok := <-c.readCh:
if !ok {
return coderws.MessageText, nil, io.EOF
}
return frame.msgType, append([]byte(nil), frame.payload...), nil
}
}
func (c *passthroughTestFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
if ctx == nil {
ctx = context.Background()
}
select {
case <-ctx.Done():
return ctx.Err()
default:
}
c.mu.Lock()
defer c.mu.Unlock()
c.writes = append(c.writes, passthroughTestFrame{msgType: msgType, payload: append([]byte(nil), payload...)})
return nil
}
func (c *passthroughTestFrameConn) Close() error {
c.once.Do(func() {
defer func() { _ = recover() }()
close(c.readCh)
})
return nil
}
func (c *passthroughTestFrameConn) Writes() []passthroughTestFrame {
c.mu.Lock()
defer c.mu.Unlock()
out := make([]passthroughTestFrame, len(c.writes))
copy(out, c.writes)
return out
}
func (c *delayedReadFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
if c == nil || c.base == nil {
return coderws.MessageText, nil, io.EOF
}
c.once.Do(func() {
if c.firstDelay > 0 {
timer := time.NewTimer(c.firstDelay)
defer timer.Stop()
select {
case <-ctx.Done():
case <-timer.C:
}
}
})
return c.base.ReadFrame(ctx)
}
func (c *delayedReadFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
if c == nil || c.base == nil {
return io.EOF
}
return c.base.WriteFrame(ctx, msgType, payload)
}
func (c *delayedReadFrameConn) Close() error {
if c == nil || c.base == nil {
return nil
}
return c.base.Close()
}
func (c *closeSpyFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
if ctx == nil {
ctx = context.Background()
}
<-ctx.Done()
return coderws.MessageText, nil, ctx.Err()
}
func (c *closeSpyFrameConn) WriteFrame(ctx context.Context, _ coderws.MessageType, _ []byte) error {
if ctx == nil {
ctx = context.Background()
}
select {
case <-ctx.Done():
return ctx.Err()
default:
return nil
}
}
func (c *closeSpyFrameConn) Close() error {
if c != nil {
c.closeCalls.Add(1)
}
return nil
}
func (c *closeSpyFrameConn) CloseCalls() int32 {
if c == nil {
return 0
}
return c.closeCalls.Load()
}
func TestRelay_BasicRelayAndUsage(t *testing.T) {
t.Parallel()
clientConn := newPassthroughTestFrameConn(nil, false)
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
{
msgType: coderws.MessageText,
payload: []byte(`{"type":"response.completed","response":{"id":"resp_123","usage":{"input_tokens":7,"output_tokens":3,"input_tokens_details":{"cached_tokens":2}}}}`),
},
}, true)
firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[{"type":"input_text","text":"hello"}]}`)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
require.Nil(t, relayExit)
require.Equal(t, "gpt-5.3-codex", result.RequestModel)
require.Equal(t, "resp_123", result.RequestID)
require.Equal(t, "response.completed", result.TerminalEventType)
require.Equal(t, 7, result.Usage.InputTokens)
require.Equal(t, 3, result.Usage.OutputTokens)
require.Equal(t, 2, result.Usage.CacheReadInputTokens)
require.NotNil(t, result.FirstTokenMs)
require.Equal(t, int64(1), result.ClientToUpstreamFrames)
require.Equal(t, int64(1), result.UpstreamToClientFrames)
require.Equal(t, int64(0), result.DroppedDownstreamFrames)
upstreamWrites := upstreamConn.Writes()
require.Len(t, upstreamWrites, 1)
require.Equal(t, coderws.MessageText, upstreamWrites[0].msgType)
require.JSONEq(t, string(firstPayload), string(upstreamWrites[0].payload))
clientWrites := clientConn.Writes()
require.Len(t, clientWrites, 1)
require.Equal(t, coderws.MessageText, clientWrites[0].msgType)
require.JSONEq(t, `{"type":"response.completed","response":{"id":"resp_123","usage":{"input_tokens":7,"output_tokens":3,"input_tokens_details":{"cached_tokens":2}}}}`, string(clientWrites[0].payload))
}
func TestRelay_FunctionCallOutputBytesPreserved(t *testing.T) {
t.Parallel()
clientConn := newPassthroughTestFrameConn(nil, false)
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
{
msgType: coderws.MessageText,
payload: []byte(`{"type":"response.completed","response":{"id":"resp_func","usage":{"input_tokens":1,"output_tokens":1}}}`),
},
}, true)
firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[{"type":"function_call_output","call_id":"call_abc123","output":"{\"ok\":true}"}]}`)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
require.Nil(t, relayExit)
upstreamWrites := upstreamConn.Writes()
require.Len(t, upstreamWrites, 1)
require.Equal(t, coderws.MessageText, upstreamWrites[0].msgType)
require.Equal(t, firstPayload, upstreamWrites[0].payload)
}
func TestRelay_UpstreamDisconnect(t *testing.T) {
t.Parallel()
// 上游立即关闭EOF客户端不发送额外帧
clientConn := newPassthroughTestFrameConn(nil, false)
upstreamConn := newPassthroughTestFrameConn(nil, true) // 立即 close -> EOF
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
// 上游 EOF 属于 disconnect标记为 graceful
require.Nil(t, relayExit, "上游 EOF 应被视为 graceful disconnect")
require.Equal(t, "gpt-4o", result.RequestModel)
}
func TestRelay_ClientDisconnect(t *testing.T) {
t.Parallel()
// 客户端立即关闭EOF上游阻塞读取直到 context 取消
clientConn := newPassthroughTestFrameConn(nil, true) // 立即 close -> EOF
upstreamConn := newPassthroughTestFrameConn(nil, false)
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
require.NotNil(t, relayExit, "客户端 EOF 应返回可观测的中断状态")
require.Equal(t, "client_disconnected", relayExit.Stage)
require.Equal(t, "gpt-4o", result.RequestModel)
}
func TestRelay_ClientDisconnect_DrainCapturesLateUsage(t *testing.T) {
t.Parallel()
clientConn := newPassthroughTestFrameConn(nil, true)
upstreamBase := newPassthroughTestFrameConn([]passthroughTestFrame{
{
msgType: coderws.MessageText,
payload: []byte(`{"type":"response.completed","response":{"id":"resp_drain","usage":{"input_tokens":6,"output_tokens":4,"input_tokens_details":{"cached_tokens":1}}}}`),
},
}, true)
upstreamConn := &delayedReadFrameConn{
base: upstreamBase,
firstDelay: 80 * time.Millisecond,
}
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
UpstreamDrainTimeout: 400 * time.Millisecond,
})
require.NotNil(t, relayExit)
require.Equal(t, "client_disconnected", relayExit.Stage)
require.Equal(t, "resp_drain", result.RequestID)
require.Equal(t, "response.completed", result.TerminalEventType)
require.Equal(t, 6, result.Usage.InputTokens)
require.Equal(t, 4, result.Usage.OutputTokens)
require.Equal(t, 1, result.Usage.CacheReadInputTokens)
require.Equal(t, int64(1), result.ClientToUpstreamFrames)
require.Equal(t, int64(0), result.UpstreamToClientFrames)
require.Equal(t, int64(1), result.DroppedDownstreamFrames)
}
func TestRelay_IdleTimeout(t *testing.T) {
t.Parallel()
// 客户端和上游都不发送帧idle timeout 应触发
clientConn := newPassthroughTestFrameConn(nil, false)
upstreamConn := newPassthroughTestFrameConn(nil, false)
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// 使用快进时间来加速 idle timeout
now := time.Now()
callCount := 0
nowFn := func() time.Time {
callCount++
// 前几次调用返回正常时间(初始化阶段),之后快进
if callCount <= 5 {
return now
}
return now.Add(time.Hour) // 快进到超时
}
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
IdleTimeout: 2 * time.Second,
Now: nowFn,
})
require.NotNil(t, relayExit, "应因 idle timeout 退出")
require.Equal(t, "idle_timeout", relayExit.Stage)
require.Equal(t, "gpt-4o", result.RequestModel)
}
func TestRelay_IdleTimeoutDoesNotCloseClientOnError(t *testing.T) {
t.Parallel()
clientConn := &closeSpyFrameConn{}
upstreamConn := &closeSpyFrameConn{}
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
now := time.Now()
callCount := 0
nowFn := func() time.Time {
callCount++
if callCount <= 5 {
return now
}
return now.Add(time.Hour)
}
_, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
IdleTimeout: 2 * time.Second,
Now: nowFn,
})
require.NotNil(t, relayExit, "应因 idle timeout 退出")
require.Equal(t, "idle_timeout", relayExit.Stage)
require.Zero(t, clientConn.CloseCalls(), "错误路径不应提前关闭客户端连接,交给上层决定 close code")
require.GreaterOrEqual(t, upstreamConn.CloseCalls(), int32(1))
}
func TestRelay_NilConnections(t *testing.T) {
t.Parallel()
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
ctx := context.Background()
t.Run("nil client conn", func(t *testing.T) {
upstreamConn := newPassthroughTestFrameConn(nil, true)
_, relayExit := Relay(ctx, nil, upstreamConn, firstPayload, RelayOptions{})
require.NotNil(t, relayExit)
require.Equal(t, "relay_init", relayExit.Stage)
require.Contains(t, relayExit.Err.Error(), "nil")
})
t.Run("nil upstream conn", func(t *testing.T) {
clientConn := newPassthroughTestFrameConn(nil, true)
_, relayExit := Relay(ctx, clientConn, nil, firstPayload, RelayOptions{})
require.NotNil(t, relayExit)
require.Equal(t, "relay_init", relayExit.Stage)
require.Contains(t, relayExit.Err.Error(), "nil")
})
}
func TestRelay_MultipleUpstreamMessages(t *testing.T) {
t.Parallel()
// 上游发送多个事件delta + completed验证多帧中继和 usage 聚合
clientConn := newPassthroughTestFrameConn(nil, false)
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
{
msgType: coderws.MessageText,
payload: []byte(`{"type":"response.output_text.delta","delta":"Hello"}`),
},
{
msgType: coderws.MessageText,
payload: []byte(`{"type":"response.output_text.delta","delta":" world"}`),
},
{
msgType: coderws.MessageText,
payload: []byte(`{"type":"response.completed","response":{"id":"resp_multi","usage":{"input_tokens":10,"output_tokens":5,"input_tokens_details":{"cached_tokens":3}}}}`),
},
}, true)
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[{"type":"input_text","text":"hi"}]}`)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
require.Nil(t, relayExit)
require.Equal(t, "resp_multi", result.RequestID)
require.Equal(t, "response.completed", result.TerminalEventType)
require.Equal(t, 10, result.Usage.InputTokens)
require.Equal(t, 5, result.Usage.OutputTokens)
require.Equal(t, 3, result.Usage.CacheReadInputTokens)
require.NotNil(t, result.FirstTokenMs)
// 验证所有 3 个上游帧都转发给了客户端
clientWrites := clientConn.Writes()
require.Len(t, clientWrites, 3)
}
func TestRelay_OnTurnComplete_PerTerminalEvent(t *testing.T) {
t.Parallel()
clientConn := newPassthroughTestFrameConn(nil, false)
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
{
msgType: coderws.MessageText,
payload: []byte(`{"type":"response.completed","response":{"id":"resp_turn_1","usage":{"input_tokens":2,"output_tokens":1}}}`),
},
{
msgType: coderws.MessageText,
payload: []byte(`{"type":"response.failed","response":{"id":"resp_turn_2","usage":{"input_tokens":3,"output_tokens":4}}}`),
},
}, true)
firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
turns := make([]RelayTurnResult, 0, 2)
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
OnTurnComplete: func(turn RelayTurnResult) {
turns = append(turns, turn)
},
})
require.Nil(t, relayExit)
require.Len(t, turns, 2)
require.Equal(t, "resp_turn_1", turns[0].RequestID)
require.Equal(t, "response.completed", turns[0].TerminalEventType)
require.Equal(t, 2, turns[0].Usage.InputTokens)
require.Equal(t, 1, turns[0].Usage.OutputTokens)
require.Equal(t, "resp_turn_2", turns[1].RequestID)
require.Equal(t, "response.failed", turns[1].TerminalEventType)
require.Equal(t, 3, turns[1].Usage.InputTokens)
require.Equal(t, 4, turns[1].Usage.OutputTokens)
require.Equal(t, 5, result.Usage.InputTokens)
require.Equal(t, 5, result.Usage.OutputTokens)
}
func TestRelay_OnTurnComplete_ProvidesTurnMetrics(t *testing.T) {
t.Parallel()
clientConn := newPassthroughTestFrameConn(nil, false)
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
{
msgType: coderws.MessageText,
payload: []byte(`{"type":"response.output_text.delta","response_id":"resp_metric","delta":"hi"}`),
},
{
msgType: coderws.MessageText,
payload: []byte(`{"type":"response.completed","response":{"id":"resp_metric","usage":{"input_tokens":2,"output_tokens":1}}}`),
},
}, true)
firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
base := time.Unix(0, 0)
var nowTick atomic.Int64
nowFn := func() time.Time {
step := nowTick.Add(1)
return base.Add(time.Duration(step) * 5 * time.Millisecond)
}
var turn RelayTurnResult
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
Now: nowFn,
OnTurnComplete: func(current RelayTurnResult) {
turn = current
},
})
require.Nil(t, relayExit)
require.Equal(t, "resp_metric", turn.RequestID)
require.Equal(t, "response.completed", turn.TerminalEventType)
require.NotNil(t, turn.FirstTokenMs)
require.GreaterOrEqual(t, *turn.FirstTokenMs, 0)
require.Greater(t, turn.Duration.Milliseconds(), int64(0))
require.NotNil(t, result.FirstTokenMs)
require.Greater(t, result.Duration.Milliseconds(), int64(0))
}
func TestRelay_BinaryFramePassthrough(t *testing.T) {
t.Parallel()
// 验证 binary frame 被透传但不进行 usage 解析
binaryPayload := []byte{0x00, 0x01, 0x02, 0x03}
clientConn := newPassthroughTestFrameConn(nil, false)
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
{
msgType: coderws.MessageBinary,
payload: binaryPayload,
},
}, true)
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
require.Nil(t, relayExit)
// binary frame 不解析 usage
require.Equal(t, 0, result.Usage.InputTokens)
clientWrites := clientConn.Writes()
require.Len(t, clientWrites, 1)
require.Equal(t, coderws.MessageBinary, clientWrites[0].msgType)
require.Equal(t, binaryPayload, clientWrites[0].payload)
}
func TestRelay_BinaryJSONFrameSkipsObservation(t *testing.T) {
t.Parallel()
clientConn := newPassthroughTestFrameConn(nil, false)
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
{
msgType: coderws.MessageBinary,
payload: []byte(`{"type":"response.completed","response":{"id":"resp_binary","usage":{"input_tokens":7,"output_tokens":3}}}`),
},
}, true)
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
require.Nil(t, relayExit)
require.Equal(t, 0, result.Usage.InputTokens)
require.Equal(t, "", result.RequestID)
require.Equal(t, "", result.TerminalEventType)
clientWrites := clientConn.Writes()
require.Len(t, clientWrites, 1)
require.Equal(t, coderws.MessageBinary, clientWrites[0].msgType)
}
func TestRelay_UpstreamErrorEventPassthroughRaw(t *testing.T) {
t.Parallel()
clientConn := newPassthroughTestFrameConn(nil, false)
errorEvent := []byte(`{"type":"error","error":{"type":"invalid_request_error","message":"No tool call found"}}`)
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
{
msgType: coderws.MessageText,
payload: errorEvent,
},
}, true)
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
require.Nil(t, relayExit)
clientWrites := clientConn.Writes()
require.Len(t, clientWrites, 1)
require.Equal(t, coderws.MessageText, clientWrites[0].msgType)
require.Equal(t, errorEvent, clientWrites[0].payload)
}
func TestRelay_PreservesFirstMessageType(t *testing.T) {
t.Parallel()
clientConn := newPassthroughTestFrameConn(nil, false)
upstreamConn := newPassthroughTestFrameConn(nil, true)
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
FirstMessageType: coderws.MessageBinary,
})
require.Nil(t, relayExit)
upstreamWrites := upstreamConn.Writes()
require.Len(t, upstreamWrites, 1)
require.Equal(t, coderws.MessageBinary, upstreamWrites[0].msgType)
require.Equal(t, firstPayload, upstreamWrites[0].payload)
}
func TestRelay_UsageParseFailureDoesNotBlockRelay(t *testing.T) {
baseline := SnapshotMetrics().UsageParseFailureTotal
// 上游发送无效 JSON非 usage 格式),不应影响透传
clientConn := newPassthroughTestFrameConn(nil, false)
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
{
msgType: coderws.MessageText,
payload: []byte(`{"type":"response.completed","response":{"id":"resp_bad","usage":"not_an_object"}}`),
},
}, true)
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
require.Nil(t, relayExit)
// usage 解析失败,值为 0 但不影响透传
require.Equal(t, 0, result.Usage.InputTokens)
require.Equal(t, "response.completed", result.TerminalEventType)
// 帧仍然被转发
clientWrites := clientConn.Writes()
require.Len(t, clientWrites, 1)
require.GreaterOrEqual(t, SnapshotMetrics().UsageParseFailureTotal, baseline+1)
}
func TestRelay_WriteUpstreamFirstMessageFails(t *testing.T) {
t.Parallel()
// 上游连接立即关闭,首包写入失败
upstreamConn := newPassthroughTestFrameConn(nil, true)
_ = upstreamConn.Close()
// 覆盖 WriteFrame 使其返回错误
errConn := &errorOnWriteFrameConn{}
clientConn := newPassthroughTestFrameConn(nil, false)
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_, relayExit := Relay(ctx, clientConn, errConn, firstPayload, RelayOptions{})
require.NotNil(t, relayExit)
require.Equal(t, "write_upstream", relayExit.Stage)
}
func TestRelay_ContextCanceled(t *testing.T) {
t.Parallel()
clientConn := newPassthroughTestFrameConn(nil, false)
upstreamConn := newPassthroughTestFrameConn(nil, false)
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
// 立即取消 context
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
// context 取消导致写首包失败
require.NotNil(t, relayExit)
}
func TestRelay_TraceEvents_ContainsLifecycleStages(t *testing.T) {
t.Parallel()
clientConn := newPassthroughTestFrameConn(nil, false)
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
{
msgType: coderws.MessageText,
payload: []byte(`{"type":"response.completed","response":{"id":"resp_trace","usage":{"input_tokens":1,"output_tokens":1}}}`),
},
}, true)
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
stages := make([]string, 0, 8)
var stagesMu sync.Mutex
_, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
OnTrace: func(event RelayTraceEvent) {
stagesMu.Lock()
stages = append(stages, event.Stage)
stagesMu.Unlock()
},
})
require.Nil(t, relayExit)
stagesMu.Lock()
capturedStages := append([]string(nil), stages...)
stagesMu.Unlock()
require.Contains(t, capturedStages, "relay_start")
require.Contains(t, capturedStages, "write_first_message_ok")
require.Contains(t, capturedStages, "first_exit")
require.Contains(t, capturedStages, "relay_complete")
}
func TestRelay_TraceEvents_IdleTimeout(t *testing.T) {
t.Parallel()
clientConn := newPassthroughTestFrameConn(nil, false)
upstreamConn := newPassthroughTestFrameConn(nil, false)
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
now := time.Now()
callCount := 0
nowFn := func() time.Time {
callCount++
if callCount <= 5 {
return now
}
return now.Add(time.Hour)
}
stages := make([]string, 0, 8)
var stagesMu sync.Mutex
_, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
IdleTimeout: 2 * time.Second,
Now: nowFn,
OnTrace: func(event RelayTraceEvent) {
stagesMu.Lock()
stages = append(stages, event.Stage)
stagesMu.Unlock()
},
})
require.NotNil(t, relayExit)
require.Equal(t, "idle_timeout", relayExit.Stage)
stagesMu.Lock()
capturedStages := append([]string(nil), stages...)
stagesMu.Unlock()
require.Contains(t, capturedStages, "idle_timeout_triggered")
require.Contains(t, capturedStages, "relay_exit")
}
// errorOnWriteFrameConn 是一个写入总是失败的 FrameConn 实现,用于测试首包写入失败。
type errorOnWriteFrameConn struct{}
func (c *errorOnWriteFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
<-ctx.Done()
return coderws.MessageText, nil, ctx.Err()
}
func (c *errorOnWriteFrameConn) WriteFrame(_ context.Context, _ coderws.MessageType, _ []byte) error {
return errors.New("write failed: connection refused")
}
func (c *errorOnWriteFrameConn) Close() error {
return nil
}