Merge pull request #1363 from feitianbubu/pr/add-jimeng-image

feat: 增加即梦生图功能
This commit is contained in:
Calcium-Ion
2025-07-18 20:45:49 +08:00
committed by GitHub
8 changed files with 421 additions and 2 deletions

View File

@@ -63,6 +63,8 @@ func ChannelType2APIType(channelType int) (int, bool) {
apiType = constant.APITypeXai
case constant.ChannelTypeCoze:
apiType = constant.APITypeCoze
case constant.ChannelTypeJimeng:
apiType = constant.APITypeJimeng
}
if apiType == -1 {
return constant.APITypeOpenAI, false

View File

@@ -30,5 +30,6 @@ const (
APITypeXinference
APITypeXai
APITypeCoze
APITypeJimeng
APITypeDummy // this one is only for count, do not add any channel after this
)

View File

@@ -203,6 +203,9 @@ func sendPingData(c *gin.Context, mutex *sync.Mutex) error {
}
}
func DoRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http.Response, error) {
return doRequest(c, req, info)
}
func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http.Response, error) {
var client *http.Client
var err error

View File

@@ -0,0 +1,136 @@
package jimeng
import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/types"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/?Action=CVProcess&Version=2022-08-31", info.BaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *relaycommon.RelayInfo) error {
return errors.New("not implemented")
}
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
type LogoInfo struct {
AddLogo bool `json:"add_logo,omitempty"`
Position int `json:"position,omitempty"`
Language int `json:"language,omitempty"`
Opacity float64 `json:"opacity,omitempty"`
LogoTextContent string `json:"logo_text_content,omitempty"`
}
type imageRequestPayload struct {
ReqKey string `json:"req_key"` // Service identifier, fixed value: jimeng_high_aes_general_v21_L
Prompt string `json:"prompt"` // Prompt for image generation, supports both Chinese and English
Seed int64 `json:"seed,omitempty"` // Random seed, default -1 (random)
Width int `json:"width,omitempty"` // Image width, default 512, range [256, 768]
Height int `json:"height,omitempty"` // Image height, default 512, range [256, 768]
UsePreLLM bool `json:"use_pre_llm,omitempty"` // Enable text expansion, default true
UseSR bool `json:"use_sr,omitempty"` // Enable super resolution, default true
ReturnURL bool `json:"return_url,omitempty"` // Whether to return image URL (valid for 24 hours)
LogoInfo LogoInfo `json:"logo_info,omitempty"` // Watermark information
ImageUrls []string `json:"image_urls,omitempty"` // Image URLs for input
BinaryData []string `json:"binary_data_base64,omitempty"` // Base64 encoded binary data
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
payload := imageRequestPayload{
ReqKey: request.Model,
Prompt: request.Prompt,
}
if request.ResponseFormat == "" || request.ResponseFormat == "url" {
payload.ReturnURL = true // Default to returning image URLs
}
if len(request.ExtraFields) > 0 {
if err := json.Unmarshal(request.ExtraFields, &payload); err != nil {
return nil, fmt.Errorf("failed to unmarshal extra fields: %w", err)
}
}
return payload, nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
fullRequestURL, err := a.GetRequestURL(info)
if err != nil {
return nil, fmt.Errorf("get request url failed: %w", err)
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return nil, fmt.Errorf("new request failed: %w", err)
}
err = Sign(c, req, info.ApiKey)
if err != nil {
return nil, fmt.Errorf("setup request header failed: %w", err)
}
resp, err := channel.DoRequest(c, req, info)
if err != nil {
return nil, fmt.Errorf("do request failed: %w", err)
}
return resp, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.RelayMode == relayconstant.RelayModeImagesGenerations {
usage, err = jimengImageHandler(c, resp, info)
} else if info.IsStream {
usage, err = openai.OaiStreamHandler(c, info, resp)
} else {
usage, err = openai.OpenaiHandler(c, info, resp)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,9 @@
package jimeng
const (
ChannelName = "jimeng"
)
var ModelList = []string{
"jimeng_high_aes_general_v21_L",
}

View File

@@ -0,0 +1,89 @@
package jimeng
import (
"encoding/json"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/types"
"github.com/gin-gonic/gin"
)
type ImageResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Data struct {
BinaryDataBase64 []string `json:"binary_data_base64"`
ImageUrls []string `json:"image_urls"`
RephraseResult string `json:"rephraser_result"`
RequestID string `json:"request_id"`
// Other fields are omitted for brevity
} `json:"data"`
RequestID string `json:"request_id"`
Status int `json:"status"`
TimeElapsed string `json:"time_elapsed"`
}
func responseJimeng2OpenAIImage(_ *gin.Context, response *ImageResponse, info *relaycommon.RelayInfo) *dto.ImageResponse {
imageResponse := dto.ImageResponse{
Created: info.StartTime.Unix(),
}
for _, base64Data := range response.Data.BinaryDataBase64 {
imageResponse.Data = append(imageResponse.Data, dto.ImageData{
B64Json: base64Data,
})
}
for _, imageUrl := range response.Data.ImageUrls {
imageResponse.Data = append(imageResponse.Data, dto.ImageData{
Url: imageUrl,
})
}
return &imageResponse
}
// jimengImageHandler handles the Jimeng image generation response
func jimengImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) {
var jimengResponse ImageResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
}
common.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &jimengResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
// Check if the response indicates an error
if jimengResponse.Code != 10000 {
return nil, types.WithOpenAIError(types.OpenAIError{
Message: jimengResponse.Message,
Type: "jimeng_error",
Param: "",
Code: fmt.Sprintf("%d", jimengResponse.Code),
}, resp.StatusCode)
}
// Convert Jimeng response to OpenAI format
fullTextResponse := responseJimeng2OpenAIImage(c, &jimengResponse, info)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
return &dto.Usage{}, nil
}

