feat: convert gemini format to openai chat completions
This commit is contained in:
@@ -34,6 +34,15 @@ type Adaptor struct {
|
||||
ResponseFormat string
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) {
|
||||
// 使用 service.GeminiToOpenAIRequest 转换请求格式
|
||||
openaiRequest, err := service.GeminiToOpenAIRequest(request, info)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return a.ConvertOpenAIRequest(c, info, openaiRequest)
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||
//if !strings.Contains(request.Model, "claude") {
|
||||
// return nil, fmt.Errorf("you are using openai channel type with path /v1/messages, only claude model supported convert, but got %s", request.Model)
|
||||
@@ -64,7 +73,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if info.RelayFormat == relaycommon.RelayFormatClaude {
|
||||
if info.RelayFormat == relaycommon.RelayFormatClaude || info.RelayFormat == relaycommon.RelayFormatGemini {
|
||||
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
|
||||
}
|
||||
if info.RelayMode == relayconstant.RelayModeRealtime {
|
||||
|
||||
@@ -2,6 +2,8 @@ package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
@@ -16,11 +18,14 @@ import (
|
||||
// 辅助函数
|
||||
func HandleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
|
||||
info.SendResponseCount++
|
||||
|
||||
switch info.RelayFormat {
|
||||
case relaycommon.RelayFormatOpenAI:
|
||||
return sendStreamData(c, info, data, forceFormat, thinkToContent)
|
||||
case relaycommon.RelayFormatClaude:
|
||||
return handleClaudeFormat(c, data, info)
|
||||
case relaycommon.RelayFormatGemini:
|
||||
return handleGeminiFormat(c, data, info)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -41,6 +46,46 @@ func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleGeminiFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error {
|
||||
// 截取前50个字符用于调试
|
||||
debugData := data
|
||||
if len(data) > 50 {
|
||||
debugData = data[:50] + "..."
|
||||
}
|
||||
common.LogInfo(c, "handleGeminiFormat called with data: "+debugData)
|
||||
|
||||
var streamResponse dto.ChatCompletionsStreamResponse
|
||||
if err := common.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
|
||||
common.LogError(c, "failed to unmarshal stream response: "+err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
common.LogInfo(c, "successfully unmarshaled stream response")
|
||||
geminiResponse := service.StreamResponseOpenAI2Gemini(&streamResponse, info)
|
||||
|
||||
// 如果返回 nil,表示没有实际内容,跳过发送
|
||||
if geminiResponse == nil {
|
||||
common.LogInfo(c, "handleGeminiFormat: no content to send, skipping")
|
||||
return nil
|
||||
}
|
||||
|
||||
geminiResponseStr, err := common.Marshal(geminiResponse)
|
||||
if err != nil {
|
||||
common.LogError(c, "failed to marshal gemini response: "+err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
common.LogInfo(c, "sending gemini format response")
|
||||
// send gemini format response
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(geminiResponseStr)})
|
||||
if flusher, ok := c.Writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
} else {
|
||||
return errors.New("streaming error: flusher not found")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ProcessStreamResponse(streamResponse dto.ChatCompletionsStreamResponse, responseTextBuilder *strings.Builder, toolCount *int) error {
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
||||
@@ -185,6 +230,37 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
|
||||
for _, resp := range claudeResponses {
|
||||
_ = helper.ClaudeData(c, *resp)
|
||||
}
|
||||
|
||||
case relaycommon.RelayFormatGemini:
|
||||
var streamResponse dto.ChatCompletionsStreamResponse
|
||||
if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 这里处理的是 openai 最后一个流响应,其 delta 为空,有 finish_reason 字段
|
||||
// 因此相比较于 google 官方的流响应,由 openai 转换而来会多一个 parts 为空,finishReason 为 STOP 的响应
|
||||
// 而包含最后一段文本输出的响应(倒数第二个)的 finishReason 为 null
|
||||
// 暂不知是否有程序会不兼容。
|
||||
|
||||
geminiResponse := service.StreamResponseOpenAI2Gemini(&streamResponse, info)
|
||||
|
||||
// openai 流响应开头的空数据
|
||||
if geminiResponse == nil {
|
||||
return
|
||||
}
|
||||
|
||||
geminiResponseStr, err := common.Marshal(geminiResponse)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling gemini response: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 发送最终的 Gemini 响应
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(geminiResponseStr)})
|
||||
if flusher, ok := c.Writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -223,6 +223,13 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
responseBody = claudeRespStr
|
||||
case relaycommon.RelayFormatGemini:
|
||||
geminiResp := service.ResponseOpenAI2Gemini(&simpleResponse, info)
|
||||
geminiRespStr, err := common.Marshal(geminiResp)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
responseBody = geminiRespStr
|
||||
}
|
||||
|
||||
common.IOCopyBytesGracefully(c, resp, responseBody)
|
||||
|
||||
Reference in New Issue
Block a user