新增 OpenAI WebSocket v2 passthrough relay 数据面与服务适配层, 支持按账号 ws mode 在 ctx_pool 与 passthrough 间路由。 同步调整前端 OpenAI ws mode 选项为 off/ctx_pool/passthrough, 并补充 i18n 文案与对应单测。 新增 Caddyfile.dmit 与 docker-compose-aicodex.yml 部署配置, 用于宿主机场景下的反向代理与服务编排。 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
753 lines
24 KiB
Go
753 lines
24 KiB
Go
package openai_ws_v2
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"io"
|
||
"sync"
|
||
"sync/atomic"
|
||
"testing"
|
||
"time"
|
||
|
||
coderws "github.com/coder/websocket"
|
||
"github.com/stretchr/testify/require"
|
||
)
|
||
|
||
type passthroughTestFrame struct {
|
||
msgType coderws.MessageType
|
||
payload []byte
|
||
}
|
||
|
||
type passthroughTestFrameConn struct {
|
||
mu sync.Mutex
|
||
writes []passthroughTestFrame
|
||
readCh chan passthroughTestFrame
|
||
once sync.Once
|
||
}
|
||
|
||
type delayedReadFrameConn struct {
|
||
base FrameConn
|
||
firstDelay time.Duration
|
||
once sync.Once
|
||
}
|
||
|
||
type closeSpyFrameConn struct {
|
||
closeCalls atomic.Int32
|
||
}
|
||
|
||
func newPassthroughTestFrameConn(frames []passthroughTestFrame, autoClose bool) *passthroughTestFrameConn {
|
||
c := &passthroughTestFrameConn{
|
||
readCh: make(chan passthroughTestFrame, len(frames)+1),
|
||
}
|
||
for _, frame := range frames {
|
||
copied := passthroughTestFrame{msgType: frame.msgType, payload: append([]byte(nil), frame.payload...)}
|
||
c.readCh <- copied
|
||
}
|
||
if autoClose {
|
||
close(c.readCh)
|
||
}
|
||
return c
|
||
}
|
||
|
||
func (c *passthroughTestFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
|
||
if ctx == nil {
|
||
ctx = context.Background()
|
||
}
|
||
select {
|
||
case <-ctx.Done():
|
||
return coderws.MessageText, nil, ctx.Err()
|
||
case frame, ok := <-c.readCh:
|
||
if !ok {
|
||
return coderws.MessageText, nil, io.EOF
|
||
}
|
||
return frame.msgType, append([]byte(nil), frame.payload...), nil
|
||
}
|
||
}
|
||
|
||
func (c *passthroughTestFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
|
||
if ctx == nil {
|
||
ctx = context.Background()
|
||
}
|
||
select {
|
||
case <-ctx.Done():
|
||
return ctx.Err()
|
||
default:
|
||
}
|
||
c.mu.Lock()
|
||
defer c.mu.Unlock()
|
||
c.writes = append(c.writes, passthroughTestFrame{msgType: msgType, payload: append([]byte(nil), payload...)})
|
||
return nil
|
||
}
|
||
|
||
func (c *passthroughTestFrameConn) Close() error {
|
||
c.once.Do(func() {
|
||
defer func() { _ = recover() }()
|
||
close(c.readCh)
|
||
})
|
||
return nil
|
||
}
|
||
|
||
func (c *passthroughTestFrameConn) Writes() []passthroughTestFrame {
|
||
c.mu.Lock()
|
||
defer c.mu.Unlock()
|
||
out := make([]passthroughTestFrame, len(c.writes))
|
||
copy(out, c.writes)
|
||
return out
|
||
}
|
||
|
||
func (c *delayedReadFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
|
||
if c == nil || c.base == nil {
|
||
return coderws.MessageText, nil, io.EOF
|
||
}
|
||
c.once.Do(func() {
|
||
if c.firstDelay > 0 {
|
||
timer := time.NewTimer(c.firstDelay)
|
||
defer timer.Stop()
|
||
select {
|
||
case <-ctx.Done():
|
||
case <-timer.C:
|
||
}
|
||
}
|
||
})
|
||
return c.base.ReadFrame(ctx)
|
||
}
|
||
|
||
func (c *delayedReadFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
|
||
if c == nil || c.base == nil {
|
||
return io.EOF
|
||
}
|
||
return c.base.WriteFrame(ctx, msgType, payload)
|
||
}
|
||
|
||
func (c *delayedReadFrameConn) Close() error {
|
||
if c == nil || c.base == nil {
|
||
return nil
|
||
}
|
||
return c.base.Close()
|
||
}
|
||
|
||
func (c *closeSpyFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
|
||
if ctx == nil {
|
||
ctx = context.Background()
|
||
}
|
||
<-ctx.Done()
|
||
return coderws.MessageText, nil, ctx.Err()
|
||
}
|
||
|
||
func (c *closeSpyFrameConn) WriteFrame(ctx context.Context, _ coderws.MessageType, _ []byte) error {
|
||
if ctx == nil {
|
||
ctx = context.Background()
|
||
}
|
||
select {
|
||
case <-ctx.Done():
|
||
return ctx.Err()
|
||
default:
|
||
return nil
|
||
}
|
||
}
|
||
|
||
func (c *closeSpyFrameConn) Close() error {
|
||
if c != nil {
|
||
c.closeCalls.Add(1)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (c *closeSpyFrameConn) CloseCalls() int32 {
|
||
if c == nil {
|
||
return 0
|
||
}
|
||
return c.closeCalls.Load()
|
||
}
|
||
|
||
func TestRelay_BasicRelayAndUsage(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
clientConn := newPassthroughTestFrameConn(nil, false)
|
||
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
|
||
{
|
||
msgType: coderws.MessageText,
|
||
payload: []byte(`{"type":"response.completed","response":{"id":"resp_123","usage":{"input_tokens":7,"output_tokens":3,"input_tokens_details":{"cached_tokens":2}}}}`),
|
||
},
|
||
}, true)
|
||
|
||
firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[{"type":"input_text","text":"hello"}]}`)
|
||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||
defer cancel()
|
||
|
||
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
|
||
require.Nil(t, relayExit)
|
||
require.Equal(t, "gpt-5.3-codex", result.RequestModel)
|
||
require.Equal(t, "resp_123", result.RequestID)
|
||
require.Equal(t, "response.completed", result.TerminalEventType)
|
||
require.Equal(t, 7, result.Usage.InputTokens)
|
||
require.Equal(t, 3, result.Usage.OutputTokens)
|
||
require.Equal(t, 2, result.Usage.CacheReadInputTokens)
|
||
require.NotNil(t, result.FirstTokenMs)
|
||
require.Equal(t, int64(1), result.ClientToUpstreamFrames)
|
||
require.Equal(t, int64(1), result.UpstreamToClientFrames)
|
||
require.Equal(t, int64(0), result.DroppedDownstreamFrames)
|
||
|
||
upstreamWrites := upstreamConn.Writes()
|
||
require.Len(t, upstreamWrites, 1)
|
||
require.Equal(t, coderws.MessageText, upstreamWrites[0].msgType)
|
||
require.JSONEq(t, string(firstPayload), string(upstreamWrites[0].payload))
|
||
|
||
clientWrites := clientConn.Writes()
|
||
require.Len(t, clientWrites, 1)
|
||
require.Equal(t, coderws.MessageText, clientWrites[0].msgType)
|
||
require.JSONEq(t, `{"type":"response.completed","response":{"id":"resp_123","usage":{"input_tokens":7,"output_tokens":3,"input_tokens_details":{"cached_tokens":2}}}}`, string(clientWrites[0].payload))
|
||
}
|
||
|
||
func TestRelay_FunctionCallOutputBytesPreserved(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
clientConn := newPassthroughTestFrameConn(nil, false)
|
||
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
|
||
{
|
||
msgType: coderws.MessageText,
|
||
payload: []byte(`{"type":"response.completed","response":{"id":"resp_func","usage":{"input_tokens":1,"output_tokens":1}}}`),
|
||
},
|
||
}, true)
|
||
|
||
firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[{"type":"function_call_output","call_id":"call_abc123","output":"{\"ok\":true}"}]}`)
|
||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||
defer cancel()
|
||
|
||
_, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
|
||
require.Nil(t, relayExit)
|
||
|
||
upstreamWrites := upstreamConn.Writes()
|
||
require.Len(t, upstreamWrites, 1)
|
||
require.Equal(t, coderws.MessageText, upstreamWrites[0].msgType)
|
||
require.Equal(t, firstPayload, upstreamWrites[0].payload)
|
||
}
|
||
|
||
func TestRelay_UpstreamDisconnect(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
// 上游立即关闭(EOF),客户端不发送额外帧
|
||
clientConn := newPassthroughTestFrameConn(nil, false)
|
||
upstreamConn := newPassthroughTestFrameConn(nil, true) // 立即 close -> EOF
|
||
|
||
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||
defer cancel()
|
||
|
||
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
|
||
// 上游 EOF 属于 disconnect,标记为 graceful
|
||
require.Nil(t, relayExit, "上游 EOF 应被视为 graceful disconnect")
|
||
require.Equal(t, "gpt-4o", result.RequestModel)
|
||
}
|
||
|
||
func TestRelay_ClientDisconnect(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
// 客户端立即关闭(EOF),上游阻塞读取直到 context 取消
|
||
clientConn := newPassthroughTestFrameConn(nil, true) // 立即 close -> EOF
|
||
upstreamConn := newPassthroughTestFrameConn(nil, false)
|
||
|
||
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||
defer cancel()
|
||
|
||
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
|
||
require.NotNil(t, relayExit, "客户端 EOF 应返回可观测的中断状态")
|
||
require.Equal(t, "client_disconnected", relayExit.Stage)
|
||
require.Equal(t, "gpt-4o", result.RequestModel)
|
||
}
|
||
|
||
func TestRelay_ClientDisconnect_DrainCapturesLateUsage(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
clientConn := newPassthroughTestFrameConn(nil, true)
|
||
upstreamBase := newPassthroughTestFrameConn([]passthroughTestFrame{
|
||
{
|
||
msgType: coderws.MessageText,
|
||
payload: []byte(`{"type":"response.completed","response":{"id":"resp_drain","usage":{"input_tokens":6,"output_tokens":4,"input_tokens_details":{"cached_tokens":1}}}}`),
|
||
},
|
||
}, true)
|
||
upstreamConn := &delayedReadFrameConn{
|
||
base: upstreamBase,
|
||
firstDelay: 80 * time.Millisecond,
|
||
}
|
||
|
||
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||
defer cancel()
|
||
|
||
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
|
||
UpstreamDrainTimeout: 400 * time.Millisecond,
|
||
})
|
||
require.NotNil(t, relayExit)
|
||
require.Equal(t, "client_disconnected", relayExit.Stage)
|
||
require.Equal(t, "resp_drain", result.RequestID)
|
||
require.Equal(t, "response.completed", result.TerminalEventType)
|
||
require.Equal(t, 6, result.Usage.InputTokens)
|
||
require.Equal(t, 4, result.Usage.OutputTokens)
|
||
require.Equal(t, 1, result.Usage.CacheReadInputTokens)
|
||
require.Equal(t, int64(1), result.ClientToUpstreamFrames)
|
||
require.Equal(t, int64(0), result.UpstreamToClientFrames)
|
||
require.Equal(t, int64(1), result.DroppedDownstreamFrames)
|
||
}
|
||
|
||
func TestRelay_IdleTimeout(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
// 客户端和上游都不发送帧,idle timeout 应触发
|
||
clientConn := newPassthroughTestFrameConn(nil, false)
|
||
upstreamConn := newPassthroughTestFrameConn(nil, false)
|
||
|
||
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||
defer cancel()
|
||
|
||
// 使用快进时间来加速 idle timeout
|
||
now := time.Now()
|
||
callCount := 0
|
||
nowFn := func() time.Time {
|
||
callCount++
|
||
// 前几次调用返回正常时间(初始化阶段),之后快进
|
||
if callCount <= 5 {
|
||
return now
|
||
}
|
||
return now.Add(time.Hour) // 快进到超时
|
||
}
|
||
|
||
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
|
||
IdleTimeout: 2 * time.Second,
|
||
Now: nowFn,
|
||
})
|
||
require.NotNil(t, relayExit, "应因 idle timeout 退出")
|
||
require.Equal(t, "idle_timeout", relayExit.Stage)
|
||
require.Equal(t, "gpt-4o", result.RequestModel)
|
||
}
|
||
|
||
func TestRelay_IdleTimeoutDoesNotCloseClientOnError(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
clientConn := &closeSpyFrameConn{}
|
||
upstreamConn := &closeSpyFrameConn{}
|
||
|
||
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||
defer cancel()
|
||
|
||
now := time.Now()
|
||
callCount := 0
|
||
nowFn := func() time.Time {
|
||
callCount++
|
||
if callCount <= 5 {
|
||
return now
|
||
}
|
||
return now.Add(time.Hour)
|
||
}
|
||
|
||
_, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
|
||
IdleTimeout: 2 * time.Second,
|
||
Now: nowFn,
|
||
})
|
||
require.NotNil(t, relayExit, "应因 idle timeout 退出")
|
||
require.Equal(t, "idle_timeout", relayExit.Stage)
|
||
require.Zero(t, clientConn.CloseCalls(), "错误路径不应提前关闭客户端连接,交给上层决定 close code")
|
||
require.GreaterOrEqual(t, upstreamConn.CloseCalls(), int32(1))
|
||
}
|
||
|
||
func TestRelay_NilConnections(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||
ctx := context.Background()
|
||
|
||
t.Run("nil client conn", func(t *testing.T) {
|
||
upstreamConn := newPassthroughTestFrameConn(nil, true)
|
||
_, relayExit := Relay(ctx, nil, upstreamConn, firstPayload, RelayOptions{})
|
||
require.NotNil(t, relayExit)
|
||
require.Equal(t, "relay_init", relayExit.Stage)
|
||
require.Contains(t, relayExit.Err.Error(), "nil")
|
||
})
|
||
|
||
t.Run("nil upstream conn", func(t *testing.T) {
|
||
clientConn := newPassthroughTestFrameConn(nil, true)
|
||
_, relayExit := Relay(ctx, clientConn, nil, firstPayload, RelayOptions{})
|
||
require.NotNil(t, relayExit)
|
||
require.Equal(t, "relay_init", relayExit.Stage)
|
||
require.Contains(t, relayExit.Err.Error(), "nil")
|
||
})
|
||
}
|
||
|
||
func TestRelay_MultipleUpstreamMessages(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
// 上游发送多个事件(delta + completed),验证多帧中继和 usage 聚合
|
||
clientConn := newPassthroughTestFrameConn(nil, false)
|
||
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
|
||
{
|
||
msgType: coderws.MessageText,
|
||
payload: []byte(`{"type":"response.output_text.delta","delta":"Hello"}`),
|
||
},
|
||
{
|
||
msgType: coderws.MessageText,
|
||
payload: []byte(`{"type":"response.output_text.delta","delta":" world"}`),
|
||
},
|
||
{
|
||
msgType: coderws.MessageText,
|
||
payload: []byte(`{"type":"response.completed","response":{"id":"resp_multi","usage":{"input_tokens":10,"output_tokens":5,"input_tokens_details":{"cached_tokens":3}}}}`),
|
||
},
|
||
}, true)
|
||
|
||
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[{"type":"input_text","text":"hi"}]}`)
|
||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||
defer cancel()
|
||
|
||
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
|
||
require.Nil(t, relayExit)
|
||
require.Equal(t, "resp_multi", result.RequestID)
|
||
require.Equal(t, "response.completed", result.TerminalEventType)
|
||
require.Equal(t, 10, result.Usage.InputTokens)
|
||
require.Equal(t, 5, result.Usage.OutputTokens)
|
||
require.Equal(t, 3, result.Usage.CacheReadInputTokens)
|
||
require.NotNil(t, result.FirstTokenMs)
|
||
|
||
// 验证所有 3 个上游帧都转发给了客户端
|
||
clientWrites := clientConn.Writes()
|
||
require.Len(t, clientWrites, 3)
|
||
}
|
||
|
||
func TestRelay_OnTurnComplete_PerTerminalEvent(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
clientConn := newPassthroughTestFrameConn(nil, false)
|
||
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
|
||
{
|
||
msgType: coderws.MessageText,
|
||
payload: []byte(`{"type":"response.completed","response":{"id":"resp_turn_1","usage":{"input_tokens":2,"output_tokens":1}}}`),
|
||
},
|
||
{
|
||
msgType: coderws.MessageText,
|
||
payload: []byte(`{"type":"response.failed","response":{"id":"resp_turn_2","usage":{"input_tokens":3,"output_tokens":4}}}`),
|
||
},
|
||
}, true)
|
||
|
||
firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`)
|
||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||
defer cancel()
|
||
|
||
turns := make([]RelayTurnResult, 0, 2)
|
||
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
|
||
OnTurnComplete: func(turn RelayTurnResult) {
|
||
turns = append(turns, turn)
|
||
},
|
||
})
|
||
require.Nil(t, relayExit)
|
||
require.Len(t, turns, 2)
|
||
require.Equal(t, "resp_turn_1", turns[0].RequestID)
|
||
require.Equal(t, "response.completed", turns[0].TerminalEventType)
|
||
require.Equal(t, 2, turns[0].Usage.InputTokens)
|
||
require.Equal(t, 1, turns[0].Usage.OutputTokens)
|
||
require.Equal(t, "resp_turn_2", turns[1].RequestID)
|
||
require.Equal(t, "response.failed", turns[1].TerminalEventType)
|
||
require.Equal(t, 3, turns[1].Usage.InputTokens)
|
||
require.Equal(t, 4, turns[1].Usage.OutputTokens)
|
||
require.Equal(t, 5, result.Usage.InputTokens)
|
||
require.Equal(t, 5, result.Usage.OutputTokens)
|
||
}
|
||
|
||
func TestRelay_OnTurnComplete_ProvidesTurnMetrics(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
clientConn := newPassthroughTestFrameConn(nil, false)
|
||
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
|
||
{
|
||
msgType: coderws.MessageText,
|
||
payload: []byte(`{"type":"response.output_text.delta","response_id":"resp_metric","delta":"hi"}`),
|
||
},
|
||
{
|
||
msgType: coderws.MessageText,
|
||
payload: []byte(`{"type":"response.completed","response":{"id":"resp_metric","usage":{"input_tokens":2,"output_tokens":1}}}`),
|
||
},
|
||
}, true)
|
||
|
||
firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`)
|
||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||
defer cancel()
|
||
|
||
base := time.Unix(0, 0)
|
||
var nowTick atomic.Int64
|
||
nowFn := func() time.Time {
|
||
step := nowTick.Add(1)
|
||
return base.Add(time.Duration(step) * 5 * time.Millisecond)
|
||
}
|
||
|
||
var turn RelayTurnResult
|
||
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
|
||
Now: nowFn,
|
||
OnTurnComplete: func(current RelayTurnResult) {
|
||
turn = current
|
||
},
|
||
})
|
||
require.Nil(t, relayExit)
|
||
require.Equal(t, "resp_metric", turn.RequestID)
|
||
require.Equal(t, "response.completed", turn.TerminalEventType)
|
||
require.NotNil(t, turn.FirstTokenMs)
|
||
require.GreaterOrEqual(t, *turn.FirstTokenMs, 0)
|
||
require.Greater(t, turn.Duration.Milliseconds(), int64(0))
|
||
require.NotNil(t, result.FirstTokenMs)
|
||
require.Greater(t, result.Duration.Milliseconds(), int64(0))
|
||
}
|
||
|
||
func TestRelay_BinaryFramePassthrough(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
// 验证 binary frame 被透传但不进行 usage 解析
|
||
binaryPayload := []byte{0x00, 0x01, 0x02, 0x03}
|
||
clientConn := newPassthroughTestFrameConn(nil, false)
|
||
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
|
||
{
|
||
msgType: coderws.MessageBinary,
|
||
payload: binaryPayload,
|
||
},
|
||
}, true)
|
||
|
||
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||
defer cancel()
|
||
|
||
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
|
||
require.Nil(t, relayExit)
|
||
// binary frame 不解析 usage
|
||
require.Equal(t, 0, result.Usage.InputTokens)
|
||
|
||
clientWrites := clientConn.Writes()
|
||
require.Len(t, clientWrites, 1)
|
||
require.Equal(t, coderws.MessageBinary, clientWrites[0].msgType)
|
||
require.Equal(t, binaryPayload, clientWrites[0].payload)
|
||
}
|
||
|
||
func TestRelay_BinaryJSONFrameSkipsObservation(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
clientConn := newPassthroughTestFrameConn(nil, false)
|
||
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
|
||
{
|
||
msgType: coderws.MessageBinary,
|
||
payload: []byte(`{"type":"response.completed","response":{"id":"resp_binary","usage":{"input_tokens":7,"output_tokens":3}}}`),
|
||
},
|
||
}, true)
|
||
|
||
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||
defer cancel()
|
||
|
||
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
|
||
require.Nil(t, relayExit)
|
||
require.Equal(t, 0, result.Usage.InputTokens)
|
||
require.Equal(t, "", result.RequestID)
|
||
require.Equal(t, "", result.TerminalEventType)
|
||
|
||
clientWrites := clientConn.Writes()
|
||
require.Len(t, clientWrites, 1)
|
||
require.Equal(t, coderws.MessageBinary, clientWrites[0].msgType)
|
||
}
|
||
|
||
func TestRelay_UpstreamErrorEventPassthroughRaw(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
clientConn := newPassthroughTestFrameConn(nil, false)
|
||
errorEvent := []byte(`{"type":"error","error":{"type":"invalid_request_error","message":"No tool call found"}}`)
|
||
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
|
||
{
|
||
msgType: coderws.MessageText,
|
||
payload: errorEvent,
|
||
},
|
||
}, true)
|
||
|
||
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||
defer cancel()
|
||
|
||
_, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
|
||
require.Nil(t, relayExit)
|
||
|
||
clientWrites := clientConn.Writes()
|
||
require.Len(t, clientWrites, 1)
|
||
require.Equal(t, coderws.MessageText, clientWrites[0].msgType)
|
||
require.Equal(t, errorEvent, clientWrites[0].payload)
|
||
}
|
||
|
||
func TestRelay_PreservesFirstMessageType(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
clientConn := newPassthroughTestFrameConn(nil, false)
|
||
upstreamConn := newPassthroughTestFrameConn(nil, true)
|
||
|
||
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||
defer cancel()
|
||
|
||
_, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
|
||
FirstMessageType: coderws.MessageBinary,
|
||
})
|
||
require.Nil(t, relayExit)
|
||
|
||
upstreamWrites := upstreamConn.Writes()
|
||
require.Len(t, upstreamWrites, 1)
|
||
require.Equal(t, coderws.MessageBinary, upstreamWrites[0].msgType)
|
||
require.Equal(t, firstPayload, upstreamWrites[0].payload)
|
||
}
|
||
|
||
func TestRelay_UsageParseFailureDoesNotBlockRelay(t *testing.T) {
|
||
baseline := SnapshotMetrics().UsageParseFailureTotal
|
||
|
||
// 上游发送无效 JSON(非 usage 格式),不应影响透传
|
||
clientConn := newPassthroughTestFrameConn(nil, false)
|
||
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
|
||
{
|
||
msgType: coderws.MessageText,
|
||
payload: []byte(`{"type":"response.completed","response":{"id":"resp_bad","usage":"not_an_object"}}`),
|
||
},
|
||
}, true)
|
||
|
||
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||
defer cancel()
|
||
|
||
result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
|
||
require.Nil(t, relayExit)
|
||
// usage 解析失败,值为 0 但不影响透传
|
||
require.Equal(t, 0, result.Usage.InputTokens)
|
||
require.Equal(t, "response.completed", result.TerminalEventType)
|
||
|
||
// 帧仍然被转发
|
||
clientWrites := clientConn.Writes()
|
||
require.Len(t, clientWrites, 1)
|
||
require.GreaterOrEqual(t, SnapshotMetrics().UsageParseFailureTotal, baseline+1)
|
||
}
|
||
|
||
func TestRelay_WriteUpstreamFirstMessageFails(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
// 上游连接立即关闭,首包写入失败
|
||
upstreamConn := newPassthroughTestFrameConn(nil, true)
|
||
_ = upstreamConn.Close()
|
||
|
||
// 覆盖 WriteFrame 使其返回错误
|
||
errConn := &errorOnWriteFrameConn{}
|
||
clientConn := newPassthroughTestFrameConn(nil, false)
|
||
|
||
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||
defer cancel()
|
||
|
||
_, relayExit := Relay(ctx, clientConn, errConn, firstPayload, RelayOptions{})
|
||
require.NotNil(t, relayExit)
|
||
require.Equal(t, "write_upstream", relayExit.Stage)
|
||
}
|
||
|
||
func TestRelay_ContextCanceled(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
clientConn := newPassthroughTestFrameConn(nil, false)
|
||
upstreamConn := newPassthroughTestFrameConn(nil, false)
|
||
|
||
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||
|
||
// 立即取消 context
|
||
ctx, cancel := context.WithCancel(context.Background())
|
||
cancel()
|
||
|
||
_, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
|
||
// context 取消导致写首包失败
|
||
require.NotNil(t, relayExit)
|
||
}
|
||
|
||
func TestRelay_TraceEvents_ContainsLifecycleStages(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
clientConn := newPassthroughTestFrameConn(nil, false)
|
||
upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
|
||
{
|
||
msgType: coderws.MessageText,
|
||
payload: []byte(`{"type":"response.completed","response":{"id":"resp_trace","usage":{"input_tokens":1,"output_tokens":1}}}`),
|
||
},
|
||
}, true)
|
||
|
||
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||
defer cancel()
|
||
|
||
stages := make([]string, 0, 8)
|
||
var stagesMu sync.Mutex
|
||
_, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
|
||
OnTrace: func(event RelayTraceEvent) {
|
||
stagesMu.Lock()
|
||
stages = append(stages, event.Stage)
|
||
stagesMu.Unlock()
|
||
},
|
||
})
|
||
require.Nil(t, relayExit)
|
||
stagesMu.Lock()
|
||
capturedStages := append([]string(nil), stages...)
|
||
stagesMu.Unlock()
|
||
require.Contains(t, capturedStages, "relay_start")
|
||
require.Contains(t, capturedStages, "write_first_message_ok")
|
||
require.Contains(t, capturedStages, "first_exit")
|
||
require.Contains(t, capturedStages, "relay_complete")
|
||
}
|
||
|
||
func TestRelay_TraceEvents_IdleTimeout(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
clientConn := newPassthroughTestFrameConn(nil, false)
|
||
upstreamConn := newPassthroughTestFrameConn(nil, false)
|
||
|
||
firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
|
||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||
defer cancel()
|
||
|
||
now := time.Now()
|
||
callCount := 0
|
||
nowFn := func() time.Time {
|
||
callCount++
|
||
if callCount <= 5 {
|
||
return now
|
||
}
|
||
return now.Add(time.Hour)
|
||
}
|
||
|
||
stages := make([]string, 0, 8)
|
||
var stagesMu sync.Mutex
|
||
_, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
|
||
IdleTimeout: 2 * time.Second,
|
||
Now: nowFn,
|
||
OnTrace: func(event RelayTraceEvent) {
|
||
stagesMu.Lock()
|
||
stages = append(stages, event.Stage)
|
||
stagesMu.Unlock()
|
||
},
|
||
})
|
||
require.NotNil(t, relayExit)
|
||
require.Equal(t, "idle_timeout", relayExit.Stage)
|
||
stagesMu.Lock()
|
||
capturedStages := append([]string(nil), stages...)
|
||
stagesMu.Unlock()
|
||
require.Contains(t, capturedStages, "idle_timeout_triggered")
|
||
require.Contains(t, capturedStages, "relay_exit")
|
||
}
|
||
|
||
// errorOnWriteFrameConn 是一个写入总是失败的 FrameConn 实现,用于测试首包写入失败。
|
||
type errorOnWriteFrameConn struct{}
|
||
|
||
func (c *errorOnWriteFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
|
||
<-ctx.Done()
|
||
return coderws.MessageText, nil, ctx.Err()
|
||
}
|
||
|
||
func (c *errorOnWriteFrameConn) WriteFrame(_ context.Context, _ coderws.MessageType, _ []byte) error {
|
||
return errors.New("write failed: connection refused")
|
||
}
|
||
|
||
func (c *errorOnWriteFrameConn) Close() error {
|
||
return nil
|
||
}
|