fix(openai): fail over before responses stream output

This commit is contained in:
AyeSt0
2026-04-25 15:09:40 +08:00
parent 641e61073f
commit 5b63a9b02d
2 changed files with 428 additions and 18 deletions

View File

@@ -93,6 +93,13 @@ type cancelReadCloser struct{}
func (c cancelReadCloser) Read(p []byte) (int, error) { return 0, context.Canceled }
func (c cancelReadCloser) Close() error { return nil }
type errReadCloser struct {
err error
}
func (r errReadCloser) Read([]byte) (int, error) { return 0, r.err }
func (r errReadCloser) Close() error { return nil }
type failingGinWriter struct {
gin.ResponseWriter
failAfter int
@@ -1003,6 +1010,150 @@ func TestOpenAIStreamingContextCanceledReturnsIncompleteErrorWithoutInjectingErr
}
}
func TestOpenAIStreamingReadErrorBeforeOutputReturnsFailover(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
StreamKeepaliveInterval: 0,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
resp := &http.Response{
StatusCode: http.StatusOK,
Body: errReadCloser{err: io.ErrUnexpectedEOF},
Header: http.Header{"X-Request-Id": []string{"rid-disconnect"}},
}
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
require.Error(t, err)
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr)
require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode)
require.False(t, c.Writer.Written())
require.Empty(t, rec.Body.String())
}
func TestOpenAIStreamingResponseFailedBeforeOutputReturnsFailover(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
StreamKeepaliveInterval: 0,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
"event: response.created",
`data: {"type":"response.created","response":{"id":"resp_1"}}`,
"",
"event: response.in_progress",
`data: {"type":"response.in_progress","response":{"id":"resp_1"}}`,
"",
"event: response.failed",
`data: {"type":"response.failed","error":{"message":"An error occurred while processing your request."}}`,
"",
}, "\n"))),
Header: http.Header{"X-Request-Id": []string{"rid-failed"}},
}
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
require.Error(t, err)
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr)
require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode)
require.Contains(t, string(failoverErr.ResponseBody), "An error occurred while processing your request")
require.False(t, c.Writer.Written())
require.Empty(t, rec.Body.String())
}
func TestOpenAIStreamingPreambleOnlyMissingTerminalReturnsFailover(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
StreamKeepaliveInterval: 0,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
"event: response.created",
`data: {"type":"response.created","response":{"id":"resp_1"}}`,
"",
"event: response.in_progress",
`data: {"type":"response.in_progress","response":{"id":"resp_1"}}`,
"",
}, "\n"))),
Header: http.Header{"X-Request-Id": []string{"rid-missing-terminal"}},
}
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
require.Error(t, err)
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr)
require.False(t, c.Writer.Written())
require.Empty(t, rec.Body.String())
}
func TestOpenAIStreamingPolicyResponseFailedBeforeOutputPassesThrough(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
StreamKeepaliveInterval: 0,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
"event: response.created",
`data: {"type":"response.created","response":{"id":"resp_1"}}`,
"",
"event: response.failed",
`data: {"type":"response.failed","error":{"type":"safety_error","message":"This request has been flagged for potentially high-risk cyber activity."}}`,
"",
}, "\n"))),
Header: http.Header{"X-Request-Id": []string{"rid-policy-failed"}},
}
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
require.Error(t, err)
var failoverErr *UpstreamFailoverError
require.False(t, errors.As(err, &failoverErr))
require.True(t, c.Writer.Written())
require.Contains(t, rec.Body.String(), "response.failed")
require.Contains(t, rec.Body.String(), "high-risk cyber activity")
}
func TestOpenAIStreamingClientDisconnectDrainsUpstreamUsage(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
@@ -1072,7 +1223,7 @@ func TestOpenAIStreamingMissingTerminalEventReturnsIncompleteError(t *testing.T)
go func() {
defer func() { _ = pw.Close() }()
_, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n"))
_, _ = pw.Write([]byte("data: {\"type\":\"response.output_item.added\",\"item\":{\"type\":\"message\"},\"output_index\":0}\n\n"))
}()
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
@@ -1104,7 +1255,7 @@ func TestOpenAIStreamingPassthroughMissingTerminalEventReturnsIncompleteError(t
go func() {
defer func() { _ = pw.Close() }()
_, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n"))
_, _ = pw.Write([]byte("data: {\"type\":\"response.output_item.added\",\"item\":{\"type\":\"message\"},\"output_index\":0}\n\n"))
}()
_, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "", "")
@@ -1114,6 +1265,42 @@ func TestOpenAIStreamingPassthroughMissingTerminalEventReturnsIncompleteError(t
}
}
func TestOpenAIStreamingPassthroughResponseFailedBeforeOutputReturnsFailover(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
"event: response.created",
`data: {"type":"response.created","response":{"id":"resp_1"}}`,
"",
"event: response.failed",
`data: {"type":"response.failed","error":{"message":"upstream processing failed"}}`,
"",
}, "\n"))),
Header: http.Header{"X-Request-Id": []string{"rid-passthrough-failed"}},
}
_, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "", "")
require.Error(t, err)
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr)
require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode)
require.Contains(t, string(failoverErr.ResponseBody), "upstream processing failed")
require.False(t, c.Writer.Written())
require.Empty(t, rec.Body.String())
}
func TestOpenAIStreamingPassthroughResponseDoneWithoutDoneMarkerStillSucceeds(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{