package palm import ( "encoding/json" "io" "net/http" "one-api/common" "one-api/constant" "one-api/dto" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" "one-api/types" "github.com/gin-gonic/gin" ) // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body func responsePaLM2OpenAI(response *PaLMChatResponse) *dto.OpenAITextResponse { fullTextResponse := dto.OpenAITextResponse{ Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)), } for i, candidate := range response.Candidates { choice := dto.OpenAITextResponseChoice{ Index: i, Message: dto.Message{ Role: "assistant", Content: candidate.Content, }, FinishReason: "stop", } fullTextResponse.Choices = append(fullTextResponse.Choices, choice) } return &fullTextResponse } func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *dto.ChatCompletionsStreamResponse { var choice dto.ChatCompletionsStreamResponseChoice if len(palmResponse.Candidates) > 0 { choice.Delta.SetContentString(palmResponse.Candidates[0].Content) } choice.FinishReason = &constant.FinishReasonStop var response dto.ChatCompletionsStreamResponse response.Object = "chat.completion.chunk" response.Model = "palm2" response.Choices = []dto.ChatCompletionsStreamResponseChoice{choice} return &response } func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, string) { responseText := "" responseId := helper.GetResponseID(c) createdTime := common.GetTimestamp() dataChan := make(chan string) stopChan := make(chan bool) go func() { responseBody, err := io.ReadAll(resp.Body) if err != nil { common.SysLog("error reading stream response: " + err.Error()) stopChan <- true return } service.CloseResponseBodyGracefully(resp) var palmResponse PaLMChatResponse err = json.Unmarshal(responseBody, &palmResponse) if err != nil { common.SysLog("error unmarshalling stream response: " + err.Error()) stopChan <- true return } fullTextResponse := streamResponsePaLM2OpenAI(&palmResponse) fullTextResponse.Id = responseId fullTextResponse.Created = createdTime if len(palmResponse.Candidates) > 0 { responseText = palmResponse.Candidates[0].Content } jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { common.SysLog("error marshalling stream response: " + err.Error()) stopChan <- true return } dataChan <- string(jsonResponse) stopChan <- true }() helper.SetEventStreamHeaders(c) c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: c.Render(-1, common.CustomEvent{Data: "data: " + data}) return true case <-stopChan: c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) return false } }) service.CloseResponseBodyGracefully(resp) return nil, responseText } func palmHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { responseBody, err := io.ReadAll(resp.Body) if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } service.CloseResponseBodyGracefully(resp) var palmResponse PaLMChatResponse err = json.Unmarshal(responseBody, &palmResponse) if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 { return nil, types.WithOpenAIError(types.OpenAIError{ Message: palmResponse.Error.Message, Type: palmResponse.Error.Status, Param: "", Code: palmResponse.Error.Code, }, resp.StatusCode) } fullTextResponse := responsePaLM2OpenAI(&palmResponse) completionTokens := service.CountTextToken(palmResponse.Candidates[0].Content, info.UpstreamModelName) usage := dto.Usage{ PromptTokens: info.PromptTokens, CompletionTokens: completionTokens, TotalTokens: info.PromptTokens + completionTokens, } fullTextResponse.Usage = usage jsonResponse, err := common.Marshal(fullTextResponse) if err != nil { return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) service.IOCopyBytesGracefully(c, resp, jsonResponse) return &usage, nil }