diff --git a/common/constants.go b/common/constants.go
index ac803148..67625439 100644
--- a/common/constants.go
+++ b/common/constants.go
@@ -242,6 +242,7 @@ const (
ChannelTypeXai = 48
ChannelTypeCoze = 49
ChannelTypeKling = 50
+ ChannelTypeJimeng = 51
ChannelTypeDummy // this one is only for count, do not add any channel after this
)
@@ -298,4 +299,5 @@ var ChannelBaseURLs = []string{
"https://api.x.ai", //48
"https://api.coze.cn", //49
"https://api.klingai.com", //50
+ "https://visual.volcengineapi.com", //51
}
diff --git a/constant/task.go b/constant/task.go
index d466fc8a..21831d3a 100644
--- a/constant/task.go
+++ b/constant/task.go
@@ -6,6 +6,7 @@ const (
TaskPlatformSuno TaskPlatform = "suno"
TaskPlatformMidjourney = "mj"
TaskPlatformKling TaskPlatform = "kling"
+ TaskPlatformJimeng TaskPlatform = "jimeng"
)
const (
diff --git a/controller/channel-test.go b/controller/channel-test.go
index db8c9db0..b3badf35 100644
--- a/controller/channel-test.go
+++ b/controller/channel-test.go
@@ -43,6 +43,9 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
if channel.Type == common.ChannelTypeKling {
return errors.New("kling channel test is not supported"), nil
}
+ if channel.Type == common.ChannelTypeJimeng {
+ return errors.New("jimeng channel test is not supported"), nil
+ }
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
diff --git a/controller/task.go b/controller/task.go
index f7523e87..5cfa728a 100644
--- a/controller/task.go
+++ b/controller/task.go
@@ -74,8 +74,8 @@ func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][
//_ = UpdateMidjourneyTaskAll(context.Background(), tasks)
case constant.TaskPlatformSuno:
_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
- case constant.TaskPlatformKling:
- _ = UpdateVideoTaskAll(context.Background(), taskChannelM, taskM)
+ case constant.TaskPlatformKling, constant.TaskPlatformJimeng:
+ _ = UpdateVideoTaskAll(context.Background(), platform, taskChannelM, taskM)
default:
common.SysLog("未知平台")
}
diff --git a/controller/task_video.go b/controller/task_video.go
index 2e980310..a17351b5 100644
--- a/controller/task_video.go
+++ b/controller/task_video.go
@@ -2,27 +2,26 @@ package controller
import (
"context"
- "encoding/json"
"fmt"
"io"
- "net/http"
"one-api/common"
"one-api/constant"
"one-api/model"
"one-api/relay"
"one-api/relay/channel"
+ "time"
)
-func UpdateVideoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
+func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
for channelId, taskIds := range taskChannelM {
- if err := updateVideoTaskAll(ctx, channelId, taskIds, taskM); err != nil {
+ if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil {
common.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
}
}
return nil
}
-func updateVideoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
+func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
common.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
if len(taskIds) == 0 {
return nil
@@ -39,7 +38,7 @@ func updateVideoTaskAll(ctx context.Context, channelId int, taskIds []string, ta
}
return fmt.Errorf("CacheGetChannel failed: %w", err)
}
- adaptor := relay.GetTaskAdaptor(constant.TaskPlatformKling)
+ adaptor := relay.GetTaskAdaptor(platform)
if adaptor == nil {
return fmt.Errorf("video adaptor not found")
}
@@ -67,60 +66,53 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
"action": task.Action,
})
if err != nil {
- return fmt.Errorf("FetchTask failed for task %s: %w", taskId, err)
- }
- if resp.StatusCode != http.StatusOK {
- return fmt.Errorf("Get Video Task status code: %d", resp.StatusCode)
+ return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err)
}
+ //if resp.StatusCode != http.StatusOK {
+ //return fmt.Errorf("get Video Task status code: %d", resp.StatusCode)
+ //}
defer resp.Body.Close()
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
- return fmt.Errorf("ReadAll failed for task %s: %w", taskId, err)
+ return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
}
- var responseItem map[string]interface{}
- err = json.Unmarshal(responseBody, &responseItem)
+ taskResult, err := adaptor.ParseTaskResult(responseBody)
if err != nil {
- common.LogError(ctx, fmt.Sprintf("Failed to parse video task response body: %v, body: %s", err, string(responseBody)))
- return fmt.Errorf("Unmarshal failed for task %s: %w", taskId, err)
+ return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
}
+ //if taskResult.Code != 0 {
+ // return fmt.Errorf("video task fetch failed for task %s", taskId)
+ //}
- code, _ := responseItem["code"].(float64)
- if code != 0 {
- return fmt.Errorf("video task fetch failed for task %s", taskId)
+ now := time.Now().Unix()
+ if taskResult.Status == "" {
+ return fmt.Errorf("task %s status is empty", taskId)
}
-
- data, ok := responseItem["data"].(map[string]interface{})
- if !ok {
- common.LogError(ctx, fmt.Sprintf("Video task data format error: %s", string(responseBody)))
- return fmt.Errorf("video task data format error for task %s", taskId)
- }
-
- if status, ok := data["task_status"].(string); ok {
- switch status {
- case "submitted", "queued":
- task.Status = model.TaskStatusSubmitted
- case "processing":
- task.Status = model.TaskStatusInProgress
- case "succeed":
- task.Status = model.TaskStatusSuccess
- task.Progress = "100%"
- if url, err := adaptor.ParseResultUrl(responseItem); err == nil {
- task.FailReason = url
- } else {
- common.LogWarn(ctx, fmt.Sprintf("Failed to get url from body for task %s: %s", task.TaskID, err.Error()))
- }
- case "failed":
- task.Status = model.TaskStatusFailure
- task.Progress = "100%"
- if reason, ok := data["fail_reason"].(string); ok {
- task.FailReason = reason
- }
+ task.Status = model.TaskStatus(taskResult.Status)
+ switch taskResult.Status {
+ case model.TaskStatusSubmitted:
+ task.Progress = "10%"
+ case model.TaskStatusQueued:
+ task.Progress = "20%"
+ case model.TaskStatusInProgress:
+ task.Progress = "30%"
+ if task.StartTime == 0 {
+ task.StartTime = now
}
- }
-
- // If task failed, refund quota
- if task.Status == model.TaskStatusFailure {
+ case model.TaskStatusSuccess:
+ task.Progress = "100%"
+ if task.FinishTime == 0 {
+ task.FinishTime = now
+ }
+ task.FailReason = taskResult.Url
+ case model.TaskStatusFailure:
+ task.Status = model.TaskStatusFailure
+ task.Progress = "100%"
+ if task.FinishTime == 0 {
+ task.FinishTime = now
+ }
+ task.FailReason = taskResult.Reason
common.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
quota := task.Quota
if quota != 0 {
@@ -130,6 +122,11 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, common.LogQuota(quota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
}
+ default:
+ return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId)
+ }
+ if taskResult.Progress != "" {
+ task.Progress = taskResult.Progress
}
task.Data = responseBody
diff --git a/middleware/distributor.go b/middleware/distributor.go
index d1159f8c..0a6a9af4 100644
--- a/middleware/distributor.go
+++ b/middleware/distributor.go
@@ -171,13 +171,23 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
c.Set("platform", string(constant.TaskPlatformSuno))
c.Set("relay_mode", relayMode)
} else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") {
- relayMode := relayconstant.Path2RelayKling(c.Request.Method, c.Request.URL.Path)
- if relayMode == relayconstant.RelayModeKlingFetchByID {
- shouldSelectChannel = false
+ err = common.UnmarshalBodyReusable(c, &modelRequest)
+ var platform string
+ var relayMode int
+ if strings.HasPrefix(modelRequest.Model, "jimeng") {
+ platform = string(constant.TaskPlatformJimeng)
+ relayMode = relayconstant.Path2RelayJimeng(c.Request.Method, c.Request.URL.Path)
+ if relayMode == relayconstant.RelayModeJimengFetchByID {
+ shouldSelectChannel = false
+ }
} else {
- err = common.UnmarshalBodyReusable(c, &modelRequest)
+ platform = string(constant.TaskPlatformKling)
+ relayMode = relayconstant.Path2RelayKling(c.Request.Method, c.Request.URL.Path)
+ if relayMode == relayconstant.RelayModeKlingFetchByID {
+ shouldSelectChannel = false
+ }
}
- c.Set("platform", string(constant.TaskPlatformKling))
+ c.Set("platform", platform)
c.Set("relay_mode", relayMode)
} else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") || strings.HasPrefix(c.Request.URL.Path, "/v1/models/") {
// Gemini API 路径处理: /v1beta/models/gemini-2.0-flash:generateContent
diff --git a/relay/channel/adapter.go b/relay/channel/adapter.go
index 873997f6..2ff34e01 100644
--- a/relay/channel/adapter.go
+++ b/relay/channel/adapter.go
@@ -45,5 +45,5 @@ type TaskAdaptor interface {
// FetchTask
FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error)
- ParseResultUrl(resp map[string]any) (string, error)
+ ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error)
}
diff --git a/relay/channel/task/jimeng/adaptor.go b/relay/channel/task/jimeng/adaptor.go
new file mode 100644
index 00000000..3298bdcb
--- /dev/null
+++ b/relay/channel/task/jimeng/adaptor.go
@@ -0,0 +1,379 @@
+package jimeng
+
+import (
+ "bytes"
+ "crypto/hmac"
+ "crypto/sha256"
+ "encoding/hex"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "one-api/model"
+ "sort"
+ "strings"
+ "time"
+
+ "github.com/gin-gonic/gin"
+ "github.com/pkg/errors"
+
+ "one-api/common"
+ "one-api/dto"
+ "one-api/relay/channel"
+ relaycommon "one-api/relay/common"
+ "one-api/service"
+)
+
+// ============================
+// Request / Response structures
+// ============================
+
+type requestPayload struct {
+ ReqKey string `json:"req_key"`
+ BinaryDataBase64 []string `json:"binary_data_base64,omitempty"`
+ ImageUrls []string `json:"image_urls,omitempty"`
+ Prompt string `json:"prompt,omitempty"`
+ Seed int64 `json:"seed"`
+ AspectRatio string `json:"aspect_ratio"`
+}
+
+type responsePayload struct {
+ Code int `json:"code"`
+ Message string `json:"message"`
+ RequestId string `json:"request_id"`
+ Data struct {
+ TaskID string `json:"task_id"`
+ } `json:"data"`
+}
+
+type responseTask struct {
+ Code int `json:"code"`
+ Data struct {
+ BinaryDataBase64 []interface{} `json:"binary_data_base64"`
+ ImageUrls interface{} `json:"image_urls"`
+ RespData string `json:"resp_data"`
+ Status string `json:"status"`
+ VideoUrl string `json:"video_url"`
+ } `json:"data"`
+ Message string `json:"message"`
+ RequestId string `json:"request_id"`
+ Status int `json:"status"`
+ TimeElapsed string `json:"time_elapsed"`
+}
+
+// ============================
+// Adaptor implementation
+// ============================
+
+type TaskAdaptor struct {
+ ChannelType int
+ accessKey string
+ secretKey string
+ baseURL string
+}
+
+func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
+ a.ChannelType = info.ChannelType
+ a.baseURL = info.BaseUrl
+
+ // apiKey format: "access_key,secret_key"
+ keyParts := strings.Split(info.ApiKey, ",")
+ if len(keyParts) == 2 {
+ a.accessKey = strings.TrimSpace(keyParts[0])
+ a.secretKey = strings.TrimSpace(keyParts[1])
+ }
+}
+
+// ValidateRequestAndSetAction parses body, validates fields and sets default action.
+func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) {
+ // Accept only POST /v1/video/generations as "generate" action.
+ action := "generate"
+ info.Action = action
+
+ req := relaycommon.TaskSubmitReq{}
+ if err := common.UnmarshalBodyReusable(c, &req); err != nil {
+ taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
+ return
+ }
+ if strings.TrimSpace(req.Prompt) == "" {
+ taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest)
+ return
+ }
+
+ // Store into context for later usage
+ c.Set("task_request", req)
+ return nil
+}
+
+// BuildRequestURL constructs the upstream URL.
+func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
+ return fmt.Sprintf("%s/?Action=CVSync2AsyncSubmitTask&Version=2022-08-31", a.baseURL), nil
+}
+
+// BuildRequestHeader sets required headers.
+func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error {
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Accept", "application/json")
+ return a.signRequest(req, a.accessKey, a.secretKey)
+}
+
+// BuildRequestBody converts request into Jimeng specific format.
+func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) {
+ v, exists := c.Get("task_request")
+ if !exists {
+ return nil, fmt.Errorf("request not found in context")
+ }
+ req := v.(relaycommon.TaskSubmitReq)
+
+ body, err := a.convertToRequestPayload(&req)
+ if err != nil {
+ return nil, errors.Wrap(err, "convert request payload failed")
+ }
+ data, err := json.Marshal(body)
+ if err != nil {
+ return nil, err
+ }
+ return bytes.NewReader(data), nil
+}
+
+// DoRequest delegates to common helper.
+func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
+ return channel.DoTaskApiRequest(a, c, info, requestBody)
+}
+
+// DoResponse handles upstream response, returns taskID etc.
+func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
+ return
+ }
+ _ = resp.Body.Close()
+
+ // Parse Jimeng response
+ var jResp responsePayload
+ if err := json.Unmarshal(responseBody, &jResp); err != nil {
+ taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
+ return
+ }
+
+ if jResp.Code != 10000 {
+ taskErr = service.TaskErrorWrapper(fmt.Errorf(jResp.Message), fmt.Sprintf("%d", jResp.Code), http.StatusInternalServerError)
+ return
+ }
+
+ c.JSON(http.StatusOK, gin.H{"task_id": jResp.Data.TaskID})
+ return jResp.Data.TaskID, responseBody, nil
+}
+
+// FetchTask fetch task status
+func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
+ taskID, ok := body["task_id"].(string)
+ if !ok {
+ return nil, fmt.Errorf("invalid task_id")
+ }
+
+ uri := fmt.Sprintf("%s/?Action=CVSync2AsyncGetResult&Version=2022-08-31", baseUrl)
+ payload := map[string]string{
+ "req_key": "jimeng_vgfm_t2v_l20", // This is fixed value from doc: https://www.volcengine.com/docs/85621/1544774
+ "task_id": taskID,
+ }
+ payloadBytes, err := json.Marshal(payload)
+ if err != nil {
+ return nil, errors.Wrap(err, "marshal fetch task payload failed")
+ }
+
+ req, err := http.NewRequest(http.MethodPost, uri, bytes.NewBuffer(payloadBytes))
+ if err != nil {
+ return nil, err
+ }
+
+ req.Header.Set("Accept", "application/json")
+ req.Header.Set("Content-Type", "application/json")
+
+ keyParts := strings.Split(key, ",")
+ if len(keyParts) != 2 {
+ return nil, fmt.Errorf("invalid api key format for jimeng: expected 'ak,sk'")
+ }
+ accessKey := strings.TrimSpace(keyParts[0])
+ secretKey := strings.TrimSpace(keyParts[1])
+
+ if err := a.signRequest(req, accessKey, secretKey); err != nil {
+ return nil, errors.Wrap(err, "sign request failed")
+ }
+
+ return service.GetHttpClient().Do(req)
+}
+
+func (a *TaskAdaptor) GetModelList() []string {
+ return []string{"jimeng_vgfm_t2v_l20"}
+}
+
+func (a *TaskAdaptor) GetChannelName() string {
+ return "jimeng"
+}
+
+func (a *TaskAdaptor) signRequest(req *http.Request, accessKey, secretKey string) error {
+ var bodyBytes []byte
+ var err error
+
+ if req.Body != nil {
+ bodyBytes, err = io.ReadAll(req.Body)
+ if err != nil {
+ return errors.Wrap(err, "read request body failed")
+ }
+ _ = req.Body.Close()
+ req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // Rewind
+ } else {
+ bodyBytes = []byte{}
+ }
+
+ payloadHash := sha256.Sum256(bodyBytes)
+ hexPayloadHash := hex.EncodeToString(payloadHash[:])
+
+ t := time.Now().UTC()
+ xDate := t.Format("20060102T150405Z")
+ shortDate := t.Format("20060102")
+
+ req.Header.Set("Host", req.URL.Host)
+ req.Header.Set("X-Date", xDate)
+ req.Header.Set("X-Content-Sha256", hexPayloadHash)
+
+ // Sort and encode query parameters to create canonical query string
+ queryParams := req.URL.Query()
+ sortedKeys := make([]string, 0, len(queryParams))
+ for k := range queryParams {
+ sortedKeys = append(sortedKeys, k)
+ }
+ sort.Strings(sortedKeys)
+ var queryParts []string
+ for _, k := range sortedKeys {
+ values := queryParams[k]
+ sort.Strings(values)
+ for _, v := range values {
+ queryParts = append(queryParts, fmt.Sprintf("%s=%s", url.QueryEscape(k), url.QueryEscape(v)))
+ }
+ }
+ canonicalQueryString := strings.Join(queryParts, "&")
+
+ headersToSign := map[string]string{
+ "host": req.URL.Host,
+ "x-date": xDate,
+ "x-content-sha256": hexPayloadHash,
+ }
+ if req.Header.Get("Content-Type") != "" {
+ headersToSign["content-type"] = req.Header.Get("Content-Type")
+ }
+
+ var signedHeaderKeys []string
+ for k := range headersToSign {
+ signedHeaderKeys = append(signedHeaderKeys, k)
+ }
+ sort.Strings(signedHeaderKeys)
+
+ var canonicalHeaders strings.Builder
+ for _, k := range signedHeaderKeys {
+ canonicalHeaders.WriteString(k)
+ canonicalHeaders.WriteString(":")
+ canonicalHeaders.WriteString(strings.TrimSpace(headersToSign[k]))
+ canonicalHeaders.WriteString("\n")
+ }
+ signedHeaders := strings.Join(signedHeaderKeys, ";")
+
+ canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s",
+ req.Method,
+ req.URL.Path,
+ canonicalQueryString,
+ canonicalHeaders.String(),
+ signedHeaders,
+ hexPayloadHash,
+ )
+
+ hashedCanonicalRequest := sha256.Sum256([]byte(canonicalRequest))
+ hexHashedCanonicalRequest := hex.EncodeToString(hashedCanonicalRequest[:])
+
+ region := "cn-north-1"
+ serviceName := "cv"
+ credentialScope := fmt.Sprintf("%s/%s/%s/request", shortDate, region, serviceName)
+ stringToSign := fmt.Sprintf("HMAC-SHA256\n%s\n%s\n%s",
+ xDate,
+ credentialScope,
+ hexHashedCanonicalRequest,
+ )
+
+ kDate := hmacSHA256([]byte(secretKey), []byte(shortDate))
+ kRegion := hmacSHA256(kDate, []byte(region))
+ kService := hmacSHA256(kRegion, []byte(serviceName))
+ kSigning := hmacSHA256(kService, []byte("request"))
+ signature := hex.EncodeToString(hmacSHA256(kSigning, []byte(stringToSign)))
+
+ authorization := fmt.Sprintf("HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s",
+ accessKey,
+ credentialScope,
+ signedHeaders,
+ signature,
+ )
+ req.Header.Set("Authorization", authorization)
+ return nil
+}
+
+func hmacSHA256(key []byte, data []byte) []byte {
+ h := hmac.New(sha256.New, key)
+ h.Write(data)
+ return h.Sum(nil)
+}
+
+func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
+ r := requestPayload{
+ ReqKey: "jimeng_vgfm_i2v_l20",
+ Prompt: req.Prompt,
+ AspectRatio: "16:9", // Default aspect ratio
+ Seed: -1, // Default to random
+ }
+
+ // Handle one-of image_urls or binary_data_base64
+ if req.Image != "" {
+ if strings.HasPrefix(req.Image, "http") {
+ r.ImageUrls = []string{req.Image}
+ } else {
+ r.BinaryDataBase64 = []string{req.Image}
+ }
+ }
+ metadata := req.Metadata
+ medaBytes, err := json.Marshal(metadata)
+ if err != nil {
+ return nil, errors.Wrap(err, "metadata marshal metadata failed")
+ }
+ err = json.Unmarshal(medaBytes, &r)
+ if err != nil {
+ return nil, errors.Wrap(err, "unmarshal metadata failed")
+ }
+ return &r, nil
+}
+
+func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
+ resTask := responseTask{}
+ if err := json.Unmarshal(respBody, &resTask); err != nil {
+ return nil, errors.Wrap(err, "unmarshal task result failed")
+ }
+ taskResult := relaycommon.TaskInfo{}
+ if resTask.Code == 10000 {
+ taskResult.Code = 0
+ } else {
+ taskResult.Code = resTask.Code // todo uni code
+ taskResult.Reason = resTask.Message
+ taskResult.Status = model.TaskStatusFailure
+ taskResult.Progress = "100%"
+ }
+ switch resTask.Data.Status {
+ case "in_queue":
+ taskResult.Status = model.TaskStatusQueued
+ taskResult.Progress = "10%"
+ case "done":
+ taskResult.Status = model.TaskStatusSuccess
+ taskResult.Progress = "100%"
+ }
+ taskResult.Url = resTask.Data.VideoUrl
+ return &taskResult, nil
+}
diff --git a/relay/channel/task/kling/adaptor.go b/relay/channel/task/kling/adaptor.go
index 55f21196..2a02472b 100644
--- a/relay/channel/task/kling/adaptor.go
+++ b/relay/channel/task/kling/adaptor.go
@@ -2,12 +2,12 @@ package kling
import (
"bytes"
- "context"
"encoding/json"
"fmt"
"github.com/samber/lo"
"io"
"net/http"
+ "one-api/model"
"strings"
"time"
@@ -47,10 +47,22 @@ type requestPayload struct {
}
type responsePayload struct {
- Code int `json:"code"`
- Message string `json:"message"`
- Data struct {
- TaskID string `json:"task_id"`
+ Code int `json:"code"`
+ Message string `json:"message"`
+ RequestId string `json:"request_id"`
+ Data struct {
+ TaskId string `json:"task_id"`
+ TaskStatus string `json:"task_status"`
+ TaskStatusMsg string `json:"task_status_msg"`
+ TaskResult struct {
+ Videos []struct {
+ Id string `json:"id"`
+ Url string `json:"url"`
+ Duration string `json:"duration"`
+ } `json:"videos"`
+ } `json:"task_result"`
+ CreatedAt int64 `json:"created_at"`
+ UpdatedAt int64 `json:"updated_at"`
} `json:"data"`
}
@@ -94,7 +106,7 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
}
// Store into context for later usage
- c.Set("kling_request", req)
+ c.Set("task_request", req)
return nil
}
@@ -120,7 +132,7 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
// BuildRequestBody converts request into Kling specific format.
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) {
- v, exists := c.Get("kling_request")
+ v, exists := c.Get("task_request")
if !exists {
return nil, fmt.Errorf("request not found in context")
}
@@ -156,8 +168,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
// Attempt Kling response parse first.
var kResp responsePayload
if err := json.Unmarshal(responseBody, &kResp); err == nil && kResp.Code == 0 {
- c.JSON(http.StatusOK, gin.H{"task_id": kResp.Data.TaskID})
- return kResp.Data.TaskID, responseBody, nil
+ c.JSON(http.StatusOK, gin.H{"task_id": kResp.Data.TaskId})
+ return kResp.Data.TaskId, responseBody, nil
}
// Fallback generic task response.
@@ -199,10 +211,6 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
token = key
}
- ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
- defer cancel()
-
- req = req.WithContext(ctx)
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("User-Agent", "kling-sdk/1.0")
@@ -305,27 +313,33 @@ func (a *TaskAdaptor) createJWTTokenWithKeys(accessKey, secretKey string) (strin
return token.SignedString([]byte(secretKey))
}
-// ParseResultUrl 提取视频任务结果的 url
-func (a *TaskAdaptor) ParseResultUrl(resp map[string]any) (string, error) {
- data, ok := resp["data"].(map[string]any)
- if !ok {
- return "", fmt.Errorf("data field not found or invalid")
+func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
+ resPayload := responsePayload{}
+ err := json.Unmarshal(respBody, &resPayload)
+ if err != nil {
+ return nil, errors.Wrap(err, "failed to unmarshal response body")
}
- taskResult, ok := data["task_result"].(map[string]any)
- if !ok {
- return "", fmt.Errorf("task_result field not found or invalid")
+ taskInfo := &relaycommon.TaskInfo{}
+ taskInfo.Code = resPayload.Code
+ taskInfo.TaskID = resPayload.Data.TaskId
+ taskInfo.Reason = resPayload.Message
+ //任务状态,枚举值:submitted(已提交)、processing(处理中)、succeed(成功)、failed(失败)
+ status := resPayload.Data.TaskStatus
+ switch status {
+ case "submitted":
+ taskInfo.Status = model.TaskStatusSubmitted
+ case "processing":
+ taskInfo.Status = model.TaskStatusInProgress
+ case "succeed":
+ taskInfo.Status = model.TaskStatusSuccess
+ case "failed":
+ taskInfo.Status = model.TaskStatusFailure
+ default:
+ return nil, fmt.Errorf("unknown task status: %s", status)
}
- videos, ok := taskResult["videos"].([]interface{})
- if !ok || len(videos) == 0 {
- return "", fmt.Errorf("videos field not found or empty")
+ if videos := resPayload.Data.TaskResult.Videos; len(videos) > 0 {
+ video := videos[0]
+ taskInfo.Url = video.Url
}
- video, ok := videos[0].(map[string]interface{})
- if !ok {
- return "", fmt.Errorf("video item invalid")
- }
- url, ok := video["url"].(string)
- if !ok || url == "" {
- return "", fmt.Errorf("url field not found or invalid")
- }
- return url, nil
+ return taskInfo, nil
}
diff --git a/relay/channel/task/suno/adaptor.go b/relay/channel/task/suno/adaptor.go
index f7042348..9c04c7ad 100644
--- a/relay/channel/task/suno/adaptor.go
+++ b/relay/channel/task/suno/adaptor.go
@@ -22,8 +22,8 @@ type TaskAdaptor struct {
ChannelType int
}
-func (a *TaskAdaptor) ParseResultUrl(resp map[string]any) (string, error) {
- return "", nil // todo implement this method if needed
+func (a *TaskAdaptor) ParseTaskResult([]byte) (*relaycommon.TaskInfo, error) {
+ return nil, fmt.Errorf("not implement") // todo implement this method if needed
}
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go
index f3fc9ce9..5fd94788 100644
--- a/relay/common/relay_info.go
+++ b/relay/common/relay_info.go
@@ -313,3 +313,22 @@ func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo {
}
return info
}
+
+type TaskSubmitReq struct {
+ Prompt string `json:"prompt"`
+ Model string `json:"model,omitempty"`
+ Mode string `json:"mode,omitempty"`
+ Image string `json:"image,omitempty"`
+ Size string `json:"size,omitempty"`
+ Duration int `json:"duration,omitempty"`
+ Metadata map[string]interface{} `json:"metadata,omitempty"`
+}
+
+type TaskInfo struct {
+ Code int `json:"code"`
+ TaskID string `json:"task_id"`
+ Status string `json:"status"`
+ Reason string `json:"reason,omitempty"`
+ Url string `json:"url,omitempty"`
+ Progress string `json:"progress,omitempty"`
+}
diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go
index 1b8a1b2d..cc8a1494 100644
--- a/relay/constant/relay_mode.go
+++ b/relay/constant/relay_mode.go
@@ -41,6 +41,9 @@ const (
RelayModeKlingFetchByID
RelayModeKlingSubmit
+ RelayModeJimengFetchByID
+ RelayModeJimengSubmit
+
RelayModeRerank
RelayModeResponses
@@ -146,3 +149,13 @@ func Path2RelayKling(method, path string) int {
}
return relayMode
}
+
+func Path2RelayJimeng(method, path string) int {
+ relayMode := RelayModeUnknown
+ if method == http.MethodPost && strings.HasSuffix(path, "/video/generations") {
+ relayMode = RelayModeJimengSubmit
+ } else if method == http.MethodGet && strings.Contains(path, "/video/generations/") {
+ relayMode = RelayModeJimengFetchByID
+ }
+ return relayMode
+}
diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go
index 626bb7e4..f648b4d5 100644
--- a/relay/relay_adaptor.go
+++ b/relay/relay_adaptor.go
@@ -22,6 +22,7 @@ import (
"one-api/relay/channel/palm"
"one-api/relay/channel/perplexity"
"one-api/relay/channel/siliconflow"
+ "one-api/relay/channel/task/jimeng"
"one-api/relay/channel/task/kling"
"one-api/relay/channel/task/suno"
"one-api/relay/channel/tencent"
@@ -104,6 +105,8 @@ func GetTaskAdaptor(platform commonconstant.TaskPlatform) channel.TaskAdaptor {
return &suno.TaskAdaptor{}
case commonconstant.TaskPlatformKling:
return &kling.TaskAdaptor{}
+ case commonconstant.TaskPlatformJimeng:
+ return &jimeng.TaskAdaptor{}
}
return nil
}
diff --git a/relay/relay_task.go b/relay/relay_task.go
index b8004105..702cff4c 100644
--- a/relay/relay_task.go
+++ b/relay/relay_task.go
@@ -245,7 +245,7 @@ func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dt
}
func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
- taskId := c.Param("id")
+ taskId := c.Param("task_id")
userId := c.GetInt("id")
originTask, exist, err := model.GetByTaskId(userId, taskId)
diff --git a/web/src/components/table/TaskLogsTable.js b/web/src/components/table/TaskLogsTable.js
index 5b77ce39..65a8e2a6 100644
--- a/web/src/components/table/TaskLogsTable.js
+++ b/web/src/components/table/TaskLogsTable.js
@@ -230,8 +230,8 @@ const LogsTable = () => {
}
};
- const renderPlatform = (type) => {
- switch (type) {
+ const renderPlatform = (platform) => {
+ switch (platform) {
case 'suno':
return (
}>
@@ -240,10 +240,16 @@ const LogsTable = () => {
);
case 'kling':
return (
- }>
+ }>
Kling
);
+ case 'jimeng':
+ return (
+ }>
+ Jimeng
+
+ );
default:
return (
}>
diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js
index 4018aa4f..b145ea11 100644
--- a/web/src/constants/channel.constants.js
+++ b/web/src/constants/channel.constants.js
@@ -130,6 +130,11 @@ export const CHANNEL_OPTIONS = [
color: 'green',
label: '可灵',
},
+ {
+ value: 51,
+ color: 'blue',
+ label: '即梦',
+ },
];
export const MODEL_TABLE_PAGE_SIZE = 10;
diff --git a/web/src/i18n/locales/en.json b/web/src/i18n/locales/en.json
index 060f3e65..8a38ef0c 100644
--- a/web/src/i18n/locales/en.json
+++ b/web/src/i18n/locales/en.json
@@ -801,6 +801,7 @@
"获取无水印": "Get no watermark",
"生成图片": "Generate pictures",
"可灵": "Kling",
+ "即梦": "Jimeng",
"正在提交": "Submitting",
"执行中": "processing",
"平台": "platform",