✨ feat(adaptor): refactor response handlers to return usage first and improve error handling
This commit is contained in:
@@ -74,12 +74,12 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
|||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||||
if info.RelayMode == constant.RelayModeRerank {
|
if info.RelayMode == constant.RelayModeRerank {
|
||||||
err, usage = cohereRerankHandler(c, resp, info)
|
usage, err = cohereRerankHandler(c, resp, info)
|
||||||
} else {
|
} else {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = cohereStreamHandler(c, info, resp)
|
usage, err = cohereStreamHandler(c, info, resp) // TODO: fix this
|
||||||
} else {
|
} else {
|
||||||
err, usage = cohereHandler(c, info, resp)
|
usage, err = cohereHandler(c, info, resp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -78,7 +78,7 @@ func stopReasonCohere2OpenAI(reason string) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||||
responseId := helper.GetResponseID(c)
|
responseId := helper.GetResponseID(c)
|
||||||
createdTime := common.GetTimestamp()
|
createdTime := common.GetTimestamp()
|
||||||
usage := &dto.Usage{}
|
usage := &dto.Usage{}
|
||||||
@@ -166,20 +166,20 @@ func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
|
|||||||
if usage.PromptTokens == 0 {
|
if usage.PromptTokens == 0 {
|
||||||
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
}
|
}
|
||||||
return nil, usage
|
return usage, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func cohereHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
func cohereHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||||
createdTime := common.GetTimestamp()
|
createdTime := common.GetTimestamp()
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
common.CloseResponseBodyGracefully(resp)
|
common.CloseResponseBodyGracefully(resp)
|
||||||
var cohereResp CohereResponseResult
|
var cohereResp CohereResponseResult
|
||||||
err = json.Unmarshal(responseBody, &cohereResp)
|
err = json.Unmarshal(responseBody, &cohereResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
usage := dto.Usage{}
|
usage := dto.Usage{}
|
||||||
usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
|
usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
|
||||||
@@ -203,24 +203,24 @@ func cohereHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
|
|||||||
|
|
||||||
jsonResponse, err := json.Marshal(openaiResp)
|
jsonResponse, err := json.Marshal(openaiResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
_, err = c.Writer.Write(jsonResponse)
|
_, _ = c.Writer.Write(jsonResponse)
|
||||||
return nil, &usage
|
return &usage, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
|
func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) {
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
common.CloseResponseBodyGracefully(resp)
|
common.CloseResponseBodyGracefully(resp)
|
||||||
var cohereResp CohereRerankResponseResult
|
var cohereResp CohereRerankResponseResult
|
||||||
err = json.Unmarshal(responseBody, &cohereResp)
|
err = json.Unmarshal(responseBody, &cohereResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
usage := dto.Usage{}
|
usage := dto.Usage{}
|
||||||
if cohereResp.Meta.BilledUnits.InputTokens == 0 {
|
if cohereResp.Meta.BilledUnits.InputTokens == 0 {
|
||||||
@@ -239,10 +239,10 @@ func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
|||||||
|
|
||||||
jsonResponse, err := json.Marshal(rerankResp)
|
jsonResponse, err := json.Marshal(rerankResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
_, err = c.Writer.Write(jsonResponse)
|
_, err = c.Writer.Write(jsonResponse)
|
||||||
return nil, &usage
|
return &usage, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -98,9 +98,9 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *common.RelayInfo, requestBody
|
|||||||
// DoResponse implements channel.Adaptor.
|
// DoResponse implements channel.Adaptor.
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *common.RelayInfo) (usage any, err *types.NewAPIError) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *common.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = cozeChatStreamHandler(c, info, resp)
|
usage, err = cozeChatStreamHandler(c, info, resp)
|
||||||
} else {
|
} else {
|
||||||
err, usage = cozeChatHandler(c, info, resp)
|
usage, err = cozeChatHandler(c, info, resp)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -44,10 +44,10 @@ func convertCozeChatRequest(c *gin.Context, request dto.GeneralOpenAIRequest) *C
|
|||||||
return cozeRequest
|
return cozeRequest
|
||||||
}
|
}
|
||||||
|
|
||||||
func cozeChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
func cozeChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
common.CloseResponseBodyGracefully(resp)
|
common.CloseResponseBodyGracefully(resp)
|
||||||
// convert coze response to openai response
|
// convert coze response to openai response
|
||||||
@@ -56,10 +56,10 @@ func cozeChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res
|
|||||||
response.Model = info.UpstreamModelName
|
response.Model = info.UpstreamModelName
|
||||||
err = json.Unmarshal(responseBody, &cozeResponse)
|
err = json.Unmarshal(responseBody, &cozeResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
if cozeResponse.Code != 0 {
|
if cozeResponse.Code != 0 {
|
||||||
return types.NewError(errors.New(cozeResponse.Msg), types.ErrorCodeBadResponseBody), nil
|
return nil, types.NewError(errors.New(cozeResponse.Msg), types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
// 从上下文获取 usage
|
// 从上下文获取 usage
|
||||||
var usage dto.Usage
|
var usage dto.Usage
|
||||||
@@ -86,16 +86,16 @@ func cozeChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res
|
|||||||
}
|
}
|
||||||
jsonResponse, err := json.Marshal(response)
|
jsonResponse, err := json.Marshal(response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
_, _ = c.Writer.Write(jsonResponse)
|
_, _ = c.Writer.Write(jsonResponse)
|
||||||
|
|
||||||
return nil, &usage
|
return &usage, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func cozeChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
func cozeChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
scanner.Split(bufio.ScanLines)
|
scanner.Split(bufio.ScanLines)
|
||||||
helper.SetEventStreamHeaders(c)
|
helper.SetEventStreamHeaders(c)
|
||||||
@@ -136,7 +136,7 @@ func cozeChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *ht
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := scanner.Err(); err != nil {
|
if err := scanner.Err(); err != nil {
|
||||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
helper.Done(c)
|
helper.Done(c)
|
||||||
|
|
||||||
@@ -144,7 +144,7 @@ func cozeChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *ht
|
|||||||
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, c.GetInt("coze_input_count"))
|
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, c.GetInt("coze_input_count"))
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, usage
|
return usage, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleCozeEvent(c *gin.Context, event string, data string, responseText *string, usage *dto.Usage, id string, info *relaycommon.RelayInfo) {
|
func handleCozeEvent(c *gin.Context, event string, data string, responseText *string, usage *dto.Usage, id string, info *relaycommon.RelayInfo) {
|
||||||
|
|||||||
@@ -96,7 +96,7 @@ func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
|
|||||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
if ollamaEmbeddingResponse.Error != "" {
|
if ollamaEmbeddingResponse.Error != "" {
|
||||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
return nil, types.NewError(fmt.Errorf("ollama error: %s", ollamaEmbeddingResponse.Error), types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
flattenedEmbeddings := flattenEmbeddings(ollamaEmbeddingResponse.Embedding)
|
flattenedEmbeddings := flattenEmbeddings(ollamaEmbeddingResponse.Embedding)
|
||||||
data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1)
|
data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1)
|
||||||
|
|||||||
Reference in New Issue
Block a user