From 7fbf9c485120d2234c5ade344781d9977cfbf0d3 Mon Sep 17 00:00:00 2001 From: CaIon Date: Sun, 24 Aug 2025 21:52:56 +0800 Subject: [PATCH] feat: enhance image request handling and add async support --- dto/openai_image.go | 3 +++ relay/channel/ali/adaptor.go | 20 +++++++++++--- relay/channel/ali/dto.go | 27 +++++++++++-------- relay/channel/ali/image.go | 52 ++++++++++++++++++++++++++++-------- 4 files changed, 77 insertions(+), 25 deletions(-) diff --git a/dto/openai_image.go b/dto/openai_image.go index c26c4200..8833e774 100644 --- a/dto/openai_image.go +++ b/dto/openai_image.go @@ -25,6 +25,8 @@ type ImageRequest struct { PartialImages json.RawMessage `json:"partial_images,omitempty"` // Stream bool `json:"stream,omitempty"` Watermark *bool `json:"watermark,omitempty"` + // 用匿名参数接收额外参数 + Extra map[string]json.RawMessage `json:"-"` } func (i *ImageRequest) GetTokenCountMeta() *types.TokenCountMeta { @@ -72,6 +74,7 @@ func (i *ImageRequest) SetModelName(modelName string) { type ImageResponse struct { Data []ImageData `json:"data"` Created int64 `json:"created"` + Extra any `json:"extra,omitempty"` } type ImageData struct { Url string `json:"url"` diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index 5e31c753..c676badc 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -63,6 +63,9 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel if c.GetString("plugin") != "" { req.Set("X-DashScope-Plugin", c.GetString("plugin")) } + if info.RelayMode == constant.RelayModeImagesGenerations { + req.Set("X-DashScope-Async", "enable") + } return nil } @@ -90,7 +93,10 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn } func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { - aliRequest := oaiImage2Ali(request) + aliRequest, err := oaiImage2Ali(request) + if err != nil { + return nil, fmt.Errorf("convert image request failed: %w", err) + } return aliRequest, nil } @@ -125,8 +131,16 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage) } default: - adaptor := openai.Adaptor{} - return adaptor.DoResponse(c, resp, info) + switch info.RelayMode { + case constant.RelayModeImagesGenerations: + err, usage = aliImageHandler(c, resp, info) + case constant.RelayModeRerank: + err, usage = RerankHandler(c, resp, info) + default: + adaptor := openai.Adaptor{} + usage, err = adaptor.DoResponse(c, resp, info) + } + return usage, err } } diff --git a/relay/channel/ali/dto.go b/relay/channel/ali/dto.go index dbd18968..d40e077d 100644 --- a/relay/channel/ali/dto.go +++ b/relay/channel/ali/dto.go @@ -86,20 +86,25 @@ type AliResponse struct { } type AliImageRequest struct { - Model string `json:"model"` - Input struct { - Prompt string `json:"prompt"` - NegativePrompt string `json:"negative_prompt,omitempty"` - } `json:"input"` - Parameters struct { - Size string `json:"size,omitempty"` - N int `json:"n,omitempty"` - Steps string `json:"steps,omitempty"` - Scale string `json:"scale,omitempty"` - } `json:"parameters,omitempty"` + Model string `json:"model"` + Input any `json:"input"` + Parameters any `json:"parameters,omitempty"` ResponseFormat string `json:"response_format,omitempty"` } +type AliImageParameters struct { + Size string `json:"size,omitempty"` + N int `json:"n,omitempty"` + Steps string `json:"steps,omitempty"` + Scale string `json:"scale,omitempty"` + Watermark *bool `json:"watermark,omitempty"` +} + +type AliImageInput struct { + Prompt string `json:"prompt"` + NegativePrompt string `json:"negative_prompt,omitempty"` +} + type AliRerankParameters struct { TopN *int `json:"top_n,omitempty"` ReturnDocuments *bool `json:"return_documents,omitempty"` diff --git a/relay/channel/ali/image.go b/relay/channel/ali/image.go index 645882bc..f90f5a3a 100644 --- a/relay/channel/ali/image.go +++ b/relay/channel/ali/image.go @@ -18,15 +18,41 @@ import ( "github.com/gin-gonic/gin" ) -func oaiImage2Ali(request dto.ImageRequest) *AliImageRequest { +func oaiImage2Ali(request dto.ImageRequest) (*AliImageRequest, error) { var imageRequest AliImageRequest - imageRequest.Input.Prompt = request.Prompt imageRequest.Model = request.Model - imageRequest.Parameters.Size = strings.Replace(request.Size, "x", "*", -1) - imageRequest.Parameters.N = int(request.N) imageRequest.ResponseFormat = request.ResponseFormat - return &imageRequest + if request.Extra != nil { + if val, ok := request.Extra["parameters"]; ok { + err := common.Unmarshal(val, &imageRequest.Parameters) + if err != nil { + return nil, fmt.Errorf("invalid parameters field: %w", err) + } + } + if val, ok := request.Extra["input"]; ok { + err := common.Unmarshal(val, &imageRequest.Input) + if err != nil { + return nil, fmt.Errorf("invalid input field: %w", err) + } + } + } + + if imageRequest.Parameters == nil { + imageRequest.Parameters = AliImageParameters{ + Size: strings.Replace(request.Size, "x", "*", -1), + N: int(request.N), + Watermark: request.Watermark, + } + } + + if imageRequest.Input == nil { + imageRequest.Input = AliImageInput{ + Prompt: request.Prompt, + } + } + + return &imageRequest, nil } func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error, []byte) { @@ -52,7 +78,7 @@ func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error responseBody, err := io.ReadAll(resp.Body) var response AliResponse - err = json.Unmarshal(responseBody, &response) + err = common.Unmarshal(responseBody, &response) if err != nil { common.SysLog("updateTask NewDecoder err: " + err.Error()) return &aliResponse, err, nil @@ -61,8 +87,8 @@ func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error return &response, nil, responseBody } -func asyncTaskWait(info *relaycommon.RelayInfo, taskID string) (*AliResponse, []byte, error) { - waitSeconds := 3 +func asyncTaskWait(c *gin.Context, info *relaycommon.RelayInfo, taskID string) (*AliResponse, []byte, error) { + waitSeconds := 5 step := 0 maxStep := 20 @@ -70,11 +96,14 @@ func asyncTaskWait(info *relaycommon.RelayInfo, taskID string) (*AliResponse, [] var responseBody []byte for { + logger.LogDebug(c, fmt.Sprintf("asyncTaskWait step %d/%d, wait %d seconds", step, maxStep, waitSeconds)) step++ rsp, err, body := updateTask(info, taskID) responseBody = body if err != nil { - return &taskResponse, responseBody, err + logger.LogWarn(c, "asyncTaskWait UpdateTask err: "+err.Error()) + time.Sleep(time.Duration(waitSeconds) * time.Second) + continue } if rsp.Output.TaskStatus == "" { @@ -124,6 +153,7 @@ func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, info *relayc RevisedPrompt: "", }) } + imageResponse.Extra = response return &imageResponse } @@ -146,7 +176,7 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela return types.NewError(errors.New(aliTaskResponse.Message), types.ErrorCodeBadResponse), nil } - aliResponse, _, err := asyncTaskWait(info, aliTaskResponse.Output.TaskId) + aliResponse, _, err := asyncTaskWait(c, info, aliTaskResponse.Output.TaskId) if err != nil { return types.NewError(err, types.ErrorCodeBadResponse), nil } @@ -161,7 +191,7 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela } fullTextResponse := responseAli2OpenAIImage(c, aliResponse, info, responseFormat) - jsonResponse, err := json.Marshal(fullTextResponse) + jsonResponse, err := common.Marshal(fullTextResponse) if err != nil { return types.NewError(err, types.ErrorCodeBadResponseBody), nil }