This commit refactors the logging mechanism across the application by replacing direct logger calls with a centralized logging approach using the `common` package. Key changes include: - Replaced instances of `logger.SysLog` and `logger.FatalLog` with `common.SysLog` and `common.FatalLog` for consistent logging practices. - Updated resource initialization error handling to utilize the new logging structure, enhancing maintainability and readability. - Minor adjustments to improve code clarity and organization throughout various modules. This change aims to streamline logging and improve the overall architecture of the codebase.
286 lines
7.7 KiB
Go
286 lines
7.7 KiB
Go
package vidu
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
|
|
"one-api/constant"
|
|
"one-api/dto"
|
|
"one-api/model"
|
|
"one-api/relay/channel"
|
|
relaycommon "one-api/relay/common"
|
|
"one-api/service"
|
|
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
// ============================
|
|
// Request / Response structures
|
|
// ============================
|
|
|
|
type SubmitReq struct {
|
|
Prompt string `json:"prompt"`
|
|
Model string `json:"model,omitempty"`
|
|
Mode string `json:"mode,omitempty"`
|
|
Image string `json:"image,omitempty"`
|
|
Size string `json:"size,omitempty"`
|
|
Duration int `json:"duration,omitempty"`
|
|
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
|
}
|
|
|
|
type requestPayload struct {
|
|
Model string `json:"model"`
|
|
Images []string `json:"images"`
|
|
Prompt string `json:"prompt,omitempty"`
|
|
Duration int `json:"duration,omitempty"`
|
|
Seed int `json:"seed,omitempty"`
|
|
Resolution string `json:"resolution,omitempty"`
|
|
MovementAmplitude string `json:"movement_amplitude,omitempty"`
|
|
Bgm bool `json:"bgm,omitempty"`
|
|
Payload string `json:"payload,omitempty"`
|
|
CallbackUrl string `json:"callback_url,omitempty"`
|
|
}
|
|
|
|
type responsePayload struct {
|
|
TaskId string `json:"task_id"`
|
|
State string `json:"state"`
|
|
Model string `json:"model"`
|
|
Images []string `json:"images"`
|
|
Prompt string `json:"prompt"`
|
|
Duration int `json:"duration"`
|
|
Seed int `json:"seed"`
|
|
Resolution string `json:"resolution"`
|
|
Bgm bool `json:"bgm"`
|
|
MovementAmplitude string `json:"movement_amplitude"`
|
|
Payload string `json:"payload"`
|
|
CreatedAt string `json:"created_at"`
|
|
}
|
|
|
|
type taskResultResponse struct {
|
|
State string `json:"state"`
|
|
ErrCode string `json:"err_code"`
|
|
Credits int `json:"credits"`
|
|
Payload string `json:"payload"`
|
|
Creations []creation `json:"creations"`
|
|
}
|
|
|
|
type creation struct {
|
|
ID string `json:"id"`
|
|
URL string `json:"url"`
|
|
CoverURL string `json:"cover_url"`
|
|
}
|
|
|
|
// ============================
|
|
// Adaptor implementation
|
|
// ============================
|
|
|
|
type TaskAdaptor struct {
|
|
ChannelType int
|
|
baseURL string
|
|
}
|
|
|
|
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
|
|
a.ChannelType = info.ChannelType
|
|
a.baseURL = info.ChannelBaseUrl
|
|
}
|
|
|
|
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) *dto.TaskError {
|
|
var req SubmitReq
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
return service.TaskErrorWrapper(err, "invalid_request_body", http.StatusBadRequest)
|
|
}
|
|
|
|
if req.Prompt == "" {
|
|
return service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "missing_prompt", http.StatusBadRequest)
|
|
}
|
|
|
|
if req.Image != "" {
|
|
info.Action = constant.TaskActionGenerate
|
|
} else {
|
|
info.Action = constant.TaskActionTextGenerate
|
|
}
|
|
|
|
c.Set("task_request", req)
|
|
return nil
|
|
}
|
|
|
|
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.TaskRelayInfo) (io.Reader, error) {
|
|
v, exists := c.Get("task_request")
|
|
if !exists {
|
|
return nil, fmt.Errorf("request not found in context")
|
|
}
|
|
req := v.(SubmitReq)
|
|
|
|
body, err := a.convertToRequestPayload(&req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(body.Images) == 0 {
|
|
c.Set("action", constant.TaskActionTextGenerate)
|
|
}
|
|
|
|
data, err := json.Marshal(body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return bytes.NewReader(data), nil
|
|
}
|
|
|
|
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
|
|
var path string
|
|
switch info.Action {
|
|
case constant.TaskActionGenerate:
|
|
path = "/img2video"
|
|
default:
|
|
path = "/text2video"
|
|
}
|
|
return fmt.Sprintf("%s/ent/v2%s", a.baseURL, path), nil
|
|
}
|
|
|
|
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error {
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Accept", "application/json")
|
|
req.Header.Set("Authorization", "Token "+info.ApiKey)
|
|
return nil
|
|
}
|
|
|
|
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
|
|
if action := c.GetString("action"); action != "" {
|
|
info.Action = action
|
|
}
|
|
return channel.DoTaskApiRequest(a, c, info, requestBody)
|
|
}
|
|
|
|
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
|
|
responseBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
var vResp responsePayload
|
|
err = json.Unmarshal(responseBody, &vResp)
|
|
if err != nil {
|
|
taskErr = service.TaskErrorWrapper(errors.Wrap(err, fmt.Sprintf("%s", responseBody)), "unmarshal_response_failed", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
if vResp.State == "failed" {
|
|
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("task failed"), "task_failed", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, vResp)
|
|
return vResp.TaskId, responseBody, nil
|
|
}
|
|
|
|
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
|
taskID, ok := body["task_id"].(string)
|
|
if !ok {
|
|
return nil, fmt.Errorf("invalid task_id")
|
|
}
|
|
|
|
url := fmt.Sprintf("%s/ent/v2/tasks/%s/creations", baseUrl, taskID)
|
|
|
|
req, err := http.NewRequest(http.MethodGet, url, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
req.Header.Set("Accept", "application/json")
|
|
req.Header.Set("Authorization", "Token "+key)
|
|
|
|
return service.GetHttpClient().Do(req)
|
|
}
|
|
|
|
func (a *TaskAdaptor) GetModelList() []string {
|
|
return []string{"viduq1", "vidu2.0", "vidu1.5"}
|
|
}
|
|
|
|
func (a *TaskAdaptor) GetChannelName() string {
|
|
return "vidu"
|
|
}
|
|
|
|
// ============================
|
|
// helpers
|
|
// ============================
|
|
|
|
func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) {
|
|
var images []string
|
|
if req.Image != "" {
|
|
images = []string{req.Image}
|
|
}
|
|
|
|
r := requestPayload{
|
|
Model: defaultString(req.Model, "viduq1"),
|
|
Images: images,
|
|
Prompt: req.Prompt,
|
|
Duration: defaultInt(req.Duration, 5),
|
|
Resolution: defaultString(req.Size, "1080p"),
|
|
MovementAmplitude: "auto",
|
|
Bgm: false,
|
|
}
|
|
metadata := req.Metadata
|
|
medaBytes, err := json.Marshal(metadata)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "metadata marshal metadata failed")
|
|
}
|
|
err = json.Unmarshal(medaBytes, &r)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "unmarshal metadata failed")
|
|
}
|
|
return &r, nil
|
|
}
|
|
|
|
func defaultString(value, defaultValue string) string {
|
|
if value == "" {
|
|
return defaultValue
|
|
}
|
|
return value
|
|
}
|
|
|
|
func defaultInt(value, defaultValue int) int {
|
|
if value == 0 {
|
|
return defaultValue
|
|
}
|
|
return value
|
|
}
|
|
|
|
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
|
|
taskInfo := &relaycommon.TaskInfo{}
|
|
|
|
var taskResp taskResultResponse
|
|
err := json.Unmarshal(respBody, &taskResp)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "failed to unmarshal response body")
|
|
}
|
|
|
|
state := taskResp.State
|
|
switch state {
|
|
case "created", "queueing":
|
|
taskInfo.Status = model.TaskStatusSubmitted
|
|
case "processing":
|
|
taskInfo.Status = model.TaskStatusInProgress
|
|
case "success":
|
|
taskInfo.Status = model.TaskStatusSuccess
|
|
if len(taskResp.Creations) > 0 {
|
|
taskInfo.Url = taskResp.Creations[0].URL
|
|
}
|
|
case "failed":
|
|
taskInfo.Status = model.TaskStatusFailure
|
|
if taskResp.ErrCode != "" {
|
|
taskInfo.Reason = taskResp.ErrCode
|
|
}
|
|
default:
|
|
return nil, fmt.Errorf("unknown task state: %s", state)
|
|
}
|
|
|
|
return taskInfo, nil
|
|
}
|