From 43a7b59b68a4137b0021689c4c1020dae62ec235 Mon Sep 17 00:00:00 2001 From: MartialBE Date: Sat, 21 Dec 2024 16:01:17 +0800 Subject: [PATCH] feat: support for Gemini structured output. --- dto/openai_request.go | 68 ++++++++++++++++------------ relay/channel/gemini/dto.go | 14 +++--- relay/channel/gemini/relay-gemini.go | 50 ++++++++++++++++++++ 3 files changed, 96 insertions(+), 36 deletions(-) diff --git a/dto/openai_request.go b/dto/openai_request.go index da14cdfe..e85605da 100644 --- a/dto/openai_request.go +++ b/dto/openai_request.go @@ -3,39 +3,47 @@ package dto import "encoding/json" type ResponseFormat struct { - Type string `json:"type,omitempty"` + Type string `json:"type,omitempty"` + JsonSchema *FormatJsonSchema `json:"json_schema,omitempty"` +} + +type FormatJsonSchema struct { + Description string `json:"description,omitempty"` + Name string `json:"name"` + Schema any `json:"schema,omitempty"` + Strict any `json:"strict,omitempty"` } type GeneralOpenAIRequest struct { - Model string `json:"model,omitempty"` - Messages []Message `json:"messages,omitempty"` - Prompt any `json:"prompt,omitempty"` - Stream bool `json:"stream,omitempty"` - StreamOptions *StreamOptions `json:"stream_options,omitempty"` - MaxTokens uint `json:"max_tokens,omitempty"` - MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - TopK int `json:"top_k,omitempty"` - Stop any `json:"stop,omitempty"` - N int `json:"n,omitempty"` - Input any `json:"input,omitempty"` - Instruction string `json:"instruction,omitempty"` - Size string `json:"size,omitempty"` - Functions any `json:"functions,omitempty"` - FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` - PresencePenalty float64 `json:"presence_penalty,omitempty"` - ResponseFormat any `json:"response_format,omitempty"` - EncodingFormat any `json:"encoding_format,omitempty"` - Seed float64 `json:"seed,omitempty"` - Tools []ToolCall `json:"tools,omitempty"` - ToolChoice any `json:"tool_choice,omitempty"` - User string `json:"user,omitempty"` - LogProbs bool `json:"logprobs,omitempty"` - TopLogProbs int `json:"top_logprobs,omitempty"` - Dimensions int `json:"dimensions,omitempty"` - Modalities any `json:"modalities,omitempty"` - Audio any `json:"audio,omitempty"` + Model string `json:"model,omitempty"` + Messages []Message `json:"messages,omitempty"` + Prompt any `json:"prompt,omitempty"` + Stream bool `json:"stream,omitempty"` + StreamOptions *StreamOptions `json:"stream_options,omitempty"` + MaxTokens uint `json:"max_tokens,omitempty"` + MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + Stop any `json:"stop,omitempty"` + N int `json:"n,omitempty"` + Input any `json:"input,omitempty"` + Instruction string `json:"instruction,omitempty"` + Size string `json:"size,omitempty"` + Functions any `json:"functions,omitempty"` + FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` + PresencePenalty float64 `json:"presence_penalty,omitempty"` + ResponseFormat *ResponseFormat `json:"response_format,omitempty"` + EncodingFormat any `json:"encoding_format,omitempty"` + Seed float64 `json:"seed,omitempty"` + Tools []ToolCall `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` + User string `json:"user,omitempty"` + LogProbs bool `json:"logprobs,omitempty"` + TopLogProbs int `json:"top_logprobs,omitempty"` + Dimensions int `json:"dimensions,omitempty"` + Modalities any `json:"modalities,omitempty"` + Audio any `json:"audio,omitempty"` } type OpenAITools struct { diff --git a/relay/channel/gemini/dto.go b/relay/channel/gemini/dto.go index fa461a37..16027b41 100644 --- a/relay/channel/gemini/dto.go +++ b/relay/channel/gemini/dto.go @@ -40,12 +40,14 @@ type GeminiChatTools struct { } type GeminiChatGenerationConfig struct { - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"topP,omitempty"` - TopK float64 `json:"topK,omitempty"` - MaxOutputTokens uint `json:"maxOutputTokens,omitempty"` - CandidateCount int `json:"candidateCount,omitempty"` - StopSequences []string `json:"stopSequences,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"topP,omitempty"` + TopK float64 `json:"topK,omitempty"` + MaxOutputTokens uint `json:"maxOutputTokens,omitempty"` + CandidateCount int `json:"candidateCount,omitempty"` + StopSequences []string `json:"stopSequences,omitempty"` + ResponseMimeType string `json:"responseMimeType,omitempty"` + ResponseSchema any `json:"responseSchema,omitempty"` } type GeminiChatCandidate struct { diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index 7a0a9694..17786326 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -77,6 +77,16 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque }, } } + + if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") { + geminiRequest.GenerationConfig.ResponseMimeType = "application/json" + + if textRequest.ResponseFormat.JsonSchema != nil && textRequest.ResponseFormat.JsonSchema.Schema != nil { + cleanedSchema := removeAdditionalPropertiesWithDepth(textRequest.ResponseFormat.JsonSchema.Schema, 0) + geminiRequest.GenerationConfig.ResponseSchema = cleanedSchema + } + } + //shouldAddDummyModelMessage := false for _, message := range textRequest.Messages { @@ -165,6 +175,46 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque return &geminiRequest, nil } +func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interface{} { + if depth >= 5 { + return schema + } + + v, ok := schema.(map[string]interface{}) + if !ok || len(v) == 0 { + return schema + } + + // 如果type不为object和array,则直接返回 + if typeVal, exists := v["type"]; !exists || (typeVal != "object" && typeVal != "array") { + return schema + } + + switch v["type"] { + case "object": + delete(v, "additionalProperties") + // 处理 properties + if properties, ok := v["properties"].(map[string]interface{}); ok { + for key, value := range properties { + properties[key] = removeAdditionalPropertiesWithDepth(value, depth+1) + } + } + for _, field := range []string{"allOf", "anyOf", "oneOf"} { + if nested, ok := v[field].([]interface{}); ok { + for i, item := range nested { + nested[i] = removeAdditionalPropertiesWithDepth(item, depth+1) + } + } + } + case "array": + if items, ok := v["items"].(map[string]interface{}); ok { + v["items"] = removeAdditionalPropertiesWithDepth(items, depth+1) + } + } + + return v +} + func (g *GeminiChatResponse) GetResponseText() string { if g == nil { return ""