diff --git a/common/endpoint_type.go b/common/endpoint_type.go index 578fe096..a0ca73ea 100644 --- a/common/endpoint_type.go +++ b/common/endpoint_type.go @@ -33,5 +33,9 @@ func GetEndpointTypesByChannelType(channelType int, modelName string) []constant endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAI} } } + if IsImageGenerationModel(modelName) { + // add to first + endpointTypes = append([]constant.EndpointType{constant.EndpointTypeImageGeneration}, endpointTypes...) + } return endpointTypes } diff --git a/common/model.go b/common/model.go index 6afb1540..14ca1911 100644 --- a/common/model.go +++ b/common/model.go @@ -9,11 +9,32 @@ var ( "o3-deep-research", "o4-mini-deep-research", } + ImageGenerationModels = []string{ + "dall-e-3", + "dall-e-2", + "gpt-image-1", + "prefix:imagen-", + "flux-", + "flux.1-", + } ) func IsOpenAIResponseOnlyModel(modelName string) bool { for _, m := range OpenAIResponseOnlyModels { - if strings.Contains(m, modelName) { + if strings.Contains(modelName, m) { + return true + } + } + return false +} + +func IsImageGenerationModel(modelName string) bool { + modelName = strings.ToLower(modelName) + for _, m := range ImageGenerationModels { + if strings.Contains(modelName, m) { + return true + } + if strings.HasPrefix(m, "prefix:") && strings.HasPrefix(modelName, strings.TrimPrefix(m, "prefix:")) { return true } } diff --git a/constant/endpoint_type.go b/constant/endpoint_type.go index c9aac94a..ef096b75 100644 --- a/constant/endpoint_type.go +++ b/constant/endpoint_type.go @@ -3,11 +3,12 @@ package constant type EndpointType string const ( - EndpointTypeOpenAI EndpointType = "openai" - EndpointTypeOpenAIResponse EndpointType = "openai-response" - EndpointTypeAnthropic EndpointType = "anthropic" - EndpointTypeGemini EndpointType = "gemini" - EndpointTypeJinaRerank EndpointType = "jina-rerank" + EndpointTypeOpenAI EndpointType = "openai" + EndpointTypeOpenAIResponse EndpointType = "openai-response" + EndpointTypeAnthropic EndpointType = "anthropic" + EndpointTypeGemini EndpointType = "gemini" + EndpointTypeJinaRerank EndpointType = "jina-rerank" + EndpointTypeImageGeneration EndpointType = "image-generation" //EndpointTypeMidjourney EndpointType = "midjourney-proxy" //EndpointTypeSuno EndpointType = "suno-proxy" //EndpointTypeKling EndpointType = "kling"