feat(adaptor): refactor response handlers to return usage first and improve error handling

This commit is contained in:
CaIon
2025-07-12 21:12:46 +08:00
parent 20607b0b5c
commit 52a5e58f0c
5 changed files with 28 additions and 28 deletions

View File

@@ -74,12 +74,12 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.RelayMode == constant.RelayModeRerank {
err, usage = cohereRerankHandler(c, resp, info)
usage, err = cohereRerankHandler(c, resp, info)
} else {
if info.IsStream {
err, usage = cohereStreamHandler(c, info, resp)
usage, err = cohereStreamHandler(c, info, resp) // TODO: fix this
} else {
err, usage = cohereHandler(c, info, resp)
usage, err = cohereHandler(c, info, resp)
}
}
return

View File

@@ -78,7 +78,7 @@ func stopReasonCohere2OpenAI(reason string) string {
}
}
func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
responseId := helper.GetResponseID(c)
createdTime := common.GetTimestamp()
usage := &dto.Usage{}
@@ -166,20 +166,20 @@ func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
if usage.PromptTokens == 0 {
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
}
return nil, usage
return usage, nil
}
func cohereHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
func cohereHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
createdTime := common.GetTimestamp()
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
common.CloseResponseBodyGracefully(resp)
var cohereResp CohereResponseResult
err = json.Unmarshal(responseBody, &cohereResp)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
usage := dto.Usage{}
usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
@@ -203,24 +203,24 @@ func cohereHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
jsonResponse, err := json.Marshal(openaiResp)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &usage
_, _ = c.Writer.Write(jsonResponse)
return &usage, nil
}
func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
common.CloseResponseBodyGracefully(resp)
var cohereResp CohereRerankResponseResult
err = json.Unmarshal(responseBody, &cohereResp)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
usage := dto.Usage{}
if cohereResp.Meta.BilledUnits.InputTokens == 0 {
@@ -239,10 +239,10 @@ func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.
jsonResponse, err := json.Marshal(rerankResp)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &usage
return &usage, nil
}

View File

@@ -98,9 +98,9 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *common.RelayInfo, requestBody
// DoResponse implements channel.Adaptor.
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *common.RelayInfo) (usage any, err *types.NewAPIError) {
if info.IsStream {
err, usage = cozeChatStreamHandler(c, info, resp)
usage, err = cozeChatStreamHandler(c, info, resp)
} else {
err, usage = cozeChatHandler(c, info, resp)
usage, err = cozeChatHandler(c, info, resp)
}
return
}

View File

@@ -44,10 +44,10 @@ func convertCozeChatRequest(c *gin.Context, request dto.GeneralOpenAIRequest) *C
return cozeRequest
}
func cozeChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
func cozeChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
common.CloseResponseBodyGracefully(resp)
// convert coze response to openai response
@@ -56,10 +56,10 @@ func cozeChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res
response.Model = info.UpstreamModelName
err = json.Unmarshal(responseBody, &cozeResponse)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
if cozeResponse.Code != 0 {
return types.NewError(errors.New(cozeResponse.Msg), types.ErrorCodeBadResponseBody), nil
return nil, types.NewError(errors.New(cozeResponse.Msg), types.ErrorCodeBadResponseBody)
}
// 从上下文获取 usage
var usage dto.Usage
@@ -86,16 +86,16 @@ func cozeChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res
}
jsonResponse, err := json.Marshal(response)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, _ = c.Writer.Write(jsonResponse)
return nil, &usage
return &usage, nil
}
func cozeChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
func cozeChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
helper.SetEventStreamHeaders(c)
@@ -136,7 +136,7 @@ func cozeChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *ht
}
if err := scanner.Err(); err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
helper.Done(c)
@@ -144,7 +144,7 @@ func cozeChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *ht
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, c.GetInt("coze_input_count"))
}
return nil, usage
return usage, nil
}
func handleCozeEvent(c *gin.Context, event string, data string, responseText *string, usage *dto.Usage, id string, info *relaycommon.RelayInfo) {

View File

@@ -96,7 +96,7 @@ func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
if ollamaEmbeddingResponse.Error != "" {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
return nil, types.NewError(fmt.Errorf("ollama error: %s", ollamaEmbeddingResponse.Error), types.ErrorCodeBadResponseBody)
}
flattenedEmbeddings := flattenEmbeddings(ollamaEmbeddingResponse.Embedding)
data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1)