feat: support /images/edit
(cherry picked from commit 1c0a1238787d490f02dd9269b616580a16604180)
This commit is contained in:
@@ -24,7 +24,7 @@ import (
|
|||||||
func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
||||||
var err *dto.OpenAIErrorWithStatusCode
|
var err *dto.OpenAIErrorWithStatusCode
|
||||||
switch relayMode {
|
switch relayMode {
|
||||||
case relayconstant.RelayModeImagesGenerations:
|
case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
|
||||||
err = relay.ImageHelper(c)
|
err = relay.ImageHelper(c)
|
||||||
case relayconstant.RelayModeAudioSpeech:
|
case relayconstant.RelayModeAudioSpeech:
|
||||||
fallthrough
|
fallthrough
|
||||||
|
|||||||
@@ -166,12 +166,16 @@ type CompletionsStreamResponse struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Usage struct {
|
type Usage struct {
|
||||||
PromptTokens int `json:"prompt_tokens"`
|
PromptTokens int `json:"prompt_tokens"`
|
||||||
CompletionTokens int `json:"completion_tokens"`
|
CompletionTokens int `json:"completion_tokens"`
|
||||||
TotalTokens int `json:"total_tokens"`
|
TotalTokens int `json:"total_tokens"`
|
||||||
PromptCacheHitTokens int `json:"prompt_cache_hit_tokens,omitempty"`
|
PromptCacheHitTokens int `json:"prompt_cache_hit_tokens,omitempty"`
|
||||||
|
|
||||||
PromptTokensDetails InputTokenDetails `json:"prompt_tokens_details"`
|
PromptTokensDetails InputTokenDetails `json:"prompt_tokens_details"`
|
||||||
CompletionTokenDetails OutputTokenDetails `json:"completion_tokens_details"`
|
CompletionTokenDetails OutputTokenDetails `json:"completion_tokens_details"`
|
||||||
|
InputTokens int `json:"input_tokens"`
|
||||||
|
OutputTokens int `json:"output_tokens"`
|
||||||
|
InputTokensDetails *InputTokenDetails `json:"input_tokens_details"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type InputTokenDetails struct {
|
type InputTokenDetails struct {
|
||||||
|
|||||||
@@ -162,7 +162,7 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
|||||||
}
|
}
|
||||||
c.Set("platform", string(constant.TaskPlatformSuno))
|
c.Set("platform", string(constant.TaskPlatformSuno))
|
||||||
c.Set("relay_mode", relayMode)
|
c.Set("relay_mode", relayMode)
|
||||||
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
|
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") && !strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") {
|
||||||
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -184,6 +184,8 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
|||||||
}
|
}
|
||||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
||||||
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "dall-e")
|
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "dall-e")
|
||||||
|
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") {
|
||||||
|
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "gpt-image-1")
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
|
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
|
||||||
relayMode := relayconstant.RelayModeAudioSpeech
|
relayMode := relayconstant.RelayModeAudioSpeech
|
||||||
|
|||||||
@@ -236,11 +236,64 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||||
return request, nil
|
switch info.RelayMode {
|
||||||
|
case constant.RelayModeImagesEdits:
|
||||||
|
body, err := common.GetRequestBody(c)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.New("get request body fail")
|
||||||
|
}
|
||||||
|
return bytes.NewReader(body), nil
|
||||||
|
|
||||||
|
/*var requestBody bytes.Buffer
|
||||||
|
writer := multipart.NewWriter(&requestBody)
|
||||||
|
|
||||||
|
writer.WriteField("model", request.Model)
|
||||||
|
// 获取所有表单字段
|
||||||
|
formData := c.Request.PostForm
|
||||||
|
// 遍历表单字段并打印输出
|
||||||
|
for key, values := range formData {
|
||||||
|
if key == "model" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, value := range values {
|
||||||
|
writer.WriteField(key, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加文件字段
|
||||||
|
imageFiles := c.Request.MultipartForm.File["image[]"]
|
||||||
|
for _, file := range imageFiles {
|
||||||
|
part, err := writer.CreateFormFile("image[]", file.Filename)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.New("create form file failed")
|
||||||
|
}
|
||||||
|
// 打开文件
|
||||||
|
src, err := file.Open()
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.New("open file failed")
|
||||||
|
}
|
||||||
|
// 将文件数据写入 form part
|
||||||
|
_, err = io.Copy(part, src)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.New("copy file failed")
|
||||||
|
}
|
||||||
|
src.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 关闭 multipart 编写器以设置分界线
|
||||||
|
writer.Close()
|
||||||
|
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
|
||||||
|
return bytes.NewReader(requestBody.Bytes()), nil*/
|
||||||
|
|
||||||
|
default:
|
||||||
|
return request, nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation {
|
if info.RelayMode == constant.RelayModeAudioTranscription ||
|
||||||
|
info.RelayMode == constant.RelayModeAudioTranslation ||
|
||||||
|
info.RelayMode == constant.RelayModeImagesEdits {
|
||||||
return channel.DoFormRequest(a, c, info, requestBody)
|
return channel.DoFormRequest(a, c, info, requestBody)
|
||||||
} else if info.RelayMode == constant.RelayModeRealtime {
|
} else if info.RelayMode == constant.RelayModeRealtime {
|
||||||
return channel.DoWssRequest(a, c, info, requestBody)
|
return channel.DoWssRequest(a, c, info, requestBody)
|
||||||
@@ -259,8 +312,8 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
fallthrough
|
fallthrough
|
||||||
case constant.RelayModeAudioTranscription:
|
case constant.RelayModeAudioTranscription:
|
||||||
err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat)
|
err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat)
|
||||||
case constant.RelayModeImagesGenerations:
|
case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits:
|
||||||
err, usage = OpenaiTTSHandler(c, resp, info)
|
err, usage = OpenaiHandlerWithUsage(c, resp, info)
|
||||||
case constant.RelayModeRerank:
|
case constant.RelayModeRerank:
|
||||||
err, usage = common_handler.RerankHandler(c, info, resp)
|
err, usage = common_handler.RerankHandler(c, info, resp)
|
||||||
default:
|
default:
|
||||||
|
|||||||
@@ -595,3 +595,52 @@ func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.R
|
|||||||
err := service.PreWssConsumeQuota(ctx, info, usage)
|
err := service.PreWssConsumeQuota(ctx, info, usage)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
// Reset response body
|
||||||
|
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||||
|
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
||||||
|
// And then we will have to send an error response, but in this case, the header has already been set.
|
||||||
|
// So the httpClient will be confused by the response.
|
||||||
|
// For example, Postman will report error, and we cannot check the response at all.
|
||||||
|
for k, v := range resp.Header {
|
||||||
|
c.Writer.Header().Set(k, v[0])
|
||||||
|
}
|
||||||
|
// reset content length
|
||||||
|
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(responseBody)))
|
||||||
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
_, err = io.Copy(c.Writer, resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var usageResp dto.SimpleResponse
|
||||||
|
err = json.Unmarshal(responseBody, &usageResp)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "parse_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
// format
|
||||||
|
if usageResp.InputTokens > 0 {
|
||||||
|
usageResp.PromptTokens += usageResp.InputTokens
|
||||||
|
}
|
||||||
|
if usageResp.OutputTokens > 0 {
|
||||||
|
usageResp.CompletionTokens += usageResp.OutputTokens
|
||||||
|
}
|
||||||
|
if usageResp.InputTokensDetails != nil {
|
||||||
|
usageResp.PromptTokensDetails.ImageTokens += usageResp.InputTokensDetails.ImageTokens
|
||||||
|
usageResp.PromptTokensDetails.TextTokens += usageResp.InputTokensDetails.TextTokens
|
||||||
|
}
|
||||||
|
return nil, &usageResp.Usage
|
||||||
|
}
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ const (
|
|||||||
RelayModeEmbeddings
|
RelayModeEmbeddings
|
||||||
RelayModeModerations
|
RelayModeModerations
|
||||||
RelayModeImagesGenerations
|
RelayModeImagesGenerations
|
||||||
|
RelayModeImagesEdits
|
||||||
RelayModeEdits
|
RelayModeEdits
|
||||||
|
|
||||||
RelayModeMidjourneyImagine
|
RelayModeMidjourneyImagine
|
||||||
@@ -56,6 +57,8 @@ func Path2RelayMode(path string) int {
|
|||||||
relayMode = RelayModeModerations
|
relayMode = RelayModeModerations
|
||||||
} else if strings.HasPrefix(path, "/v1/images/generations") {
|
} else if strings.HasPrefix(path, "/v1/images/generations") {
|
||||||
relayMode = RelayModeImagesGenerations
|
relayMode = RelayModeImagesGenerations
|
||||||
|
} else if strings.HasPrefix(path, "/v1/images/edits") {
|
||||||
|
relayMode = RelayModeImagesEdits
|
||||||
} else if strings.HasPrefix(path, "/v1/edits") {
|
} else if strings.HasPrefix(path, "/v1/edits") {
|
||||||
relayMode = RelayModeEdits
|
relayMode = RelayModeEdits
|
||||||
} else if strings.HasPrefix(path, "/v1/audio/speech") {
|
} else if strings.HasPrefix(path, "/v1/audio/speech") {
|
||||||
|
|||||||
@@ -15,14 +15,15 @@ type PriceData struct {
|
|||||||
ModelRatio float64
|
ModelRatio float64
|
||||||
CompletionRatio float64
|
CompletionRatio float64
|
||||||
CacheRatio float64
|
CacheRatio float64
|
||||||
|
CacheCreationRatio float64
|
||||||
|
ImageRatio float64
|
||||||
GroupRatio float64
|
GroupRatio float64
|
||||||
UsePrice bool
|
UsePrice bool
|
||||||
CacheCreationRatio float64
|
|
||||||
ShouldPreConsumedQuota int
|
ShouldPreConsumedQuota int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p PriceData) ToSetting() string {
|
func (p PriceData) ToSetting() string {
|
||||||
return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota)
|
return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %d", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) {
|
func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) {
|
||||||
@@ -32,6 +33,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
|
|||||||
var modelRatio float64
|
var modelRatio float64
|
||||||
var completionRatio float64
|
var completionRatio float64
|
||||||
var cacheRatio float64
|
var cacheRatio float64
|
||||||
|
var imageRatio float64
|
||||||
var cacheCreationRatio float64
|
var cacheCreationRatio float64
|
||||||
if !usePrice {
|
if !usePrice {
|
||||||
preConsumedTokens := common.PreConsumedQuota
|
preConsumedTokens := common.PreConsumedQuota
|
||||||
@@ -55,6 +57,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
|
|||||||
completionRatio = operation_setting.GetCompletionRatio(info.OriginModelName)
|
completionRatio = operation_setting.GetCompletionRatio(info.OriginModelName)
|
||||||
cacheRatio, _ = operation_setting.GetCacheRatio(info.OriginModelName)
|
cacheRatio, _ = operation_setting.GetCacheRatio(info.OriginModelName)
|
||||||
cacheCreationRatio, _ = operation_setting.GetCreateCacheRatio(info.OriginModelName)
|
cacheCreationRatio, _ = operation_setting.GetCreateCacheRatio(info.OriginModelName)
|
||||||
|
imageRatio, _ = operation_setting.GetImageRatio(info.OriginModelName)
|
||||||
ratio := modelRatio * groupRatio
|
ratio := modelRatio * groupRatio
|
||||||
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
|
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
|
||||||
} else {
|
} else {
|
||||||
@@ -68,6 +71,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
|
|||||||
GroupRatio: groupRatio,
|
GroupRatio: groupRatio,
|
||||||
UsePrice: usePrice,
|
UsePrice: usePrice,
|
||||||
CacheRatio: cacheRatio,
|
CacheRatio: cacheRatio,
|
||||||
|
ImageRatio: imageRatio,
|
||||||
CacheCreationRatio: cacheCreationRatio,
|
CacheCreationRatio: cacheCreationRatio,
|
||||||
ShouldPreConsumedQuota: preConsumedQuota,
|
ShouldPreConsumedQuota: preConsumedQuota,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
|
relayconstant "one-api/relay/constant"
|
||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
@@ -20,13 +21,56 @@ import (
|
|||||||
|
|
||||||
func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.ImageRequest, error) {
|
func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.ImageRequest, error) {
|
||||||
imageRequest := &dto.ImageRequest{}
|
imageRequest := &dto.ImageRequest{}
|
||||||
err := common.UnmarshalBodyReusable(c, imageRequest)
|
|
||||||
if err != nil {
|
switch info.RelayMode {
|
||||||
return nil, err
|
case relayconstant.RelayModeImagesEdits:
|
||||||
|
_, err := c.MultipartForm()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
formData := c.Request.PostForm
|
||||||
|
imageRequest.Prompt = formData.Get("prompt")
|
||||||
|
imageRequest.Model = formData.Get("model")
|
||||||
|
imageRequest.N = common.String2Int(formData.Get("n"))
|
||||||
|
imageRequest.Quality = formData.Get("quality")
|
||||||
|
imageRequest.Size = formData.Get("size")
|
||||||
|
|
||||||
|
if imageRequest.Model == "gpt-image-1" {
|
||||||
|
if imageRequest.Quality == "" {
|
||||||
|
imageRequest.Quality = "standard"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
err := common.UnmarshalBodyReusable(c, imageRequest)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// Not "256x256", "512x512", or "1024x1024"
|
||||||
|
if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" {
|
||||||
|
if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
|
||||||
|
return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024")
|
||||||
|
}
|
||||||
|
} else if imageRequest.Model == "dall-e-3" {
|
||||||
|
if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" {
|
||||||
|
return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024")
|
||||||
|
}
|
||||||
|
if imageRequest.Quality == "" {
|
||||||
|
imageRequest.Quality = "standard"
|
||||||
|
}
|
||||||
|
// N should between 1 and 10
|
||||||
|
//if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) {
|
||||||
|
// return service.OpenAIErrorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest)
|
||||||
|
//}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if imageRequest.Prompt == "" {
|
if imageRequest.Prompt == "" {
|
||||||
return nil, errors.New("prompt is required")
|
return nil, errors.New("prompt is required")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if imageRequest.Model == "" {
|
||||||
|
imageRequest.Model = "dall-e-2"
|
||||||
|
}
|
||||||
if strings.Contains(imageRequest.Size, "×") {
|
if strings.Contains(imageRequest.Size, "×") {
|
||||||
return nil, errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'")
|
return nil, errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'")
|
||||||
}
|
}
|
||||||
@@ -36,30 +80,7 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
|
|||||||
if imageRequest.Size == "" {
|
if imageRequest.Size == "" {
|
||||||
imageRequest.Size = "1024x1024"
|
imageRequest.Size = "1024x1024"
|
||||||
}
|
}
|
||||||
if imageRequest.Model == "" {
|
|
||||||
imageRequest.Model = "dall-e-2"
|
|
||||||
}
|
|
||||||
|
|
||||||
// Not "256x256", "512x512", or "1024x1024"
|
|
||||||
if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" {
|
|
||||||
if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
|
|
||||||
return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024")
|
|
||||||
}
|
|
||||||
} else if imageRequest.Model == "dall-e-3" {
|
|
||||||
if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" {
|
|
||||||
return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024")
|
|
||||||
}
|
|
||||||
if imageRequest.Quality == "" {
|
|
||||||
imageRequest.Quality = "standard"
|
|
||||||
}
|
|
||||||
//if imageRequest.N != 1 {
|
|
||||||
// return nil, errors.New("n must be 1")
|
|
||||||
//}
|
|
||||||
}
|
|
||||||
// N should between 1 and 10
|
|
||||||
//if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) {
|
|
||||||
// return service.OpenAIErrorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest)
|
|
||||||
//}
|
|
||||||
if setting.ShouldCheckPromptSensitive() {
|
if setting.ShouldCheckPromptSensitive() {
|
||||||
words, err := service.CheckSensitiveInput(imageRequest.Prompt)
|
words, err := service.CheckSensitiveInput(imageRequest.Prompt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -86,43 +107,59 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
|||||||
|
|
||||||
imageRequest.Model = relayInfo.UpstreamModelName
|
imageRequest.Model = relayInfo.UpstreamModelName
|
||||||
|
|
||||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, 0, 0)
|
priceData, err := helper.ModelPriceHelper(c, relayInfo, len(imageRequest.Prompt), 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
var preConsumedQuota int
|
||||||
|
var quota int
|
||||||
|
var userQuota int
|
||||||
if !priceData.UsePrice {
|
if !priceData.UsePrice {
|
||||||
// modelRatio 16 = modelPrice $0.04
|
// modelRatio 16 = modelPrice $0.04
|
||||||
// per 1 modelRatio = $0.04 / 16
|
// per 1 modelRatio = $0.04 / 16
|
||||||
priceData.ModelPrice = 0.0025 * priceData.ModelRatio
|
// priceData.ModelPrice = 0.0025 * priceData.ModelRatio
|
||||||
}
|
var openaiErr *dto.OpenAIErrorWithStatusCode
|
||||||
|
preConsumedQuota, userQuota, openaiErr = preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||||
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
|
if openaiErr != nil {
|
||||||
|
return openaiErr
|
||||||
sizeRatio := 1.0
|
|
||||||
// Size
|
|
||||||
if imageRequest.Size == "256x256" {
|
|
||||||
sizeRatio = 0.4
|
|
||||||
} else if imageRequest.Size == "512x512" {
|
|
||||||
sizeRatio = 0.45
|
|
||||||
} else if imageRequest.Size == "1024x1024" {
|
|
||||||
sizeRatio = 1
|
|
||||||
} else if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" {
|
|
||||||
sizeRatio = 2
|
|
||||||
}
|
|
||||||
|
|
||||||
qualityRatio := 1.0
|
|
||||||
if imageRequest.Model == "dall-e-3" && imageRequest.Quality == "hd" {
|
|
||||||
qualityRatio = 2.0
|
|
||||||
if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" {
|
|
||||||
qualityRatio = 1.5
|
|
||||||
}
|
}
|
||||||
}
|
defer func() {
|
||||||
|
if openaiErr != nil {
|
||||||
|
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
priceData.ModelPrice *= sizeRatio * qualityRatio * float64(imageRequest.N)
|
} else {
|
||||||
quota := int(priceData.ModelPrice * priceData.GroupRatio * common.QuotaPerUnit)
|
sizeRatio := 1.0
|
||||||
|
// Size
|
||||||
|
if imageRequest.Size == "256x256" {
|
||||||
|
sizeRatio = 0.4
|
||||||
|
} else if imageRequest.Size == "512x512" {
|
||||||
|
sizeRatio = 0.45
|
||||||
|
} else if imageRequest.Size == "1024x1024" {
|
||||||
|
sizeRatio = 1
|
||||||
|
} else if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" {
|
||||||
|
sizeRatio = 2
|
||||||
|
}
|
||||||
|
|
||||||
if userQuota-quota < 0 {
|
qualityRatio := 1.0
|
||||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("image pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)), "insufficient_user_quota", http.StatusForbidden)
|
if imageRequest.Model == "dall-e-3" && imageRequest.Quality == "hd" {
|
||||||
|
qualityRatio = 2.0
|
||||||
|
if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" {
|
||||||
|
qualityRatio = 1.5
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// reset model price
|
||||||
|
priceData.ModelPrice *= sizeRatio * qualityRatio * float64(imageRequest.N)
|
||||||
|
quota = int(priceData.ModelPrice * priceData.GroupRatio * common.QuotaPerUnit)
|
||||||
|
userQuota, err = model.GetUserQuota(relayInfo.UserId, false)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
if userQuota-quota < 0 {
|
||||||
|
return service.OpenAIErrorWrapperLocal(fmt.Errorf("image pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)), "insufficient_user_quota", http.StatusForbidden)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||||
@@ -137,12 +174,15 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
if relayInfo.RelayMode == relayconstant.RelayModeImagesEdits {
|
||||||
jsonData, err := json.Marshal(convertedRequest)
|
requestBody = convertedRequest.(io.Reader)
|
||||||
if err != nil {
|
} else {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
|
jsonData, err := json.Marshal(convertedRequest)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
requestBody = bytes.NewBuffer(jsonData)
|
||||||
}
|
}
|
||||||
requestBody = bytes.NewBuffer(jsonData)
|
|
||||||
|
|
||||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||||
|
|
||||||
@@ -162,24 +202,25 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
_, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
|
usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
|
||||||
if openaiErr != nil {
|
if openaiErr != nil {
|
||||||
// reset status code 重置状态码
|
// reset status code 重置状态码
|
||||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||||
return openaiErr
|
return openaiErr
|
||||||
}
|
}
|
||||||
|
|
||||||
usage := &dto.Usage{
|
if usage.(*dto.Usage).TotalTokens == 0 {
|
||||||
PromptTokens: imageRequest.N,
|
usage.(*dto.Usage).TotalTokens = imageRequest.N
|
||||||
TotalTokens: imageRequest.N,
|
}
|
||||||
|
if usage.(*dto.Usage).PromptTokens == 0 {
|
||||||
|
usage.(*dto.Usage).PromptTokens = imageRequest.N
|
||||||
}
|
}
|
||||||
|
|
||||||
quality := "standard"
|
quality := "standard"
|
||||||
if imageRequest.Quality == "hd" {
|
if imageRequest.Quality == "hd" {
|
||||||
quality = "hd"
|
quality = "hd"
|
||||||
}
|
}
|
||||||
|
|
||||||
logContent := fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality)
|
logContent := fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality)
|
||||||
postConsumeQuota(c, relayInfo, usage, 0, userQuota, priceData, logContent)
|
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, logContent)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -331,12 +331,14 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
|||||||
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
|
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
|
||||||
promptTokens := usage.PromptTokens
|
promptTokens := usage.PromptTokens
|
||||||
cacheTokens := usage.PromptTokensDetails.CachedTokens
|
cacheTokens := usage.PromptTokensDetails.CachedTokens
|
||||||
|
imageTokens := usage.PromptTokensDetails.ImageTokens
|
||||||
completionTokens := usage.CompletionTokens
|
completionTokens := usage.CompletionTokens
|
||||||
modelName := relayInfo.OriginModelName
|
modelName := relayInfo.OriginModelName
|
||||||
|
|
||||||
tokenName := ctx.GetString("token_name")
|
tokenName := ctx.GetString("token_name")
|
||||||
completionRatio := priceData.CompletionRatio
|
completionRatio := priceData.CompletionRatio
|
||||||
cacheRatio := priceData.CacheRatio
|
cacheRatio := priceData.CacheRatio
|
||||||
|
imageRatio := priceData.ImageRatio
|
||||||
modelRatio := priceData.ModelRatio
|
modelRatio := priceData.ModelRatio
|
||||||
groupRatio := priceData.GroupRatio
|
groupRatio := priceData.GroupRatio
|
||||||
modelPrice := priceData.ModelPrice
|
modelPrice := priceData.ModelPrice
|
||||||
@@ -344,9 +346,11 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
|||||||
// Convert values to decimal for precise calculation
|
// Convert values to decimal for precise calculation
|
||||||
dPromptTokens := decimal.NewFromInt(int64(promptTokens))
|
dPromptTokens := decimal.NewFromInt(int64(promptTokens))
|
||||||
dCacheTokens := decimal.NewFromInt(int64(cacheTokens))
|
dCacheTokens := decimal.NewFromInt(int64(cacheTokens))
|
||||||
|
dImageTokens := decimal.NewFromInt(int64(imageTokens))
|
||||||
dCompletionTokens := decimal.NewFromInt(int64(completionTokens))
|
dCompletionTokens := decimal.NewFromInt(int64(completionTokens))
|
||||||
dCompletionRatio := decimal.NewFromFloat(completionRatio)
|
dCompletionRatio := decimal.NewFromFloat(completionRatio)
|
||||||
dCacheRatio := decimal.NewFromFloat(cacheRatio)
|
dCacheRatio := decimal.NewFromFloat(cacheRatio)
|
||||||
|
dImageRatio := decimal.NewFromFloat(imageRatio)
|
||||||
dModelRatio := decimal.NewFromFloat(modelRatio)
|
dModelRatio := decimal.NewFromFloat(modelRatio)
|
||||||
dGroupRatio := decimal.NewFromFloat(groupRatio)
|
dGroupRatio := decimal.NewFromFloat(groupRatio)
|
||||||
dModelPrice := decimal.NewFromFloat(modelPrice)
|
dModelPrice := decimal.NewFromFloat(modelPrice)
|
||||||
@@ -358,7 +362,14 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
|||||||
if !priceData.UsePrice {
|
if !priceData.UsePrice {
|
||||||
nonCachedTokens := dPromptTokens.Sub(dCacheTokens)
|
nonCachedTokens := dPromptTokens.Sub(dCacheTokens)
|
||||||
cachedTokensWithRatio := dCacheTokens.Mul(dCacheRatio)
|
cachedTokensWithRatio := dCacheTokens.Mul(dCacheRatio)
|
||||||
|
|
||||||
promptQuota := nonCachedTokens.Add(cachedTokensWithRatio)
|
promptQuota := nonCachedTokens.Add(cachedTokensWithRatio)
|
||||||
|
if imageTokens > 0 {
|
||||||
|
nonImageTokens := dPromptTokens.Sub(dImageTokens)
|
||||||
|
imageTokensWithRatio := dImageTokens.Mul(dImageRatio)
|
||||||
|
promptQuota = nonImageTokens.Add(imageTokensWithRatio)
|
||||||
|
}
|
||||||
|
|
||||||
completionQuota := dCompletionTokens.Mul(dCompletionRatio)
|
completionQuota := dCompletionTokens.Mul(dCompletionRatio)
|
||||||
|
|
||||||
quotaCalculateDecimal = promptQuota.Add(completionQuota).Mul(ratio)
|
quotaCalculateDecimal = promptQuota.Add(completionQuota).Mul(ratio)
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ func SetRelayRouter(router *gin.Engine) {
|
|||||||
httpRouter.POST("/chat/completions", controller.Relay)
|
httpRouter.POST("/chat/completions", controller.Relay)
|
||||||
httpRouter.POST("/edits", controller.Relay)
|
httpRouter.POST("/edits", controller.Relay)
|
||||||
httpRouter.POST("/images/generations", controller.Relay)
|
httpRouter.POST("/images/generations", controller.Relay)
|
||||||
httpRouter.POST("/images/edits", controller.RelayNotImplemented)
|
httpRouter.POST("/images/edits", controller.Relay)
|
||||||
httpRouter.POST("/images/variations", controller.RelayNotImplemented)
|
httpRouter.POST("/images/variations", controller.RelayNotImplemented)
|
||||||
httpRouter.POST("/embeddings", controller.Relay)
|
httpRouter.POST("/embeddings", controller.Relay)
|
||||||
httpRouter.POST("/engines/:model/embeddings", controller.Relay)
|
httpRouter.POST("/engines/:model/embeddings", controller.Relay)
|
||||||
|
|||||||
@@ -51,26 +51,27 @@ var defaultModelRatio = map[string]float64{
|
|||||||
"gpt-4o-realtime-preview-2024-12-17": 2.5,
|
"gpt-4o-realtime-preview-2024-12-17": 2.5,
|
||||||
"gpt-4o-mini-realtime-preview": 0.3,
|
"gpt-4o-mini-realtime-preview": 0.3,
|
||||||
"gpt-4o-mini-realtime-preview-2024-12-17": 0.3,
|
"gpt-4o-mini-realtime-preview-2024-12-17": 0.3,
|
||||||
"o1": 7.5,
|
"gpt-image-1": 2.5,
|
||||||
"o1-2024-12-17": 7.5,
|
"o1": 7.5,
|
||||||
"o1-preview": 7.5,
|
"o1-2024-12-17": 7.5,
|
||||||
"o1-preview-2024-09-12": 7.5,
|
"o1-preview": 7.5,
|
||||||
"o1-mini": 0.55,
|
"o1-preview-2024-09-12": 7.5,
|
||||||
"o1-mini-2024-09-12": 0.55,
|
"o1-mini": 0.55,
|
||||||
"o3-mini": 0.55,
|
"o1-mini-2024-09-12": 0.55,
|
||||||
"o3-mini-2025-01-31": 0.55,
|
"o3-mini": 0.55,
|
||||||
"o3-mini-high": 0.55,
|
"o3-mini-2025-01-31": 0.55,
|
||||||
"o3-mini-2025-01-31-high": 0.55,
|
"o3-mini-high": 0.55,
|
||||||
"o3-mini-low": 0.55,
|
"o3-mini-2025-01-31-high": 0.55,
|
||||||
"o3-mini-2025-01-31-low": 0.55,
|
"o3-mini-low": 0.55,
|
||||||
"o3-mini-medium": 0.55,
|
"o3-mini-2025-01-31-low": 0.55,
|
||||||
"o3-mini-2025-01-31-medium": 0.55,
|
"o3-mini-medium": 0.55,
|
||||||
"gpt-4o-mini": 0.075,
|
"o3-mini-2025-01-31-medium": 0.55,
|
||||||
"gpt-4o-mini-2024-07-18": 0.075,
|
"gpt-4o-mini": 0.075,
|
||||||
"gpt-4-turbo": 5, // $0.01 / 1K tokens
|
"gpt-4o-mini-2024-07-18": 0.075,
|
||||||
"gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens
|
"gpt-4-turbo": 5, // $0.01 / 1K tokens
|
||||||
"gpt-4.5-preview": 37.5,
|
"gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens
|
||||||
"gpt-4.5-preview-2025-02-27": 37.5,
|
"gpt-4.5-preview": 37.5,
|
||||||
|
"gpt-4.5-preview-2025-02-27": 37.5,
|
||||||
//"gpt-3.5-turbo-0301": 0.75, //deprecated
|
//"gpt-3.5-turbo-0301": 0.75, //deprecated
|
||||||
"gpt-3.5-turbo": 0.25,
|
"gpt-3.5-turbo": 0.25,
|
||||||
"gpt-3.5-turbo-0613": 0.75,
|
"gpt-3.5-turbo-0613": 0.75,
|
||||||
@@ -255,6 +256,7 @@ var defaultCompletionRatio = map[string]float64{
|
|||||||
"gpt-4-gizmo-*": 2,
|
"gpt-4-gizmo-*": 2,
|
||||||
"gpt-4o-gizmo-*": 3,
|
"gpt-4o-gizmo-*": 3,
|
||||||
"gpt-4-all": 2,
|
"gpt-4-all": 2,
|
||||||
|
"gpt-image-1": 8,
|
||||||
}
|
}
|
||||||
|
|
||||||
// InitModelSettings initializes all model related settings maps
|
// InitModelSettings initializes all model related settings maps
|
||||||
@@ -275,9 +277,10 @@ func InitModelSettings() {
|
|||||||
CompletionRatioMutex.Unlock()
|
CompletionRatioMutex.Unlock()
|
||||||
|
|
||||||
// Initialize cacheRatioMap
|
// Initialize cacheRatioMap
|
||||||
cacheRatioMapMutex.Lock()
|
imageRatioMapMutex.Lock()
|
||||||
cacheRatioMap = defaultCacheRatio
|
imageRatioMap = defaultImageRatio
|
||||||
cacheRatioMapMutex.Unlock()
|
imageRatioMapMutex.Unlock()
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetModelPriceMap() map[string]float64 {
|
func GetModelPriceMap() map[string]float64 {
|
||||||
@@ -548,3 +551,36 @@ func ModelRatio2JSONString() string {
|
|||||||
}
|
}
|
||||||
return string(jsonBytes)
|
return string(jsonBytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var defaultImageRatio = map[string]float64{
|
||||||
|
"gpt-image-1": 2,
|
||||||
|
}
|
||||||
|
var imageRatioMap map[string]float64
|
||||||
|
var imageRatioMapMutex sync.RWMutex
|
||||||
|
|
||||||
|
func ImageRatio2JSONString() string {
|
||||||
|
imageRatioMapMutex.RLock()
|
||||||
|
defer imageRatioMapMutex.RUnlock()
|
||||||
|
jsonBytes, err := json.Marshal(imageRatioMap)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error marshalling cache ratio: " + err.Error())
|
||||||
|
}
|
||||||
|
return string(jsonBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
func UpdateImageRatioByJSONString(jsonStr string) error {
|
||||||
|
imageRatioMapMutex.Lock()
|
||||||
|
defer imageRatioMapMutex.Unlock()
|
||||||
|
imageRatioMap = make(map[string]float64)
|
||||||
|
return json.Unmarshal([]byte(jsonStr), &imageRatioMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetImageRatio(name string) (float64, bool) {
|
||||||
|
imageRatioMapMutex.RLock()
|
||||||
|
defer imageRatioMapMutex.RUnlock()
|
||||||
|
ratio, ok := imageRatioMap[name]
|
||||||
|
if !ok {
|
||||||
|
return 1, false // Default to 1 if not found
|
||||||
|
}
|
||||||
|
return ratio, true
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user