diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index 0ae8a8d1..5e31c753 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -120,15 +120,14 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom switch info.RelayFormat { case types.RelayFormatClaude: if info.IsStream { - err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage) + return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage) } else { - err, usage = claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage) + return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage) } default: adaptor := openai.Adaptor{} return adaptor.DoResponse(c, resp, info) } - return } func (a *Adaptor) GetModelList() []string { diff --git a/relay/channel/claude/adaptor.go b/relay/channel/claude/adaptor.go index c5f6efcc..959327e1 100644 --- a/relay/channel/claude/adaptor.go +++ b/relay/channel/claude/adaptor.go @@ -102,9 +102,9 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { if info.IsStream { - err, usage = ClaudeStreamHandler(c, resp, info, a.RequestMode) + return ClaudeStreamHandler(c, resp, info, a.RequestMode) } else { - err, usage = ClaudeHandler(c, resp, info, a.RequestMode) + return ClaudeHandler(c, resp, info, a.RequestMode) } return } diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index ad363352..0c445bb9 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -674,7 +674,7 @@ func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, clau } } -func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) { +func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.Usage, *types.NewAPIError) { claudeInfo := &ClaudeResponseInfo{ ResponseId: helper.GetResponseID(c), Created: common.GetTimestamp(), @@ -691,11 +691,11 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. return true }) if err != nil { - return err, nil + return nil, err } HandleStreamFinalResponse(c, info, claudeInfo, requestMode) - return nil, claudeInfo.Usage + return claudeInfo.Usage, nil } func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data []byte, requestMode int) *types.NewAPIError { @@ -740,7 +740,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud return nil } -func ClaudeHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) { +func ClaudeHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.Usage, *types.NewAPIError) { defer service.CloseResponseBodyGracefully(resp) claudeInfo := &ClaudeResponseInfo{ @@ -752,16 +752,16 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI } responseBody, err := io.ReadAll(resp.Body) if err != nil { - return types.NewError(err, types.ErrorCodeBadResponseBody), nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } if common.DebugEnabled { println("responseBody: ", string(responseBody)) } handleErr := HandleClaudeResponseData(c, info, claudeInfo, responseBody, requestMode) if handleErr != nil { - return handleErr, nil + return nil, handleErr } - return nil, claudeInfo.Usage + return claudeInfo.Usage, nil } func mapToolChoice(toolChoice any, parallelToolCalls *bool) *dto.ClaudeToolChoice { diff --git a/relay/channel/moonshot/adaptor.go b/relay/channel/moonshot/adaptor.go index 29004d0c..e290c239 100644 --- a/relay/channel/moonshot/adaptor.go +++ b/relay/channel/moonshot/adaptor.go @@ -89,17 +89,16 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { switch info.RelayFormat { - case types.RelayFormatOpenAI: - adaptor := openai.Adaptor{} - return adaptor.DoResponse(c, resp, info) case types.RelayFormatClaude: if info.IsStream { - err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage) + return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage) } else { - err, usage = claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage) + return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage) } + default: + adaptor := openai.Adaptor{} + return adaptor.DoResponse(c, resp, info) } - return } func (a *Adaptor) GetModelList() []string { diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go index 1359b4e9..0b6b2674 100644 --- a/relay/channel/vertex/adaptor.go +++ b/relay/channel/vertex/adaptor.go @@ -279,31 +279,31 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom if info.IsStream { switch a.RequestMode { case RequestModeClaude: - err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage) + return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage) case RequestModeGemini: if info.RelayMode == constant.RelayModeGemini { - usage, err = gemini.GeminiTextGenerationStreamHandler(c, info, resp) + return gemini.GeminiTextGenerationStreamHandler(c, info, resp) } else { - usage, err = gemini.GeminiChatStreamHandler(c, info, resp) + return gemini.GeminiChatStreamHandler(c, info, resp) } case RequestModeLlama: - usage, err = openai.OaiStreamHandler(c, info, resp) + return openai.OaiStreamHandler(c, info, resp) } } else { switch a.RequestMode { case RequestModeClaude: - err, usage = claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage) + return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage) case RequestModeGemini: if info.RelayMode == constant.RelayModeGemini { - usage, err = gemini.GeminiTextGenerationHandler(c, info, resp) + return gemini.GeminiTextGenerationHandler(c, info, resp) } else { if strings.HasPrefix(info.UpstreamModelName, "imagen") { return gemini.GeminiImageHandler(c, info, resp) } - usage, err = gemini.GeminiChatHandler(c, info, resp) + return gemini.GeminiChatHandler(c, info, resp) } case RequestModeLlama: - usage, err = openai.OpenaiHandler(c, info, resp) + return openai.OpenaiHandler(c, info, resp) } } return diff --git a/relay/channel/zhipu_4v/adaptor.go b/relay/channel/zhipu_4v/adaptor.go index 0fae3767..37c0c352 100644 --- a/relay/channel/zhipu_4v/adaptor.go +++ b/relay/channel/zhipu_4v/adaptor.go @@ -7,6 +7,7 @@ import ( "net/http" "one-api/dto" "one-api/relay/channel" + "one-api/relay/channel/claude" "one-api/relay/channel/openai" relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" @@ -23,10 +24,8 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt return nil, errors.New("not implemented") } -func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { - //TODO implement me - panic("implement me") - return nil, nil +func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) { + return req, nil } func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { @@ -43,12 +42,16 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - baseUrl := fmt.Sprintf("%s/api/paas/v4", info.ChannelBaseUrl) - switch info.RelayMode { - case relayconstant.RelayModeEmbeddings: - return fmt.Sprintf("%s/embeddings", baseUrl), nil + switch info.RelayFormat { + case types.RelayFormatClaude: + return fmt.Sprintf("%s/api/anthropic/v1/messages", info.ChannelBaseUrl), nil default: - return fmt.Sprintf("%s/chat/completions", baseUrl), nil + switch info.RelayMode { + case relayconstant.RelayModeEmbeddings: + return fmt.Sprintf("%s/api/paas/v4/embeddings", info.ChannelBaseUrl), nil + default: + return fmt.Sprintf("%s/api/paas/v4/chat/completions", info.ChannelBaseUrl), nil + } } } @@ -86,12 +89,17 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { - if info.IsStream { - usage, err = openai.OaiStreamHandler(c, info, resp) - } else { - usage, err = openai.OpenaiHandler(c, info, resp) + switch info.RelayFormat { + case types.RelayFormatClaude: + if info.IsStream { + return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage) + } else { + return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage) + } + default: + adaptor := openai.Adaptor{} + return adaptor.DoResponse(c, resp, info) } - return } func (a *Adaptor) GetModelList() []string {