feat(openai-ws): 合并 WS v2 透传模式与前端 ws mode
新增 OpenAI WebSocket v2 passthrough relay 数据面与服务适配层, 支持按账号 ws mode 在 ctx_pool 与 passthrough 间路由。 同步调整前端 OpenAI ws mode 选项为 off/ctx_pool/passthrough, 并补充 i18n 文案与对应单测。 新增 Caddyfile.dmit 与 docker-compose-aicodex.yml 部署配置, 用于宿主机场景下的反向代理与服务编排。 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
807
backend/internal/service/openai_ws_v2/passthrough_relay.go
Normal file
807
backend/internal/service/openai_ws_v2/passthrough_relay.go
Normal file
@@ -0,0 +1,807 @@
|
||||
package openai_ws_v2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
coderws "github.com/coder/websocket"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type FrameConn interface {
|
||||
ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error)
|
||||
WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error
|
||||
Close() error
|
||||
}
|
||||
|
||||
type Usage struct {
|
||||
InputTokens int
|
||||
OutputTokens int
|
||||
CacheCreationInputTokens int
|
||||
CacheReadInputTokens int
|
||||
}
|
||||
|
||||
type RelayResult struct {
|
||||
RequestModel string
|
||||
Usage Usage
|
||||
RequestID string
|
||||
TerminalEventType string
|
||||
FirstTokenMs *int
|
||||
Duration time.Duration
|
||||
ClientToUpstreamFrames int64
|
||||
UpstreamToClientFrames int64
|
||||
DroppedDownstreamFrames int64
|
||||
}
|
||||
|
||||
type RelayTurnResult struct {
|
||||
RequestModel string
|
||||
Usage Usage
|
||||
RequestID string
|
||||
TerminalEventType string
|
||||
Duration time.Duration
|
||||
FirstTokenMs *int
|
||||
}
|
||||
|
||||
type RelayExit struct {
|
||||
Stage string
|
||||
Err error
|
||||
WroteDownstream bool
|
||||
}
|
||||
|
||||
type RelayOptions struct {
|
||||
WriteTimeout time.Duration
|
||||
IdleTimeout time.Duration
|
||||
UpstreamDrainTimeout time.Duration
|
||||
FirstMessageType coderws.MessageType
|
||||
OnUsageParseFailure func(eventType string, usageRaw string)
|
||||
OnTurnComplete func(turn RelayTurnResult)
|
||||
OnTrace func(event RelayTraceEvent)
|
||||
Now func() time.Time
|
||||
}
|
||||
|
||||
type RelayTraceEvent struct {
|
||||
Stage string
|
||||
Direction string
|
||||
MessageType string
|
||||
PayloadBytes int
|
||||
Graceful bool
|
||||
WroteDownstream bool
|
||||
Error string
|
||||
}
|
||||
|
||||
type relayState struct {
|
||||
usage Usage
|
||||
requestModel string
|
||||
lastResponseID string
|
||||
terminalEventType string
|
||||
firstTokenMs *int
|
||||
turnTimingByID map[string]*relayTurnTiming
|
||||
}
|
||||
|
||||
type relayExitSignal struct {
|
||||
stage string
|
||||
err error
|
||||
graceful bool
|
||||
wroteDownstream bool
|
||||
}
|
||||
|
||||
type observedUpstreamEvent struct {
|
||||
terminal bool
|
||||
eventType string
|
||||
responseID string
|
||||
usage Usage
|
||||
duration time.Duration
|
||||
firstToken *int
|
||||
}
|
||||
|
||||
type relayTurnTiming struct {
|
||||
startAt time.Time
|
||||
firstTokenMs *int
|
||||
}
|
||||
|
||||
func Relay(
|
||||
ctx context.Context,
|
||||
clientConn FrameConn,
|
||||
upstreamConn FrameConn,
|
||||
firstClientMessage []byte,
|
||||
options RelayOptions,
|
||||
) (RelayResult, *RelayExit) {
|
||||
result := RelayResult{RequestModel: strings.TrimSpace(gjson.GetBytes(firstClientMessage, "model").String())}
|
||||
if clientConn == nil || upstreamConn == nil {
|
||||
return result, &RelayExit{Stage: "relay_init", Err: errors.New("relay connection is nil")}
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
nowFn := options.Now
|
||||
if nowFn == nil {
|
||||
nowFn = time.Now
|
||||
}
|
||||
writeTimeout := options.WriteTimeout
|
||||
if writeTimeout <= 0 {
|
||||
writeTimeout = 2 * time.Minute
|
||||
}
|
||||
drainTimeout := options.UpstreamDrainTimeout
|
||||
if drainTimeout <= 0 {
|
||||
drainTimeout = 1200 * time.Millisecond
|
||||
}
|
||||
firstMessageType := options.FirstMessageType
|
||||
if firstMessageType != coderws.MessageBinary {
|
||||
firstMessageType = coderws.MessageText
|
||||
}
|
||||
startAt := nowFn()
|
||||
state := &relayState{requestModel: result.RequestModel}
|
||||
onTrace := options.OnTrace
|
||||
|
||||
relayCtx, relayCancel := context.WithCancel(ctx)
|
||||
defer relayCancel()
|
||||
|
||||
lastActivity := atomic.Int64{}
|
||||
lastActivity.Store(nowFn().UnixNano())
|
||||
markActivity := func() {
|
||||
lastActivity.Store(nowFn().UnixNano())
|
||||
}
|
||||
|
||||
writeUpstream := func(msgType coderws.MessageType, payload []byte) error {
|
||||
writeCtx, cancel := context.WithTimeout(relayCtx, writeTimeout)
|
||||
defer cancel()
|
||||
return upstreamConn.WriteFrame(writeCtx, msgType, payload)
|
||||
}
|
||||
writeClient := func(msgType coderws.MessageType, payload []byte) error {
|
||||
writeCtx, cancel := context.WithTimeout(relayCtx, writeTimeout)
|
||||
defer cancel()
|
||||
return clientConn.WriteFrame(writeCtx, msgType, payload)
|
||||
}
|
||||
|
||||
clientToUpstreamFrames := &atomic.Int64{}
|
||||
upstreamToClientFrames := &atomic.Int64{}
|
||||
droppedDownstreamFrames := &atomic.Int64{}
|
||||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||
Stage: "relay_start",
|
||||
PayloadBytes: len(firstClientMessage),
|
||||
MessageType: relayMessageTypeString(firstMessageType),
|
||||
})
|
||||
|
||||
if err := writeUpstream(firstMessageType, firstClientMessage); err != nil {
|
||||
result.Duration = nowFn().Sub(startAt)
|
||||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||
Stage: "write_first_message_failed",
|
||||
Direction: "client_to_upstream",
|
||||
MessageType: relayMessageTypeString(firstMessageType),
|
||||
PayloadBytes: len(firstClientMessage),
|
||||
Error: err.Error(),
|
||||
})
|
||||
return result, &RelayExit{Stage: "write_upstream", Err: err}
|
||||
}
|
||||
clientToUpstreamFrames.Add(1)
|
||||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||
Stage: "write_first_message_ok",
|
||||
Direction: "client_to_upstream",
|
||||
MessageType: relayMessageTypeString(firstMessageType),
|
||||
PayloadBytes: len(firstClientMessage),
|
||||
})
|
||||
markActivity()
|
||||
|
||||
exitCh := make(chan relayExitSignal, 3)
|
||||
dropDownstreamWrites := atomic.Bool{}
|
||||
go runClientToUpstream(relayCtx, clientConn, writeUpstream, markActivity, clientToUpstreamFrames, onTrace, exitCh)
|
||||
go runUpstreamToClient(
|
||||
relayCtx,
|
||||
upstreamConn,
|
||||
writeClient,
|
||||
startAt,
|
||||
nowFn,
|
||||
state,
|
||||
options.OnUsageParseFailure,
|
||||
options.OnTurnComplete,
|
||||
&dropDownstreamWrites,
|
||||
upstreamToClientFrames,
|
||||
droppedDownstreamFrames,
|
||||
markActivity,
|
||||
onTrace,
|
||||
exitCh,
|
||||
)
|
||||
go runIdleWatchdog(relayCtx, nowFn, options.IdleTimeout, &lastActivity, onTrace, exitCh)
|
||||
|
||||
firstExit := <-exitCh
|
||||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||
Stage: "first_exit",
|
||||
Direction: relayDirectionFromStage(firstExit.stage),
|
||||
Graceful: firstExit.graceful,
|
||||
WroteDownstream: firstExit.wroteDownstream,
|
||||
Error: relayErrorString(firstExit.err),
|
||||
})
|
||||
combinedWroteDownstream := firstExit.wroteDownstream
|
||||
secondExit := relayExitSignal{graceful: true}
|
||||
hasSecondExit := false
|
||||
|
||||
// 客户端断开后尽力继续读取上游短窗口,捕获延迟 usage/terminal 事件用于计费。
|
||||
if firstExit.stage == "read_client" && firstExit.graceful {
|
||||
dropDownstreamWrites.Store(true)
|
||||
secondExit, hasSecondExit = waitRelayExit(exitCh, drainTimeout)
|
||||
} else {
|
||||
relayCancel()
|
||||
_ = upstreamConn.Close()
|
||||
secondExit, hasSecondExit = waitRelayExit(exitCh, 200*time.Millisecond)
|
||||
}
|
||||
if hasSecondExit {
|
||||
combinedWroteDownstream = combinedWroteDownstream || secondExit.wroteDownstream
|
||||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||
Stage: "second_exit",
|
||||
Direction: relayDirectionFromStage(secondExit.stage),
|
||||
Graceful: secondExit.graceful,
|
||||
WroteDownstream: secondExit.wroteDownstream,
|
||||
Error: relayErrorString(secondExit.err),
|
||||
})
|
||||
}
|
||||
|
||||
relayCancel()
|
||||
_ = upstreamConn.Close()
|
||||
|
||||
enrichResult(&result, state, nowFn().Sub(startAt))
|
||||
result.ClientToUpstreamFrames = clientToUpstreamFrames.Load()
|
||||
result.UpstreamToClientFrames = upstreamToClientFrames.Load()
|
||||
result.DroppedDownstreamFrames = droppedDownstreamFrames.Load()
|
||||
if firstExit.stage == "read_client" && firstExit.graceful {
|
||||
stage := "client_disconnected"
|
||||
exitErr := firstExit.err
|
||||
if hasSecondExit && !secondExit.graceful {
|
||||
stage = secondExit.stage
|
||||
exitErr = secondExit.err
|
||||
}
|
||||
if exitErr == nil {
|
||||
exitErr = io.EOF
|
||||
}
|
||||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||
Stage: "relay_exit",
|
||||
Direction: relayDirectionFromStage(stage),
|
||||
Graceful: false,
|
||||
WroteDownstream: combinedWroteDownstream,
|
||||
Error: relayErrorString(exitErr),
|
||||
})
|
||||
return result, &RelayExit{
|
||||
Stage: stage,
|
||||
Err: exitErr,
|
||||
WroteDownstream: combinedWroteDownstream,
|
||||
}
|
||||
}
|
||||
if firstExit.graceful && (!hasSecondExit || secondExit.graceful) {
|
||||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||
Stage: "relay_complete",
|
||||
Graceful: true,
|
||||
WroteDownstream: combinedWroteDownstream,
|
||||
})
|
||||
_ = clientConn.Close()
|
||||
return result, nil
|
||||
}
|
||||
if !firstExit.graceful {
|
||||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||
Stage: "relay_exit",
|
||||
Direction: relayDirectionFromStage(firstExit.stage),
|
||||
Graceful: false,
|
||||
WroteDownstream: combinedWroteDownstream,
|
||||
Error: relayErrorString(firstExit.err),
|
||||
})
|
||||
return result, &RelayExit{
|
||||
Stage: firstExit.stage,
|
||||
Err: firstExit.err,
|
||||
WroteDownstream: combinedWroteDownstream,
|
||||
}
|
||||
}
|
||||
if hasSecondExit && !secondExit.graceful {
|
||||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||
Stage: "relay_exit",
|
||||
Direction: relayDirectionFromStage(secondExit.stage),
|
||||
Graceful: false,
|
||||
WroteDownstream: combinedWroteDownstream,
|
||||
Error: relayErrorString(secondExit.err),
|
||||
})
|
||||
return result, &RelayExit{
|
||||
Stage: secondExit.stage,
|
||||
Err: secondExit.err,
|
||||
WroteDownstream: combinedWroteDownstream,
|
||||
}
|
||||
}
|
||||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||
Stage: "relay_complete",
|
||||
Graceful: true,
|
||||
WroteDownstream: combinedWroteDownstream,
|
||||
})
|
||||
_ = clientConn.Close()
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func runClientToUpstream(
|
||||
ctx context.Context,
|
||||
clientConn FrameConn,
|
||||
writeUpstream func(msgType coderws.MessageType, payload []byte) error,
|
||||
markActivity func(),
|
||||
forwardedFrames *atomic.Int64,
|
||||
onTrace func(event RelayTraceEvent),
|
||||
exitCh chan<- relayExitSignal,
|
||||
) {
|
||||
for {
|
||||
msgType, payload, err := clientConn.ReadFrame(ctx)
|
||||
if err != nil {
|
||||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||
Stage: "read_client_failed",
|
||||
Direction: "client_to_upstream",
|
||||
Error: err.Error(),
|
||||
Graceful: isDisconnectError(err),
|
||||
})
|
||||
exitCh <- relayExitSignal{stage: "read_client", err: err, graceful: isDisconnectError(err)}
|
||||
return
|
||||
}
|
||||
markActivity()
|
||||
if err := writeUpstream(msgType, payload); err != nil {
|
||||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||
Stage: "write_upstream_failed",
|
||||
Direction: "client_to_upstream",
|
||||
MessageType: relayMessageTypeString(msgType),
|
||||
PayloadBytes: len(payload),
|
||||
Error: err.Error(),
|
||||
})
|
||||
exitCh <- relayExitSignal{stage: "write_upstream", err: err}
|
||||
return
|
||||
}
|
||||
if forwardedFrames != nil {
|
||||
forwardedFrames.Add(1)
|
||||
}
|
||||
markActivity()
|
||||
}
|
||||
}
|
||||
|
||||
func runUpstreamToClient(
|
||||
ctx context.Context,
|
||||
upstreamConn FrameConn,
|
||||
writeClient func(msgType coderws.MessageType, payload []byte) error,
|
||||
startAt time.Time,
|
||||
nowFn func() time.Time,
|
||||
state *relayState,
|
||||
onUsageParseFailure func(eventType string, usageRaw string),
|
||||
onTurnComplete func(turn RelayTurnResult),
|
||||
dropDownstreamWrites *atomic.Bool,
|
||||
forwardedFrames *atomic.Int64,
|
||||
droppedFrames *atomic.Int64,
|
||||
markActivity func(),
|
||||
onTrace func(event RelayTraceEvent),
|
||||
exitCh chan<- relayExitSignal,
|
||||
) {
|
||||
wroteDownstream := false
|
||||
for {
|
||||
msgType, payload, err := upstreamConn.ReadFrame(ctx)
|
||||
if err != nil {
|
||||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||
Stage: "read_upstream_failed",
|
||||
Direction: "upstream_to_client",
|
||||
Error: err.Error(),
|
||||
Graceful: isDisconnectError(err),
|
||||
WroteDownstream: wroteDownstream,
|
||||
})
|
||||
exitCh <- relayExitSignal{
|
||||
stage: "read_upstream",
|
||||
err: err,
|
||||
graceful: isDisconnectError(err),
|
||||
wroteDownstream: wroteDownstream,
|
||||
}
|
||||
return
|
||||
}
|
||||
markActivity()
|
||||
observedEvent := observedUpstreamEvent{}
|
||||
switch msgType {
|
||||
case coderws.MessageText:
|
||||
observedEvent = observeUpstreamMessage(state, payload, startAt, nowFn, onUsageParseFailure)
|
||||
case coderws.MessageBinary:
|
||||
// binary frame 直接透传,不进入 JSON 观测路径(避免无效解析开销)。
|
||||
}
|
||||
emitTurnComplete(onTurnComplete, state, observedEvent)
|
||||
if dropDownstreamWrites != nil && dropDownstreamWrites.Load() {
|
||||
if droppedFrames != nil {
|
||||
droppedFrames.Add(1)
|
||||
}
|
||||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||
Stage: "drop_downstream_frame",
|
||||
Direction: "upstream_to_client",
|
||||
MessageType: relayMessageTypeString(msgType),
|
||||
PayloadBytes: len(payload),
|
||||
WroteDownstream: wroteDownstream,
|
||||
})
|
||||
if observedEvent.terminal {
|
||||
exitCh <- relayExitSignal{
|
||||
stage: "drain_terminal",
|
||||
graceful: true,
|
||||
wroteDownstream: wroteDownstream,
|
||||
}
|
||||
return
|
||||
}
|
||||
markActivity()
|
||||
continue
|
||||
}
|
||||
if err := writeClient(msgType, payload); err != nil {
|
||||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||
Stage: "write_client_failed",
|
||||
Direction: "upstream_to_client",
|
||||
MessageType: relayMessageTypeString(msgType),
|
||||
PayloadBytes: len(payload),
|
||||
WroteDownstream: wroteDownstream,
|
||||
Error: err.Error(),
|
||||
})
|
||||
exitCh <- relayExitSignal{stage: "write_client", err: err, wroteDownstream: wroteDownstream}
|
||||
return
|
||||
}
|
||||
wroteDownstream = true
|
||||
if forwardedFrames != nil {
|
||||
forwardedFrames.Add(1)
|
||||
}
|
||||
markActivity()
|
||||
}
|
||||
}
|
||||
|
||||
func runIdleWatchdog(
|
||||
ctx context.Context,
|
||||
nowFn func() time.Time,
|
||||
idleTimeout time.Duration,
|
||||
lastActivity *atomic.Int64,
|
||||
onTrace func(event RelayTraceEvent),
|
||||
exitCh chan<- relayExitSignal,
|
||||
) {
|
||||
if idleTimeout <= 0 {
|
||||
return
|
||||
}
|
||||
checkInterval := minDuration(idleTimeout/4, 5*time.Second)
|
||||
if checkInterval < time.Second {
|
||||
checkInterval = time.Second
|
||||
}
|
||||
ticker := time.NewTicker(checkInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
last := time.Unix(0, lastActivity.Load())
|
||||
if nowFn().Sub(last) < idleTimeout {
|
||||
continue
|
||||
}
|
||||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||
Stage: "idle_timeout_triggered",
|
||||
Direction: "watchdog",
|
||||
Error: context.DeadlineExceeded.Error(),
|
||||
})
|
||||
exitCh <- relayExitSignal{stage: "idle_timeout", err: context.DeadlineExceeded}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func emitRelayTrace(onTrace func(event RelayTraceEvent), event RelayTraceEvent) {
|
||||
if onTrace == nil {
|
||||
return
|
||||
}
|
||||
onTrace(event)
|
||||
}
|
||||
|
||||
func relayMessageTypeString(msgType coderws.MessageType) string {
|
||||
switch msgType {
|
||||
case coderws.MessageText:
|
||||
return "text"
|
||||
case coderws.MessageBinary:
|
||||
return "binary"
|
||||
default:
|
||||
return "unknown(" + strconv.Itoa(int(msgType)) + ")"
|
||||
}
|
||||
}
|
||||
|
||||
func relayDirectionFromStage(stage string) string {
|
||||
switch stage {
|
||||
case "read_client", "write_upstream":
|
||||
return "client_to_upstream"
|
||||
case "read_upstream", "write_client", "drain_terminal":
|
||||
return "upstream_to_client"
|
||||
case "idle_timeout":
|
||||
return "watchdog"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func relayErrorString(err error) string {
|
||||
if err == nil {
|
||||
return ""
|
||||
}
|
||||
return err.Error()
|
||||
}
|
||||
|
||||
func observeUpstreamMessage(
|
||||
state *relayState,
|
||||
message []byte,
|
||||
startAt time.Time,
|
||||
nowFn func() time.Time,
|
||||
onUsageParseFailure func(eventType string, usageRaw string),
|
||||
) observedUpstreamEvent {
|
||||
if state == nil || len(message) == 0 {
|
||||
return observedUpstreamEvent{}
|
||||
}
|
||||
values := gjson.GetManyBytes(message, "type", "response.id", "response_id", "id")
|
||||
eventType := strings.TrimSpace(values[0].String())
|
||||
if eventType == "" {
|
||||
return observedUpstreamEvent{}
|
||||
}
|
||||
responseID := strings.TrimSpace(values[1].String())
|
||||
if responseID == "" {
|
||||
responseID = strings.TrimSpace(values[2].String())
|
||||
}
|
||||
// 仅 terminal 事件兜底读取顶层 id,避免把 event_id 当成 response_id 关联到 turn。
|
||||
if responseID == "" && isTerminalEvent(eventType) {
|
||||
responseID = strings.TrimSpace(values[3].String())
|
||||
}
|
||||
now := nowFn()
|
||||
|
||||
if state.firstTokenMs == nil && isTokenEvent(eventType) {
|
||||
ms := int(now.Sub(startAt).Milliseconds())
|
||||
if ms >= 0 {
|
||||
state.firstTokenMs = &ms
|
||||
}
|
||||
}
|
||||
parsedUsage := parseUsageAndAccumulate(state, message, eventType, onUsageParseFailure)
|
||||
observed := observedUpstreamEvent{
|
||||
eventType: eventType,
|
||||
responseID: responseID,
|
||||
usage: parsedUsage,
|
||||
}
|
||||
if responseID != "" {
|
||||
turnTiming := openAIWSRelayGetOrInitTurnTiming(state, responseID, now)
|
||||
if turnTiming != nil && turnTiming.firstTokenMs == nil && isTokenEvent(eventType) {
|
||||
ms := int(now.Sub(turnTiming.startAt).Milliseconds())
|
||||
if ms >= 0 {
|
||||
turnTiming.firstTokenMs = &ms
|
||||
}
|
||||
}
|
||||
}
|
||||
if !isTerminalEvent(eventType) {
|
||||
return observed
|
||||
}
|
||||
observed.terminal = true
|
||||
state.terminalEventType = eventType
|
||||
if responseID != "" {
|
||||
state.lastResponseID = responseID
|
||||
if turnTiming, ok := openAIWSRelayDeleteTurnTiming(state, responseID); ok {
|
||||
duration := now.Sub(turnTiming.startAt)
|
||||
if duration < 0 {
|
||||
duration = 0
|
||||
}
|
||||
observed.duration = duration
|
||||
observed.firstToken = openAIWSRelayCloneIntPtr(turnTiming.firstTokenMs)
|
||||
}
|
||||
}
|
||||
return observed
|
||||
}
|
||||
|
||||
func emitTurnComplete(
|
||||
onTurnComplete func(turn RelayTurnResult),
|
||||
state *relayState,
|
||||
observed observedUpstreamEvent,
|
||||
) {
|
||||
if onTurnComplete == nil || !observed.terminal {
|
||||
return
|
||||
}
|
||||
responseID := strings.TrimSpace(observed.responseID)
|
||||
if responseID == "" {
|
||||
return
|
||||
}
|
||||
requestModel := ""
|
||||
if state != nil {
|
||||
requestModel = state.requestModel
|
||||
}
|
||||
onTurnComplete(RelayTurnResult{
|
||||
RequestModel: requestModel,
|
||||
Usage: observed.usage,
|
||||
RequestID: responseID,
|
||||
TerminalEventType: observed.eventType,
|
||||
Duration: observed.duration,
|
||||
FirstTokenMs: openAIWSRelayCloneIntPtr(observed.firstToken),
|
||||
})
|
||||
}
|
||||
|
||||
func openAIWSRelayGetOrInitTurnTiming(state *relayState, responseID string, now time.Time) *relayTurnTiming {
|
||||
if state == nil {
|
||||
return nil
|
||||
}
|
||||
if state.turnTimingByID == nil {
|
||||
state.turnTimingByID = make(map[string]*relayTurnTiming, 8)
|
||||
}
|
||||
timing, ok := state.turnTimingByID[responseID]
|
||||
if !ok || timing == nil || timing.startAt.IsZero() {
|
||||
timing = &relayTurnTiming{startAt: now}
|
||||
state.turnTimingByID[responseID] = timing
|
||||
return timing
|
||||
}
|
||||
return timing
|
||||
}
|
||||
|
||||
func openAIWSRelayDeleteTurnTiming(state *relayState, responseID string) (relayTurnTiming, bool) {
|
||||
if state == nil || state.turnTimingByID == nil {
|
||||
return relayTurnTiming{}, false
|
||||
}
|
||||
timing, ok := state.turnTimingByID[responseID]
|
||||
if !ok || timing == nil {
|
||||
return relayTurnTiming{}, false
|
||||
}
|
||||
delete(state.turnTimingByID, responseID)
|
||||
return *timing, true
|
||||
}
|
||||
|
||||
func openAIWSRelayCloneIntPtr(v *int) *int {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *v
|
||||
return &cloned
|
||||
}
|
||||
|
||||
func parseUsageAndAccumulate(
|
||||
state *relayState,
|
||||
message []byte,
|
||||
eventType string,
|
||||
onParseFailure func(eventType string, usageRaw string),
|
||||
) Usage {
|
||||
if state == nil || len(message) == 0 || !shouldParseUsage(eventType) {
|
||||
return Usage{}
|
||||
}
|
||||
usageResult := gjson.GetBytes(message, "response.usage")
|
||||
if !usageResult.Exists() {
|
||||
return Usage{}
|
||||
}
|
||||
usageRaw := strings.TrimSpace(usageResult.Raw)
|
||||
if usageRaw == "" || !strings.HasPrefix(usageRaw, "{") {
|
||||
recordUsageParseFailure()
|
||||
if onParseFailure != nil {
|
||||
onParseFailure(eventType, usageRaw)
|
||||
}
|
||||
return Usage{}
|
||||
}
|
||||
|
||||
inputResult := gjson.GetBytes(message, "response.usage.input_tokens")
|
||||
outputResult := gjson.GetBytes(message, "response.usage.output_tokens")
|
||||
cachedResult := gjson.GetBytes(message, "response.usage.input_tokens_details.cached_tokens")
|
||||
|
||||
inputTokens, inputOK := parseUsageIntField(inputResult, true)
|
||||
outputTokens, outputOK := parseUsageIntField(outputResult, true)
|
||||
cachedTokens, cachedOK := parseUsageIntField(cachedResult, false)
|
||||
if !inputOK || !outputOK || !cachedOK {
|
||||
recordUsageParseFailure()
|
||||
if onParseFailure != nil {
|
||||
onParseFailure(eventType, usageRaw)
|
||||
}
|
||||
// 解析失败时不做部分字段累加,避免计费 usage 出现“半有效”状态。
|
||||
return Usage{}
|
||||
}
|
||||
parsedUsage := Usage{
|
||||
InputTokens: inputTokens,
|
||||
OutputTokens: outputTokens,
|
||||
CacheReadInputTokens: cachedTokens,
|
||||
}
|
||||
|
||||
state.usage.InputTokens += parsedUsage.InputTokens
|
||||
state.usage.OutputTokens += parsedUsage.OutputTokens
|
||||
state.usage.CacheReadInputTokens += parsedUsage.CacheReadInputTokens
|
||||
return parsedUsage
|
||||
}
|
||||
|
||||
func parseUsageIntField(value gjson.Result, required bool) (int, bool) {
|
||||
if !value.Exists() {
|
||||
return 0, !required
|
||||
}
|
||||
if value.Type != gjson.Number {
|
||||
return 0, false
|
||||
}
|
||||
return int(value.Int()), true
|
||||
}
|
||||
|
||||
func enrichResult(result *RelayResult, state *relayState, duration time.Duration) {
|
||||
if result == nil {
|
||||
return
|
||||
}
|
||||
result.Duration = duration
|
||||
if state == nil {
|
||||
return
|
||||
}
|
||||
result.RequestModel = state.requestModel
|
||||
result.Usage = state.usage
|
||||
result.RequestID = state.lastResponseID
|
||||
result.TerminalEventType = state.terminalEventType
|
||||
result.FirstTokenMs = state.firstTokenMs
|
||||
}
|
||||
|
||||
func isDisconnectError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) {
|
||||
return true
|
||||
}
|
||||
switch coderws.CloseStatus(err) {
|
||||
case coderws.StatusNormalClosure, coderws.StatusGoingAway, coderws.StatusNoStatusRcvd, coderws.StatusAbnormalClosure:
|
||||
return true
|
||||
}
|
||||
message := strings.ToLower(strings.TrimSpace(err.Error()))
|
||||
if message == "" {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(message, "failed to read frame header: eof") ||
|
||||
strings.Contains(message, "unexpected eof") ||
|
||||
strings.Contains(message, "use of closed network connection") ||
|
||||
strings.Contains(message, "connection reset by peer") ||
|
||||
strings.Contains(message, "broken pipe")
|
||||
}
|
||||
|
||||
func isTerminalEvent(eventType string) bool {
|
||||
switch eventType {
|
||||
case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func shouldParseUsage(eventType string) bool {
|
||||
switch eventType {
|
||||
case "response.completed", "response.done", "response.failed":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func isTokenEvent(eventType string) bool {
|
||||
if eventType == "" {
|
||||
return false
|
||||
}
|
||||
switch eventType {
|
||||
case "response.created", "response.in_progress", "response.output_item.added", "response.output_item.done":
|
||||
return false
|
||||
}
|
||||
if strings.Contains(eventType, ".delta") {
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(eventType, "response.output_text") {
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(eventType, "response.output") {
|
||||
return true
|
||||
}
|
||||
return eventType == "response.completed" || eventType == "response.done"
|
||||
}
|
||||
|
||||
func minDuration(a, b time.Duration) time.Duration {
|
||||
if a <= 0 {
|
||||
return b
|
||||
}
|
||||
if b <= 0 {
|
||||
return a
|
||||
}
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func waitRelayExit(exitCh <-chan relayExitSignal, timeout time.Duration) (relayExitSignal, bool) {
|
||||
if timeout <= 0 {
|
||||
timeout = 200 * time.Millisecond
|
||||
}
|
||||
select {
|
||||
case sig := <-exitCh:
|
||||
return sig, true
|
||||
case <-time.After(timeout):
|
||||
return relayExitSignal{}, false
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user