diff --git a/controller/task_video.go b/controller/task_video.go
index a2c2431d..2e980310 100644
--- a/controller/task_video.go
+++ b/controller/task_video.go
@@ -56,8 +56,15 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
if channel.GetBaseURL() != "" {
baseURL = channel.GetBaseURL()
}
+
+ task := taskM[taskId]
+ if task == nil {
+ common.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
+ return fmt.Errorf("task %s not found", taskId)
+ }
resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
"task_id": taskId,
+ "action": task.Action,
})
if err != nil {
return fmt.Errorf("FetchTask failed for task %s: %w", taskId, err)
@@ -89,12 +96,6 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
return fmt.Errorf("video task data format error for task %s", taskId)
}
- task := taskM[taskId]
- if task == nil {
- common.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
- return fmt.Errorf("task %s not found", taskId)
- }
-
if status, ok := data["task_status"].(string); ok {
switch status {
case "submitted", "queued":
diff --git a/middleware/kling_adapter.go b/middleware/kling_adapter.go
new file mode 100644
index 00000000..b6ecf727
--- /dev/null
+++ b/middleware/kling_adapter.go
@@ -0,0 +1,45 @@
+package middleware
+
+import (
+ "bytes"
+ "encoding/json"
+ "github.com/gin-gonic/gin"
+ "io"
+ "one-api/common"
+)
+
+func KlingRequestConvert() func(c *gin.Context) {
+ return func(c *gin.Context) {
+ var originalReq map[string]interface{}
+ if err := common.UnmarshalBodyReusable(c, &originalReq); err != nil {
+ c.Next()
+ return
+ }
+
+ model, _ := originalReq["model"].(string)
+ prompt, _ := originalReq["prompt"].(string)
+
+ unifiedReq := map[string]interface{}{
+ "model": model,
+ "prompt": prompt,
+ "metadata": originalReq,
+ }
+
+ jsonData, err := json.Marshal(unifiedReq)
+ if err != nil {
+ c.Next()
+ return
+ }
+
+ // Rewrite request body and path
+ c.Request.Body = io.NopCloser(bytes.NewBuffer(jsonData))
+ c.Request.URL.Path = "/v1/video/generations"
+ if image := originalReq["image"]; image == "" {
+ c.Set("action", "textGenerate")
+ }
+
+ // We have to reset the request body for the next handlers
+ c.Set(common.KeyRequestBody, jsonData)
+ c.Next()
+ }
+}
diff --git a/relay/channel/task/kling/adaptor.go b/relay/channel/task/kling/adaptor.go
index 2995a07b..55f21196 100644
--- a/relay/channel/task/kling/adaptor.go
+++ b/relay/channel/task/kling/adaptor.go
@@ -5,6 +5,7 @@ import (
"context"
"encoding/json"
"fmt"
+ "github.com/samber/lo"
"io"
"net/http"
"strings"
@@ -41,7 +42,6 @@ type requestPayload struct {
Mode string `json:"mode,omitempty"`
Duration string `json:"duration,omitempty"`
AspectRatio string `json:"aspect_ratio,omitempty"`
- Model string `json:"model,omitempty"`
ModelName string `json:"model_name,omitempty"`
CfgScale float64 `json:"cfg_scale,omitempty"`
}
@@ -100,7 +100,8 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
// BuildRequestURL constructs the upstream URL.
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
- return fmt.Sprintf("%s/v1/videos/image2video", a.baseURL), nil
+ path := lo.Ternary(info.Action == "generate", "/v1/videos/image2video", "/v1/videos/text2video")
+ return fmt.Sprintf("%s%s", a.baseURL, path), nil
}
// BuildRequestHeader sets required headers.
@@ -125,7 +126,10 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRel
}
req := v.(SubmitReq)
- body := a.convertToRequestPayload(&req)
+ body, err := a.convertToRequestPayload(&req)
+ if err != nil {
+ return nil, err
+ }
data, err := json.Marshal(body)
if err != nil {
return nil, err
@@ -135,6 +139,9 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRel
// DoRequest delegates to common helper.
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
+ if action := c.GetString("action"); action != "" {
+ info.Action = action
+ }
return channel.DoTaskApiRequest(a, c, info, requestBody)
}
@@ -175,7 +182,12 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
if !ok {
return nil, fmt.Errorf("invalid task_id")
}
- url := fmt.Sprintf("%s/v1/videos/image2video/%s", baseUrl, taskID)
+ action, ok := body["action"].(string)
+ if !ok {
+ return nil, fmt.Errorf("invalid action")
+ }
+ path := lo.Ternary(action == "generate", "/v1/videos/image2video", "/v1/videos/text2video")
+ url := fmt.Sprintf("%s%s/%s", baseUrl, path, taskID)
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
@@ -210,22 +222,29 @@ func (a *TaskAdaptor) GetChannelName() string {
// helpers
// ============================
-func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) *requestPayload {
- r := &requestPayload{
+func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) {
+ r := requestPayload{
Prompt: req.Prompt,
Image: req.Image,
Mode: defaultString(req.Mode, "std"),
Duration: fmt.Sprintf("%d", defaultInt(req.Duration, 5)),
AspectRatio: a.getAspectRatio(req.Size),
- Model: req.Model,
ModelName: req.Model,
CfgScale: 0.5,
}
- if r.Model == "" {
- r.Model = "kling-v1"
+ if r.ModelName == "" {
r.ModelName = "kling-v1"
}
- return r
+ metadata := req.Metadata
+ medaBytes, err := json.Marshal(metadata)
+ if err != nil {
+ return nil, errors.Wrap(err, "metadata marshal metadata failed")
+ }
+ err = json.Unmarshal(medaBytes, &r)
+ if err != nil {
+ return nil, errors.Wrap(err, "unmarshal metadata failed")
+ }
+ return &r, nil
}
func (a *TaskAdaptor) getAspectRatio(size string) string {
diff --git a/router/video-router.go b/router/video-router.go
index 7201c34a..9e605d54 100644
--- a/router/video-router.go
+++ b/router/video-router.go
@@ -14,4 +14,11 @@ func SetVideoRouter(router *gin.Engine) {
videoV1Router.POST("/video/generations", controller.RelayTask)
videoV1Router.GET("/video/generations/:task_id", controller.RelayTask)
}
+
+ klingV1Router := router.Group("/kling/v1")
+ klingV1Router.Use(middleware.KlingRequestConvert(), middleware.TokenAuth(), middleware.Distribute())
+ {
+ klingV1Router.POST("/videos/text2video", controller.RelayTask)
+ klingV1Router.POST("/videos/image2video", controller.RelayTask)
+ }
}
diff --git a/web/src/components/table/TaskLogsTable.js b/web/src/components/table/TaskLogsTable.js
index 8b309942..5b77ce39 100644
--- a/web/src/components/table/TaskLogsTable.js
+++ b/web/src/components/table/TaskLogsTable.js
@@ -212,7 +212,13 @@ const LogsTable = () => {
case 'generate':
return (
}>
- {t('生成视频')}
+ {t('图生视频')}
+
+ );
+ case 'textGenerate':
+ return (
+ }>
+ {t('文生视频')}
);
default:
@@ -434,7 +440,7 @@ const LogsTable = () => {
fixed: 'right',
render: (text, record, index) => {
// 仅当为视频生成任务且成功,且 fail_reason 是 URL 时显示可点击链接
- const isVideoTask = record.action === 'generate';
+ const isVideoTask = record.action === 'generate' || record.action === 'textGenerate';
const isSuccess = record.status === 'SUCCESS';
const isUrl = typeof text === 'string' && /^https?:\/\//.test(text);
if (isSuccess && isVideoTask && isUrl) {