Merge pull request #1291 from feitianbubu/pr/add-origin-kling-api

feat: add origin kling api
This commit is contained in:
Xyfacai
2025-06-27 16:08:03 +08:00
committed by GitHub
5 changed files with 96 additions and 18 deletions

View File

@@ -56,8 +56,15 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
if channel.GetBaseURL() != "" { if channel.GetBaseURL() != "" {
baseURL = channel.GetBaseURL() baseURL = channel.GetBaseURL()
} }
task := taskM[taskId]
if task == nil {
common.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
return fmt.Errorf("task %s not found", taskId)
}
resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{ resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
"task_id": taskId, "task_id": taskId,
"action": task.Action,
}) })
if err != nil { if err != nil {
return fmt.Errorf("FetchTask failed for task %s: %w", taskId, err) return fmt.Errorf("FetchTask failed for task %s: %w", taskId, err)
@@ -89,12 +96,6 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
return fmt.Errorf("video task data format error for task %s", taskId) return fmt.Errorf("video task data format error for task %s", taskId)
} }
task := taskM[taskId]
if task == nil {
common.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
return fmt.Errorf("task %s not found", taskId)
}
if status, ok := data["task_status"].(string); ok { if status, ok := data["task_status"].(string); ok {
switch status { switch status {
case "submitted", "queued": case "submitted", "queued":

View File

@@ -0,0 +1,45 @@
package middleware
import (
"bytes"
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"one-api/common"
)
func KlingRequestConvert() func(c *gin.Context) {
return func(c *gin.Context) {
var originalReq map[string]interface{}
if err := common.UnmarshalBodyReusable(c, &originalReq); err != nil {
c.Next()
return
}
model, _ := originalReq["model"].(string)
prompt, _ := originalReq["prompt"].(string)
unifiedReq := map[string]interface{}{
"model": model,
"prompt": prompt,
"metadata": originalReq,
}
jsonData, err := json.Marshal(unifiedReq)
if err != nil {
c.Next()
return
}
// Rewrite request body and path
c.Request.Body = io.NopCloser(bytes.NewBuffer(jsonData))
c.Request.URL.Path = "/v1/video/generations"
if image := originalReq["image"]; image == "" {
c.Set("action", "textGenerate")
}
// We have to reset the request body for the next handlers
c.Set(common.KeyRequestBody, jsonData)
c.Next()
}
}

View File

@@ -5,6 +5,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/samber/lo"
"io" "io"
"net/http" "net/http"
"strings" "strings"
@@ -41,7 +42,6 @@ type requestPayload struct {
Mode string `json:"mode,omitempty"` Mode string `json:"mode,omitempty"`
Duration string `json:"duration,omitempty"` Duration string `json:"duration,omitempty"`
AspectRatio string `json:"aspect_ratio,omitempty"` AspectRatio string `json:"aspect_ratio,omitempty"`
Model string `json:"model,omitempty"`
ModelName string `json:"model_name,omitempty"` ModelName string `json:"model_name,omitempty"`
CfgScale float64 `json:"cfg_scale,omitempty"` CfgScale float64 `json:"cfg_scale,omitempty"`
} }
@@ -100,7 +100,8 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
// BuildRequestURL constructs the upstream URL. // BuildRequestURL constructs the upstream URL.
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) { func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
return fmt.Sprintf("%s/v1/videos/image2video", a.baseURL), nil path := lo.Ternary(info.Action == "generate", "/v1/videos/image2video", "/v1/videos/text2video")
return fmt.Sprintf("%s%s", a.baseURL, path), nil
} }
// BuildRequestHeader sets required headers. // BuildRequestHeader sets required headers.
@@ -125,7 +126,10 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRel
} }
req := v.(SubmitReq) req := v.(SubmitReq)
body := a.convertToRequestPayload(&req) body, err := a.convertToRequestPayload(&req)
if err != nil {
return nil, err
}
data, err := json.Marshal(body) data, err := json.Marshal(body)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -135,6 +139,9 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRel
// DoRequest delegates to common helper. // DoRequest delegates to common helper.
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
if action := c.GetString("action"); action != "" {
info.Action = action
}
return channel.DoTaskApiRequest(a, c, info, requestBody) return channel.DoTaskApiRequest(a, c, info, requestBody)
} }
@@ -175,7 +182,12 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
if !ok { if !ok {
return nil, fmt.Errorf("invalid task_id") return nil, fmt.Errorf("invalid task_id")
} }
url := fmt.Sprintf("%s/v1/videos/image2video/%s", baseUrl, taskID) action, ok := body["action"].(string)
if !ok {
return nil, fmt.Errorf("invalid action")
}
path := lo.Ternary(action == "generate", "/v1/videos/image2video", "/v1/videos/text2video")
url := fmt.Sprintf("%s%s/%s", baseUrl, path, taskID)
req, err := http.NewRequest(http.MethodGet, url, nil) req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil { if err != nil {
@@ -210,22 +222,29 @@ func (a *TaskAdaptor) GetChannelName() string {
// helpers // helpers
// ============================ // ============================
func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) *requestPayload { func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) {
r := &requestPayload{ r := requestPayload{
Prompt: req.Prompt, Prompt: req.Prompt,
Image: req.Image, Image: req.Image,
Mode: defaultString(req.Mode, "std"), Mode: defaultString(req.Mode, "std"),
Duration: fmt.Sprintf("%d", defaultInt(req.Duration, 5)), Duration: fmt.Sprintf("%d", defaultInt(req.Duration, 5)),
AspectRatio: a.getAspectRatio(req.Size), AspectRatio: a.getAspectRatio(req.Size),
Model: req.Model,
ModelName: req.Model, ModelName: req.Model,
CfgScale: 0.5, CfgScale: 0.5,
} }
if r.Model == "" { if r.ModelName == "" {
r.Model = "kling-v1"
r.ModelName = "kling-v1" r.ModelName = "kling-v1"
} }
return r metadata := req.Metadata
medaBytes, err := json.Marshal(metadata)
if err != nil {
return nil, errors.Wrap(err, "metadata marshal metadata failed")
}
err = json.Unmarshal(medaBytes, &r)
if err != nil {
return nil, errors.Wrap(err, "unmarshal metadata failed")
}
return &r, nil
} }
func (a *TaskAdaptor) getAspectRatio(size string) string { func (a *TaskAdaptor) getAspectRatio(size string) string {

View File

@@ -14,4 +14,11 @@ func SetVideoRouter(router *gin.Engine) {
videoV1Router.POST("/video/generations", controller.RelayTask) videoV1Router.POST("/video/generations", controller.RelayTask)
videoV1Router.GET("/video/generations/:task_id", controller.RelayTask) videoV1Router.GET("/video/generations/:task_id", controller.RelayTask)
} }
klingV1Router := router.Group("/kling/v1")
klingV1Router.Use(middleware.KlingRequestConvert(), middleware.TokenAuth(), middleware.Distribute())
{
klingV1Router.POST("/videos/text2video", controller.RelayTask)
klingV1Router.POST("/videos/image2video", controller.RelayTask)
}
} }

