feat: add video channel kling
This commit is contained in:
@@ -241,6 +241,7 @@ const (
|
|||||||
ChannelTypeXinference = 47
|
ChannelTypeXinference = 47
|
||||||
ChannelTypeXai = 48
|
ChannelTypeXai = 48
|
||||||
ChannelTypeCoze = 49
|
ChannelTypeCoze = 49
|
||||||
|
ChannelTypeKling = 50
|
||||||
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
||||||
|
|
||||||
)
|
)
|
||||||
@@ -296,4 +297,5 @@ var ChannelBaseURLs = []string{
|
|||||||
"", //47
|
"", //47
|
||||||
"https://api.x.ai", //48
|
"https://api.x.ai", //48
|
||||||
"https://api.coze.cn", //49
|
"https://api.coze.cn", //49
|
||||||
|
"https://api.klingai.com", //50
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ type TaskPlatform string
|
|||||||
const (
|
const (
|
||||||
TaskPlatformSuno TaskPlatform = "suno"
|
TaskPlatformSuno TaskPlatform = "suno"
|
||||||
TaskPlatformMidjourney = "mj"
|
TaskPlatformMidjourney = "mj"
|
||||||
|
TaskPlatformKling TaskPlatform = "kling"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|||||||
@@ -420,7 +420,7 @@ func RelayTask(c *gin.Context) {
|
|||||||
func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError {
|
func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError {
|
||||||
var err *dto.TaskError
|
var err *dto.TaskError
|
||||||
switch relayMode {
|
switch relayMode {
|
||||||
case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID:
|
case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeKlingFetchByID:
|
||||||
err = relay.RelayTaskFetch(c, relayMode)
|
err = relay.RelayTaskFetch(c, relayMode)
|
||||||
default:
|
default:
|
||||||
err = relay.RelayTaskSubmit(c, relayMode)
|
err = relay.RelayTaskSubmit(c, relayMode)
|
||||||
|
|||||||
@@ -74,6 +74,8 @@ func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][
|
|||||||
//_ = UpdateMidjourneyTaskAll(context.Background(), tasks)
|
//_ = UpdateMidjourneyTaskAll(context.Background(), tasks)
|
||||||
case constant.TaskPlatformSuno:
|
case constant.TaskPlatformSuno:
|
||||||
_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
|
_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
|
||||||
|
case constant.TaskPlatformKling:
|
||||||
|
_ = UpdateVideoTaskAll(context.Background(), taskChannelM, taskM)
|
||||||
default:
|
default:
|
||||||
common.SysLog("未知平台")
|
common.SysLog("未知平台")
|
||||||
}
|
}
|
||||||
|
|||||||
142
controller/task_video.go
Normal file
142
controller/task_video.go
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
|
"one-api/model"
|
||||||
|
"one-api/relay"
|
||||||
|
"one-api/relay/channel"
|
||||||
|
)
|
||||||
|
|
||||||
|
func UpdateVideoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
|
||||||
|
for channelId, taskIds := range taskChannelM {
|
||||||
|
if err := updateVideoTaskAll(ctx, channelId, taskIds, taskM); err != nil {
|
||||||
|
common.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func updateVideoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
|
||||||
|
common.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
|
||||||
|
if len(taskIds) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
cacheGetChannel, err := model.CacheGetChannel(channelId)
|
||||||
|
if err != nil {
|
||||||
|
errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{
|
||||||
|
"fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
|
||||||
|
"status": "FAILURE",
|
||||||
|
"progress": "100%",
|
||||||
|
})
|
||||||
|
if errUpdate != nil {
|
||||||
|
common.SysError(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
|
||||||
|
}
|
||||||
|
return fmt.Errorf("CacheGetChannel failed: %w", err)
|
||||||
|
}
|
||||||
|
adaptor := relay.GetTaskAdaptor(constant.TaskPlatformKling)
|
||||||
|
if adaptor == nil {
|
||||||
|
return fmt.Errorf("video adaptor not found")
|
||||||
|
}
|
||||||
|
for _, taskId := range taskIds {
|
||||||
|
if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
|
||||||
|
common.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error {
|
||||||
|
baseURL := common.ChannelBaseURLs[channel.Type]
|
||||||
|
if channel.GetBaseURL() != "" {
|
||||||
|
baseURL = channel.GetBaseURL()
|
||||||
|
}
|
||||||
|
resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
|
||||||
|
"task_id": taskId,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("FetchTask failed for task %s: %w", taskId, err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return fmt.Errorf("Get Video Task status code: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("ReadAll failed for task %s: %w", taskId, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var responseItem map[string]interface{}
|
||||||
|
err = json.Unmarshal(responseBody, &responseItem)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(ctx, fmt.Sprintf("Failed to parse video task response body: %v, body: %s", err, string(responseBody)))
|
||||||
|
return fmt.Errorf("Unmarshal failed for task %s: %w", taskId, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
code, _ := responseItem["code"].(float64)
|
||||||
|
if code != 0 {
|
||||||
|
return fmt.Errorf("video task fetch failed for task %s", taskId)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, ok := responseItem["data"].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
common.LogError(ctx, fmt.Sprintf("Video task data format error: %s", string(responseBody)))
|
||||||
|
return fmt.Errorf("video task data format error for task %s", taskId)
|
||||||
|
}
|
||||||
|
|
||||||
|
task := taskM[taskId]
|
||||||
|
if task == nil {
|
||||||
|
common.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
|
||||||
|
return fmt.Errorf("task %s not found", taskId)
|
||||||
|
}
|
||||||
|
|
||||||
|
if status, ok := data["task_status"].(string); ok {
|
||||||
|
switch status {
|
||||||
|
case "submitted", "queued":
|
||||||
|
task.Status = model.TaskStatusSubmitted
|
||||||
|
case "processing":
|
||||||
|
task.Status = model.TaskStatusInProgress
|
||||||
|
case "succeed":
|
||||||
|
task.Status = model.TaskStatusSuccess
|
||||||
|
task.Progress = "100%"
|
||||||
|
if url, err := adaptor.(interface {
|
||||||
|
ParseResultUrl(map[string]any) (string, error)
|
||||||
|
}).ParseResultUrl(responseItem); err == nil {
|
||||||
|
task.FailReason = url
|
||||||
|
} else {
|
||||||
|
common.LogWarn(ctx, fmt.Sprintf("Failed to get url from body for task %s: %s", task.TaskID, err.Error()))
|
||||||
|
}
|
||||||
|
case "failed":
|
||||||
|
task.Status = model.TaskStatusFailure
|
||||||
|
task.Progress = "100%"
|
||||||
|
if reason, ok := data["fail_reason"].(string); ok {
|
||||||
|
task.FailReason = reason
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If task failed, refund quota
|
||||||
|
if task.Status == model.TaskStatusFailure {
|
||||||
|
common.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
|
||||||
|
quota := task.Quota
|
||||||
|
if quota != 0 {
|
||||||
|
if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
|
||||||
|
common.LogError(ctx, "Failed to increase user quota: "+err.Error())
|
||||||
|
}
|
||||||
|
logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, common.LogQuota(quota))
|
||||||
|
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
task.Data = responseBody
|
||||||
|
if err := task.Update(); err != nil {
|
||||||
|
common.SysError("UpdateVideoTask task error: " + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
47
dto/video.go
Normal file
47
dto/video.go
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
package dto
|
||||||
|
|
||||||
|
type VideoRequest struct {
|
||||||
|
Model string `json:"model,omitempty" example:"kling-v1"` // Model/style ID
|
||||||
|
Prompt string `json:"prompt,omitempty" example:"宇航员站起身走了"` // Text prompt
|
||||||
|
Image string `json:"image,omitempty" example:"https://h2.inkwai.com/bs2/upload-ylab-stunt/se/ai_portal_queue_mmu_image_upscale_aiweb/3214b798-e1b4-4b00-b7af-72b5b0417420_raw_image_0.jpg"` // Image input (URL/Base64)
|
||||||
|
Duration float64 `json:"duration" example:"5.0"` // Video duration (seconds)
|
||||||
|
Width int `json:"width" example:"512"` // Video width
|
||||||
|
Height int `json:"height" example:"512"` // Video height
|
||||||
|
Fps int `json:"fps,omitempty" example:"30"` // Video frame rate
|
||||||
|
Seed int `json:"seed,omitempty" example:"20231234"` // Random seed
|
||||||
|
N int `json:"n,omitempty" example:"1"` // Number of videos to generate
|
||||||
|
ResponseFormat string `json:"response_format,omitempty" example:"url"` // Response format
|
||||||
|
User string `json:"user,omitempty" example:"user-1234"` // User identifier
|
||||||
|
Metadata map[string]any `json:"metadata,omitempty"` // Vendor-specific/custom params (e.g. negative_prompt, style, quality_level, etc.)
|
||||||
|
}
|
||||||
|
|
||||||
|
// VideoResponse 视频生成提交任务后的响应
|
||||||
|
type VideoResponse struct {
|
||||||
|
TaskId string `json:"task_id"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// VideoTaskResponse 查询视频生成任务状态的响应
|
||||||
|
type VideoTaskResponse struct {
|
||||||
|
TaskId string `json:"task_id" example:"abcd1234efgh"` // 任务ID
|
||||||
|
Status string `json:"status" example:"succeeded"` // 任务状态
|
||||||
|
Url string `json:"url,omitempty"` // 视频资源URL(成功时)
|
||||||
|
Format string `json:"format,omitempty" example:"mp4"` // 视频格式
|
||||||
|
Metadata *VideoTaskMetadata `json:"metadata,omitempty"` // 结果元数据
|
||||||
|
Error *VideoTaskError `json:"error,omitempty"` // 错误信息(失败时)
|
||||||
|
}
|
||||||
|
|
||||||
|
// VideoTaskMetadata 视频任务元数据
|
||||||
|
type VideoTaskMetadata struct {
|
||||||
|
Duration float64 `json:"duration" example:"5.0"` // 实际生成的视频时长
|
||||||
|
Fps int `json:"fps" example:"30"` // 实际帧率
|
||||||
|
Width int `json:"width" example:"512"` // 实际宽度
|
||||||
|
Height int `json:"height" example:"512"` // 实际高度
|
||||||
|
Seed int `json:"seed" example:"20231234"` // 使用的随机种子
|
||||||
|
}
|
||||||
|
|
||||||
|
// VideoTaskError 视频任务错误信息
|
||||||
|
type VideoTaskError struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
@@ -170,6 +170,15 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
|||||||
}
|
}
|
||||||
c.Set("platform", string(constant.TaskPlatformSuno))
|
c.Set("platform", string(constant.TaskPlatformSuno))
|
||||||
c.Set("relay_mode", relayMode)
|
c.Set("relay_mode", relayMode)
|
||||||
|
} else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") {
|
||||||
|
relayMode := relayconstant.Path2RelayKling(c.Request.Method, c.Request.URL.Path)
|
||||||
|
if relayMode == relayconstant.RelayModeKlingFetchByID {
|
||||||
|
shouldSelectChannel = false
|
||||||
|
} else {
|
||||||
|
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
||||||
|
}
|
||||||
|
c.Set("platform", string(constant.TaskPlatformKling))
|
||||||
|
c.Set("relay_mode", relayMode)
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") {
|
} else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") {
|
||||||
// Gemini API 路径处理: /v1beta/models/gemini-2.0-flash:generateContent
|
// Gemini API 路径处理: /v1beta/models/gemini-2.0-flash:generateContent
|
||||||
relayMode := relayconstant.RelayModeGemini
|
relayMode := relayconstant.RelayModeGemini
|
||||||
|
|||||||
@@ -44,4 +44,6 @@ type TaskAdaptor interface {
|
|||||||
|
|
||||||
// FetchTask
|
// FetchTask
|
||||||
FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error)
|
FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error)
|
||||||
|
|
||||||
|
ParseResultUrl(resp map[string]any) (string, error)
|
||||||
}
|
}
|
||||||
|
|||||||
312
relay/channel/task/kling/adaptor.go
Normal file
312
relay/channel/task/kling/adaptor.go
Normal file
@@ -0,0 +1,312 @@
|
|||||||
|
package kling
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/golang-jwt/jwt"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/dto"
|
||||||
|
"one-api/relay/channel"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ============================
|
||||||
|
// Request / Response structures
|
||||||
|
// ============================
|
||||||
|
|
||||||
|
type SubmitReq struct {
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
Model string `json:"model,omitempty"`
|
||||||
|
Mode string `json:"mode,omitempty"`
|
||||||
|
Image string `json:"image,omitempty"`
|
||||||
|
Size string `json:"size,omitempty"`
|
||||||
|
Duration int `json:"duration,omitempty"`
|
||||||
|
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type requestPayload struct {
|
||||||
|
Prompt string `json:"prompt,omitempty"`
|
||||||
|
Image string `json:"image,omitempty"`
|
||||||
|
Mode string `json:"mode,omitempty"`
|
||||||
|
Duration string `json:"duration,omitempty"`
|
||||||
|
AspectRatio string `json:"aspect_ratio,omitempty"`
|
||||||
|
Model string `json:"model,omitempty"`
|
||||||
|
ModelName string `json:"model_name,omitempty"`
|
||||||
|
CfgScale float64 `json:"cfg_scale,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type responsePayload struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Data struct {
|
||||||
|
TaskID string `json:"task_id"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================
|
||||||
|
// Adaptor implementation
|
||||||
|
// ============================
|
||||||
|
|
||||||
|
type TaskAdaptor struct {
|
||||||
|
ChannelType int
|
||||||
|
accessKey string
|
||||||
|
secretKey string
|
||||||
|
baseURL string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
|
||||||
|
a.ChannelType = info.ChannelType
|
||||||
|
a.baseURL = info.BaseUrl
|
||||||
|
|
||||||
|
// apiKey format: "access_key,secret_key"
|
||||||
|
keyParts := strings.Split(info.ApiKey, ",")
|
||||||
|
if len(keyParts) == 2 {
|
||||||
|
a.accessKey = strings.TrimSpace(keyParts[0])
|
||||||
|
a.secretKey = strings.TrimSpace(keyParts[1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
|
||||||
|
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) {
|
||||||
|
// Accept only POST /v1/video/generations as "generate" action.
|
||||||
|
action := "generate"
|
||||||
|
info.Action = action
|
||||||
|
|
||||||
|
var req SubmitReq
|
||||||
|
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
|
||||||
|
taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(req.Prompt) == "" {
|
||||||
|
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store into context for later usage
|
||||||
|
c.Set("kling_request", req)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildRequestURL constructs the upstream URL.
|
||||||
|
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
|
||||||
|
return fmt.Sprintf("%s/v1/videos/image2video", a.baseURL), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildRequestHeader sets required headers.
|
||||||
|
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error {
|
||||||
|
token, err := a.createJWTToken()
|
||||||
|
if err != nil {
|
||||||
|
token = info.ApiKey // fallback
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
req.Header.Set("Authorization", "Bearer "+token)
|
||||||
|
req.Header.Set("User-Agent", "kling-sdk/1.0")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildRequestBody converts request into Kling specific format.
|
||||||
|
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) {
|
||||||
|
v, exists := c.Get("kling_request")
|
||||||
|
if !exists {
|
||||||
|
return nil, fmt.Errorf("request not found in context")
|
||||||
|
}
|
||||||
|
req := v.(SubmitReq)
|
||||||
|
|
||||||
|
body := a.convertToRequestPayload(&req)
|
||||||
|
data, err := json.Marshal(body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return bytes.NewReader(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DoRequest delegates to common helper.
|
||||||
|
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
|
return channel.DoTaskApiRequest(a, c, info, requestBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DoResponse handles upstream response, returns taskID etc.
|
||||||
|
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attempt Kling response parse first.
|
||||||
|
var kResp responsePayload
|
||||||
|
if err := json.Unmarshal(responseBody, &kResp); err == nil && kResp.Code == 0 {
|
||||||
|
c.JSON(http.StatusOK, gin.H{"task_id": kResp.Data.TaskID})
|
||||||
|
return kResp.Data.TaskID, responseBody, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback generic task response.
|
||||||
|
var generic dto.TaskResponse[string]
|
||||||
|
if err := json.Unmarshal(responseBody, &generic); err != nil {
|
||||||
|
taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !generic.IsSuccess() {
|
||||||
|
taskErr = service.TaskErrorWrapper(fmt.Errorf(generic.Message), generic.Code, http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{"task_id": generic.Data})
|
||||||
|
return generic.Data, responseBody, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FetchTask fetch task status
|
||||||
|
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
||||||
|
taskID, ok := body["task_id"].(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("invalid task_id")
|
||||||
|
}
|
||||||
|
url := fmt.Sprintf("%s/v1/videos/image2video/%s", baseUrl, taskID)
|
||||||
|
|
||||||
|
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := a.createJWTTokenWithKey(key)
|
||||||
|
if err != nil {
|
||||||
|
token = key
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
req = req.WithContext(ctx)
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
req.Header.Set("Authorization", "Bearer "+token)
|
||||||
|
req.Header.Set("User-Agent", "kling-sdk/1.0")
|
||||||
|
|
||||||
|
return service.GetHttpClient().Do(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *TaskAdaptor) GetModelList() []string {
|
||||||
|
return []string{"kling-v1", "kling-v1-6", "kling-v2-master"}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *TaskAdaptor) GetChannelName() string {
|
||||||
|
return "kling"
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================
|
||||||
|
// helpers
|
||||||
|
// ============================
|
||||||
|
|
||||||
|
func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) *requestPayload {
|
||||||
|
r := &requestPayload{
|
||||||
|
Prompt: req.Prompt,
|
||||||
|
Image: req.Image,
|
||||||
|
Mode: defaultString(req.Mode, "std"),
|
||||||
|
Duration: fmt.Sprintf("%d", defaultInt(req.Duration, 5)),
|
||||||
|
AspectRatio: a.getAspectRatio(req.Size),
|
||||||
|
Model: req.Model,
|
||||||
|
ModelName: req.Model,
|
||||||
|
CfgScale: 0.5,
|
||||||
|
}
|
||||||
|
if r.Model == "" {
|
||||||
|
r.Model = "kling-v1"
|
||||||
|
r.ModelName = "kling-v1"
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *TaskAdaptor) getAspectRatio(size string) string {
|
||||||
|
switch size {
|
||||||
|
case "1024x1024", "512x512":
|
||||||
|
return "1:1"
|
||||||
|
case "1280x720", "1920x1080":
|
||||||
|
return "16:9"
|
||||||
|
case "720x1280", "1080x1920":
|
||||||
|
return "9:16"
|
||||||
|
default:
|
||||||
|
return "1:1"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultString(s, def string) string {
|
||||||
|
if strings.TrimSpace(s) == "" {
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultInt(v int, def int) int {
|
||||||
|
if v == 0 {
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================
|
||||||
|
// JWT helpers
|
||||||
|
// ============================
|
||||||
|
|
||||||
|
func (a *TaskAdaptor) createJWTToken() (string, error) {
|
||||||
|
return a.createJWTTokenWithKeys(a.accessKey, a.secretKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) {
|
||||||
|
parts := strings.Split(apiKey, ",")
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return "", fmt.Errorf("invalid API key format, expected 'access_key,secret_key'")
|
||||||
|
}
|
||||||
|
return a.createJWTTokenWithKeys(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *TaskAdaptor) createJWTTokenWithKeys(accessKey, secretKey string) (string, error) {
|
||||||
|
if accessKey == "" || secretKey == "" {
|
||||||
|
return "", fmt.Errorf("access key and secret key are required")
|
||||||
|
}
|
||||||
|
now := time.Now().Unix()
|
||||||
|
claims := jwt.MapClaims{
|
||||||
|
"iss": accessKey,
|
||||||
|
"exp": now + 1800, // 30 minutes
|
||||||
|
"nbf": now - 5,
|
||||||
|
}
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||||
|
token.Header["typ"] = "JWT"
|
||||||
|
return token.SignedString([]byte(secretKey))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseResultUrl 提取视频任务结果的 url
|
||||||
|
func (a *TaskAdaptor) ParseResultUrl(resp map[string]any) (string, error) {
|
||||||
|
data, ok := resp["data"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("data field not found or invalid")
|
||||||
|
}
|
||||||
|
taskResult, ok := data["task_result"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("task_result field not found or invalid")
|
||||||
|
}
|
||||||
|
videos, ok := taskResult["videos"].([]interface{})
|
||||||
|
if !ok || len(videos) == 0 {
|
||||||
|
return "", fmt.Errorf("videos field not found or empty")
|
||||||
|
}
|
||||||
|
video, ok := videos[0].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("video item invalid")
|
||||||
|
}
|
||||||
|
url, ok := video["url"].(string)
|
||||||
|
if !ok || url == "" {
|
||||||
|
return "", fmt.Errorf("url field not found or invalid")
|
||||||
|
}
|
||||||
|
return url, nil
|
||||||
|
}
|
||||||
@@ -22,6 +22,10 @@ type TaskAdaptor struct {
|
|||||||
ChannelType int
|
ChannelType int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *TaskAdaptor) ParseResultUrl(resp map[string]any) (string, error) {
|
||||||
|
return "", nil // todo implement this method if needed
|
||||||
|
}
|
||||||
|
|
||||||
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
|
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
|
||||||
a.ChannelType = info.ChannelType
|
a.ChannelType = info.ChannelType
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -38,6 +38,9 @@ const (
|
|||||||
RelayModeSunoFetchByID
|
RelayModeSunoFetchByID
|
||||||
RelayModeSunoSubmit
|
RelayModeSunoSubmit
|
||||||
|
|
||||||
|
RelayModeKlingFetchByID
|
||||||
|
RelayModeKlingSubmit
|
||||||
|
|
||||||
RelayModeRerank
|
RelayModeRerank
|
||||||
|
|
||||||
RelayModeResponses
|
RelayModeResponses
|
||||||
@@ -133,3 +136,13 @@ func Path2RelaySuno(method, path string) int {
|
|||||||
}
|
}
|
||||||
return relayMode
|
return relayMode
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Path2RelayKling(method, path string) int {
|
||||||
|
relayMode := RelayModeUnknown
|
||||||
|
if method == http.MethodPost && strings.HasSuffix(path, "/video/generations") {
|
||||||
|
relayMode = RelayModeKlingSubmit
|
||||||
|
} else if method == http.MethodGet && strings.Contains(path, "/video/generations/") {
|
||||||
|
relayMode = RelayModeKlingFetchByID
|
||||||
|
}
|
||||||
|
return relayMode
|
||||||
|
}
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ import (
|
|||||||
"one-api/relay/channel/palm"
|
"one-api/relay/channel/palm"
|
||||||
"one-api/relay/channel/perplexity"
|
"one-api/relay/channel/perplexity"
|
||||||
"one-api/relay/channel/siliconflow"
|
"one-api/relay/channel/siliconflow"
|
||||||
|
"one-api/relay/channel/task/kling"
|
||||||
"one-api/relay/channel/task/suno"
|
"one-api/relay/channel/task/suno"
|
||||||
"one-api/relay/channel/tencent"
|
"one-api/relay/channel/tencent"
|
||||||
"one-api/relay/channel/vertex"
|
"one-api/relay/channel/vertex"
|
||||||
@@ -101,6 +102,8 @@ func GetTaskAdaptor(platform commonconstant.TaskPlatform) channel.TaskAdaptor {
|
|||||||
// return &aiproxy.Adaptor{}
|
// return &aiproxy.Adaptor{}
|
||||||
case commonconstant.TaskPlatformSuno:
|
case commonconstant.TaskPlatformSuno:
|
||||||
return &suno.TaskAdaptor{}
|
return &suno.TaskAdaptor{}
|
||||||
|
case commonconstant.TaskPlatformKling:
|
||||||
|
return &kling.TaskAdaptor{}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -37,6 +37,9 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
modelName := service.CoverTaskActionToModelName(platform, relayInfo.Action)
|
modelName := service.CoverTaskActionToModelName(platform, relayInfo.Action)
|
||||||
|
if platform == constant.TaskPlatformKling {
|
||||||
|
modelName = relayInfo.OriginModelName
|
||||||
|
}
|
||||||
modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
|
modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
|
||||||
if !success {
|
if !success {
|
||||||
defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[modelName]
|
defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[modelName]
|
||||||
@@ -136,10 +139,11 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
|||||||
}
|
}
|
||||||
relayInfo.ConsumeQuota = true
|
relayInfo.ConsumeQuota = true
|
||||||
// insert task
|
// insert task
|
||||||
task := model.InitTask(constant.TaskPlatformSuno, relayInfo)
|
task := model.InitTask(platform, relayInfo)
|
||||||
task.TaskID = taskID
|
task.TaskID = taskID
|
||||||
task.Quota = quota
|
task.Quota = quota
|
||||||
task.Data = taskData
|
task.Data = taskData
|
||||||
|
task.Action = relayInfo.Action
|
||||||
err = task.Insert()
|
err = task.Insert()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError)
|
taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError)
|
||||||
@@ -149,8 +153,9 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
|
var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
|
||||||
relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder,
|
relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder,
|
||||||
relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder,
|
relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder,
|
||||||
|
relayconstant.RelayModeKlingFetchByID: videoFetchByIDRespBodyBuilder,
|
||||||
}
|
}
|
||||||
|
|
||||||
func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) {
|
func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) {
|
||||||
@@ -225,6 +230,27 @@ func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dt
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
|
||||||
|
taskId := c.Param("id")
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
|
||||||
|
originTask, exist, err := model.GetByTaskId(userId, taskId)
|
||||||
|
if err != nil {
|
||||||
|
taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !exist {
|
||||||
|
taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
respBody, err = json.Marshal(dto.TaskResponse[any]{
|
||||||
|
Code: "success",
|
||||||
|
Data: TaskModel2Dto(originTask),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func TaskModel2Dto(task *model.Task) *dto.TaskDto {
|
func TaskModel2Dto(task *model.Task) *dto.TaskDto {
|
||||||
return &dto.TaskDto{
|
return &dto.TaskDto{
|
||||||
TaskID: task.TaskID,
|
TaskID: task.TaskID,
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
|
|||||||
SetApiRouter(router)
|
SetApiRouter(router)
|
||||||
SetDashboardRouter(router)
|
SetDashboardRouter(router)
|
||||||
SetRelayRouter(router)
|
SetRelayRouter(router)
|
||||||
|
SetVideoRouter(router)
|
||||||
frontendBaseUrl := os.Getenv("FRONTEND_BASE_URL")
|
frontendBaseUrl := os.Getenv("FRONTEND_BASE_URL")
|
||||||
if common.IsMasterNode && frontendBaseUrl != "" {
|
if common.IsMasterNode && frontendBaseUrl != "" {
|
||||||
frontendBaseUrl = ""
|
frontendBaseUrl = ""
|
||||||
|
|||||||
17
router/video-router.go
Normal file
17
router/video-router.go
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
package router
|
||||||
|
|
||||||
|
import (
|
||||||
|
"one-api/controller"
|
||||||
|
"one-api/middleware"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func SetVideoRouter(router *gin.Engine) {
|
||||||
|
videoV1Router := router.Group("/v1")
|
||||||
|
videoV1Router.Use(middleware.TokenAuth(), middleware.Distribute())
|
||||||
|
{
|
||||||
|
videoV1Router.POST("/video/generations", controller.RelayTask)
|
||||||
|
videoV1Router.GET("/video/generations/:task_id", controller.RelayTask)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -11,7 +11,9 @@ import {
|
|||||||
XCircle,
|
XCircle,
|
||||||
Loader,
|
Loader,
|
||||||
List,
|
List,
|
||||||
Hash
|
Hash,
|
||||||
|
Video,
|
||||||
|
Sparkles
|
||||||
} from 'lucide-react';
|
} from 'lucide-react';
|
||||||
import {
|
import {
|
||||||
API,
|
API,
|
||||||
@@ -80,6 +82,7 @@ const COLUMN_KEYS = {
|
|||||||
TASK_STATUS: 'task_status',
|
TASK_STATUS: 'task_status',
|
||||||
PROGRESS: 'progress',
|
PROGRESS: 'progress',
|
||||||
FAIL_REASON: 'fail_reason',
|
FAIL_REASON: 'fail_reason',
|
||||||
|
RESULT_URL: 'result_url',
|
||||||
};
|
};
|
||||||
|
|
||||||
const renderTimestamp = (timestampInSeconds) => {
|
const renderTimestamp = (timestampInSeconds) => {
|
||||||
@@ -150,6 +153,7 @@ const LogsTable = () => {
|
|||||||
[COLUMN_KEYS.TASK_STATUS]: true,
|
[COLUMN_KEYS.TASK_STATUS]: true,
|
||||||
[COLUMN_KEYS.PROGRESS]: true,
|
[COLUMN_KEYS.PROGRESS]: true,
|
||||||
[COLUMN_KEYS.FAIL_REASON]: true,
|
[COLUMN_KEYS.FAIL_REASON]: true,
|
||||||
|
[COLUMN_KEYS.RESULT_URL]: true,
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -203,6 +207,12 @@ const LogsTable = () => {
|
|||||||
{t('生成歌词')}
|
{t('生成歌词')}
|
||||||
</Tag>
|
</Tag>
|
||||||
);
|
);
|
||||||
|
case 'generate':
|
||||||
|
return (
|
||||||
|
<Tag color='blue' size='large' shape='circle' prefixIcon={<Sparkles size={14} />}>
|
||||||
|
{t('生成视频')}
|
||||||
|
</Tag>
|
||||||
|
);
|
||||||
default:
|
default:
|
||||||
return (
|
return (
|
||||||
<Tag color='white' size='large' shape='circle' prefixIcon={<HelpCircle size={14} />}>
|
<Tag color='white' size='large' shape='circle' prefixIcon={<HelpCircle size={14} />}>
|
||||||
@@ -220,6 +230,12 @@ const LogsTable = () => {
|
|||||||
Suno
|
Suno
|
||||||
</Tag>
|
</Tag>
|
||||||
);
|
);
|
||||||
|
case 'kling':
|
||||||
|
return (
|
||||||
|
<Tag color='blue' size='large' shape='circle' prefixIcon={<Video size={14} />}>
|
||||||
|
Kling
|
||||||
|
</Tag>
|
||||||
|
);
|
||||||
default:
|
default:
|
||||||
return (
|
return (
|
||||||
<Tag color='white' size='large' shape='circle' prefixIcon={<HelpCircle size={14} />}>
|
<Tag color='white' size='large' shape='circle' prefixIcon={<HelpCircle size={14} />}>
|
||||||
@@ -411,10 +427,21 @@ const LogsTable = () => {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
key: COLUMN_KEYS.FAIL_REASON,
|
key: COLUMN_KEYS.FAIL_REASON,
|
||||||
title: t('失败原因'),
|
title: t('详情'),
|
||||||
dataIndex: 'fail_reason',
|
dataIndex: 'fail_reason',
|
||||||
fixed: 'right',
|
fixed: 'right',
|
||||||
render: (text, record, index) => {
|
render: (text, record, index) => {
|
||||||
|
// 仅当为视频生成任务且成功,且 fail_reason 是 URL 时显示可点击链接
|
||||||
|
const isVideoTask = record.action === 'generate';
|
||||||
|
const isSuccess = record.status === 'SUCCESS';
|
||||||
|
const isUrl = typeof text === 'string' && /^https?:\/\//.test(text);
|
||||||
|
if (isSuccess && isVideoTask && isUrl) {
|
||||||
|
return (
|
||||||
|
<a href={text} target="_blank" rel="noopener noreferrer">
|
||||||
|
{t('点击预览视频')}
|
||||||
|
</a>
|
||||||
|
);
|
||||||
|
}
|
||||||
if (!text) {
|
if (!text) {
|
||||||
return t('无');
|
return t('无');
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -125,4 +125,9 @@ export const CHANNEL_OPTIONS = [
|
|||||||
color: 'blue',
|
color: 'blue',
|
||||||
label: 'Coze',
|
label: 'Coze',
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
value: 50,
|
||||||
|
color: 'green',
|
||||||
|
label: '可灵',
|
||||||
|
},
|
||||||
];
|
];
|
||||||
|
|||||||
Reference in New Issue
Block a user