Merge pull request #1484 from QuantumNous/ConvertGeminiRequest
feat: Convert gemini request
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
package gemini
|
||||
package dto
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
@@ -56,7 +56,7 @@ type FunctionCall struct {
|
||||
Arguments any `json:"args"`
|
||||
}
|
||||
|
||||
type FunctionResponse struct {
|
||||
type GeminiFunctionResponse struct {
|
||||
Name string `json:"name"`
|
||||
Response map[string]interface{} `json:"response"`
|
||||
}
|
||||
@@ -81,7 +81,7 @@ type GeminiPart struct {
|
||||
Thought bool `json:"thought,omitempty"`
|
||||
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
|
||||
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
|
||||
FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"`
|
||||
FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"`
|
||||
FileData *GeminiFileData `json:"fileData,omitempty"`
|
||||
ExecutableCode *GeminiPartExecutableCode `json:"executableCode,omitempty"`
|
||||
CodeExecutionResult *GeminiPartCodeExecutionResult `json:"codeExecutionResult,omitempty"`
|
||||
@@ -26,6 +26,7 @@ type Adaptor interface {
|
||||
GetModelList() []string
|
||||
GetChannelName() string
|
||||
ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error)
|
||||
ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error)
|
||||
}
|
||||
|
||||
type TaskAdaptor interface {
|
||||
|
||||
@@ -18,6 +18,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -22,6 +22,11 @@ type Adaptor struct {
|
||||
RequestMode int
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||
c.Set("request_model", request.Model)
|
||||
c.Set("converted_request", request)
|
||||
|
||||
@@ -18,6 +18,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -18,6 +18,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
@@ -43,15 +48,15 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
keyParts := strings.Split(info.ApiKey, "|")
|
||||
keyParts := strings.Split(info.ApiKey, "|")
|
||||
if len(keyParts) == 0 || keyParts[0] == "" {
|
||||
return errors.New("invalid API key: authorization token is required")
|
||||
}
|
||||
if len(keyParts) > 1 {
|
||||
if keyParts[1] != "" {
|
||||
req.Set("appid", keyParts[1])
|
||||
}
|
||||
}
|
||||
return errors.New("invalid API key: authorization token is required")
|
||||
}
|
||||
if len(keyParts) > 1 {
|
||||
if keyParts[1] != "" {
|
||||
req.Set("appid", keyParts[1])
|
||||
}
|
||||
}
|
||||
req.Set("Authorization", "Bearer "+keyParts[0])
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -24,6 +24,11 @@ type Adaptor struct {
|
||||
RequestMode int
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||
return request, nil
|
||||
}
|
||||
|
||||
@@ -18,6 +18,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -17,6 +17,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -18,6 +18,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *common.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
// ConvertAudioRequest implements channel.Adaptor.
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *common.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
|
||||
@@ -19,6 +19,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -24,6 +24,11 @@ type Adaptor struct {
|
||||
BotType int
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -20,6 +20,10 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) {
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
|
||||
adaptor := openai.Adaptor{}
|
||||
oaiReq, err := adaptor.ConvertClaudeRequest(c, info, req)
|
||||
@@ -51,13 +55,13 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
}
|
||||
|
||||
// build gemini imagen request
|
||||
geminiRequest := GeminiImageRequest{
|
||||
Instances: []GeminiImageInstance{
|
||||
geminiRequest := dto.GeminiImageRequest{
|
||||
Instances: []dto.GeminiImageInstance{
|
||||
{
|
||||
Prompt: request.Prompt,
|
||||
},
|
||||
},
|
||||
Parameters: GeminiImageParameters{
|
||||
Parameters: dto.GeminiImageParameters{
|
||||
SampleCount: request.N,
|
||||
AspectRatio: aspectRatio,
|
||||
PersonGeneration: "allow_adult", // default allow adult
|
||||
@@ -138,9 +142,9 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
||||
}
|
||||
|
||||
// only process the first input
|
||||
geminiRequest := GeminiEmbeddingRequest{
|
||||
Content: GeminiChatContent{
|
||||
Parts: []GeminiPart{
|
||||
geminiRequest := dto.GeminiEmbeddingRequest{
|
||||
Content: dto.GeminiChatContent{
|
||||
Parts: []dto.GeminiPart{
|
||||
{
|
||||
Text: inputs[0],
|
||||
},
|
||||
|
||||
@@ -29,7 +29,7 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re
|
||||
}
|
||||
|
||||
// 解析为 Gemini 原生响应格式
|
||||
var geminiResponse GeminiChatResponse
|
||||
var geminiResponse dto.GeminiChatResponse
|
||||
err = common.Unmarshal(responseBody, &geminiResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
@@ -72,7 +72,7 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayIn
|
||||
responseText := strings.Builder{}
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
var geminiResponse GeminiChatResponse
|
||||
var geminiResponse dto.GeminiChatResponse
|
||||
err := common.UnmarshalJsonStr(data, &geminiResponse)
|
||||
if err != nil {
|
||||
common.LogError(c, "error unmarshalling stream response: "+err.Error())
|
||||
|
||||
@@ -81,7 +81,7 @@ func clampThinkingBudget(modelName string, budget int) int {
|
||||
return budget
|
||||
}
|
||||
|
||||
func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayInfo) {
|
||||
func ThinkingAdaptor(geminiRequest *dto.GeminiChatRequest, info *relaycommon.RelayInfo) {
|
||||
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
|
||||
modelName := info.UpstreamModelName
|
||||
isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
|
||||
@@ -93,7 +93,7 @@ func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayIn
|
||||
if len(parts) == 2 && parts[1] != "" {
|
||||
if budgetTokens, err := strconv.Atoi(parts[1]); err == nil {
|
||||
clampedBudget := clampThinkingBudget(modelName, budgetTokens)
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
|
||||
ThinkingBudget: common.GetPointer(clampedBudget),
|
||||
IncludeThoughts: true,
|
||||
}
|
||||
@@ -113,11 +113,11 @@ func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayIn
|
||||
}
|
||||
|
||||
if isUnsupported {
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
|
||||
IncludeThoughts: true,
|
||||
}
|
||||
} else {
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
|
||||
IncludeThoughts: true,
|
||||
}
|
||||
if geminiRequest.GenerationConfig.MaxOutputTokens > 0 {
|
||||
@@ -128,7 +128,7 @@ func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayIn
|
||||
}
|
||||
} else if strings.HasSuffix(modelName, "-nothinking") {
|
||||
if !isNew25Pro {
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
|
||||
ThinkingBudget: common.GetPointer(0),
|
||||
}
|
||||
}
|
||||
@@ -137,11 +137,11 @@ func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayIn
|
||||
}
|
||||
|
||||
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
||||
func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*GeminiChatRequest, error) {
|
||||
func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*dto.GeminiChatRequest, error) {
|
||||
|
||||
geminiRequest := GeminiChatRequest{
|
||||
Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
|
||||
GenerationConfig: GeminiChatGenerationConfig{
|
||||
geminiRequest := dto.GeminiChatRequest{
|
||||
Contents: make([]dto.GeminiChatContent, 0, len(textRequest.Messages)),
|
||||
GenerationConfig: dto.GeminiChatGenerationConfig{
|
||||
Temperature: textRequest.Temperature,
|
||||
TopP: textRequest.TopP,
|
||||
MaxOutputTokens: textRequest.MaxTokens,
|
||||
@@ -158,9 +158,9 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
|
||||
ThinkingAdaptor(&geminiRequest, info)
|
||||
|
||||
safetySettings := make([]GeminiChatSafetySettings, 0, len(SafetySettingList))
|
||||
safetySettings := make([]dto.GeminiChatSafetySettings, 0, len(SafetySettingList))
|
||||
for _, category := range SafetySettingList {
|
||||
safetySettings = append(safetySettings, GeminiChatSafetySettings{
|
||||
safetySettings = append(safetySettings, dto.GeminiChatSafetySettings{
|
||||
Category: category,
|
||||
Threshold: model_setting.GetGeminiSafetySetting(category),
|
||||
})
|
||||
@@ -198,17 +198,17 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
functions = append(functions, tool.Function)
|
||||
}
|
||||
if codeExecution {
|
||||
geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
|
||||
geminiRequest.Tools = append(geminiRequest.Tools, dto.GeminiChatTool{
|
||||
CodeExecution: make(map[string]string),
|
||||
})
|
||||
}
|
||||
if googleSearch {
|
||||
geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
|
||||
geminiRequest.Tools = append(geminiRequest.Tools, dto.GeminiChatTool{
|
||||
GoogleSearch: make(map[string]string),
|
||||
})
|
||||
}
|
||||
if len(functions) > 0 {
|
||||
geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
|
||||
geminiRequest.Tools = append(geminiRequest.Tools, dto.GeminiChatTool{
|
||||
FunctionDeclarations: functions,
|
||||
})
|
||||
}
|
||||
@@ -238,7 +238,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
continue
|
||||
} else if message.Role == "tool" || message.Role == "function" {
|
||||
if len(geminiRequest.Contents) == 0 || geminiRequest.Contents[len(geminiRequest.Contents)-1].Role == "model" {
|
||||
geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
|
||||
geminiRequest.Contents = append(geminiRequest.Contents, dto.GeminiChatContent{
|
||||
Role: "user",
|
||||
})
|
||||
}
|
||||
@@ -265,18 +265,18 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
}
|
||||
}
|
||||
|
||||
functionResp := &FunctionResponse{
|
||||
functionResp := &dto.GeminiFunctionResponse{
|
||||
Name: name,
|
||||
Response: contentMap,
|
||||
}
|
||||
|
||||
*parts = append(*parts, GeminiPart{
|
||||
*parts = append(*parts, dto.GeminiPart{
|
||||
FunctionResponse: functionResp,
|
||||
})
|
||||
continue
|
||||
}
|
||||
var parts []GeminiPart
|
||||
content := GeminiChatContent{
|
||||
var parts []dto.GeminiPart
|
||||
content := dto.GeminiChatContent{
|
||||
Role: message.Role,
|
||||
}
|
||||
// isToolCall := false
|
||||
@@ -290,8 +290,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
return nil, fmt.Errorf("invalid arguments for function %s, args: %s", call.Function.Name, call.Function.Arguments)
|
||||
}
|
||||
}
|
||||
toolCall := GeminiPart{
|
||||
FunctionCall: &FunctionCall{
|
||||
toolCall := dto.GeminiPart{
|
||||
FunctionCall: &dto.FunctionCall{
|
||||
FunctionName: call.Function.Name,
|
||||
Arguments: args,
|
||||
},
|
||||
@@ -308,7 +308,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
if part.Text == "" {
|
||||
continue
|
||||
}
|
||||
parts = append(parts, GeminiPart{
|
||||
parts = append(parts, dto.GeminiPart{
|
||||
Text: part.Text,
|
||||
})
|
||||
} else if part.Type == dto.ContentTypeImageURL {
|
||||
@@ -331,8 +331,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
return nil, fmt.Errorf("mime type is not supported by Gemini: '%s', url: '%s', supported types are: %v", fileData.MimeType, url, getSupportedMimeTypesList())
|
||||
}
|
||||
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
parts = append(parts, dto.GeminiPart{
|
||||
InlineData: &dto.GeminiInlineData{
|
||||
MimeType: fileData.MimeType, // 使用原始的 MimeType,因为大小写可能对API有意义
|
||||
Data: fileData.Base64Data,
|
||||
},
|
||||
@@ -342,8 +342,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
|
||||
}
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
parts = append(parts, dto.GeminiPart{
|
||||
InlineData: &dto.GeminiInlineData{
|
||||
MimeType: format,
|
||||
Data: base64String,
|
||||
},
|
||||
@@ -357,8 +357,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode base64 file data failed: %s", err.Error())
|
||||
}
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
parts = append(parts, dto.GeminiPart{
|
||||
InlineData: &dto.GeminiInlineData{
|
||||
MimeType: format,
|
||||
Data: base64String,
|
||||
},
|
||||
@@ -371,8 +371,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode base64 audio data failed: %s", err.Error())
|
||||
}
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
parts = append(parts, dto.GeminiPart{
|
||||
InlineData: &dto.GeminiInlineData{
|
||||
MimeType: "audio/" + part.GetInputAudio().Format,
|
||||
Data: base64String,
|
||||
},
|
||||
@@ -392,8 +392,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
}
|
||||
|
||||
if len(system_content) > 0 {
|
||||
geminiRequest.SystemInstructions = &GeminiChatContent{
|
||||
Parts: []GeminiPart{
|
||||
geminiRequest.SystemInstructions = &dto.GeminiChatContent{
|
||||
Parts: []dto.GeminiPart{
|
||||
{
|
||||
Text: strings.Join(system_content, "\n"),
|
||||
},
|
||||
@@ -636,7 +636,7 @@ func unescapeMapOrSlice(data interface{}) interface{} {
|
||||
return data
|
||||
}
|
||||
|
||||
func getResponseToolCall(item *GeminiPart) *dto.ToolCallResponse {
|
||||
func getResponseToolCall(item *dto.GeminiPart) *dto.ToolCallResponse {
|
||||
var argsBytes []byte
|
||||
var err error
|
||||
if result, ok := item.FunctionCall.Arguments.(map[string]interface{}); ok {
|
||||
@@ -658,7 +658,7 @@ func getResponseToolCall(item *GeminiPart) *dto.ToolCallResponse {
|
||||
}
|
||||
}
|
||||
|
||||
func responseGeminiChat2OpenAI(c *gin.Context, response *GeminiChatResponse) *dto.OpenAITextResponse {
|
||||
func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse) *dto.OpenAITextResponse {
|
||||
fullTextResponse := dto.OpenAITextResponse{
|
||||
Id: helper.GetResponseID(c),
|
||||
Object: "chat.completion",
|
||||
@@ -725,7 +725,7 @@ func responseGeminiChat2OpenAI(c *gin.Context, response *GeminiChatResponse) *dt
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool) {
|
||||
func streamResponseGeminiChat2OpenAI(geminiResponse *dto.GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool) {
|
||||
choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates))
|
||||
isStop := false
|
||||
for _, candidate := range geminiResponse.Candidates {
|
||||
@@ -827,7 +827,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
|
||||
var imageCount int
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
var geminiResponse GeminiChatResponse
|
||||
var geminiResponse dto.GeminiChatResponse
|
||||
err := common.UnmarshalJsonStr(data, &geminiResponse)
|
||||
if err != nil {
|
||||
common.LogError(c, "error unmarshalling stream response: "+err.Error())
|
||||
@@ -928,7 +928,7 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
|
||||
if common.DebugEnabled {
|
||||
println(string(responseBody))
|
||||
}
|
||||
var geminiResponse GeminiChatResponse
|
||||
var geminiResponse dto.GeminiChatResponse
|
||||
err = common.Unmarshal(responseBody, &geminiResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
@@ -974,7 +974,7 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
|
||||
return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
var geminiResponse GeminiEmbeddingResponse
|
||||
var geminiResponse dto.GeminiEmbeddingResponse
|
||||
if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
|
||||
return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
@@ -1020,7 +1020,7 @@ func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
|
||||
var geminiResponse GeminiImageResponse
|
||||
var geminiResponse dto.GeminiImageResponse
|
||||
if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
|
||||
return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
@@ -13,11 +12,18 @@ import (
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
@@ -19,6 +19,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -16,6 +16,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -18,6 +18,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -17,6 +17,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||
openaiAdaptor := openai.Adaptor{}
|
||||
openaiRequest, err := openaiAdaptor.ConvertClaudeRequest(c, info, request)
|
||||
|
||||
@@ -34,6 +34,15 @@ type Adaptor struct {
|
||||
ResponseFormat string
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) {
|
||||
// 使用 service.GeminiToOpenAIRequest 转换请求格式
|
||||
openaiRequest, err := service.GeminiToOpenAIRequest(request, info)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return a.ConvertOpenAIRequest(c, info, openaiRequest)
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||
//if !strings.Contains(request.Model, "claude") {
|
||||
// return nil, fmt.Errorf("you are using openai channel type with path /v1/messages, only claude model supported convert, but got %s", request.Model)
|
||||
@@ -64,7 +73,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if info.RelayFormat == relaycommon.RelayFormatClaude {
|
||||
if info.RelayFormat == relaycommon.RelayFormatClaude || info.RelayFormat == relaycommon.RelayFormatGemini {
|
||||
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
|
||||
}
|
||||
if info.RelayMode == relayconstant.RelayModeRealtime {
|
||||
|
||||
@@ -2,6 +2,8 @@ package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
@@ -16,11 +18,14 @@ import (
|
||||
// 辅助函数
|
||||
func HandleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
|
||||
info.SendResponseCount++
|
||||
|
||||
switch info.RelayFormat {
|
||||
case relaycommon.RelayFormatOpenAI:
|
||||
return sendStreamData(c, info, data, forceFormat, thinkToContent)
|
||||
case relaycommon.RelayFormatClaude:
|
||||
return handleClaudeFormat(c, data, info)
|
||||
case relaycommon.RelayFormatGemini:
|
||||
return handleGeminiFormat(c, data, info)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -41,6 +46,36 @@ func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleGeminiFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error {
|
||||
var streamResponse dto.ChatCompletionsStreamResponse
|
||||
if err := common.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
|
||||
common.LogError(c, "failed to unmarshal stream response: "+err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
geminiResponse := service.StreamResponseOpenAI2Gemini(&streamResponse, info)
|
||||
|
||||
// 如果返回 nil,表示没有实际内容,跳过发送
|
||||
if geminiResponse == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
geminiResponseStr, err := common.Marshal(geminiResponse)
|
||||
if err != nil {
|
||||
common.LogError(c, "failed to marshal gemini response: "+err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
// send gemini format response
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(geminiResponseStr)})
|
||||
if flusher, ok := c.Writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
} else {
|
||||
return errors.New("streaming error: flusher not found")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ProcessStreamResponse(streamResponse dto.ChatCompletionsStreamResponse, responseTextBuilder *strings.Builder, toolCount *int) error {
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
||||
@@ -185,6 +220,37 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
|
||||
for _, resp := range claudeResponses {
|
||||
_ = helper.ClaudeData(c, *resp)
|
||||
}
|
||||
|
||||
case relaycommon.RelayFormatGemini:
|
||||
var streamResponse dto.ChatCompletionsStreamResponse
|
||||
if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 这里处理的是 openai 最后一个流响应,其 delta 为空,有 finish_reason 字段
|
||||
// 因此相比较于 google 官方的流响应,由 openai 转换而来会多一个 parts 为空,finishReason 为 STOP 的响应
|
||||
// 而包含最后一段文本输出的响应(倒数第二个)的 finishReason 为 null
|
||||
// 暂不知是否有程序会不兼容。
|
||||
|
||||
geminiResponse := service.StreamResponseOpenAI2Gemini(&streamResponse, info)
|
||||
|
||||
// openai 流响应开头的空数据
|
||||
if geminiResponse == nil {
|
||||
return
|
||||
}
|
||||
|
||||
geminiResponseStr, err := common.Marshal(geminiResponse)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling gemini response: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 发送最终的 Gemini 响应
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(geminiResponseStr)})
|
||||
if flusher, ok := c.Writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -223,6 +223,13 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
responseBody = claudeRespStr
|
||||
case relaycommon.RelayFormatGemini:
|
||||
geminiResp := service.ResponseOpenAI2Gemini(&simpleResponse, info)
|
||||
geminiRespStr, err := common.Marshal(geminiResp)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
responseBody = geminiRespStr
|
||||
}
|
||||
|
||||
common.IOCopyBytesGracefully(c, resp, responseBody)
|
||||
|
||||
@@ -17,6 +17,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -17,6 +17,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -18,6 +18,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
|
||||
adaptor := openai.Adaptor{}
|
||||
return adaptor.ConvertClaudeRequest(c, info, req)
|
||||
|
||||
@@ -25,6 +25,11 @@ type Adaptor struct {
|
||||
Timestamp int64
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -44,6 +44,10 @@ type Adaptor struct {
|
||||
AccountCredentials Credentials
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) {
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||
if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
|
||||
c.Set("request_model", v)
|
||||
|
||||
@@ -23,6 +23,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -19,6 +19,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
//panic("implement me")
|
||||
|
||||
@@ -17,6 +17,11 @@ type Adaptor struct {
|
||||
request *dto.GeneralOpenAIRequest
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -16,6 +16,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -18,6 +18,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -20,8 +20,8 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func getAndValidateGeminiRequest(c *gin.Context) (*gemini.GeminiChatRequest, error) {
|
||||
request := &gemini.GeminiChatRequest{}
|
||||
func getAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiChatRequest, error) {
|
||||
request := &dto.GeminiChatRequest{}
|
||||
err := common.UnmarshalBodyReusable(c, request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -44,7 +44,7 @@ func checkGeminiStreamMode(c *gin.Context, relayInfo *relaycommon.RelayInfo) {
|
||||
// }
|
||||
}
|
||||
|
||||
func checkGeminiInputSensitive(textRequest *gemini.GeminiChatRequest) ([]string, error) {
|
||||
func checkGeminiInputSensitive(textRequest *dto.GeminiChatRequest) ([]string, error) {
|
||||
var inputTexts []string
|
||||
for _, content := range textRequest.Contents {
|
||||
for _, part := range content.Parts {
|
||||
@@ -61,7 +61,7 @@ func checkGeminiInputSensitive(textRequest *gemini.GeminiChatRequest) ([]string,
|
||||
return sensitiveWords, err
|
||||
}
|
||||
|
||||
func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) int {
|
||||
func getGeminiInputTokens(req *dto.GeminiChatRequest, info *relaycommon.RelayInfo) int {
|
||||
// 计算输入 token 数量
|
||||
var inputTexts []string
|
||||
for _, content := range req.Contents {
|
||||
@@ -78,7 +78,7 @@ func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.Relay
|
||||
return inputTokens
|
||||
}
|
||||
|
||||
func isNoThinkingRequest(req *gemini.GeminiChatRequest) bool {
|
||||
func isNoThinkingRequest(req *dto.GeminiChatRequest) bool {
|
||||
if req.GenerationConfig.ThinkingConfig != nil && req.GenerationConfig.ThinkingConfig.ThinkingBudget != nil {
|
||||
return *req.GenerationConfig.ThinkingConfig.ThinkingBudget == 0
|
||||
}
|
||||
@@ -202,7 +202,12 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
}
|
||||
requestBody = bytes.NewReader(body)
|
||||
} else {
|
||||
jsonData, err := common.Marshal(req)
|
||||
// 使用 ConvertGeminiRequest 转换请求格式
|
||||
convertedRequest, err := adaptor.ConvertGeminiRequest(c, relayInfo, req)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
jsonData, err := common.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
@@ -448,3 +448,353 @@ func toJSONString(v interface{}) string {
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func GeminiToOpenAIRequest(geminiRequest *dto.GeminiChatRequest, info *relaycommon.RelayInfo) (*dto.GeneralOpenAIRequest, error) {
|
||||
openaiRequest := &dto.GeneralOpenAIRequest{
|
||||
Model: info.UpstreamModelName,
|
||||
Stream: info.IsStream,
|
||||
}
|
||||
|
||||
// 转换 messages
|
||||
var messages []dto.Message
|
||||
for _, content := range geminiRequest.Contents {
|
||||
message := dto.Message{
|
||||
Role: convertGeminiRoleToOpenAI(content.Role),
|
||||
}
|
||||
|
||||
// 处理 parts
|
||||
var mediaContents []dto.MediaContent
|
||||
var toolCalls []dto.ToolCallRequest
|
||||
for _, part := range content.Parts {
|
||||
if part.Text != "" {
|
||||
mediaContent := dto.MediaContent{
|
||||
Type: "text",
|
||||
Text: part.Text,
|
||||
}
|
||||
mediaContents = append(mediaContents, mediaContent)
|
||||
} else if part.InlineData != nil {
|
||||
mediaContent := dto.MediaContent{
|
||||
Type: "image_url",
|
||||
ImageUrl: &dto.MessageImageUrl{
|
||||
Url: fmt.Sprintf("data:%s;base64,%s", part.InlineData.MimeType, part.InlineData.Data),
|
||||
Detail: "auto",
|
||||
MimeType: part.InlineData.MimeType,
|
||||
},
|
||||
}
|
||||
mediaContents = append(mediaContents, mediaContent)
|
||||
} else if part.FileData != nil {
|
||||
mediaContent := dto.MediaContent{
|
||||
Type: "image_url",
|
||||
ImageUrl: &dto.MessageImageUrl{
|
||||
Url: part.FileData.FileUri,
|
||||
Detail: "auto",
|
||||
MimeType: part.FileData.MimeType,
|
||||
},
|
||||
}
|
||||
mediaContents = append(mediaContents, mediaContent)
|
||||
} else if part.FunctionCall != nil {
|
||||
// 处理 Gemini 的工具调用
|
||||
toolCall := dto.ToolCallRequest{
|
||||
ID: fmt.Sprintf("call_%d", len(toolCalls)+1), // 生成唯一ID
|
||||
Type: "function",
|
||||
Function: dto.FunctionRequest{
|
||||
Name: part.FunctionCall.FunctionName,
|
||||
Arguments: toJSONString(part.FunctionCall.Arguments),
|
||||
},
|
||||
}
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
} else if part.FunctionResponse != nil {
|
||||
// 处理 Gemini 的工具响应,创建单独的 tool 消息
|
||||
toolMessage := dto.Message{
|
||||
Role: "tool",
|
||||
ToolCallId: fmt.Sprintf("call_%d", len(toolCalls)), // 使用对应的调用ID
|
||||
}
|
||||
toolMessage.SetStringContent(toJSONString(part.FunctionResponse.Response))
|
||||
messages = append(messages, toolMessage)
|
||||
}
|
||||
}
|
||||
|
||||
// 设置消息内容
|
||||
if len(toolCalls) > 0 {
|
||||
// 如果有工具调用,设置工具调用
|
||||
message.SetToolCalls(toolCalls)
|
||||
} else if len(mediaContents) == 1 && mediaContents[0].Type == "text" {
|
||||
// 如果只有一个文本内容,直接设置字符串
|
||||
message.Content = mediaContents[0].Text
|
||||
} else if len(mediaContents) > 0 {
|
||||
// 如果有多个内容或包含媒体,设置为数组
|
||||
message.SetMediaContent(mediaContents)
|
||||
}
|
||||
|
||||
// 只有当消息有内容或工具调用时才添加
|
||||
if len(message.ParseContent()) > 0 || len(message.ToolCalls) > 0 {
|
||||
messages = append(messages, message)
|
||||
}
|
||||
}
|
||||
|
||||
openaiRequest.Messages = messages
|
||||
|
||||
if geminiRequest.GenerationConfig.Temperature != nil {
|
||||
openaiRequest.Temperature = geminiRequest.GenerationConfig.Temperature
|
||||
}
|
||||
if geminiRequest.GenerationConfig.TopP > 0 {
|
||||
openaiRequest.TopP = geminiRequest.GenerationConfig.TopP
|
||||
}
|
||||
if geminiRequest.GenerationConfig.TopK > 0 {
|
||||
openaiRequest.TopK = int(geminiRequest.GenerationConfig.TopK)
|
||||
}
|
||||
if geminiRequest.GenerationConfig.MaxOutputTokens > 0 {
|
||||
openaiRequest.MaxTokens = geminiRequest.GenerationConfig.MaxOutputTokens
|
||||
}
|
||||
// gemini stop sequences 最多 5 个,openai stop 最多 4 个
|
||||
if len(geminiRequest.GenerationConfig.StopSequences) > 0 {
|
||||
openaiRequest.Stop = geminiRequest.GenerationConfig.StopSequences[:4]
|
||||
}
|
||||
if geminiRequest.GenerationConfig.CandidateCount > 0 {
|
||||
openaiRequest.N = geminiRequest.GenerationConfig.CandidateCount
|
||||
}
|
||||
|
||||
// 转换工具调用
|
||||
if len(geminiRequest.Tools) > 0 {
|
||||
var tools []dto.ToolCallRequest
|
||||
for _, tool := range geminiRequest.Tools {
|
||||
if tool.FunctionDeclarations != nil {
|
||||
// 将 Gemini 的 FunctionDeclarations 转换为 OpenAI 的 ToolCallRequest
|
||||
functionDeclarations, ok := tool.FunctionDeclarations.([]dto.FunctionRequest)
|
||||
if ok {
|
||||
for _, function := range functionDeclarations {
|
||||
openAITool := dto.ToolCallRequest{
|
||||
Type: "function",
|
||||
Function: dto.FunctionRequest{
|
||||
Name: function.Name,
|
||||
Description: function.Description,
|
||||
Parameters: function.Parameters,
|
||||
},
|
||||
}
|
||||
tools = append(tools, openAITool)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(tools) > 0 {
|
||||
openaiRequest.Tools = tools
|
||||
}
|
||||
}
|
||||
|
||||
// gemini system instructions
|
||||
if geminiRequest.SystemInstructions != nil {
|
||||
// 将系统指令作为第一条消息插入
|
||||
systemMessage := dto.Message{
|
||||
Role: "system",
|
||||
Content: extractTextFromGeminiParts(geminiRequest.SystemInstructions.Parts),
|
||||
}
|
||||
openaiRequest.Messages = append([]dto.Message{systemMessage}, openaiRequest.Messages...)
|
||||
}
|
||||
|
||||
return openaiRequest, nil
|
||||
}
|
||||
|
||||
func convertGeminiRoleToOpenAI(geminiRole string) string {
|
||||
switch geminiRole {
|
||||
case "user":
|
||||
return "user"
|
||||
case "model":
|
||||
return "assistant"
|
||||
case "function":
|
||||
return "function"
|
||||
default:
|
||||
return "user"
|
||||
}
|
||||
}
|
||||
|
||||
func extractTextFromGeminiParts(parts []dto.GeminiPart) string {
|
||||
var texts []string
|
||||
for _, part := range parts {
|
||||
if part.Text != "" {
|
||||
texts = append(texts, part.Text)
|
||||
}
|
||||
}
|
||||
return strings.Join(texts, "\n")
|
||||
}
|
||||
|
||||
// ResponseOpenAI2Gemini 将 OpenAI 响应转换为 Gemini 格式
|
||||
func ResponseOpenAI2Gemini(openAIResponse *dto.OpenAITextResponse, info *relaycommon.RelayInfo) *dto.GeminiChatResponse {
|
||||
geminiResponse := &dto.GeminiChatResponse{
|
||||
Candidates: make([]dto.GeminiChatCandidate, 0, len(openAIResponse.Choices)),
|
||||
PromptFeedback: dto.GeminiChatPromptFeedback{
|
||||
SafetyRatings: []dto.GeminiChatSafetyRating{},
|
||||
},
|
||||
UsageMetadata: dto.GeminiUsageMetadata{
|
||||
PromptTokenCount: openAIResponse.PromptTokens,
|
||||
CandidatesTokenCount: openAIResponse.CompletionTokens,
|
||||
TotalTokenCount: openAIResponse.PromptTokens + openAIResponse.CompletionTokens,
|
||||
},
|
||||
}
|
||||
|
||||
for _, choice := range openAIResponse.Choices {
|
||||
candidate := dto.GeminiChatCandidate{
|
||||
Index: int64(choice.Index),
|
||||
SafetyRatings: []dto.GeminiChatSafetyRating{},
|
||||
}
|
||||
|
||||
// 设置结束原因
|
||||
var finishReason string
|
||||
switch choice.FinishReason {
|
||||
case "stop":
|
||||
finishReason = "STOP"
|
||||
case "length":
|
||||
finishReason = "MAX_TOKENS"
|
||||
case "content_filter":
|
||||
finishReason = "SAFETY"
|
||||
case "tool_calls":
|
||||
finishReason = "STOP"
|
||||
default:
|
||||
finishReason = "STOP"
|
||||
}
|
||||
candidate.FinishReason = &finishReason
|
||||
|
||||
// 转换消息内容
|
||||
content := dto.GeminiChatContent{
|
||||
Role: "model",
|
||||
Parts: make([]dto.GeminiPart, 0),
|
||||
}
|
||||
|
||||
// 处理工具调用
|
||||
toolCalls := choice.Message.ParseToolCalls()
|
||||
if len(toolCalls) > 0 {
|
||||
for _, toolCall := range toolCalls {
|
||||
// 解析参数
|
||||
var args map[string]interface{}
|
||||
if toolCall.Function.Arguments != "" {
|
||||
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil {
|
||||
args = map[string]interface{}{"arguments": toolCall.Function.Arguments}
|
||||
}
|
||||
} else {
|
||||
args = make(map[string]interface{})
|
||||
}
|
||||
|
||||
part := dto.GeminiPart{
|
||||
FunctionCall: &dto.FunctionCall{
|
||||
FunctionName: toolCall.Function.Name,
|
||||
Arguments: args,
|
||||
},
|
||||
}
|
||||
content.Parts = append(content.Parts, part)
|
||||
}
|
||||
} else {
|
||||
// 处理文本内容
|
||||
textContent := choice.Message.StringContent()
|
||||
if textContent != "" {
|
||||
part := dto.GeminiPart{
|
||||
Text: textContent,
|
||||
}
|
||||
content.Parts = append(content.Parts, part)
|
||||
}
|
||||
}
|
||||
|
||||
candidate.Content = content
|
||||
geminiResponse.Candidates = append(geminiResponse.Candidates, candidate)
|
||||
}
|
||||
|
||||
return geminiResponse
|
||||
}
|
||||
|
||||
// StreamResponseOpenAI2Gemini 将 OpenAI 流式响应转换为 Gemini 格式
|
||||
func StreamResponseOpenAI2Gemini(openAIResponse *dto.ChatCompletionsStreamResponse, info *relaycommon.RelayInfo) *dto.GeminiChatResponse {
|
||||
// 检查是否有实际内容或结束标志
|
||||
hasContent := false
|
||||
hasFinishReason := false
|
||||
for _, choice := range openAIResponse.Choices {
|
||||
if len(choice.Delta.GetContentString()) > 0 || (choice.Delta.ToolCalls != nil && len(choice.Delta.ToolCalls) > 0) {
|
||||
hasContent = true
|
||||
}
|
||||
if choice.FinishReason != nil {
|
||||
hasFinishReason = true
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有实际内容且没有结束标志,跳过。主要针对 openai 流响应开头的空数据
|
||||
if !hasContent && !hasFinishReason {
|
||||
return nil
|
||||
}
|
||||
|
||||
geminiResponse := &dto.GeminiChatResponse{
|
||||
Candidates: make([]dto.GeminiChatCandidate, 0, len(openAIResponse.Choices)),
|
||||
PromptFeedback: dto.GeminiChatPromptFeedback{
|
||||
SafetyRatings: []dto.GeminiChatSafetyRating{},
|
||||
},
|
||||
UsageMetadata: dto.GeminiUsageMetadata{
|
||||
PromptTokenCount: info.PromptTokens,
|
||||
CandidatesTokenCount: 0, // 流式响应中可能没有完整的 usage 信息
|
||||
TotalTokenCount: info.PromptTokens,
|
||||
},
|
||||
}
|
||||
|
||||
for _, choice := range openAIResponse.Choices {
|
||||
candidate := dto.GeminiChatCandidate{
|
||||
Index: int64(choice.Index),
|
||||
SafetyRatings: []dto.GeminiChatSafetyRating{},
|
||||
}
|
||||
|
||||
// 设置结束原因
|
||||
if choice.FinishReason != nil {
|
||||
var finishReason string
|
||||
switch *choice.FinishReason {
|
||||
case "stop":
|
||||
finishReason = "STOP"
|
||||
case "length":
|
||||
finishReason = "MAX_TOKENS"
|
||||
case "content_filter":
|
||||
finishReason = "SAFETY"
|
||||
case "tool_calls":
|
||||
finishReason = "STOP"
|
||||
default:
|
||||
finishReason = "STOP"
|
||||
}
|
||||
candidate.FinishReason = &finishReason
|
||||
}
|
||||
|
||||
// 转换消息内容
|
||||
content := dto.GeminiChatContent{
|
||||
Role: "model",
|
||||
Parts: make([]dto.GeminiPart, 0),
|
||||
}
|
||||
|
||||
// 处理工具调用
|
||||
if choice.Delta.ToolCalls != nil {
|
||||
for _, toolCall := range choice.Delta.ToolCalls {
|
||||
// 解析参数
|
||||
var args map[string]interface{}
|
||||
if toolCall.Function.Arguments != "" {
|
||||
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil {
|
||||
args = map[string]interface{}{"arguments": toolCall.Function.Arguments}
|
||||
}
|
||||
} else {
|
||||
args = make(map[string]interface{})
|
||||
}
|
||||
|
||||
part := dto.GeminiPart{
|
||||
FunctionCall: &dto.FunctionCall{
|
||||
FunctionName: toolCall.Function.Name,
|
||||
Arguments: args,
|
||||
},
|
||||
}
|
||||
content.Parts = append(content.Parts, part)
|
||||
}
|
||||
} else {
|
||||
// 处理文本内容
|
||||
textContent := choice.Delta.GetContentString()
|
||||
if textContent != "" {
|
||||
part := dto.GeminiPart{
|
||||
Text: textContent,
|
||||
}
|
||||
content.Parts = append(content.Parts, part)
|
||||
}
|
||||
}
|
||||
|
||||
candidate.Content = content
|
||||
geminiResponse.Candidates = append(geminiResponse.Candidates, candidate)
|
||||
}
|
||||
|
||||
return geminiResponse
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user