View File

@@ -212,7 +212,13 @@ const LogsTable = () => {
case 'generate': case 'generate':
return ( return (
<Tag color='blue' size='large' shape='circle' prefixIcon={<Sparkles size={14} />}> <Tag color='blue' size='large' shape='circle' prefixIcon={<Sparkles size={14} />}>
{t('生视频')} {t('生视频')}
</Tag>
);
case 'textGenerate':
return (
<Tag color='blue' size='large' shape='circle' prefixIcon={<Sparkles size={14} />}>
{t('文生视频')}
</Tag> </Tag>
); );
default: default:
@@ -434,7 +440,7 @@ const LogsTable = () => {
fixed: 'right', fixed: 'right',
render: (text, record, index) => { render: (text, record, index) => {
// 仅当为视频生成任务且成功,且 fail_reason 是 URL 时显示可点击链接 // 仅当为视频生成任务且成功,且 fail_reason 是 URL 时显示可点击链接
const isVideoTask = record.action === 'generate'; const isVideoTask = record.action === 'generate' || record.action === 'textGenerate';
const isSuccess = record.status === 'SUCCESS'; const isSuccess = record.status === 'SUCCESS';
const isUrl = typeof text === 'string' && /^https?:\/\//.test(text); const isUrl = typeof text === 'string' && /^https?:\/\//.test(text);
if (isSuccess && isVideoTask && isUrl) { if (isSuccess && isVideoTask && isUrl) {