feat: enhance image request handling and add async support
This commit is contained in:
@@ -25,6 +25,8 @@ type ImageRequest struct {
|
|||||||
PartialImages json.RawMessage `json:"partial_images,omitempty"`
|
PartialImages json.RawMessage `json:"partial_images,omitempty"`
|
||||||
// Stream bool `json:"stream,omitempty"`
|
// Stream bool `json:"stream,omitempty"`
|
||||||
Watermark *bool `json:"watermark,omitempty"`
|
Watermark *bool `json:"watermark,omitempty"`
|
||||||
|
// 用匿名参数接收额外参数
|
||||||
|
Extra map[string]json.RawMessage `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *ImageRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
func (i *ImageRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||||
@@ -72,6 +74,7 @@ func (i *ImageRequest) SetModelName(modelName string) {
|
|||||||
type ImageResponse struct {
|
type ImageResponse struct {
|
||||||
Data []ImageData `json:"data"`
|
Data []ImageData `json:"data"`
|
||||||
Created int64 `json:"created"`
|
Created int64 `json:"created"`
|
||||||
|
Extra any `json:"extra,omitempty"`
|
||||||
}
|
}
|
||||||
type ImageData struct {
|
type ImageData struct {
|
||||||
Url string `json:"url"`
|
Url string `json:"url"`
|
||||||
|
|||||||
@@ -63,6 +63,9 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
|||||||
if c.GetString("plugin") != "" {
|
if c.GetString("plugin") != "" {
|
||||||
req.Set("X-DashScope-Plugin", c.GetString("plugin"))
|
req.Set("X-DashScope-Plugin", c.GetString("plugin"))
|
||||||
}
|
}
|
||||||
|
if info.RelayMode == constant.RelayModeImagesGenerations {
|
||||||
|
req.Set("X-DashScope-Async", "enable")
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -90,7 +93,10 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||||
aliRequest := oaiImage2Ali(request)
|
aliRequest, err := oaiImage2Ali(request)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("convert image request failed: %w", err)
|
||||||
|
}
|
||||||
return aliRequest, nil
|
return aliRequest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -125,8 +131,16 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
|
return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
adaptor := openai.Adaptor{}
|
switch info.RelayMode {
|
||||||
return adaptor.DoResponse(c, resp, info)
|
case constant.RelayModeImagesGenerations:
|
||||||
|
err, usage = aliImageHandler(c, resp, info)
|
||||||
|
case constant.RelayModeRerank:
|
||||||
|
err, usage = RerankHandler(c, resp, info)
|
||||||
|
default:
|
||||||
|
adaptor := openai.Adaptor{}
|
||||||
|
usage, err = adaptor.DoResponse(c, resp, info)
|
||||||
|
}
|
||||||
|
return usage, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -86,20 +86,25 @@ type AliResponse struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type AliImageRequest struct {
|
type AliImageRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Input struct {
|
Input any `json:"input"`
|
||||||
Prompt string `json:"prompt"`
|
Parameters any `json:"parameters,omitempty"`
|
||||||
NegativePrompt string `json:"negative_prompt,omitempty"`
|
|
||||||
} `json:"input"`
|
|
||||||
Parameters struct {
|
|
||||||
Size string `json:"size,omitempty"`
|
|
||||||
N int `json:"n,omitempty"`
|
|
||||||
Steps string `json:"steps,omitempty"`
|
|
||||||
Scale string `json:"scale,omitempty"`
|
|
||||||
} `json:"parameters,omitempty"`
|
|
||||||
ResponseFormat string `json:"response_format,omitempty"`
|
ResponseFormat string `json:"response_format,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type AliImageParameters struct {
|
||||||
|
Size string `json:"size,omitempty"`
|
||||||
|
N int `json:"n,omitempty"`
|
||||||
|
Steps string `json:"steps,omitempty"`
|
||||||
|
Scale string `json:"scale,omitempty"`
|
||||||
|
Watermark *bool `json:"watermark,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AliImageInput struct {
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
NegativePrompt string `json:"negative_prompt,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
type AliRerankParameters struct {
|
type AliRerankParameters struct {
|
||||||
TopN *int `json:"top_n,omitempty"`
|
TopN *int `json:"top_n,omitempty"`
|
||||||
ReturnDocuments *bool `json:"return_documents,omitempty"`
|
ReturnDocuments *bool `json:"return_documents,omitempty"`
|
||||||
|
|||||||
@@ -18,15 +18,41 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func oaiImage2Ali(request dto.ImageRequest) *AliImageRequest {
|
func oaiImage2Ali(request dto.ImageRequest) (*AliImageRequest, error) {
|
||||||
var imageRequest AliImageRequest
|
var imageRequest AliImageRequest
|
||||||
imageRequest.Input.Prompt = request.Prompt
|
|
||||||
imageRequest.Model = request.Model
|
imageRequest.Model = request.Model
|
||||||
imageRequest.Parameters.Size = strings.Replace(request.Size, "x", "*", -1)
|
|
||||||
imageRequest.Parameters.N = int(request.N)
|
|
||||||
imageRequest.ResponseFormat = request.ResponseFormat
|
imageRequest.ResponseFormat = request.ResponseFormat
|
||||||
|
|
||||||
return &imageRequest
|
if request.Extra != nil {
|
||||||
|
if val, ok := request.Extra["parameters"]; ok {
|
||||||
|
err := common.Unmarshal(val, &imageRequest.Parameters)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid parameters field: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if val, ok := request.Extra["input"]; ok {
|
||||||
|
err := common.Unmarshal(val, &imageRequest.Input)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid input field: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if imageRequest.Parameters == nil {
|
||||||
|
imageRequest.Parameters = AliImageParameters{
|
||||||
|
Size: strings.Replace(request.Size, "x", "*", -1),
|
||||||
|
N: int(request.N),
|
||||||
|
Watermark: request.Watermark,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if imageRequest.Input == nil {
|
||||||
|
imageRequest.Input = AliImageInput{
|
||||||
|
Prompt: request.Prompt,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &imageRequest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error, []byte) {
|
func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error, []byte) {
|
||||||
@@ -52,7 +78,7 @@ func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error
|
|||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
var response AliResponse
|
var response AliResponse
|
||||||
err = json.Unmarshal(responseBody, &response)
|
err = common.Unmarshal(responseBody, &response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysLog("updateTask NewDecoder err: " + err.Error())
|
common.SysLog("updateTask NewDecoder err: " + err.Error())
|
||||||
return &aliResponse, err, nil
|
return &aliResponse, err, nil
|
||||||
@@ -61,8 +87,8 @@ func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error
|
|||||||
return &response, nil, responseBody
|
return &response, nil, responseBody
|
||||||
}
|
}
|
||||||
|
|
||||||
func asyncTaskWait(info *relaycommon.RelayInfo, taskID string) (*AliResponse, []byte, error) {
|
func asyncTaskWait(c *gin.Context, info *relaycommon.RelayInfo, taskID string) (*AliResponse, []byte, error) {
|
||||||
waitSeconds := 3
|
waitSeconds := 5
|
||||||
step := 0
|
step := 0
|
||||||
maxStep := 20
|
maxStep := 20
|
||||||
|
|
||||||
@@ -70,11 +96,14 @@ func asyncTaskWait(info *relaycommon.RelayInfo, taskID string) (*AliResponse, []
|
|||||||
var responseBody []byte
|
var responseBody []byte
|
||||||
|
|
||||||
for {
|
for {
|
||||||
|
logger.LogDebug(c, fmt.Sprintf("asyncTaskWait step %d/%d, wait %d seconds", step, maxStep, waitSeconds))
|
||||||
step++
|
step++
|
||||||
rsp, err, body := updateTask(info, taskID)
|
rsp, err, body := updateTask(info, taskID)
|
||||||
responseBody = body
|
responseBody = body
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &taskResponse, responseBody, err
|
logger.LogWarn(c, "asyncTaskWait UpdateTask err: "+err.Error())
|
||||||
|
time.Sleep(time.Duration(waitSeconds) * time.Second)
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if rsp.Output.TaskStatus == "" {
|
if rsp.Output.TaskStatus == "" {
|
||||||
@@ -124,6 +153,7 @@ func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, info *relayc
|
|||||||
RevisedPrompt: "",
|
RevisedPrompt: "",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
imageResponse.Extra = response
|
||||||
return &imageResponse
|
return &imageResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -146,7 +176,7 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
|||||||
return types.NewError(errors.New(aliTaskResponse.Message), types.ErrorCodeBadResponse), nil
|
return types.NewError(errors.New(aliTaskResponse.Message), types.ErrorCodeBadResponse), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
aliResponse, _, err := asyncTaskWait(info, aliTaskResponse.Output.TaskId)
|
aliResponse, _, err := asyncTaskWait(c, info, aliTaskResponse.Output.TaskId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.NewError(err, types.ErrorCodeBadResponse), nil
|
return types.NewError(err, types.ErrorCodeBadResponse), nil
|
||||||
}
|
}
|
||||||
@@ -161,7 +191,7 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
|||||||
}
|
}
|
||||||
|
|
||||||
fullTextResponse := responseAli2OpenAIImage(c, aliResponse, info, responseFormat)
|
fullTextResponse := responseAli2OpenAIImage(c, aliResponse, info, responseFormat)
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
jsonResponse, err := common.Marshal(fullTextResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user