🐛 fix: refactor OaiStreamHandler to improve last response handling and streamline response body closure

This commit is contained in:
CaIon
2025-06-27 22:44:20 +08:00
parent bfb6fbbac9
commit 9e6bc518cc

View File

@@ -111,12 +111,13 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
return service.OpenAIErrorWrapper(fmt.Errorf("invalid response"), "invalid_response", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(fmt.Errorf("invalid response"), "invalid_response", http.StatusInternalServerError), nil
} }
containStreamUsage := false defer common.CloseResponseBodyGracefully(resp)
model := info.UpstreamModelName
var responseId string var responseId string
var createAt int64 = 0 var createAt int64 = 0
var systemFingerprint string var systemFingerprint string
model := info.UpstreamModelName var containStreamUsage bool
var responseTextBuilder strings.Builder var responseTextBuilder strings.Builder
var toolCount int var toolCount int
var usage = &dto.Usage{} var usage = &dto.Usage{}
@@ -148,31 +149,15 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
return true return true
}) })
// 处理最后的响应
shouldSendLastResp := true shouldSendLastResp := true
var lastStreamResponse dto.ChatCompletionsStreamResponse if err := handleLastResponse(lastStreamData, &responseId, &createAt, &systemFingerprint, &model, &usage,
err := common.DecodeJsonStr(lastStreamData, &lastStreamResponse) &containStreamUsage, info, &shouldSendLastResp); err != nil {
if err == nil { common.SysError("error handling last response: " + err.Error())
responseId = lastStreamResponse.Id
createAt = lastStreamResponse.Created
systemFingerprint = lastStreamResponse.GetSystemFingerprint()
model = lastStreamResponse.Model
if service.ValidUsage(lastStreamResponse.Usage) {
containStreamUsage = true
usage = lastStreamResponse.Usage
if !info.ShouldIncludeUsage {
shouldSendLastResp = false
}
}
for _, choice := range lastStreamResponse.Choices {
if choice.FinishReason != nil {
shouldSendLastResp = true
}
}
} }
if shouldSendLastResp { if shouldSendLastResp && info.RelayFormat == relaycommon.RelayFormatOpenAI {
sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent) _ = sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
//err = handleStreamFormat(c, info, lastStreamData, forceFormat, thinkToContent)
} }
// 处理token计算 // 处理token计算
@@ -490,9 +475,9 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
localUsage = &dto.RealtimeUsage{} localUsage = &dto.RealtimeUsage{}
// print now usage // print now usage
} }
//common.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage)) common.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
//common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage)) common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
//common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage)) common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
} else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated { } else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
realtimeSession := realtimeEvent.Session realtimeSession := realtimeEvent.Session