gemini text generation

This commit is contained in:
creamlike1024
2025-05-26 13:34:41 +08:00
parent b564cac048
commit 738a9a4558
9 changed files with 353 additions and 2 deletions

View File

@@ -40,6 +40,8 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
err = relay.EmbeddingHelper(c)
case relayconstant.RelayModeResponses:
err = relay.ResponsesHelper(c)
case relayconstant.RelayModeGemini:
err = relay.GeminiHelper(c)
default:
err = relay.TextHelper(c)
}

69
dto/gemini.go Normal file
View File

@@ -0,0 +1,69 @@
package dto
import "encoding/json"
type GeminiPart struct {
Text string `json:"text"`
}
type GeminiContent struct {
Parts []GeminiPart `json:"parts"`
Role string `json:"role"`
}
type GeminiCandidate struct {
Content GeminiContent `json:"content"`
FinishReason string `json:"finishReason"`
AvgLogprobs float64 `json:"avgLogprobs"`
}
type GeminiTokenDetails struct {
Modality string `json:"modality"`
TokenCount int `json:"tokenCount"`
}
type GeminiUsageMetadata struct {
PromptTokenCount int `json:"promptTokenCount"`
CandidatesTokenCount int `json:"candidatesTokenCount"`
TotalTokenCount int `json:"totalTokenCount"`
PromptTokensDetails []GeminiTokenDetails `json:"promptTokensDetails"`
CandidatesTokensDetails []GeminiTokenDetails `json:"candidatesTokensDetails"`
}
type GeminiTextGenerationResponse struct {
Candidates []GeminiCandidate `json:"candidates"`
UsageMetadata GeminiUsageMetadata `json:"usageMetadata"`
ModelVersion string `json:"modelVersion"`
ResponseID string `json:"responseId"`
}
type GeminiGenerationConfig struct {
StopSequences []string `json:"stopSequences,omitempty"`
ResponseMimeType string `json:"responseMimeType,omitempty"`
ResponseSchema *json.RawMessage `json:"responseSchema,omitempty"`
ResponseModalities *json.RawMessage `json:"responseModalities,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK int `json:"topK,omitempty"`
Seed int `json:"seed,omitempty"`
PresencePenalty float64 `json:"presencePenalty,omitempty"`
FrequencyPenalty float64 `json:"frequencyPenalty,omitempty"`
ResponseLogprobs bool `json:"responseLogprobs,omitempty"`
LogProbs int `json:"logProbs,omitempty"`
EnableEnhancedCivicAnswers bool `json:"enableEnhancedCivicAnswers,omitempty"`
SpeechConfig *json.RawMessage `json:"speechConfig,omitempty"`
ThinkingConfig *json.RawMessage `json:"thinkingConfig,omitempty"`
MediaResolution *json.RawMessage `json:"mediaResolution,omitempty"`
}
type GeminiTextGenerationRequest struct {
Contents []GeminiContent `json:"contents"`
Tools *json.RawMessage `json:"tools,omitempty"`
ToolConfig *json.RawMessage `json:"toolConfig,omitempty"`
SafetySettings *json.RawMessage `json:"safetySettings,omitempty"`
SystemInstruction *json.RawMessage `json:"systemInstruction,omitempty"`
GenerationConfig GeminiGenerationConfig `json:"generationConfig,omitempty"`
CachedContent *json.RawMessage `json:"cachedContent,omitempty"`
}

View File

@@ -1,13 +1,14 @@
package middleware
import (
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"strings"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
)
func validUserInfo(username string, role int) bool {
@@ -182,6 +183,13 @@ func TokenAuth() func(c *gin.Context) {
c.Request.Header.Set("Authorization", "Bearer "+key)
}
}
// gemini api 从query中获取key
if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") {
skKey := c.Query("key")
if skKey != "" {
c.Request.Header.Set("Authorization", "Bearer "+skKey)
}
}
key := c.Request.Header.Get("Authorization")
parts := make([]string, 0)
key = strings.TrimPrefix(key, "Bearer ")

View File

@@ -162,6 +162,14 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
}
c.Set("platform", string(constant.TaskPlatformSuno))
c.Set("relay_mode", relayMode)
} else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") {
// Gemini API 路径处理: /v1beta/models/gemini-2.0-flash:generateContent
relayMode := relayconstant.RelayModeGemini
modelName := extractModelNameFromGeminiPath(c.Request.URL.Path)
if modelName != "" {
modelRequest.Model = modelName
}
c.Set("relay_mode", relayMode)
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") && !strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") {
err = common.UnmarshalBodyReusable(c, &modelRequest)
}
@@ -244,3 +252,31 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
c.Set("bot_id", channel.Other)
}
}
// extractModelNameFromGeminiPath 从 Gemini API URL 路径中提取模型名
// 输入格式: /v1beta/models/gemini-2.0-flash:generateContent
// 输出: gemini-2.0-flash
func extractModelNameFromGeminiPath(path string) string {
// 查找 "/models/" 的位置
modelsPrefix := "/models/"
modelsIndex := strings.Index(path, modelsPrefix)
if modelsIndex == -1 {
return ""
}
// 从 "/models/" 之后开始提取
startIndex := modelsIndex + len(modelsPrefix)
if startIndex >= len(path) {
return ""
}
// 查找 ":" 的位置,模型名在 ":" 之前
colonIndex := strings.Index(path[startIndex:], ":")
if colonIndex == -1 {
// 如果没有找到 ":",返回从 "/models/" 到路径结尾的部分
return path[startIndex:]
}
// 返回模型名部分
return path[startIndex : startIndex+colonIndex]
}

View File

@@ -10,6 +10,7 @@ import (
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"one-api/service"
"one-api/setting/model_setting"
"strings"
@@ -165,6 +166,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.RelayMode == constant.RelayModeGemini {
err, usage = GeminiTextGenerationHandler(c, resp, info)
return usage, err
}
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
return GeminiImageHandler(c, resp, info)
}

View File

@@ -0,0 +1,77 @@
package gemini
import (
"encoding/json"
"io"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"github.com/gin-gonic/gin"
)
func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
// 读取响应体
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
if common.DebugEnabled {
println(string(responseBody))
}
// 解析为 Gemini 原生响应格式
var geminiResponse dto.GeminiTextGenerationResponse
err = common.DecodeJson(responseBody, &geminiResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
// 检查是否有候选响应
if len(geminiResponse.Candidates) == 0 {
return &dto.OpenAIErrorWithStatusCode{
Error: dto.OpenAIError{
Message: "No candidates returned",
Type: "server_error",
Param: "",
Code: 500,
},
StatusCode: resp.StatusCode,
}, nil
}
// 计算使用量(基于 UsageMetadata
usage := dto.Usage{
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
}
// 设置模型版本
if geminiResponse.ModelVersion == "" {
geminiResponse.ModelVersion = info.UpstreamModelName
}
// 直接返回 Gemini 原生格式的 JSON 响应
jsonResponse, err := json.Marshal(geminiResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
// 设置响应头并写入响应
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "write_response_failed", http.StatusInternalServerError), nil
}
return nil, &usage
}

View File

@@ -43,6 +43,8 @@ const (
RelayModeResponses
RelayModeRealtime
RelayModeGemini
)
func Path2RelayMode(path string) int {
@@ -75,6 +77,8 @@ func Path2RelayMode(path string) int {
relayMode = RelayModeRerank
} else if strings.HasPrefix(path, "/v1/realtime") {
relayMode = RelayModeRealtime
} else if strings.HasPrefix(path, "/v1beta/models") {
relayMode = RelayModeGemini
}
return relayMode
}

141
relay/relay-gemini.go Normal file
View File

@@ -0,0 +1,141 @@
package relay
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"one-api/setting"
"strings"
"github.com/gin-gonic/gin"
)
func getAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiTextGenerationRequest, error) {
request := &dto.GeminiTextGenerationRequest{}
err := common.UnmarshalBodyReusable(c, request)
if err != nil {
return nil, err
}
if len(request.Contents) == 0 {
return nil, errors.New("contents is required")
}
return request, nil
}
func checkGeminiInputSensitive(textRequest *dto.GeminiTextGenerationRequest, info *relaycommon.RelayInfo) ([]string, error) {
var inputTexts []string
for _, content := range textRequest.Contents {
for _, part := range content.Parts {
if part.Text != "" {
inputTexts = append(inputTexts, part.Text)
}
}
}
if len(inputTexts) == 0 {
return nil, nil
}
sensitiveWords, err := service.CheckSensitiveInput(inputTexts)
return sensitiveWords, err
}
func getGeminiInputTokens(req *dto.GeminiTextGenerationRequest, info *relaycommon.RelayInfo) (int, error) {
// 计算输入 token 数量
var inputTexts []string
for _, content := range req.Contents {
for _, part := range content.Parts {
if part.Text != "" {
inputTexts = append(inputTexts, part.Text)
}
}
}
inputText := strings.Join(inputTexts, "\n")
inputTokens, err := service.CountTokenInput(inputText, info.UpstreamModelName)
info.PromptTokens = inputTokens
return inputTokens, err
}
func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
req, err := getAndValidateGeminiRequest(c)
if err != nil {
common.LogError(c, fmt.Sprintf("getAndValidateGeminiRequest error: %s", err.Error()))
return service.OpenAIErrorWrapperLocal(err, "invalid_gemini_request", http.StatusBadRequest)
}
relayInfo := relaycommon.GenRelayInfo(c)
if setting.ShouldCheckPromptSensitive() {
sensitiveWords, err := checkGeminiInputSensitive(req, relayInfo)
if err != nil {
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", ")))
return service.OpenAIErrorWrapperLocal(err, "check_request_sensitive_error", http.StatusBadRequest)
}
}
// model mapped 模型映射
err = helper.ModelMappedHelper(c, relayInfo)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest)
}
if value, exists := c.Get("prompt_tokens"); exists {
promptTokens := value.(int)
relayInfo.SetPromptTokens(promptTokens)
} else {
promptTokens, err := getGeminiInputTokens(req, relayInfo)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "count_input_tokens_error", http.StatusBadRequest)
}
c.Set("prompt_tokens", promptTokens)
}
priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, req.GenerationConfig.MaxOutputTokens)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
}
// pre consume quota
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if openaiErr != nil {
return openaiErr
}
defer func() {
if openaiErr != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
}
}()
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
}
adaptor.Init(relayInfo)
requestBody, err := json.Marshal(req)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
resp, err := adaptor.DoRequest(c, relayInfo, bytes.NewReader(requestBody))
if err != nil {
common.LogError(c, "Do gemini request failed: "+err.Error())
return service.OpenAIErrorWrapperLocal(err, "do_request_failed", http.StatusInternalServerError)
}
usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), relayInfo)
if openaiErr != nil {
return openaiErr
}
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
return nil
}

View File

@@ -79,6 +79,14 @@ func SetRelayRouter(router *gin.Engine) {
relaySunoRouter.GET("/fetch/:id", controller.RelayTask)
}
relayGeminiRouter := router.Group("/v1beta")
relayGeminiRouter.Use(middleware.TokenAuth())
relayGeminiRouter.Use(middleware.ModelRequestRateLimit())
relayGeminiRouter.Use(middleware.Distribute())
{
// Gemini API 路径格式: /v1beta/models/{model_name}:{action}
relayGeminiRouter.POST("/models/*path", controller.Relay)
}
}
func registerMjRouterGroup(relayMjRouter *gin.RouterGroup) {