diff --git a/relay/channel/cohere/adaptor.go b/relay/channel/cohere/adaptor.go index c63c89c1..4f3a96c3 100644 --- a/relay/channel/cohere/adaptor.go +++ b/relay/channel/cohere/adaptor.go @@ -74,12 +74,12 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { if info.RelayMode == constant.RelayModeRerank { - err, usage = cohereRerankHandler(c, resp, info) + usage, err = cohereRerankHandler(c, resp, info) } else { if info.IsStream { - err, usage = cohereStreamHandler(c, info, resp) + usage, err = cohereStreamHandler(c, info, resp) // TODO: fix this } else { - err, usage = cohereHandler(c, info, resp) + usage, err = cohereHandler(c, info, resp) } } return diff --git a/relay/channel/cohere/relay-cohere.go b/relay/channel/cohere/relay-cohere.go index 7a727e20..fcfb12b7 100644 --- a/relay/channel/cohere/relay-cohere.go +++ b/relay/channel/cohere/relay-cohere.go @@ -78,7 +78,7 @@ func stopReasonCohere2OpenAI(reason string) string { } } -func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) { +func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { responseId := helper.GetResponseID(c) createdTime := common.GetTimestamp() usage := &dto.Usage{} @@ -166,20 +166,20 @@ func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http if usage.PromptTokens == 0 { usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } - return nil, usage + return usage, nil } -func cohereHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) { +func cohereHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { createdTime := common.GetTimestamp() responseBody, err := io.ReadAll(resp.Body) if err != nil { - return types.NewError(err, types.ErrorCodeBadResponseBody), nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } common.CloseResponseBodyGracefully(resp) var cohereResp CohereResponseResult err = json.Unmarshal(responseBody, &cohereResp) if err != nil { - return types.NewError(err, types.ErrorCodeBadResponseBody), nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } usage := dto.Usage{} usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens @@ -203,24 +203,24 @@ func cohereHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo jsonResponse, err := json.Marshal(openaiResp) if err != nil { - return types.NewError(err, types.ErrorCodeBadResponseBody), nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) - _, err = c.Writer.Write(jsonResponse) - return nil, &usage + _, _ = c.Writer.Write(jsonResponse) + return &usage, nil } -func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) { +func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) { responseBody, err := io.ReadAll(resp.Body) if err != nil { - return types.NewError(err, types.ErrorCodeBadResponseBody), nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } common.CloseResponseBodyGracefully(resp) var cohereResp CohereRerankResponseResult err = json.Unmarshal(responseBody, &cohereResp) if err != nil { - return types.NewError(err, types.ErrorCodeBadResponseBody), nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } usage := dto.Usage{} if cohereResp.Meta.BilledUnits.InputTokens == 0 { @@ -239,10 +239,10 @@ func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon. jsonResponse, err := json.Marshal(rerankResp) if err != nil { - return types.NewError(err, types.ErrorCodeBadResponseBody), nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) _, err = c.Writer.Write(jsonResponse) - return nil, &usage + return &usage, nil } diff --git a/relay/channel/coze/adaptor.go b/relay/channel/coze/adaptor.go index cbad91ba..fe5f5f00 100644 --- a/relay/channel/coze/adaptor.go +++ b/relay/channel/coze/adaptor.go @@ -98,9 +98,9 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *common.RelayInfo, requestBody // DoResponse implements channel.Adaptor. func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *common.RelayInfo) (usage any, err *types.NewAPIError) { if info.IsStream { - err, usage = cozeChatStreamHandler(c, info, resp) + usage, err = cozeChatStreamHandler(c, info, resp) } else { - err, usage = cozeChatHandler(c, info, resp) + usage, err = cozeChatHandler(c, info, resp) } return } diff --git a/relay/channel/coze/relay-coze.go b/relay/channel/coze/relay-coze.go index 42f8503e..32cc6937 100644 --- a/relay/channel/coze/relay-coze.go +++ b/relay/channel/coze/relay-coze.go @@ -44,10 +44,10 @@ func convertCozeChatRequest(c *gin.Context, request dto.GeneralOpenAIRequest) *C return cozeRequest } -func cozeChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) { +func cozeChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { responseBody, err := io.ReadAll(resp.Body) if err != nil { - return types.NewError(err, types.ErrorCodeBadResponseBody), nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } common.CloseResponseBodyGracefully(resp) // convert coze response to openai response @@ -56,10 +56,10 @@ func cozeChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res response.Model = info.UpstreamModelName err = json.Unmarshal(responseBody, &cozeResponse) if err != nil { - return types.NewError(err, types.ErrorCodeBadResponseBody), nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } if cozeResponse.Code != 0 { - return types.NewError(errors.New(cozeResponse.Msg), types.ErrorCodeBadResponseBody), nil + return nil, types.NewError(errors.New(cozeResponse.Msg), types.ErrorCodeBadResponseBody) } // 从上下文获取 usage var usage dto.Usage @@ -86,16 +86,16 @@ func cozeChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res } jsonResponse, err := json.Marshal(response) if err != nil { - return types.NewError(err, types.ErrorCodeBadResponseBody), nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) _, _ = c.Writer.Write(jsonResponse) - return nil, &usage + return &usage, nil } -func cozeChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) { +func cozeChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { scanner := bufio.NewScanner(resp.Body) scanner.Split(bufio.ScanLines) helper.SetEventStreamHeaders(c) @@ -136,7 +136,7 @@ func cozeChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *ht } if err := scanner.Err(); err != nil { - return types.NewError(err, types.ErrorCodeBadResponseBody), nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } helper.Done(c) @@ -144,7 +144,7 @@ func cozeChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *ht usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, c.GetInt("coze_input_count")) } - return nil, usage + return usage, nil } func handleCozeEvent(c *gin.Context, event string, data string, responseText *string, usage *dto.Usage, id string, info *relaycommon.RelayInfo) { diff --git a/relay/channel/ollama/relay-ollama.go b/relay/channel/ollama/relay-ollama.go index 974fe1bd..295349e3 100644 --- a/relay/channel/ollama/relay-ollama.go +++ b/relay/channel/ollama/relay-ollama.go @@ -96,7 +96,7 @@ func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } if ollamaEmbeddingResponse.Error != "" { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewError(fmt.Errorf("ollama error: %s", ollamaEmbeddingResponse.Error), types.ErrorCodeBadResponseBody) } flattenedEmbeddings := flattenEmbeddings(ollamaEmbeddingResponse.Embedding) data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1)