This commit refactors the logging mechanism across the application by replacing direct logger calls with a centralized logging approach using the `common` package. Key changes include: - Replaced instances of `logger.SysLog` and `logger.FatalLog` with `common.SysLog` and `common.FatalLog` for consistent logging practices. - Updated resource initialization error handling to utilize the new logging structure, enhancing maintainability and readability. - Minor adjustments to improve code clarity and organization throughout various modules. This change aims to streamline logging and improve the overall architecture of the codebase.
381 lines
11 KiB
Go
381 lines
11 KiB
Go
package jimeng
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/hmac"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"one-api/model"
|
|
"sort"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/pkg/errors"
|
|
|
|
"one-api/common"
|
|
"one-api/constant"
|
|
"one-api/dto"
|
|
"one-api/relay/channel"
|
|
relaycommon "one-api/relay/common"
|
|
"one-api/service"
|
|
)
|
|
|
|
// ============================
|
|
// Request / Response structures
|
|
// ============================
|
|
|
|
type requestPayload struct {
|
|
ReqKey string `json:"req_key"`
|
|
BinaryDataBase64 []string `json:"binary_data_base64,omitempty"`
|
|
ImageUrls []string `json:"image_urls,omitempty"`
|
|
Prompt string `json:"prompt,omitempty"`
|
|
Seed int64 `json:"seed"`
|
|
AspectRatio string `json:"aspect_ratio"`
|
|
}
|
|
|
|
type responsePayload struct {
|
|
Code int `json:"code"`
|
|
Message string `json:"message"`
|
|
RequestId string `json:"request_id"`
|
|
Data struct {
|
|
TaskID string `json:"task_id"`
|
|
} `json:"data"`
|
|
}
|
|
|
|
type responseTask struct {
|
|
Code int `json:"code"`
|
|
Data struct {
|
|
BinaryDataBase64 []interface{} `json:"binary_data_base64"`
|
|
ImageUrls interface{} `json:"image_urls"`
|
|
RespData string `json:"resp_data"`
|
|
Status string `json:"status"`
|
|
VideoUrl string `json:"video_url"`
|
|
} `json:"data"`
|
|
Message string `json:"message"`
|
|
RequestId string `json:"request_id"`
|
|
Status int `json:"status"`
|
|
TimeElapsed string `json:"time_elapsed"`
|
|
}
|
|
|
|
// ============================
|
|
// Adaptor implementation
|
|
// ============================
|
|
|
|
type TaskAdaptor struct {
|
|
ChannelType int
|
|
accessKey string
|
|
secretKey string
|
|
baseURL string
|
|
}
|
|
|
|
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
|
|
a.ChannelType = info.ChannelType
|
|
a.baseURL = info.ChannelBaseUrl
|
|
|
|
// apiKey format: "access_key|secret_key"
|
|
keyParts := strings.Split(info.ApiKey, "|")
|
|
if len(keyParts) == 2 {
|
|
a.accessKey = strings.TrimSpace(keyParts[0])
|
|
a.secretKey = strings.TrimSpace(keyParts[1])
|
|
}
|
|
}
|
|
|
|
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
|
|
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) {
|
|
// Accept only POST /v1/video/generations as "generate" action.
|
|
action := constant.TaskActionGenerate
|
|
info.Action = action
|
|
|
|
req := relaycommon.TaskSubmitReq{}
|
|
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
|
|
taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
|
|
return
|
|
}
|
|
if strings.TrimSpace(req.Prompt) == "" {
|
|
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Store into context for later usage
|
|
c.Set("task_request", req)
|
|
return nil
|
|
}
|
|
|
|
// BuildRequestURL constructs the upstream URL.
|
|
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
|
|
return fmt.Sprintf("%s/?Action=CVSync2AsyncSubmitTask&Version=2022-08-31", a.baseURL), nil
|
|
}
|
|
|
|
// BuildRequestHeader sets required headers.
|
|
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error {
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Accept", "application/json")
|
|
return a.signRequest(req, a.accessKey, a.secretKey)
|
|
}
|
|
|
|
// BuildRequestBody converts request into Jimeng specific format.
|
|
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) {
|
|
v, exists := c.Get("task_request")
|
|
if !exists {
|
|
return nil, fmt.Errorf("request not found in context")
|
|
}
|
|
req := v.(relaycommon.TaskSubmitReq)
|
|
|
|
body, err := a.convertToRequestPayload(&req)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "convert request payload failed")
|
|
}
|
|
data, err := json.Marshal(body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return bytes.NewReader(data), nil
|
|
}
|
|
|
|
// DoRequest delegates to common helper.
|
|
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
|
|
return channel.DoTaskApiRequest(a, c, info, requestBody)
|
|
}
|
|
|
|
// DoResponse handles upstream response, returns taskID etc.
|
|
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
|
|
responseBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
_ = resp.Body.Close()
|
|
|
|
// Parse Jimeng response
|
|
var jResp responsePayload
|
|
if err := json.Unmarshal(responseBody, &jResp); err != nil {
|
|
taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
if jResp.Code != 10000 {
|
|
taskErr = service.TaskErrorWrapper(fmt.Errorf(jResp.Message), fmt.Sprintf("%d", jResp.Code), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, gin.H{"task_id": jResp.Data.TaskID})
|
|
return jResp.Data.TaskID, responseBody, nil
|
|
}
|
|
|
|
// FetchTask fetch task status
|
|
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
|
taskID, ok := body["task_id"].(string)
|
|
if !ok {
|
|
return nil, fmt.Errorf("invalid task_id")
|
|
}
|
|
|
|
uri := fmt.Sprintf("%s/?Action=CVSync2AsyncGetResult&Version=2022-08-31", baseUrl)
|
|
payload := map[string]string{
|
|
"req_key": "jimeng_vgfm_t2v_l20", // This is fixed value from doc: https://www.volcengine.com/docs/85621/1544774
|
|
"task_id": taskID,
|
|
}
|
|
payloadBytes, err := json.Marshal(payload)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "marshal fetch task payload failed")
|
|
}
|
|
|
|
req, err := http.NewRequest(http.MethodPost, uri, bytes.NewBuffer(payloadBytes))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
req.Header.Set("Accept", "application/json")
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
keyParts := strings.Split(key, "|")
|
|
if len(keyParts) != 2 {
|
|
return nil, fmt.Errorf("invalid api key format for jimeng: expected 'ak|sk'")
|
|
}
|
|
accessKey := strings.TrimSpace(keyParts[0])
|
|
secretKey := strings.TrimSpace(keyParts[1])
|
|
|
|
if err := a.signRequest(req, accessKey, secretKey); err != nil {
|
|
return nil, errors.Wrap(err, "sign request failed")
|
|
}
|
|
|
|
return service.GetHttpClient().Do(req)
|
|
}
|
|
|
|
func (a *TaskAdaptor) GetModelList() []string {
|
|
return []string{"jimeng_vgfm_t2v_l20"}
|
|
}
|
|
|
|
func (a *TaskAdaptor) GetChannelName() string {
|
|
return "jimeng"
|
|
}
|
|
|
|
func (a *TaskAdaptor) signRequest(req *http.Request, accessKey, secretKey string) error {
|
|
var bodyBytes []byte
|
|
var err error
|
|
|
|
if req.Body != nil {
|
|
bodyBytes, err = io.ReadAll(req.Body)
|
|
if err != nil {
|
|
return errors.Wrap(err, "read request body failed")
|
|
}
|
|
_ = req.Body.Close()
|
|
req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // Rewind
|
|
} else {
|
|
bodyBytes = []byte{}
|
|
}
|
|
|
|
payloadHash := sha256.Sum256(bodyBytes)
|
|
hexPayloadHash := hex.EncodeToString(payloadHash[:])
|
|
|
|
t := time.Now().UTC()
|
|
xDate := t.Format("20060102T150405Z")
|
|
shortDate := t.Format("20060102")
|
|
|
|
req.Header.Set("Host", req.URL.Host)
|
|
req.Header.Set("X-Date", xDate)
|
|
req.Header.Set("X-Content-Sha256", hexPayloadHash)
|
|
|
|
// Sort and encode query parameters to create canonical query string
|
|
queryParams := req.URL.Query()
|
|
sortedKeys := make([]string, 0, len(queryParams))
|
|
for k := range queryParams {
|
|
sortedKeys = append(sortedKeys, k)
|
|
}
|
|
sort.Strings(sortedKeys)
|
|
var queryParts []string
|
|
for _, k := range sortedKeys {
|
|
values := queryParams[k]
|
|
sort.Strings(values)
|
|
for _, v := range values {
|
|
queryParts = append(queryParts, fmt.Sprintf("%s=%s", url.QueryEscape(k), url.QueryEscape(v)))
|
|
}
|
|
}
|
|
canonicalQueryString := strings.Join(queryParts, "&")
|
|
|
|
headersToSign := map[string]string{
|
|
"host": req.URL.Host,
|
|
"x-date": xDate,
|
|
"x-content-sha256": hexPayloadHash,
|
|
}
|
|
if req.Header.Get("Content-Type") != "" {
|
|
headersToSign["content-type"] = req.Header.Get("Content-Type")
|
|
}
|
|
|
|
var signedHeaderKeys []string
|
|
for k := range headersToSign {
|
|
signedHeaderKeys = append(signedHeaderKeys, k)
|
|
}
|
|
sort.Strings(signedHeaderKeys)
|
|
|
|
var canonicalHeaders strings.Builder
|
|
for _, k := range signedHeaderKeys {
|
|
canonicalHeaders.WriteString(k)
|
|
canonicalHeaders.WriteString(":")
|
|
canonicalHeaders.WriteString(strings.TrimSpace(headersToSign[k]))
|
|
canonicalHeaders.WriteString("\n")
|
|
}
|
|
signedHeaders := strings.Join(signedHeaderKeys, ";")
|
|
|
|
canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s",
|
|
req.Method,
|
|
req.URL.Path,
|
|
canonicalQueryString,
|
|
canonicalHeaders.String(),
|
|
signedHeaders,
|
|
hexPayloadHash,
|
|
)
|
|
|
|
hashedCanonicalRequest := sha256.Sum256([]byte(canonicalRequest))
|
|
hexHashedCanonicalRequest := hex.EncodeToString(hashedCanonicalRequest[:])
|
|
|
|
region := "cn-north-1"
|
|
serviceName := "cv"
|
|
credentialScope := fmt.Sprintf("%s/%s/%s/request", shortDate, region, serviceName)
|
|
stringToSign := fmt.Sprintf("HMAC-SHA256\n%s\n%s\n%s",
|
|
xDate,
|
|
credentialScope,
|
|
hexHashedCanonicalRequest,
|
|
)
|
|
|
|
kDate := hmacSHA256([]byte(secretKey), []byte(shortDate))
|
|
kRegion := hmacSHA256(kDate, []byte(region))
|
|
kService := hmacSHA256(kRegion, []byte(serviceName))
|
|
kSigning := hmacSHA256(kService, []byte("request"))
|
|
signature := hex.EncodeToString(hmacSHA256(kSigning, []byte(stringToSign)))
|
|
|
|
authorization := fmt.Sprintf("HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s",
|
|
accessKey,
|
|
credentialScope,
|
|
signedHeaders,
|
|
signature,
|
|
)
|
|
req.Header.Set("Authorization", authorization)
|
|
return nil
|
|
}
|
|
|
|
func hmacSHA256(key []byte, data []byte) []byte {
|
|
h := hmac.New(sha256.New, key)
|
|
h.Write(data)
|
|
return h.Sum(nil)
|
|
}
|
|
|
|
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
|
|
r := requestPayload{
|
|
ReqKey: "jimeng_vgfm_i2v_l20",
|
|
Prompt: req.Prompt,
|
|
AspectRatio: "16:9", // Default aspect ratio
|
|
Seed: -1, // Default to random
|
|
}
|
|
|
|
// Handle one-of image_urls or binary_data_base64
|
|
if req.Image != "" {
|
|
if strings.HasPrefix(req.Image, "http") {
|
|
r.ImageUrls = []string{req.Image}
|
|
} else {
|
|
r.BinaryDataBase64 = []string{req.Image}
|
|
}
|
|
}
|
|
metadata := req.Metadata
|
|
medaBytes, err := json.Marshal(metadata)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "metadata marshal metadata failed")
|
|
}
|
|
err = json.Unmarshal(medaBytes, &r)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "unmarshal metadata failed")
|
|
}
|
|
return &r, nil
|
|
}
|
|
|
|
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
|
|
resTask := responseTask{}
|
|
if err := json.Unmarshal(respBody, &resTask); err != nil {
|
|
return nil, errors.Wrap(err, "unmarshal task result failed")
|
|
}
|
|
taskResult := relaycommon.TaskInfo{}
|
|
if resTask.Code == 10000 {
|
|
taskResult.Code = 0
|
|
} else {
|
|
taskResult.Code = resTask.Code // todo uni code
|
|
taskResult.Reason = resTask.Message
|
|
taskResult.Status = model.TaskStatusFailure
|
|
taskResult.Progress = "100%"
|
|
}
|
|
switch resTask.Data.Status {
|
|
case "in_queue":
|
|
taskResult.Status = model.TaskStatusQueued
|
|
taskResult.Progress = "10%"
|
|
case "done":
|
|
taskResult.Status = model.TaskStatusSuccess
|
|
taskResult.Progress = "100%"
|
|
}
|
|
taskResult.Url = resTask.Data.VideoUrl
|
|
return &taskResult, nil
|
|
}
|