View File

@@ -0,0 +1,176 @@
package jimeng
import (
"bytes"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"net/url"
"one-api/common"
"sort"
"strings"
"time"
)
// SignRequestForJimeng 对即梦 API 请求进行签名,支持 http.Request 或 header+url+body 方式
//func SignRequestForJimeng(req *http.Request, accessKey, secretKey string) error {
// var bodyBytes []byte
// var err error
//
// if req.Body != nil {
// bodyBytes, err = io.ReadAll(req.Body)
// if err != nil {
// return fmt.Errorf("read request body failed: %w", err)
// }
// _ = req.Body.Close()
// req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // rewind
// } else {
// bodyBytes = []byte{}
// }
//
// return signJimengHeaders(&req.Header, req.Method, req.URL, bodyBytes, accessKey, secretKey)
//}
const HexPayloadHashKey = "HexPayloadHash"
func SetPayloadHash(c *gin.Context, req any) error {
body, err := json.Marshal(req)
if err != nil {
return err
}
common.LogInfo(c, fmt.Sprintf("SetPayloadHash body: %s", body))
payloadHash := sha256.Sum256(body)
hexPayloadHash := hex.EncodeToString(payloadHash[:])
c.Set(HexPayloadHashKey, hexPayloadHash)
return nil
}
func getPayloadHash(c *gin.Context) string {
return c.GetString(HexPayloadHashKey)
}
func Sign(c *gin.Context, req *http.Request, apiKey string) error {
header := req.Header
var bodyBytes []byte
var err error
if req.Body != nil {
bodyBytes, err = io.ReadAll(req.Body)
if err != nil {
return err
}
_ = req.Body.Close()
req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // Rewind
}
payloadHash := sha256.Sum256(bodyBytes)
hexPayloadHash := hex.EncodeToString(payloadHash[:])
method := c.Request.Method
u := req.URL
keyParts := strings.Split(apiKey, "|")
if len(keyParts) != 2 {
return errors.New("invalid api key format for jimeng: expected 'ak|sk'")
}
accessKey := strings.TrimSpace(keyParts[0])
secretKey := strings.TrimSpace(keyParts[1])
t := time.Now().UTC()
xDate := t.Format("20060102T150405Z")
shortDate := t.Format("20060102")
host := u.Host
header.Set("Host", host)
header.Set("X-Date", xDate)
header.Set("X-Content-Sha256", hexPayloadHash)
// Sort and encode query parameters to create canonical query string
queryParams := u.Query()
sortedKeys := make([]string, 0, len(queryParams))
for k := range queryParams {
sortedKeys = append(sortedKeys, k)
}
sort.Strings(sortedKeys)
var queryParts []string
for _, k := range sortedKeys {
values := queryParams[k]
sort.Strings(values)
for _, v := range values {
queryParts = append(queryParts, fmt.Sprintf("%s=%s", url.QueryEscape(k), url.QueryEscape(v)))
}
}
canonicalQueryString := strings.Join(queryParts, "&")
headersToSign := map[string]string{
"host": host,
"x-date": xDate,
"x-content-sha256": hexPayloadHash,
}
if header.Get("Content-Type") == "" {
header.Set("Content-Type", "application/json")
}
headersToSign["content-type"] = header.Get("Content-Type")
var signedHeaderKeys []string
for k := range headersToSign {
signedHeaderKeys = append(signedHeaderKeys, k)
}
sort.Strings(signedHeaderKeys)
var canonicalHeaders strings.Builder
for _, k := range signedHeaderKeys {
canonicalHeaders.WriteString(k)
canonicalHeaders.WriteString(":")
canonicalHeaders.WriteString(strings.TrimSpace(headersToSign[k]))
canonicalHeaders.WriteString("\n")
}
signedHeaders := strings.Join(signedHeaderKeys, ";")
canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s",
method,
u.Path,
canonicalQueryString,
canonicalHeaders.String(),
signedHeaders,
hexPayloadHash,
)
hashedCanonicalRequest := sha256.Sum256([]byte(canonicalRequest))
hexHashedCanonicalRequest := hex.EncodeToString(hashedCanonicalRequest[:])
region := "cn-north-1"
serviceName := "cv"
credentialScope := fmt.Sprintf("%s/%s/%s/request", shortDate, region, serviceName)
stringToSign := fmt.Sprintf("HMAC-SHA256\n%s\n%s\n%s",
xDate,
credentialScope,
hexHashedCanonicalRequest,
)
kDate := hmacSHA256([]byte(secretKey), []byte(shortDate))
kRegion := hmacSHA256(kDate, []byte(region))
kService := hmacSHA256(kRegion, []byte(serviceName))
kSigning := hmacSHA256(kService, []byte("request"))
signature := hex.EncodeToString(hmacSHA256(kSigning, []byte(stringToSign)))
authorization := fmt.Sprintf("HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s",
accessKey,
credentialScope,
signedHeaders,
signature,
)
header.Set("Authorization", authorization)
return nil
}
// hmacSHA256 计算 HMAC-SHA256
func hmacSHA256(key []byte, data []byte) []byte {
h := hmac.New(sha256.New, key)
h.Write(data)
return h.Sum(nil)
}

