diff --git a/common/constants.go b/common/constants.go index bee00506..ac803148 100644 --- a/common/constants.go +++ b/common/constants.go @@ -241,6 +241,7 @@ const ( ChannelTypeXinference = 47 ChannelTypeXai = 48 ChannelTypeCoze = 49 + ChannelTypeKling = 50 ChannelTypeDummy // this one is only for count, do not add any channel after this ) @@ -296,4 +297,5 @@ var ChannelBaseURLs = []string{ "", //47 "https://api.x.ai", //48 "https://api.coze.cn", //49 + "https://api.klingai.com", //50 } diff --git a/constant/task.go b/constant/task.go index 1a68b812..d466fc8a 100644 --- a/constant/task.go +++ b/constant/task.go @@ -5,6 +5,7 @@ type TaskPlatform string const ( TaskPlatformSuno TaskPlatform = "suno" TaskPlatformMidjourney = "mj" + TaskPlatformKling TaskPlatform = "kling" ) const ( diff --git a/controller/relay.go b/controller/relay.go index c1c45114..4da4262b 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -420,7 +420,7 @@ func RelayTask(c *gin.Context) { func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError { var err *dto.TaskError switch relayMode { - case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID: + case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeKlingFetchByID: err = relay.RelayTaskFetch(c, relayMode) default: err = relay.RelayTaskSubmit(c, relayMode) diff --git a/controller/task.go b/controller/task.go index 34e14f3f..f7523e87 100644 --- a/controller/task.go +++ b/controller/task.go @@ -74,6 +74,8 @@ func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][ //_ = UpdateMidjourneyTaskAll(context.Background(), tasks) case constant.TaskPlatformSuno: _ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM) + case constant.TaskPlatformKling: + _ = UpdateVideoTaskAll(context.Background(), taskChannelM, taskM) default: common.SysLog("未知平台") } diff --git a/controller/task_video.go b/controller/task_video.go new file mode 100644 index 00000000..3f2c9588 --- /dev/null +++ b/controller/task_video.go @@ -0,0 +1,142 @@ +package controller + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "one-api/common" + "one-api/constant" + "one-api/model" + "one-api/relay" + "one-api/relay/channel" +) + +func UpdateVideoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error { + for channelId, taskIds := range taskChannelM { + if err := updateVideoTaskAll(ctx, channelId, taskIds, taskM); err != nil { + common.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error())) + } + } + return nil +} + +func updateVideoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error { + common.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds))) + if len(taskIds) == 0 { + return nil + } + cacheGetChannel, err := model.CacheGetChannel(channelId) + if err != nil { + errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{ + "fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId), + "status": "FAILURE", + "progress": "100%", + }) + if errUpdate != nil { + common.SysError(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate)) + } + return fmt.Errorf("CacheGetChannel failed: %w", err) + } + adaptor := relay.GetTaskAdaptor(constant.TaskPlatformKling) + if adaptor == nil { + return fmt.Errorf("video adaptor not found") + } + for _, taskId := range taskIds { + if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil { + common.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error())) + } + } + return nil +} + +func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error { + baseURL := common.ChannelBaseURLs[channel.Type] + if channel.GetBaseURL() != "" { + baseURL = channel.GetBaseURL() + } + resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{ + "task_id": taskId, + }) + if err != nil { + return fmt.Errorf("FetchTask failed for task %s: %w", taskId, err) + } + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("Get Video Task status code: %d", resp.StatusCode) + } + defer resp.Body.Close() + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("ReadAll failed for task %s: %w", taskId, err) + } + + var responseItem map[string]interface{} + err = json.Unmarshal(responseBody, &responseItem) + if err != nil { + common.LogError(ctx, fmt.Sprintf("Failed to parse video task response body: %v, body: %s", err, string(responseBody))) + return fmt.Errorf("Unmarshal failed for task %s: %w", taskId, err) + } + + code, _ := responseItem["code"].(float64) + if code != 0 { + return fmt.Errorf("video task fetch failed for task %s", taskId) + } + + data, ok := responseItem["data"].(map[string]interface{}) + if !ok { + common.LogError(ctx, fmt.Sprintf("Video task data format error: %s", string(responseBody))) + return fmt.Errorf("video task data format error for task %s", taskId) + } + + task := taskM[taskId] + if task == nil { + common.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId)) + return fmt.Errorf("task %s not found", taskId) + } + + if status, ok := data["task_status"].(string); ok { + switch status { + case "submitted", "queued": + task.Status = model.TaskStatusSubmitted + case "processing": + task.Status = model.TaskStatusInProgress + case "succeed": + task.Status = model.TaskStatusSuccess + task.Progress = "100%" + if url, err := adaptor.(interface { + ParseResultUrl(map[string]any) (string, error) + }).ParseResultUrl(responseItem); err == nil { + task.FailReason = url + } else { + common.LogWarn(ctx, fmt.Sprintf("Failed to get url from body for task %s: %s", task.TaskID, err.Error())) + } + case "failed": + task.Status = model.TaskStatusFailure + task.Progress = "100%" + if reason, ok := data["fail_reason"].(string); ok { + task.FailReason = reason + } + } + } + + // If task failed, refund quota + if task.Status == model.TaskStatusFailure { + common.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason)) + quota := task.Quota + if quota != 0 { + if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil { + common.LogError(ctx, "Failed to increase user quota: "+err.Error()) + } + logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, common.LogQuota(quota)) + model.RecordLog(task.UserId, model.LogTypeSystem, logContent) + } + } + + task.Data = responseBody + if err := task.Update(); err != nil { + common.SysError("UpdateVideoTask task error: " + err.Error()) + } + + return nil +} diff --git a/dto/video.go b/dto/video.go new file mode 100644 index 00000000..5b48146a --- /dev/null +++ b/dto/video.go @@ -0,0 +1,47 @@ +package dto + +type VideoRequest struct { + Model string `json:"model,omitempty" example:"kling-v1"` // Model/style ID + Prompt string `json:"prompt,omitempty" example:"宇航员站起身走了"` // Text prompt + Image string `json:"image,omitempty" example:"https://h2.inkwai.com/bs2/upload-ylab-stunt/se/ai_portal_queue_mmu_image_upscale_aiweb/3214b798-e1b4-4b00-b7af-72b5b0417420_raw_image_0.jpg"` // Image input (URL/Base64) + Duration float64 `json:"duration" example:"5.0"` // Video duration (seconds) + Width int `json:"width" example:"512"` // Video width + Height int `json:"height" example:"512"` // Video height + Fps int `json:"fps,omitempty" example:"30"` // Video frame rate + Seed int `json:"seed,omitempty" example:"20231234"` // Random seed + N int `json:"n,omitempty" example:"1"` // Number of videos to generate + ResponseFormat string `json:"response_format,omitempty" example:"url"` // Response format + User string `json:"user,omitempty" example:"user-1234"` // User identifier + Metadata map[string]any `json:"metadata,omitempty"` // Vendor-specific/custom params (e.g. negative_prompt, style, quality_level, etc.) +} + +// VideoResponse 视频生成提交任务后的响应 +type VideoResponse struct { + TaskId string `json:"task_id"` + Status string `json:"status"` +} + +// VideoTaskResponse 查询视频生成任务状态的响应 +type VideoTaskResponse struct { + TaskId string `json:"task_id" example:"abcd1234efgh"` // 任务ID + Status string `json:"status" example:"succeeded"` // 任务状态 + Url string `json:"url,omitempty"` // 视频资源URL(成功时) + Format string `json:"format,omitempty" example:"mp4"` // 视频格式 + Metadata *VideoTaskMetadata `json:"metadata,omitempty"` // 结果元数据 + Error *VideoTaskError `json:"error,omitempty"` // 错误信息(失败时) +} + +// VideoTaskMetadata 视频任务元数据 +type VideoTaskMetadata struct { + Duration float64 `json:"duration" example:"5.0"` // 实际生成的视频时长 + Fps int `json:"fps" example:"30"` // 实际帧率 + Width int `json:"width" example:"512"` // 实际宽度 + Height int `json:"height" example:"512"` // 实际高度 + Seed int `json:"seed" example:"20231234"` // 使用的随机种子 +} + +// VideoTaskError 视频任务错误信息 +type VideoTaskError struct { + Code int `json:"code"` + Message string `json:"message"` +} diff --git a/middleware/distributor.go b/middleware/distributor.go index 84eb182e..9d074ce8 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -170,6 +170,15 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { } c.Set("platform", string(constant.TaskPlatformSuno)) c.Set("relay_mode", relayMode) + } else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") { + relayMode := relayconstant.Path2RelayKling(c.Request.Method, c.Request.URL.Path) + if relayMode == relayconstant.RelayModeKlingFetchByID { + shouldSelectChannel = false + } else { + err = common.UnmarshalBodyReusable(c, &modelRequest) + } + c.Set("platform", string(constant.TaskPlatformKling)) + c.Set("relay_mode", relayMode) } else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") { // Gemini API 路径处理: /v1beta/models/gemini-2.0-flash:generateContent relayMode := relayconstant.RelayModeGemini diff --git a/relay/channel/adapter.go b/relay/channel/adapter.go index 50255d0a..873997f6 100644 --- a/relay/channel/adapter.go +++ b/relay/channel/adapter.go @@ -44,4 +44,6 @@ type TaskAdaptor interface { // FetchTask FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) + + ParseResultUrl(resp map[string]any) (string, error) } diff --git a/relay/channel/task/kling/adaptor.go b/relay/channel/task/kling/adaptor.go new file mode 100644 index 00000000..9c6773f5 --- /dev/null +++ b/relay/channel/task/kling/adaptor.go @@ -0,0 +1,312 @@ +package kling + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/gin-gonic/gin" + "github.com/golang-jwt/jwt" + "github.com/pkg/errors" + + "one-api/common" + "one-api/dto" + "one-api/relay/channel" + relaycommon "one-api/relay/common" + "one-api/service" +) + +// ============================ +// Request / Response structures +// ============================ + +type SubmitReq struct { + Prompt string `json:"prompt"` + Model string `json:"model,omitempty"` + Mode string `json:"mode,omitempty"` + Image string `json:"image,omitempty"` + Size string `json:"size,omitempty"` + Duration int `json:"duration,omitempty"` + Metadata map[string]interface{} `json:"metadata,omitempty"` +} + +type requestPayload struct { + Prompt string `json:"prompt,omitempty"` + Image string `json:"image,omitempty"` + Mode string `json:"mode,omitempty"` + Duration string `json:"duration,omitempty"` + AspectRatio string `json:"aspect_ratio,omitempty"` + Model string `json:"model,omitempty"` + ModelName string `json:"model_name,omitempty"` + CfgScale float64 `json:"cfg_scale,omitempty"` +} + +type responsePayload struct { + Code int `json:"code"` + Message string `json:"message"` + Data struct { + TaskID string `json:"task_id"` + } `json:"data"` +} + +// ============================ +// Adaptor implementation +// ============================ + +type TaskAdaptor struct { + ChannelType int + accessKey string + secretKey string + baseURL string +} + +func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) { + a.ChannelType = info.ChannelType + a.baseURL = info.BaseUrl + + // apiKey format: "access_key,secret_key" + keyParts := strings.Split(info.ApiKey, ",") + if len(keyParts) == 2 { + a.accessKey = strings.TrimSpace(keyParts[0]) + a.secretKey = strings.TrimSpace(keyParts[1]) + } +} + +// ValidateRequestAndSetAction parses body, validates fields and sets default action. +func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) { + // Accept only POST /v1/video/generations as "generate" action. + action := "generate" + info.Action = action + + var req SubmitReq + if err := common.UnmarshalBodyReusable(c, &req); err != nil { + taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest) + return + } + if strings.TrimSpace(req.Prompt) == "" { + taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest) + return + } + + // Store into context for later usage + c.Set("kling_request", req) + return nil +} + +// BuildRequestURL constructs the upstream URL. +func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) { + return fmt.Sprintf("%s/v1/videos/image2video", a.baseURL), nil +} + +// BuildRequestHeader sets required headers. +func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error { + token, err := a.createJWTToken() + if err != nil { + token = info.ApiKey // fallback + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("User-Agent", "kling-sdk/1.0") + return nil +} + +// BuildRequestBody converts request into Kling specific format. +func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) { + v, exists := c.Get("kling_request") + if !exists { + return nil, fmt.Errorf("request not found in context") + } + req := v.(SubmitReq) + + body := a.convertToRequestPayload(&req) + data, err := json.Marshal(body) + if err != nil { + return nil, err + } + return bytes.NewReader(data), nil +} + +// DoRequest delegates to common helper. +func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) { + return channel.DoTaskApiRequest(a, c, info, requestBody) +} + +// DoResponse handles upstream response, returns taskID etc. +func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) + return + } + + // Attempt Kling response parse first. + var kResp responsePayload + if err := json.Unmarshal(responseBody, &kResp); err == nil && kResp.Code == 0 { + c.JSON(http.StatusOK, gin.H{"task_id": kResp.Data.TaskID}) + return kResp.Data.TaskID, responseBody, nil + } + + // Fallback generic task response. + var generic dto.TaskResponse[string] + if err := json.Unmarshal(responseBody, &generic); err != nil { + taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError) + return + } + + if !generic.IsSuccess() { + taskErr = service.TaskErrorWrapper(fmt.Errorf(generic.Message), generic.Code, http.StatusInternalServerError) + return + } + + c.JSON(http.StatusOK, gin.H{"task_id": generic.Data}) + return generic.Data, responseBody, nil +} + +// FetchTask fetch task status +func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) { + taskID, ok := body["task_id"].(string) + if !ok { + return nil, fmt.Errorf("invalid task_id") + } + url := fmt.Sprintf("%s/v1/videos/image2video/%s", baseUrl, taskID) + + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return nil, err + } + + token, err := a.createJWTTokenWithKey(key) + if err != nil { + token = key + } + + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + req = req.WithContext(ctx) + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("User-Agent", "kling-sdk/1.0") + + return service.GetHttpClient().Do(req) +} + +func (a *TaskAdaptor) GetModelList() []string { + return []string{"kling-v1", "kling-v1-6", "kling-v2-master"} +} + +func (a *TaskAdaptor) GetChannelName() string { + return "kling" +} + +// ============================ +// helpers +// ============================ + +func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) *requestPayload { + r := &requestPayload{ + Prompt: req.Prompt, + Image: req.Image, + Mode: defaultString(req.Mode, "std"), + Duration: fmt.Sprintf("%d", defaultInt(req.Duration, 5)), + AspectRatio: a.getAspectRatio(req.Size), + Model: req.Model, + ModelName: req.Model, + CfgScale: 0.5, + } + if r.Model == "" { + r.Model = "kling-v1" + r.ModelName = "kling-v1" + } + return r +} + +func (a *TaskAdaptor) getAspectRatio(size string) string { + switch size { + case "1024x1024", "512x512": + return "1:1" + case "1280x720", "1920x1080": + return "16:9" + case "720x1280", "1080x1920": + return "9:16" + default: + return "1:1" + } +} + +func defaultString(s, def string) string { + if strings.TrimSpace(s) == "" { + return def + } + return s +} + +func defaultInt(v int, def int) int { + if v == 0 { + return def + } + return v +} + +// ============================ +// JWT helpers +// ============================ + +func (a *TaskAdaptor) createJWTToken() (string, error) { + return a.createJWTTokenWithKeys(a.accessKey, a.secretKey) +} + +func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) { + parts := strings.Split(apiKey, ",") + if len(parts) != 2 { + return "", fmt.Errorf("invalid API key format, expected 'access_key,secret_key'") + } + return a.createJWTTokenWithKeys(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1])) +} + +func (a *TaskAdaptor) createJWTTokenWithKeys(accessKey, secretKey string) (string, error) { + if accessKey == "" || secretKey == "" { + return "", fmt.Errorf("access key and secret key are required") + } + now := time.Now().Unix() + claims := jwt.MapClaims{ + "iss": accessKey, + "exp": now + 1800, // 30 minutes + "nbf": now - 5, + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + token.Header["typ"] = "JWT" + return token.SignedString([]byte(secretKey)) +} + +// ParseResultUrl 提取视频任务结果的 url +func (a *TaskAdaptor) ParseResultUrl(resp map[string]any) (string, error) { + data, ok := resp["data"].(map[string]any) + if !ok { + return "", fmt.Errorf("data field not found or invalid") + } + taskResult, ok := data["task_result"].(map[string]any) + if !ok { + return "", fmt.Errorf("task_result field not found or invalid") + } + videos, ok := taskResult["videos"].([]interface{}) + if !ok || len(videos) == 0 { + return "", fmt.Errorf("videos field not found or empty") + } + video, ok := videos[0].(map[string]interface{}) + if !ok { + return "", fmt.Errorf("video item invalid") + } + url, ok := video["url"].(string) + if !ok || url == "" { + return "", fmt.Errorf("url field not found or invalid") + } + return url, nil +} diff --git a/relay/channel/task/suno/adaptor.go b/relay/channel/task/suno/adaptor.go index 03d60516..f7042348 100644 --- a/relay/channel/task/suno/adaptor.go +++ b/relay/channel/task/suno/adaptor.go @@ -22,6 +22,10 @@ type TaskAdaptor struct { ChannelType int } +func (a *TaskAdaptor) ParseResultUrl(resp map[string]any) (string, error) { + return "", nil // todo implement this method if needed +} + func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) { a.ChannelType = info.ChannelType } diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go index f22a20bd..02a286e2 100644 --- a/relay/constant/relay_mode.go +++ b/relay/constant/relay_mode.go @@ -38,6 +38,9 @@ const ( RelayModeSunoFetchByID RelayModeSunoSubmit + RelayModeKlingFetchByID + RelayModeKlingSubmit + RelayModeRerank RelayModeResponses @@ -133,3 +136,13 @@ func Path2RelaySuno(method, path string) int { } return relayMode } + +func Path2RelayKling(method, path string) int { + relayMode := RelayModeUnknown + if method == http.MethodPost && strings.HasSuffix(path, "/video/generations") { + relayMode = RelayModeKlingSubmit + } else if method == http.MethodGet && strings.Contains(path, "/video/generations/") { + relayMode = RelayModeKlingFetchByID + } + return relayMode +} diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go index 7bf0da9f..626bb7e4 100644 --- a/relay/relay_adaptor.go +++ b/relay/relay_adaptor.go @@ -22,6 +22,7 @@ import ( "one-api/relay/channel/palm" "one-api/relay/channel/perplexity" "one-api/relay/channel/siliconflow" + "one-api/relay/channel/task/kling" "one-api/relay/channel/task/suno" "one-api/relay/channel/tencent" "one-api/relay/channel/vertex" @@ -101,6 +102,8 @@ func GetTaskAdaptor(platform commonconstant.TaskPlatform) channel.TaskAdaptor { // return &aiproxy.Adaptor{} case commonconstant.TaskPlatformSuno: return &suno.TaskAdaptor{} + case commonconstant.TaskPlatformKling: + return &kling.TaskAdaptor{} } return nil } diff --git a/relay/relay_task.go b/relay/relay_task.go index 3da9a20f..245fd681 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -37,6 +37,9 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { } modelName := service.CoverTaskActionToModelName(platform, relayInfo.Action) + if platform == constant.TaskPlatformKling { + modelName = relayInfo.OriginModelName + } modelPrice, success := ratio_setting.GetModelPrice(modelName, true) if !success { defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[modelName] @@ -136,10 +139,11 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { } relayInfo.ConsumeQuota = true // insert task - task := model.InitTask(constant.TaskPlatformSuno, relayInfo) + task := model.InitTask(platform, relayInfo) task.TaskID = taskID task.Quota = quota task.Data = taskData + task.Action = relayInfo.Action err = task.Insert() if err != nil { taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError) @@ -149,8 +153,9 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { } var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){ - relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder, - relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder, + relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder, + relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder, + relayconstant.RelayModeKlingFetchByID: videoFetchByIDRespBodyBuilder, } func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) { @@ -225,6 +230,27 @@ func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dt return } +func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) { + taskId := c.Param("id") + userId := c.GetInt("id") + + originTask, exist, err := model.GetByTaskId(userId, taskId) + if err != nil { + taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError) + return + } + if !exist { + taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest) + return + } + + respBody, err = json.Marshal(dto.TaskResponse[any]{ + Code: "success", + Data: TaskModel2Dto(originTask), + }) + return +} + func TaskModel2Dto(task *model.Task) *dto.TaskDto { return &dto.TaskDto{ TaskID: task.TaskID, diff --git a/router/main.go b/router/main.go index b8ac4055..0d2bfdce 100644 --- a/router/main.go +++ b/router/main.go @@ -14,6 +14,7 @@ func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) { SetApiRouter(router) SetDashboardRouter(router) SetRelayRouter(router) + SetVideoRouter(router) frontendBaseUrl := os.Getenv("FRONTEND_BASE_URL") if common.IsMasterNode && frontendBaseUrl != "" { frontendBaseUrl = "" diff --git a/router/video-router.go b/router/video-router.go new file mode 100644 index 00000000..7201c34a --- /dev/null +++ b/router/video-router.go @@ -0,0 +1,17 @@ +package router + +import ( + "one-api/controller" + "one-api/middleware" + + "github.com/gin-gonic/gin" +) + +func SetVideoRouter(router *gin.Engine) { + videoV1Router := router.Group("/v1") + videoV1Router.Use(middleware.TokenAuth(), middleware.Distribute()) + { + videoV1Router.POST("/video/generations", controller.RelayTask) + videoV1Router.GET("/video/generations/:task_id", controller.RelayTask) + } +} diff --git a/web/src/components/table/TaskLogsTable.js b/web/src/components/table/TaskLogsTable.js index b3d0ab7b..37bdde57 100644 --- a/web/src/components/table/TaskLogsTable.js +++ b/web/src/components/table/TaskLogsTable.js @@ -11,7 +11,9 @@ import { XCircle, Loader, List, - Hash + Hash, + Video, + Sparkles } from 'lucide-react'; import { API, @@ -80,6 +82,7 @@ const COLUMN_KEYS = { TASK_STATUS: 'task_status', PROGRESS: 'progress', FAIL_REASON: 'fail_reason', + RESULT_URL: 'result_url', }; const renderTimestamp = (timestampInSeconds) => { @@ -150,6 +153,7 @@ const LogsTable = () => { [COLUMN_KEYS.TASK_STATUS]: true, [COLUMN_KEYS.PROGRESS]: true, [COLUMN_KEYS.FAIL_REASON]: true, + [COLUMN_KEYS.RESULT_URL]: true, }; }; @@ -203,6 +207,12 @@ const LogsTable = () => { {t('生成歌词')} ); + case 'generate': + return ( + }> + {t('生成视频')} + + ); default: return ( }> @@ -220,6 +230,12 @@ const LogsTable = () => { Suno ); + case 'kling': + return ( + }> + Kling + + ); default: return ( }> @@ -411,10 +427,21 @@ const LogsTable = () => { }, { key: COLUMN_KEYS.FAIL_REASON, - title: t('失败原因'), + title: t('详情'), dataIndex: 'fail_reason', fixed: 'right', render: (text, record, index) => { + // 仅当为视频生成任务且成功,且 fail_reason 是 URL 时显示可点击链接 + const isVideoTask = record.action === 'generate'; + const isSuccess = record.status === 'SUCCESS'; + const isUrl = typeof text === 'string' && /^https?:\/\//.test(text); + if (isSuccess && isVideoTask && isUrl) { + return ( + + {t('点击预览视频')} + + ); + } if (!text) { return t('无'); } diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index 20fed5b7..c4220bd4 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -125,4 +125,9 @@ export const CHANNEL_OPTIONS = [ color: 'blue', label: 'Coze', }, + { + value: 50, + color: 'green', + label: '可灵', + }, ];