diff --git a/common/constants.go b/common/constants.go
index bee00506..ac803148 100644
--- a/common/constants.go
+++ b/common/constants.go
@@ -241,6 +241,7 @@ const (
ChannelTypeXinference = 47
ChannelTypeXai = 48
ChannelTypeCoze = 49
+ ChannelTypeKling = 50
ChannelTypeDummy // this one is only for count, do not add any channel after this
)
@@ -296,4 +297,5 @@ var ChannelBaseURLs = []string{
"", //47
"https://api.x.ai", //48
"https://api.coze.cn", //49
+ "https://api.klingai.com", //50
}
diff --git a/constant/task.go b/constant/task.go
index 1a68b812..d466fc8a 100644
--- a/constant/task.go
+++ b/constant/task.go
@@ -5,6 +5,7 @@ type TaskPlatform string
const (
TaskPlatformSuno TaskPlatform = "suno"
TaskPlatformMidjourney = "mj"
+ TaskPlatformKling TaskPlatform = "kling"
)
const (
diff --git a/controller/relay.go b/controller/relay.go
index c1c45114..4da4262b 100644
--- a/controller/relay.go
+++ b/controller/relay.go
@@ -420,7 +420,7 @@ func RelayTask(c *gin.Context) {
func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError {
var err *dto.TaskError
switch relayMode {
- case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID:
+ case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeKlingFetchByID:
err = relay.RelayTaskFetch(c, relayMode)
default:
err = relay.RelayTaskSubmit(c, relayMode)
diff --git a/controller/task.go b/controller/task.go
index 34e14f3f..f7523e87 100644
--- a/controller/task.go
+++ b/controller/task.go
@@ -74,6 +74,8 @@ func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][
//_ = UpdateMidjourneyTaskAll(context.Background(), tasks)
case constant.TaskPlatformSuno:
_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
+ case constant.TaskPlatformKling:
+ _ = UpdateVideoTaskAll(context.Background(), taskChannelM, taskM)
default:
common.SysLog("未知平台")
}
diff --git a/controller/task_video.go b/controller/task_video.go
new file mode 100644
index 00000000..3f2c9588
--- /dev/null
+++ b/controller/task_video.go
@@ -0,0 +1,142 @@
+package controller
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/constant"
+ "one-api/model"
+ "one-api/relay"
+ "one-api/relay/channel"
+)
+
+func UpdateVideoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
+ for channelId, taskIds := range taskChannelM {
+ if err := updateVideoTaskAll(ctx, channelId, taskIds, taskM); err != nil {
+ common.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
+ }
+ }
+ return nil
+}
+
+func updateVideoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
+ common.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
+ if len(taskIds) == 0 {
+ return nil
+ }
+ cacheGetChannel, err := model.CacheGetChannel(channelId)
+ if err != nil {
+ errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{
+ "fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
+ "status": "FAILURE",
+ "progress": "100%",
+ })
+ if errUpdate != nil {
+ common.SysError(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
+ }
+ return fmt.Errorf("CacheGetChannel failed: %w", err)
+ }
+ adaptor := relay.GetTaskAdaptor(constant.TaskPlatformKling)
+ if adaptor == nil {
+ return fmt.Errorf("video adaptor not found")
+ }
+ for _, taskId := range taskIds {
+ if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
+ common.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
+ }
+ }
+ return nil
+}
+
+func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error {
+ baseURL := common.ChannelBaseURLs[channel.Type]
+ if channel.GetBaseURL() != "" {
+ baseURL = channel.GetBaseURL()
+ }
+ resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
+ "task_id": taskId,
+ })
+ if err != nil {
+ return fmt.Errorf("FetchTask failed for task %s: %w", taskId, err)
+ }
+ if resp.StatusCode != http.StatusOK {
+ return fmt.Errorf("Get Video Task status code: %d", resp.StatusCode)
+ }
+ defer resp.Body.Close()
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return fmt.Errorf("ReadAll failed for task %s: %w", taskId, err)
+ }
+
+ var responseItem map[string]interface{}
+ err = json.Unmarshal(responseBody, &responseItem)
+ if err != nil {
+ common.LogError(ctx, fmt.Sprintf("Failed to parse video task response body: %v, body: %s", err, string(responseBody)))
+ return fmt.Errorf("Unmarshal failed for task %s: %w", taskId, err)
+ }
+
+ code, _ := responseItem["code"].(float64)
+ if code != 0 {
+ return fmt.Errorf("video task fetch failed for task %s", taskId)
+ }
+
+ data, ok := responseItem["data"].(map[string]interface{})
+ if !ok {
+ common.LogError(ctx, fmt.Sprintf("Video task data format error: %s", string(responseBody)))
+ return fmt.Errorf("video task data format error for task %s", taskId)
+ }
+
+ task := taskM[taskId]
+ if task == nil {
+ common.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
+ return fmt.Errorf("task %s not found", taskId)
+ }
+
+ if status, ok := data["task_status"].(string); ok {
+ switch status {
+ case "submitted", "queued":
+ task.Status = model.TaskStatusSubmitted
+ case "processing":
+ task.Status = model.TaskStatusInProgress
+ case "succeed":
+ task.Status = model.TaskStatusSuccess
+ task.Progress = "100%"
+ if url, err := adaptor.(interface {
+ ParseResultUrl(map[string]any) (string, error)
+ }).ParseResultUrl(responseItem); err == nil {
+ task.FailReason = url
+ } else {
+ common.LogWarn(ctx, fmt.Sprintf("Failed to get url from body for task %s: %s", task.TaskID, err.Error()))
+ }
+ case "failed":
+ task.Status = model.TaskStatusFailure
+ task.Progress = "100%"
+ if reason, ok := data["fail_reason"].(string); ok {
+ task.FailReason = reason
+ }
+ }
+ }
+
+ // If task failed, refund quota
+ if task.Status == model.TaskStatusFailure {
+ common.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
+ quota := task.Quota
+ if quota != 0 {
+ if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
+ common.LogError(ctx, "Failed to increase user quota: "+err.Error())
+ }
+ logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, common.LogQuota(quota))
+ model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
+ }
+ }
+
+ task.Data = responseBody
+ if err := task.Update(); err != nil {
+ common.SysError("UpdateVideoTask task error: " + err.Error())
+ }
+
+ return nil
+}
diff --git a/dto/video.go b/dto/video.go
new file mode 100644
index 00000000..5b48146a
--- /dev/null
+++ b/dto/video.go
@@ -0,0 +1,47 @@
+package dto
+
+type VideoRequest struct {
+ Model string `json:"model,omitempty" example:"kling-v1"` // Model/style ID
+ Prompt string `json:"prompt,omitempty" example:"宇航员站起身走了"` // Text prompt
+ Image string `json:"image,omitempty" example:"https://h2.inkwai.com/bs2/upload-ylab-stunt/se/ai_portal_queue_mmu_image_upscale_aiweb/3214b798-e1b4-4b00-b7af-72b5b0417420_raw_image_0.jpg"` // Image input (URL/Base64)
+ Duration float64 `json:"duration" example:"5.0"` // Video duration (seconds)
+ Width int `json:"width" example:"512"` // Video width
+ Height int `json:"height" example:"512"` // Video height
+ Fps int `json:"fps,omitempty" example:"30"` // Video frame rate
+ Seed int `json:"seed,omitempty" example:"20231234"` // Random seed
+ N int `json:"n,omitempty" example:"1"` // Number of videos to generate
+ ResponseFormat string `json:"response_format,omitempty" example:"url"` // Response format
+ User string `json:"user,omitempty" example:"user-1234"` // User identifier
+ Metadata map[string]any `json:"metadata,omitempty"` // Vendor-specific/custom params (e.g. negative_prompt, style, quality_level, etc.)
+}
+
+// VideoResponse 视频生成提交任务后的响应
+type VideoResponse struct {
+ TaskId string `json:"task_id"`
+ Status string `json:"status"`
+}
+
+// VideoTaskResponse 查询视频生成任务状态的响应
+type VideoTaskResponse struct {
+ TaskId string `json:"task_id" example:"abcd1234efgh"` // 任务ID
+ Status string `json:"status" example:"succeeded"` // 任务状态
+ Url string `json:"url,omitempty"` // 视频资源URL(成功时)
+ Format string `json:"format,omitempty" example:"mp4"` // 视频格式
+ Metadata *VideoTaskMetadata `json:"metadata,omitempty"` // 结果元数据
+ Error *VideoTaskError `json:"error,omitempty"` // 错误信息(失败时)
+}
+
+// VideoTaskMetadata 视频任务元数据
+type VideoTaskMetadata struct {
+ Duration float64 `json:"duration" example:"5.0"` // 实际生成的视频时长
+ Fps int `json:"fps" example:"30"` // 实际帧率
+ Width int `json:"width" example:"512"` // 实际宽度
+ Height int `json:"height" example:"512"` // 实际高度
+ Seed int `json:"seed" example:"20231234"` // 使用的随机种子
+}
+
+// VideoTaskError 视频任务错误信息
+type VideoTaskError struct {
+ Code int `json:"code"`
+ Message string `json:"message"`
+}
diff --git a/middleware/distributor.go b/middleware/distributor.go
index 84eb182e..9d074ce8 100644
--- a/middleware/distributor.go
+++ b/middleware/distributor.go
@@ -170,6 +170,15 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
}
c.Set("platform", string(constant.TaskPlatformSuno))
c.Set("relay_mode", relayMode)
+ } else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") {
+ relayMode := relayconstant.Path2RelayKling(c.Request.Method, c.Request.URL.Path)
+ if relayMode == relayconstant.RelayModeKlingFetchByID {
+ shouldSelectChannel = false
+ } else {
+ err = common.UnmarshalBodyReusable(c, &modelRequest)
+ }
+ c.Set("platform", string(constant.TaskPlatformKling))
+ c.Set("relay_mode", relayMode)
} else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") {
// Gemini API 路径处理: /v1beta/models/gemini-2.0-flash:generateContent
relayMode := relayconstant.RelayModeGemini
diff --git a/relay/channel/adapter.go b/relay/channel/adapter.go
index 50255d0a..873997f6 100644
--- a/relay/channel/adapter.go
+++ b/relay/channel/adapter.go
@@ -44,4 +44,6 @@ type TaskAdaptor interface {
// FetchTask
FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error)
+
+ ParseResultUrl(resp map[string]any) (string, error)
}
diff --git a/relay/channel/task/kling/adaptor.go b/relay/channel/task/kling/adaptor.go
new file mode 100644
index 00000000..9c6773f5
--- /dev/null
+++ b/relay/channel/task/kling/adaptor.go
@@ -0,0 +1,312 @@
+package kling
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "strings"
+ "time"
+
+ "github.com/gin-gonic/gin"
+ "github.com/golang-jwt/jwt"
+ "github.com/pkg/errors"
+
+ "one-api/common"
+ "one-api/dto"
+ "one-api/relay/channel"
+ relaycommon "one-api/relay/common"
+ "one-api/service"
+)
+
+// ============================
+// Request / Response structures
+// ============================
+
+type SubmitReq struct {
+ Prompt string `json:"prompt"`
+ Model string `json:"model,omitempty"`
+ Mode string `json:"mode,omitempty"`
+ Image string `json:"image,omitempty"`
+ Size string `json:"size,omitempty"`
+ Duration int `json:"duration,omitempty"`
+ Metadata map[string]interface{} `json:"metadata,omitempty"`
+}
+
+type requestPayload struct {
+ Prompt string `json:"prompt,omitempty"`
+ Image string `json:"image,omitempty"`
+ Mode string `json:"mode,omitempty"`
+ Duration string `json:"duration,omitempty"`
+ AspectRatio string `json:"aspect_ratio,omitempty"`
+ Model string `json:"model,omitempty"`
+ ModelName string `json:"model_name,omitempty"`
+ CfgScale float64 `json:"cfg_scale,omitempty"`
+}
+
+type responsePayload struct {
+ Code int `json:"code"`
+ Message string `json:"message"`
+ Data struct {
+ TaskID string `json:"task_id"`
+ } `json:"data"`
+}
+
+// ============================
+// Adaptor implementation
+// ============================
+
+type TaskAdaptor struct {
+ ChannelType int
+ accessKey string
+ secretKey string
+ baseURL string
+}
+
+func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
+ a.ChannelType = info.ChannelType
+ a.baseURL = info.BaseUrl
+
+ // apiKey format: "access_key,secret_key"
+ keyParts := strings.Split(info.ApiKey, ",")
+ if len(keyParts) == 2 {
+ a.accessKey = strings.TrimSpace(keyParts[0])
+ a.secretKey = strings.TrimSpace(keyParts[1])
+ }
+}
+
+// ValidateRequestAndSetAction parses body, validates fields and sets default action.
+func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) {
+ // Accept only POST /v1/video/generations as "generate" action.
+ action := "generate"
+ info.Action = action
+
+ var req SubmitReq
+ if err := common.UnmarshalBodyReusable(c, &req); err != nil {
+ taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
+ return
+ }
+ if strings.TrimSpace(req.Prompt) == "" {
+ taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest)
+ return
+ }
+
+ // Store into context for later usage
+ c.Set("kling_request", req)
+ return nil
+}
+
+// BuildRequestURL constructs the upstream URL.
+func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
+ return fmt.Sprintf("%s/v1/videos/image2video", a.baseURL), nil
+}
+
+// BuildRequestHeader sets required headers.
+func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error {
+ token, err := a.createJWTToken()
+ if err != nil {
+ token = info.ApiKey // fallback
+ }
+
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Accept", "application/json")
+ req.Header.Set("Authorization", "Bearer "+token)
+ req.Header.Set("User-Agent", "kling-sdk/1.0")
+ return nil
+}
+
+// BuildRequestBody converts request into Kling specific format.
+func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) {
+ v, exists := c.Get("kling_request")
+ if !exists {
+ return nil, fmt.Errorf("request not found in context")
+ }
+ req := v.(SubmitReq)
+
+ body := a.convertToRequestPayload(&req)
+ data, err := json.Marshal(body)
+ if err != nil {
+ return nil, err
+ }
+ return bytes.NewReader(data), nil
+}
+
+// DoRequest delegates to common helper.
+func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
+ return channel.DoTaskApiRequest(a, c, info, requestBody)
+}
+
+// DoResponse handles upstream response, returns taskID etc.
+func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
+ return
+ }
+
+ // Attempt Kling response parse first.
+ var kResp responsePayload
+ if err := json.Unmarshal(responseBody, &kResp); err == nil && kResp.Code == 0 {
+ c.JSON(http.StatusOK, gin.H{"task_id": kResp.Data.TaskID})
+ return kResp.Data.TaskID, responseBody, nil
+ }
+
+ // Fallback generic task response.
+ var generic dto.TaskResponse[string]
+ if err := json.Unmarshal(responseBody, &generic); err != nil {
+ taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
+ return
+ }
+
+ if !generic.IsSuccess() {
+ taskErr = service.TaskErrorWrapper(fmt.Errorf(generic.Message), generic.Code, http.StatusInternalServerError)
+ return
+ }
+
+ c.JSON(http.StatusOK, gin.H{"task_id": generic.Data})
+ return generic.Data, responseBody, nil
+}
+
+// FetchTask fetch task status
+func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
+ taskID, ok := body["task_id"].(string)
+ if !ok {
+ return nil, fmt.Errorf("invalid task_id")
+ }
+ url := fmt.Sprintf("%s/v1/videos/image2video/%s", baseUrl, taskID)
+
+ req, err := http.NewRequest(http.MethodGet, url, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ token, err := a.createJWTTokenWithKey(key)
+ if err != nil {
+ token = key
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
+ defer cancel()
+
+ req = req.WithContext(ctx)
+ req.Header.Set("Accept", "application/json")
+ req.Header.Set("Authorization", "Bearer "+token)
+ req.Header.Set("User-Agent", "kling-sdk/1.0")
+
+ return service.GetHttpClient().Do(req)
+}
+
+func (a *TaskAdaptor) GetModelList() []string {
+ return []string{"kling-v1", "kling-v1-6", "kling-v2-master"}
+}
+
+func (a *TaskAdaptor) GetChannelName() string {
+ return "kling"
+}
+
+// ============================
+// helpers
+// ============================
+
+func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) *requestPayload {
+ r := &requestPayload{
+ Prompt: req.Prompt,
+ Image: req.Image,
+ Mode: defaultString(req.Mode, "std"),
+ Duration: fmt.Sprintf("%d", defaultInt(req.Duration, 5)),
+ AspectRatio: a.getAspectRatio(req.Size),
+ Model: req.Model,
+ ModelName: req.Model,
+ CfgScale: 0.5,
+ }
+ if r.Model == "" {
+ r.Model = "kling-v1"
+ r.ModelName = "kling-v1"
+ }
+ return r
+}
+
+func (a *TaskAdaptor) getAspectRatio(size string) string {
+ switch size {
+ case "1024x1024", "512x512":
+ return "1:1"
+ case "1280x720", "1920x1080":
+ return "16:9"
+ case "720x1280", "1080x1920":
+ return "9:16"
+ default:
+ return "1:1"
+ }
+}
+
+func defaultString(s, def string) string {
+ if strings.TrimSpace(s) == "" {
+ return def
+ }
+ return s
+}
+
+func defaultInt(v int, def int) int {
+ if v == 0 {
+ return def
+ }
+ return v
+}
+
+// ============================
+// JWT helpers
+// ============================
+
+func (a *TaskAdaptor) createJWTToken() (string, error) {
+ return a.createJWTTokenWithKeys(a.accessKey, a.secretKey)
+}
+
+func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) {
+ parts := strings.Split(apiKey, ",")
+ if len(parts) != 2 {
+ return "", fmt.Errorf("invalid API key format, expected 'access_key,secret_key'")
+ }
+ return a.createJWTTokenWithKeys(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]))
+}
+
+func (a *TaskAdaptor) createJWTTokenWithKeys(accessKey, secretKey string) (string, error) {
+ if accessKey == "" || secretKey == "" {
+ return "", fmt.Errorf("access key and secret key are required")
+ }
+ now := time.Now().Unix()
+ claims := jwt.MapClaims{
+ "iss": accessKey,
+ "exp": now + 1800, // 30 minutes
+ "nbf": now - 5,
+ }
+ token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
+ token.Header["typ"] = "JWT"
+ return token.SignedString([]byte(secretKey))
+}
+
+// ParseResultUrl 提取视频任务结果的 url
+func (a *TaskAdaptor) ParseResultUrl(resp map[string]any) (string, error) {
+ data, ok := resp["data"].(map[string]any)
+ if !ok {
+ return "", fmt.Errorf("data field not found or invalid")
+ }
+ taskResult, ok := data["task_result"].(map[string]any)
+ if !ok {
+ return "", fmt.Errorf("task_result field not found or invalid")
+ }
+ videos, ok := taskResult["videos"].([]interface{})
+ if !ok || len(videos) == 0 {
+ return "", fmt.Errorf("videos field not found or empty")
+ }
+ video, ok := videos[0].(map[string]interface{})
+ if !ok {
+ return "", fmt.Errorf("video item invalid")
+ }
+ url, ok := video["url"].(string)
+ if !ok || url == "" {
+ return "", fmt.Errorf("url field not found or invalid")
+ }
+ return url, nil
+}
diff --git a/relay/channel/task/suno/adaptor.go b/relay/channel/task/suno/adaptor.go
index 03d60516..f7042348 100644
--- a/relay/channel/task/suno/adaptor.go
+++ b/relay/channel/task/suno/adaptor.go
@@ -22,6 +22,10 @@ type TaskAdaptor struct {
ChannelType int
}
+func (a *TaskAdaptor) ParseResultUrl(resp map[string]any) (string, error) {
+ return "", nil // todo implement this method if needed
+}
+
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
a.ChannelType = info.ChannelType
}
diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go
index f22a20bd..02a286e2 100644
--- a/relay/constant/relay_mode.go
+++ b/relay/constant/relay_mode.go
@@ -38,6 +38,9 @@ const (
RelayModeSunoFetchByID
RelayModeSunoSubmit
+ RelayModeKlingFetchByID
+ RelayModeKlingSubmit
+
RelayModeRerank
RelayModeResponses
@@ -133,3 +136,13 @@ func Path2RelaySuno(method, path string) int {
}
return relayMode
}
+
+func Path2RelayKling(method, path string) int {
+ relayMode := RelayModeUnknown
+ if method == http.MethodPost && strings.HasSuffix(path, "/video/generations") {
+ relayMode = RelayModeKlingSubmit
+ } else if method == http.MethodGet && strings.Contains(path, "/video/generations/") {
+ relayMode = RelayModeKlingFetchByID
+ }
+ return relayMode
+}
diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go
index 7bf0da9f..626bb7e4 100644
--- a/relay/relay_adaptor.go
+++ b/relay/relay_adaptor.go
@@ -22,6 +22,7 @@ import (
"one-api/relay/channel/palm"
"one-api/relay/channel/perplexity"
"one-api/relay/channel/siliconflow"
+ "one-api/relay/channel/task/kling"
"one-api/relay/channel/task/suno"
"one-api/relay/channel/tencent"
"one-api/relay/channel/vertex"
@@ -101,6 +102,8 @@ func GetTaskAdaptor(platform commonconstant.TaskPlatform) channel.TaskAdaptor {
// return &aiproxy.Adaptor{}
case commonconstant.TaskPlatformSuno:
return &suno.TaskAdaptor{}
+ case commonconstant.TaskPlatformKling:
+ return &kling.TaskAdaptor{}
}
return nil
}
diff --git a/relay/relay_task.go b/relay/relay_task.go
index 3da9a20f..245fd681 100644
--- a/relay/relay_task.go
+++ b/relay/relay_task.go
@@ -37,6 +37,9 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
}
modelName := service.CoverTaskActionToModelName(platform, relayInfo.Action)
+ if platform == constant.TaskPlatformKling {
+ modelName = relayInfo.OriginModelName
+ }
modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
if !success {
defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[modelName]
@@ -136,10 +139,11 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
}
relayInfo.ConsumeQuota = true
// insert task
- task := model.InitTask(constant.TaskPlatformSuno, relayInfo)
+ task := model.InitTask(platform, relayInfo)
task.TaskID = taskID
task.Quota = quota
task.Data = taskData
+ task.Action = relayInfo.Action
err = task.Insert()
if err != nil {
taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError)
@@ -149,8 +153,9 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
}
var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
- relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder,
- relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder,
+ relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder,
+ relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder,
+ relayconstant.RelayModeKlingFetchByID: videoFetchByIDRespBodyBuilder,
}
func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) {
@@ -225,6 +230,27 @@ func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dt
return
}
+func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
+ taskId := c.Param("id")
+ userId := c.GetInt("id")
+
+ originTask, exist, err := model.GetByTaskId(userId, taskId)
+ if err != nil {
+ taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
+ return
+ }
+ if !exist {
+ taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
+ return
+ }
+
+ respBody, err = json.Marshal(dto.TaskResponse[any]{
+ Code: "success",
+ Data: TaskModel2Dto(originTask),
+ })
+ return
+}
+
func TaskModel2Dto(task *model.Task) *dto.TaskDto {
return &dto.TaskDto{
TaskID: task.TaskID,
diff --git a/router/main.go b/router/main.go
index b8ac4055..0d2bfdce 100644
--- a/router/main.go
+++ b/router/main.go
@@ -14,6 +14,7 @@ func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
SetApiRouter(router)
SetDashboardRouter(router)
SetRelayRouter(router)
+ SetVideoRouter(router)
frontendBaseUrl := os.Getenv("FRONTEND_BASE_URL")
if common.IsMasterNode && frontendBaseUrl != "" {
frontendBaseUrl = ""
diff --git a/router/video-router.go b/router/video-router.go
new file mode 100644
index 00000000..7201c34a
--- /dev/null
+++ b/router/video-router.go
@@ -0,0 +1,17 @@
+package router
+
+import (
+ "one-api/controller"
+ "one-api/middleware"
+
+ "github.com/gin-gonic/gin"
+)
+
+func SetVideoRouter(router *gin.Engine) {
+ videoV1Router := router.Group("/v1")
+ videoV1Router.Use(middleware.TokenAuth(), middleware.Distribute())
+ {
+ videoV1Router.POST("/video/generations", controller.RelayTask)
+ videoV1Router.GET("/video/generations/:task_id", controller.RelayTask)
+ }
+}
diff --git a/web/src/components/table/TaskLogsTable.js b/web/src/components/table/TaskLogsTable.js
index b3d0ab7b..37bdde57 100644
--- a/web/src/components/table/TaskLogsTable.js
+++ b/web/src/components/table/TaskLogsTable.js
@@ -11,7 +11,9 @@ import {
XCircle,
Loader,
List,
- Hash
+ Hash,
+ Video,
+ Sparkles
} from 'lucide-react';
import {
API,
@@ -80,6 +82,7 @@ const COLUMN_KEYS = {
TASK_STATUS: 'task_status',
PROGRESS: 'progress',
FAIL_REASON: 'fail_reason',
+ RESULT_URL: 'result_url',
};
const renderTimestamp = (timestampInSeconds) => {
@@ -150,6 +153,7 @@ const LogsTable = () => {
[COLUMN_KEYS.TASK_STATUS]: true,
[COLUMN_KEYS.PROGRESS]: true,
[COLUMN_KEYS.FAIL_REASON]: true,
+ [COLUMN_KEYS.RESULT_URL]: true,
};
};
@@ -203,6 +207,12 @@ const LogsTable = () => {
{t('生成歌词')}
);
+ case 'generate':
+ return (
+ }>
+ {t('生成视频')}
+
+ );
default:
return (
}>
@@ -220,6 +230,12 @@ const LogsTable = () => {
Suno
);
+ case 'kling':
+ return (
+ }>
+ Kling
+
+ );
default:
return (
}>
@@ -411,10 +427,21 @@ const LogsTable = () => {
},
{
key: COLUMN_KEYS.FAIL_REASON,
- title: t('失败原因'),
+ title: t('详情'),
dataIndex: 'fail_reason',
fixed: 'right',
render: (text, record, index) => {
+ // 仅当为视频生成任务且成功,且 fail_reason 是 URL 时显示可点击链接
+ const isVideoTask = record.action === 'generate';
+ const isSuccess = record.status === 'SUCCESS';
+ const isUrl = typeof text === 'string' && /^https?:\/\//.test(text);
+ if (isSuccess && isVideoTask && isUrl) {
+ return (
+
+ {t('点击预览视频')}
+
+ );
+ }
if (!text) {
return t('无');
}
diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js
index 20fed5b7..c4220bd4 100644
--- a/web/src/constants/channel.constants.js
+++ b/web/src/constants/channel.constants.js
@@ -125,4 +125,9 @@ export const CHANNEL_OPTIONS = [
color: 'blue',
label: 'Coze',
},
+ {
+ value: 50,
+ color: 'green',
+ label: '可灵',
+ },
];