feat: enhance image request handling and add async support

This commit is contained in:
CaIon
2025-08-24 21:52:56 +08:00
parent 808f5c481e
commit 7fbf9c4851
4 changed files with 77 additions and 25 deletions

View File

@@ -25,6 +25,8 @@ type ImageRequest struct {
PartialImages json.RawMessage `json:"partial_images,omitempty"` PartialImages json.RawMessage `json:"partial_images,omitempty"`
// Stream bool `json:"stream,omitempty"` // Stream bool `json:"stream,omitempty"`
Watermark *bool `json:"watermark,omitempty"` Watermark *bool `json:"watermark,omitempty"`
// 用匿名参数接收额外参数
Extra map[string]json.RawMessage `json:"-"`
} }
func (i *ImageRequest) GetTokenCountMeta() *types.TokenCountMeta { func (i *ImageRequest) GetTokenCountMeta() *types.TokenCountMeta {
@@ -72,6 +74,7 @@ func (i *ImageRequest) SetModelName(modelName string) {
type ImageResponse struct { type ImageResponse struct {
Data []ImageData `json:"data"` Data []ImageData `json:"data"`
Created int64 `json:"created"` Created int64 `json:"created"`
Extra any `json:"extra,omitempty"`
} }
type ImageData struct { type ImageData struct {
Url string `json:"url"` Url string `json:"url"`

View File

@@ -63,6 +63,9 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
if c.GetString("plugin") != "" { if c.GetString("plugin") != "" {
req.Set("X-DashScope-Plugin", c.GetString("plugin")) req.Set("X-DashScope-Plugin", c.GetString("plugin"))
} }
if info.RelayMode == constant.RelayModeImagesGenerations {
req.Set("X-DashScope-Async", "enable")
}
return nil return nil
} }
@@ -90,7 +93,10 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
} }
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
aliRequest := oaiImage2Ali(request) aliRequest, err := oaiImage2Ali(request)
if err != nil {
return nil, fmt.Errorf("convert image request failed: %w", err)
}
return aliRequest, nil return aliRequest, nil
} }
@@ -124,9 +130,17 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
} else { } else {
return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage) return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
} }
default:
switch info.RelayMode {
case constant.RelayModeImagesGenerations:
err, usage = aliImageHandler(c, resp, info)
case constant.RelayModeRerank:
err, usage = RerankHandler(c, resp, info)
default: default:
adaptor := openai.Adaptor{} adaptor := openai.Adaptor{}
return adaptor.DoResponse(c, resp, info) usage, err = adaptor.DoResponse(c, resp, info)
}
return usage, err
} }
} }

View File

