From f9f32a0158e5f4a0648498adf95f069ba7c9b21f Mon Sep 17 00:00:00 2001 From: xyfacai Date: Thu, 24 Apr 2025 19:25:08 +0800 Subject: [PATCH] feat: support /images/edit (cherry picked from commit 1c0a1238787d490f02dd9269b616580a16604180) --- controller/relay.go | 2 +- dto/openai_response.go | 12 +- middleware/distributor.go | 4 +- relay/channel/openai/adaptor.go | 61 +++++++- relay/channel/openai/relay-openai.go | 49 +++++++ relay/constant/relay_mode.go | 3 + relay/helper/price.go | 8 +- relay/relay-image.go | 171 ++++++++++++++--------- relay/relay-text.go | 11 ++ router/relay-router.go | 2 +- setting/operation_setting/model-ratio.go | 82 ++++++++--- 11 files changed, 304 insertions(+), 101 deletions(-) diff --git a/controller/relay.go b/controller/relay.go index fb4c524f..e7b20d50 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -24,7 +24,7 @@ import ( func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { var err *dto.OpenAIErrorWithStatusCode switch relayMode { - case relayconstant.RelayModeImagesGenerations: + case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits: err = relay.ImageHelper(c) case relayconstant.RelayModeAudioSpeech: fallthrough diff --git a/dto/openai_response.go b/dto/openai_response.go index ddd1a907..c2100ec8 100644 --- a/dto/openai_response.go +++ b/dto/openai_response.go @@ -166,12 +166,16 @@ type CompletionsStreamResponse struct { } type Usage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` - PromptCacheHitTokens int `json:"prompt_cache_hit_tokens,omitempty"` + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + PromptCacheHitTokens int `json:"prompt_cache_hit_tokens,omitempty"` + PromptTokensDetails InputTokenDetails `json:"prompt_tokens_details"` CompletionTokenDetails OutputTokenDetails `json:"completion_tokens_details"` + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + InputTokensDetails *InputTokenDetails `json:"input_tokens_details"` } type InputTokenDetails struct { diff --git a/middleware/distributor.go b/middleware/distributor.go index fc9f5512..51fd8fd1 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -162,7 +162,7 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { } c.Set("platform", string(constant.TaskPlatformSuno)) c.Set("relay_mode", relayMode) - } else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") { + } else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") && !strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") { err = common.UnmarshalBodyReusable(c, &modelRequest) } if err != nil { @@ -184,6 +184,8 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { } if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "dall-e") + } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") { + modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "gpt-image-1") } if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { relayMode := relayconstant.RelayModeAudioSpeech diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index ef11b4fe..6614d116 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -236,11 +236,64 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf } func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { - return request, nil + switch info.RelayMode { + case constant.RelayModeImagesEdits: + body, err := common.GetRequestBody(c) + if err != nil { + return nil, errors.New("get request body fail") + } + return bytes.NewReader(body), nil + + /*var requestBody bytes.Buffer + writer := multipart.NewWriter(&requestBody) + + writer.WriteField("model", request.Model) + // 获取所有表单字段 + formData := c.Request.PostForm + // 遍历表单字段并打印输出 + for key, values := range formData { + if key == "model" { + continue + } + for _, value := range values { + writer.WriteField(key, value) + } + } + + // 添加文件字段 + imageFiles := c.Request.MultipartForm.File["image[]"] + for _, file := range imageFiles { + part, err := writer.CreateFormFile("image[]", file.Filename) + if err != nil { + return nil, errors.New("create form file failed") + } + // 打开文件 + src, err := file.Open() + if err != nil { + return nil, errors.New("open file failed") + } + // 将文件数据写入 form part + _, err = io.Copy(part, src) + if err != nil { + return nil, errors.New("copy file failed") + } + src.Close() + } + + // 关闭 multipart 编写器以设置分界线 + writer.Close() + c.Request.Header.Set("Content-Type", writer.FormDataContentType()) + return bytes.NewReader(requestBody.Bytes()), nil*/ + + default: + return request, nil + } } func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { - if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation { + if info.RelayMode == constant.RelayModeAudioTranscription || + info.RelayMode == constant.RelayModeAudioTranslation || + info.RelayMode == constant.RelayModeImagesEdits { return channel.DoFormRequest(a, c, info, requestBody) } else if info.RelayMode == constant.RelayModeRealtime { return channel.DoWssRequest(a, c, info, requestBody) @@ -259,8 +312,8 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom fallthrough case constant.RelayModeAudioTranscription: err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat) - case constant.RelayModeImagesGenerations: - err, usage = OpenaiTTSHandler(c, resp, info) + case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits: + err, usage = OpenaiHandlerWithUsage(c, resp, info) case constant.RelayModeRerank: err, usage = common_handler.RerankHandler(c, info, resp) default: diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 7e06ea12..b9ed94e2 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -595,3 +595,52 @@ func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.R err := service.PreWssConsumeQuota(ctx, info, usage) return err } + +func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + // Reset response body + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + // We shouldn't set the header before we parse the response body, because the parse part may fail. + // And then we will have to send an error response, but in this case, the header has already been set. + // So the httpClient will be confused by the response. + // For example, Postman will report error, and we cannot check the response at all. + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + // reset content length + c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(responseBody))) + c.Writer.WriteHeader(resp.StatusCode) + _, err = io.Copy(c.Writer, resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + + var usageResp dto.SimpleResponse + err = json.Unmarshal(responseBody, &usageResp) + if err != nil { + return service.OpenAIErrorWrapper(err, "parse_response_body_failed", http.StatusInternalServerError), nil + } + // format + if usageResp.InputTokens > 0 { + usageResp.PromptTokens += usageResp.InputTokens + } + if usageResp.OutputTokens > 0 { + usageResp.CompletionTokens += usageResp.OutputTokens + } + if usageResp.InputTokensDetails != nil { + usageResp.PromptTokensDetails.ImageTokens += usageResp.InputTokensDetails.ImageTokens + usageResp.PromptTokensDetails.TextTokens += usageResp.InputTokensDetails.TextTokens + } + return nil, &usageResp.Usage +} diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go index 845166c3..e2d51098 100644 --- a/relay/constant/relay_mode.go +++ b/relay/constant/relay_mode.go @@ -12,6 +12,7 @@ const ( RelayModeEmbeddings RelayModeModerations RelayModeImagesGenerations + RelayModeImagesEdits RelayModeEdits RelayModeMidjourneyImagine @@ -56,6 +57,8 @@ func Path2RelayMode(path string) int { relayMode = RelayModeModerations } else if strings.HasPrefix(path, "/v1/images/generations") { relayMode = RelayModeImagesGenerations + } else if strings.HasPrefix(path, "/v1/images/edits") { + relayMode = RelayModeImagesEdits } else if strings.HasPrefix(path, "/v1/edits") { relayMode = RelayModeEdits } else if strings.HasPrefix(path, "/v1/audio/speech") { diff --git a/relay/helper/price.go b/relay/helper/price.go index a68cd54d..899c72b9 100644 --- a/relay/helper/price.go +++ b/relay/helper/price.go @@ -15,14 +15,15 @@ type PriceData struct { ModelRatio float64 CompletionRatio float64 CacheRatio float64 + CacheCreationRatio float64 + ImageRatio float64 GroupRatio float64 UsePrice bool - CacheCreationRatio float64 ShouldPreConsumedQuota int } func (p PriceData) ToSetting() string { - return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota) + return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %d", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio) } func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) { @@ -32,6 +33,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens var modelRatio float64 var completionRatio float64 var cacheRatio float64 + var imageRatio float64 var cacheCreationRatio float64 if !usePrice { preConsumedTokens := common.PreConsumedQuota @@ -55,6 +57,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens completionRatio = operation_setting.GetCompletionRatio(info.OriginModelName) cacheRatio, _ = operation_setting.GetCacheRatio(info.OriginModelName) cacheCreationRatio, _ = operation_setting.GetCreateCacheRatio(info.OriginModelName) + imageRatio, _ = operation_setting.GetImageRatio(info.OriginModelName) ratio := modelRatio * groupRatio preConsumedQuota = int(float64(preConsumedTokens) * ratio) } else { @@ -68,6 +71,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens GroupRatio: groupRatio, UsePrice: usePrice, CacheRatio: cacheRatio, + ImageRatio: imageRatio, CacheCreationRatio: cacheCreationRatio, ShouldPreConsumedQuota: preConsumedQuota, } diff --git a/relay/relay-image.go b/relay/relay-image.go index f9f542a7..15763298 100644 --- a/relay/relay-image.go +++ b/relay/relay-image.go @@ -12,6 +12,7 @@ import ( "one-api/dto" "one-api/model" relaycommon "one-api/relay/common" + relayconstant "one-api/relay/constant" "one-api/relay/helper" "one-api/service" "one-api/setting" @@ -20,13 +21,56 @@ import ( func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.ImageRequest, error) { imageRequest := &dto.ImageRequest{} - err := common.UnmarshalBodyReusable(c, imageRequest) - if err != nil { - return nil, err + + switch info.RelayMode { + case relayconstant.RelayModeImagesEdits: + _, err := c.MultipartForm() + if err != nil { + return nil, err + } + formData := c.Request.PostForm + imageRequest.Prompt = formData.Get("prompt") + imageRequest.Model = formData.Get("model") + imageRequest.N = common.String2Int(formData.Get("n")) + imageRequest.Quality = formData.Get("quality") + imageRequest.Size = formData.Get("size") + + if imageRequest.Model == "gpt-image-1" { + if imageRequest.Quality == "" { + imageRequest.Quality = "standard" + } + } + default: + err := common.UnmarshalBodyReusable(c, imageRequest) + if err != nil { + return nil, err + } + // Not "256x256", "512x512", or "1024x1024" + if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" { + if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" { + return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024") + } + } else if imageRequest.Model == "dall-e-3" { + if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" { + return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024") + } + if imageRequest.Quality == "" { + imageRequest.Quality = "standard" + } + // N should between 1 and 10 + //if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) { + // return service.OpenAIErrorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest) + //} + } } + if imageRequest.Prompt == "" { return nil, errors.New("prompt is required") } + + if imageRequest.Model == "" { + imageRequest.Model = "dall-e-2" + } if strings.Contains(imageRequest.Size, "×") { return nil, errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'") } @@ -36,30 +80,7 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto. if imageRequest.Size == "" { imageRequest.Size = "1024x1024" } - if imageRequest.Model == "" { - imageRequest.Model = "dall-e-2" - } - // Not "256x256", "512x512", or "1024x1024" - if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" { - if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" { - return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024") - } - } else if imageRequest.Model == "dall-e-3" { - if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" { - return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024") - } - if imageRequest.Quality == "" { - imageRequest.Quality = "standard" - } - //if imageRequest.N != 1 { - // return nil, errors.New("n must be 1") - //} - } - // N should between 1 and 10 - //if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) { - // return service.OpenAIErrorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest) - //} if setting.ShouldCheckPromptSensitive() { words, err := service.CheckSensitiveInput(imageRequest.Prompt) if err != nil { @@ -86,43 +107,59 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { imageRequest.Model = relayInfo.UpstreamModelName - priceData, err := helper.ModelPriceHelper(c, relayInfo, 0, 0) + priceData, err := helper.ModelPriceHelper(c, relayInfo, len(imageRequest.Prompt), 0) if err != nil { return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError) } + var preConsumedQuota int + var quota int + var userQuota int if !priceData.UsePrice { // modelRatio 16 = modelPrice $0.04 // per 1 modelRatio = $0.04 / 16 - priceData.ModelPrice = 0.0025 * priceData.ModelRatio - } - - userQuota, err := model.GetUserQuota(relayInfo.UserId, false) - - sizeRatio := 1.0 - // Size - if imageRequest.Size == "256x256" { - sizeRatio = 0.4 - } else if imageRequest.Size == "512x512" { - sizeRatio = 0.45 - } else if imageRequest.Size == "1024x1024" { - sizeRatio = 1 - } else if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" { - sizeRatio = 2 - } - - qualityRatio := 1.0 - if imageRequest.Model == "dall-e-3" && imageRequest.Quality == "hd" { - qualityRatio = 2.0 - if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" { - qualityRatio = 1.5 + // priceData.ModelPrice = 0.0025 * priceData.ModelRatio + var openaiErr *dto.OpenAIErrorWithStatusCode + preConsumedQuota, userQuota, openaiErr = preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) + if openaiErr != nil { + return openaiErr } - } + defer func() { + if openaiErr != nil { + returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) + } + }() - priceData.ModelPrice *= sizeRatio * qualityRatio * float64(imageRequest.N) - quota := int(priceData.ModelPrice * priceData.GroupRatio * common.QuotaPerUnit) + } else { + sizeRatio := 1.0 + // Size + if imageRequest.Size == "256x256" { + sizeRatio = 0.4 + } else if imageRequest.Size == "512x512" { + sizeRatio = 0.45 + } else if imageRequest.Size == "1024x1024" { + sizeRatio = 1 + } else if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" { + sizeRatio = 2 + } - if userQuota-quota < 0 { - return service.OpenAIErrorWrapperLocal(fmt.Errorf("image pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)), "insufficient_user_quota", http.StatusForbidden) + qualityRatio := 1.0 + if imageRequest.Model == "dall-e-3" && imageRequest.Quality == "hd" { + qualityRatio = 2.0 + if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" { + qualityRatio = 1.5 + } + } + + // reset model price + priceData.ModelPrice *= sizeRatio * qualityRatio * float64(imageRequest.N) + quota = int(priceData.ModelPrice * priceData.GroupRatio * common.QuotaPerUnit) + userQuota, err = model.GetUserQuota(relayInfo.UserId, false) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError) + } + if userQuota-quota < 0 { + return service.OpenAIErrorWrapperLocal(fmt.Errorf("image pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)), "insufficient_user_quota", http.StatusForbidden) + } } adaptor := GetAdaptor(relayInfo.ApiType) @@ -137,12 +174,15 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { if err != nil { return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError) } - - jsonData, err := json.Marshal(convertedRequest) - if err != nil { - return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError) + if relayInfo.RelayMode == relayconstant.RelayModeImagesEdits { + requestBody = convertedRequest.(io.Reader) + } else { + jsonData, err := json.Marshal(convertedRequest) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError) + } + requestBody = bytes.NewBuffer(jsonData) } - requestBody = bytes.NewBuffer(jsonData) statusCodeMappingStr := c.GetString("status_code_mapping") @@ -162,24 +202,25 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { } } - _, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo) + usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo) if openaiErr != nil { // reset status code 重置状态码 service.ResetStatusCode(openaiErr, statusCodeMappingStr) return openaiErr } - usage := &dto.Usage{ - PromptTokens: imageRequest.N, - TotalTokens: imageRequest.N, + if usage.(*dto.Usage).TotalTokens == 0 { + usage.(*dto.Usage).TotalTokens = imageRequest.N + } + if usage.(*dto.Usage).PromptTokens == 0 { + usage.(*dto.Usage).PromptTokens = imageRequest.N } - quality := "standard" if imageRequest.Quality == "hd" { quality = "hd" } logContent := fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality) - postConsumeQuota(c, relayInfo, usage, 0, userQuota, priceData, logContent) + postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, logContent) return nil } diff --git a/relay/relay-text.go b/relay/relay-text.go index 7b2b7fc0..d5625409 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -331,12 +331,14 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() promptTokens := usage.PromptTokens cacheTokens := usage.PromptTokensDetails.CachedTokens + imageTokens := usage.PromptTokensDetails.ImageTokens completionTokens := usage.CompletionTokens modelName := relayInfo.OriginModelName tokenName := ctx.GetString("token_name") completionRatio := priceData.CompletionRatio cacheRatio := priceData.CacheRatio + imageRatio := priceData.ImageRatio modelRatio := priceData.ModelRatio groupRatio := priceData.GroupRatio modelPrice := priceData.ModelPrice @@ -344,9 +346,11 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, // Convert values to decimal for precise calculation dPromptTokens := decimal.NewFromInt(int64(promptTokens)) dCacheTokens := decimal.NewFromInt(int64(cacheTokens)) + dImageTokens := decimal.NewFromInt(int64(imageTokens)) dCompletionTokens := decimal.NewFromInt(int64(completionTokens)) dCompletionRatio := decimal.NewFromFloat(completionRatio) dCacheRatio := decimal.NewFromFloat(cacheRatio) + dImageRatio := decimal.NewFromFloat(imageRatio) dModelRatio := decimal.NewFromFloat(modelRatio) dGroupRatio := decimal.NewFromFloat(groupRatio) dModelPrice := decimal.NewFromFloat(modelPrice) @@ -358,7 +362,14 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, if !priceData.UsePrice { nonCachedTokens := dPromptTokens.Sub(dCacheTokens) cachedTokensWithRatio := dCacheTokens.Mul(dCacheRatio) + promptQuota := nonCachedTokens.Add(cachedTokensWithRatio) + if imageTokens > 0 { + nonImageTokens := dPromptTokens.Sub(dImageTokens) + imageTokensWithRatio := dImageTokens.Mul(dImageRatio) + promptQuota = nonImageTokens.Add(imageTokensWithRatio) + } + completionQuota := dCompletionTokens.Mul(dCompletionRatio) quotaCalculateDecimal = promptQuota.Add(completionQuota).Mul(ratio) diff --git a/router/relay-router.go b/router/relay-router.go index 3a9122d4..85000beb 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -40,7 +40,7 @@ func SetRelayRouter(router *gin.Engine) { httpRouter.POST("/chat/completions", controller.Relay) httpRouter.POST("/edits", controller.Relay) httpRouter.POST("/images/generations", controller.Relay) - httpRouter.POST("/images/edits", controller.RelayNotImplemented) + httpRouter.POST("/images/edits", controller.Relay) httpRouter.POST("/images/variations", controller.RelayNotImplemented) httpRouter.POST("/embeddings", controller.Relay) httpRouter.POST("/engines/:model/embeddings", controller.Relay) diff --git a/setting/operation_setting/model-ratio.go b/setting/operation_setting/model-ratio.go index 6a80ef1a..2b57c4d0 100644 --- a/setting/operation_setting/model-ratio.go +++ b/setting/operation_setting/model-ratio.go @@ -51,26 +51,27 @@ var defaultModelRatio = map[string]float64{ "gpt-4o-realtime-preview-2024-12-17": 2.5, "gpt-4o-mini-realtime-preview": 0.3, "gpt-4o-mini-realtime-preview-2024-12-17": 0.3, - "o1": 7.5, - "o1-2024-12-17": 7.5, - "o1-preview": 7.5, - "o1-preview-2024-09-12": 7.5, - "o1-mini": 0.55, - "o1-mini-2024-09-12": 0.55, - "o3-mini": 0.55, - "o3-mini-2025-01-31": 0.55, - "o3-mini-high": 0.55, - "o3-mini-2025-01-31-high": 0.55, - "o3-mini-low": 0.55, - "o3-mini-2025-01-31-low": 0.55, - "o3-mini-medium": 0.55, - "o3-mini-2025-01-31-medium": 0.55, - "gpt-4o-mini": 0.075, - "gpt-4o-mini-2024-07-18": 0.075, - "gpt-4-turbo": 5, // $0.01 / 1K tokens - "gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens - "gpt-4.5-preview": 37.5, - "gpt-4.5-preview-2025-02-27": 37.5, + "gpt-image-1": 2.5, + "o1": 7.5, + "o1-2024-12-17": 7.5, + "o1-preview": 7.5, + "o1-preview-2024-09-12": 7.5, + "o1-mini": 0.55, + "o1-mini-2024-09-12": 0.55, + "o3-mini": 0.55, + "o3-mini-2025-01-31": 0.55, + "o3-mini-high": 0.55, + "o3-mini-2025-01-31-high": 0.55, + "o3-mini-low": 0.55, + "o3-mini-2025-01-31-low": 0.55, + "o3-mini-medium": 0.55, + "o3-mini-2025-01-31-medium": 0.55, + "gpt-4o-mini": 0.075, + "gpt-4o-mini-2024-07-18": 0.075, + "gpt-4-turbo": 5, // $0.01 / 1K tokens + "gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens + "gpt-4.5-preview": 37.5, + "gpt-4.5-preview-2025-02-27": 37.5, //"gpt-3.5-turbo-0301": 0.75, //deprecated "gpt-3.5-turbo": 0.25, "gpt-3.5-turbo-0613": 0.75, @@ -255,6 +256,7 @@ var defaultCompletionRatio = map[string]float64{ "gpt-4-gizmo-*": 2, "gpt-4o-gizmo-*": 3, "gpt-4-all": 2, + "gpt-image-1": 8, } // InitModelSettings initializes all model related settings maps @@ -275,9 +277,10 @@ func InitModelSettings() { CompletionRatioMutex.Unlock() // Initialize cacheRatioMap - cacheRatioMapMutex.Lock() - cacheRatioMap = defaultCacheRatio - cacheRatioMapMutex.Unlock() + imageRatioMapMutex.Lock() + imageRatioMap = defaultImageRatio + imageRatioMapMutex.Unlock() + } func GetModelPriceMap() map[string]float64 { @@ -548,3 +551,36 @@ func ModelRatio2JSONString() string { } return string(jsonBytes) } + +var defaultImageRatio = map[string]float64{ + "gpt-image-1": 2, +} +var imageRatioMap map[string]float64 +var imageRatioMapMutex sync.RWMutex + +func ImageRatio2JSONString() string { + imageRatioMapMutex.RLock() + defer imageRatioMapMutex.RUnlock() + jsonBytes, err := json.Marshal(imageRatioMap) + if err != nil { + common.SysError("error marshalling cache ratio: " + err.Error()) + } + return string(jsonBytes) +} + +func UpdateImageRatioByJSONString(jsonStr string) error { + imageRatioMapMutex.Lock() + defer imageRatioMapMutex.Unlock() + imageRatioMap = make(map[string]float64) + return json.Unmarshal([]byte(jsonStr), &imageRatioMap) +} + +func GetImageRatio(name string) (float64, bool) { + imageRatioMapMutex.RLock() + defer imageRatioMapMutex.RUnlock() + ratio, ok := imageRatioMap[name] + if !ok { + return 1, false // Default to 1 if not found + } + return ratio, true +}