Merge pull request #645 from MartialBE/gemini_res_format

支持gemini结构化输出
This commit is contained in:
Calcium-Ion
2024-12-21 16:40:51 +08:00
committed by GitHub
3 changed files with 96 additions and 36 deletions

View File

@@ -3,39 +3,47 @@ package dto
import "encoding/json"
type ResponseFormat struct {
Type string `json:"type,omitempty"`
Type string `json:"type,omitempty"`
JsonSchema *FormatJsonSchema `json:"json_schema,omitempty"`
}
type FormatJsonSchema struct {
Description string `json:"description,omitempty"`
Name string `json:"name"`
Schema any `json:"schema,omitempty"`
Strict any `json:"strict,omitempty"`
}
type GeneralOpenAIRequest struct {
Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"`
Stream bool `json:"stream,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
N int `json:"n,omitempty"`
Input any `json:"input,omitempty"`
Instruction string `json:"instruction,omitempty"`
Size string `json:"size,omitempty"`
Functions any `json:"functions,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
ResponseFormat any `json:"response_format,omitempty"`
EncodingFormat any `json:"encoding_format,omitempty"`
Seed float64 `json:"seed,omitempty"`
Tools []ToolCall `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
User string `json:"user,omitempty"`
LogProbs bool `json:"logprobs,omitempty"`
TopLogProbs int `json:"top_logprobs,omitempty"`
Dimensions int `json:"dimensions,omitempty"`
Modalities any `json:"modalities,omitempty"`
Audio any `json:"audio,omitempty"`
Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"`
Stream bool `json:"stream,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
N int `json:"n,omitempty"`
Input any `json:"input,omitempty"`
Instruction string `json:"instruction,omitempty"`
Size string `json:"size,omitempty"`
Functions any `json:"functions,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
EncodingFormat any `json:"encoding_format,omitempty"`
Seed float64 `json:"seed,omitempty"`
Tools []ToolCall `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
User string `json:"user,omitempty"`
LogProbs bool `json:"logprobs,omitempty"`
TopLogProbs int `json:"top_logprobs,omitempty"`
Dimensions int `json:"dimensions,omitempty"`
Modalities any `json:"modalities,omitempty"`
Audio any `json:"audio,omitempty"`
}
type OpenAITools struct {

View File

@@ -40,12 +40,14 @@ type GeminiChatTools struct {
}
type GeminiChatGenerationConfig struct {
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK float64 `json:"topK,omitempty"`
MaxOutputTokens uint `json:"maxOutputTokens,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
StopSequences []string `json:"stopSequences,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK float64 `json:"topK,omitempty"`
MaxOutputTokens uint `json:"maxOutputTokens,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
StopSequences []string `json:"stopSequences,omitempty"`
ResponseMimeType string `json:"responseMimeType,omitempty"`
ResponseSchema any `json:"responseSchema,omitempty"`
}
type GeminiChatCandidate struct {

View File

@@ -77,6 +77,16 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
},
}
}
if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") {
geminiRequest.GenerationConfig.ResponseMimeType = "application/json"
if textRequest.ResponseFormat.JsonSchema != nil && textRequest.ResponseFormat.JsonSchema.Schema != nil {
cleanedSchema := removeAdditionalPropertiesWithDepth(textRequest.ResponseFormat.JsonSchema.Schema, 0)
geminiRequest.GenerationConfig.ResponseSchema = cleanedSchema
}
}
//shouldAddDummyModelMessage := false
for _, message := range textRequest.Messages {
@@ -165,6 +175,46 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
return &geminiRequest, nil
}
func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interface{} {
if depth >= 5 {
return schema
}
v, ok := schema.(map[string]interface{})
if !ok || len(v) == 0 {
return schema
}
// 如果type不为object和array则直接返回
if typeVal, exists := v["type"]; !exists || (typeVal != "object" && typeVal != "array") {
return schema
}
switch v["type"] {
case "object":
delete(v, "additionalProperties")
// 处理 properties
if properties, ok := v["properties"].(map[string]interface{}); ok {
for key, value := range properties {
properties[key] = removeAdditionalPropertiesWithDepth(value, depth+1)
}
}
for _, field := range []string{"allOf", "anyOf", "oneOf"} {
if nested, ok := v[field].([]interface{}); ok {
for i, item := range nested {
nested[i] = removeAdditionalPropertiesWithDepth(item, depth+1)
}
}
}
case "array":
if items, ok := v["items"].(map[string]interface{}); ok {
v["items"] = removeAdditionalPropertiesWithDepth(items, depth+1)
}
}
return v
}
func (g *GeminiChatResponse) GetResponseText() string {
if g == nil {
return ""