Merge pull request #1274 from feitianbubu/feat/add-channel-jimeng
feat: 支持即梦视频渠道
This commit is contained in:
@@ -242,6 +242,7 @@ const (
|
||||
ChannelTypeXai = 48
|
||||
ChannelTypeCoze = 49
|
||||
ChannelTypeKling = 50
|
||||
ChannelTypeJimeng = 51
|
||||
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
||||
|
||||
)
|
||||
@@ -298,4 +299,5 @@ var ChannelBaseURLs = []string{
|
||||
"https://api.x.ai", //48
|
||||
"https://api.coze.cn", //49
|
||||
"https://api.klingai.com", //50
|
||||
"https://visual.volcengineapi.com", //51
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ const (
|
||||
TaskPlatformSuno TaskPlatform = "suno"
|
||||
TaskPlatformMidjourney = "mj"
|
||||
TaskPlatformKling TaskPlatform = "kling"
|
||||
TaskPlatformJimeng TaskPlatform = "jimeng"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -43,6 +43,9 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
||||
if channel.Type == common.ChannelTypeKling {
|
||||
return errors.New("kling channel test is not supported"), nil
|
||||
}
|
||||
if channel.Type == common.ChannelTypeJimeng {
|
||||
return errors.New("jimeng channel test is not supported"), nil
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
|
||||
@@ -74,8 +74,8 @@ func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][
|
||||
//_ = UpdateMidjourneyTaskAll(context.Background(), tasks)
|
||||
case constant.TaskPlatformSuno:
|
||||
_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
|
||||
case constant.TaskPlatformKling:
|
||||
_ = UpdateVideoTaskAll(context.Background(), taskChannelM, taskM)
|
||||
case constant.TaskPlatformKling, constant.TaskPlatformJimeng:
|
||||
_ = UpdateVideoTaskAll(context.Background(), platform, taskChannelM, taskM)
|
||||
default:
|
||||
common.SysLog("未知平台")
|
||||
}
|
||||
|
||||
@@ -2,27 +2,26 @@ package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/model"
|
||||
"one-api/relay"
|
||||
"one-api/relay/channel"
|
||||
"time"
|
||||
)
|
||||
|
||||
func UpdateVideoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
|
||||
func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
|
||||
for channelId, taskIds := range taskChannelM {
|
||||
if err := updateVideoTaskAll(ctx, channelId, taskIds, taskM); err != nil {
|
||||
if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateVideoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
|
||||
func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
|
||||
common.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
|
||||
if len(taskIds) == 0 {
|
||||
return nil
|
||||
@@ -39,7 +38,7 @@ func updateVideoTaskAll(ctx context.Context, channelId int, taskIds []string, ta
|
||||
}
|
||||
return fmt.Errorf("CacheGetChannel failed: %w", err)
|
||||
}
|
||||
adaptor := relay.GetTaskAdaptor(constant.TaskPlatformKling)
|
||||
adaptor := relay.GetTaskAdaptor(platform)
|
||||
if adaptor == nil {
|
||||
return fmt.Errorf("video adaptor not found")
|
||||
}
|
||||
@@ -67,60 +66,53 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
||||
"action": task.Action,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("FetchTask failed for task %s: %w", taskId, err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("Get Video Task status code: %d", resp.StatusCode)
|
||||
return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err)
|
||||
}
|
||||
//if resp.StatusCode != http.StatusOK {
|
||||
//return fmt.Errorf("get Video Task status code: %d", resp.StatusCode)
|
||||
//}
|
||||
defer resp.Body.Close()
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ReadAll failed for task %s: %w", taskId, err)
|
||||
return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
|
||||
}
|
||||
|
||||
var responseItem map[string]interface{}
|
||||
err = json.Unmarshal(responseBody, &responseItem)
|
||||
taskResult, err := adaptor.ParseTaskResult(responseBody)
|
||||
if err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Failed to parse video task response body: %v, body: %s", err, string(responseBody)))
|
||||
return fmt.Errorf("Unmarshal failed for task %s: %w", taskId, err)
|
||||
return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
|
||||
}
|
||||
//if taskResult.Code != 0 {
|
||||
// return fmt.Errorf("video task fetch failed for task %s", taskId)
|
||||
//}
|
||||
|
||||
code, _ := responseItem["code"].(float64)
|
||||
if code != 0 {
|
||||
return fmt.Errorf("video task fetch failed for task %s", taskId)
|
||||
now := time.Now().Unix()
|
||||
if taskResult.Status == "" {
|
||||
return fmt.Errorf("task %s status is empty", taskId)
|
||||
}
|
||||
|
||||
data, ok := responseItem["data"].(map[string]interface{})
|
||||
if !ok {
|
||||
common.LogError(ctx, fmt.Sprintf("Video task data format error: %s", string(responseBody)))
|
||||
return fmt.Errorf("video task data format error for task %s", taskId)
|
||||
}
|
||||
|
||||
if status, ok := data["task_status"].(string); ok {
|
||||
switch status {
|
||||
case "submitted", "queued":
|
||||
task.Status = model.TaskStatusSubmitted
|
||||
case "processing":
|
||||
task.Status = model.TaskStatusInProgress
|
||||
case "succeed":
|
||||
task.Status = model.TaskStatusSuccess
|
||||
task.Progress = "100%"
|
||||
if url, err := adaptor.ParseResultUrl(responseItem); err == nil {
|
||||
task.FailReason = url
|
||||
} else {
|
||||
common.LogWarn(ctx, fmt.Sprintf("Failed to get url from body for task %s: %s", task.TaskID, err.Error()))
|
||||
}
|
||||
case "failed":
|
||||
task.Status = model.TaskStatusFailure
|
||||
task.Progress = "100%"
|
||||
if reason, ok := data["fail_reason"].(string); ok {
|
||||
task.FailReason = reason
|
||||
}
|
||||
task.Status = model.TaskStatus(taskResult.Status)
|
||||
switch taskResult.Status {
|
||||
case model.TaskStatusSubmitted:
|
||||
task.Progress = "10%"
|
||||
case model.TaskStatusQueued:
|
||||
task.Progress = "20%"
|
||||
case model.TaskStatusInProgress:
|
||||
task.Progress = "30%"
|
||||
if task.StartTime == 0 {
|
||||
task.StartTime = now
|
||||
}
|
||||
}
|
||||
|
||||
// If task failed, refund quota
|
||||
if task.Status == model.TaskStatusFailure {
|
||||
case model.TaskStatusSuccess:
|
||||
task.Progress = "100%"
|
||||
if task.FinishTime == 0 {
|
||||
task.FinishTime = now
|
||||
}
|
||||
task.FailReason = taskResult.Url
|
||||
case model.TaskStatusFailure:
|
||||
task.Status = model.TaskStatusFailure
|
||||
task.Progress = "100%"
|
||||
if task.FinishTime == 0 {
|
||||
task.FinishTime = now
|
||||
}
|
||||
task.FailReason = taskResult.Reason
|
||||
common.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
|
||||
quota := task.Quota
|
||||
if quota != 0 {
|
||||
@@ -130,6 +122,11 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
||||
logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, common.LogQuota(quota))
|
||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId)
|
||||
}
|
||||
if taskResult.Progress != "" {
|
||||
task.Progress = taskResult.Progress
|
||||
}
|
||||
|
||||
task.Data = responseBody
|
||||
|
||||
@@ -171,13 +171,23 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
||||
c.Set("platform", string(constant.TaskPlatformSuno))
|
||||
c.Set("relay_mode", relayMode)
|
||||
} else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") {
|
||||
relayMode := relayconstant.Path2RelayKling(c.Request.Method, c.Request.URL.Path)
|
||||
if relayMode == relayconstant.RelayModeKlingFetchByID {
|
||||
shouldSelectChannel = false
|
||||
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
||||
var platform string
|
||||
var relayMode int
|
||||
if strings.HasPrefix(modelRequest.Model, "jimeng") {
|
||||
platform = string(constant.TaskPlatformJimeng)
|
||||
relayMode = relayconstant.Path2RelayJimeng(c.Request.Method, c.Request.URL.Path)
|
||||
if relayMode == relayconstant.RelayModeJimengFetchByID {
|
||||
shouldSelectChannel = false
|
||||
}
|
||||
} else {
|
||||
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
||||
platform = string(constant.TaskPlatformKling)
|
||||
relayMode = relayconstant.Path2RelayKling(c.Request.Method, c.Request.URL.Path)
|
||||
if relayMode == relayconstant.RelayModeKlingFetchByID {
|
||||
shouldSelectChannel = false
|
||||
}
|
||||
}
|
||||
c.Set("platform", string(constant.TaskPlatformKling))
|
||||
c.Set("platform", platform)
|
||||
c.Set("relay_mode", relayMode)
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") || strings.HasPrefix(c.Request.URL.Path, "/v1/models/") {
|
||||
// Gemini API 路径处理: /v1beta/models/gemini-2.0-flash:generateContent
|
||||
|
||||
@@ -45,5 +45,5 @@ type TaskAdaptor interface {
|
||||
// FetchTask
|
||||
FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error)
|
||||
|
||||
ParseResultUrl(resp map[string]any) (string, error)
|
||||
ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error)
|
||||
}
|
||||
|
||||
379
relay/channel/task/jimeng/adaptor.go
Normal file
379
relay/channel/task/jimeng/adaptor.go
Normal file
@@ -0,0 +1,379 @@
|
||||
package jimeng
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"one-api/model"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
)
|
||||
|
||||
// ============================
|
||||
// Request / Response structures
|
||||
// ============================
|
||||
|
||||
type requestPayload struct {
|
||||
ReqKey string `json:"req_key"`
|
||||
BinaryDataBase64 []string `json:"binary_data_base64,omitempty"`
|
||||
ImageUrls []string `json:"image_urls,omitempty"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
Seed int64 `json:"seed"`
|
||||
AspectRatio string `json:"aspect_ratio"`
|
||||
}
|
||||
|
||||
type responsePayload struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
RequestId string `json:"request_id"`
|
||||
Data struct {
|
||||
TaskID string `json:"task_id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
type responseTask struct {
|
||||
Code int `json:"code"`
|
||||
Data struct {
|
||||
BinaryDataBase64 []interface{} `json:"binary_data_base64"`
|
||||
ImageUrls interface{} `json:"image_urls"`
|
||||
RespData string `json:"resp_data"`
|
||||
Status string `json:"status"`
|
||||
VideoUrl string `json:"video_url"`
|
||||
} `json:"data"`
|
||||
Message string `json:"message"`
|
||||
RequestId string `json:"request_id"`
|
||||
Status int `json:"status"`
|
||||
TimeElapsed string `json:"time_elapsed"`
|
||||
}
|
||||
|
||||
// ============================
|
||||
// Adaptor implementation
|
||||
// ============================
|
||||
|
||||
type TaskAdaptor struct {
|
||||
ChannelType int
|
||||
accessKey string
|
||||
secretKey string
|
||||
baseURL string
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
|
||||
a.ChannelType = info.ChannelType
|
||||
a.baseURL = info.BaseUrl
|
||||
|
||||
// apiKey format: "access_key,secret_key"
|
||||
keyParts := strings.Split(info.ApiKey, ",")
|
||||
if len(keyParts) == 2 {
|
||||
a.accessKey = strings.TrimSpace(keyParts[0])
|
||||
a.secretKey = strings.TrimSpace(keyParts[1])
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) {
|
||||
// Accept only POST /v1/video/generations as "generate" action.
|
||||
action := "generate"
|
||||
info.Action = action
|
||||
|
||||
req := relaycommon.TaskSubmitReq{}
|
||||
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
|
||||
taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(req.Prompt) == "" {
|
||||
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Store into context for later usage
|
||||
c.Set("task_request", req)
|
||||
return nil
|
||||
}
|
||||
|
||||
// BuildRequestURL constructs the upstream URL.
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/?Action=CVSync2AsyncSubmitTask&Version=2022-08-31", a.baseURL), nil
|
||||
}
|
||||
|
||||
// BuildRequestHeader sets required headers.
|
||||
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
return a.signRequest(req, a.accessKey, a.secretKey)
|
||||
}
|
||||
|
||||
// BuildRequestBody converts request into Jimeng specific format.
|
||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) {
|
||||
v, exists := c.Get("task_request")
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("request not found in context")
|
||||
}
|
||||
req := v.(relaycommon.TaskSubmitReq)
|
||||
|
||||
body, err := a.convertToRequestPayload(&req)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "convert request payload failed")
|
||||
}
|
||||
data, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return bytes.NewReader(data), nil
|
||||
}
|
||||
|
||||
// DoRequest delegates to common helper.
|
||||
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoTaskApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
// DoResponse handles upstream response, returns taskID etc.
|
||||
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
|
||||
// Parse Jimeng response
|
||||
var jResp responsePayload
|
||||
if err := json.Unmarshal(responseBody, &jResp); err != nil {
|
||||
taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if jResp.Code != 10000 {
|
||||
taskErr = service.TaskErrorWrapper(fmt.Errorf(jResp.Message), fmt.Sprintf("%d", jResp.Code), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"task_id": jResp.Data.TaskID})
|
||||
return jResp.Data.TaskID, responseBody, nil
|
||||
}
|
||||
|
||||
// FetchTask fetch task status
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
||||
taskID, ok := body["task_id"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid task_id")
|
||||
}
|
||||
|
||||
uri := fmt.Sprintf("%s/?Action=CVSync2AsyncGetResult&Version=2022-08-31", baseUrl)
|
||||
payload := map[string]string{
|
||||
"req_key": "jimeng_vgfm_t2v_l20", // This is fixed value from doc: https://www.volcengine.com/docs/85621/1544774
|
||||
"task_id": taskID,
|
||||
}
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "marshal fetch task payload failed")
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, uri, bytes.NewBuffer(payloadBytes))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
keyParts := strings.Split(key, ",")
|
||||
if len(keyParts) != 2 {
|
||||
return nil, fmt.Errorf("invalid api key format for jimeng: expected 'ak,sk'")
|
||||
}
|
||||
accessKey := strings.TrimSpace(keyParts[0])
|
||||
secretKey := strings.TrimSpace(keyParts[1])
|
||||
|
||||
if err := a.signRequest(req, accessKey, secretKey); err != nil {
|
||||
return nil, errors.Wrap(err, "sign request failed")
|
||||
}
|
||||
|
||||
return service.GetHttpClient().Do(req)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) GetModelList() []string {
|
||||
return []string{"jimeng_vgfm_t2v_l20"}
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) GetChannelName() string {
|
||||
return "jimeng"
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) signRequest(req *http.Request, accessKey, secretKey string) error {
|
||||
var bodyBytes []byte
|
||||
var err error
|
||||
|
||||
if req.Body != nil {
|
||||
bodyBytes, err = io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "read request body failed")
|
||||
}
|
||||
_ = req.Body.Close()
|
||||
req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // Rewind
|
||||
} else {
|
||||
bodyBytes = []byte{}
|
||||
}
|
||||
|
||||
payloadHash := sha256.Sum256(bodyBytes)
|
||||
hexPayloadHash := hex.EncodeToString(payloadHash[:])
|
||||
|
||||
t := time.Now().UTC()
|
||||
xDate := t.Format("20060102T150405Z")
|
||||
shortDate := t.Format("20060102")
|
||||
|
||||
req.Header.Set("Host", req.URL.Host)
|
||||
req.Header.Set("X-Date", xDate)
|
||||
req.Header.Set("X-Content-Sha256", hexPayloadHash)
|
||||
|
||||
// Sort and encode query parameters to create canonical query string
|
||||
queryParams := req.URL.Query()
|
||||
sortedKeys := make([]string, 0, len(queryParams))
|
||||
for k := range queryParams {
|
||||
sortedKeys = append(sortedKeys, k)
|
||||
}
|
||||
sort.Strings(sortedKeys)
|
||||
var queryParts []string
|
||||
for _, k := range sortedKeys {
|
||||
values := queryParams[k]
|
||||
sort.Strings(values)
|
||||
for _, v := range values {
|
||||
queryParts = append(queryParts, fmt.Sprintf("%s=%s", url.QueryEscape(k), url.QueryEscape(v)))
|
||||
}
|
||||
}
|
||||
canonicalQueryString := strings.Join(queryParts, "&")
|
||||
|
||||
headersToSign := map[string]string{
|
||||
"host": req.URL.Host,
|
||||
"x-date": xDate,
|
||||
"x-content-sha256": hexPayloadHash,
|
||||
}
|
||||
if req.Header.Get("Content-Type") != "" {
|
||||
headersToSign["content-type"] = req.Header.Get("Content-Type")
|
||||
}
|
||||
|
||||
var signedHeaderKeys []string
|
||||
for k := range headersToSign {
|
||||
signedHeaderKeys = append(signedHeaderKeys, k)
|
||||
}
|
||||
sort.Strings(signedHeaderKeys)
|
||||
|
||||
var canonicalHeaders strings.Builder
|
||||
for _, k := range signedHeaderKeys {
|
||||
canonicalHeaders.WriteString(k)
|
||||
canonicalHeaders.WriteString(":")
|
||||
canonicalHeaders.WriteString(strings.TrimSpace(headersToSign[k]))
|
||||
canonicalHeaders.WriteString("\n")
|
||||
}
|
||||
signedHeaders := strings.Join(signedHeaderKeys, ";")
|
||||
|
||||
canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s",
|
||||
req.Method,
|
||||
req.URL.Path,
|
||||
canonicalQueryString,
|
||||
canonicalHeaders.String(),
|
||||
signedHeaders,
|
||||
hexPayloadHash,
|
||||
)
|
||||
|
||||
hashedCanonicalRequest := sha256.Sum256([]byte(canonicalRequest))
|
||||
hexHashedCanonicalRequest := hex.EncodeToString(hashedCanonicalRequest[:])
|
||||
|
||||
region := "cn-north-1"
|
||||
serviceName := "cv"
|
||||
credentialScope := fmt.Sprintf("%s/%s/%s/request", shortDate, region, serviceName)
|
||||
stringToSign := fmt.Sprintf("HMAC-SHA256\n%s\n%s\n%s",
|
||||
xDate,
|
||||
credentialScope,
|
||||
hexHashedCanonicalRequest,
|
||||
)
|
||||
|
||||
kDate := hmacSHA256([]byte(secretKey), []byte(shortDate))
|
||||
kRegion := hmacSHA256(kDate, []byte(region))
|
||||
kService := hmacSHA256(kRegion, []byte(serviceName))
|
||||
kSigning := hmacSHA256(kService, []byte("request"))
|
||||
signature := hex.EncodeToString(hmacSHA256(kSigning, []byte(stringToSign)))
|
||||
|
||||
authorization := fmt.Sprintf("HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s",
|
||||
accessKey,
|
||||
credentialScope,
|
||||
signedHeaders,
|
||||
signature,
|
||||
)
|
||||
req.Header.Set("Authorization", authorization)
|
||||
return nil
|
||||
}
|
||||
|
||||
func hmacSHA256(key []byte, data []byte) []byte {
|
||||
h := hmac.New(sha256.New, key)
|
||||
h.Write(data)
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
|
||||
r := requestPayload{
|
||||
ReqKey: "jimeng_vgfm_i2v_l20",
|
||||
Prompt: req.Prompt,
|
||||
AspectRatio: "16:9", // Default aspect ratio
|
||||
Seed: -1, // Default to random
|
||||
}
|
||||
|
||||
// Handle one-of image_urls or binary_data_base64
|
||||
if req.Image != "" {
|
||||
if strings.HasPrefix(req.Image, "http") {
|
||||
r.ImageUrls = []string{req.Image}
|
||||
} else {
|
||||
r.BinaryDataBase64 = []string{req.Image}
|
||||
}
|
||||
}
|
||||
metadata := req.Metadata
|
||||
medaBytes, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "metadata marshal metadata failed")
|
||||
}
|
||||
err = json.Unmarshal(medaBytes, &r)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal metadata failed")
|
||||
}
|
||||
return &r, nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
|
||||
resTask := responseTask{}
|
||||
if err := json.Unmarshal(respBody, &resTask); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal task result failed")
|
||||
}
|
||||
taskResult := relaycommon.TaskInfo{}
|
||||
if resTask.Code == 10000 {
|
||||
taskResult.Code = 0
|
||||
} else {
|
||||
taskResult.Code = resTask.Code // todo uni code
|
||||
taskResult.Reason = resTask.Message
|
||||
taskResult.Status = model.TaskStatusFailure
|
||||
taskResult.Progress = "100%"
|
||||
}
|
||||
switch resTask.Data.Status {
|
||||
case "in_queue":
|
||||
taskResult.Status = model.TaskStatusQueued
|
||||
taskResult.Progress = "10%"
|
||||
case "done":
|
||||
taskResult.Status = model.TaskStatusSuccess
|
||||
taskResult.Progress = "100%"
|
||||
}
|
||||
taskResult.Url = resTask.Data.VideoUrl
|
||||
return &taskResult, nil
|
||||
}
|
||||
@@ -2,12 +2,12 @@ package kling
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/samber/lo"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/model"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -47,10 +47,22 @@ type requestPayload struct {
|
||||
}
|
||||
|
||||
type responsePayload struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data struct {
|
||||
TaskID string `json:"task_id"`
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
RequestId string `json:"request_id"`
|
||||
Data struct {
|
||||
TaskId string `json:"task_id"`
|
||||
TaskStatus string `json:"task_status"`
|
||||
TaskStatusMsg string `json:"task_status_msg"`
|
||||
TaskResult struct {
|
||||
Videos []struct {
|
||||
Id string `json:"id"`
|
||||
Url string `json:"url"`
|
||||
Duration string `json:"duration"`
|
||||
} `json:"videos"`
|
||||
} `json:"task_result"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
@@ -94,7 +106,7 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
|
||||
}
|
||||
|
||||
// Store into context for later usage
|
||||
c.Set("kling_request", req)
|
||||
c.Set("task_request", req)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -120,7 +132,7 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
|
||||
|
||||
// BuildRequestBody converts request into Kling specific format.
|
||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) {
|
||||
v, exists := c.Get("kling_request")
|
||||
v, exists := c.Get("task_request")
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("request not found in context")
|
||||
}
|
||||
@@ -156,8 +168,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
// Attempt Kling response parse first.
|
||||
var kResp responsePayload
|
||||
if err := json.Unmarshal(responseBody, &kResp); err == nil && kResp.Code == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{"task_id": kResp.Data.TaskID})
|
||||
return kResp.Data.TaskID, responseBody, nil
|
||||
c.JSON(http.StatusOK, gin.H{"task_id": kResp.Data.TaskId})
|
||||
return kResp.Data.TaskId, responseBody, nil
|
||||
}
|
||||
|
||||
// Fallback generic task response.
|
||||
@@ -199,10 +211,6 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
|
||||
token = key
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req = req.WithContext(ctx)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("User-Agent", "kling-sdk/1.0")
|
||||
@@ -305,27 +313,33 @@ func (a *TaskAdaptor) createJWTTokenWithKeys(accessKey, secretKey string) (strin
|
||||
return token.SignedString([]byte(secretKey))
|
||||
}
|
||||
|
||||
// ParseResultUrl 提取视频任务结果的 url
|
||||
func (a *TaskAdaptor) ParseResultUrl(resp map[string]any) (string, error) {
|
||||
data, ok := resp["data"].(map[string]any)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("data field not found or invalid")
|
||||
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
|
||||
resPayload := responsePayload{}
|
||||
err := json.Unmarshal(respBody, &resPayload)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to unmarshal response body")
|
||||
}
|
||||
taskResult, ok := data["task_result"].(map[string]any)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("task_result field not found or invalid")
|
||||
taskInfo := &relaycommon.TaskInfo{}
|
||||
taskInfo.Code = resPayload.Code
|
||||
taskInfo.TaskID = resPayload.Data.TaskId
|
||||
taskInfo.Reason = resPayload.Message
|
||||
//任务状态,枚举值:submitted(已提交)、processing(处理中)、succeed(成功)、failed(失败)
|
||||
status := resPayload.Data.TaskStatus
|
||||
switch status {
|
||||
case "submitted":
|
||||
taskInfo.Status = model.TaskStatusSubmitted
|
||||
case "processing":
|
||||
taskInfo.Status = model.TaskStatusInProgress
|
||||
case "succeed":
|
||||
taskInfo.Status = model.TaskStatusSuccess
|
||||
case "failed":
|
||||
taskInfo.Status = model.TaskStatusFailure
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown task status: %s", status)
|
||||
}
|
||||
videos, ok := taskResult["videos"].([]interface{})
|
||||
if !ok || len(videos) == 0 {
|
||||
return "", fmt.Errorf("videos field not found or empty")
|
||||
if videos := resPayload.Data.TaskResult.Videos; len(videos) > 0 {
|
||||
video := videos[0]
|
||||
taskInfo.Url = video.Url
|
||||
}
|
||||
video, ok := videos[0].(map[string]interface{})
|
||||
if !ok {
|
||||
return "", fmt.Errorf("video item invalid")
|
||||
}
|
||||
url, ok := video["url"].(string)
|
||||
if !ok || url == "" {
|
||||
return "", fmt.Errorf("url field not found or invalid")
|
||||
}
|
||||
return url, nil
|
||||
return taskInfo, nil
|
||||
}
|
||||
|
||||
@@ -22,8 +22,8 @@ type TaskAdaptor struct {
|
||||
ChannelType int
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ParseResultUrl(resp map[string]any) (string, error) {
|
||||
return "", nil // todo implement this method if needed
|
||||
func (a *TaskAdaptor) ParseTaskResult([]byte) (*relaycommon.TaskInfo, error) {
|
||||
return nil, fmt.Errorf("not implement") // todo implement this method if needed
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
|
||||
|
||||
@@ -313,3 +313,22 @@ func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo {
|
||||
}
|
||||
return info
|
||||
}
|
||||
|
||||
type TaskSubmitReq struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Mode string `json:"mode,omitempty"`
|
||||
Image string `json:"image,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Duration int `json:"duration,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
type TaskInfo struct {
|
||||
Code int `json:"code"`
|
||||
TaskID string `json:"task_id"`
|
||||
Status string `json:"status"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
Url string `json:"url,omitempty"`
|
||||
Progress string `json:"progress,omitempty"`
|
||||
}
|
||||
|
||||
@@ -41,6 +41,9 @@ const (
|
||||
RelayModeKlingFetchByID
|
||||
RelayModeKlingSubmit
|
||||
|
||||
RelayModeJimengFetchByID
|
||||
RelayModeJimengSubmit
|
||||
|
||||
RelayModeRerank
|
||||
|
||||
RelayModeResponses
|
||||
@@ -146,3 +149,13 @@ func Path2RelayKling(method, path string) int {
|
||||
}
|
||||
return relayMode
|
||||
}
|
||||
|
||||
func Path2RelayJimeng(method, path string) int {
|
||||
relayMode := RelayModeUnknown
|
||||
if method == http.MethodPost && strings.HasSuffix(path, "/video/generations") {
|
||||
relayMode = RelayModeJimengSubmit
|
||||
} else if method == http.MethodGet && strings.Contains(path, "/video/generations/") {
|
||||
relayMode = RelayModeJimengFetchByID
|
||||
}
|
||||
return relayMode
|
||||
}
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
"one-api/relay/channel/palm"
|
||||
"one-api/relay/channel/perplexity"
|
||||
"one-api/relay/channel/siliconflow"
|
||||
"one-api/relay/channel/task/jimeng"
|
||||
"one-api/relay/channel/task/kling"
|
||||
"one-api/relay/channel/task/suno"
|
||||
"one-api/relay/channel/tencent"
|
||||
@@ -104,6 +105,8 @@ func GetTaskAdaptor(platform commonconstant.TaskPlatform) channel.TaskAdaptor {
|
||||
return &suno.TaskAdaptor{}
|
||||
case commonconstant.TaskPlatformKling:
|
||||
return &kling.TaskAdaptor{}
|
||||
case commonconstant.TaskPlatformJimeng:
|
||||
return &jimeng.TaskAdaptor{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -245,7 +245,7 @@ func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dt
|
||||
}
|
||||
|
||||
func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
|
||||
taskId := c.Param("id")
|
||||
taskId := c.Param("task_id")
|
||||
userId := c.GetInt("id")
|
||||
|
||||
originTask, exist, err := model.GetByTaskId(userId, taskId)
|
||||
|
||||
@@ -230,8 +230,8 @@ const LogsTable = () => {
|
||||
}
|
||||
};
|
||||
|
||||
const renderPlatform = (type) => {
|
||||
switch (type) {
|
||||
const renderPlatform = (platform) => {
|
||||
switch (platform) {
|
||||
case 'suno':
|
||||
return (
|
||||
<Tag color='green' size='large' shape='circle' prefixIcon={<Music size={14} />}>
|
||||
@@ -240,10 +240,16 @@ const LogsTable = () => {
|
||||
);
|
||||
case 'kling':
|
||||
return (
|
||||
<Tag color='blue' size='large' shape='circle' prefixIcon={<Video size={14} />}>
|
||||
<Tag color='orange' size='large' shape='circle' prefixIcon={<Video size={14} />}>
|
||||
Kling
|
||||
</Tag>
|
||||
);
|
||||
case 'jimeng':
|
||||
return (
|
||||
<Tag color='purple' size='large' shape='circle' prefixIcon={<Video size={14} />}>
|
||||
Jimeng
|
||||
</Tag>
|
||||
);
|
||||
default:
|
||||
return (
|
||||
<Tag color='white' size='large' shape='circle' prefixIcon={<HelpCircle size={14} />}>
|
||||
|
||||
@@ -130,6 +130,11 @@ export const CHANNEL_OPTIONS = [
|
||||
color: 'green',
|
||||
label: '可灵',
|
||||
},
|
||||
{
|
||||
value: 51,
|
||||
color: 'blue',
|
||||
label: '即梦',
|
||||
},
|
||||
];
|
||||
|
||||
export const MODEL_TABLE_PAGE_SIZE = 10;
|
||||
|
||||
@@ -801,6 +801,7 @@
|
||||
"获取无水印": "Get no watermark",
|
||||
"生成图片": "Generate pictures",
|
||||
"可灵": "Kling",
|
||||
"即梦": "Jimeng",
|
||||
"正在提交": "Submitting",
|
||||
"执行中": "processing",
|
||||
"平台": "platform",
|
||||
|
||||
Reference in New Issue
Block a user