新增 OpenAI WebSocket v2 passthrough relay 数据面与服务适配层, 支持按账号 ws mode 在 ctx_pool 与 passthrough 间路由。 同步调整前端 OpenAI ws mode 选项为 off/ctx_pool/passthrough, 并补充 i18n 文案与对应单测。 新增 Caddyfile.dmit 与 docker-compose-aicodex.yml 部署配置, 用于宿主机场景下的反向代理与服务编排。 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
808 lines
22 KiB
Go
808 lines
22 KiB
Go
package openai_ws_v2
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"io"
|
||
"net"
|
||
"strconv"
|
||
"strings"
|
||
"sync/atomic"
|
||
"time"
|
||
|
||
coderws "github.com/coder/websocket"
|
||
"github.com/tidwall/gjson"
|
||
)
|
||
|
||
type FrameConn interface {
|
||
ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error)
|
||
WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error
|
||
Close() error
|
||
}
|
||
|
||
type Usage struct {
|
||
InputTokens int
|
||
OutputTokens int
|
||
CacheCreationInputTokens int
|
||
CacheReadInputTokens int
|
||
}
|
||
|
||
type RelayResult struct {
|
||
RequestModel string
|
||
Usage Usage
|
||
RequestID string
|
||
TerminalEventType string
|
||
FirstTokenMs *int
|
||
Duration time.Duration
|
||
ClientToUpstreamFrames int64
|
||
UpstreamToClientFrames int64
|
||
DroppedDownstreamFrames int64
|
||
}
|
||
|
||
type RelayTurnResult struct {
|
||
RequestModel string
|
||
Usage Usage
|
||
RequestID string
|
||
TerminalEventType string
|
||
Duration time.Duration
|
||
FirstTokenMs *int
|
||
}
|
||
|
||
type RelayExit struct {
|
||
Stage string
|
||
Err error
|
||
WroteDownstream bool
|
||
}
|
||
|
||
type RelayOptions struct {
|
||
WriteTimeout time.Duration
|
||
IdleTimeout time.Duration
|
||
UpstreamDrainTimeout time.Duration
|
||
FirstMessageType coderws.MessageType
|
||
OnUsageParseFailure func(eventType string, usageRaw string)
|
||
OnTurnComplete func(turn RelayTurnResult)
|
||
OnTrace func(event RelayTraceEvent)
|
||
Now func() time.Time
|
||
}
|
||
|
||
type RelayTraceEvent struct {
|
||
Stage string
|
||
Direction string
|
||
MessageType string
|
||
PayloadBytes int
|
||
Graceful bool
|
||
WroteDownstream bool
|
||
Error string
|
||
}
|
||
|
||
type relayState struct {
|
||
usage Usage
|
||
requestModel string
|
||
lastResponseID string
|
||
terminalEventType string
|
||
firstTokenMs *int
|
||
turnTimingByID map[string]*relayTurnTiming
|
||
}
|
||
|
||
type relayExitSignal struct {
|
||
stage string
|
||
err error
|
||
graceful bool
|
||
wroteDownstream bool
|
||
}
|
||
|
||
type observedUpstreamEvent struct {
|
||
terminal bool
|
||
eventType string
|
||
responseID string
|
||
usage Usage
|
||
duration time.Duration
|
||
firstToken *int
|
||
}
|
||
|
||
type relayTurnTiming struct {
|
||
startAt time.Time
|
||
firstTokenMs *int
|
||
}
|
||
|
||
func Relay(
|
||
ctx context.Context,
|
||
clientConn FrameConn,
|
||
upstreamConn FrameConn,
|
||
firstClientMessage []byte,
|
||
options RelayOptions,
|
||
) (RelayResult, *RelayExit) {
|
||
result := RelayResult{RequestModel: strings.TrimSpace(gjson.GetBytes(firstClientMessage, "model").String())}
|
||
if clientConn == nil || upstreamConn == nil {
|
||
return result, &RelayExit{Stage: "relay_init", Err: errors.New("relay connection is nil")}
|
||
}
|
||
if ctx == nil {
|
||
ctx = context.Background()
|
||
}
|
||
|
||
nowFn := options.Now
|
||
if nowFn == nil {
|
||
nowFn = time.Now
|
||
}
|
||
writeTimeout := options.WriteTimeout
|
||
if writeTimeout <= 0 {
|
||
writeTimeout = 2 * time.Minute
|
||
}
|
||
drainTimeout := options.UpstreamDrainTimeout
|
||
if drainTimeout <= 0 {
|
||
drainTimeout = 1200 * time.Millisecond
|
||
}
|
||
firstMessageType := options.FirstMessageType
|
||
if firstMessageType != coderws.MessageBinary {
|
||
firstMessageType = coderws.MessageText
|
||
}
|
||
startAt := nowFn()
|
||
state := &relayState{requestModel: result.RequestModel}
|
||
onTrace := options.OnTrace
|
||
|
||
relayCtx, relayCancel := context.WithCancel(ctx)
|
||
defer relayCancel()
|
||
|
||
lastActivity := atomic.Int64{}
|
||
lastActivity.Store(nowFn().UnixNano())
|
||
markActivity := func() {
|
||
lastActivity.Store(nowFn().UnixNano())
|
||
}
|
||
|
||
writeUpstream := func(msgType coderws.MessageType, payload []byte) error {
|
||
writeCtx, cancel := context.WithTimeout(relayCtx, writeTimeout)
|
||
defer cancel()
|
||
return upstreamConn.WriteFrame(writeCtx, msgType, payload)
|
||
}
|
||
writeClient := func(msgType coderws.MessageType, payload []byte) error {
|
||
writeCtx, cancel := context.WithTimeout(relayCtx, writeTimeout)
|
||
defer cancel()
|
||
return clientConn.WriteFrame(writeCtx, msgType, payload)
|
||
}
|
||
|
||
clientToUpstreamFrames := &atomic.Int64{}
|
||
upstreamToClientFrames := &atomic.Int64{}
|
||
droppedDownstreamFrames := &atomic.Int64{}
|
||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||
Stage: "relay_start",
|
||
PayloadBytes: len(firstClientMessage),
|
||
MessageType: relayMessageTypeString(firstMessageType),
|
||
})
|
||
|
||
if err := writeUpstream(firstMessageType, firstClientMessage); err != nil {
|
||
result.Duration = nowFn().Sub(startAt)
|
||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||
Stage: "write_first_message_failed",
|
||
Direction: "client_to_upstream",
|
||
MessageType: relayMessageTypeString(firstMessageType),
|
||
PayloadBytes: len(firstClientMessage),
|
||
Error: err.Error(),
|
||
})
|
||
return result, &RelayExit{Stage: "write_upstream", Err: err}
|
||
}
|
||
clientToUpstreamFrames.Add(1)
|
||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||
Stage: "write_first_message_ok",
|
||
Direction: "client_to_upstream",
|
||
MessageType: relayMessageTypeString(firstMessageType),
|
||
PayloadBytes: len(firstClientMessage),
|
||
})
|
||
markActivity()
|
||
|
||
exitCh := make(chan relayExitSignal, 3)
|
||
dropDownstreamWrites := atomic.Bool{}
|
||
go runClientToUpstream(relayCtx, clientConn, writeUpstream, markActivity, clientToUpstreamFrames, onTrace, exitCh)
|
||
go runUpstreamToClient(
|
||
relayCtx,
|
||
upstreamConn,
|
||
writeClient,
|
||
startAt,
|
||
nowFn,
|
||
state,
|
||
options.OnUsageParseFailure,
|
||
options.OnTurnComplete,
|
||
&dropDownstreamWrites,
|
||
upstreamToClientFrames,
|
||
droppedDownstreamFrames,
|
||
markActivity,
|
||
onTrace,
|
||
exitCh,
|
||
)
|
||
go runIdleWatchdog(relayCtx, nowFn, options.IdleTimeout, &lastActivity, onTrace, exitCh)
|
||
|
||
firstExit := <-exitCh
|
||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||
Stage: "first_exit",
|
||
Direction: relayDirectionFromStage(firstExit.stage),
|
||
Graceful: firstExit.graceful,
|
||
WroteDownstream: firstExit.wroteDownstream,
|
||
Error: relayErrorString(firstExit.err),
|
||
})
|
||
combinedWroteDownstream := firstExit.wroteDownstream
|
||
secondExit := relayExitSignal{graceful: true}
|
||
hasSecondExit := false
|
||
|
||
// 客户端断开后尽力继续读取上游短窗口,捕获延迟 usage/terminal 事件用于计费。
|
||
if firstExit.stage == "read_client" && firstExit.graceful {
|
||
dropDownstreamWrites.Store(true)
|
||
secondExit, hasSecondExit = waitRelayExit(exitCh, drainTimeout)
|
||
} else {
|
||
relayCancel()
|
||
_ = upstreamConn.Close()
|
||
secondExit, hasSecondExit = waitRelayExit(exitCh, 200*time.Millisecond)
|
||
}
|
||
if hasSecondExit {
|
||
combinedWroteDownstream = combinedWroteDownstream || secondExit.wroteDownstream
|
||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||
Stage: "second_exit",
|
||
Direction: relayDirectionFromStage(secondExit.stage),
|
||
Graceful: secondExit.graceful,
|
||
WroteDownstream: secondExit.wroteDownstream,
|
||
Error: relayErrorString(secondExit.err),
|
||
})
|
||
}
|
||
|
||
relayCancel()
|
||
_ = upstreamConn.Close()
|
||
|
||
enrichResult(&result, state, nowFn().Sub(startAt))
|
||
result.ClientToUpstreamFrames = clientToUpstreamFrames.Load()
|
||
result.UpstreamToClientFrames = upstreamToClientFrames.Load()
|
||
result.DroppedDownstreamFrames = droppedDownstreamFrames.Load()
|
||
if firstExit.stage == "read_client" && firstExit.graceful {
|
||
stage := "client_disconnected"
|
||
exitErr := firstExit.err
|
||
if hasSecondExit && !secondExit.graceful {
|
||
stage = secondExit.stage
|
||
exitErr = secondExit.err
|
||
}
|
||
if exitErr == nil {
|
||
exitErr = io.EOF
|
||
}
|
||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||
Stage: "relay_exit",
|
||
Direction: relayDirectionFromStage(stage),
|
||
Graceful: false,
|
||
WroteDownstream: combinedWroteDownstream,
|
||
Error: relayErrorString(exitErr),
|
||
})
|
||
return result, &RelayExit{
|
||
Stage: stage,
|
||
Err: exitErr,
|
||
WroteDownstream: combinedWroteDownstream,
|
||
}
|
||
}
|
||
if firstExit.graceful && (!hasSecondExit || secondExit.graceful) {
|
||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||
Stage: "relay_complete",
|
||
Graceful: true,
|
||
WroteDownstream: combinedWroteDownstream,
|
||
})
|
||
_ = clientConn.Close()
|
||
return result, nil
|
||
}
|
||
if !firstExit.graceful {
|
||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||
Stage: "relay_exit",
|
||
Direction: relayDirectionFromStage(firstExit.stage),
|
||
Graceful: false,
|
||
WroteDownstream: combinedWroteDownstream,
|
||
Error: relayErrorString(firstExit.err),
|
||
})
|
||
return result, &RelayExit{
|
||
Stage: firstExit.stage,
|
||
Err: firstExit.err,
|
||
WroteDownstream: combinedWroteDownstream,
|
||
}
|
||
}
|
||
if hasSecondExit && !secondExit.graceful {
|
||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||
Stage: "relay_exit",
|
||
Direction: relayDirectionFromStage(secondExit.stage),
|
||
Graceful: false,
|
||
WroteDownstream: combinedWroteDownstream,
|
||
Error: relayErrorString(secondExit.err),
|
||
})
|
||
return result, &RelayExit{
|
||
Stage: secondExit.stage,
|
||
Err: secondExit.err,
|
||
WroteDownstream: combinedWroteDownstream,
|
||
}
|
||
}
|
||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||
Stage: "relay_complete",
|
||
Graceful: true,
|
||
WroteDownstream: combinedWroteDownstream,
|
||
})
|
||
_ = clientConn.Close()
|
||
return result, nil
|
||
}
|
||
|
||
func runClientToUpstream(
|
||
ctx context.Context,
|
||
clientConn FrameConn,
|
||
writeUpstream func(msgType coderws.MessageType, payload []byte) error,
|
||
markActivity func(),
|
||
forwardedFrames *atomic.Int64,
|
||
onTrace func(event RelayTraceEvent),
|
||
exitCh chan<- relayExitSignal,
|
||
) {
|
||
for {
|
||
msgType, payload, err := clientConn.ReadFrame(ctx)
|
||
if err != nil {
|
||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||
Stage: "read_client_failed",
|
||
Direction: "client_to_upstream",
|
||
Error: err.Error(),
|
||
Graceful: isDisconnectError(err),
|
||
})
|
||
exitCh <- relayExitSignal{stage: "read_client", err: err, graceful: isDisconnectError(err)}
|
||
return
|
||
}
|
||
markActivity()
|
||
if err := writeUpstream(msgType, payload); err != nil {
|
||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||
Stage: "write_upstream_failed",
|
||
Direction: "client_to_upstream",
|
||
MessageType: relayMessageTypeString(msgType),
|
||
PayloadBytes: len(payload),
|
||
Error: err.Error(),
|
||
})
|
||
exitCh <- relayExitSignal{stage: "write_upstream", err: err}
|
||
return
|
||
}
|
||
if forwardedFrames != nil {
|
||
forwardedFrames.Add(1)
|
||
}
|
||
markActivity()
|
||
}
|
||
}
|
||
|
||
func runUpstreamToClient(
|
||
ctx context.Context,
|
||
upstreamConn FrameConn,
|
||
writeClient func(msgType coderws.MessageType, payload []byte) error,
|
||
startAt time.Time,
|
||
nowFn func() time.Time,
|
||
state *relayState,
|
||
onUsageParseFailure func(eventType string, usageRaw string),
|
||
onTurnComplete func(turn RelayTurnResult),
|
||
dropDownstreamWrites *atomic.Bool,
|
||
forwardedFrames *atomic.Int64,
|
||
droppedFrames *atomic.Int64,
|
||
markActivity func(),
|
||
onTrace func(event RelayTraceEvent),
|
||
exitCh chan<- relayExitSignal,
|
||
) {
|
||
wroteDownstream := false
|
||
for {
|
||
msgType, payload, err := upstreamConn.ReadFrame(ctx)
|
||
if err != nil {
|
||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||
Stage: "read_upstream_failed",
|
||
Direction: "upstream_to_client",
|
||
Error: err.Error(),
|
||
Graceful: isDisconnectError(err),
|
||
WroteDownstream: wroteDownstream,
|
||
})
|
||
exitCh <- relayExitSignal{
|
||
stage: "read_upstream",
|
||
err: err,
|
||
graceful: isDisconnectError(err),
|
||
wroteDownstream: wroteDownstream,
|
||
}
|
||
return
|
||
}
|
||
markActivity()
|
||
observedEvent := observedUpstreamEvent{}
|
||
switch msgType {
|
||
case coderws.MessageText:
|
||
observedEvent = observeUpstreamMessage(state, payload, startAt, nowFn, onUsageParseFailure)
|
||
case coderws.MessageBinary:
|
||
// binary frame 直接透传,不进入 JSON 观测路径(避免无效解析开销)。
|
||
}
|
||
emitTurnComplete(onTurnComplete, state, observedEvent)
|
||
if dropDownstreamWrites != nil && dropDownstreamWrites.Load() {
|
||
if droppedFrames != nil {
|
||
droppedFrames.Add(1)
|
||
}
|
||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||
Stage: "drop_downstream_frame",
|
||
Direction: "upstream_to_client",
|
||
MessageType: relayMessageTypeString(msgType),
|
||
PayloadBytes: len(payload),
|
||
WroteDownstream: wroteDownstream,
|
||
})
|
||
if observedEvent.terminal {
|
||
exitCh <- relayExitSignal{
|
||
stage: "drain_terminal",
|
||
graceful: true,
|
||
wroteDownstream: wroteDownstream,
|
||
}
|
||
return
|
||
}
|
||
markActivity()
|
||
continue
|
||
}
|
||
if err := writeClient(msgType, payload); err != nil {
|
||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||
Stage: "write_client_failed",
|
||
Direction: "upstream_to_client",
|
||
MessageType: relayMessageTypeString(msgType),
|
||
PayloadBytes: len(payload),
|
||
WroteDownstream: wroteDownstream,
|
||
Error: err.Error(),
|
||
})
|
||
exitCh <- relayExitSignal{stage: "write_client", err: err, wroteDownstream: wroteDownstream}
|
||
return
|
||
}
|
||
wroteDownstream = true
|
||
if forwardedFrames != nil {
|
||
forwardedFrames.Add(1)
|
||
}
|
||
markActivity()
|
||
}
|
||
}
|
||
|
||
func runIdleWatchdog(
|
||
ctx context.Context,
|
||
nowFn func() time.Time,
|
||
idleTimeout time.Duration,
|
||
lastActivity *atomic.Int64,
|
||
onTrace func(event RelayTraceEvent),
|
||
exitCh chan<- relayExitSignal,
|
||
) {
|
||
if idleTimeout <= 0 {
|
||
return
|
||
}
|
||
checkInterval := minDuration(idleTimeout/4, 5*time.Second)
|
||
if checkInterval < time.Second {
|
||
checkInterval = time.Second
|
||
}
|
||
ticker := time.NewTicker(checkInterval)
|
||
defer ticker.Stop()
|
||
|
||
for {
|
||
select {
|
||
case <-ctx.Done():
|
||
return
|
||
case <-ticker.C:
|
||
last := time.Unix(0, lastActivity.Load())
|
||
if nowFn().Sub(last) < idleTimeout {
|
||
continue
|
||
}
|
||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||
Stage: "idle_timeout_triggered",
|
||
Direction: "watchdog",
|
||
Error: context.DeadlineExceeded.Error(),
|
||
})
|
||
exitCh <- relayExitSignal{stage: "idle_timeout", err: context.DeadlineExceeded}
|
||
return
|
||
}
|
||
}
|
||
}
|
||
|
||
func emitRelayTrace(onTrace func(event RelayTraceEvent), event RelayTraceEvent) {
|
||
if onTrace == nil {
|
||
return
|
||
}
|
||
onTrace(event)
|
||
}
|
||
|
||
func relayMessageTypeString(msgType coderws.MessageType) string {
|
||
switch msgType {
|
||
case coderws.MessageText:
|
||
return "text"
|
||
case coderws.MessageBinary:
|
||
return "binary"
|
||
default:
|
||
return "unknown(" + strconv.Itoa(int(msgType)) + ")"
|
||
}
|
||
}
|
||
|
||
func relayDirectionFromStage(stage string) string {
|
||
switch stage {
|
||
case "read_client", "write_upstream":
|
||
return "client_to_upstream"
|
||
case "read_upstream", "write_client", "drain_terminal":
|
||
return "upstream_to_client"
|
||
case "idle_timeout":
|
||
return "watchdog"
|
||
default:
|
||
return ""
|
||
}
|
||
}
|
||
|
||
func relayErrorString(err error) string {
|
||
if err == nil {
|
||
return ""
|
||
}
|
||
return err.Error()
|
||
}
|
||
|
||
func observeUpstreamMessage(
|
||
state *relayState,
|
||
message []byte,
|
||
startAt time.Time,
|
||
nowFn func() time.Time,
|
||
onUsageParseFailure func(eventType string, usageRaw string),
|
||
) observedUpstreamEvent {
|
||
if state == nil || len(message) == 0 {
|
||
return observedUpstreamEvent{}
|
||
}
|
||
values := gjson.GetManyBytes(message, "type", "response.id", "response_id", "id")
|
||
eventType := strings.TrimSpace(values[0].String())
|
||
if eventType == "" {
|
||
return observedUpstreamEvent{}
|
||
}
|
||
responseID := strings.TrimSpace(values[1].String())
|
||
if responseID == "" {
|
||
responseID = strings.TrimSpace(values[2].String())
|
||
}
|
||
// 仅 terminal 事件兜底读取顶层 id,避免把 event_id 当成 response_id 关联到 turn。
|
||
if responseID == "" && isTerminalEvent(eventType) {
|
||
responseID = strings.TrimSpace(values[3].String())
|
||
}
|
||
now := nowFn()
|
||
|
||
if state.firstTokenMs == nil && isTokenEvent(eventType) {
|
||
ms := int(now.Sub(startAt).Milliseconds())
|
||
if ms >= 0 {
|
||
state.firstTokenMs = &ms
|
||
}
|
||
}
|
||
parsedUsage := parseUsageAndAccumulate(state, message, eventType, onUsageParseFailure)
|
||
observed := observedUpstreamEvent{
|
||
eventType: eventType,
|
||
responseID: responseID,
|
||
usage: parsedUsage,
|
||
}
|
||
if responseID != "" {
|
||
turnTiming := openAIWSRelayGetOrInitTurnTiming(state, responseID, now)
|
||
if turnTiming != nil && turnTiming.firstTokenMs == nil && isTokenEvent(eventType) {
|
||
ms := int(now.Sub(turnTiming.startAt).Milliseconds())
|
||
if ms >= 0 {
|
||
turnTiming.firstTokenMs = &ms
|
||
}
|
||
}
|
||
}
|
||
if !isTerminalEvent(eventType) {
|
||
return observed
|
||
}
|
||
observed.terminal = true
|
||
state.terminalEventType = eventType
|
||
if responseID != "" {
|
||
state.lastResponseID = responseID
|
||
if turnTiming, ok := openAIWSRelayDeleteTurnTiming(state, responseID); ok {
|
||
duration := now.Sub(turnTiming.startAt)
|
||
if duration < 0 {
|
||
duration = 0
|
||
}
|
||
observed.duration = duration
|
||
observed.firstToken = openAIWSRelayCloneIntPtr(turnTiming.firstTokenMs)
|
||
}
|
||
}
|
||
return observed
|
||
}
|
||
|
||
func emitTurnComplete(
|
||
onTurnComplete func(turn RelayTurnResult),
|
||
state *relayState,
|
||
observed observedUpstreamEvent,
|
||
) {
|
||
if onTurnComplete == nil || !observed.terminal {
|
||
return
|
||
}
|
||
responseID := strings.TrimSpace(observed.responseID)
|
||
if responseID == "" {
|
||
return
|
||
}
|
||
requestModel := ""
|
||
if state != nil {
|
||
requestModel = state.requestModel
|
||
}
|
||
onTurnComplete(RelayTurnResult{
|
||
RequestModel: requestModel,
|
||
Usage: observed.usage,
|
||
RequestID: responseID,
|
||
TerminalEventType: observed.eventType,
|
||
Duration: observed.duration,
|
||
FirstTokenMs: openAIWSRelayCloneIntPtr(observed.firstToken),
|
||
})
|
||
}
|
||
|
||
func openAIWSRelayGetOrInitTurnTiming(state *relayState, responseID string, now time.Time) *relayTurnTiming {
|
||
if state == nil {
|
||
return nil
|
||
}
|
||
if state.turnTimingByID == nil {
|
||
state.turnTimingByID = make(map[string]*relayTurnTiming, 8)
|
||
}
|
||
timing, ok := state.turnTimingByID[responseID]
|
||
if !ok || timing == nil || timing.startAt.IsZero() {
|
||
timing = &relayTurnTiming{startAt: now}
|
||
state.turnTimingByID[responseID] = timing
|
||
return timing
|
||
}
|
||
return timing
|
||
}
|
||
|
||
func openAIWSRelayDeleteTurnTiming(state *relayState, responseID string) (relayTurnTiming, bool) {
|
||
if state == nil || state.turnTimingByID == nil {
|
||
return relayTurnTiming{}, false
|
||
}
|
||
timing, ok := state.turnTimingByID[responseID]
|
||
if !ok || timing == nil {
|
||
return relayTurnTiming{}, false
|
||
}
|
||
delete(state.turnTimingByID, responseID)
|
||
return *timing, true
|
||
}
|
||
|
||
func openAIWSRelayCloneIntPtr(v *int) *int {
|
||
if v == nil {
|
||
return nil
|
||
}
|
||
cloned := *v
|
||
return &cloned
|
||
}
|
||
|
||
func parseUsageAndAccumulate(
|
||
state *relayState,
|
||
message []byte,
|
||
eventType string,
|
||
onParseFailure func(eventType string, usageRaw string),
|
||
) Usage {
|
||
if state == nil || len(message) == 0 || !shouldParseUsage(eventType) {
|
||
return Usage{}
|
||
}
|
||
usageResult := gjson.GetBytes(message, "response.usage")
|
||
if !usageResult.Exists() {
|
||
return Usage{}
|
||
}
|
||
usageRaw := strings.TrimSpace(usageResult.Raw)
|
||
if usageRaw == "" || !strings.HasPrefix(usageRaw, "{") {
|
||
recordUsageParseFailure()
|
||
if onParseFailure != nil {
|
||
onParseFailure(eventType, usageRaw)
|
||
}
|
||
return Usage{}
|
||
}
|
||
|
||
inputResult := gjson.GetBytes(message, "response.usage.input_tokens")
|
||
outputResult := gjson.GetBytes(message, "response.usage.output_tokens")
|
||
cachedResult := gjson.GetBytes(message, "response.usage.input_tokens_details.cached_tokens")
|
||
|
||
inputTokens, inputOK := parseUsageIntField(inputResult, true)
|
||
outputTokens, outputOK := parseUsageIntField(outputResult, true)
|
||
cachedTokens, cachedOK := parseUsageIntField(cachedResult, false)
|
||
if !inputOK || !outputOK || !cachedOK {
|
||
recordUsageParseFailure()
|
||
if onParseFailure != nil {
|
||
onParseFailure(eventType, usageRaw)
|
||
}
|
||
// 解析失败时不做部分字段累加,避免计费 usage 出现“半有效”状态。
|
||
return Usage{}
|
||
}
|
||
parsedUsage := Usage{
|
||
InputTokens: inputTokens,
|
||
OutputTokens: outputTokens,
|
||
CacheReadInputTokens: cachedTokens,
|
||
}
|
||
|
||
state.usage.InputTokens += parsedUsage.InputTokens
|
||
state.usage.OutputTokens += parsedUsage.OutputTokens
|
||
state.usage.CacheReadInputTokens += parsedUsage.CacheReadInputTokens
|
||
return parsedUsage
|
||
}
|
||
|
||
func parseUsageIntField(value gjson.Result, required bool) (int, bool) {
|
||
if !value.Exists() {
|
||
return 0, !required
|
||
}
|
||
if value.Type != gjson.Number {
|
||
return 0, false
|
||
}
|
||
return int(value.Int()), true
|
||
}
|
||
|
||
func enrichResult(result *RelayResult, state *relayState, duration time.Duration) {
|
||
if result == nil {
|
||
return
|
||
}
|
||
result.Duration = duration
|
||
if state == nil {
|
||
return
|
||
}
|
||
result.RequestModel = state.requestModel
|
||
result.Usage = state.usage
|
||
result.RequestID = state.lastResponseID
|
||
result.TerminalEventType = state.terminalEventType
|
||
result.FirstTokenMs = state.firstTokenMs
|
||
}
|
||
|
||
func isDisconnectError(err error) bool {
|
||
if err == nil {
|
||
return false
|
||
}
|
||
if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) {
|
||
return true
|
||
}
|
||
switch coderws.CloseStatus(err) {
|
||
case coderws.StatusNormalClosure, coderws.StatusGoingAway, coderws.StatusNoStatusRcvd, coderws.StatusAbnormalClosure:
|
||
return true
|
||
}
|
||
message := strings.ToLower(strings.TrimSpace(err.Error()))
|
||
if message == "" {
|
||
return false
|
||
}
|
||
return strings.Contains(message, "failed to read frame header: eof") ||
|
||
strings.Contains(message, "unexpected eof") ||
|
||
strings.Contains(message, "use of closed network connection") ||
|
||
strings.Contains(message, "connection reset by peer") ||
|
||
strings.Contains(message, "broken pipe")
|
||
}
|
||
|
||
func isTerminalEvent(eventType string) bool {
|
||
switch eventType {
|
||
case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled":
|
||
return true
|
||
default:
|
||
return false
|
||
}
|
||
}
|
||
|
||
func shouldParseUsage(eventType string) bool {
|
||
switch eventType {
|
||
case "response.completed", "response.done", "response.failed":
|
||
return true
|
||
default:
|
||
return false
|
||
}
|
||
}
|
||
|
||
func isTokenEvent(eventType string) bool {
|
||
if eventType == "" {
|
||
return false
|
||
}
|
||
switch eventType {
|
||
case "response.created", "response.in_progress", "response.output_item.added", "response.output_item.done":
|
||
return false
|
||
}
|
||
if strings.Contains(eventType, ".delta") {
|
||
return true
|
||
}
|
||
if strings.HasPrefix(eventType, "response.output_text") {
|
||
return true
|
||
}
|
||
if strings.HasPrefix(eventType, "response.output") {
|
||
return true
|
||
}
|
||
return eventType == "response.completed" || eventType == "response.done"
|
||
}
|
||
|
||
func minDuration(a, b time.Duration) time.Duration {
|
||
if a <= 0 {
|
||
return b
|
||
}
|
||
if b <= 0 {
|
||
return a
|
||
}
|
||
if a < b {
|
||
return a
|
||
}
|
||
return b
|
||
}
|
||
|
||
func waitRelayExit(exitCh <-chan relayExitSignal, timeout time.Duration) (relayExitSignal, bool) {
|
||
if timeout <= 0 {
|
||
timeout = 200 * time.Millisecond
|
||
}
|
||
select {
|
||
case sig := <-exitCh:
|
||
return sig, true
|
||
case <-time.After(timeout):
|
||
return relayExitSignal{}, false
|
||
}
|
||
}
|