feat: support /images/edit
(cherry picked from commit 1c0a1238787d490f02dd9269b616580a16604180)
This commit is contained in:
@@ -236,11 +236,64 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
return request, nil
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeImagesEdits:
|
||||
body, err := common.GetRequestBody(c)
|
||||
if err != nil {
|
||||
return nil, errors.New("get request body fail")
|
||||
}
|
||||
return bytes.NewReader(body), nil
|
||||
|
||||
/*var requestBody bytes.Buffer
|
||||
writer := multipart.NewWriter(&requestBody)
|
||||
|
||||
writer.WriteField("model", request.Model)
|
||||
// 获取所有表单字段
|
||||
formData := c.Request.PostForm
|
||||
// 遍历表单字段并打印输出
|
||||
for key, values := range formData {
|
||||
if key == "model" {
|
||||
continue
|
||||
}
|
||||
for _, value := range values {
|
||||
writer.WriteField(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
// 添加文件字段
|
||||
imageFiles := c.Request.MultipartForm.File["image[]"]
|
||||
for _, file := range imageFiles {
|
||||
part, err := writer.CreateFormFile("image[]", file.Filename)
|
||||
if err != nil {
|
||||
return nil, errors.New("create form file failed")
|
||||
}
|
||||
// 打开文件
|
||||
src, err := file.Open()
|
||||
if err != nil {
|
||||
return nil, errors.New("open file failed")
|
||||
}
|
||||
// 将文件数据写入 form part
|
||||
_, err = io.Copy(part, src)
|
||||
if err != nil {
|
||||
return nil, errors.New("copy file failed")
|
||||
}
|
||||
src.Close()
|
||||
}
|
||||
|
||||
// 关闭 multipart 编写器以设置分界线
|
||||
writer.Close()
|
||||
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
return bytes.NewReader(requestBody.Bytes()), nil*/
|
||||
|
||||
default:
|
||||
return request, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation {
|
||||
if info.RelayMode == constant.RelayModeAudioTranscription ||
|
||||
info.RelayMode == constant.RelayModeAudioTranslation ||
|
||||
info.RelayMode == constant.RelayModeImagesEdits {
|
||||
return channel.DoFormRequest(a, c, info, requestBody)
|
||||
} else if info.RelayMode == constant.RelayModeRealtime {
|
||||
return channel.DoWssRequest(a, c, info, requestBody)
|
||||
@@ -259,8 +312,8 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
fallthrough
|
||||
case constant.RelayModeAudioTranscription:
|
||||
err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat)
|
||||
case constant.RelayModeImagesGenerations:
|
||||
err, usage = OpenaiTTSHandler(c, resp, info)
|
||||
case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits:
|
||||
err, usage = OpenaiHandlerWithUsage(c, resp, info)
|
||||
case constant.RelayModeRerank:
|
||||
err, usage = common_handler.RerankHandler(c, info, resp)
|
||||
default:
|
||||
|
||||
@@ -595,3 +595,52 @@ func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.R
|
||||
err := service.PreWssConsumeQuota(ctx, info, usage)
|
||||
return err
|
||||
}
|
||||
|
||||
func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
// Reset response body
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
||||
// And then we will have to send an error response, but in this case, the header has already been set.
|
||||
// So the httpClient will be confused by the response.
|
||||
// For example, Postman will report error, and we cannot check the response at all.
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
// reset content length
|
||||
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(responseBody)))
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
var usageResp dto.SimpleResponse
|
||||
err = json.Unmarshal(responseBody, &usageResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "parse_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
// format
|
||||
if usageResp.InputTokens > 0 {
|
||||
usageResp.PromptTokens += usageResp.InputTokens
|
||||
}
|
||||
if usageResp.OutputTokens > 0 {
|
||||
usageResp.CompletionTokens += usageResp.OutputTokens
|
||||
}
|
||||
if usageResp.InputTokensDetails != nil {
|
||||
usageResp.PromptTokensDetails.ImageTokens += usageResp.InputTokensDetails.ImageTokens
|
||||
usageResp.PromptTokensDetails.TextTokens += usageResp.InputTokensDetails.TextTokens
|
||||
}
|
||||
return nil, &usageResp.Usage
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user