feat: enhance OaiResponsesStreamHandler to handle output text and improve response streaming

This commit is contained in:
CaIon
2025-05-04 17:09:37 +08:00
parent 1236fa8fe4
commit fe3232bf23
3 changed files with 40 additions and 32 deletions

View File

@@ -237,8 +237,19 @@ type ResponsesOutputContent struct {
Annotations []interface{} `json:"annotations"` Annotations []interface{} `json:"annotations"`
} }
const (
BuildInTools_WebSearch = "web_search_preview"
BuildInTools_FileSearch = "file_search"
)
const (
ResponsesOutputTypeItemAdded = "response.output_item.added"
ResponsesOutputTypeItemDone = "response.output_item.done"
)
// ResponsesStreamResponse 用于处理 /v1/responses 流式响应 // ResponsesStreamResponse 用于处理 /v1/responses 流式响应
type ResponsesStreamResponse struct { type ResponsesStreamResponse struct {
Type string `json:"type"` Type string `json:"type"`
Response *OpenAIResponsesResponse `json:"response"` Response *OpenAIResponsesResponse `json:"response"`
Delta string `json:"delta,omitempty"`
} }

View File

@@ -702,57 +702,46 @@ func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relayc
} }
var usage = &dto.Usage{} var usage = &dto.Usage{}
var streamItems []string // 存储流式数据项 var responseTextBuilder strings.Builder
// var responseTextBuilder strings.Builder
// var toolCount int
var forceFormat bool
if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok {
forceFormat = forceFmt
}
var lastStreamData string
helper.StreamScannerHandler(c, resp, info, func(data string) bool { helper.StreamScannerHandler(c, resp, info, func(data string) bool {
if lastStreamData != "" {
// 处理上一条数据
sendResponsesStreamData(c, lastStreamData, forceFormat)
}
lastStreamData = data
streamItems = append(streamItems, data)
// 检查当前数据是否包含 completed 状态和 usage 信息 // 检查当前数据是否包含 completed 状态和 usage 信息
var streamResponse dto.ResponsesStreamResponse var streamResponse dto.ResponsesStreamResponse
if err := common.DecodeJsonStr(data, &streamResponse); err == nil { if err := common.DecodeJsonStr(data, &streamResponse); err == nil {
if streamResponse.Type == "response.completed" { sendResponsesStreamData(c, streamResponse, data)
// 处理 completed 状态 switch streamResponse.Type {
case "response.completed":
usage.PromptTokens = streamResponse.Response.Usage.InputTokens usage.PromptTokens = streamResponse.Response.Usage.InputTokens
usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
usage.TotalTokens = streamResponse.Response.Usage.TotalTokens usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
case "response.output_text.delta":
// 处理输出文本
responseTextBuilder.WriteString(streamResponse.Delta)
} }
} }
return true return true
}) })
// 处理最后一条数据 helper.Done(c)
sendResponsesStreamData(c, lastStreamData, forceFormat)
// 处理token计算 if usage.CompletionTokens == 0 {
// if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil { // 计算输出文本的 token 数量
// common.SysError("error processing tokens: " + err.Error()) tempStr := responseTextBuilder.String()
// } if len(tempStr) > 0 {
// 非正常结束,使用输出文本的 token 数量
completionTokens, _ := service.CountTextToken(tempStr, info.UpstreamModelName)
usage.CompletionTokens = completionTokens
}
}
return nil, usage return nil, usage
} }
func sendResponsesStreamData(c *gin.Context, data string, forceFormat bool) error { func sendResponsesStreamData(c *gin.Context, streamResponse dto.ResponsesStreamResponse, data string) {
if data == "" { if data == "" {
return nil return
}
if forceFormat {
return helper.ObjectData(c, data)
} else {
return helper.StringData(c, data)
} }
helper.ResponseChunkData(c, streamResponse, data)
} }

View File

@@ -43,6 +43,14 @@ func ClaudeChunkData(c *gin.Context, resp dto.ClaudeResponse, data string) {
} }
} }
func ResponseChunkData(c *gin.Context, resp dto.ResponsesStreamResponse, data string) {
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)})
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("data: %s\n", data)})
if flusher, ok := c.Writer.(http.Flusher); ok {
flusher.Flush()
}
}
func StringData(c *gin.Context, str string) error { func StringData(c *gin.Context, str string) error {
//str = strings.TrimPrefix(str, "data: ") //str = strings.TrimPrefix(str, "data: ")
//str = strings.TrimSuffix(str, "\r") //str = strings.TrimSuffix(str, "\r")