View File

@@ -15,6 +15,7 @@ import (
"one-api/relay/channel/deepseek"
"one-api/relay/channel/dify"
"one-api/relay/channel/gemini"
"one-api/relay/channel/jimeng"
"one-api/relay/channel/jina"
"one-api/relay/channel/mistral"
"one-api/relay/channel/mokaai"
@@ -23,7 +24,7 @@ import (
"one-api/relay/channel/palm"
"one-api/relay/channel/perplexity"
"one-api/relay/channel/siliconflow"
"one-api/relay/channel/task/jimeng"
taskjimeng "one-api/relay/channel/task/jimeng"
"one-api/relay/channel/task/kling"
"one-api/relay/channel/task/suno"
"one-api/relay/channel/tencent"
@@ -93,6 +94,8 @@ func GetAdaptor(apiType int) channel.Adaptor {
return &xai.Adaptor{}
case constant.APITypeCoze:
return &coze.Adaptor{}
case constant.APITypeJimeng:
return &jimeng.Adaptor{}
}
return nil
}
@@ -106,7 +109,7 @@ func GetTaskAdaptor(platform commonconstant.TaskPlatform) channel.TaskAdaptor {
case commonconstant.TaskPlatformKling:
return &kling.TaskAdaptor{}
case commonconstant.TaskPlatformJimeng:
return &jimeng.TaskAdaptor{}
return &taskjimeng.TaskAdaptor{}
}
return nil
}