fix(openai): fail over before responses stream output
This commit is contained in:
@@ -3147,6 +3147,113 @@ type openaiStreamingResultPassthrough struct {
|
||||
firstTokenMs *int
|
||||
}
|
||||
|
||||
func openAIStreamClientOutputStarted(c *gin.Context, localStarted bool) bool {
|
||||
if localStarted {
|
||||
return true
|
||||
}
|
||||
return c != nil && c.Writer != nil && c.Writer.Written()
|
||||
}
|
||||
|
||||
func openAIStreamEventIsPreamble(eventType string) bool {
|
||||
switch strings.TrimSpace(eventType) {
|
||||
case "response.created", "response.in_progress":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func openAIStreamDataStartsClientOutput(data, eventType string) bool {
|
||||
trimmed := strings.TrimSpace(data)
|
||||
if trimmed == "" {
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(eventType) == "response.failed" {
|
||||
return false
|
||||
}
|
||||
return !openAIStreamEventIsPreamble(eventType)
|
||||
}
|
||||
|
||||
func openAIStreamFailedEventShouldFailover(payload []byte, message string) bool {
|
||||
code := strings.ToLower(strings.TrimSpace(gjson.GetBytes(payload, "response.error.code").String()))
|
||||
if code == "" {
|
||||
code = strings.ToLower(strings.TrimSpace(gjson.GetBytes(payload, "error.code").String()))
|
||||
}
|
||||
errType := strings.ToLower(strings.TrimSpace(gjson.GetBytes(payload, "response.error.type").String()))
|
||||
if errType == "" {
|
||||
errType = strings.ToLower(strings.TrimSpace(gjson.GetBytes(payload, "error.type").String()))
|
||||
}
|
||||
combined := strings.ToLower(strings.TrimSpace(message + " " + code + " " + errType))
|
||||
if combined == "" {
|
||||
return true
|
||||
}
|
||||
nonRetryableMarkers := []string{
|
||||
"invalid_request",
|
||||
"content_policy",
|
||||
"policy",
|
||||
"safety",
|
||||
"high-risk cyber",
|
||||
"not allowed",
|
||||
"violat",
|
||||
}
|
||||
for _, marker := range nonRetryableMarkers {
|
||||
if strings.Contains(combined, marker) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) newOpenAIStreamFailoverError(
|
||||
c *gin.Context,
|
||||
account *Account,
|
||||
passthrough bool,
|
||||
upstreamRequestID string,
|
||||
payload []byte,
|
||||
message string,
|
||||
) *UpstreamFailoverError {
|
||||
message = sanitizeUpstreamErrorMessage(strings.TrimSpace(message))
|
||||
if message == "" {
|
||||
message = "OpenAI stream disconnected before completion"
|
||||
}
|
||||
detail := ""
|
||||
if len(payload) > 0 && s != nil && s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 2048
|
||||
}
|
||||
detail = truncateString(string(payload), maxBytes)
|
||||
}
|
||||
if c != nil {
|
||||
setOpsUpstreamError(c, http.StatusBadGateway, message, detail)
|
||||
event := OpsUpstreamErrorEvent{
|
||||
Platform: PlatformOpenAI,
|
||||
UpstreamStatusCode: http.StatusBadGateway,
|
||||
UpstreamRequestID: strings.TrimSpace(upstreamRequestID),
|
||||
Passthrough: passthrough,
|
||||
Kind: "failover",
|
||||
Message: message,
|
||||
Detail: detail,
|
||||
}
|
||||
if account != nil {
|
||||
event.Platform = account.Platform
|
||||
event.AccountID = account.ID
|
||||
event.AccountName = account.Name
|
||||
}
|
||||
appendOpsUpstreamError(c, event)
|
||||
}
|
||||
body, _ := json.Marshal(gin.H{
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": message,
|
||||
},
|
||||
})
|
||||
return &UpstreamFailoverError{
|
||||
StatusCode: http.StatusBadGateway,
|
||||
ResponseBody: body,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
||||
ctx context.Context,
|
||||
resp *http.Response,
|
||||
@@ -3178,7 +3285,22 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
||||
clientDisconnected := false
|
||||
sawDone := false
|
||||
sawTerminalEvent := false
|
||||
sawFailedEvent := false
|
||||
failedMessage := ""
|
||||
clientOutputStarted := false
|
||||
upstreamRequestID := strings.TrimSpace(resp.Header.Get("x-request-id"))
|
||||
pendingLines := make([]string, 0, 8)
|
||||
writePendingLines := func() bool {
|
||||
for _, pending := range pendingLines {
|
||||
if _, err := fmt.Fprintln(w, pending); err != nil {
|
||||
clientDisconnected = true
|
||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID)
|
||||
return false
|
||||
}
|
||||
}
|
||||
pendingLines = pendingLines[:0]
|
||||
return true
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
maxLineSize := defaultMaxLineSize
|
||||
@@ -3193,6 +3315,8 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
lineStartsClientOutput := false
|
||||
forceFlushFailedEvent := false
|
||||
if data, ok := extractOpenAISSEDataLine(line); ok {
|
||||
dataBytes := []byte(data)
|
||||
trimmedData := strings.TrimSpace(data)
|
||||
@@ -3203,13 +3327,24 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
||||
trimmedData = strings.TrimSpace(replacedData)
|
||||
}
|
||||
}
|
||||
eventType := strings.TrimSpace(gjson.Get(trimmedData, "type").String())
|
||||
if eventType == "response.failed" {
|
||||
failedMessage = extractOpenAISSEErrorMessage(dataBytes)
|
||||
if !openAIStreamClientOutputStarted(c, clientOutputStarted) && openAIStreamFailedEventShouldFailover(dataBytes, failedMessage) {
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs},
|
||||
s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, dataBytes, failedMessage)
|
||||
}
|
||||
forceFlushFailedEvent = true
|
||||
sawFailedEvent = true
|
||||
}
|
||||
if trimmedData == "[DONE]" {
|
||||
sawDone = true
|
||||
}
|
||||
if openAIStreamEventIsTerminal(trimmedData) {
|
||||
sawTerminalEvent = true
|
||||
}
|
||||
if firstTokenMs == nil && trimmedData != "" && trimmedData != "[DONE]" {
|
||||
lineStartsClientOutput = forceFlushFailedEvent || openAIStreamDataStartsClientOutput(trimmedData, eventType)
|
||||
if firstTokenMs == nil && lineStartsClientOutput && trimmedData != "[DONE]" {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
@@ -3217,20 +3352,30 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
||||
}
|
||||
|
||||
if !clientDisconnected {
|
||||
if !clientOutputStarted && !lineStartsClientOutput {
|
||||
pendingLines = append(pendingLines, line)
|
||||
continue
|
||||
}
|
||||
if !clientOutputStarted && len(pendingLines) > 0 {
|
||||
if !writePendingLines() {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if _, err := fmt.Fprintln(w, line); err != nil {
|
||||
clientDisconnected = true
|
||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID)
|
||||
} else {
|
||||
clientOutputStarted = true
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
if sawTerminalEvent {
|
||||
if sawTerminalEvent && !sawFailedEvent {
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
}
|
||||
if clientDisconnected {
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete after disconnect: %w", err)
|
||||
if sawFailedEvent {
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("upstream response failed: %s", failedMessage)
|
||||
}
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete: %w", err)
|
||||
@@ -3239,6 +3384,17 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err)
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
}
|
||||
if !openAIStreamClientOutputStarted(c, clientOutputStarted) {
|
||||
msg := "OpenAI stream disconnected before completion"
|
||||
if errText := strings.TrimSpace(err.Error()); errText != "" {
|
||||
msg += ": " + errText
|
||||
}
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs},
|
||||
s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, nil, msg)
|
||||
}
|
||||
if clientDisconnected {
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete after disconnect: %w", err)
|
||||
}
|
||||
logger.LegacyPrintf("service.openai_gateway",
|
||||
"[OpenAI passthrough] 流读取异常中断: account=%d request_id=%s err=%v",
|
||||
account.ID,
|
||||
@@ -3247,12 +3403,19 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
||||
)
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
|
||||
}
|
||||
if sawFailedEvent {
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("upstream response failed: %s", failedMessage)
|
||||
}
|
||||
if !clientDisconnected && !sawDone && !sawTerminalEvent && ctx.Err() == nil {
|
||||
logger.FromContext(ctx).With(
|
||||
zap.String("component", "service.openai_gateway"),
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.String("upstream_request_id", upstreamRequestID),
|
||||
).Info("OpenAI passthrough 上游流在未收到 [DONE] 时结束,疑似断流")
|
||||
if !openAIStreamClientOutputStarted(c, clientOutputStarted) {
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs},
|
||||
s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, nil, "OpenAI stream ended before a terminal event")
|
||||
}
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, errors.New("stream usage incomplete: missing terminal event")
|
||||
}
|
||||
|
||||
@@ -3854,6 +4017,11 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
errorEventSent := false
|
||||
clientDisconnected := false // 客户端断开后继续 drain 上游以收集 usage
|
||||
sawTerminalEvent := false
|
||||
sawFailedEvent := false
|
||||
failedMessage := ""
|
||||
clientOutputStarted := false
|
||||
upstreamRequestID := strings.TrimSpace(resp.Header.Get("x-request-id"))
|
||||
var streamFailoverErr error
|
||||
sendErrorEvent := func(reason string) {
|
||||
if errorEventSent || clientDisconnected {
|
||||
return
|
||||
@@ -3870,7 +4038,9 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
}
|
||||
if err := flushBuffered(); err != nil {
|
||||
clientDisconnected = true
|
||||
return
|
||||
}
|
||||
clientOutputStarted = true
|
||||
}
|
||||
|
||||
needModelReplace := originalModel != mappedModel
|
||||
@@ -3878,43 +4048,72 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}
|
||||
}
|
||||
finalizeStream := func() (*openaiStreamingResult, error) {
|
||||
if !sawTerminalEvent {
|
||||
if !openAIStreamClientOutputStarted(c, clientOutputStarted) {
|
||||
return resultWithUsage(), s.newOpenAIStreamFailoverError(
|
||||
c,
|
||||
account,
|
||||
false,
|
||||
upstreamRequestID,
|
||||
nil,
|
||||
"OpenAI stream ended before a terminal event",
|
||||
)
|
||||
}
|
||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event")
|
||||
}
|
||||
if sawFailedEvent {
|
||||
return resultWithUsage(), fmt.Errorf("upstream response failed: %s", failedMessage)
|
||||
}
|
||||
if !clientDisconnected {
|
||||
hadBufferedData := bufferedWriter.Buffered() > 0
|
||||
if err := flushBuffered(); err != nil {
|
||||
clientDisconnected = true
|
||||
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during final flush, returning collected usage")
|
||||
} else if hadBufferedData {
|
||||
clientOutputStarted = true
|
||||
}
|
||||
}
|
||||
if !sawTerminalEvent {
|
||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event")
|
||||
}
|
||||
return resultWithUsage(), nil
|
||||
}
|
||||
handleScanErr := func(scanErr error) (*openaiStreamingResult, error, bool) {
|
||||
if scanErr == nil {
|
||||
return nil, nil, false
|
||||
}
|
||||
if sawTerminalEvent {
|
||||
if sawTerminalEvent && !sawFailedEvent {
|
||||
logger.LegacyPrintf("service.openai_gateway", "Upstream scan ended after terminal event: %v", scanErr)
|
||||
return resultWithUsage(), nil, true
|
||||
}
|
||||
if sawFailedEvent {
|
||||
return resultWithUsage(), fmt.Errorf("upstream response failed: %s", failedMessage), true
|
||||
}
|
||||
// 客户端断开/取消请求时,上游读取往往会返回 context canceled。
|
||||
// /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。
|
||||
if errors.Is(scanErr, context.Canceled) || errors.Is(scanErr, context.DeadlineExceeded) {
|
||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", scanErr), true
|
||||
}
|
||||
// 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage
|
||||
if clientDisconnected {
|
||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete after disconnect: %w", scanErr), true
|
||||
}
|
||||
if errors.Is(scanErr, bufio.ErrTooLong) {
|
||||
logger.LegacyPrintf("service.openai_gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, scanErr)
|
||||
sendErrorEvent("response_too_large")
|
||||
return resultWithUsage(), scanErr, true
|
||||
}
|
||||
if !openAIStreamClientOutputStarted(c, clientOutputStarted) {
|
||||
msg := "OpenAI stream disconnected before completion"
|
||||
if errText := strings.TrimSpace(scanErr.Error()); errText != "" {
|
||||
msg += ": " + errText
|
||||
}
|
||||
return resultWithUsage(), s.newOpenAIStreamFailoverError(c, account, false, upstreamRequestID, nil, msg), true
|
||||
}
|
||||
// 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage
|
||||
if clientDisconnected {
|
||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete after disconnect: %w", scanErr), true
|
||||
}
|
||||
sendErrorEvent("stream_read_error")
|
||||
return resultWithUsage(), fmt.Errorf("stream read error: %w", scanErr), true
|
||||
}
|
||||
processSSELine := func(line string, queueDrained bool) {
|
||||
if streamFailoverErr != nil {
|
||||
return
|
||||
}
|
||||
lastDataAt = time.Now()
|
||||
|
||||
// Extract data from SSE line (supports both "data: " and "data:" formats)
|
||||
@@ -3930,18 +4129,32 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
if openAIStreamEventIsTerminal(data) {
|
||||
sawTerminalEvent = true
|
||||
}
|
||||
eventType := strings.TrimSpace(gjson.GetBytes(dataBytes, "type").String())
|
||||
forceFlushFailedEvent := false
|
||||
if eventType == "response.failed" {
|
||||
failedMessage = extractOpenAISSEErrorMessage(dataBytes)
|
||||
if !openAIStreamClientOutputStarted(c, clientOutputStarted) && openAIStreamFailedEventShouldFailover(dataBytes, failedMessage) {
|
||||
sawFailedEvent = true
|
||||
streamFailoverErr = s.newOpenAIStreamFailoverError(c, account, false, upstreamRequestID, dataBytes, failedMessage)
|
||||
return
|
||||
}
|
||||
forceFlushFailedEvent = true
|
||||
sawFailedEvent = true
|
||||
}
|
||||
|
||||
// Correct Codex tool calls if needed (apply_patch -> edit, etc.)
|
||||
if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEBytes(dataBytes); corrected {
|
||||
dataBytes = correctedData
|
||||
data = string(correctedData)
|
||||
line = "data: " + data
|
||||
eventType = strings.TrimSpace(gjson.GetBytes(dataBytes, "type").String())
|
||||
}
|
||||
startsClientOutput := forceFlushFailedEvent || openAIStreamDataStartsClientOutput(data, eventType)
|
||||
|
||||
// 写入客户端(客户端断开后继续 drain 上游)
|
||||
if !clientDisconnected {
|
||||
shouldFlush := queueDrained
|
||||
if firstTokenMs == nil && data != "" && data != "[DONE]" {
|
||||
shouldFlush := queueDrained && (clientOutputStarted || startsClientOutput)
|
||||
if firstTokenMs == nil && startsClientOutput {
|
||||
// 保证首个 token 事件尽快出站,避免影响 TTFT。
|
||||
shouldFlush = true
|
||||
}
|
||||
@@ -3955,12 +4168,14 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
if err := flushBuffered(); err != nil {
|
||||
clientDisconnected = true
|
||||
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming flush, continuing to drain upstream for billing")
|
||||
} else {
|
||||
clientOutputStarted = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Record first token time
|
||||
if firstTokenMs == nil && data != "" && data != "[DONE]" {
|
||||
if firstTokenMs == nil && startsClientOutput {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
@@ -3976,10 +4191,12 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
} else if _, err := bufferedWriter.WriteString("\n"); err != nil {
|
||||
clientDisconnected = true
|
||||
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing")
|
||||
} else if queueDrained {
|
||||
} else if queueDrained && clientOutputStarted {
|
||||
if err := flushBuffered(); err != nil {
|
||||
clientDisconnected = true
|
||||
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming flush, continuing to drain upstream for billing")
|
||||
} else {
|
||||
clientOutputStarted = true
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3990,6 +4207,9 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
defer putSSEScannerBuf64K(scanBuf)
|
||||
for scanner.Scan() {
|
||||
processSSELine(scanner.Text(), true)
|
||||
if streamFailoverErr != nil {
|
||||
return resultWithUsage(), streamFailoverErr
|
||||
}
|
||||
}
|
||||
if result, err, done := handleScanErr(scanner.Err()); done {
|
||||
return result, err
|
||||
@@ -4039,6 +4259,9 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
return result, err
|
||||
}
|
||||
processSSELine(ev.line, len(events) == 0)
|
||||
if streamFailoverErr != nil {
|
||||
return resultWithUsage(), streamFailoverErr
|
||||
}
|
||||
|
||||
case <-intervalCh:
|
||||
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
|
||||
|
||||
@@ -93,6 +93,13 @@ type cancelReadCloser struct{}
|
||||
func (c cancelReadCloser) Read(p []byte) (int, error) { return 0, context.Canceled }
|
||||
func (c cancelReadCloser) Close() error { return nil }
|
||||
|
||||
type errReadCloser struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (r errReadCloser) Read([]byte) (int, error) { return 0, r.err }
|
||||
func (r errReadCloser) Close() error { return nil }
|
||||
|
||||
type failingGinWriter struct {
|
||||
gin.ResponseWriter
|
||||
failAfter int
|
||||
@@ -1003,6 +1010,150 @@ func TestOpenAIStreamingContextCanceledReturnsIncompleteErrorWithoutInjectingErr
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingReadErrorBeforeOutputReturnsFailover(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
StreamDataIntervalTimeout: 0,
|
||||
StreamKeepaliveInterval: 0,
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: errReadCloser{err: io.ErrUnexpectedEOF},
|
||||
Header: http.Header{"X-Request-Id": []string{"rid-disconnect"}},
|
||||
}
|
||||
|
||||
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
|
||||
require.Error(t, err)
|
||||
var failoverErr *UpstreamFailoverError
|
||||
require.ErrorAs(t, err, &failoverErr)
|
||||
require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode)
|
||||
require.False(t, c.Writer.Written())
|
||||
require.Empty(t, rec.Body.String())
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingResponseFailedBeforeOutputReturnsFailover(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
StreamDataIntervalTimeout: 0,
|
||||
StreamKeepaliveInterval: 0,
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
|
||||
"event: response.created",
|
||||
`data: {"type":"response.created","response":{"id":"resp_1"}}`,
|
||||
"",
|
||||
"event: response.in_progress",
|
||||
`data: {"type":"response.in_progress","response":{"id":"resp_1"}}`,
|
||||
"",
|
||||
"event: response.failed",
|
||||
`data: {"type":"response.failed","error":{"message":"An error occurred while processing your request."}}`,
|
||||
"",
|
||||
}, "\n"))),
|
||||
Header: http.Header{"X-Request-Id": []string{"rid-failed"}},
|
||||
}
|
||||
|
||||
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
|
||||
require.Error(t, err)
|
||||
var failoverErr *UpstreamFailoverError
|
||||
require.ErrorAs(t, err, &failoverErr)
|
||||
require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode)
|
||||
require.Contains(t, string(failoverErr.ResponseBody), "An error occurred while processing your request")
|
||||
require.False(t, c.Writer.Written())
|
||||
require.Empty(t, rec.Body.String())
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingPreambleOnlyMissingTerminalReturnsFailover(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
StreamDataIntervalTimeout: 0,
|
||||
StreamKeepaliveInterval: 0,
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
|
||||
"event: response.created",
|
||||
`data: {"type":"response.created","response":{"id":"resp_1"}}`,
|
||||
"",
|
||||
"event: response.in_progress",
|
||||
`data: {"type":"response.in_progress","response":{"id":"resp_1"}}`,
|
||||
"",
|
||||
}, "\n"))),
|
||||
Header: http.Header{"X-Request-Id": []string{"rid-missing-terminal"}},
|
||||
}
|
||||
|
||||
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
|
||||
require.Error(t, err)
|
||||
var failoverErr *UpstreamFailoverError
|
||||
require.ErrorAs(t, err, &failoverErr)
|
||||
require.False(t, c.Writer.Written())
|
||||
require.Empty(t, rec.Body.String())
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingPolicyResponseFailedBeforeOutputPassesThrough(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
StreamDataIntervalTimeout: 0,
|
||||
StreamKeepaliveInterval: 0,
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
|
||||
"event: response.created",
|
||||
`data: {"type":"response.created","response":{"id":"resp_1"}}`,
|
||||
"",
|
||||
"event: response.failed",
|
||||
`data: {"type":"response.failed","error":{"type":"safety_error","message":"This request has been flagged for potentially high-risk cyber activity."}}`,
|
||||
"",
|
||||
}, "\n"))),
|
||||
Header: http.Header{"X-Request-Id": []string{"rid-policy-failed"}},
|
||||
}
|
||||
|
||||
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
|
||||
require.Error(t, err)
|
||||
var failoverErr *UpstreamFailoverError
|
||||
require.False(t, errors.As(err, &failoverErr))
|
||||
require.True(t, c.Writer.Written())
|
||||
require.Contains(t, rec.Body.String(), "response.failed")
|
||||
require.Contains(t, rec.Body.String(), "high-risk cyber activity")
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingClientDisconnectDrainsUpstreamUsage(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
@@ -1072,7 +1223,7 @@ func TestOpenAIStreamingMissingTerminalEventReturnsIncompleteError(t *testing.T)
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
_, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n"))
|
||||
_, _ = pw.Write([]byte("data: {\"type\":\"response.output_item.added\",\"item\":{\"type\":\"message\"},\"output_index\":0}\n\n"))
|
||||
}()
|
||||
|
||||
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
|
||||
@@ -1104,7 +1255,7 @@ func TestOpenAIStreamingPassthroughMissingTerminalEventReturnsIncompleteError(t
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
_, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n"))
|
||||
_, _ = pw.Write([]byte("data: {\"type\":\"response.output_item.added\",\"item\":{\"type\":\"message\"},\"output_index\":0}\n\n"))
|
||||
}()
|
||||
|
||||
_, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "", "")
|
||||
@@ -1114,6 +1265,42 @@ func TestOpenAIStreamingPassthroughMissingTerminalEventReturnsIncompleteError(t
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingPassthroughResponseFailedBeforeOutputReturnsFailover(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
|
||||
"event: response.created",
|
||||
`data: {"type":"response.created","response":{"id":"resp_1"}}`,
|
||||
"",
|
||||
"event: response.failed",
|
||||
`data: {"type":"response.failed","error":{"message":"upstream processing failed"}}`,
|
||||
"",
|
||||
}, "\n"))),
|
||||
Header: http.Header{"X-Request-Id": []string{"rid-passthrough-failed"}},
|
||||
}
|
||||
|
||||
_, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "", "")
|
||||
require.Error(t, err)
|
||||
var failoverErr *UpstreamFailoverError
|
||||
require.ErrorAs(t, err, &failoverErr)
|
||||
require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode)
|
||||
require.Contains(t, string(failoverErr.ResponseBody), "upstream processing failed")
|
||||
require.False(t, c.Writer.Written())
|
||||
require.Empty(t, rec.Body.String())
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingPassthroughResponseDoneWithoutDoneMarkerStillSucceeds(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
|
||||
Reference in New Issue
Block a user