Files
new-api/relay/channel/gemini/adaptor.go
2025-08-09 00:27:33 +08:00

250 lines
7.4 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package gemini
import (
"errors"
"fmt"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"one-api/setting/model_setting"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) {
if len(request.Contents) > 0 {
for i, content := range request.Contents {
if i == 0 {
if request.Contents[0].Role == "" {
request.Contents[0].Role = "user"
}
}
for _, part := range content.Parts {
if part.FileData != nil {
if part.FileData.MimeType == "" && strings.Contains(part.FileData.FileUri, "www.youtube.com") {
part.FileData.MimeType = "video/webm"
}
}
}
}
}
return request, nil
}
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
adaptor := openai.Adaptor{}
oaiReq, err := adaptor.ConvertClaudeRequest(c, info, req)
if err != nil {
return nil, err
}
return a.ConvertOpenAIRequest(c, info, oaiReq.(*dto.GeneralOpenAIRequest))
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
if !strings.HasPrefix(info.UpstreamModelName, "imagen") {
return nil, errors.New("not supported model for image generation")
}
// convert size to aspect ratio
aspectRatio := "1:1" // default aspect ratio
switch request.Size {
case "1024x1024":
aspectRatio = "1:1"
case "1024x1792":
aspectRatio = "9:16"
case "1792x1024":
aspectRatio = "16:9"
}
// build gemini imagen request
geminiRequest := dto.GeminiImageRequest{
Instances: []dto.GeminiImageInstance{
{
Prompt: request.Prompt,
},
},
Parameters: dto.GeminiImageParameters{
SampleCount: request.N,
AspectRatio: aspectRatio,
PersonGeneration: "allow_adult", // default allow adult
},
}
return geminiRequest, nil
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
// 新增逻辑:处理 -thinking-<budget> 格式
if strings.Contains(info.UpstreamModelName, "-thinking-") {
parts := strings.Split(info.UpstreamModelName, "-thinking-")
info.UpstreamModelName = parts[0]
} else if strings.HasSuffix(info.UpstreamModelName, "-thinking") { // 旧的适配
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
} else if strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
}
}
version := model_setting.GetGeminiVersionSetting(info.UpstreamModelName)
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
return fmt.Sprintf("%s/%s/models/%s:predict", info.BaseUrl, version, info.UpstreamModelName), nil
}
if strings.HasPrefix(info.UpstreamModelName, "text-embedding") ||
strings.HasPrefix(info.UpstreamModelName, "embedding") ||
strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") {
action := "embedContent"
if info.IsGeminiBatchEmbdding {
action = "batchEmbedContents"
}
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
}
action := "generateContent"
if info.IsStream {
action = "streamGenerateContent?alt=sse"
if info.RelayMode == constant.RelayModeGemini {
info.DisablePing = true
}
}
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Set("x-goog-api-key", info.ApiKey)
return nil
}
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
geminiRequest, err := CovertGemini2OpenAI(*request, info)
if err != nil {
return nil, err
}
return geminiRequest, nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
if request.Input == nil {
return nil, errors.New("input is required")
}
inputs := request.ParseInput()
if len(inputs) == 0 {
return nil, errors.New("input is empty")
}
// only process the first input
geminiRequest := dto.GeminiEmbeddingRequest{
Content: dto.GeminiChatContent{
Parts: []dto.GeminiPart{
{
Text: inputs[0],
},
},
},
}
// set specific parameters for different models
// https://ai.google.dev/api/embeddings?hl=zh-cn#method:-models.embedcontent
switch info.UpstreamModelName {
case "text-embedding-004":
// except embedding-001 supports setting `OutputDimensionality`
if request.Dimensions > 0 {
geminiRequest.OutputDimensionality = request.Dimensions
}
}
return geminiRequest, nil
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.RelayMode == constant.RelayModeGemini {
if strings.Contains(info.RequestURLPath, "embed") {
return NativeGeminiEmbeddingHandler(c, resp, info)
}
if info.IsStream {
return GeminiTextGenerationStreamHandler(c, info, resp)
} else {
return GeminiTextGenerationHandler(c, info, resp)
}
}
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
return GeminiImageHandler(c, info, resp)
}
// check if the model is an embedding model
if strings.HasPrefix(info.UpstreamModelName, "text-embedding") ||
strings.HasPrefix(info.UpstreamModelName, "embedding") ||
strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") {
return GeminiEmbeddingHandler(c, info, resp)
}
if info.IsStream {
return GeminiChatStreamHandler(c, info, resp)
} else {
return GeminiChatHandler(c, info, resp)
}
//if usage.(*dto.Usage).CompletionTokenDetails.ReasoningTokens > 100 {
// // 没有请求-thinking的情况下产生思考token则按照思考模型计费
// if !strings.HasSuffix(info.OriginModelName, "-thinking") &&
// !strings.HasSuffix(info.OriginModelName, "-nothinking") {
// thinkingModelName := info.OriginModelName + "-thinking"
// if operation_setting.SelfUseModeEnabled || helper.ContainPriceOrRatio(thinkingModelName) {
// info.OriginModelName = thinkingModelName
// }
// }
//}
return nil, types.NewError(errors.New("not implemented"), types.ErrorCodeBadResponseBody)
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}