diff --git a/controller/relay.go b/controller/relay.go index 3660e8be..d4b5fd18 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -56,7 +56,7 @@ func relayHandler(c *gin.Context, relayMode int) *types.NewAPIError { userGroup := c.GetString("group") channelId := c.GetInt("channel_id") other := make(map[string]interface{}) - other["error_type"] = err.ErrorType + other["error_type"] = err.GetErrorType() other["error_code"] = err.GetErrorCode() other["status_code"] = err.StatusCode other["channel_id"] = channelId diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go index b9e304fc..8fd1e1bf 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/channel/ollama/adaptor.go @@ -17,10 +17,13 @@ import ( type Adaptor struct { } -func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { - //TODO implement me - panic("implement me") - return nil, nil +func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) { + openaiAdaptor := openai.Adaptor{} + openaiRequest, err := openaiAdaptor.ConvertClaudeRequest(c, info, request) + if err != nil { + return nil, err + } + return requestOpenAI2Ollama(openaiRequest.(*dto.GeneralOpenAIRequest)) } func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { @@ -37,6 +40,9 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { + if info.RelayFormat == relaycommon.RelayFormatClaude { + return info.BaseUrl + "/v1/chat/completions", nil + } switch info.RelayMode { case relayconstant.RelayModeEmbeddings: return info.BaseUrl + "/api/embed", nil @@ -55,7 +61,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn if request == nil { return nil, errors.New("request is nil") } - return requestOpenAI2Ollama(*request) + return requestOpenAI2Ollama(request) } func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { @@ -85,6 +91,16 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom usage, err = openai.OpenaiHandler(c, info, resp) } } + switch info.RelayMode { + case relayconstant.RelayModeEmbeddings: + usage, err = ollamaEmbeddingHandler(c, info, resp) + default: + if info.IsStream { + usage, err = openai.OaiStreamHandler(c, info, resp) + } else { + usage, err = openai.OpenaiHandler(c, info, resp) + } + } return } diff --git a/relay/channel/ollama/relay-ollama.go b/relay/channel/ollama/relay-ollama.go index cd899b83..f98dfc73 100644 --- a/relay/channel/ollama/relay-ollama.go +++ b/relay/channel/ollama/relay-ollama.go @@ -14,7 +14,7 @@ import ( "github.com/gin-gonic/gin" ) -func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) (*OllamaRequest, error) { +func requestOpenAI2Ollama(request *dto.GeneralOpenAIRequest) (*OllamaRequest, error) { messages := make([]dto.Message, 0, len(request.Messages)) for _, message := range request.Messages { if !message.IsStringContent() { @@ -92,15 +92,15 @@ func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h var ollamaEmbeddingResponse OllamaEmbeddingResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } common.CloseResponseBodyGracefully(resp) err = common.Unmarshal(responseBody, &ollamaEmbeddingResponse) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } if ollamaEmbeddingResponse.Error != "" { - return nil, types.NewError(fmt.Errorf("ollama error: %s", ollamaEmbeddingResponse.Error), types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(fmt.Errorf("ollama error: %s", ollamaEmbeddingResponse.Error), types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } flattenedEmbeddings := flattenEmbeddings(ollamaEmbeddingResponse.Embedding) data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1) @@ -121,7 +121,7 @@ func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h } doResponseBody, err := common.Marshal(embeddingResponse) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } common.IOCopyBytesGracefully(c, resp, doResponseBody) return usage, nil diff --git a/service/error.go b/service/error.go index a0713b55..83979add 100644 --- a/service/error.go +++ b/service/error.go @@ -80,10 +80,7 @@ func ClaudeErrorWrapperLocal(err error, code string, statusCode int) *dto.Claude } func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *types.NewAPIError) { - newApiErr = &types.NewAPIError{ - StatusCode: resp.StatusCode, - ErrorType: types.ErrorTypeOpenAIError, - } + newApiErr = types.InitOpenAIError(types.ErrorCodeBadResponseStatusCode, resp.StatusCode) responseBody, err := io.ReadAll(resp.Body) if err != nil { @@ -105,8 +102,7 @@ func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *t // General format error (OpenAI, Anthropic, Gemini, etc.) newApiErr = types.WithOpenAIError(errResponse.Error, resp.StatusCode) } else { - newApiErr = types.NewErrorWithStatusCode(errors.New(errResponse.ToMessage()), types.ErrorCodeBadResponseStatusCode, resp.StatusCode) - newApiErr.ErrorType = types.ErrorTypeOpenAIError + newApiErr = types.NewOpenAIError(errors.New(errResponse.ToMessage()), types.ErrorCodeBadResponseStatusCode, resp.StatusCode) } return } diff --git a/types/error.go b/types/error.go index c301e59c..4ffae2d7 100644 --- a/types/error.go +++ b/types/error.go @@ -75,7 +75,7 @@ const ( type NewAPIError struct { Err error RelayError any - ErrorType ErrorType + errorType ErrorType errorCode ErrorCode StatusCode int } @@ -87,6 +87,13 @@ func (e *NewAPIError) GetErrorCode() ErrorCode { return e.errorCode } +func (e *NewAPIError) GetErrorType() ErrorType { + if e == nil { + return "" + } + return e.errorType +} + func (e *NewAPIError) Error() string { if e == nil { return "" @@ -103,7 +110,7 @@ func (e *NewAPIError) SetMessage(message string) { } func (e *NewAPIError) ToOpenAIError() OpenAIError { - switch e.ErrorType { + switch e.errorType { case ErrorTypeOpenAIError: if openAIError, ok := e.RelayError.(OpenAIError); ok { return openAIError @@ -120,14 +127,14 @@ func (e *NewAPIError) ToOpenAIError() OpenAIError { } return OpenAIError{ Message: e.Error(), - Type: string(e.ErrorType), + Type: string(e.errorType), Param: "", Code: e.errorCode, } } func (e *NewAPIError) ToClaudeError() ClaudeError { - switch e.ErrorType { + switch e.errorType { case ErrorTypeOpenAIError: openAIError := e.RelayError.(OpenAIError) return ClaudeError{ @@ -139,7 +146,7 @@ func (e *NewAPIError) ToClaudeError() ClaudeError { default: return ClaudeError{ Message: e.Error(), - Type: string(e.ErrorType), + Type: string(e.errorType), } } } @@ -148,7 +155,7 @@ func NewError(err error, errorCode ErrorCode) *NewAPIError { return &NewAPIError{ Err: err, RelayError: nil, - ErrorType: ErrorTypeNewAPIError, + errorType: ErrorTypeNewAPIError, StatusCode: http.StatusInternalServerError, errorCode: errorCode, } @@ -162,6 +169,13 @@ func NewOpenAIError(err error, errorCode ErrorCode, statusCode int) *NewAPIError return WithOpenAIError(openaiError, statusCode) } +func InitOpenAIError(errorCode ErrorCode, statusCode int) *NewAPIError { + openaiError := OpenAIError{ + Type: string(errorCode), + } + return WithOpenAIError(openaiError, statusCode) +} + func NewErrorWithStatusCode(err error, errorCode ErrorCode, statusCode int) *NewAPIError { return &NewAPIError{ Err: err, @@ -169,7 +183,7 @@ func NewErrorWithStatusCode(err error, errorCode ErrorCode, statusCode int) *New Message: err.Error(), Type: string(errorCode), }, - ErrorType: ErrorTypeNewAPIError, + errorType: ErrorTypeNewAPIError, StatusCode: statusCode, errorCode: errorCode, } @@ -182,7 +196,7 @@ func WithOpenAIError(openAIError OpenAIError, statusCode int) *NewAPIError { } return &NewAPIError{ RelayError: openAIError, - ErrorType: ErrorTypeOpenAIError, + errorType: ErrorTypeOpenAIError, StatusCode: statusCode, Err: errors.New(openAIError.Message), errorCode: ErrorCode(code), @@ -192,7 +206,7 @@ func WithOpenAIError(openAIError OpenAIError, statusCode int) *NewAPIError { func WithClaudeError(claudeError ClaudeError, statusCode int) *NewAPIError { return &NewAPIError{ RelayError: claudeError, - ErrorType: ErrorTypeClaudeError, + errorType: ErrorTypeClaudeError, StatusCode: statusCode, Err: errors.New(claudeError.Message), errorCode: ErrorCode(claudeError.Type), @@ -211,5 +225,5 @@ func IsLocalError(err *NewAPIError) bool { return false } - return err.ErrorType == ErrorTypeNewAPIError + return err.errorType == ErrorTypeNewAPIError }