diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 5db273b4..75a92f6e 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -3147,6 +3147,113 @@ type openaiStreamingResultPassthrough struct { firstTokenMs *int } +func openAIStreamClientOutputStarted(c *gin.Context, localStarted bool) bool { + if localStarted { + return true + } + return c != nil && c.Writer != nil && c.Writer.Written() +} + +func openAIStreamEventIsPreamble(eventType string) bool { + switch strings.TrimSpace(eventType) { + case "response.created", "response.in_progress": + return true + default: + return false + } +} + +func openAIStreamDataStartsClientOutput(data, eventType string) bool { + trimmed := strings.TrimSpace(data) + if trimmed == "" { + return false + } + if strings.TrimSpace(eventType) == "response.failed" { + return false + } + return !openAIStreamEventIsPreamble(eventType) +} + +func openAIStreamFailedEventShouldFailover(payload []byte, message string) bool { + code := strings.ToLower(strings.TrimSpace(gjson.GetBytes(payload, "response.error.code").String())) + if code == "" { + code = strings.ToLower(strings.TrimSpace(gjson.GetBytes(payload, "error.code").String())) + } + errType := strings.ToLower(strings.TrimSpace(gjson.GetBytes(payload, "response.error.type").String())) + if errType == "" { + errType = strings.ToLower(strings.TrimSpace(gjson.GetBytes(payload, "error.type").String())) + } + combined := strings.ToLower(strings.TrimSpace(message + " " + code + " " + errType)) + if combined == "" { + return true + } + nonRetryableMarkers := []string{ + "invalid_request", + "content_policy", + "policy", + "safety", + "high-risk cyber", + "not allowed", + "violat", + } + for _, marker := range nonRetryableMarkers { + if strings.Contains(combined, marker) { + return false + } + } + return true +} + +func (s *OpenAIGatewayService) newOpenAIStreamFailoverError( + c *gin.Context, + account *Account, + passthrough bool, + upstreamRequestID string, + payload []byte, + message string, +) *UpstreamFailoverError { + message = sanitizeUpstreamErrorMessage(strings.TrimSpace(message)) + if message == "" { + message = "OpenAI stream disconnected before completion" + } + detail := "" + if len(payload) > 0 && s != nil && s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + detail = truncateString(string(payload), maxBytes) + } + if c != nil { + setOpsUpstreamError(c, http.StatusBadGateway, message, detail) + event := OpsUpstreamErrorEvent{ + Platform: PlatformOpenAI, + UpstreamStatusCode: http.StatusBadGateway, + UpstreamRequestID: strings.TrimSpace(upstreamRequestID), + Passthrough: passthrough, + Kind: "failover", + Message: message, + Detail: detail, + } + if account != nil { + event.Platform = account.Platform + event.AccountID = account.ID + event.AccountName = account.Name + } + appendOpsUpstreamError(c, event) + } + body, _ := json.Marshal(gin.H{ + "error": gin.H{ + "type": "upstream_error", + "message": message, + }, + }) + return &UpstreamFailoverError{ + StatusCode: http.StatusBadGateway, + ResponseBody: body, + } +} + func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( ctx context.Context, resp *http.Response, @@ -3178,7 +3285,22 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( clientDisconnected := false sawDone := false sawTerminalEvent := false + sawFailedEvent := false + failedMessage := "" + clientOutputStarted := false upstreamRequestID := strings.TrimSpace(resp.Header.Get("x-request-id")) + pendingLines := make([]string, 0, 8) + writePendingLines := func() bool { + for _, pending := range pendingLines { + if _, err := fmt.Fprintln(w, pending); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID) + return false + } + } + pendingLines = pendingLines[:0] + return true + } scanner := bufio.NewScanner(resp.Body) maxLineSize := defaultMaxLineSize @@ -3193,6 +3315,8 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( for scanner.Scan() { line := scanner.Text() + lineStartsClientOutput := false + forceFlushFailedEvent := false if data, ok := extractOpenAISSEDataLine(line); ok { dataBytes := []byte(data) trimmedData := strings.TrimSpace(data) @@ -3203,13 +3327,24 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( trimmedData = strings.TrimSpace(replacedData) } } + eventType := strings.TrimSpace(gjson.Get(trimmedData, "type").String()) + if eventType == "response.failed" { + failedMessage = extractOpenAISSEErrorMessage(dataBytes) + if !openAIStreamClientOutputStarted(c, clientOutputStarted) && openAIStreamFailedEventShouldFailover(dataBytes, failedMessage) { + return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, + s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, dataBytes, failedMessage) + } + forceFlushFailedEvent = true + sawFailedEvent = true + } if trimmedData == "[DONE]" { sawDone = true } if openAIStreamEventIsTerminal(trimmedData) { sawTerminalEvent = true } - if firstTokenMs == nil && trimmedData != "" && trimmedData != "[DONE]" { + lineStartsClientOutput = forceFlushFailedEvent || openAIStreamDataStartsClientOutput(trimmedData, eventType) + if firstTokenMs == nil && lineStartsClientOutput && trimmedData != "[DONE]" { ms := int(time.Since(startTime).Milliseconds()) firstTokenMs = &ms } @@ -3217,20 +3352,30 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( } if !clientDisconnected { + if !clientOutputStarted && !lineStartsClientOutput { + pendingLines = append(pendingLines, line) + continue + } + if !clientOutputStarted && len(pendingLines) > 0 { + if !writePendingLines() { + continue + } + } if _, err := fmt.Fprintln(w, line); err != nil { clientDisconnected = true logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID) } else { + clientOutputStarted = true flusher.Flush() } } } if err := scanner.Err(); err != nil { - if sawTerminalEvent { + if sawTerminalEvent && !sawFailedEvent { return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil } - if clientDisconnected { - return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete after disconnect: %w", err) + if sawFailedEvent { + return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("upstream response failed: %s", failedMessage) } if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete: %w", err) @@ -3239,6 +3384,17 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err) return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, err } + if !openAIStreamClientOutputStarted(c, clientOutputStarted) { + msg := "OpenAI stream disconnected before completion" + if errText := strings.TrimSpace(err.Error()); errText != "" { + msg += ": " + errText + } + return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, + s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, nil, msg) + } + if clientDisconnected { + return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete after disconnect: %w", err) + } logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] 流读取异常中断: account=%d request_id=%s err=%v", account.ID, @@ -3247,12 +3403,19 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( ) return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err) } + if sawFailedEvent { + return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("upstream response failed: %s", failedMessage) + } if !clientDisconnected && !sawDone && !sawTerminalEvent && ctx.Err() == nil { logger.FromContext(ctx).With( zap.String("component", "service.openai_gateway"), zap.Int64("account_id", account.ID), zap.String("upstream_request_id", upstreamRequestID), ).Info("OpenAI passthrough 上游流在未收到 [DONE] 时结束,疑似断流") + if !openAIStreamClientOutputStarted(c, clientOutputStarted) { + return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, + s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, nil, "OpenAI stream ended before a terminal event") + } return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, errors.New("stream usage incomplete: missing terminal event") } @@ -3854,6 +4017,11 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp errorEventSent := false clientDisconnected := false // 客户端断开后继续 drain 上游以收集 usage sawTerminalEvent := false + sawFailedEvent := false + failedMessage := "" + clientOutputStarted := false + upstreamRequestID := strings.TrimSpace(resp.Header.Get("x-request-id")) + var streamFailoverErr error sendErrorEvent := func(reason string) { if errorEventSent || clientDisconnected { return @@ -3870,7 +4038,9 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp } if err := flushBuffered(); err != nil { clientDisconnected = true + return } + clientOutputStarted = true } needModelReplace := originalModel != mappedModel @@ -3878,43 +4048,72 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs} } finalizeStream := func() (*openaiStreamingResult, error) { + if !sawTerminalEvent { + if !openAIStreamClientOutputStarted(c, clientOutputStarted) { + return resultWithUsage(), s.newOpenAIStreamFailoverError( + c, + account, + false, + upstreamRequestID, + nil, + "OpenAI stream ended before a terminal event", + ) + } + return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event") + } + if sawFailedEvent { + return resultWithUsage(), fmt.Errorf("upstream response failed: %s", failedMessage) + } if !clientDisconnected { + hadBufferedData := bufferedWriter.Buffered() > 0 if err := flushBuffered(); err != nil { clientDisconnected = true logger.LegacyPrintf("service.openai_gateway", "Client disconnected during final flush, returning collected usage") + } else if hadBufferedData { + clientOutputStarted = true } } - if !sawTerminalEvent { - return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event") - } return resultWithUsage(), nil } handleScanErr := func(scanErr error) (*openaiStreamingResult, error, bool) { if scanErr == nil { return nil, nil, false } - if sawTerminalEvent { + if sawTerminalEvent && !sawFailedEvent { logger.LegacyPrintf("service.openai_gateway", "Upstream scan ended after terminal event: %v", scanErr) return resultWithUsage(), nil, true } + if sawFailedEvent { + return resultWithUsage(), fmt.Errorf("upstream response failed: %s", failedMessage), true + } // 客户端断开/取消请求时,上游读取往往会返回 context canceled。 // /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。 if errors.Is(scanErr, context.Canceled) || errors.Is(scanErr, context.DeadlineExceeded) { return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", scanErr), true } - // 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage - if clientDisconnected { - return resultWithUsage(), fmt.Errorf("stream usage incomplete after disconnect: %w", scanErr), true - } if errors.Is(scanErr, bufio.ErrTooLong) { logger.LegacyPrintf("service.openai_gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, scanErr) sendErrorEvent("response_too_large") return resultWithUsage(), scanErr, true } + if !openAIStreamClientOutputStarted(c, clientOutputStarted) { + msg := "OpenAI stream disconnected before completion" + if errText := strings.TrimSpace(scanErr.Error()); errText != "" { + msg += ": " + errText + } + return resultWithUsage(), s.newOpenAIStreamFailoverError(c, account, false, upstreamRequestID, nil, msg), true + } + // 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage + if clientDisconnected { + return resultWithUsage(), fmt.Errorf("stream usage incomplete after disconnect: %w", scanErr), true + } sendErrorEvent("stream_read_error") return resultWithUsage(), fmt.Errorf("stream read error: %w", scanErr), true } processSSELine := func(line string, queueDrained bool) { + if streamFailoverErr != nil { + return + } lastDataAt = time.Now() // Extract data from SSE line (supports both "data: " and "data:" formats) @@ -3930,18 +4129,32 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp if openAIStreamEventIsTerminal(data) { sawTerminalEvent = true } + eventType := strings.TrimSpace(gjson.GetBytes(dataBytes, "type").String()) + forceFlushFailedEvent := false + if eventType == "response.failed" { + failedMessage = extractOpenAISSEErrorMessage(dataBytes) + if !openAIStreamClientOutputStarted(c, clientOutputStarted) && openAIStreamFailedEventShouldFailover(dataBytes, failedMessage) { + sawFailedEvent = true + streamFailoverErr = s.newOpenAIStreamFailoverError(c, account, false, upstreamRequestID, dataBytes, failedMessage) + return + } + forceFlushFailedEvent = true + sawFailedEvent = true + } // Correct Codex tool calls if needed (apply_patch -> edit, etc.) if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEBytes(dataBytes); corrected { dataBytes = correctedData data = string(correctedData) line = "data: " + data + eventType = strings.TrimSpace(gjson.GetBytes(dataBytes, "type").String()) } + startsClientOutput := forceFlushFailedEvent || openAIStreamDataStartsClientOutput(data, eventType) // 写入客户端(客户端断开后继续 drain 上游) if !clientDisconnected { - shouldFlush := queueDrained - if firstTokenMs == nil && data != "" && data != "[DONE]" { + shouldFlush := queueDrained && (clientOutputStarted || startsClientOutput) + if firstTokenMs == nil && startsClientOutput { // 保证首个 token 事件尽快出站,避免影响 TTFT。 shouldFlush = true } @@ -3955,12 +4168,14 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp if err := flushBuffered(); err != nil { clientDisconnected = true logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming flush, continuing to drain upstream for billing") + } else { + clientOutputStarted = true } } } // Record first token time - if firstTokenMs == nil && data != "" && data != "[DONE]" { + if firstTokenMs == nil && startsClientOutput { ms := int(time.Since(startTime).Milliseconds()) firstTokenMs = &ms } @@ -3976,10 +4191,12 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp } else if _, err := bufferedWriter.WriteString("\n"); err != nil { clientDisconnected = true logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing") - } else if queueDrained { + } else if queueDrained && clientOutputStarted { if err := flushBuffered(); err != nil { clientDisconnected = true logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming flush, continuing to drain upstream for billing") + } else { + clientOutputStarted = true } } } @@ -3990,6 +4207,9 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp defer putSSEScannerBuf64K(scanBuf) for scanner.Scan() { processSSELine(scanner.Text(), true) + if streamFailoverErr != nil { + return resultWithUsage(), streamFailoverErr + } } if result, err, done := handleScanErr(scanner.Err()); done { return result, err @@ -4039,6 +4259,9 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp return result, err } processSSELine(ev.line, len(events) == 0) + if streamFailoverErr != nil { + return resultWithUsage(), streamFailoverErr + } case <-intervalCh: lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go index 8b7945bc..0cf2392d 100644 --- a/backend/internal/service/openai_gateway_service_test.go +++ b/backend/internal/service/openai_gateway_service_test.go @@ -93,6 +93,13 @@ type cancelReadCloser struct{} func (c cancelReadCloser) Read(p []byte) (int, error) { return 0, context.Canceled } func (c cancelReadCloser) Close() error { return nil } +type errReadCloser struct { + err error +} + +func (r errReadCloser) Read([]byte) (int, error) { return 0, r.err } +func (r errReadCloser) Close() error { return nil } + type failingGinWriter struct { gin.ResponseWriter failAfter int @@ -1003,6 +1010,150 @@ func TestOpenAIStreamingContextCanceledReturnsIncompleteErrorWithoutInjectingErr } } +func TestOpenAIStreamingReadErrorBeforeOutputReturnsFailover(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + StreamDataIntervalTimeout: 0, + StreamKeepaliveInterval: 0, + MaxLineSize: defaultMaxLineSize, + }, + } + svc := &OpenAIGatewayService{cfg: cfg} + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: errReadCloser{err: io.ErrUnexpectedEOF}, + Header: http.Header{"X-Request-Id": []string{"rid-disconnect"}}, + } + + _, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model") + require.Error(t, err) + var failoverErr *UpstreamFailoverError + require.ErrorAs(t, err, &failoverErr) + require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode) + require.False(t, c.Writer.Written()) + require.Empty(t, rec.Body.String()) +} + +func TestOpenAIStreamingResponseFailedBeforeOutputReturnsFailover(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + StreamDataIntervalTimeout: 0, + StreamKeepaliveInterval: 0, + MaxLineSize: defaultMaxLineSize, + }, + } + svc := &OpenAIGatewayService{cfg: cfg} + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + "event: response.created", + `data: {"type":"response.created","response":{"id":"resp_1"}}`, + "", + "event: response.in_progress", + `data: {"type":"response.in_progress","response":{"id":"resp_1"}}`, + "", + "event: response.failed", + `data: {"type":"response.failed","error":{"message":"An error occurred while processing your request."}}`, + "", + }, "\n"))), + Header: http.Header{"X-Request-Id": []string{"rid-failed"}}, + } + + _, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model") + require.Error(t, err) + var failoverErr *UpstreamFailoverError + require.ErrorAs(t, err, &failoverErr) + require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode) + require.Contains(t, string(failoverErr.ResponseBody), "An error occurred while processing your request") + require.False(t, c.Writer.Written()) + require.Empty(t, rec.Body.String()) +} + +func TestOpenAIStreamingPreambleOnlyMissingTerminalReturnsFailover(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + StreamDataIntervalTimeout: 0, + StreamKeepaliveInterval: 0, + MaxLineSize: defaultMaxLineSize, + }, + } + svc := &OpenAIGatewayService{cfg: cfg} + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + "event: response.created", + `data: {"type":"response.created","response":{"id":"resp_1"}}`, + "", + "event: response.in_progress", + `data: {"type":"response.in_progress","response":{"id":"resp_1"}}`, + "", + }, "\n"))), + Header: http.Header{"X-Request-Id": []string{"rid-missing-terminal"}}, + } + + _, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model") + require.Error(t, err) + var failoverErr *UpstreamFailoverError + require.ErrorAs(t, err, &failoverErr) + require.False(t, c.Writer.Written()) + require.Empty(t, rec.Body.String()) +} + +func TestOpenAIStreamingPolicyResponseFailedBeforeOutputPassesThrough(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + StreamDataIntervalTimeout: 0, + StreamKeepaliveInterval: 0, + MaxLineSize: defaultMaxLineSize, + }, + } + svc := &OpenAIGatewayService{cfg: cfg} + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + "event: response.created", + `data: {"type":"response.created","response":{"id":"resp_1"}}`, + "", + "event: response.failed", + `data: {"type":"response.failed","error":{"type":"safety_error","message":"This request has been flagged for potentially high-risk cyber activity."}}`, + "", + }, "\n"))), + Header: http.Header{"X-Request-Id": []string{"rid-policy-failed"}}, + } + + _, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model") + require.Error(t, err) + var failoverErr *UpstreamFailoverError + require.False(t, errors.As(err, &failoverErr)) + require.True(t, c.Writer.Written()) + require.Contains(t, rec.Body.String(), "response.failed") + require.Contains(t, rec.Body.String(), "high-risk cyber activity") +} + func TestOpenAIStreamingClientDisconnectDrainsUpstreamUsage(t *testing.T) { gin.SetMode(gin.TestMode) cfg := &config.Config{ @@ -1072,7 +1223,7 @@ func TestOpenAIStreamingMissingTerminalEventReturnsIncompleteError(t *testing.T) go func() { defer func() { _ = pw.Close() }() - _, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n")) + _, _ = pw.Write([]byte("data: {\"type\":\"response.output_item.added\",\"item\":{\"type\":\"message\"},\"output_index\":0}\n\n")) }() _, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model") @@ -1104,7 +1255,7 @@ func TestOpenAIStreamingPassthroughMissingTerminalEventReturnsIncompleteError(t go func() { defer func() { _ = pw.Close() }() - _, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n")) + _, _ = pw.Write([]byte("data: {\"type\":\"response.output_item.added\",\"item\":{\"type\":\"message\"},\"output_index\":0}\n\n")) }() _, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "", "") @@ -1114,6 +1265,42 @@ func TestOpenAIStreamingPassthroughMissingTerminalEventReturnsIncompleteError(t } } +func TestOpenAIStreamingPassthroughResponseFailedBeforeOutputReturnsFailover(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + MaxLineSize: defaultMaxLineSize, + }, + } + svc := &OpenAIGatewayService{cfg: cfg} + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + "event: response.created", + `data: {"type":"response.created","response":{"id":"resp_1"}}`, + "", + "event: response.failed", + `data: {"type":"response.failed","error":{"message":"upstream processing failed"}}`, + "", + }, "\n"))), + Header: http.Header{"X-Request-Id": []string{"rid-passthrough-failed"}}, + } + + _, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "", "") + require.Error(t, err) + var failoverErr *UpstreamFailoverError + require.ErrorAs(t, err, &failoverErr) + require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode) + require.Contains(t, string(failoverErr.ResponseBody), "upstream processing failed") + require.False(t, c.Writer.Written()) + require.Empty(t, rec.Body.String()) +} + func TestOpenAIStreamingPassthroughResponseDoneWithoutDoneMarkerStillSucceeds(t *testing.T) { gin.SetMode(gin.TestMode) cfg := &config.Config{