@@ -87,17 +87,22 @@ type AliResponse struct {
type AliImageRequest struct { type AliImageRequest struct {
Model string `json:"model"` Model string `json:"model"`
Input struct { Input any `json:"input"`
Prompt string `json:"prompt"` Parameters any `json:"parameters,omitempty"`
NegativePrompt string `json:"negative_prompt,omitempty"` ResponseFormat string `json:"response_format,omitempty"`
} `json:"input"` }
Parameters struct {
type AliImageParameters struct {
Size string `json:"size,omitempty"` Size string `json:"size,omitempty"`
N int `json:"n,omitempty"` N int `json:"n,omitempty"`
Steps string `json:"steps,omitempty"` Steps string `json:"steps,omitempty"`
Scale string `json:"scale,omitempty"` Scale string `json:"scale,omitempty"`
} `json:"parameters,omitempty"` Watermark *bool `json:"watermark,omitempty"`
ResponseFormat string `json:"response_format,omitempty"` }
type AliImageInput struct {
Prompt string `json:"prompt"`
NegativePrompt string `json:"negative_prompt,omitempty"`
} }
type AliRerankParameters struct { type AliRerankParameters struct {

View File

@@ -18,15 +18,41 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
func oaiImage2Ali(request dto.ImageRequest) *AliImageRequest { func oaiImage2Ali(request dto.ImageRequest) (*AliImageRequest, error) {
var imageRequest AliImageRequest var imageRequest AliImageRequest
imageRequest.Input.Prompt = request.Prompt
imageRequest.Model = request.Model imageRequest.Model = request.Model
imageRequest.Parameters.Size = strings.Replace(request.Size, "x", "*", -1)
imageRequest.Parameters.N = int(request.N)
imageRequest.ResponseFormat = request.ResponseFormat imageRequest.ResponseFormat = request.ResponseFormat
return &imageRequest if request.Extra != nil {
if val, ok := request.Extra["parameters"]; ok {
err := common.Unmarshal(val, &imageRequest.Parameters)
if err != nil {
return nil, fmt.Errorf("invalid parameters field: %w", err)
}
}
if val, ok := request.Extra["input"]; ok {
err := common.Unmarshal(val, &imageRequest.Input)
if err != nil {
return nil, fmt.Errorf("invalid input field: %w", err)
}
}
}
if imageRequest.Parameters == nil {
imageRequest.Parameters = AliImageParameters{
Size: strings.Replace(request.Size, "x", "*", -1),
N: int(request.N),
Watermark: request.Watermark,
}
}
if imageRequest.Input == nil {
imageRequest.Input = AliImageInput{
Prompt: request.Prompt,
}
}
return &imageRequest, nil
} }
func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error, []byte) { func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error, []byte) {
@@ -52,7 +78,7 @@ func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
var response AliResponse var response AliResponse
err = json.Unmarshal(responseBody, &response) err = common.Unmarshal(responseBody, &response)
if err != nil { if err != nil {
common.SysLog("updateTask NewDecoder err: " + err.Error()) common.SysLog("updateTask NewDecoder err: " + err.Error())
return &aliResponse, err, nil return &aliResponse, err, nil
@@ -61,8 +87,8 @@ func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error
return &response, nil, responseBody return &response, nil, responseBody
} }
func asyncTaskWait(info *relaycommon.RelayInfo, taskID string) (*AliResponse, []byte, error) { func asyncTaskWait(c *gin.Context, info *relaycommon.RelayInfo, taskID string) (*AliResponse, []byte, error) {
waitSeconds := 3 waitSeconds := 5
step := 0 step := 0
maxStep := 20 maxStep := 20
@@ -70,11 +96,14 @@ func asyncTaskWait(info *relaycommon.RelayInfo, taskID string) (*AliResponse, []
var responseBody []byte var responseBody []byte
for { for {
logger.LogDebug(c, fmt.Sprintf("asyncTaskWait step %d/%d, wait %d seconds", step, maxStep, waitSeconds))
step++ step++
rsp, err, body := updateTask(info, taskID) rsp, err, body := updateTask(info, taskID)
responseBody = body responseBody = body
if err != nil { if err != nil {
return &taskResponse, responseBody, err logger.LogWarn(c, "asyncTaskWait UpdateTask err: "+err.Error())
time.Sleep(time.Duration(waitSeconds) * time.Second)
continue
} }
if rsp.Output.TaskStatus == "" { if rsp.Output.TaskStatus == "" {
@@ -124,6 +153,7 @@ func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, info *relayc
RevisedPrompt: "", RevisedPrompt: "",
}) })
} }
imageResponse.Extra = response
return &imageResponse return &imageResponse
} }
@@ -146,7 +176,7 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
return types.NewError(errors.New(aliTaskResponse.Message), types.ErrorCodeBadResponse), nil return types.NewError(errors.New(aliTaskResponse.Message), types.ErrorCodeBadResponse), nil
} }
aliResponse, _, err := asyncTaskWait(info, aliTaskResponse.Output.TaskId) aliResponse, _, err := asyncTaskWait(c, info, aliTaskResponse.Output.TaskId)
if err != nil { if err != nil {
return types.NewError(err, types.ErrorCodeBadResponse), nil return types.NewError(err, types.ErrorCodeBadResponse), nil
} }
@@ -161,7 +191,7 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
} }
fullTextResponse := responseAli2OpenAIImage(c, aliResponse, info, responseFormat) fullTextResponse := responseAli2OpenAIImage(c, aliResponse, info, responseFormat)
jsonResponse, err := json.Marshal(fullTextResponse) jsonResponse, err := common.Marshal(fullTextResponse)
if err != nil { if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil return types.NewError(err, types.ErrorCodeBadResponseBody), nil
} }