feat(adaptor): refactor response handlers to return usage first and improve error handling

This commit is contained in:
CaIon
2025-07-12 21:12:46 +08:00
parent 20607b0b5c
commit 52a5e58f0c
5 changed files with 28 additions and 28 deletions

View File

@@ -74,12 +74,12 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.RelayMode == constant.RelayModeRerank { if info.RelayMode == constant.RelayModeRerank {
err, usage = cohereRerankHandler(c, resp, info) usage, err = cohereRerankHandler(c, resp, info)
} else { } else {
if info.IsStream { if info.IsStream {
err, usage = cohereStreamHandler(c, info, resp) usage, err = cohereStreamHandler(c, info, resp) // TODO: fix this
} else { } else {
err, usage = cohereHandler(c, info, resp) usage, err = cohereHandler(c, info, resp)
} }
} }
return return

View File

@@ -78,7 +78,7 @@ func stopReasonCohere2OpenAI(reason string) string {
} }
} }
func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) { func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
responseId := helper.GetResponseID(c) responseId := helper.GetResponseID(c)
createdTime := common.GetTimestamp() createdTime := common.GetTimestamp()
usage := &dto.Usage{} usage := &dto.Usage{}
@@ -166,20 +166,20 @@ func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
if usage.PromptTokens == 0 { if usage.PromptTokens == 0 {
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} }
return nil, usage return usage, nil
} }
func cohereHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) { func cohereHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
createdTime := common.GetTimestamp() createdTime := common.GetTimestamp()
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
} }
common.CloseResponseBodyGracefully(resp) common.CloseResponseBodyGracefully(resp)
var cohereResp CohereResponseResult var cohereResp CohereResponseResult
err = json.Unmarshal(responseBody, &cohereResp) err = json.Unmarshal(responseBody, &cohereResp)
if err != nil { if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
} }
usage := dto.Usage{} usage := dto.Usage{}
usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
@@ -203,24 +203,24 @@ func cohereHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
jsonResponse, err := json.Marshal(openaiResp) jsonResponse, err := json.Marshal(openaiResp)
if err != nil { if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
} }
c.Writer.Header().Set("Content-Type", "application/json") c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode) c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse) _, _ = c.Writer.Write(jsonResponse)
return nil, &usage return &usage, nil
} }
func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) { func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) {
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
} }
common.CloseResponseBodyGracefully(resp) common.CloseResponseBodyGracefully(resp)
var cohereResp CohereRerankResponseResult var cohereResp CohereRerankResponseResult
err = json.Unmarshal(responseBody, &cohereResp) err = json.Unmarshal(responseBody, &cohereResp)
if err != nil { if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
} }
usage := dto.Usage{} usage := dto.Usage{}
if cohereResp.Meta.BilledUnits.InputTokens == 0 { if cohereResp.Meta.BilledUnits.InputTokens == 0 {
@@ -239,10 +239,10 @@ func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.
jsonResponse, err := json.Marshal(rerankResp) jsonResponse, err := json.Marshal(rerankResp)
if err != nil { if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
} }
c.Writer.Header().Set("Content-Type", "application/json") c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode) c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse) _, err = c.Writer.Write(jsonResponse)
return nil, &usage return &usage, nil
} }

View File

@@ -98,9 +98,9 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *common.RelayInfo, requestBody
// DoResponse implements channel.Adaptor. // DoResponse implements channel.Adaptor.
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *common.RelayInfo) (usage any, err *types.NewAPIError) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *common.RelayInfo) (usage any, err *types.NewAPIError) {
if info.IsStream { if info.IsStream {
err, usage = cozeChatStreamHandler(c, info, resp) usage, err = cozeChatStreamHandler(c, info, resp)
} else { } else {
err, usage = cozeChatHandler(c, info, resp) usage, err = cozeChatHandler(c, info, resp)
} }
return return
} }

View File

@@ -44,10 +44,10 @@ func convertCozeChatRequest(c *gin.Context, request dto.GeneralOpenAIRequest) *C
return cozeRequest return cozeRequest
} }
func cozeChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) { func cozeChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
} }
common.CloseResponseBodyGracefully(resp) common.CloseResponseBodyGracefully(resp)
// convert coze response to openai response // convert coze response to openai response
@@ -56,10 +56,10 @@ func cozeChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res
response.Model = info.UpstreamModelName response.Model = info.UpstreamModelName
err = json.Unmarshal(responseBody, &cozeResponse) err = json.Unmarshal(responseBody, &cozeResponse)
if err != nil { if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
} }
if cozeResponse.Code != 0 { if cozeResponse.Code != 0 {
return types.NewError(errors.New(cozeResponse.Msg), types.ErrorCodeBadResponseBody), nil return nil, types.NewError(errors.New(cozeResponse.Msg), types.ErrorCodeBadResponseBody)
} }
// 从上下文获取 usage // 从上下文获取 usage
var usage dto.Usage var usage dto.Usage
@@ -86,16 +86,16 @@ func cozeChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res
} }
jsonResponse, err := json.Marshal(response) jsonResponse, err := json.Marshal(response)
if err != nil { if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
} }
c.Writer.Header().Set("Content-Type", "application/json") c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode) c.Writer.WriteHeader(resp.StatusCode)
_, _ = c.Writer.Write(jsonResponse) _, _ = c.Writer.Write(jsonResponse)
return nil, &usage return &usage, nil
} }
func cozeChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) { func cozeChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines) scanner.Split(bufio.ScanLines)
helper.SetEventStreamHeaders(c) helper.SetEventStreamHeaders(c)
@@ -136,7 +136,7 @@ func cozeChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *ht
} }
if err := scanner.Err(); err != nil { if err := scanner.Err(); err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
} }
helper.Done(c) helper.Done(c)
@@ -144,7 +144,7 @@ func cozeChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *ht
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, c.GetInt("coze_input_count")) usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, c.GetInt("coze_input_count"))
} }
return nil, usage return usage, nil
} }
func handleCozeEvent(c *gin.Context, event string, data string, responseText *string, usage *dto.Usage, id string, info *relaycommon.RelayInfo) { func handleCozeEvent(c *gin.Context, event string, data string, responseText *string, usage *dto.Usage, id string, info *relaycommon.RelayInfo) {

View File

@@ -96,7 +96,7 @@ func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
return nil, types.NewError(err, types.ErrorCodeBadResponseBody) return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
} }
if ollamaEmbeddingResponse.Error != "" { if ollamaEmbeddingResponse.Error != "" {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody) return nil, types.NewError(fmt.Errorf("ollama error: %s", ollamaEmbeddingResponse.Error), types.ErrorCodeBadResponseBody)
} }
flattenedEmbeddings := flattenEmbeddings(ollamaEmbeddingResponse.Embedding) flattenedEmbeddings := flattenEmbeddings(ollamaEmbeddingResponse.Embedding)
data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1) data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1)