From 75b6327f4faeb27f249b7343361fd8e85695e0b4 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Tue, 9 Jan 2024 15:46:45 +0800 Subject: [PATCH] feat: support Azure dall-e --- controller/relay-audio.go | 18 +++++++++++++++++- controller/relay-image.go | 18 +++++++++++++++--- controller/relay-utils.go | 9 +++++++++ controller/relay.go | 18 ++++++++++++++---- 4 files changed, 55 insertions(+), 8 deletions(-) diff --git a/controller/relay-audio.go b/controller/relay-audio.go index bb1f5c59..ce2c7062 100644 --- a/controller/relay-audio.go +++ b/controller/relay-audio.go @@ -106,13 +106,29 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode } fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) + if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { + // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api + apiVersion := GetAPIVersion(c) + fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioRequest.Model, apiVersion) + } + requestBody := c.Request.Body req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) if err != nil { return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) } - req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) + + if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { + // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api + apiKey := c.Request.Header.Get("Authorization") + apiKey = strings.TrimPrefix(apiKey, "Bearer ") + req.Header.Set("api-key", apiKey) + req.ContentLength = c.Request.ContentLength + } else { + req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) + } + req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Accept", c.Request.Header.Get("Accept")) diff --git a/controller/relay-image.go b/controller/relay-image.go index fa9dc96e..a215b579 100644 --- a/controller/relay-image.go +++ b/controller/relay-image.go @@ -31,7 +31,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode } if imageRequest.Model == "" { - imageRequest.Model = "dall-e" + imageRequest.Model = "dall-e-2" } if imageRequest.Size == "" { imageRequest.Size = "1024x1024" @@ -86,8 +86,14 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode baseURL = c.GetString("base_url") } fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) + if channelType == common.ChannelTypeAzure && relayMode == RelayModeImagesGenerations { + // https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api + apiVersion := GetAPIVersion(c) + // https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview + fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageRequest.Model, apiVersion) + } var requestBody io.Reader - if isModelMapped { + if isModelMapped || channelType == common.ChannelTypeAzure { // make Azure channel request body jsonStr, err := json.Marshal(imageRequest) if err != nil { return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) @@ -132,8 +138,14 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if err != nil { return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) } - req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) + token := c.Request.Header.Get("Authorization") + if channelType == common.ChannelTypeAzure { // Azure authentication + token = strings.TrimPrefix(token, "Bearer ") + req.Header.Set("api-key", token) + } else { + req.Header.Set("Authorization", token) + } req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Accept", c.Request.Header.Get("Accept")) diff --git a/controller/relay-utils.go b/controller/relay-utils.go index 7d41a0c8..c22144f7 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -301,3 +301,12 @@ func getFullRequestURL(baseURL string, requestURL string, channelType int) strin } return fullRequestURL } + +func GetAPIVersion(c *gin.Context) string { + query := c.Request.URL.Query() + apiVersion := query.Get("api-version") + if apiVersion == "" { + apiVersion = c.GetString("api_version") + } + return apiVersion +} diff --git a/controller/relay.go b/controller/relay.go index 3850b2f8..175e335e 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -99,7 +99,9 @@ const ( RelayModeMidjourneyNotify RelayModeMidjourneyTaskFetch RelayModeMidjourneyTaskFetchByCondition - RelayModeAudio + RelayModeAudioSpeech + RelayModeAudioTranscription + RelayModeAudioTranslation ) // https://platform.openai.com/docs/api-reference/chat @@ -291,14 +293,22 @@ func Relay(c *gin.Context) { relayMode = RelayModeImagesGenerations } else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") { relayMode = RelayModeEdits - } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { - relayMode = RelayModeAudio + } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") { + relayMode = RelayModeAudioSpeech + } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") { + relayMode = RelayModeAudioTranscription + } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { + relayMode = RelayModeAudioTranslation } var err *OpenAIErrorWithStatusCode switch relayMode { case RelayModeImagesGenerations: err = relayImageHelper(c, relayMode) - case RelayModeAudio: + case RelayModeAudioSpeech: + fallthrough + case RelayModeAudioTranslation: + fallthrough + case RelayModeAudioTranscription: err = relayAudioHelper(c, relayMode) default: err = relayTextHelper(c, relayMode)