Merge pull request #1943 from AyeSt0/fix/openai-responses-preoutput-failover
fix(openai): 修复 Responses 流式失败前置事件导致无法 failover
This commit is contained in:
@@ -3147,6 +3147,113 @@ type openaiStreamingResultPassthrough struct {
|
|||||||
firstTokenMs *int
|
firstTokenMs *int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func openAIStreamClientOutputStarted(c *gin.Context, localStarted bool) bool {
|
||||||
|
if localStarted {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return c != nil && c.Writer != nil && c.Writer.Written()
|
||||||
|
}
|
||||||
|
|
||||||
|
func openAIStreamEventIsPreamble(eventType string) bool {
|
||||||
|
switch strings.TrimSpace(eventType) {
|
||||||
|
case "response.created", "response.in_progress":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func openAIStreamDataStartsClientOutput(data, eventType string) bool {
|
||||||
|
trimmed := strings.TrimSpace(data)
|
||||||
|
if trimmed == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(eventType) == "response.failed" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return !openAIStreamEventIsPreamble(eventType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func openAIStreamFailedEventShouldFailover(payload []byte, message string) bool {
|
||||||
|
code := strings.ToLower(strings.TrimSpace(gjson.GetBytes(payload, "response.error.code").String()))
|
||||||
|
if code == "" {
|
||||||
|
code = strings.ToLower(strings.TrimSpace(gjson.GetBytes(payload, "error.code").String()))
|
||||||
|
}
|
||||||
|
errType := strings.ToLower(strings.TrimSpace(gjson.GetBytes(payload, "response.error.type").String()))
|
||||||
|
if errType == "" {
|
||||||
|
errType = strings.ToLower(strings.TrimSpace(gjson.GetBytes(payload, "error.type").String()))
|
||||||
|
}
|
||||||
|
combined := strings.ToLower(strings.TrimSpace(message + " " + code + " " + errType))
|
||||||
|
if combined == "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
nonRetryableMarkers := []string{
|
||||||
|
"invalid_request",
|
||||||
|
"content_policy",
|
||||||
|
"policy",
|
||||||
|
"safety",
|
||||||
|
"high-risk cyber",
|
||||||
|
"not allowed",
|
||||||
|
"violat",
|
||||||
|
}
|
||||||
|
for _, marker := range nonRetryableMarkers {
|
||||||
|
if strings.Contains(combined, marker) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) newOpenAIStreamFailoverError(
|
||||||
|
c *gin.Context,
|
||||||
|
account *Account,
|
||||||
|
passthrough bool,
|
||||||
|
upstreamRequestID string,
|
||||||
|
payload []byte,
|
||||||
|
message string,
|
||||||
|
) *UpstreamFailoverError {
|
||||||
|
message = sanitizeUpstreamErrorMessage(strings.TrimSpace(message))
|
||||||
|
if message == "" {
|
||||||
|
message = "OpenAI stream disconnected before completion"
|
||||||
|
}
|
||||||
|
detail := ""
|
||||||
|
if len(payload) > 0 && s != nil && s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||||
|
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||||
|
if maxBytes <= 0 {
|
||||||
|
maxBytes = 2048
|
||||||
|
}
|
||||||
|
detail = truncateString(string(payload), maxBytes)
|
||||||
|
}
|
||||||
|
if c != nil {
|
||||||
|
setOpsUpstreamError(c, http.StatusBadGateway, message, detail)
|
||||||
|
event := OpsUpstreamErrorEvent{
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
UpstreamStatusCode: http.StatusBadGateway,
|
||||||
|
UpstreamRequestID: strings.TrimSpace(upstreamRequestID),
|
||||||
|
Passthrough: passthrough,
|
||||||
|
Kind: "failover",
|
||||||
|
Message: message,
|
||||||
|
Detail: detail,
|
||||||
|
}
|
||||||
|
if account != nil {
|
||||||
|
event.Platform = account.Platform
|
||||||
|
event.AccountID = account.ID
|
||||||
|
event.AccountName = account.Name
|
||||||
|
}
|
||||||
|
appendOpsUpstreamError(c, event)
|
||||||
|
}
|
||||||
|
body, _ := json.Marshal(gin.H{
|
||||||
|
"error": gin.H{
|
||||||
|
"type": "upstream_error",
|
||||||
|
"message": message,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return &UpstreamFailoverError{
|
||||||
|
StatusCode: http.StatusBadGateway,
|
||||||
|
ResponseBody: body,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
resp *http.Response,
|
resp *http.Response,
|
||||||
@@ -3178,7 +3285,22 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
|||||||
clientDisconnected := false
|
clientDisconnected := false
|
||||||
sawDone := false
|
sawDone := false
|
||||||
sawTerminalEvent := false
|
sawTerminalEvent := false
|
||||||
|
sawFailedEvent := false
|
||||||
|
failedMessage := ""
|
||||||
|
clientOutputStarted := false
|
||||||
upstreamRequestID := strings.TrimSpace(resp.Header.Get("x-request-id"))
|
upstreamRequestID := strings.TrimSpace(resp.Header.Get("x-request-id"))
|
||||||
|
pendingLines := make([]string, 0, 8)
|
||||||
|
writePendingLines := func() bool {
|
||||||
|
for _, pending := range pendingLines {
|
||||||
|
if _, err := fmt.Fprintln(w, pending); err != nil {
|
||||||
|
clientDisconnected = true
|
||||||
|
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pendingLines = pendingLines[:0]
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
maxLineSize := defaultMaxLineSize
|
maxLineSize := defaultMaxLineSize
|
||||||
@@ -3193,6 +3315,8 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
|||||||
|
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := scanner.Text()
|
line := scanner.Text()
|
||||||
|
lineStartsClientOutput := false
|
||||||
|
forceFlushFailedEvent := false
|
||||||
if data, ok := extractOpenAISSEDataLine(line); ok {
|
if data, ok := extractOpenAISSEDataLine(line); ok {
|
||||||
dataBytes := []byte(data)
|
dataBytes := []byte(data)
|
||||||
trimmedData := strings.TrimSpace(data)
|
trimmedData := strings.TrimSpace(data)
|
||||||
@@ -3203,13 +3327,24 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
|||||||
trimmedData = strings.TrimSpace(replacedData)
|
trimmedData = strings.TrimSpace(replacedData)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
eventType := strings.TrimSpace(gjson.Get(trimmedData, "type").String())
|
||||||
|
if eventType == "response.failed" {
|
||||||
|
failedMessage = extractOpenAISSEErrorMessage(dataBytes)
|
||||||
|
if !openAIStreamClientOutputStarted(c, clientOutputStarted) && openAIStreamFailedEventShouldFailover(dataBytes, failedMessage) {
|
||||||
|
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs},
|
||||||
|
s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, dataBytes, failedMessage)
|
||||||
|
}
|
||||||
|
forceFlushFailedEvent = true
|
||||||
|
sawFailedEvent = true
|
||||||
|
}
|
||||||
if trimmedData == "[DONE]" {
|
if trimmedData == "[DONE]" {
|
||||||
sawDone = true
|
sawDone = true
|
||||||
}
|
}
|
||||||
if openAIStreamEventIsTerminal(trimmedData) {
|
if openAIStreamEventIsTerminal(trimmedData) {
|
||||||
sawTerminalEvent = true
|
sawTerminalEvent = true
|
||||||
}
|
}
|
||||||
if firstTokenMs == nil && trimmedData != "" && trimmedData != "[DONE]" {
|
lineStartsClientOutput = forceFlushFailedEvent || openAIStreamDataStartsClientOutput(trimmedData, eventType)
|
||||||
|
if firstTokenMs == nil && lineStartsClientOutput && trimmedData != "[DONE]" {
|
||||||
ms := int(time.Since(startTime).Milliseconds())
|
ms := int(time.Since(startTime).Milliseconds())
|
||||||
firstTokenMs = &ms
|
firstTokenMs = &ms
|
||||||
}
|
}
|
||||||
@@ -3217,20 +3352,30 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !clientDisconnected {
|
if !clientDisconnected {
|
||||||
|
if !clientOutputStarted && !lineStartsClientOutput {
|
||||||
|
pendingLines = append(pendingLines, line)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !clientOutputStarted && len(pendingLines) > 0 {
|
||||||
|
if !writePendingLines() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
if _, err := fmt.Fprintln(w, line); err != nil {
|
if _, err := fmt.Fprintln(w, line); err != nil {
|
||||||
clientDisconnected = true
|
clientDisconnected = true
|
||||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID)
|
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID)
|
||||||
} else {
|
} else {
|
||||||
|
clientOutputStarted = true
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := scanner.Err(); err != nil {
|
if err := scanner.Err(); err != nil {
|
||||||
if sawTerminalEvent {
|
if sawTerminalEvent && !sawFailedEvent {
|
||||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
|
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||||
}
|
}
|
||||||
if clientDisconnected {
|
if sawFailedEvent {
|
||||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete after disconnect: %w", err)
|
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("upstream response failed: %s", failedMessage)
|
||||||
}
|
}
|
||||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete: %w", err)
|
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete: %w", err)
|
||||||
@@ -3239,6 +3384,17 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
|||||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err)
|
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err)
|
||||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, err
|
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||||
}
|
}
|
||||||
|
if !openAIStreamClientOutputStarted(c, clientOutputStarted) {
|
||||||
|
msg := "OpenAI stream disconnected before completion"
|
||||||
|
if errText := strings.TrimSpace(err.Error()); errText != "" {
|
||||||
|
msg += ": " + errText
|
||||||
|
}
|
||||||
|
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs},
|
||||||
|
s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, nil, msg)
|
||||||
|
}
|
||||||
|
if clientDisconnected {
|
||||||
|
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete after disconnect: %w", err)
|
||||||
|
}
|
||||||
logger.LegacyPrintf("service.openai_gateway",
|
logger.LegacyPrintf("service.openai_gateway",
|
||||||
"[OpenAI passthrough] 流读取异常中断: account=%d request_id=%s err=%v",
|
"[OpenAI passthrough] 流读取异常中断: account=%d request_id=%s err=%v",
|
||||||
account.ID,
|
account.ID,
|
||||||
@@ -3247,12 +3403,19 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
|||||||
)
|
)
|
||||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
|
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
|
||||||
}
|
}
|
||||||
|
if sawFailedEvent {
|
||||||
|
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("upstream response failed: %s", failedMessage)
|
||||||
|
}
|
||||||
if !clientDisconnected && !sawDone && !sawTerminalEvent && ctx.Err() == nil {
|
if !clientDisconnected && !sawDone && !sawTerminalEvent && ctx.Err() == nil {
|
||||||
logger.FromContext(ctx).With(
|
logger.FromContext(ctx).With(
|
||||||
zap.String("component", "service.openai_gateway"),
|
zap.String("component", "service.openai_gateway"),
|
||||||
zap.Int64("account_id", account.ID),
|
zap.Int64("account_id", account.ID),
|
||||||
zap.String("upstream_request_id", upstreamRequestID),
|
zap.String("upstream_request_id", upstreamRequestID),
|
||||||
).Info("OpenAI passthrough 上游流在未收到 [DONE] 时结束,疑似断流")
|
).Info("OpenAI passthrough 上游流在未收到 [DONE] 时结束,疑似断流")
|
||||||
|
if !openAIStreamClientOutputStarted(c, clientOutputStarted) {
|
||||||
|
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs},
|
||||||
|
s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, nil, "OpenAI stream ended before a terminal event")
|
||||||
|
}
|
||||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, errors.New("stream usage incomplete: missing terminal event")
|
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, errors.New("stream usage incomplete: missing terminal event")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -3854,6 +4017,11 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
|||||||
errorEventSent := false
|
errorEventSent := false
|
||||||
clientDisconnected := false // 客户端断开后继续 drain 上游以收集 usage
|
clientDisconnected := false // 客户端断开后继续 drain 上游以收集 usage
|
||||||
sawTerminalEvent := false
|
sawTerminalEvent := false
|
||||||
|
sawFailedEvent := false
|
||||||
|
failedMessage := ""
|
||||||
|
clientOutputStarted := false
|
||||||
|
upstreamRequestID := strings.TrimSpace(resp.Header.Get("x-request-id"))
|
||||||
|
var streamFailoverErr error
|
||||||
sendErrorEvent := func(reason string) {
|
sendErrorEvent := func(reason string) {
|
||||||
if errorEventSent || clientDisconnected {
|
if errorEventSent || clientDisconnected {
|
||||||
return
|
return
|
||||||
@@ -3870,7 +4038,9 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
|||||||
}
|
}
|
||||||
if err := flushBuffered(); err != nil {
|
if err := flushBuffered(); err != nil {
|
||||||
clientDisconnected = true
|
clientDisconnected = true
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
clientOutputStarted = true
|
||||||
}
|
}
|
||||||
|
|
||||||
needModelReplace := originalModel != mappedModel
|
needModelReplace := originalModel != mappedModel
|
||||||
@@ -3878,43 +4048,72 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
|||||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}
|
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}
|
||||||
}
|
}
|
||||||
finalizeStream := func() (*openaiStreamingResult, error) {
|
finalizeStream := func() (*openaiStreamingResult, error) {
|
||||||
|
if !sawTerminalEvent {
|
||||||
|
if !openAIStreamClientOutputStarted(c, clientOutputStarted) {
|
||||||
|
return resultWithUsage(), s.newOpenAIStreamFailoverError(
|
||||||
|
c,
|
||||||
|
account,
|
||||||
|
false,
|
||||||
|
upstreamRequestID,
|
||||||
|
nil,
|
||||||
|
"OpenAI stream ended before a terminal event",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event")
|
||||||
|
}
|
||||||
|
if sawFailedEvent {
|
||||||
|
return resultWithUsage(), fmt.Errorf("upstream response failed: %s", failedMessage)
|
||||||
|
}
|
||||||
if !clientDisconnected {
|
if !clientDisconnected {
|
||||||
|
hadBufferedData := bufferedWriter.Buffered() > 0
|
||||||
if err := flushBuffered(); err != nil {
|
if err := flushBuffered(); err != nil {
|
||||||
clientDisconnected = true
|
clientDisconnected = true
|
||||||
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during final flush, returning collected usage")
|
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during final flush, returning collected usage")
|
||||||
|
} else if hadBufferedData {
|
||||||
|
clientOutputStarted = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !sawTerminalEvent {
|
|
||||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event")
|
|
||||||
}
|
|
||||||
return resultWithUsage(), nil
|
return resultWithUsage(), nil
|
||||||
}
|
}
|
||||||
handleScanErr := func(scanErr error) (*openaiStreamingResult, error, bool) {
|
handleScanErr := func(scanErr error) (*openaiStreamingResult, error, bool) {
|
||||||
if scanErr == nil {
|
if scanErr == nil {
|
||||||
return nil, nil, false
|
return nil, nil, false
|
||||||
}
|
}
|
||||||
if sawTerminalEvent {
|
if sawTerminalEvent && !sawFailedEvent {
|
||||||
logger.LegacyPrintf("service.openai_gateway", "Upstream scan ended after terminal event: %v", scanErr)
|
logger.LegacyPrintf("service.openai_gateway", "Upstream scan ended after terminal event: %v", scanErr)
|
||||||
return resultWithUsage(), nil, true
|
return resultWithUsage(), nil, true
|
||||||
}
|
}
|
||||||
|
if sawFailedEvent {
|
||||||
|
return resultWithUsage(), fmt.Errorf("upstream response failed: %s", failedMessage), true
|
||||||
|
}
|
||||||
// 客户端断开/取消请求时,上游读取往往会返回 context canceled。
|
// 客户端断开/取消请求时,上游读取往往会返回 context canceled。
|
||||||
// /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。
|
// /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。
|
||||||
if errors.Is(scanErr, context.Canceled) || errors.Is(scanErr, context.DeadlineExceeded) {
|
if errors.Is(scanErr, context.Canceled) || errors.Is(scanErr, context.DeadlineExceeded) {
|
||||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", scanErr), true
|
return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", scanErr), true
|
||||||
}
|
}
|
||||||
// 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage
|
|
||||||
if clientDisconnected {
|
|
||||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete after disconnect: %w", scanErr), true
|
|
||||||
}
|
|
||||||
if errors.Is(scanErr, bufio.ErrTooLong) {
|
if errors.Is(scanErr, bufio.ErrTooLong) {
|
||||||
logger.LegacyPrintf("service.openai_gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, scanErr)
|
logger.LegacyPrintf("service.openai_gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, scanErr)
|
||||||
sendErrorEvent("response_too_large")
|
sendErrorEvent("response_too_large")
|
||||||
return resultWithUsage(), scanErr, true
|
return resultWithUsage(), scanErr, true
|
||||||
}
|
}
|
||||||
|
if !openAIStreamClientOutputStarted(c, clientOutputStarted) {
|
||||||
|
msg := "OpenAI stream disconnected before completion"
|
||||||
|
if errText := strings.TrimSpace(scanErr.Error()); errText != "" {
|
||||||
|
msg += ": " + errText
|
||||||
|
}
|
||||||
|
return resultWithUsage(), s.newOpenAIStreamFailoverError(c, account, false, upstreamRequestID, nil, msg), true
|
||||||
|
}
|
||||||
|
// 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage
|
||||||
|
if clientDisconnected {
|
||||||
|
return resultWithUsage(), fmt.Errorf("stream usage incomplete after disconnect: %w", scanErr), true
|
||||||
|
}
|
||||||
sendErrorEvent("stream_read_error")
|
sendErrorEvent("stream_read_error")
|
||||||
return resultWithUsage(), fmt.Errorf("stream read error: %w", scanErr), true
|
return resultWithUsage(), fmt.Errorf("stream read error: %w", scanErr), true
|
||||||
}
|
}
|
||||||
processSSELine := func(line string, queueDrained bool) {
|
processSSELine := func(line string, queueDrained bool) {
|
||||||
|
if streamFailoverErr != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
lastDataAt = time.Now()
|
lastDataAt = time.Now()
|
||||||
|
|
||||||
// Extract data from SSE line (supports both "data: " and "data:" formats)
|
// Extract data from SSE line (supports both "data: " and "data:" formats)
|
||||||
@@ -3930,18 +4129,32 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
|||||||
if openAIStreamEventIsTerminal(data) {
|
if openAIStreamEventIsTerminal(data) {
|
||||||
sawTerminalEvent = true
|
sawTerminalEvent = true
|
||||||
}
|
}
|
||||||
|
eventType := strings.TrimSpace(gjson.GetBytes(dataBytes, "type").String())
|
||||||
|
forceFlushFailedEvent := false
|
||||||
|
if eventType == "response.failed" {
|
||||||
|
failedMessage = extractOpenAISSEErrorMessage(dataBytes)
|
||||||
|
if !openAIStreamClientOutputStarted(c, clientOutputStarted) && openAIStreamFailedEventShouldFailover(dataBytes, failedMessage) {
|
||||||
|
sawFailedEvent = true
|
||||||
|
streamFailoverErr = s.newOpenAIStreamFailoverError(c, account, false, upstreamRequestID, dataBytes, failedMessage)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
forceFlushFailedEvent = true
|
||||||
|
sawFailedEvent = true
|
||||||
|
}
|
||||||
|
|
||||||
// Correct Codex tool calls if needed (apply_patch -> edit, etc.)
|
// Correct Codex tool calls if needed (apply_patch -> edit, etc.)
|
||||||
if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEBytes(dataBytes); corrected {
|
if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEBytes(dataBytes); corrected {
|
||||||
dataBytes = correctedData
|
dataBytes = correctedData
|
||||||
data = string(correctedData)
|
data = string(correctedData)
|
||||||
line = "data: " + data
|
line = "data: " + data
|
||||||
|
eventType = strings.TrimSpace(gjson.GetBytes(dataBytes, "type").String())
|
||||||
}
|
}
|
||||||
|
startsClientOutput := forceFlushFailedEvent || openAIStreamDataStartsClientOutput(data, eventType)
|
||||||
|
|
||||||
// 写入客户端(客户端断开后继续 drain 上游)
|
// 写入客户端(客户端断开后继续 drain 上游)
|
||||||
if !clientDisconnected {
|
if !clientDisconnected {
|
||||||
shouldFlush := queueDrained
|
shouldFlush := queueDrained && (clientOutputStarted || startsClientOutput)
|
||||||
if firstTokenMs == nil && data != "" && data != "[DONE]" {
|
if firstTokenMs == nil && startsClientOutput {
|
||||||
// 保证首个 token 事件尽快出站,避免影响 TTFT。
|
// 保证首个 token 事件尽快出站,避免影响 TTFT。
|
||||||
shouldFlush = true
|
shouldFlush = true
|
||||||
}
|
}
|
||||||
@@ -3955,12 +4168,14 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
|||||||
if err := flushBuffered(); err != nil {
|
if err := flushBuffered(); err != nil {
|
||||||
clientDisconnected = true
|
clientDisconnected = true
|
||||||
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming flush, continuing to drain upstream for billing")
|
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming flush, continuing to drain upstream for billing")
|
||||||
|
} else {
|
||||||
|
clientOutputStarted = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Record first token time
|
// Record first token time
|
||||||
if firstTokenMs == nil && data != "" && data != "[DONE]" {
|
if firstTokenMs == nil && startsClientOutput {
|
||||||
ms := int(time.Since(startTime).Milliseconds())
|
ms := int(time.Since(startTime).Milliseconds())
|
||||||
firstTokenMs = &ms
|
firstTokenMs = &ms
|
||||||
}
|
}
|
||||||
@@ -3976,10 +4191,12 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
|||||||
} else if _, err := bufferedWriter.WriteString("\n"); err != nil {
|
} else if _, err := bufferedWriter.WriteString("\n"); err != nil {
|
||||||
clientDisconnected = true
|
clientDisconnected = true
|
||||||
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing")
|
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing")
|
||||||
} else if queueDrained {
|
} else if queueDrained && clientOutputStarted {
|
||||||
if err := flushBuffered(); err != nil {
|
if err := flushBuffered(); err != nil {
|
||||||
clientDisconnected = true
|
clientDisconnected = true
|
||||||
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming flush, continuing to drain upstream for billing")
|
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming flush, continuing to drain upstream for billing")
|
||||||
|
} else {
|
||||||
|
clientOutputStarted = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -3990,6 +4207,9 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
|||||||
defer putSSEScannerBuf64K(scanBuf)
|
defer putSSEScannerBuf64K(scanBuf)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
processSSELine(scanner.Text(), true)
|
processSSELine(scanner.Text(), true)
|
||||||
|
if streamFailoverErr != nil {
|
||||||
|
return resultWithUsage(), streamFailoverErr
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if result, err, done := handleScanErr(scanner.Err()); done {
|
if result, err, done := handleScanErr(scanner.Err()); done {
|
||||||
return result, err
|
return result, err
|
||||||
@@ -4039,6 +4259,9 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
|||||||
return result, err
|
return result, err
|
||||||
}
|
}
|
||||||
processSSELine(ev.line, len(events) == 0)
|
processSSELine(ev.line, len(events) == 0)
|
||||||
|
if streamFailoverErr != nil {
|
||||||
|
return resultWithUsage(), streamFailoverErr
|
||||||
|
}
|
||||||
|
|
||||||
case <-intervalCh:
|
case <-intervalCh:
|
||||||
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
|
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
|
||||||
|
|||||||
@@ -93,6 +93,13 @@ type cancelReadCloser struct{}
|
|||||||
func (c cancelReadCloser) Read(p []byte) (int, error) { return 0, context.Canceled }
|
func (c cancelReadCloser) Read(p []byte) (int, error) { return 0, context.Canceled }
|
||||||
func (c cancelReadCloser) Close() error { return nil }
|
func (c cancelReadCloser) Close() error { return nil }
|
||||||
|
|
||||||
|
type errReadCloser struct {
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r errReadCloser) Read([]byte) (int, error) { return 0, r.err }
|
||||||
|
func (r errReadCloser) Close() error { return nil }
|
||||||
|
|
||||||
type failingGinWriter struct {
|
type failingGinWriter struct {
|
||||||
gin.ResponseWriter
|
gin.ResponseWriter
|
||||||
failAfter int
|
failAfter int
|
||||||
@@ -1003,6 +1010,150 @@ func TestOpenAIStreamingContextCanceledReturnsIncompleteErrorWithoutInjectingErr
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIStreamingReadErrorBeforeOutputReturnsFailover(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
cfg := &config.Config{
|
||||||
|
Gateway: config.GatewayConfig{
|
||||||
|
StreamDataIntervalTimeout: 0,
|
||||||
|
StreamKeepaliveInterval: 0,
|
||||||
|
MaxLineSize: defaultMaxLineSize,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &OpenAIGatewayService{cfg: cfg}
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||||
|
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: errReadCloser{err: io.ErrUnexpectedEOF},
|
||||||
|
Header: http.Header{"X-Request-Id": []string{"rid-disconnect"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
|
||||||
|
require.Error(t, err)
|
||||||
|
var failoverErr *UpstreamFailoverError
|
||||||
|
require.ErrorAs(t, err, &failoverErr)
|
||||||
|
require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode)
|
||||||
|
require.False(t, c.Writer.Written())
|
||||||
|
require.Empty(t, rec.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIStreamingResponseFailedBeforeOutputReturnsFailover(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
cfg := &config.Config{
|
||||||
|
Gateway: config.GatewayConfig{
|
||||||
|
StreamDataIntervalTimeout: 0,
|
||||||
|
StreamKeepaliveInterval: 0,
|
||||||
|
MaxLineSize: defaultMaxLineSize,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &OpenAIGatewayService{cfg: cfg}
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||||
|
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
|
||||||
|
"event: response.created",
|
||||||
|
`data: {"type":"response.created","response":{"id":"resp_1"}}`,
|
||||||
|
"",
|
||||||
|
"event: response.in_progress",
|
||||||
|
`data: {"type":"response.in_progress","response":{"id":"resp_1"}}`,
|
||||||
|
"",
|
||||||
|
"event: response.failed",
|
||||||
|
`data: {"type":"response.failed","error":{"message":"An error occurred while processing your request."}}`,
|
||||||
|
"",
|
||||||
|
}, "\n"))),
|
||||||
|
Header: http.Header{"X-Request-Id": []string{"rid-failed"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
|
||||||
|
require.Error(t, err)
|
||||||
|
var failoverErr *UpstreamFailoverError
|
||||||
|
require.ErrorAs(t, err, &failoverErr)
|
||||||
|
require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode)
|
||||||
|
require.Contains(t, string(failoverErr.ResponseBody), "An error occurred while processing your request")
|
||||||
|
require.False(t, c.Writer.Written())
|
||||||
|
require.Empty(t, rec.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIStreamingPreambleOnlyMissingTerminalReturnsFailover(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
cfg := &config.Config{
|
||||||
|
Gateway: config.GatewayConfig{
|
||||||
|
StreamDataIntervalTimeout: 0,
|
||||||
|
StreamKeepaliveInterval: 0,
|
||||||
|
MaxLineSize: defaultMaxLineSize,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &OpenAIGatewayService{cfg: cfg}
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||||
|
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
|
||||||
|
"event: response.created",
|
||||||
|
`data: {"type":"response.created","response":{"id":"resp_1"}}`,
|
||||||
|
"",
|
||||||
|
"event: response.in_progress",
|
||||||
|
`data: {"type":"response.in_progress","response":{"id":"resp_1"}}`,
|
||||||
|
"",
|
||||||
|
}, "\n"))),
|
||||||
|
Header: http.Header{"X-Request-Id": []string{"rid-missing-terminal"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
|
||||||
|
require.Error(t, err)
|
||||||
|
var failoverErr *UpstreamFailoverError
|
||||||
|
require.ErrorAs(t, err, &failoverErr)
|
||||||
|
require.False(t, c.Writer.Written())
|
||||||
|
require.Empty(t, rec.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIStreamingPolicyResponseFailedBeforeOutputPassesThrough(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
cfg := &config.Config{
|
||||||
|
Gateway: config.GatewayConfig{
|
||||||
|
StreamDataIntervalTimeout: 0,
|
||||||
|
StreamKeepaliveInterval: 0,
|
||||||
|
MaxLineSize: defaultMaxLineSize,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &OpenAIGatewayService{cfg: cfg}
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||||
|
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
|
||||||
|
"event: response.created",
|
||||||
|
`data: {"type":"response.created","response":{"id":"resp_1"}}`,
|
||||||
|
"",
|
||||||
|
"event: response.failed",
|
||||||
|
`data: {"type":"response.failed","error":{"type":"safety_error","message":"This request has been flagged for potentially high-risk cyber activity."}}`,
|
||||||
|
"",
|
||||||
|
}, "\n"))),
|
||||||
|
Header: http.Header{"X-Request-Id": []string{"rid-policy-failed"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
|
||||||
|
require.Error(t, err)
|
||||||
|
var failoverErr *UpstreamFailoverError
|
||||||
|
require.False(t, errors.As(err, &failoverErr))
|
||||||
|
require.True(t, c.Writer.Written())
|
||||||
|
require.Contains(t, rec.Body.String(), "response.failed")
|
||||||
|
require.Contains(t, rec.Body.String(), "high-risk cyber activity")
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIStreamingClientDisconnectDrainsUpstreamUsage(t *testing.T) {
|
func TestOpenAIStreamingClientDisconnectDrainsUpstreamUsage(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
cfg := &config.Config{
|
cfg := &config.Config{
|
||||||
@@ -1072,7 +1223,7 @@ func TestOpenAIStreamingMissingTerminalEventReturnsIncompleteError(t *testing.T)
|
|||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer func() { _ = pw.Close() }()
|
defer func() { _ = pw.Close() }()
|
||||||
_, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n"))
|
_, _ = pw.Write([]byte("data: {\"type\":\"response.output_item.added\",\"item\":{\"type\":\"message\"},\"output_index\":0}\n\n"))
|
||||||
}()
|
}()
|
||||||
|
|
||||||
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
|
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
|
||||||
@@ -1104,7 +1255,7 @@ func TestOpenAIStreamingPassthroughMissingTerminalEventReturnsIncompleteError(t
|
|||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer func() { _ = pw.Close() }()
|
defer func() { _ = pw.Close() }()
|
||||||
_, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n"))
|
_, _ = pw.Write([]byte("data: {\"type\":\"response.output_item.added\",\"item\":{\"type\":\"message\"},\"output_index\":0}\n\n"))
|
||||||
}()
|
}()
|
||||||
|
|
||||||
_, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "", "")
|
_, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "", "")
|
||||||
@@ -1114,6 +1265,42 @@ func TestOpenAIStreamingPassthroughMissingTerminalEventReturnsIncompleteError(t
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIStreamingPassthroughResponseFailedBeforeOutputReturnsFailover(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
cfg := &config.Config{
|
||||||
|
Gateway: config.GatewayConfig{
|
||||||
|
MaxLineSize: defaultMaxLineSize,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &OpenAIGatewayService{cfg: cfg}
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||||
|
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
|
||||||
|
"event: response.created",
|
||||||
|
`data: {"type":"response.created","response":{"id":"resp_1"}}`,
|
||||||
|
"",
|
||||||
|
"event: response.failed",
|
||||||
|
`data: {"type":"response.failed","error":{"message":"upstream processing failed"}}`,
|
||||||
|
"",
|
||||||
|
}, "\n"))),
|
||||||
|
Header: http.Header{"X-Request-Id": []string{"rid-passthrough-failed"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "", "")
|
||||||
|
require.Error(t, err)
|
||||||
|
var failoverErr *UpstreamFailoverError
|
||||||
|
require.ErrorAs(t, err, &failoverErr)
|
||||||
|
require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode)
|
||||||
|
require.Contains(t, string(failoverErr.ResponseBody), "upstream processing failed")
|
||||||
|
require.False(t, c.Writer.Written())
|
||||||
|
require.Empty(t, rec.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIStreamingPassthroughResponseDoneWithoutDoneMarkerStillSucceeds(t *testing.T) {
|
func TestOpenAIStreamingPassthroughResponseDoneWithoutDoneMarkerStillSucceeds(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
cfg := &config.Config{
|
cfg := &config.Config{
|
||||||
|
|||||||
Reference in New Issue
Block a user