From 61d2a2f92d00a27e12dba06d3dc4f1ac25b32323 Mon Sep 17 00:00:00 2001 From: Sh1n3zZ Date: Tue, 18 Feb 2025 01:39:13 +0800 Subject: [PATCH] feat: add Gemini Imagen image generation support --- relay/channel/gemini/adaptor.go | 103 +++++++++++++++++++++++++++++-- relay/channel/gemini/constant.go | 2 + relay/channel/gemini/dto.go | 27 ++++++++ 3 files changed, 128 insertions(+), 4 deletions(-) diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index 681e9988..32513c42 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -1,15 +1,21 @@ package gemini import ( + "encoding/json" "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" + "one-api/common" "one-api/constant" "one-api/dto" "one-api/relay/channel" relaycommon "one-api/relay/common" + "one-api/service" + + "strings" + + "github.com/gin-gonic/gin" ) type Adaptor struct { @@ -21,8 +27,36 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf } func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { - //TODO implement me - return nil, errors.New("not implemented") + if !strings.HasPrefix(info.UpstreamModelName, "imagen") { + return nil, errors.New("not supported model for image generation") + } + + // convert size to aspect ratio + aspectRatio := "1:1" // default aspect ratio + switch request.Size { + case "1024x1024": + aspectRatio = "1:1" + case "1024x1792": + aspectRatio = "9:16" + case "1792x1024": + aspectRatio = "16:9" + } + + // build gemini imagen request + geminiRequest := GeminiImageRequest{ + Instances: []GeminiImageInstance{ + { + Prompt: request.Prompt, + }, + }, + Parameters: GeminiImageParameters{ + SampleCount: request.N, + AspectRatio: aspectRatio, + PersonGeneration: "allow_adult", // default allow adult + }, + } + + return geminiRequest, nil } func (a *Adaptor) Init(info *relaycommon.RelayInfo) { @@ -40,6 +74,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { } } + if strings.HasPrefix(info.UpstreamModelName, "imagen") { + return fmt.Sprintf("%s/%s/models/%s:predict", info.BaseUrl, version, info.UpstreamModelName), nil + } + action := "generateContent" if info.IsStream { action = "streamGenerateContent?alt=sse" @@ -73,12 +111,15 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return nil, errors.New("not implemented") } - func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { + if strings.HasPrefix(info.UpstreamModelName, "imagen") { + return GeminiImageHandler(c, resp, info) + } + if info.IsStream { err, usage = GeminiChatStreamHandler(c, resp, info) } else { @@ -87,6 +128,60 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom return } +func GeminiImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { + responseBody, readErr := io.ReadAll(resp.Body) + if readErr != nil { + return nil, service.OpenAIErrorWrapper(readErr, "read_response_body_failed", http.StatusInternalServerError) + } + _ = resp.Body.Close() + + var geminiResponse GeminiImageResponse + if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil { + return nil, service.OpenAIErrorWrapper(jsonErr, "unmarshal_response_body_failed", http.StatusInternalServerError) + } + + if len(geminiResponse.Predictions) == 0 { + return nil, service.OpenAIErrorWrapper(errors.New("no images generated"), "no_images", http.StatusBadRequest) + } + + // convert to openai format response + openAIResponse := dto.ImageResponse{ + Created: common.GetTimestamp(), + Data: make([]dto.ImageData, 0, len(geminiResponse.Predictions)), + } + + for _, prediction := range geminiResponse.Predictions { + if prediction.RaiFilteredReason != "" { + continue // skip filtered image + } + openAIResponse.Data = append(openAIResponse.Data, dto.ImageData{ + B64Json: prediction.BytesBase64Encoded, + }) + } + + jsonResponse, jsonErr := json.Marshal(openAIResponse) + if jsonErr != nil { + return nil, service.OpenAIErrorWrapper(jsonErr, "marshal_response_failed", http.StatusInternalServerError) + } + + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, _ = c.Writer.Write(jsonResponse) + + // https://github.com/google-gemini/cookbook/blob/719a27d752aac33f39de18a8d3cb42a70874917e/quickstarts/Counting_Tokens.ipynb + // each image has fixed 258 tokens + const imageTokens = 258 + generatedImages := len(openAIResponse.Data) + + usage = &dto.Usage{ + PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens + CompletionTokens: 0, // image generation does not calculate completion tokens + TotalTokens: imageTokens * generatedImages, + } + + return usage, nil +} + func (a *Adaptor) GetModelList() []string { return ModelList } diff --git a/relay/channel/gemini/constant.go b/relay/channel/gemini/constant.go index 9651bd60..b7c1f0cf 100644 --- a/relay/channel/gemini/constant.go +++ b/relay/channel/gemini/constant.go @@ -16,6 +16,8 @@ var ModelList = []string{ "gemini-2.0-pro-exp", // thinking exp "gemini-2.0-flash-thinking-exp", + // imagen models + "imagen-3.0-generate-002", } var ChannelName = "google gemini" diff --git a/relay/channel/gemini/dto.go b/relay/channel/gemini/dto.go index 08a5db84..bbcb1248 100644 --- a/relay/channel/gemini/dto.go +++ b/relay/channel/gemini/dto.go @@ -109,3 +109,30 @@ type GeminiUsageMetadata struct { CandidatesTokenCount int `json:"candidatesTokenCount"` TotalTokenCount int `json:"totalTokenCount"` } + +// Imagen related structs +type GeminiImageRequest struct { + Instances []GeminiImageInstance `json:"instances"` + Parameters GeminiImageParameters `json:"parameters"` +} + +type GeminiImageInstance struct { + Prompt string `json:"prompt"` +} + +type GeminiImageParameters struct { + SampleCount int `json:"sampleCount,omitempty"` + AspectRatio string `json:"aspectRatio,omitempty"` + PersonGeneration string `json:"personGeneration,omitempty"` +} + +type GeminiImageResponse struct { + Predictions []GeminiImagePrediction `json:"predictions"` +} + +type GeminiImagePrediction struct { + MimeType string `json:"mimeType"` + BytesBase64Encoded string `json:"bytesBase64Encoded"` + RaiFilteredReason string `json:"raiFilteredReason,omitempty"` + SafetyAttributes any `json:"safetyAttributes,omitempty"` +}