diff --git a/common/constants.go b/common/constants.go index ac803148..67625439 100644 --- a/common/constants.go +++ b/common/constants.go @@ -242,6 +242,7 @@ const ( ChannelTypeXai = 48 ChannelTypeCoze = 49 ChannelTypeKling = 50 + ChannelTypeJimeng = 51 ChannelTypeDummy // this one is only for count, do not add any channel after this ) @@ -298,4 +299,5 @@ var ChannelBaseURLs = []string{ "https://api.x.ai", //48 "https://api.coze.cn", //49 "https://api.klingai.com", //50 + "https://visual.volcengineapi.com", //51 } diff --git a/constant/task.go b/constant/task.go index d466fc8a..21831d3a 100644 --- a/constant/task.go +++ b/constant/task.go @@ -6,6 +6,7 @@ const ( TaskPlatformSuno TaskPlatform = "suno" TaskPlatformMidjourney = "mj" TaskPlatformKling TaskPlatform = "kling" + TaskPlatformJimeng TaskPlatform = "jimeng" ) const ( diff --git a/controller/channel-test.go b/controller/channel-test.go index db8c9db0..b3badf35 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -43,6 +43,9 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr if channel.Type == common.ChannelTypeKling { return errors.New("kling channel test is not supported"), nil } + if channel.Type == common.ChannelTypeJimeng { + return errors.New("jimeng channel test is not supported"), nil + } w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) diff --git a/controller/task.go b/controller/task.go index f7523e87..5cfa728a 100644 --- a/controller/task.go +++ b/controller/task.go @@ -74,8 +74,8 @@ func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][ //_ = UpdateMidjourneyTaskAll(context.Background(), tasks) case constant.TaskPlatformSuno: _ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM) - case constant.TaskPlatformKling: - _ = UpdateVideoTaskAll(context.Background(), taskChannelM, taskM) + case constant.TaskPlatformKling, constant.TaskPlatformJimeng: + _ = UpdateVideoTaskAll(context.Background(), platform, taskChannelM, taskM) default: common.SysLog("未知平台") } diff --git a/controller/task_video.go b/controller/task_video.go index 2e980310..a17351b5 100644 --- a/controller/task_video.go +++ b/controller/task_video.go @@ -2,27 +2,26 @@ package controller import ( "context" - "encoding/json" "fmt" "io" - "net/http" "one-api/common" "one-api/constant" "one-api/model" "one-api/relay" "one-api/relay/channel" + "time" ) -func UpdateVideoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error { +func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error { for channelId, taskIds := range taskChannelM { - if err := updateVideoTaskAll(ctx, channelId, taskIds, taskM); err != nil { + if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil { common.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error())) } } return nil } -func updateVideoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error { +func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error { common.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds))) if len(taskIds) == 0 { return nil @@ -39,7 +38,7 @@ func updateVideoTaskAll(ctx context.Context, channelId int, taskIds []string, ta } return fmt.Errorf("CacheGetChannel failed: %w", err) } - adaptor := relay.GetTaskAdaptor(constant.TaskPlatformKling) + adaptor := relay.GetTaskAdaptor(platform) if adaptor == nil { return fmt.Errorf("video adaptor not found") } @@ -67,60 +66,53 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha "action": task.Action, }) if err != nil { - return fmt.Errorf("FetchTask failed for task %s: %w", taskId, err) - } - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("Get Video Task status code: %d", resp.StatusCode) + return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err) } + //if resp.StatusCode != http.StatusOK { + //return fmt.Errorf("get Video Task status code: %d", resp.StatusCode) + //} defer resp.Body.Close() responseBody, err := io.ReadAll(resp.Body) if err != nil { - return fmt.Errorf("ReadAll failed for task %s: %w", taskId, err) + return fmt.Errorf("readAll failed for task %s: %w", taskId, err) } - var responseItem map[string]interface{} - err = json.Unmarshal(responseBody, &responseItem) + taskResult, err := adaptor.ParseTaskResult(responseBody) if err != nil { - common.LogError(ctx, fmt.Sprintf("Failed to parse video task response body: %v, body: %s", err, string(responseBody))) - return fmt.Errorf("Unmarshal failed for task %s: %w", taskId, err) + return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err) } + //if taskResult.Code != 0 { + // return fmt.Errorf("video task fetch failed for task %s", taskId) + //} - code, _ := responseItem["code"].(float64) - if code != 0 { - return fmt.Errorf("video task fetch failed for task %s", taskId) + now := time.Now().Unix() + if taskResult.Status == "" { + return fmt.Errorf("task %s status is empty", taskId) } - - data, ok := responseItem["data"].(map[string]interface{}) - if !ok { - common.LogError(ctx, fmt.Sprintf("Video task data format error: %s", string(responseBody))) - return fmt.Errorf("video task data format error for task %s", taskId) - } - - if status, ok := data["task_status"].(string); ok { - switch status { - case "submitted", "queued": - task.Status = model.TaskStatusSubmitted - case "processing": - task.Status = model.TaskStatusInProgress - case "succeed": - task.Status = model.TaskStatusSuccess - task.Progress = "100%" - if url, err := adaptor.ParseResultUrl(responseItem); err == nil { - task.FailReason = url - } else { - common.LogWarn(ctx, fmt.Sprintf("Failed to get url from body for task %s: %s", task.TaskID, err.Error())) - } - case "failed": - task.Status = model.TaskStatusFailure - task.Progress = "100%" - if reason, ok := data["fail_reason"].(string); ok { - task.FailReason = reason - } + task.Status = model.TaskStatus(taskResult.Status) + switch taskResult.Status { + case model.TaskStatusSubmitted: + task.Progress = "10%" + case model.TaskStatusQueued: + task.Progress = "20%" + case model.TaskStatusInProgress: + task.Progress = "30%" + if task.StartTime == 0 { + task.StartTime = now } - } - - // If task failed, refund quota - if task.Status == model.TaskStatusFailure { + case model.TaskStatusSuccess: + task.Progress = "100%" + if task.FinishTime == 0 { + task.FinishTime = now + } + task.FailReason = taskResult.Url + case model.TaskStatusFailure: + task.Status = model.TaskStatusFailure + task.Progress = "100%" + if task.FinishTime == 0 { + task.FinishTime = now + } + task.FailReason = taskResult.Reason common.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason)) quota := task.Quota if quota != 0 { @@ -130,6 +122,11 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, common.LogQuota(quota)) model.RecordLog(task.UserId, model.LogTypeSystem, logContent) } + default: + return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId) + } + if taskResult.Progress != "" { + task.Progress = taskResult.Progress } task.Data = responseBody diff --git a/middleware/distributor.go b/middleware/distributor.go index d1159f8c..0a6a9af4 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -171,13 +171,23 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { c.Set("platform", string(constant.TaskPlatformSuno)) c.Set("relay_mode", relayMode) } else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") { - relayMode := relayconstant.Path2RelayKling(c.Request.Method, c.Request.URL.Path) - if relayMode == relayconstant.RelayModeKlingFetchByID { - shouldSelectChannel = false + err = common.UnmarshalBodyReusable(c, &modelRequest) + var platform string + var relayMode int + if strings.HasPrefix(modelRequest.Model, "jimeng") { + platform = string(constant.TaskPlatformJimeng) + relayMode = relayconstant.Path2RelayJimeng(c.Request.Method, c.Request.URL.Path) + if relayMode == relayconstant.RelayModeJimengFetchByID { + shouldSelectChannel = false + } } else { - err = common.UnmarshalBodyReusable(c, &modelRequest) + platform = string(constant.TaskPlatformKling) + relayMode = relayconstant.Path2RelayKling(c.Request.Method, c.Request.URL.Path) + if relayMode == relayconstant.RelayModeKlingFetchByID { + shouldSelectChannel = false + } } - c.Set("platform", string(constant.TaskPlatformKling)) + c.Set("platform", platform) c.Set("relay_mode", relayMode) } else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") || strings.HasPrefix(c.Request.URL.Path, "/v1/models/") { // Gemini API 路径处理: /v1beta/models/gemini-2.0-flash:generateContent diff --git a/relay/channel/adapter.go b/relay/channel/adapter.go index 873997f6..2ff34e01 100644 --- a/relay/channel/adapter.go +++ b/relay/channel/adapter.go @@ -45,5 +45,5 @@ type TaskAdaptor interface { // FetchTask FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) - ParseResultUrl(resp map[string]any) (string, error) + ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) } diff --git a/relay/channel/task/jimeng/adaptor.go b/relay/channel/task/jimeng/adaptor.go new file mode 100644 index 00000000..3298bdcb --- /dev/null +++ b/relay/channel/task/jimeng/adaptor.go @@ -0,0 +1,379 @@ +package jimeng + +import ( + "bytes" + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "one-api/model" + "sort" + "strings" + "time" + + "github.com/gin-gonic/gin" + "github.com/pkg/errors" + + "one-api/common" + "one-api/dto" + "one-api/relay/channel" + relaycommon "one-api/relay/common" + "one-api/service" +) + +// ============================ +// Request / Response structures +// ============================ + +type requestPayload struct { + ReqKey string `json:"req_key"` + BinaryDataBase64 []string `json:"binary_data_base64,omitempty"` + ImageUrls []string `json:"image_urls,omitempty"` + Prompt string `json:"prompt,omitempty"` + Seed int64 `json:"seed"` + AspectRatio string `json:"aspect_ratio"` +} + +type responsePayload struct { + Code int `json:"code"` + Message string `json:"message"` + RequestId string `json:"request_id"` + Data struct { + TaskID string `json:"task_id"` + } `json:"data"` +} + +type responseTask struct { + Code int `json:"code"` + Data struct { + BinaryDataBase64 []interface{} `json:"binary_data_base64"` + ImageUrls interface{} `json:"image_urls"` + RespData string `json:"resp_data"` + Status string `json:"status"` + VideoUrl string `json:"video_url"` + } `json:"data"` + Message string `json:"message"` + RequestId string `json:"request_id"` + Status int `json:"status"` + TimeElapsed string `json:"time_elapsed"` +} + +// ============================ +// Adaptor implementation +// ============================ + +type TaskAdaptor struct { + ChannelType int + accessKey string + secretKey string + baseURL string +} + +func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) { + a.ChannelType = info.ChannelType + a.baseURL = info.BaseUrl + + // apiKey format: "access_key,secret_key" + keyParts := strings.Split(info.ApiKey, ",") + if len(keyParts) == 2 { + a.accessKey = strings.TrimSpace(keyParts[0]) + a.secretKey = strings.TrimSpace(keyParts[1]) + } +} + +// ValidateRequestAndSetAction parses body, validates fields and sets default action. +func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) { + // Accept only POST /v1/video/generations as "generate" action. + action := "generate" + info.Action = action + + req := relaycommon.TaskSubmitReq{} + if err := common.UnmarshalBodyReusable(c, &req); err != nil { + taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest) + return + } + if strings.TrimSpace(req.Prompt) == "" { + taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest) + return + } + + // Store into context for later usage + c.Set("task_request", req) + return nil +} + +// BuildRequestURL constructs the upstream URL. +func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) { + return fmt.Sprintf("%s/?Action=CVSync2AsyncSubmitTask&Version=2022-08-31", a.baseURL), nil +} + +// BuildRequestHeader sets required headers. +func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error { + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + return a.signRequest(req, a.accessKey, a.secretKey) +} + +// BuildRequestBody converts request into Jimeng specific format. +func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) { + v, exists := c.Get("task_request") + if !exists { + return nil, fmt.Errorf("request not found in context") + } + req := v.(relaycommon.TaskSubmitReq) + + body, err := a.convertToRequestPayload(&req) + if err != nil { + return nil, errors.Wrap(err, "convert request payload failed") + } + data, err := json.Marshal(body) + if err != nil { + return nil, err + } + return bytes.NewReader(data), nil +} + +// DoRequest delegates to common helper. +func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) { + return channel.DoTaskApiRequest(a, c, info, requestBody) +} + +// DoResponse handles upstream response, returns taskID etc. +func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) + return + } + _ = resp.Body.Close() + + // Parse Jimeng response + var jResp responsePayload + if err := json.Unmarshal(responseBody, &jResp); err != nil { + taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError) + return + } + + if jResp.Code != 10000 { + taskErr = service.TaskErrorWrapper(fmt.Errorf(jResp.Message), fmt.Sprintf("%d", jResp.Code), http.StatusInternalServerError) + return + } + + c.JSON(http.StatusOK, gin.H{"task_id": jResp.Data.TaskID}) + return jResp.Data.TaskID, responseBody, nil +} + +// FetchTask fetch task status +func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) { + taskID, ok := body["task_id"].(string) + if !ok { + return nil, fmt.Errorf("invalid task_id") + } + + uri := fmt.Sprintf("%s/?Action=CVSync2AsyncGetResult&Version=2022-08-31", baseUrl) + payload := map[string]string{ + "req_key": "jimeng_vgfm_t2v_l20", // This is fixed value from doc: https://www.volcengine.com/docs/85621/1544774 + "task_id": taskID, + } + payloadBytes, err := json.Marshal(payload) + if err != nil { + return nil, errors.Wrap(err, "marshal fetch task payload failed") + } + + req, err := http.NewRequest(http.MethodPost, uri, bytes.NewBuffer(payloadBytes)) + if err != nil { + return nil, err + } + + req.Header.Set("Accept", "application/json") + req.Header.Set("Content-Type", "application/json") + + keyParts := strings.Split(key, ",") + if len(keyParts) != 2 { + return nil, fmt.Errorf("invalid api key format for jimeng: expected 'ak,sk'") + } + accessKey := strings.TrimSpace(keyParts[0]) + secretKey := strings.TrimSpace(keyParts[1]) + + if err := a.signRequest(req, accessKey, secretKey); err != nil { + return nil, errors.Wrap(err, "sign request failed") + } + + return service.GetHttpClient().Do(req) +} + +func (a *TaskAdaptor) GetModelList() []string { + return []string{"jimeng_vgfm_t2v_l20"} +} + +func (a *TaskAdaptor) GetChannelName() string { + return "jimeng" +} + +func (a *TaskAdaptor) signRequest(req *http.Request, accessKey, secretKey string) error { + var bodyBytes []byte + var err error + + if req.Body != nil { + bodyBytes, err = io.ReadAll(req.Body) + if err != nil { + return errors.Wrap(err, "read request body failed") + } + _ = req.Body.Close() + req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // Rewind + } else { + bodyBytes = []byte{} + } + + payloadHash := sha256.Sum256(bodyBytes) + hexPayloadHash := hex.EncodeToString(payloadHash[:]) + + t := time.Now().UTC() + xDate := t.Format("20060102T150405Z") + shortDate := t.Format("20060102") + + req.Header.Set("Host", req.URL.Host) + req.Header.Set("X-Date", xDate) + req.Header.Set("X-Content-Sha256", hexPayloadHash) + + // Sort and encode query parameters to create canonical query string + queryParams := req.URL.Query() + sortedKeys := make([]string, 0, len(queryParams)) + for k := range queryParams { + sortedKeys = append(sortedKeys, k) + } + sort.Strings(sortedKeys) + var queryParts []string + for _, k := range sortedKeys { + values := queryParams[k] + sort.Strings(values) + for _, v := range values { + queryParts = append(queryParts, fmt.Sprintf("%s=%s", url.QueryEscape(k), url.QueryEscape(v))) + } + } + canonicalQueryString := strings.Join(queryParts, "&") + + headersToSign := map[string]string{ + "host": req.URL.Host, + "x-date": xDate, + "x-content-sha256": hexPayloadHash, + } + if req.Header.Get("Content-Type") != "" { + headersToSign["content-type"] = req.Header.Get("Content-Type") + } + + var signedHeaderKeys []string + for k := range headersToSign { + signedHeaderKeys = append(signedHeaderKeys, k) + } + sort.Strings(signedHeaderKeys) + + var canonicalHeaders strings.Builder + for _, k := range signedHeaderKeys { + canonicalHeaders.WriteString(k) + canonicalHeaders.WriteString(":") + canonicalHeaders.WriteString(strings.TrimSpace(headersToSign[k])) + canonicalHeaders.WriteString("\n") + } + signedHeaders := strings.Join(signedHeaderKeys, ";") + + canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s", + req.Method, + req.URL.Path, + canonicalQueryString, + canonicalHeaders.String(), + signedHeaders, + hexPayloadHash, + ) + + hashedCanonicalRequest := sha256.Sum256([]byte(canonicalRequest)) + hexHashedCanonicalRequest := hex.EncodeToString(hashedCanonicalRequest[:]) + + region := "cn-north-1" + serviceName := "cv" + credentialScope := fmt.Sprintf("%s/%s/%s/request", shortDate, region, serviceName) + stringToSign := fmt.Sprintf("HMAC-SHA256\n%s\n%s\n%s", + xDate, + credentialScope, + hexHashedCanonicalRequest, + ) + + kDate := hmacSHA256([]byte(secretKey), []byte(shortDate)) + kRegion := hmacSHA256(kDate, []byte(region)) + kService := hmacSHA256(kRegion, []byte(serviceName)) + kSigning := hmacSHA256(kService, []byte("request")) + signature := hex.EncodeToString(hmacSHA256(kSigning, []byte(stringToSign))) + + authorization := fmt.Sprintf("HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s", + accessKey, + credentialScope, + signedHeaders, + signature, + ) + req.Header.Set("Authorization", authorization) + return nil +} + +func hmacSHA256(key []byte, data []byte) []byte { + h := hmac.New(sha256.New, key) + h.Write(data) + return h.Sum(nil) +} + +func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) { + r := requestPayload{ + ReqKey: "jimeng_vgfm_i2v_l20", + Prompt: req.Prompt, + AspectRatio: "16:9", // Default aspect ratio + Seed: -1, // Default to random + } + + // Handle one-of image_urls or binary_data_base64 + if req.Image != "" { + if strings.HasPrefix(req.Image, "http") { + r.ImageUrls = []string{req.Image} + } else { + r.BinaryDataBase64 = []string{req.Image} + } + } + metadata := req.Metadata + medaBytes, err := json.Marshal(metadata) + if err != nil { + return nil, errors.Wrap(err, "metadata marshal metadata failed") + } + err = json.Unmarshal(medaBytes, &r) + if err != nil { + return nil, errors.Wrap(err, "unmarshal metadata failed") + } + return &r, nil +} + +func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { + resTask := responseTask{} + if err := json.Unmarshal(respBody, &resTask); err != nil { + return nil, errors.Wrap(err, "unmarshal task result failed") + } + taskResult := relaycommon.TaskInfo{} + if resTask.Code == 10000 { + taskResult.Code = 0 + } else { + taskResult.Code = resTask.Code // todo uni code + taskResult.Reason = resTask.Message + taskResult.Status = model.TaskStatusFailure + taskResult.Progress = "100%" + } + switch resTask.Data.Status { + case "in_queue": + taskResult.Status = model.TaskStatusQueued + taskResult.Progress = "10%" + case "done": + taskResult.Status = model.TaskStatusSuccess + taskResult.Progress = "100%" + } + taskResult.Url = resTask.Data.VideoUrl + return &taskResult, nil +} diff --git a/relay/channel/task/kling/adaptor.go b/relay/channel/task/kling/adaptor.go index 55f21196..2a02472b 100644 --- a/relay/channel/task/kling/adaptor.go +++ b/relay/channel/task/kling/adaptor.go @@ -2,12 +2,12 @@ package kling import ( "bytes" - "context" "encoding/json" "fmt" "github.com/samber/lo" "io" "net/http" + "one-api/model" "strings" "time" @@ -47,10 +47,22 @@ type requestPayload struct { } type responsePayload struct { - Code int `json:"code"` - Message string `json:"message"` - Data struct { - TaskID string `json:"task_id"` + Code int `json:"code"` + Message string `json:"message"` + RequestId string `json:"request_id"` + Data struct { + TaskId string `json:"task_id"` + TaskStatus string `json:"task_status"` + TaskStatusMsg string `json:"task_status_msg"` + TaskResult struct { + Videos []struct { + Id string `json:"id"` + Url string `json:"url"` + Duration string `json:"duration"` + } `json:"videos"` + } `json:"task_result"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` } `json:"data"` } @@ -94,7 +106,7 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom } // Store into context for later usage - c.Set("kling_request", req) + c.Set("task_request", req) return nil } @@ -120,7 +132,7 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info // BuildRequestBody converts request into Kling specific format. func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) { - v, exists := c.Get("kling_request") + v, exists := c.Get("task_request") if !exists { return nil, fmt.Errorf("request not found in context") } @@ -156,8 +168,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela // Attempt Kling response parse first. var kResp responsePayload if err := json.Unmarshal(responseBody, &kResp); err == nil && kResp.Code == 0 { - c.JSON(http.StatusOK, gin.H{"task_id": kResp.Data.TaskID}) - return kResp.Data.TaskID, responseBody, nil + c.JSON(http.StatusOK, gin.H{"task_id": kResp.Data.TaskId}) + return kResp.Data.TaskId, responseBody, nil } // Fallback generic task response. @@ -199,10 +211,6 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http token = key } - ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) - defer cancel() - - req = req.WithContext(ctx) req.Header.Set("Accept", "application/json") req.Header.Set("Authorization", "Bearer "+token) req.Header.Set("User-Agent", "kling-sdk/1.0") @@ -305,27 +313,33 @@ func (a *TaskAdaptor) createJWTTokenWithKeys(accessKey, secretKey string) (strin return token.SignedString([]byte(secretKey)) } -// ParseResultUrl 提取视频任务结果的 url -func (a *TaskAdaptor) ParseResultUrl(resp map[string]any) (string, error) { - data, ok := resp["data"].(map[string]any) - if !ok { - return "", fmt.Errorf("data field not found or invalid") +func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { + resPayload := responsePayload{} + err := json.Unmarshal(respBody, &resPayload) + if err != nil { + return nil, errors.Wrap(err, "failed to unmarshal response body") } - taskResult, ok := data["task_result"].(map[string]any) - if !ok { - return "", fmt.Errorf("task_result field not found or invalid") + taskInfo := &relaycommon.TaskInfo{} + taskInfo.Code = resPayload.Code + taskInfo.TaskID = resPayload.Data.TaskId + taskInfo.Reason = resPayload.Message + //任务状态,枚举值:submitted(已提交)、processing(处理中)、succeed(成功)、failed(失败) + status := resPayload.Data.TaskStatus + switch status { + case "submitted": + taskInfo.Status = model.TaskStatusSubmitted + case "processing": + taskInfo.Status = model.TaskStatusInProgress + case "succeed": + taskInfo.Status = model.TaskStatusSuccess + case "failed": + taskInfo.Status = model.TaskStatusFailure + default: + return nil, fmt.Errorf("unknown task status: %s", status) } - videos, ok := taskResult["videos"].([]interface{}) - if !ok || len(videos) == 0 { - return "", fmt.Errorf("videos field not found or empty") + if videos := resPayload.Data.TaskResult.Videos; len(videos) > 0 { + video := videos[0] + taskInfo.Url = video.Url } - video, ok := videos[0].(map[string]interface{}) - if !ok { - return "", fmt.Errorf("video item invalid") - } - url, ok := video["url"].(string) - if !ok || url == "" { - return "", fmt.Errorf("url field not found or invalid") - } - return url, nil + return taskInfo, nil } diff --git a/relay/channel/task/suno/adaptor.go b/relay/channel/task/suno/adaptor.go index f7042348..9c04c7ad 100644 --- a/relay/channel/task/suno/adaptor.go +++ b/relay/channel/task/suno/adaptor.go @@ -22,8 +22,8 @@ type TaskAdaptor struct { ChannelType int } -func (a *TaskAdaptor) ParseResultUrl(resp map[string]any) (string, error) { - return "", nil // todo implement this method if needed +func (a *TaskAdaptor) ParseTaskResult([]byte) (*relaycommon.TaskInfo, error) { + return nil, fmt.Errorf("not implement") // todo implement this method if needed } func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) { diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index f3fc9ce9..5fd94788 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -313,3 +313,22 @@ func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo { } return info } + +type TaskSubmitReq struct { + Prompt string `json:"prompt"` + Model string `json:"model,omitempty"` + Mode string `json:"mode,omitempty"` + Image string `json:"image,omitempty"` + Size string `json:"size,omitempty"` + Duration int `json:"duration,omitempty"` + Metadata map[string]interface{} `json:"metadata,omitempty"` +} + +type TaskInfo struct { + Code int `json:"code"` + TaskID string `json:"task_id"` + Status string `json:"status"` + Reason string `json:"reason,omitempty"` + Url string `json:"url,omitempty"` + Progress string `json:"progress,omitempty"` +} diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go index 1b8a1b2d..cc8a1494 100644 --- a/relay/constant/relay_mode.go +++ b/relay/constant/relay_mode.go @@ -41,6 +41,9 @@ const ( RelayModeKlingFetchByID RelayModeKlingSubmit + RelayModeJimengFetchByID + RelayModeJimengSubmit + RelayModeRerank RelayModeResponses @@ -146,3 +149,13 @@ func Path2RelayKling(method, path string) int { } return relayMode } + +func Path2RelayJimeng(method, path string) int { + relayMode := RelayModeUnknown + if method == http.MethodPost && strings.HasSuffix(path, "/video/generations") { + relayMode = RelayModeJimengSubmit + } else if method == http.MethodGet && strings.Contains(path, "/video/generations/") { + relayMode = RelayModeJimengFetchByID + } + return relayMode +} diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go index 626bb7e4..f648b4d5 100644 --- a/relay/relay_adaptor.go +++ b/relay/relay_adaptor.go @@ -22,6 +22,7 @@ import ( "one-api/relay/channel/palm" "one-api/relay/channel/perplexity" "one-api/relay/channel/siliconflow" + "one-api/relay/channel/task/jimeng" "one-api/relay/channel/task/kling" "one-api/relay/channel/task/suno" "one-api/relay/channel/tencent" @@ -104,6 +105,8 @@ func GetTaskAdaptor(platform commonconstant.TaskPlatform) channel.TaskAdaptor { return &suno.TaskAdaptor{} case commonconstant.TaskPlatformKling: return &kling.TaskAdaptor{} + case commonconstant.TaskPlatformJimeng: + return &jimeng.TaskAdaptor{} } return nil } diff --git a/relay/relay_task.go b/relay/relay_task.go index b8004105..702cff4c 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -245,7 +245,7 @@ func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dt } func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) { - taskId := c.Param("id") + taskId := c.Param("task_id") userId := c.GetInt("id") originTask, exist, err := model.GetByTaskId(userId, taskId) diff --git a/web/src/components/table/TaskLogsTable.js b/web/src/components/table/TaskLogsTable.js index 5b77ce39..65a8e2a6 100644 --- a/web/src/components/table/TaskLogsTable.js +++ b/web/src/components/table/TaskLogsTable.js @@ -230,8 +230,8 @@ const LogsTable = () => { } }; - const renderPlatform = (type) => { - switch (type) { + const renderPlatform = (platform) => { + switch (platform) { case 'suno': return ( }> @@ -240,10 +240,16 @@ const LogsTable = () => { ); case 'kling': return ( - }> + }> Kling ); + case 'jimeng': + return ( + }> + Jimeng + + ); default: return ( }> diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index 4018aa4f..b145ea11 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -130,6 +130,11 @@ export const CHANNEL_OPTIONS = [ color: 'green', label: '可灵', }, + { + value: 51, + color: 'blue', + label: '即梦', + }, ]; export const MODEL_TABLE_PAGE_SIZE = 10; diff --git a/web/src/i18n/locales/en.json b/web/src/i18n/locales/en.json index 060f3e65..8a38ef0c 100644 --- a/web/src/i18n/locales/en.json +++ b/web/src/i18n/locales/en.json @@ -801,6 +801,7 @@ "获取无水印": "Get no watermark", "生成图片": "Generate pictures", "可灵": "Kling", + "即梦": "Jimeng", "正在提交": "Submitting", "执行中": "processing", "平台": "platform",