diff --git a/dto/openai_response.go b/dto/openai_response.go index 02befd79..1508d1f6 100644 --- a/dto/openai_response.go +++ b/dto/openai_response.go @@ -237,8 +237,19 @@ type ResponsesOutputContent struct { Annotations []interface{} `json:"annotations"` } +const ( + BuildInTools_WebSearch = "web_search_preview" + BuildInTools_FileSearch = "file_search" +) + +const ( + ResponsesOutputTypeItemAdded = "response.output_item.added" + ResponsesOutputTypeItemDone = "response.output_item.done" +) + // ResponsesStreamResponse 用于处理 /v1/responses 流式响应 type ResponsesStreamResponse struct { Type string `json:"type"` Response *OpenAIResponsesResponse `json:"response"` + Delta string `json:"delta,omitempty"` } diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index bfeed2cf..f10ebc1b 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -702,57 +702,46 @@ func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relayc } var usage = &dto.Usage{} - var streamItems []string // 存储流式数据项 - // var responseTextBuilder strings.Builder - // var toolCount int - var forceFormat bool - - if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok { - forceFormat = forceFmt - } - - var lastStreamData string + var responseTextBuilder strings.Builder helper.StreamScannerHandler(c, resp, info, func(data string) bool { - if lastStreamData != "" { - // 处理上一条数据 - sendResponsesStreamData(c, lastStreamData, forceFormat) - } - lastStreamData = data - streamItems = append(streamItems, data) // 检查当前数据是否包含 completed 状态和 usage 信息 var streamResponse dto.ResponsesStreamResponse if err := common.DecodeJsonStr(data, &streamResponse); err == nil { - if streamResponse.Type == "response.completed" { - // 处理 completed 状态 + sendResponsesStreamData(c, streamResponse, data) + switch streamResponse.Type { + case "response.completed": usage.PromptTokens = streamResponse.Response.Usage.InputTokens usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens usage.TotalTokens = streamResponse.Response.Usage.TotalTokens + case "response.output_text.delta": + // 处理输出文本 + responseTextBuilder.WriteString(streamResponse.Delta) + } } return true }) - // 处理最后一条数据 - sendResponsesStreamData(c, lastStreamData, forceFormat) + helper.Done(c) - // 处理token计算 - // if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil { - // common.SysError("error processing tokens: " + err.Error()) - // } + if usage.CompletionTokens == 0 { + // 计算输出文本的 token 数量 + tempStr := responseTextBuilder.String() + if len(tempStr) > 0 { + // 非正常结束,使用输出文本的 token 数量 + completionTokens, _ := service.CountTextToken(tempStr, info.UpstreamModelName) + usage.CompletionTokens = completionTokens + } + } return nil, usage } -func sendResponsesStreamData(c *gin.Context, data string, forceFormat bool) error { +func sendResponsesStreamData(c *gin.Context, streamResponse dto.ResponsesStreamResponse, data string) { if data == "" { - return nil - } - - if forceFormat { - return helper.ObjectData(c, data) - } else { - return helper.StringData(c, data) + return } + helper.ResponseChunkData(c, streamResponse, data) } diff --git a/relay/helper/common.go b/relay/helper/common.go index ebfb6d58..43e8b92c 100644 --- a/relay/helper/common.go +++ b/relay/helper/common.go @@ -43,6 +43,14 @@ func ClaudeChunkData(c *gin.Context, resp dto.ClaudeResponse, data string) { } } +func ResponseChunkData(c *gin.Context, resp dto.ResponsesStreamResponse, data string) { + c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)}) + c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("data: %s\n", data)}) + if flusher, ok := c.Writer.(http.Flusher); ok { + flusher.Flush() + } +} + func StringData(c *gin.Context, str string) error { //str = strings.TrimPrefix(str, "data: ") //str = strings.TrimSuffix(str, "\r")