From 30fb349d912a5c65fa2f68e949a3771741bf968b Mon Sep 17 00:00:00 2001 From: CaIon Date: Sat, 5 Jul 2025 14:14:40 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat(endpoint=20types):=20add=20sup?= =?UTF-8?q?port=20for=20image=20generation=20models=20in=20endpoint=20type?= =?UTF-8?q?=20handling?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/endpoint_type.go | 4 ++++ common/model.go | 23 ++++++++++++++++++++++- constant/endpoint_type.go | 11 ++++++----- 3 files changed, 32 insertions(+), 6 deletions(-) diff --git a/common/endpoint_type.go b/common/endpoint_type.go index 578fe096..a0ca73ea 100644 --- a/common/endpoint_type.go +++ b/common/endpoint_type.go @@ -33,5 +33,9 @@ func GetEndpointTypesByChannelType(channelType int, modelName string) []constant endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAI} } } + if IsImageGenerationModel(modelName) { + // add to first + endpointTypes = append([]constant.EndpointType{constant.EndpointTypeImageGeneration}, endpointTypes...) + } return endpointTypes } diff --git a/common/model.go b/common/model.go index 6afb1540..14ca1911 100644 --- a/common/model.go +++ b/common/model.go @@ -9,11 +9,32 @@ var ( "o3-deep-research", "o4-mini-deep-research", } + ImageGenerationModels = []string{ + "dall-e-3", + "dall-e-2", + "gpt-image-1", + "prefix:imagen-", + "flux-", + "flux.1-", + } ) func IsOpenAIResponseOnlyModel(modelName string) bool { for _, m := range OpenAIResponseOnlyModels { - if strings.Contains(m, modelName) { + if strings.Contains(modelName, m) { + return true + } + } + return false +} + +func IsImageGenerationModel(modelName string) bool { + modelName = strings.ToLower(modelName) + for _, m := range ImageGenerationModels { + if strings.Contains(modelName, m) { + return true + } + if strings.HasPrefix(m, "prefix:") && strings.HasPrefix(modelName, strings.TrimPrefix(m, "prefix:")) { return true } } diff --git a/constant/endpoint_type.go b/constant/endpoint_type.go index c9aac94a..ef096b75 100644 --- a/constant/endpoint_type.go +++ b/constant/endpoint_type.go @@ -3,11 +3,12 @@ package constant type EndpointType string const ( - EndpointTypeOpenAI EndpointType = "openai" - EndpointTypeOpenAIResponse EndpointType = "openai-response" - EndpointTypeAnthropic EndpointType = "anthropic" - EndpointTypeGemini EndpointType = "gemini" - EndpointTypeJinaRerank EndpointType = "jina-rerank" + EndpointTypeOpenAI EndpointType = "openai" + EndpointTypeOpenAIResponse EndpointType = "openai-response" + EndpointTypeAnthropic EndpointType = "anthropic" + EndpointTypeGemini EndpointType = "gemini" + EndpointTypeJinaRerank EndpointType = "jina-rerank" + EndpointTypeImageGeneration EndpointType = "image-generation" //EndpointTypeMidjourney EndpointType = "midjourney-proxy" //EndpointTypeSuno EndpointType = "suno-proxy" //EndpointTypeKling EndpointType = "kling"