diff --git a/relay/channel/zhipu_4v/adaptor.go b/relay/channel/zhipu_4v/adaptor.go index 8f6aab39..e13a7ad2 100644 --- a/relay/channel/zhipu_4v/adaptor.go +++ b/relay/channel/zhipu_4v/adaptor.go @@ -10,6 +10,7 @@ import ( "one-api/relay/channel" "one-api/relay/channel/openai" relaycommon "one-api/relay/common" + relayconstant "one-api/relay/constant" ) type Adaptor struct { @@ -35,7 +36,13 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - return fmt.Sprintf("%s/api/paas/v4/chat/completions", info.BaseUrl), nil + baseUrl := fmt.Sprintf("%s/api/paas/v4", info.BaseUrl) + switch info.RelayMode { + case relayconstant.RelayModeEmbeddings: + return fmt.Sprintf("%s/embeddings", baseUrl), nil + default: + return fmt.Sprintf("%s/chat/completions", baseUrl), nil + } } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { @@ -60,8 +67,7 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt } func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { - //TODO implement me - return nil, errors.New("not implemented") + return request, nil } func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { diff --git a/relay/channel/zhipu_4v/relay-zhipu_v4.go b/relay/channel/zhipu_4v/relay-zhipu_v4.go index c55cee51..271dda8f 100644 --- a/relay/channel/zhipu_4v/relay-zhipu_v4.go +++ b/relay/channel/zhipu_4v/relay-zhipu_v4.go @@ -1,17 +1,9 @@ package zhipu_4v import ( - "bufio" - "bytes" - "encoding/json" - "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt" - "io" - "net/http" "one-api/common" "one-api/dto" - "one-api/relay/helper" - "one-api/service" "strings" "sync" "time" @@ -119,163 +111,3 @@ func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIReq ToolChoice: request.ToolChoice, } } - -//func responseZhipu2OpenAI(response *dto.OpenAITextResponse) *dto.OpenAITextResponse { -// fullTextResponse := dto.OpenAITextResponse{ -// Id: response.Id, -// Object: "chat.completion", -// Created: common.GetTimestamp(), -// Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.TextResponseChoices)), -// Usage: response.Usage, -// } -// for i, choice := range response.TextResponseChoices { -// content, _ := json.Marshal(strings.Trim(choice.Content, "\"")) -// openaiChoice := dto.OpenAITextResponseChoice{ -// Index: i, -// Message: dto.Message{ -// Role: choice.Role, -// Content: content, -// }, -// FinishReason: "", -// } -// if i == len(response.TextResponseChoices)-1 { -// openaiChoice.FinishReason = "stop" -// } -// fullTextResponse.Choices = append(fullTextResponse.Choices, openaiChoice) -// } -// return &fullTextResponse -//} - -func streamResponseZhipu2OpenAI(zhipuResponse *ZhipuV4StreamResponse) *dto.ChatCompletionsStreamResponse { - var choice dto.ChatCompletionsStreamResponseChoice - choice.Delta.Content = zhipuResponse.Choices[0].Delta.Content - choice.Delta.Role = zhipuResponse.Choices[0].Delta.Role - choice.Delta.ToolCalls = zhipuResponse.Choices[0].Delta.ToolCalls - choice.Index = zhipuResponse.Choices[0].Index - choice.FinishReason = zhipuResponse.Choices[0].FinishReason - response := dto.ChatCompletionsStreamResponse{ - Id: zhipuResponse.Id, - Object: "chat.completion.chunk", - Created: zhipuResponse.Created, - Model: "glm-4v", - Choices: []dto.ChatCompletionsStreamResponseChoice{choice}, - } - return &response -} - -func lastStreamResponseZhipuV42OpenAI(zhipuResponse *ZhipuV4StreamResponse) (*dto.ChatCompletionsStreamResponse, *dto.Usage) { - response := streamResponseZhipu2OpenAI(zhipuResponse) - return response, &zhipuResponse.Usage -} - -func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { - var usage *dto.Usage - scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := strings.Index(string(data), "\n"); i >= 0 { - return i + 1, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - if len(data) < 6 { // ignore blank line or wrong format - continue - } - if data[:6] != "data: " && data[:6] != "[DONE]" { - continue - } - dataChan <- data - } - stopChan <- true - }() - helper.SetEventStreamHeaders(c) - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - if strings.HasPrefix(data, "data: [DONE]") { - data = data[:12] - } - // some implementations may add \r at the end of data - data = strings.TrimSuffix(data, "\r") - - var streamResponse ZhipuV4StreamResponse - err := json.Unmarshal([]byte(data), &streamResponse) - if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) - } - var response *dto.ChatCompletionsStreamResponse - if strings.Contains(data, "prompt_tokens") { - response, usage = lastStreamResponseZhipuV42OpenAI(&streamResponse) - } else { - response = streamResponseZhipu2OpenAI(&streamResponse) - } - jsonResponse, err := json.Marshal(response) - if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - return true - case <-stopChan: - return false - } - }) - err := resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - return nil, usage -} - -func zhipuHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { - var textResponse ZhipuV4Response - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil - } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - err = json.Unmarshal(responseBody, &textResponse) - if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil - } - if textResponse.Error.Type != "" { - return &dto.OpenAIErrorWithStatusCode{ - Error: textResponse.Error, - StatusCode: resp.StatusCode, - }, nil - } - // Reset response body - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) - - // We shouldn't set the header before we parse the response body, because the parse part may fail. - // And then we will have to send an error response, but in this case, the header has already been set. - // So the HTTPClient will be confused by the response. - // For example, Postman will report error, and we cannot check the response at all. - for k, v := range resp.Header { - c.Writer.Header().Set(k, v[0]) - } - c.Writer.WriteHeader(resp.StatusCode) - _, err = io.Copy(c.Writer, resp.Body) - if err != nil { - return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil - } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - - return nil, &textResponse.Usage -}