diff --git a/constant/channel.go b/constant/channel.go index 224121e7..2e1cc5b0 100644 --- a/constant/channel.go +++ b/constant/channel.go @@ -49,6 +49,7 @@ const ( ChannelTypeCoze = 49 ChannelTypeKling = 50 ChannelTypeJimeng = 51 + ChannelTypeVidu = 52 ChannelTypeDummy // this one is only for count, do not add any channel after this ) @@ -106,4 +107,5 @@ var ChannelBaseURLs = []string{ "https://api.coze.cn", //49 "https://api.klingai.com", //50 "https://visual.volcengineapi.com", //51 + "https://api.vidu.cn", //52 } diff --git a/controller/channel-test.go b/controller/channel-test.go index 8c4a26ae..c1c3c21d 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -69,6 +69,12 @@ func testChannel(channel *model.Channel, testModel string) testResult { newAPIError: nil, } } + if channel.Type == constant.ChannelTypeVidu { + return testResult{ + localErr: errors.New("vidu channel test is not supported"), + newAPIError: nil, + } + } w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) diff --git a/controller/task_video.go b/controller/task_video.go index 684f30fa..914bf6e6 100644 --- a/controller/task_video.go +++ b/controller/task_video.go @@ -83,7 +83,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha taskResult := &relaycommon.TaskInfo{} // try parse as New API response format var responseItems dto.TaskResponse[model.Task] - if err = json.Unmarshal(responseBody, &responseItems); err == nil { + if err = json.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() { t := responseItems.Data taskResult.TaskID = t.TaskID taskResult.Status = string(t.Status) diff --git a/relay/channel/task/vidu/adaptor.go b/relay/channel/task/vidu/adaptor.go new file mode 100644 index 00000000..f40b480c --- /dev/null +++ b/relay/channel/task/vidu/adaptor.go @@ -0,0 +1,285 @@ +package vidu + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/gin-gonic/gin" + + "one-api/constant" + "one-api/dto" + "one-api/model" + "one-api/relay/channel" + relaycommon "one-api/relay/common" + "one-api/service" + + "github.com/pkg/errors" +) + +// ============================ +// Request / Response structures +// ============================ + +type SubmitReq struct { + Prompt string `json:"prompt"` + Model string `json:"model,omitempty"` + Mode string `json:"mode,omitempty"` + Image string `json:"image,omitempty"` + Size string `json:"size,omitempty"` + Duration int `json:"duration,omitempty"` + Metadata map[string]interface{} `json:"metadata,omitempty"` +} + +type requestPayload struct { + Model string `json:"model"` + Images []string `json:"images"` + Prompt string `json:"prompt,omitempty"` + Duration int `json:"duration,omitempty"` + Seed int `json:"seed,omitempty"` + Resolution string `json:"resolution,omitempty"` + MovementAmplitude string `json:"movement_amplitude,omitempty"` + Bgm bool `json:"bgm,omitempty"` + Payload string `json:"payload,omitempty"` + CallbackUrl string `json:"callback_url,omitempty"` +} + +type responsePayload struct { + TaskId string `json:"task_id"` + State string `json:"state"` + Model string `json:"model"` + Images []string `json:"images"` + Prompt string `json:"prompt"` + Duration int `json:"duration"` + Seed int `json:"seed"` + Resolution string `json:"resolution"` + Bgm bool `json:"bgm"` + MovementAmplitude string `json:"movement_amplitude"` + Payload string `json:"payload"` + CreatedAt string `json:"created_at"` +} + +type taskResultResponse struct { + State string `json:"state"` + ErrCode string `json:"err_code"` + Credits int `json:"credits"` + Payload string `json:"payload"` + Creations []creation `json:"creations"` +} + +type creation struct { + ID string `json:"id"` + URL string `json:"url"` + CoverURL string `json:"cover_url"` +} + +// ============================ +// Adaptor implementation +// ============================ + +type TaskAdaptor struct { + ChannelType int + baseURL string +} + +func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) { + a.ChannelType = info.ChannelType + a.baseURL = info.BaseUrl +} + +func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) *dto.TaskError { + var req SubmitReq + if err := c.ShouldBindJSON(&req); err != nil { + return service.TaskErrorWrapper(err, "invalid_request_body", http.StatusBadRequest) + } + + if req.Prompt == "" { + return service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "missing_prompt", http.StatusBadRequest) + } + + if req.Image != "" { + info.Action = constant.TaskActionGenerate + } else { + info.Action = constant.TaskActionTextGenerate + } + + c.Set("task_request", req) + return nil +} + +func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.TaskRelayInfo) (io.Reader, error) { + v, exists := c.Get("task_request") + if !exists { + return nil, fmt.Errorf("request not found in context") + } + req := v.(SubmitReq) + + body, err := a.convertToRequestPayload(&req) + if err != nil { + return nil, err + } + + if len(body.Images) == 0 { + c.Set("action", constant.TaskActionTextGenerate) + } + + data, err := json.Marshal(body) + if err != nil { + return nil, err + } + return bytes.NewReader(data), nil +} + +func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) { + var path string + switch info.Action { + case constant.TaskActionGenerate: + path = "/img2video" + default: + path = "/text2video" + } + return fmt.Sprintf("%s/ent/v2%s", a.baseURL, path), nil +} + +func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error { + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", "Token "+info.ApiKey) + return nil +} + +func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) { + if action := c.GetString("action"); action != "" { + info.Action = action + } + return channel.DoTaskApiRequest(a, c, info, requestBody) +} + +func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) + return + } + + var vResp responsePayload + err = json.Unmarshal(responseBody, &vResp) + if err != nil { + taskErr = service.TaskErrorWrapper(errors.Wrap(err, fmt.Sprintf("%s", responseBody)), "unmarshal_response_failed", http.StatusInternalServerError) + return + } + + if vResp.State == "failed" { + taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("task failed"), "task_failed", http.StatusBadRequest) + return + } + + c.JSON(http.StatusOK, vResp) + return vResp.TaskId, responseBody, nil +} + +func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) { + taskID, ok := body["task_id"].(string) + if !ok { + return nil, fmt.Errorf("invalid task_id") + } + + url := fmt.Sprintf("%s/ent/v2/tasks/%s/creations", baseUrl, taskID) + + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return nil, err + } + + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", "Token "+key) + + return service.GetHttpClient().Do(req) +} + +func (a *TaskAdaptor) GetModelList() []string { + return []string{"viduq1", "vidu2.0", "vidu1.5"} +} + +func (a *TaskAdaptor) GetChannelName() string { + return "vidu" +} + +// ============================ +// helpers +// ============================ + +func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) { + var images []string + if req.Image != "" { + images = []string{req.Image} + } + + r := requestPayload{ + Model: defaultString(req.Model, "viduq1"), + Images: images, + Prompt: req.Prompt, + Duration: defaultInt(req.Duration, 5), + Resolution: defaultString(req.Size, "1080p"), + MovementAmplitude: "auto", + Bgm: false, + } + metadata := req.Metadata + medaBytes, err := json.Marshal(metadata) + if err != nil { + return nil, errors.Wrap(err, "metadata marshal metadata failed") + } + err = json.Unmarshal(medaBytes, &r) + if err != nil { + return nil, errors.Wrap(err, "unmarshal metadata failed") + } + return &r, nil +} + +func defaultString(value, defaultValue string) string { + if value == "" { + return defaultValue + } + return value +} + +func defaultInt(value, defaultValue int) int { + if value == 0 { + return defaultValue + } + return value +} + +func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { + taskInfo := &relaycommon.TaskInfo{} + + var taskResp taskResultResponse + err := json.Unmarshal(respBody, &taskResp) + if err != nil { + return nil, errors.Wrap(err, "failed to unmarshal response body") + } + + state := taskResp.State + switch state { + case "created", "queueing": + taskInfo.Status = model.TaskStatusSubmitted + case "processing": + taskInfo.Status = model.TaskStatusInProgress + case "success": + taskInfo.Status = model.TaskStatusSuccess + if len(taskResp.Creations) > 0 { + taskInfo.Url = taskResp.Creations[0].URL + } + case "failed": + taskInfo.Status = model.TaskStatusFailure + if taskResp.ErrCode != "" { + taskInfo.Reason = taskResp.ErrCode + } + default: + return nil, fmt.Errorf("unknown task state: %s", state) + } + + return taskInfo, nil +} diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go index 1e9c46e8..cc9c5bbb 100644 --- a/relay/relay_adaptor.go +++ b/relay/relay_adaptor.go @@ -27,6 +27,7 @@ import ( taskjimeng "one-api/relay/channel/task/jimeng" "one-api/relay/channel/task/kling" "one-api/relay/channel/task/suno" + taskVidu "one-api/relay/channel/task/vidu" "one-api/relay/channel/tencent" "one-api/relay/channel/vertex" "one-api/relay/channel/volcengine" @@ -122,6 +123,8 @@ func GetTaskAdaptor(platform constant.TaskPlatform) channel.TaskAdaptor { return &kling.TaskAdaptor{} case constant.ChannelTypeJimeng: return &taskjimeng.TaskAdaptor{} + case constant.ChannelTypeVidu: + return &taskVidu.TaskAdaptor{} } } return nil diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index c2468ec7..43372a25 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -154,6 +154,11 @@ export const CHANNEL_OPTIONS = [ color: 'blue', label: '即梦', }, + { + value: 52, + color: 'purple', + label: 'Vidu', + }, ]; export const MODEL_TABLE_PAGE_SIZE = 10;