package vidu import ( "bytes" "encoding/json" "fmt" "io" "net/http" "github.com/gin-gonic/gin" "one-api/constant" "one-api/dto" "one-api/model" "one-api/relay/channel" relaycommon "one-api/relay/common" "one-api/service" "github.com/pkg/errors" ) // ============================ // Request / Response structures // ============================ type SubmitReq struct { Prompt string `json:"prompt"` Model string `json:"model,omitempty"` Mode string `json:"mode,omitempty"` Image string `json:"image,omitempty"` Size string `json:"size,omitempty"` Duration int `json:"duration,omitempty"` Metadata map[string]interface{} `json:"metadata,omitempty"` } type requestPayload struct { Model string `json:"model"` Images []string `json:"images"` Prompt string `json:"prompt,omitempty"` Duration int `json:"duration,omitempty"` Seed int `json:"seed,omitempty"` Resolution string `json:"resolution,omitempty"` MovementAmplitude string `json:"movement_amplitude,omitempty"` Bgm bool `json:"bgm,omitempty"` Payload string `json:"payload,omitempty"` CallbackUrl string `json:"callback_url,omitempty"` } type responsePayload struct { TaskId string `json:"task_id"` State string `json:"state"` Model string `json:"model"` Images []string `json:"images"` Prompt string `json:"prompt"` Duration int `json:"duration"` Seed int `json:"seed"` Resolution string `json:"resolution"` Bgm bool `json:"bgm"` MovementAmplitude string `json:"movement_amplitude"` Payload string `json:"payload"` CreatedAt string `json:"created_at"` } type taskResultResponse struct { State string `json:"state"` ErrCode string `json:"err_code"` Credits int `json:"credits"` Payload string `json:"payload"` Creations []creation `json:"creations"` } type creation struct { ID string `json:"id"` URL string `json:"url"` CoverURL string `json:"cover_url"` } // ============================ // Adaptor implementation // ============================ type TaskAdaptor struct { ChannelType int baseURL string } func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) { a.ChannelType = info.ChannelType a.baseURL = info.ChannelBaseUrl } func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) *dto.TaskError { var req SubmitReq if err := c.ShouldBindJSON(&req); err != nil { return service.TaskErrorWrapper(err, "invalid_request_body", http.StatusBadRequest) } if req.Prompt == "" { return service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "missing_prompt", http.StatusBadRequest) } if req.Image != "" { info.Action = constant.TaskActionGenerate } else { info.Action = constant.TaskActionTextGenerate } c.Set("task_request", req) return nil } func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.TaskRelayInfo) (io.Reader, error) { v, exists := c.Get("task_request") if !exists { return nil, fmt.Errorf("request not found in context") } req := v.(SubmitReq) body, err := a.convertToRequestPayload(&req) if err != nil { return nil, err } if len(body.Images) == 0 { c.Set("action", constant.TaskActionTextGenerate) } data, err := json.Marshal(body) if err != nil { return nil, err } return bytes.NewReader(data), nil } func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) { var path string switch info.Action { case constant.TaskActionGenerate: path = "/img2video" default: path = "/text2video" } return fmt.Sprintf("%s/ent/v2%s", a.baseURL, path), nil } func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error { req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") req.Header.Set("Authorization", "Token "+info.ApiKey) return nil } func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) { if action := c.GetString("action"); action != "" { info.Action = action } return channel.DoTaskApiRequest(a, c, info, requestBody) } func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { responseBody, err := io.ReadAll(resp.Body) if err != nil { taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) return } var vResp responsePayload err = json.Unmarshal(responseBody, &vResp) if err != nil { taskErr = service.TaskErrorWrapper(errors.Wrap(err, fmt.Sprintf("%s", responseBody)), "unmarshal_response_failed", http.StatusInternalServerError) return } if vResp.State == "failed" { taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("task failed"), "task_failed", http.StatusBadRequest) return } c.JSON(http.StatusOK, vResp) return vResp.TaskId, responseBody, nil } func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) { taskID, ok := body["task_id"].(string) if !ok { return nil, fmt.Errorf("invalid task_id") } url := fmt.Sprintf("%s/ent/v2/tasks/%s/creations", baseUrl, taskID) req, err := http.NewRequest(http.MethodGet, url, nil) if err != nil { return nil, err } req.Header.Set("Accept", "application/json") req.Header.Set("Authorization", "Token "+key) return service.GetHttpClient().Do(req) } func (a *TaskAdaptor) GetModelList() []string { return []string{"viduq1", "vidu2.0", "vidu1.5"} } func (a *TaskAdaptor) GetChannelName() string { return "vidu" } // ============================ // helpers // ============================ func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) { var images []string if req.Image != "" { images = []string{req.Image} } r := requestPayload{ Model: defaultString(req.Model, "viduq1"), Images: images, Prompt: req.Prompt, Duration: defaultInt(req.Duration, 5), Resolution: defaultString(req.Size, "1080p"), MovementAmplitude: "auto", Bgm: false, } metadata := req.Metadata medaBytes, err := json.Marshal(metadata) if err != nil { return nil, errors.Wrap(err, "metadata marshal metadata failed") } err = json.Unmarshal(medaBytes, &r) if err != nil { return nil, errors.Wrap(err, "unmarshal metadata failed") } return &r, nil } func defaultString(value, defaultValue string) string { if value == "" { return defaultValue } return value } func defaultInt(value, defaultValue int) int { if value == 0 { return defaultValue } return value } func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { taskInfo := &relaycommon.TaskInfo{} var taskResp taskResultResponse err := json.Unmarshal(respBody, &taskResp) if err != nil { return nil, errors.Wrap(err, "failed to unmarshal response body") } state := taskResp.State switch state { case "created", "queueing": taskInfo.Status = model.TaskStatusSubmitted case "processing": taskInfo.Status = model.TaskStatusInProgress case "success": taskInfo.Status = model.TaskStatusSuccess if len(taskResp.Creations) > 0 { taskInfo.Url = taskResp.Creations[0].URL } case "failed": taskInfo.Status = model.TaskStatusFailure if taskResp.ErrCode != "" { taskInfo.Reason = taskResp.ErrCode } default: return nil, fmt.Errorf("unknown task state: %s", state) } return taskInfo, nil }