From 1236fa8fe42f8fc736c4450836bc5b93b05e0f85 Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Sat, 3 May 2025 19:19:19 +0800 Subject: [PATCH] add OaiResponsesStreamHandler --- dto/openai_response.go | 8 +++- relay/channel/openai/adaptor.go | 2 +- relay/channel/openai/relay-openai.go | 62 ++++++++++++++++++++++++++++ 3 files changed, 70 insertions(+), 2 deletions(-) diff --git a/dto/openai_response.go b/dto/openai_response.go index 2f858d26..02befd79 100644 --- a/dto/openai_response.go +++ b/dto/openai_response.go @@ -214,7 +214,7 @@ type OpenAIResponsesResponse struct { Tools []interface{} `json:"tools"` TopP float64 `json:"top_p"` Truncation string `json:"truncation"` - Usage Usage `json:"usage"` + Usage *Usage `json:"usage"` User json.RawMessage `json:"user"` Metadata json.RawMessage `json:"metadata"` } @@ -236,3 +236,9 @@ type ResponsesOutputContent struct { Text string `json:"text"` Annotations []interface{} `json:"annotations"` } + +// ResponsesStreamResponse 用于处理 /v1/responses 流式响应 +type ResponsesStreamResponse struct { + Type string `json:"type"` + Response *OpenAIResponsesResponse `json:"response"` +} diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index dc5098c4..7740c498 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -427,7 +427,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom err, usage = common_handler.RerankHandler(c, info, resp) case constant.RelayModeResponses: if info.IsStream { - err, usage = OaiStreamHandler(c, resp, info) + err, usage = OaiResponsesStreamHandler(c, resp, info) } else { err, usage = OpenaiResponsesHandler(c, resp, info) } diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 269a76f7..bfeed2cf 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -694,3 +694,65 @@ func OpenaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycomm usage.TotalTokens = responsesResponse.Usage.TotalTokens return nil, &usage } + +func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + if resp == nil || resp.Body == nil { + common.LogError(c, "invalid response or response body") + return service.OpenAIErrorWrapper(fmt.Errorf("invalid response"), "invalid_response", http.StatusInternalServerError), nil + } + + var usage = &dto.Usage{} + var streamItems []string // 存储流式数据项 + // var responseTextBuilder strings.Builder + // var toolCount int + var forceFormat bool + + if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok { + forceFormat = forceFmt + } + + var lastStreamData string + + helper.StreamScannerHandler(c, resp, info, func(data string) bool { + if lastStreamData != "" { + // 处理上一条数据 + sendResponsesStreamData(c, lastStreamData, forceFormat) + } + lastStreamData = data + streamItems = append(streamItems, data) + + // 检查当前数据是否包含 completed 状态和 usage 信息 + var streamResponse dto.ResponsesStreamResponse + if err := common.DecodeJsonStr(data, &streamResponse); err == nil { + if streamResponse.Type == "response.completed" { + // 处理 completed 状态 + usage.PromptTokens = streamResponse.Response.Usage.InputTokens + usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens + usage.TotalTokens = streamResponse.Response.Usage.TotalTokens + } + } + return true + }) + + // 处理最后一条数据 + sendResponsesStreamData(c, lastStreamData, forceFormat) + + // 处理token计算 + // if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil { + // common.SysError("error processing tokens: " + err.Error()) + // } + + return nil, usage +} + +func sendResponsesStreamData(c *gin.Context, data string, forceFormat bool) error { + if data == "" { + return nil + } + + if forceFormat { + return helper.ObjectData(c, data) + } else { + return helper.StringData(c, data) + } +}