From 738a9a455853163f6e1e705fe1fb4c2ca5aa094f Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Mon, 26 May 2025 13:34:41 +0800 Subject: [PATCH 1/3] gemini text generation --- controller/relay.go | 2 + dto/gemini.go | 69 ++++++++++ middleware/auth.go | 12 +- middleware/distributor.go | 36 +++++ relay/channel/gemini/adaptor.go | 6 + relay/channel/gemini/relay-gemini-native.go | 77 +++++++++++ relay/constant/relay_mode.go | 4 + relay/relay-gemini.go | 141 ++++++++++++++++++++ router/relay-router.go | 8 ++ 9 files changed, 353 insertions(+), 2 deletions(-) create mode 100644 dto/gemini.go create mode 100644 relay/channel/gemini/relay-gemini-native.go create mode 100644 relay/relay-gemini.go diff --git a/controller/relay.go b/controller/relay.go index 41cb22a5..1a875dbc 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -40,6 +40,8 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode err = relay.EmbeddingHelper(c) case relayconstant.RelayModeResponses: err = relay.ResponsesHelper(c) + case relayconstant.RelayModeGemini: + err = relay.GeminiHelper(c) default: err = relay.TextHelper(c) } diff --git a/dto/gemini.go b/dto/gemini.go new file mode 100644 index 00000000..898c966f --- /dev/null +++ b/dto/gemini.go @@ -0,0 +1,69 @@ +package dto + +import "encoding/json" + +type GeminiPart struct { + Text string `json:"text"` +} + +type GeminiContent struct { + Parts []GeminiPart `json:"parts"` + Role string `json:"role"` +} + +type GeminiCandidate struct { + Content GeminiContent `json:"content"` + FinishReason string `json:"finishReason"` + AvgLogprobs float64 `json:"avgLogprobs"` +} + +type GeminiTokenDetails struct { + Modality string `json:"modality"` + TokenCount int `json:"tokenCount"` +} + +type GeminiUsageMetadata struct { + PromptTokenCount int `json:"promptTokenCount"` + CandidatesTokenCount int `json:"candidatesTokenCount"` + TotalTokenCount int `json:"totalTokenCount"` + PromptTokensDetails []GeminiTokenDetails `json:"promptTokensDetails"` + CandidatesTokensDetails []GeminiTokenDetails `json:"candidatesTokensDetails"` +} + +type GeminiTextGenerationResponse struct { + Candidates []GeminiCandidate `json:"candidates"` + UsageMetadata GeminiUsageMetadata `json:"usageMetadata"` + ModelVersion string `json:"modelVersion"` + ResponseID string `json:"responseId"` +} + +type GeminiGenerationConfig struct { + StopSequences []string `json:"stopSequences,omitempty"` + ResponseMimeType string `json:"responseMimeType,omitempty"` + ResponseSchema *json.RawMessage `json:"responseSchema,omitempty"` + ResponseModalities *json.RawMessage `json:"responseModalities,omitempty"` + CandidateCount int `json:"candidateCount,omitempty"` + MaxOutputTokens int `json:"maxOutputTokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"topP,omitempty"` + TopK int `json:"topK,omitempty"` + Seed int `json:"seed,omitempty"` + PresencePenalty float64 `json:"presencePenalty,omitempty"` + FrequencyPenalty float64 `json:"frequencyPenalty,omitempty"` + ResponseLogprobs bool `json:"responseLogprobs,omitempty"` + LogProbs int `json:"logProbs,omitempty"` + EnableEnhancedCivicAnswers bool `json:"enableEnhancedCivicAnswers,omitempty"` + SpeechConfig *json.RawMessage `json:"speechConfig,omitempty"` + ThinkingConfig *json.RawMessage `json:"thinkingConfig,omitempty"` + MediaResolution *json.RawMessage `json:"mediaResolution,omitempty"` +} + +type GeminiTextGenerationRequest struct { + Contents []GeminiContent `json:"contents"` + Tools *json.RawMessage `json:"tools,omitempty"` + ToolConfig *json.RawMessage `json:"toolConfig,omitempty"` + SafetySettings *json.RawMessage `json:"safetySettings,omitempty"` + SystemInstruction *json.RawMessage `json:"systemInstruction,omitempty"` + GenerationConfig GeminiGenerationConfig `json:"generationConfig,omitempty"` + CachedContent *json.RawMessage `json:"cachedContent,omitempty"` +} diff --git a/middleware/auth.go b/middleware/auth.go index fece4553..ce86bb36 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -1,13 +1,14 @@ package middleware import ( - "github.com/gin-contrib/sessions" - "github.com/gin-gonic/gin" "net/http" "one-api/common" "one-api/model" "strconv" "strings" + + "github.com/gin-contrib/sessions" + "github.com/gin-gonic/gin" ) func validUserInfo(username string, role int) bool { @@ -182,6 +183,13 @@ func TokenAuth() func(c *gin.Context) { c.Request.Header.Set("Authorization", "Bearer "+key) } } + // gemini api 从query中获取key + if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") { + skKey := c.Query("key") + if skKey != "" { + c.Request.Header.Set("Authorization", "Bearer "+skKey) + } + } key := c.Request.Header.Get("Authorization") parts := make([]string, 0) key = strings.TrimPrefix(key, "Bearer ") diff --git a/middleware/distributor.go b/middleware/distributor.go index e7db6d77..1bfe1821 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -162,6 +162,14 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { } c.Set("platform", string(constant.TaskPlatformSuno)) c.Set("relay_mode", relayMode) + } else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") { + // Gemini API 路径处理: /v1beta/models/gemini-2.0-flash:generateContent + relayMode := relayconstant.RelayModeGemini + modelName := extractModelNameFromGeminiPath(c.Request.URL.Path) + if modelName != "" { + modelRequest.Model = modelName + } + c.Set("relay_mode", relayMode) } else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") && !strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") { err = common.UnmarshalBodyReusable(c, &modelRequest) } @@ -244,3 +252,31 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode c.Set("bot_id", channel.Other) } } + +// extractModelNameFromGeminiPath 从 Gemini API URL 路径中提取模型名 +// 输入格式: /v1beta/models/gemini-2.0-flash:generateContent +// 输出: gemini-2.0-flash +func extractModelNameFromGeminiPath(path string) string { + // 查找 "/models/" 的位置 + modelsPrefix := "/models/" + modelsIndex := strings.Index(path, modelsPrefix) + if modelsIndex == -1 { + return "" + } + + // 从 "/models/" 之后开始提取 + startIndex := modelsIndex + len(modelsPrefix) + if startIndex >= len(path) { + return "" + } + + // 查找 ":" 的位置,模型名在 ":" 之前 + colonIndex := strings.Index(path[startIndex:], ":") + if colonIndex == -1 { + // 如果没有找到 ":",返回从 "/models/" 到路径结尾的部分 + return path[startIndex:] + } + + // 返回模型名部分 + return path[startIndex : startIndex+colonIndex] +} diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index c3c7b49d..12833736 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -10,6 +10,7 @@ import ( "one-api/dto" "one-api/relay/channel" relaycommon "one-api/relay/common" + "one-api/relay/constant" "one-api/service" "one-api/setting/model_setting" "strings" @@ -165,6 +166,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { + if info.RelayMode == constant.RelayModeGemini { + err, usage = GeminiTextGenerationHandler(c, resp, info) + return usage, err + } + if strings.HasPrefix(info.UpstreamModelName, "imagen") { return GeminiImageHandler(c, resp, info) } diff --git a/relay/channel/gemini/relay-gemini-native.go b/relay/channel/gemini/relay-gemini-native.go new file mode 100644 index 00000000..16374ea4 --- /dev/null +++ b/relay/channel/gemini/relay-gemini-native.go @@ -0,0 +1,77 @@ +package gemini + +import ( + "encoding/json" + "io" + "net/http" + "one-api/common" + "one-api/dto" + relaycommon "one-api/relay/common" + "one-api/service" + + "github.com/gin-gonic/gin" +) + +func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + // 读取响应体 + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + + if common.DebugEnabled { + println(string(responseBody)) + } + + // 解析为 Gemini 原生响应格式 + var geminiResponse dto.GeminiTextGenerationResponse + err = common.DecodeJson(responseBody, &geminiResponse) + if err != nil { + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + // 检查是否有候选响应 + if len(geminiResponse.Candidates) == 0 { + return &dto.OpenAIErrorWithStatusCode{ + Error: dto.OpenAIError{ + Message: "No candidates returned", + Type: "server_error", + Param: "", + Code: 500, + }, + StatusCode: resp.StatusCode, + }, nil + } + + // 计算使用量(基于 UsageMetadata) + usage := dto.Usage{ + PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount, + CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount, + TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount, + } + + // 设置模型版本 + if geminiResponse.ModelVersion == "" { + geminiResponse.ModelVersion = info.UpstreamModelName + } + + // 直接返回 Gemini 原生格式的 JSON 响应 + jsonResponse, err := json.Marshal(geminiResponse) + if err != nil { + return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + + // 设置响应头并写入响应 + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + if err != nil { + return service.OpenAIErrorWrapper(err, "write_response_failed", http.StatusInternalServerError), nil + } + + return nil, &usage +} diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go index 4454e815..f22a20bd 100644 --- a/relay/constant/relay_mode.go +++ b/relay/constant/relay_mode.go @@ -43,6 +43,8 @@ const ( RelayModeResponses RelayModeRealtime + + RelayModeGemini ) func Path2RelayMode(path string) int { @@ -75,6 +77,8 @@ func Path2RelayMode(path string) int { relayMode = RelayModeRerank } else if strings.HasPrefix(path, "/v1/realtime") { relayMode = RelayModeRealtime + } else if strings.HasPrefix(path, "/v1beta/models") { + relayMode = RelayModeGemini } return relayMode } diff --git a/relay/relay-gemini.go b/relay/relay-gemini.go new file mode 100644 index 00000000..9aa072e1 --- /dev/null +++ b/relay/relay-gemini.go @@ -0,0 +1,141 @@ +package relay + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "net/http" + "one-api/common" + "one-api/dto" + relaycommon "one-api/relay/common" + "one-api/relay/helper" + "one-api/service" + "one-api/setting" + "strings" + + "github.com/gin-gonic/gin" +) + +func getAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiTextGenerationRequest, error) { + request := &dto.GeminiTextGenerationRequest{} + err := common.UnmarshalBodyReusable(c, request) + if err != nil { + return nil, err + } + if len(request.Contents) == 0 { + return nil, errors.New("contents is required") + } + return request, nil +} + +func checkGeminiInputSensitive(textRequest *dto.GeminiTextGenerationRequest, info *relaycommon.RelayInfo) ([]string, error) { + var inputTexts []string + for _, content := range textRequest.Contents { + for _, part := range content.Parts { + if part.Text != "" { + inputTexts = append(inputTexts, part.Text) + } + } + } + if len(inputTexts) == 0 { + return nil, nil + } + + sensitiveWords, err := service.CheckSensitiveInput(inputTexts) + return sensitiveWords, err +} + +func getGeminiInputTokens(req *dto.GeminiTextGenerationRequest, info *relaycommon.RelayInfo) (int, error) { + // 计算输入 token 数量 + var inputTexts []string + for _, content := range req.Contents { + for _, part := range content.Parts { + if part.Text != "" { + inputTexts = append(inputTexts, part.Text) + } + } + } + + inputText := strings.Join(inputTexts, "\n") + inputTokens, err := service.CountTokenInput(inputText, info.UpstreamModelName) + info.PromptTokens = inputTokens + return inputTokens, err +} + +func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { + req, err := getAndValidateGeminiRequest(c) + if err != nil { + common.LogError(c, fmt.Sprintf("getAndValidateGeminiRequest error: %s", err.Error())) + return service.OpenAIErrorWrapperLocal(err, "invalid_gemini_request", http.StatusBadRequest) + } + + relayInfo := relaycommon.GenRelayInfo(c) + + if setting.ShouldCheckPromptSensitive() { + sensitiveWords, err := checkGeminiInputSensitive(req, relayInfo) + if err != nil { + common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", "))) + return service.OpenAIErrorWrapperLocal(err, "check_request_sensitive_error", http.StatusBadRequest) + } + } + + // model mapped 模型映射 + err = helper.ModelMappedHelper(c, relayInfo) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest) + } + + if value, exists := c.Get("prompt_tokens"); exists { + promptTokens := value.(int) + relayInfo.SetPromptTokens(promptTokens) + } else { + promptTokens, err := getGeminiInputTokens(req, relayInfo) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "count_input_tokens_error", http.StatusBadRequest) + } + c.Set("prompt_tokens", promptTokens) + } + + priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, req.GenerationConfig.MaxOutputTokens) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError) + } + + // pre consume quota + preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) + if openaiErr != nil { + return openaiErr + } + defer func() { + if openaiErr != nil { + returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) + } + }() + + adaptor := GetAdaptor(relayInfo.ApiType) + if adaptor == nil { + return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) + } + + adaptor.Init(relayInfo) + + requestBody, err := json.Marshal(req) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "marshal_text_request_failed", http.StatusInternalServerError) + } + + resp, err := adaptor.DoRequest(c, relayInfo, bytes.NewReader(requestBody)) + if err != nil { + common.LogError(c, "Do gemini request failed: "+err.Error()) + return service.OpenAIErrorWrapperLocal(err, "do_request_failed", http.StatusInternalServerError) + } + + usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), relayInfo) + if openaiErr != nil { + return openaiErr + } + + postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") + return nil +} diff --git a/router/relay-router.go b/router/relay-router.go index 4cd84b41..1115a491 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -79,6 +79,14 @@ func SetRelayRouter(router *gin.Engine) { relaySunoRouter.GET("/fetch/:id", controller.RelayTask) } + relayGeminiRouter := router.Group("/v1beta") + relayGeminiRouter.Use(middleware.TokenAuth()) + relayGeminiRouter.Use(middleware.ModelRequestRateLimit()) + relayGeminiRouter.Use(middleware.Distribute()) + { + // Gemini API 路径格式: /v1beta/models/{model_name}:{action} + relayGeminiRouter.POST("/models/*path", controller.Relay) + } } func registerMjRouterGroup(relayMjRouter *gin.RouterGroup) { From d90e4bef63ac262bc3190002bab90180f69acdef Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Mon, 26 May 2025 14:50:50 +0800 Subject: [PATCH 2/3] gemini stream --- dto/gemini.go | 69 ------------------ relay/channel/gemini/adaptor.go | 7 +- relay/channel/gemini/relay-gemini-native.go | 81 +++++++++++++++++---- relay/relay-gemini.go | 28 +++++-- 4 files changed, 93 insertions(+), 92 deletions(-) delete mode 100644 dto/gemini.go diff --git a/dto/gemini.go b/dto/gemini.go deleted file mode 100644 index 898c966f..00000000 --- a/dto/gemini.go +++ /dev/null @@ -1,69 +0,0 @@ -package dto - -import "encoding/json" - -type GeminiPart struct { - Text string `json:"text"` -} - -type GeminiContent struct { - Parts []GeminiPart `json:"parts"` - Role string `json:"role"` -} - -type GeminiCandidate struct { - Content GeminiContent `json:"content"` - FinishReason string `json:"finishReason"` - AvgLogprobs float64 `json:"avgLogprobs"` -} - -type GeminiTokenDetails struct { - Modality string `json:"modality"` - TokenCount int `json:"tokenCount"` -} - -type GeminiUsageMetadata struct { - PromptTokenCount int `json:"promptTokenCount"` - CandidatesTokenCount int `json:"candidatesTokenCount"` - TotalTokenCount int `json:"totalTokenCount"` - PromptTokensDetails []GeminiTokenDetails `json:"promptTokensDetails"` - CandidatesTokensDetails []GeminiTokenDetails `json:"candidatesTokensDetails"` -} - -type GeminiTextGenerationResponse struct { - Candidates []GeminiCandidate `json:"candidates"` - UsageMetadata GeminiUsageMetadata `json:"usageMetadata"` - ModelVersion string `json:"modelVersion"` - ResponseID string `json:"responseId"` -} - -type GeminiGenerationConfig struct { - StopSequences []string `json:"stopSequences,omitempty"` - ResponseMimeType string `json:"responseMimeType,omitempty"` - ResponseSchema *json.RawMessage `json:"responseSchema,omitempty"` - ResponseModalities *json.RawMessage `json:"responseModalities,omitempty"` - CandidateCount int `json:"candidateCount,omitempty"` - MaxOutputTokens int `json:"maxOutputTokens,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"topP,omitempty"` - TopK int `json:"topK,omitempty"` - Seed int `json:"seed,omitempty"` - PresencePenalty float64 `json:"presencePenalty,omitempty"` - FrequencyPenalty float64 `json:"frequencyPenalty,omitempty"` - ResponseLogprobs bool `json:"responseLogprobs,omitempty"` - LogProbs int `json:"logProbs,omitempty"` - EnableEnhancedCivicAnswers bool `json:"enableEnhancedCivicAnswers,omitempty"` - SpeechConfig *json.RawMessage `json:"speechConfig,omitempty"` - ThinkingConfig *json.RawMessage `json:"thinkingConfig,omitempty"` - MediaResolution *json.RawMessage `json:"mediaResolution,omitempty"` -} - -type GeminiTextGenerationRequest struct { - Contents []GeminiContent `json:"contents"` - Tools *json.RawMessage `json:"tools,omitempty"` - ToolConfig *json.RawMessage `json:"toolConfig,omitempty"` - SafetySettings *json.RawMessage `json:"safetySettings,omitempty"` - SystemInstruction *json.RawMessage `json:"systemInstruction,omitempty"` - GenerationConfig GeminiGenerationConfig `json:"generationConfig,omitempty"` - CachedContent *json.RawMessage `json:"cachedContent,omitempty"` -} diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index 12833736..e6f66d5f 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -167,8 +167,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { if info.RelayMode == constant.RelayModeGemini { - err, usage = GeminiTextGenerationHandler(c, resp, info) - return usage, err + if info.IsStream { + return GeminiTextGenerationStreamHandler(c, resp, info) + } else { + return GeminiTextGenerationHandler(c, resp, info) + } } if strings.HasPrefix(info.UpstreamModelName, "imagen") { diff --git a/relay/channel/gemini/relay-gemini-native.go b/relay/channel/gemini/relay-gemini-native.go index 16374ea4..c055e299 100644 --- a/relay/channel/gemini/relay-gemini-native.go +++ b/relay/channel/gemini/relay-gemini-native.go @@ -7,20 +7,21 @@ import ( "one-api/common" "one-api/dto" relaycommon "one-api/relay/common" + "one-api/relay/helper" "one-api/service" "github.com/gin-gonic/gin" ) -func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) { // 读取响应体 responseBody, err := io.ReadAll(resp.Body) if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return nil, service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) } err = resp.Body.Close() if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + return nil, service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) } if common.DebugEnabled { @@ -28,15 +29,15 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela } // 解析为 Gemini 原生响应格式 - var geminiResponse dto.GeminiTextGenerationResponse + var geminiResponse GeminiChatResponse err = common.DecodeJson(responseBody, &geminiResponse) if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return nil, service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) } // 检查是否有候选响应 if len(geminiResponse.Candidates) == 0 { - return &dto.OpenAIErrorWithStatusCode{ + return nil, &dto.OpenAIErrorWithStatusCode{ Error: dto.OpenAIError{ Message: "No candidates returned", Type: "server_error", @@ -44,7 +45,7 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela Code: 500, }, StatusCode: resp.StatusCode, - }, nil + } } // 计算使用量(基于 UsageMetadata) @@ -54,15 +55,10 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount, } - // 设置模型版本 - if geminiResponse.ModelVersion == "" { - geminiResponse.ModelVersion = info.UpstreamModelName - } - // 直接返回 Gemini 原生格式的 JSON 响应 jsonResponse, err := json.Marshal(geminiResponse) if err != nil { - return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return nil, service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError) } // 设置响应头并写入响应 @@ -70,8 +66,63 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela c.Writer.WriteHeader(resp.StatusCode) _, err = c.Writer.Write(jsonResponse) if err != nil { - return service.OpenAIErrorWrapper(err, "write_response_failed", http.StatusInternalServerError), nil + return nil, service.OpenAIErrorWrapper(err, "write_response_failed", http.StatusInternalServerError) } - return nil, &usage + return &usage, nil +} + +func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) { + var usage = &dto.Usage{} + var imageCount int + + helper.SetEventStreamHeaders(c) + + helper.StreamScannerHandler(c, resp, info, func(data string) bool { + var geminiResponse GeminiChatResponse + err := common.DecodeJsonStr(data, &geminiResponse) + if err != nil { + common.LogError(c, "error unmarshalling stream response: "+err.Error()) + return false + } + + // 统计图片数量 + for _, candidate := range geminiResponse.Candidates { + for _, part := range candidate.Content.Parts { + if part.InlineData != nil && part.InlineData.MimeType != "" { + imageCount++ + } + } + } + + // 更新使用量统计 + if geminiResponse.UsageMetadata.TotalTokenCount != 0 { + usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount + usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount + } + + // 直接发送 GeminiChatResponse 响应 + err = helper.ObjectData(c, geminiResponse) + if err != nil { + common.LogError(c, err.Error()) + } + + return true + }) + + if imageCount != 0 { + if usage.CompletionTokens == 0 { + usage.CompletionTokens = imageCount * 258 + } + } + + // 计算最终使用量 + usage.PromptTokensDetails.TextTokens = usage.PromptTokens + usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens + + // 结束流式响应 + helper.Done(c) + + return usage, nil } diff --git a/relay/relay-gemini.go b/relay/relay-gemini.go index 9aa072e1..93a2b7aa 100644 --- a/relay/relay-gemini.go +++ b/relay/relay-gemini.go @@ -8,6 +8,7 @@ import ( "net/http" "one-api/common" "one-api/dto" + "one-api/relay/channel/gemini" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" @@ -17,8 +18,8 @@ import ( "github.com/gin-gonic/gin" ) -func getAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiTextGenerationRequest, error) { - request := &dto.GeminiTextGenerationRequest{} +func getAndValidateGeminiRequest(c *gin.Context) (*gemini.GeminiChatRequest, error) { + request := &gemini.GeminiChatRequest{} err := common.UnmarshalBodyReusable(c, request) if err != nil { return nil, err @@ -29,7 +30,19 @@ func getAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiTextGenerationReque return request, nil } -func checkGeminiInputSensitive(textRequest *dto.GeminiTextGenerationRequest, info *relaycommon.RelayInfo) ([]string, error) { +// 流模式 +// /v1beta/models/gemini-2.0-flash:streamGenerateContent?alt=sse&key=xxx +func checkGeminiStreamMode(c *gin.Context, relayInfo *relaycommon.RelayInfo) { + if c.Query("alt") == "sse" { + relayInfo.IsStream = true + } + + // if strings.Contains(c.Request.URL.Path, "streamGenerateContent") { + // relayInfo.IsStream = true + // } +} + +func checkGeminiInputSensitive(textRequest *gemini.GeminiChatRequest) ([]string, error) { var inputTexts []string for _, content := range textRequest.Contents { for _, part := range content.Parts { @@ -46,7 +59,7 @@ func checkGeminiInputSensitive(textRequest *dto.GeminiTextGenerationRequest, inf return sensitiveWords, err } -func getGeminiInputTokens(req *dto.GeminiTextGenerationRequest, info *relaycommon.RelayInfo) (int, error) { +func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) (int, error) { // 计算输入 token 数量 var inputTexts []string for _, content := range req.Contents { @@ -72,8 +85,11 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { relayInfo := relaycommon.GenRelayInfo(c) + // 检查 Gemini 流式模式 + checkGeminiStreamMode(c, relayInfo) + if setting.ShouldCheckPromptSensitive() { - sensitiveWords, err := checkGeminiInputSensitive(req, relayInfo) + sensitiveWords, err := checkGeminiInputSensitive(req) if err != nil { common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", "))) return service.OpenAIErrorWrapperLocal(err, "check_request_sensitive_error", http.StatusBadRequest) @@ -97,7 +113,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { c.Set("prompt_tokens", promptTokens) } - priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, req.GenerationConfig.MaxOutputTokens) + priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.GenerationConfig.MaxOutputTokens)) if err != nil { return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError) } From 156ad5c3fdc1c547a7b5905a4325bfb8a19cc869 Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Mon, 26 May 2025 15:02:20 +0800 Subject: [PATCH 3/3] vertex --- relay/channel/vertex/adaptor.go | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go index d21a3e08..e58ea762 100644 --- a/relay/channel/vertex/adaptor.go +++ b/relay/channel/vertex/adaptor.go @@ -12,6 +12,7 @@ import ( "one-api/relay/channel/gemini" "one-api/relay/channel/openai" relaycommon "one-api/relay/common" + "one-api/relay/constant" "one-api/setting/model_setting" "strings" @@ -192,7 +193,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom case RequestModeClaude: err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage) case RequestModeGemini: - err, usage = gemini.GeminiChatStreamHandler(c, resp, info) + if info.RelayMode == constant.RelayModeGemini { + usage, err = gemini.GeminiTextGenerationStreamHandler(c, resp, info) + } else { + err, usage = gemini.GeminiChatStreamHandler(c, resp, info) + } case RequestModeLlama: err, usage = openai.OaiStreamHandler(c, resp, info) } @@ -201,7 +206,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom case RequestModeClaude: err, usage = claude.ClaudeHandler(c, resp, claude.RequestModeMessage, info) case RequestModeGemini: - err, usage = gemini.GeminiChatHandler(c, resp, info) + if info.RelayMode == constant.RelayModeGemini { + usage, err = gemini.GeminiTextGenerationHandler(c, resp, info) + } else { + err, usage = gemini.GeminiChatHandler(c, resp, info) + } case RequestModeLlama: err, usage = openai.OpenaiHandler(c, resp, info) }