diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 05112f84..220f1f95 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -111,12 +111,13 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel return service.OpenAIErrorWrapper(fmt.Errorf("invalid response"), "invalid_response", http.StatusInternalServerError), nil } - containStreamUsage := false + defer common.CloseResponseBodyGracefully(resp) + + model := info.UpstreamModelName var responseId string var createAt int64 = 0 var systemFingerprint string - model := info.UpstreamModelName - + var containStreamUsage bool var responseTextBuilder strings.Builder var toolCount int var usage = &dto.Usage{} @@ -148,31 +149,15 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel return true }) + // 处理最后的响应 shouldSendLastResp := true - var lastStreamResponse dto.ChatCompletionsStreamResponse - err := common.DecodeJsonStr(lastStreamData, &lastStreamResponse) - if err == nil { - responseId = lastStreamResponse.Id - createAt = lastStreamResponse.Created - systemFingerprint = lastStreamResponse.GetSystemFingerprint() - model = lastStreamResponse.Model - if service.ValidUsage(lastStreamResponse.Usage) { - containStreamUsage = true - usage = lastStreamResponse.Usage - if !info.ShouldIncludeUsage { - shouldSendLastResp = false - } - } - for _, choice := range lastStreamResponse.Choices { - if choice.FinishReason != nil { - shouldSendLastResp = true - } - } + if err := handleLastResponse(lastStreamData, &responseId, &createAt, &systemFingerprint, &model, &usage, + &containStreamUsage, info, &shouldSendLastResp); err != nil { + common.SysError("error handling last response: " + err.Error()) } - if shouldSendLastResp { - sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent) - //err = handleStreamFormat(c, info, lastStreamData, forceFormat, thinkToContent) + if shouldSendLastResp && info.RelayFormat == relaycommon.RelayFormatOpenAI { + _ = sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent) } // 处理token计算 @@ -490,9 +475,9 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op localUsage = &dto.RealtimeUsage{} // print now usage } - //common.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage)) - //common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage)) - //common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage)) + common.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage)) + common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage)) + common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage)) } else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated { realtimeSession := realtimeEvent.Session