fix(openai): fail over before responses stream output

This commit is contained in:
AyeSt0
2026-04-25 15:09:40 +08:00
parent 641e61073f
commit 5b63a9b02d
2 changed files with 428 additions and 18 deletions

View File

@@ -3147,6 +3147,113 @@ type openaiStreamingResultPassthrough struct {
firstTokenMs *int
}
func openAIStreamClientOutputStarted(c *gin.Context, localStarted bool) bool {
if localStarted {
return true
}
return c != nil && c.Writer != nil && c.Writer.Written()
}
func openAIStreamEventIsPreamble(eventType string) bool {
switch strings.TrimSpace(eventType) {
case "response.created", "response.in_progress":
return true
default:
return false
}
}
func openAIStreamDataStartsClientOutput(data, eventType string) bool {
trimmed := strings.TrimSpace(data)
if trimmed == "" {
return false
}
if strings.TrimSpace(eventType) == "response.failed" {
return false
}
return !openAIStreamEventIsPreamble(eventType)
}
func openAIStreamFailedEventShouldFailover(payload []byte, message string) bool {
code := strings.ToLower(strings.TrimSpace(gjson.GetBytes(payload, "response.error.code").String()))
if code == "" {
code = strings.ToLower(strings.TrimSpace(gjson.GetBytes(payload, "error.code").String()))
}
errType := strings.ToLower(strings.TrimSpace(gjson.GetBytes(payload, "response.error.type").String()))
if errType == "" {
errType = strings.ToLower(strings.TrimSpace(gjson.GetBytes(payload, "error.type").String()))
}
combined := strings.ToLower(strings.TrimSpace(message + " " + code + " " + errType))
if combined == "" {
return true
}
nonRetryableMarkers := []string{
"invalid_request",
"content_policy",
"policy",
"safety",
"high-risk cyber",
"not allowed",
"violat",
}
for _, marker := range nonRetryableMarkers {
if strings.Contains(combined, marker) {
return false
}
}
return true
}
func (s *OpenAIGatewayService) newOpenAIStreamFailoverError(
c *gin.Context,
account *Account,
passthrough bool,
upstreamRequestID string,
payload []byte,
message string,
) *UpstreamFailoverError {
message = sanitizeUpstreamErrorMessage(strings.TrimSpace(message))
if message == "" {
message = "OpenAI stream disconnected before completion"
}
detail := ""
if len(payload) > 0 && s != nil && s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
detail = truncateString(string(payload), maxBytes)
}
if c != nil {
setOpsUpstreamError(c, http.StatusBadGateway, message, detail)
event := OpsUpstreamErrorEvent{
Platform: PlatformOpenAI,
UpstreamStatusCode: http.StatusBadGateway,
UpstreamRequestID: strings.TrimSpace(upstreamRequestID),
Passthrough: passthrough,
Kind: "failover",
Message: message,
Detail: detail,
}
if account != nil {
event.Platform = account.Platform
event.AccountID = account.ID
event.AccountName = account.Name
}
appendOpsUpstreamError(c, event)
}
body, _ := json.Marshal(gin.H{
"error": gin.H{
"type": "upstream_error",
"message": message,
},
})
return &UpstreamFailoverError{
StatusCode: http.StatusBadGateway,
ResponseBody: body,
}
}
func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
ctx context.Context,
resp *http.Response,
@@ -3178,7 +3285,22 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
clientDisconnected := false
sawDone := false
sawTerminalEvent := false
sawFailedEvent := false
failedMessage := ""
clientOutputStarted := false
upstreamRequestID := strings.TrimSpace(resp.Header.Get("x-request-id"))
pendingLines := make([]string, 0, 8)
writePendingLines := func() bool {
for _, pending := range pendingLines {
if _, err := fmt.Fprintln(w, pending); err != nil {
clientDisconnected = true
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID)
return false
}
}
pendingLines = pendingLines[:0]
return true
}
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
@@ -3193,6 +3315,8 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
for scanner.Scan() {
line := scanner.Text()
lineStartsClientOutput := false
forceFlushFailedEvent := false
if data, ok := extractOpenAISSEDataLine(line); ok {
dataBytes := []byte(data)
trimmedData := strings.TrimSpace(data)
@@ -3203,13 +3327,24 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
trimmedData = strings.TrimSpace(replacedData)
}
}
eventType := strings.TrimSpace(gjson.Get(trimmedData, "type").String())
if eventType == "response.failed" {
failedMessage = extractOpenAISSEErrorMessage(dataBytes)
if !openAIStreamClientOutputStarted(c, clientOutputStarted) && openAIStreamFailedEventShouldFailover(dataBytes, failedMessage) {
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs},
s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, dataBytes, failedMessage)
}
forceFlushFailedEvent = true
sawFailedEvent = true
}
if trimmedData == "[DONE]" {
sawDone = true
}
if openAIStreamEventIsTerminal(trimmedData) {
sawTerminalEvent = true
}
if firstTokenMs == nil && trimmedData != "" && trimmedData != "[DONE]" {
lineStartsClientOutput = forceFlushFailedEvent || openAIStreamDataStartsClientOutput(trimmedData, eventType)
if firstTokenMs == nil && lineStartsClientOutput && trimmedData != "[DONE]" {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
@@ -3217,20 +3352,30 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
}
if !clientDisconnected {
if !clientOutputStarted && !lineStartsClientOutput {
pendingLines = append(pendingLines, line)
continue
}
if !clientOutputStarted && len(pendingLines) > 0 {
if !writePendingLines() {
continue
}
}
if _, err := fmt.Fprintln(w, line); err != nil {
clientDisconnected = true
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID)
} else {
clientOutputStarted = true
flusher.Flush()
}
}
}
if err := scanner.Err(); err != nil {
if sawTerminalEvent {
if sawTerminalEvent && !sawFailedEvent {
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
}
if clientDisconnected {
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete after disconnect: %w", err)
if sawFailedEvent {
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("upstream response failed: %s", failedMessage)
}
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete: %w", err)
@@ -3239,6 +3384,17 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err)
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, err
}
if !openAIStreamClientOutputStarted(c, clientOutputStarted) {
msg := "OpenAI stream disconnected before completion"
if errText := strings.TrimSpace(err.Error()); errText != "" {
msg += ": " + errText
}
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs},
s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, nil, msg)
}
if clientDisconnected {
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete after disconnect: %w", err)
}
logger.LegacyPrintf("service.openai_gateway",
"[OpenAI passthrough] 流读取异常中断: account=%d request_id=%s err=%v",
account.ID,
@@ -3247,12 +3403,19 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
)
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
}
if sawFailedEvent {
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("upstream response failed: %s", failedMessage)
}
if !clientDisconnected && !sawDone && !sawTerminalEvent && ctx.Err() == nil {
logger.FromContext(ctx).With(
zap.String("component", "service.openai_gateway"),
zap.Int64("account_id", account.ID),
zap.String("upstream_request_id", upstreamRequestID),
).Info("OpenAI passthrough 上游流在未收到 [DONE] 时结束,疑似断流")
if !openAIStreamClientOutputStarted(c, clientOutputStarted) {
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs},
s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, nil, "OpenAI stream ended before a terminal event")
}
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, errors.New("stream usage incomplete: missing terminal event")
}
@@ -3854,6 +4017,11 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
errorEventSent := false
clientDisconnected := false // 客户端断开后继续 drain 上游以收集 usage
sawTerminalEvent := false
sawFailedEvent := false
failedMessage := ""
clientOutputStarted := false
upstreamRequestID := strings.TrimSpace(resp.Header.Get("x-request-id"))
var streamFailoverErr error
sendErrorEvent := func(reason string) {
if errorEventSent || clientDisconnected {
return
@@ -3870,7 +4038,9 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
}
if err := flushBuffered(); err != nil {
clientDisconnected = true
return
}
clientOutputStarted = true
}
needModelReplace := originalModel != mappedModel
@@ -3878,43 +4048,72 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}
}
finalizeStream := func() (*openaiStreamingResult, error) {
if !sawTerminalEvent {
if !openAIStreamClientOutputStarted(c, clientOutputStarted) {
return resultWithUsage(), s.newOpenAIStreamFailoverError(
c,
account,
false,
upstreamRequestID,
nil,
"OpenAI stream ended before a terminal event",
)
}
return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event")
}
if sawFailedEvent {
return resultWithUsage(), fmt.Errorf("upstream response failed: %s", failedMessage)
}
if !clientDisconnected {
hadBufferedData := bufferedWriter.Buffered() > 0
if err := flushBuffered(); err != nil {
clientDisconnected = true
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during final flush, returning collected usage")
} else if hadBufferedData {
clientOutputStarted = true
}
}
if !sawTerminalEvent {
return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event")
}
return resultWithUsage(), nil
}
handleScanErr := func(scanErr error) (*openaiStreamingResult, error, bool) {
if scanErr == nil {
return nil, nil, false
}
if sawTerminalEvent {
if sawTerminalEvent && !sawFailedEvent {
logger.LegacyPrintf("service.openai_gateway", "Upstream scan ended after terminal event: %v", scanErr)
return resultWithUsage(), nil, true
}
if sawFailedEvent {
return resultWithUsage(), fmt.Errorf("upstream response failed: %s", failedMessage), true
}
// 客户端断开/取消请求时,上游读取往往会返回 context canceled。
// /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event避免下游 SDK 解析失败。
if errors.Is(scanErr, context.Canceled) || errors.Is(scanErr, context.DeadlineExceeded) {
return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", scanErr), true
}
// 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage
if clientDisconnected {
return resultWithUsage(), fmt.Errorf("stream usage incomplete after disconnect: %w", scanErr), true
}
if errors.Is(scanErr, bufio.ErrTooLong) {
logger.LegacyPrintf("service.openai_gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, scanErr)
sendErrorEvent("response_too_large")
return resultWithUsage(), scanErr, true
}
if !openAIStreamClientOutputStarted(c, clientOutputStarted) {
msg := "OpenAI stream disconnected before completion"
if errText := strings.TrimSpace(scanErr.Error()); errText != "" {
msg += ": " + errText
}
return resultWithUsage(), s.newOpenAIStreamFailoverError(c, account, false, upstreamRequestID, nil, msg), true
}
// 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage
if clientDisconnected {
return resultWithUsage(), fmt.Errorf("stream usage incomplete after disconnect: %w", scanErr), true
}
sendErrorEvent("stream_read_error")
return resultWithUsage(), fmt.Errorf("stream read error: %w", scanErr), true
}
processSSELine := func(line string, queueDrained bool) {
if streamFailoverErr != nil {
return
}
lastDataAt = time.Now()
// Extract data from SSE line (supports both "data: " and "data:" formats)
@@ -3930,18 +4129,32 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
if openAIStreamEventIsTerminal(data) {
sawTerminalEvent = true
}
eventType := strings.TrimSpace(gjson.GetBytes(dataBytes, "type").String())
forceFlushFailedEvent := false
if eventType == "response.failed" {
failedMessage = extractOpenAISSEErrorMessage(dataBytes)
if !openAIStreamClientOutputStarted(c, clientOutputStarted) && openAIStreamFailedEventShouldFailover(dataBytes, failedMessage) {
sawFailedEvent = true
streamFailoverErr = s.newOpenAIStreamFailoverError(c, account, false, upstreamRequestID, dataBytes, failedMessage)
return
}
forceFlushFailedEvent = true
sawFailedEvent = true
}
// Correct Codex tool calls if needed (apply_patch -> edit, etc.)
if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEBytes(dataBytes); corrected {
dataBytes = correctedData
data = string(correctedData)
line = "data: " + data
eventType = strings.TrimSpace(gjson.GetBytes(dataBytes, "type").String())
}
startsClientOutput := forceFlushFailedEvent || openAIStreamDataStartsClientOutput(data, eventType)
// 写入客户端(客户端断开后继续 drain 上游)
if !clientDisconnected {
shouldFlush := queueDrained
if firstTokenMs == nil && data != "" && data != "[DONE]" {
shouldFlush := queueDrained && (clientOutputStarted || startsClientOutput)
if firstTokenMs == nil && startsClientOutput {
// 保证首个 token 事件尽快出站,避免影响 TTFT。
shouldFlush = true
}
@@ -3955,12 +4168,14 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
if err := flushBuffered(); err != nil {
clientDisconnected = true
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming flush, continuing to drain upstream for billing")
} else {
clientOutputStarted = true
}
}
}
// Record first token time
if firstTokenMs == nil && data != "" && data != "[DONE]" {
if firstTokenMs == nil && startsClientOutput {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
@@ -3976,10 +4191,12 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
} else if _, err := bufferedWriter.WriteString("\n"); err != nil {
clientDisconnected = true
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing")
} else if queueDrained {
} else if queueDrained && clientOutputStarted {
if err := flushBuffered(); err != nil {
clientDisconnected = true
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming flush, continuing to drain upstream for billing")
} else {
clientOutputStarted = true
}
}
}
@@ -3990,6 +4207,9 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
defer putSSEScannerBuf64K(scanBuf)
for scanner.Scan() {
processSSELine(scanner.Text(), true)
if streamFailoverErr != nil {
return resultWithUsage(), streamFailoverErr
}
}
if result, err, done := handleScanErr(scanner.Err()); done {
return result, err
@@ -4039,6 +4259,9 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
return result, err
}
processSSELine(ev.line, len(events) == 0)
if streamFailoverErr != nil {
return resultWithUsage(), streamFailoverErr
}
case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))

View File

@@ -93,6 +93,13 @@ type cancelReadCloser struct{}
func (c cancelReadCloser) Read(p []byte) (int, error) { return 0, context.Canceled }
func (c cancelReadCloser) Close() error { return nil }
type errReadCloser struct {
err error
}
func (r errReadCloser) Read([]byte) (int, error) { return 0, r.err }
func (r errReadCloser) Close() error { return nil }
type failingGinWriter struct {
gin.ResponseWriter
failAfter int
@@ -1003,6 +1010,150 @@ func TestOpenAIStreamingContextCanceledReturnsIncompleteErrorWithoutInjectingErr
}
}
func TestOpenAIStreamingReadErrorBeforeOutputReturnsFailover(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
StreamKeepaliveInterval: 0,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
resp := &http.Response{
StatusCode: http.StatusOK,
Body: errReadCloser{err: io.ErrUnexpectedEOF},
Header: http.Header{"X-Request-Id": []string{"rid-disconnect"}},
}
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
require.Error(t, err)
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr)
require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode)
require.False(t, c.Writer.Written())
require.Empty(t, rec.Body.String())
}
func TestOpenAIStreamingResponseFailedBeforeOutputReturnsFailover(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
StreamKeepaliveInterval: 0,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
"event: response.created",
`data: {"type":"response.created","response":{"id":"resp_1"}}`,
"",
"event: response.in_progress",
`data: {"type":"response.in_progress","response":{"id":"resp_1"}}`,
"",
"event: response.failed",
`data: {"type":"response.failed","error":{"message":"An error occurred while processing your request."}}`,
"",
}, "\n"))),
Header: http.Header{"X-Request-Id": []string{"rid-failed"}},
}
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
require.Error(t, err)
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr)
require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode)
require.Contains(t, string(failoverErr.ResponseBody), "An error occurred while processing your request")
require.False(t, c.Writer.Written())
require.Empty(t, rec.Body.String())
}
func TestOpenAIStreamingPreambleOnlyMissingTerminalReturnsFailover(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
StreamKeepaliveInterval: 0,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
"event: response.created",
`data: {"type":"response.created","response":{"id":"resp_1"}}`,
"",
"event: response.in_progress",
`data: {"type":"response.in_progress","response":{"id":"resp_1"}}`,
"",
}, "\n"))),
Header: http.Header{"X-Request-Id": []string{"rid-missing-terminal"}},
}
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
require.Error(t, err)
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr)
require.False(t, c.Writer.Written())
require.Empty(t, rec.Body.String())
}
func TestOpenAIStreamingPolicyResponseFailedBeforeOutputPassesThrough(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
StreamKeepaliveInterval: 0,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
"event: response.created",
`data: {"type":"response.created","response":{"id":"resp_1"}}`,
"",
"event: response.failed",
`data: {"type":"response.failed","error":{"type":"safety_error","message":"This request has been flagged for potentially high-risk cyber activity."}}`,
"",
}, "\n"))),
Header: http.Header{"X-Request-Id": []string{"rid-policy-failed"}},
}
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
require.Error(t, err)
var failoverErr *UpstreamFailoverError
require.False(t, errors.As(err, &failoverErr))
require.True(t, c.Writer.Written())
require.Contains(t, rec.Body.String(), "response.failed")
require.Contains(t, rec.Body.String(), "high-risk cyber activity")
}
func TestOpenAIStreamingClientDisconnectDrainsUpstreamUsage(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
@@ -1072,7 +1223,7 @@ func TestOpenAIStreamingMissingTerminalEventReturnsIncompleteError(t *testing.T)
go func() {
defer func() { _ = pw.Close() }()
_, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n"))
_, _ = pw.Write([]byte("data: {\"type\":\"response.output_item.added\",\"item\":{\"type\":\"message\"},\"output_index\":0}\n\n"))
}()
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
@@ -1104,7 +1255,7 @@ func TestOpenAIStreamingPassthroughMissingTerminalEventReturnsIncompleteError(t
go func() {
defer func() { _ = pw.Close() }()
_, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n"))
_, _ = pw.Write([]byte("data: {\"type\":\"response.output_item.added\",\"item\":{\"type\":\"message\"},\"output_index\":0}\n\n"))
}()
_, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "", "")
@@ -1114,6 +1265,42 @@ func TestOpenAIStreamingPassthroughMissingTerminalEventReturnsIncompleteError(t
}
}
func TestOpenAIStreamingPassthroughResponseFailedBeforeOutputReturnsFailover(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
"event: response.created",
`data: {"type":"response.created","response":{"id":"resp_1"}}`,
"",
"event: response.failed",
`data: {"type":"response.failed","error":{"message":"upstream processing failed"}}`,
"",
}, "\n"))),
Header: http.Header{"X-Request-Id": []string{"rid-passthrough-failed"}},
}
_, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "", "")
require.Error(t, err)
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr)
require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode)
require.Contains(t, string(failoverErr.ResponseBody), "upstream processing failed")
require.False(t, c.Writer.Written())
require.Empty(t, rec.Body.String())
}
func TestOpenAIStreamingPassthroughResponseDoneWithoutDoneMarkerStillSucceeds(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{