From 18b3300ff1d431a46a7cadf9a48db5c3c6ac0519 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Mon, 5 May 2025 00:40:16 +0800 Subject: [PATCH] feat: implement OpenAI responses handling and streaming support with built-in tool tracking --- dto/openai_response.go | 11 ++- relay/channel/openai/adaptor.go | 2 +- relay/channel/openai/helper.go | 7 ++ relay/channel/openai/relay-openai.go | 99 -------------------- relay/channel/openai/relay_responses.go | 114 ++++++++++++++++++++++++ relay/channel/vertex/adaptor.go | 2 +- relay/common/relay_info.go | 37 ++++++++ relay/relay-responses.go | 10 +-- 8 files changed, 173 insertions(+), 109 deletions(-) create mode 100644 relay/channel/openai/relay_responses.go diff --git a/dto/openai_response.go b/dto/openai_response.go index 1508d1f6..c8f61b9d 100644 --- a/dto/openai_response.go +++ b/dto/openai_response.go @@ -238,8 +238,12 @@ type ResponsesOutputContent struct { } const ( - BuildInTools_WebSearch = "web_search_preview" - BuildInTools_FileSearch = "file_search" + BuildInToolWebSearchPreview = "web_search_preview" + BuildInToolFileSearch = "file_search" +) + +const ( + BuildInCallWebSearchCall = "web_search_call" ) const ( @@ -250,6 +254,7 @@ const ( // ResponsesStreamResponse 用于处理 /v1/responses 流式响应 type ResponsesStreamResponse struct { Type string `json:"type"` - Response *OpenAIResponsesResponse `json:"response"` + Response *OpenAIResponsesResponse `json:"response,omitempty"` Delta string `json:"delta,omitempty"` + Item *ResponsesOutput `json:"item,omitempty"` } diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 7740c498..eb12a22a 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -429,7 +429,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom if info.IsStream { err, usage = OaiResponsesStreamHandler(c, resp, info) } else { - err, usage = OpenaiResponsesHandler(c, resp, info) + err, usage = OaiResponsesHandler(c, resp, info) } default: if info.IsStream { diff --git a/relay/channel/openai/helper.go b/relay/channel/openai/helper.go index e7ba2e7b..a068c544 100644 --- a/relay/channel/openai/helper.go +++ b/relay/channel/openai/helper.go @@ -187,3 +187,10 @@ func handleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream } } } + +func sendResponsesStreamData(c *gin.Context, streamResponse dto.ResponsesStreamResponse, data string) { + if data == "" { + return + } + helper.ResponseChunkData(c, streamResponse, data) +} diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index ef660564..b9ed94e2 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -644,102 +644,3 @@ func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycomm } return nil, &usageResp.Usage } - -func OpenaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { - // read response body - var responsesResponse dto.OpenAIResponsesResponse - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil - } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - err = common.DecodeJson(responseBody, &responsesResponse) - if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil - } - if responsesResponse.Error != nil { - return &dto.OpenAIErrorWithStatusCode{ - Error: dto.OpenAIError{ - Message: responsesResponse.Error.Message, - Type: "openai_error", - Code: responsesResponse.Error.Code, - }, - StatusCode: resp.StatusCode, - }, nil - } - - // reset response body - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) - // We shouldn't set the header before we parse the response body, because the parse part may fail. - // And then we will have to send an error response, but in this case, the header has already been set. - // So the httpClient will be confused by the response. - // For example, Postman will report error, and we cannot check the response at all. - for k, v := range resp.Header { - c.Writer.Header().Set(k, v[0]) - } - c.Writer.WriteHeader(resp.StatusCode) - // copy response body - _, err = io.Copy(c.Writer, resp.Body) - if err != nil { - common.SysError("error copying response body: " + err.Error()) - } - resp.Body.Close() - // compute usage - usage := dto.Usage{} - usage.PromptTokens = responsesResponse.Usage.InputTokens - usage.CompletionTokens = responsesResponse.Usage.OutputTokens - usage.TotalTokens = responsesResponse.Usage.TotalTokens - return nil, &usage -} - -func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { - if resp == nil || resp.Body == nil { - common.LogError(c, "invalid response or response body") - return service.OpenAIErrorWrapper(fmt.Errorf("invalid response"), "invalid_response", http.StatusInternalServerError), nil - } - - var usage = &dto.Usage{} - var responseTextBuilder strings.Builder - - helper.StreamScannerHandler(c, resp, info, func(data string) bool { - - // 检查当前数据是否包含 completed 状态和 usage 信息 - var streamResponse dto.ResponsesStreamResponse - if err := common.DecodeJsonStr(data, &streamResponse); err == nil { - sendResponsesStreamData(c, streamResponse, data) - switch streamResponse.Type { - case "response.completed": - usage.PromptTokens = streamResponse.Response.Usage.InputTokens - usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens - usage.TotalTokens = streamResponse.Response.Usage.TotalTokens - case "response.output_text.delta": - // 处理输出文本 - responseTextBuilder.WriteString(streamResponse.Delta) - - } - } - return true - }) - - if usage.CompletionTokens == 0 { - // 计算输出文本的 token 数量 - tempStr := responseTextBuilder.String() - if len(tempStr) > 0 { - // 非正常结束,使用输出文本的 token 数量 - completionTokens, _ := service.CountTextToken(tempStr, info.UpstreamModelName) - usage.CompletionTokens = completionTokens - } - } - - return nil, usage -} - -func sendResponsesStreamData(c *gin.Context, streamResponse dto.ResponsesStreamResponse, data string) { - if data == "" { - return - } - helper.ResponseChunkData(c, streamResponse, data) -} diff --git a/relay/channel/openai/relay_responses.go b/relay/channel/openai/relay_responses.go new file mode 100644 index 00000000..6af8c676 --- /dev/null +++ b/relay/channel/openai/relay_responses.go @@ -0,0 +1,114 @@ +package openai + +import ( + "bytes" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/common" + "one-api/dto" + relaycommon "one-api/relay/common" + "one-api/relay/helper" + "one-api/service" + "strings" +) + +func OaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + // read response body + var responsesResponse dto.OpenAIResponsesResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = common.DecodeJson(responseBody, &responsesResponse) + if err != nil { + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if responsesResponse.Error != nil { + return &dto.OpenAIErrorWithStatusCode{ + Error: dto.OpenAIError{ + Message: responsesResponse.Error.Message, + Type: "openai_error", + Code: responsesResponse.Error.Code, + }, + StatusCode: resp.StatusCode, + }, nil + } + + // reset response body + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + // We shouldn't set the header before we parse the response body, because the parse part may fail. + // And then we will have to send an error response, but in this case, the header has already been set. + // So the httpClient will be confused by the response. + // For example, Postman will report error, and we cannot check the response at all. + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + c.Writer.WriteHeader(resp.StatusCode) + // copy response body + _, err = io.Copy(c.Writer, resp.Body) + if err != nil { + common.SysError("error copying response body: " + err.Error()) + } + resp.Body.Close() + // compute usage + usage := dto.Usage{} + usage.PromptTokens = responsesResponse.Usage.InputTokens + usage.CompletionTokens = responsesResponse.Usage.OutputTokens + usage.TotalTokens = responsesResponse.Usage.TotalTokens + return nil, &usage +} + +func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + if resp == nil || resp.Body == nil { + common.LogError(c, "invalid response or response body") + return service.OpenAIErrorWrapper(fmt.Errorf("invalid response"), "invalid_response", http.StatusInternalServerError), nil + } + + var usage = &dto.Usage{} + var responseTextBuilder strings.Builder + + helper.StreamScannerHandler(c, resp, info, func(data string) bool { + + // 检查当前数据是否包含 completed 状态和 usage 信息 + var streamResponse dto.ResponsesStreamResponse + if err := common.DecodeJsonStr(data, &streamResponse); err == nil { + sendResponsesStreamData(c, streamResponse, data) + switch streamResponse.Type { + case "response.completed": + usage.PromptTokens = streamResponse.Response.Usage.InputTokens + usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens + usage.TotalTokens = streamResponse.Response.Usage.TotalTokens + case "response.output_text.delta": + // 处理输出文本 + responseTextBuilder.WriteString(streamResponse.Delta) + case dto.ResponsesOutputTypeItemDone: + // 函数调用处理 + if streamResponse.Item != nil { + switch streamResponse.Item.Type { + case dto.BuildInCallWebSearchCall: + info.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview].CallCount++ + } + } + } + } + return true + }) + + if usage.CompletionTokens == 0 { + // 计算输出文本的 token 数量 + tempStr := responseTextBuilder.String() + if len(tempStr) > 0 { + // 非正常结束,使用输出文本的 token 数量 + completionTokens, _ := service.CountTextToken(tempStr, info.UpstreamModelName) + usage.CompletionTokens = completionTokens + } + } + + return nil, usage +} diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go index c1b64f11..7daf9a61 100644 --- a/relay/channel/vertex/adaptor.go +++ b/relay/channel/vertex/adaptor.go @@ -11,8 +11,8 @@ import ( "one-api/relay/channel/claude" "one-api/relay/channel/gemini" "one-api/relay/channel/openai" - "one-api/setting/model_setting" relaycommon "one-api/relay/common" + "one-api/setting/model_setting" "strings" "github.com/gin-gonic/gin" diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 915474e1..99c6d12b 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -36,6 +36,7 @@ type ClaudeConvertInfo struct { const ( RelayFormatOpenAI = "openai" RelayFormatClaude = "claude" + RelayFormatGemini = "gemini" ) type RerankerInfo struct { @@ -43,6 +44,16 @@ type RerankerInfo struct { ReturnDocuments bool } +type BuildInToolInfo struct { + ToolName string + CallCount int + SearchContextSize string +} + +type ResponsesUsageInfo struct { + BuiltInTools map[string]*BuildInToolInfo +} + type RelayInfo struct { ChannelType int ChannelId int @@ -90,6 +101,7 @@ type RelayInfo struct { ThinkingContentInfo *ClaudeConvertInfo *RerankerInfo + *ResponsesUsageInfo } // 定义支持流式选项的通道类型 @@ -134,6 +146,31 @@ func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo { return info } +func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *RelayInfo { + info := GenRelayInfo(c) + info.RelayMode = relayconstant.RelayModeResponses + info.ResponsesUsageInfo = &ResponsesUsageInfo{ + BuiltInTools: make(map[string]*BuildInToolInfo), + } + if len(req.Tools) > 0 { + for _, tool := range req.Tools { + info.ResponsesUsageInfo.BuiltInTools[tool.Type] = &BuildInToolInfo{ + ToolName: tool.Type, + CallCount: 0, + } + switch tool.Type { + case dto.BuildInToolWebSearchPreview: + if tool.SearchContextSize == "" { + tool.SearchContextSize = "medium" + } + info.ResponsesUsageInfo.BuiltInTools[tool.Type].SearchContextSize = tool.SearchContextSize + } + } + } + info.IsStream = req.Stream + return info +} + func GenRelayInfo(c *gin.Context) *RelayInfo { channelType := c.GetInt("channel_type") channelId := c.GetInt("channel_id") diff --git a/relay/relay-responses.go b/relay/relay-responses.go index cdb37ae7..fd3ddb5a 100644 --- a/relay/relay-responses.go +++ b/relay/relay-responses.go @@ -19,7 +19,7 @@ import ( "github.com/gin-gonic/gin" ) -func getAndValidateResponsesRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo) (*dto.OpenAIResponsesRequest, error) { +func getAndValidateResponsesRequest(c *gin.Context) (*dto.OpenAIResponsesRequest, error) { request := &dto.OpenAIResponsesRequest{} err := common.UnmarshalBodyReusable(c, request) if err != nil { @@ -31,13 +31,11 @@ func getAndValidateResponsesRequest(c *gin.Context, relayInfo *relaycommon.Relay if len(request.Input) == 0 { return nil, errors.New("input is required") } - relayInfo.IsStream = request.Stream return request, nil } func checkInputSensitive(textRequest *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) ([]string, error) { - sensitiveWords, err := service.CheckSensitiveInput(textRequest.Input) return sensitiveWords, err } @@ -49,12 +47,14 @@ func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo } func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { - relayInfo := relaycommon.GenRelayInfo(c) - req, err := getAndValidateResponsesRequest(c, relayInfo) + req, err := getAndValidateResponsesRequest(c) if err != nil { common.LogError(c, fmt.Sprintf("getAndValidateResponsesRequest error: %s", err.Error())) return service.OpenAIErrorWrapperLocal(err, "invalid_responses_request", http.StatusBadRequest) } + + relayInfo := relaycommon.GenRelayInfoResponses(c, req) + if setting.ShouldCheckPromptSensitive() { sensitiveWords, err := checkInputSensitive(req, relayInfo) if err != nil {