This update standardizes the closure of HTTP response bodies across multiple stream handlers, enhancing error management and resource cleanup. The new method ensures that any errors during closure are handled gracefully, preventing potential request termination issues.
163 lines
5.1 KiB
Go
163 lines
5.1 KiB
Go
package palm
|
|
|
|
import (
|
|
"encoding/json"
|
|
"github.com/gin-gonic/gin"
|
|
"io"
|
|
"net/http"
|
|
"one-api/common"
|
|
"one-api/constant"
|
|
"one-api/dto"
|
|
"one-api/relay/helper"
|
|
"one-api/service"
|
|
)
|
|
|
|
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
|
|
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body
|
|
|
|
func requestOpenAI2PaLM(textRequest dto.GeneralOpenAIRequest) *PaLMChatRequest {
|
|
palmRequest := PaLMChatRequest{
|
|
Prompt: PaLMPrompt{
|
|
Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)),
|
|
},
|
|
Temperature: textRequest.Temperature,
|
|
CandidateCount: textRequest.N,
|
|
TopP: textRequest.TopP,
|
|
TopK: textRequest.MaxTokens,
|
|
}
|
|
for _, message := range textRequest.Messages {
|
|
palmMessage := PaLMChatMessage{
|
|
Content: message.StringContent(),
|
|
}
|
|
if message.Role == "user" {
|
|
palmMessage.Author = "0"
|
|
} else {
|
|
palmMessage.Author = "1"
|
|
}
|
|
palmRequest.Prompt.Messages = append(palmRequest.Prompt.Messages, palmMessage)
|
|
}
|
|
return &palmRequest
|
|
}
|
|
|
|
func responsePaLM2OpenAI(response *PaLMChatResponse) *dto.OpenAITextResponse {
|
|
fullTextResponse := dto.OpenAITextResponse{
|
|
Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
|
|
}
|
|
for i, candidate := range response.Candidates {
|
|
choice := dto.OpenAITextResponseChoice{
|
|
Index: i,
|
|
Message: dto.Message{
|
|
Role: "assistant",
|
|
Content: candidate.Content,
|
|
},
|
|
FinishReason: "stop",
|
|
}
|
|
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
|
}
|
|
return &fullTextResponse
|
|
}
|
|
|
|
func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *dto.ChatCompletionsStreamResponse {
|
|
var choice dto.ChatCompletionsStreamResponseChoice
|
|
if len(palmResponse.Candidates) > 0 {
|
|
choice.Delta.SetContentString(palmResponse.Candidates[0].Content)
|
|
}
|
|
choice.FinishReason = &constant.FinishReasonStop
|
|
var response dto.ChatCompletionsStreamResponse
|
|
response.Object = "chat.completion.chunk"
|
|
response.Model = "palm2"
|
|
response.Choices = []dto.ChatCompletionsStreamResponseChoice{choice}
|
|
return &response
|
|
}
|
|
|
|
func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
|
|
responseText := ""
|
|
responseId := helper.GetResponseID(c)
|
|
createdTime := common.GetTimestamp()
|
|
dataChan := make(chan string)
|
|
stopChan := make(chan bool)
|
|
go func() {
|
|
responseBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
common.SysError("error reading stream response: " + err.Error())
|
|
stopChan <- true
|
|
return
|
|
}
|
|
common.CloseResponseBodyGracefully(resp)
|
|
var palmResponse PaLMChatResponse
|
|
err = json.Unmarshal(responseBody, &palmResponse)
|
|
if err != nil {
|
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
|
stopChan <- true
|
|
return
|
|
}
|
|
fullTextResponse := streamResponsePaLM2OpenAI(&palmResponse)
|
|
fullTextResponse.Id = responseId
|
|
fullTextResponse.Created = createdTime
|
|
if len(palmResponse.Candidates) > 0 {
|
|
responseText = palmResponse.Candidates[0].Content
|
|
}
|
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
|
if err != nil {
|
|
common.SysError("error marshalling stream response: " + err.Error())
|
|
stopChan <- true
|
|
return
|
|
}
|
|
dataChan <- string(jsonResponse)
|
|
stopChan <- true
|
|
}()
|
|
helper.SetEventStreamHeaders(c)
|
|
c.Stream(func(w io.Writer) bool {
|
|
select {
|
|
case data := <-dataChan:
|
|
c.Render(-1, common.CustomEvent{Data: "data: " + data})
|
|
return true
|
|
case <-stopChan:
|
|
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
|
return false
|
|
}
|
|
})
|
|
common.CloseResponseBodyGracefully(resp)
|
|
return nil, responseText
|
|
}
|
|
|
|
func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
|
responseBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
|
}
|
|
common.CloseResponseBodyGracefully(resp)
|
|
var palmResponse PaLMChatResponse
|
|
err = json.Unmarshal(responseBody, &palmResponse)
|
|
if err != nil {
|
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
|
}
|
|
if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
|
|
return &dto.OpenAIErrorWithStatusCode{
|
|
Error: dto.OpenAIError{
|
|
Message: palmResponse.Error.Message,
|
|
Type: palmResponse.Error.Status,
|
|
Param: "",
|
|
Code: palmResponse.Error.Code,
|
|
},
|
|
StatusCode: resp.StatusCode,
|
|
}, nil
|
|
}
|
|
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
|
|
completionTokens := service.CountTextToken(palmResponse.Candidates[0].Content, model)
|
|
usage := dto.Usage{
|
|
PromptTokens: promptTokens,
|
|
CompletionTokens: completionTokens,
|
|
TotalTokens: promptTokens + completionTokens,
|
|
}
|
|
fullTextResponse.Usage = usage
|
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
|
if err != nil {
|
|
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
|
}
|
|
c.Writer.Header().Set("Content-Type", "application/json")
|
|
c.Writer.WriteHeader(resp.StatusCode)
|
|
_, err = c.Writer.Write(jsonResponse)
|
|
return nil, &usage
|
|
}
|