diff --git a/common/api_type.go b/common/api_type.go index d9071236..f045866a 100644 --- a/common/api_type.go +++ b/common/api_type.go @@ -63,6 +63,8 @@ func ChannelType2APIType(channelType int) (int, bool) { apiType = constant.APITypeXai case constant.ChannelTypeCoze: apiType = constant.APITypeCoze + case constant.ChannelTypeJimeng: + apiType = constant.APITypeJimeng } if apiType == -1 { return constant.APITypeOpenAI, false diff --git a/constant/api_type.go b/constant/api_type.go index ae867870..6ba5f257 100644 --- a/constant/api_type.go +++ b/constant/api_type.go @@ -30,5 +30,6 @@ const ( APITypeXinference APITypeXai APITypeCoze + APITypeJimeng APITypeDummy // this one is only for count, do not add any channel after this ) diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go index 97887266..ff7c63fa 100644 --- a/relay/channel/api_request.go +++ b/relay/channel/api_request.go @@ -203,6 +203,9 @@ func sendPingData(c *gin.Context, mutex *sync.Mutex) error { } } +func DoRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http.Response, error) { + return doRequest(c, req, info) +} func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http.Response, error) { var client *http.Client var err error diff --git a/relay/channel/jimeng/adaptor.go b/relay/channel/jimeng/adaptor.go new file mode 100644 index 00000000..0b743879 --- /dev/null +++ b/relay/channel/jimeng/adaptor.go @@ -0,0 +1,136 @@ +package jimeng + +import ( + "encoding/json" + "errors" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/dto" + "one-api/relay/channel" + "one-api/relay/channel/openai" + relaycommon "one-api/relay/common" + relayconstant "one-api/relay/constant" + "one-api/types" +) + +type Adaptor struct { +} + +func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { + return nil, errors.New("not implemented") +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { +} + +func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { + return fmt.Sprintf("%s/?Action=CVProcess&Version=2022-08-31", info.BaseUrl), nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *relaycommon.RelayInfo) error { + return errors.New("not implemented") +} + +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + +type LogoInfo struct { + AddLogo bool `json:"add_logo,omitempty"` + Position int `json:"position,omitempty"` + Language int `json:"language,omitempty"` + Opacity float64 `json:"opacity,omitempty"` + LogoTextContent string `json:"logo_text_content,omitempty"` +} + +type imageRequestPayload struct { + ReqKey string `json:"req_key"` // Service identifier, fixed value: jimeng_high_aes_general_v21_L + Prompt string `json:"prompt"` // Prompt for image generation, supports both Chinese and English + Seed int64 `json:"seed,omitempty"` // Random seed, default -1 (random) + Width int `json:"width,omitempty"` // Image width, default 512, range [256, 768] + Height int `json:"height,omitempty"` // Image height, default 512, range [256, 768] + UsePreLLM bool `json:"use_pre_llm,omitempty"` // Enable text expansion, default true + UseSR bool `json:"use_sr,omitempty"` // Enable super resolution, default true + ReturnURL bool `json:"return_url,omitempty"` // Whether to return image URL (valid for 24 hours) + LogoInfo LogoInfo `json:"logo_info,omitempty"` // Watermark information + ImageUrls []string `json:"image_urls,omitempty"` // Image URLs for input + BinaryData []string `json:"binary_data_base64,omitempty"` // Base64 encoded binary data +} + +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + payload := imageRequestPayload{ + ReqKey: request.Model, + Prompt: request.Prompt, + } + if request.ResponseFormat == "" || request.ResponseFormat == "url" { + payload.ReturnURL = true // Default to returning image URLs + } + + if len(request.ExtraFields) > 0 { + if err := json.Unmarshal(request.ExtraFields, &payload); err != nil { + return nil, fmt.Errorf("failed to unmarshal extra fields: %w", err) + } + } + + return payload, nil +} + +func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { + return nil, errors.New("not implemented") +} + +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + return nil, errors.New("not implemented") +} + +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { + return nil, errors.New("not implemented") +} + +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + return nil, errors.New("not implemented") +} + +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { + fullRequestURL, err := a.GetRequestURL(info) + if err != nil { + return nil, fmt.Errorf("get request url failed: %w", err) + } + req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) + if err != nil { + return nil, fmt.Errorf("new request failed: %w", err) + } + err = Sign(c, req, info.ApiKey) + if err != nil { + return nil, fmt.Errorf("setup request header failed: %w", err) + } + resp, err := channel.DoRequest(c, req, info) + if err != nil { + return nil, fmt.Errorf("do request failed: %w", err) + } + return resp, nil +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { + if info.RelayMode == relayconstant.RelayModeImagesGenerations { + usage, err = jimengImageHandler(c, resp, info) + } else if info.IsStream { + usage, err = openai.OaiStreamHandler(c, info, resp) + } else { + usage, err = openai.OpenaiHandler(c, info, resp) + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return ChannelName +} diff --git a/relay/channel/jimeng/constants.go b/relay/channel/jimeng/constants.go new file mode 100644 index 00000000..0d1764e5 --- /dev/null +++ b/relay/channel/jimeng/constants.go @@ -0,0 +1,9 @@ +package jimeng + +const ( + ChannelName = "jimeng" +) + +var ModelList = []string{ + "jimeng_high_aes_general_v21_L", +} diff --git a/relay/channel/jimeng/image.go b/relay/channel/jimeng/image.go new file mode 100644 index 00000000..3c6a1d99 --- /dev/null +++ b/relay/channel/jimeng/image.go @@ -0,0 +1,89 @@ +package jimeng + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "one-api/common" + "one-api/dto" + relaycommon "one-api/relay/common" + "one-api/types" + + "github.com/gin-gonic/gin" +) + +type ImageResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Data struct { + BinaryDataBase64 []string `json:"binary_data_base64"` + ImageUrls []string `json:"image_urls"` + RephraseResult string `json:"rephraser_result"` + RequestID string `json:"request_id"` + // Other fields are omitted for brevity + } `json:"data"` + RequestID string `json:"request_id"` + Status int `json:"status"` + TimeElapsed string `json:"time_elapsed"` +} + +func responseJimeng2OpenAIImage(_ *gin.Context, response *ImageResponse, info *relaycommon.RelayInfo) *dto.ImageResponse { + imageResponse := dto.ImageResponse{ + Created: info.StartTime.Unix(), + } + + for _, base64Data := range response.Data.BinaryDataBase64 { + imageResponse.Data = append(imageResponse.Data, dto.ImageData{ + B64Json: base64Data, + }) + } + for _, imageUrl := range response.Data.ImageUrls { + imageResponse.Data = append(imageResponse.Data, dto.ImageData{ + Url: imageUrl, + }) + } + + return &imageResponse +} + +// jimengImageHandler handles the Jimeng image generation response +func jimengImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) { + var jimengResponse ImageResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) + } + common.CloseResponseBodyGracefully(resp) + + err = json.Unmarshal(responseBody, &jimengResponse) + if err != nil { + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + } + + // Check if the response indicates an error + if jimengResponse.Code != 10000 { + return nil, types.WithOpenAIError(types.OpenAIError{ + Message: jimengResponse.Message, + Type: "jimeng_error", + Param: "", + Code: fmt.Sprintf("%d", jimengResponse.Code), + }, resp.StatusCode) + } + + // Convert Jimeng response to OpenAI format + fullTextResponse := responseJimeng2OpenAIImage(c, &jimengResponse, info) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + } + + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + if err != nil { + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + } + + return &dto.Usage{}, nil +} diff --git a/relay/channel/jimeng/sign.go b/relay/channel/jimeng/sign.go new file mode 100644 index 00000000..c9db6630 --- /dev/null +++ b/relay/channel/jimeng/sign.go @@ -0,0 +1,176 @@ +package jimeng + +import ( + "bytes" + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "net/url" + "one-api/common" + "sort" + "strings" + "time" +) + +// SignRequestForJimeng 对即梦 API 请求进行签名,支持 http.Request 或 header+url+body 方式 +//func SignRequestForJimeng(req *http.Request, accessKey, secretKey string) error { +// var bodyBytes []byte +// var err error +// +// if req.Body != nil { +// bodyBytes, err = io.ReadAll(req.Body) +// if err != nil { +// return fmt.Errorf("read request body failed: %w", err) +// } +// _ = req.Body.Close() +// req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // rewind +// } else { +// bodyBytes = []byte{} +// } +// +// return signJimengHeaders(&req.Header, req.Method, req.URL, bodyBytes, accessKey, secretKey) +//} + +const HexPayloadHashKey = "HexPayloadHash" + +func SetPayloadHash(c *gin.Context, req any) error { + body, err := json.Marshal(req) + if err != nil { + return err + } + common.LogInfo(c, fmt.Sprintf("SetPayloadHash body: %s", body)) + payloadHash := sha256.Sum256(body) + hexPayloadHash := hex.EncodeToString(payloadHash[:]) + c.Set(HexPayloadHashKey, hexPayloadHash) + return nil +} +func getPayloadHash(c *gin.Context) string { + return c.GetString(HexPayloadHashKey) +} + +func Sign(c *gin.Context, req *http.Request, apiKey string) error { + header := req.Header + + var bodyBytes []byte + var err error + + if req.Body != nil { + bodyBytes, err = io.ReadAll(req.Body) + if err != nil { + return err + } + _ = req.Body.Close() + req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // Rewind + } + + payloadHash := sha256.Sum256(bodyBytes) + hexPayloadHash := hex.EncodeToString(payloadHash[:]) + + method := c.Request.Method + u := req.URL + keyParts := strings.Split(apiKey, "|") + if len(keyParts) != 2 { + return errors.New("invalid api key format for jimeng: expected 'ak|sk'") + } + accessKey := strings.TrimSpace(keyParts[0]) + secretKey := strings.TrimSpace(keyParts[1]) + t := time.Now().UTC() + xDate := t.Format("20060102T150405Z") + shortDate := t.Format("20060102") + + host := u.Host + header.Set("Host", host) + header.Set("X-Date", xDate) + header.Set("X-Content-Sha256", hexPayloadHash) + + // Sort and encode query parameters to create canonical query string + queryParams := u.Query() + sortedKeys := make([]string, 0, len(queryParams)) + for k := range queryParams { + sortedKeys = append(sortedKeys, k) + } + sort.Strings(sortedKeys) + var queryParts []string + for _, k := range sortedKeys { + values := queryParams[k] + sort.Strings(values) + for _, v := range values { + queryParts = append(queryParts, fmt.Sprintf("%s=%s", url.QueryEscape(k), url.QueryEscape(v))) + } + } + canonicalQueryString := strings.Join(queryParts, "&") + + headersToSign := map[string]string{ + "host": host, + "x-date": xDate, + "x-content-sha256": hexPayloadHash, + } + if header.Get("Content-Type") == "" { + header.Set("Content-Type", "application/json") + } + headersToSign["content-type"] = header.Get("Content-Type") + + var signedHeaderKeys []string + for k := range headersToSign { + signedHeaderKeys = append(signedHeaderKeys, k) + } + sort.Strings(signedHeaderKeys) + + var canonicalHeaders strings.Builder + for _, k := range signedHeaderKeys { + canonicalHeaders.WriteString(k) + canonicalHeaders.WriteString(":") + canonicalHeaders.WriteString(strings.TrimSpace(headersToSign[k])) + canonicalHeaders.WriteString("\n") + } + signedHeaders := strings.Join(signedHeaderKeys, ";") + + canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s", + method, + u.Path, + canonicalQueryString, + canonicalHeaders.String(), + signedHeaders, + hexPayloadHash, + ) + + hashedCanonicalRequest := sha256.Sum256([]byte(canonicalRequest)) + hexHashedCanonicalRequest := hex.EncodeToString(hashedCanonicalRequest[:]) + + region := "cn-north-1" + serviceName := "cv" + credentialScope := fmt.Sprintf("%s/%s/%s/request", shortDate, region, serviceName) + stringToSign := fmt.Sprintf("HMAC-SHA256\n%s\n%s\n%s", + xDate, + credentialScope, + hexHashedCanonicalRequest, + ) + + kDate := hmacSHA256([]byte(secretKey), []byte(shortDate)) + kRegion := hmacSHA256(kDate, []byte(region)) + kService := hmacSHA256(kRegion, []byte(serviceName)) + kSigning := hmacSHA256(kService, []byte("request")) + signature := hex.EncodeToString(hmacSHA256(kSigning, []byte(stringToSign))) + + authorization := fmt.Sprintf("HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s", + accessKey, + credentialScope, + signedHeaders, + signature, + ) + header.Set("Authorization", authorization) + return nil +} + +// hmacSHA256 计算 HMAC-SHA256 +func hmacSHA256(key []byte, data []byte) []byte { + h := hmac.New(sha256.New, key) + h.Write(data) + return h.Sum(nil) +} diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go index 00e59eac..2ce12a87 100644 --- a/relay/relay_adaptor.go +++ b/relay/relay_adaptor.go @@ -15,6 +15,7 @@ import ( "one-api/relay/channel/deepseek" "one-api/relay/channel/dify" "one-api/relay/channel/gemini" + "one-api/relay/channel/jimeng" "one-api/relay/channel/jina" "one-api/relay/channel/mistral" "one-api/relay/channel/mokaai" @@ -23,7 +24,7 @@ import ( "one-api/relay/channel/palm" "one-api/relay/channel/perplexity" "one-api/relay/channel/siliconflow" - "one-api/relay/channel/task/jimeng" + taskjimeng "one-api/relay/channel/task/jimeng" "one-api/relay/channel/task/kling" "one-api/relay/channel/task/suno" "one-api/relay/channel/tencent" @@ -93,6 +94,8 @@ func GetAdaptor(apiType int) channel.Adaptor { return &xai.Adaptor{} case constant.APITypeCoze: return &coze.Adaptor{} + case constant.APITypeJimeng: + return &jimeng.Adaptor{} } return nil } @@ -106,7 +109,7 @@ func GetTaskAdaptor(platform commonconstant.TaskPlatform) channel.TaskAdaptor { case commonconstant.TaskPlatformKling: return &kling.TaskAdaptor{} case commonconstant.TaskPlatformJimeng: - return &jimeng.TaskAdaptor{} + return &taskjimeng.TaskAdaptor{} } return nil }