From cd2870aebc8563b0bbe87031327f443c0989e66e Mon Sep 17 00:00:00 2001 From: skynono Date: Mon, 23 Jun 2025 21:22:01 +0800 Subject: [PATCH] feat: add origin kling api --- controller/task_video.go | 13 ++++--- middleware/kling_adapter.go | 45 +++++++++++++++++++++++ relay/channel/task/kling/adaptor.go | 39 +++++++++++++++----- router/video-router.go | 7 ++++ web/src/components/table/TaskLogsTable.js | 10 ++++- 5 files changed, 96 insertions(+), 18 deletions(-) create mode 100644 middleware/kling_adapter.go diff --git a/controller/task_video.go b/controller/task_video.go index a2c2431d..2e980310 100644 --- a/controller/task_video.go +++ b/controller/task_video.go @@ -56,8 +56,15 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha if channel.GetBaseURL() != "" { baseURL = channel.GetBaseURL() } + + task := taskM[taskId] + if task == nil { + common.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId)) + return fmt.Errorf("task %s not found", taskId) + } resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{ "task_id": taskId, + "action": task.Action, }) if err != nil { return fmt.Errorf("FetchTask failed for task %s: %w", taskId, err) @@ -89,12 +96,6 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha return fmt.Errorf("video task data format error for task %s", taskId) } - task := taskM[taskId] - if task == nil { - common.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId)) - return fmt.Errorf("task %s not found", taskId) - } - if status, ok := data["task_status"].(string); ok { switch status { case "submitted", "queued": diff --git a/middleware/kling_adapter.go b/middleware/kling_adapter.go new file mode 100644 index 00000000..b6ecf727 --- /dev/null +++ b/middleware/kling_adapter.go @@ -0,0 +1,45 @@ +package middleware + +import ( + "bytes" + "encoding/json" + "github.com/gin-gonic/gin" + "io" + "one-api/common" +) + +func KlingRequestConvert() func(c *gin.Context) { + return func(c *gin.Context) { + var originalReq map[string]interface{} + if err := common.UnmarshalBodyReusable(c, &originalReq); err != nil { + c.Next() + return + } + + model, _ := originalReq["model"].(string) + prompt, _ := originalReq["prompt"].(string) + + unifiedReq := map[string]interface{}{ + "model": model, + "prompt": prompt, + "metadata": originalReq, + } + + jsonData, err := json.Marshal(unifiedReq) + if err != nil { + c.Next() + return + } + + // Rewrite request body and path + c.Request.Body = io.NopCloser(bytes.NewBuffer(jsonData)) + c.Request.URL.Path = "/v1/video/generations" + if image := originalReq["image"]; image == "" { + c.Set("action", "textGenerate") + } + + // We have to reset the request body for the next handlers + c.Set(common.KeyRequestBody, jsonData) + c.Next() + } +} diff --git a/relay/channel/task/kling/adaptor.go b/relay/channel/task/kling/adaptor.go index 2995a07b..55f21196 100644 --- a/relay/channel/task/kling/adaptor.go +++ b/relay/channel/task/kling/adaptor.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "fmt" + "github.com/samber/lo" "io" "net/http" "strings" @@ -41,7 +42,6 @@ type requestPayload struct { Mode string `json:"mode,omitempty"` Duration string `json:"duration,omitempty"` AspectRatio string `json:"aspect_ratio,omitempty"` - Model string `json:"model,omitempty"` ModelName string `json:"model_name,omitempty"` CfgScale float64 `json:"cfg_scale,omitempty"` } @@ -100,7 +100,8 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom // BuildRequestURL constructs the upstream URL. func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) { - return fmt.Sprintf("%s/v1/videos/image2video", a.baseURL), nil + path := lo.Ternary(info.Action == "generate", "/v1/videos/image2video", "/v1/videos/text2video") + return fmt.Sprintf("%s%s", a.baseURL, path), nil } // BuildRequestHeader sets required headers. @@ -125,7 +126,10 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRel } req := v.(SubmitReq) - body := a.convertToRequestPayload(&req) + body, err := a.convertToRequestPayload(&req) + if err != nil { + return nil, err + } data, err := json.Marshal(body) if err != nil { return nil, err @@ -135,6 +139,9 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRel // DoRequest delegates to common helper. func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) { + if action := c.GetString("action"); action != "" { + info.Action = action + } return channel.DoTaskApiRequest(a, c, info, requestBody) } @@ -175,7 +182,12 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http if !ok { return nil, fmt.Errorf("invalid task_id") } - url := fmt.Sprintf("%s/v1/videos/image2video/%s", baseUrl, taskID) + action, ok := body["action"].(string) + if !ok { + return nil, fmt.Errorf("invalid action") + } + path := lo.Ternary(action == "generate", "/v1/videos/image2video", "/v1/videos/text2video") + url := fmt.Sprintf("%s%s/%s", baseUrl, path, taskID) req, err := http.NewRequest(http.MethodGet, url, nil) if err != nil { @@ -210,22 +222,29 @@ func (a *TaskAdaptor) GetChannelName() string { // helpers // ============================ -func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) *requestPayload { - r := &requestPayload{ +func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) { + r := requestPayload{ Prompt: req.Prompt, Image: req.Image, Mode: defaultString(req.Mode, "std"), Duration: fmt.Sprintf("%d", defaultInt(req.Duration, 5)), AspectRatio: a.getAspectRatio(req.Size), - Model: req.Model, ModelName: req.Model, CfgScale: 0.5, } - if r.Model == "" { - r.Model = "kling-v1" + if r.ModelName == "" { r.ModelName = "kling-v1" } - return r + metadata := req.Metadata + medaBytes, err := json.Marshal(metadata) + if err != nil { + return nil, errors.Wrap(err, "metadata marshal metadata failed") + } + err = json.Unmarshal(medaBytes, &r) + if err != nil { + return nil, errors.Wrap(err, "unmarshal metadata failed") + } + return &r, nil } func (a *TaskAdaptor) getAspectRatio(size string) string { diff --git a/router/video-router.go b/router/video-router.go index 7201c34a..9e605d54 100644 --- a/router/video-router.go +++ b/router/video-router.go @@ -14,4 +14,11 @@ func SetVideoRouter(router *gin.Engine) { videoV1Router.POST("/video/generations", controller.RelayTask) videoV1Router.GET("/video/generations/:task_id", controller.RelayTask) } + + klingV1Router := router.Group("/kling/v1") + klingV1Router.Use(middleware.KlingRequestConvert(), middleware.TokenAuth(), middleware.Distribute()) + { + klingV1Router.POST("/videos/text2video", controller.RelayTask) + klingV1Router.POST("/videos/image2video", controller.RelayTask) + } } diff --git a/web/src/components/table/TaskLogsTable.js b/web/src/components/table/TaskLogsTable.js index 8b309942..5b77ce39 100644 --- a/web/src/components/table/TaskLogsTable.js +++ b/web/src/components/table/TaskLogsTable.js @@ -212,7 +212,13 @@ const LogsTable = () => { case 'generate': return ( }> - {t('生成视频')} + {t('图生视频')} + + ); + case 'textGenerate': + return ( + }> + {t('文生视频')} ); default: @@ -434,7 +440,7 @@ const LogsTable = () => { fixed: 'right', render: (text, record, index) => { // 仅当为视频生成任务且成功,且 fail_reason 是 URL 时显示可点击链接 - const isVideoTask = record.action === 'generate'; + const isVideoTask = record.action === 'generate' || record.action === 'textGenerate'; const isSuccess = record.status === 'SUCCESS'; const isUrl = typeof text === 'string' && /^https?:\/\//.test(text); if (isSuccess && isVideoTask && isUrl) {