From 738a9a455853163f6e1e705fe1fb4c2ca5aa094f Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Mon, 26 May 2025 13:34:41 +0800 Subject: [PATCH 01/15] gemini text generation --- controller/relay.go | 2 + dto/gemini.go | 69 ++++++++++ middleware/auth.go | 12 +- middleware/distributor.go | 36 +++++ relay/channel/gemini/adaptor.go | 6 + relay/channel/gemini/relay-gemini-native.go | 77 +++++++++++ relay/constant/relay_mode.go | 4 + relay/relay-gemini.go | 141 ++++++++++++++++++++ router/relay-router.go | 8 ++ 9 files changed, 353 insertions(+), 2 deletions(-) create mode 100644 dto/gemini.go create mode 100644 relay/channel/gemini/relay-gemini-native.go create mode 100644 relay/relay-gemini.go diff --git a/controller/relay.go b/controller/relay.go index 41cb22a5..1a875dbc 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -40,6 +40,8 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode err = relay.EmbeddingHelper(c) case relayconstant.RelayModeResponses: err = relay.ResponsesHelper(c) + case relayconstant.RelayModeGemini: + err = relay.GeminiHelper(c) default: err = relay.TextHelper(c) } diff --git a/dto/gemini.go b/dto/gemini.go new file mode 100644 index 00000000..898c966f --- /dev/null +++ b/dto/gemini.go @@ -0,0 +1,69 @@ +package dto + +import "encoding/json" + +type GeminiPart struct { + Text string `json:"text"` +} + +type GeminiContent struct { + Parts []GeminiPart `json:"parts"` + Role string `json:"role"` +} + +type GeminiCandidate struct { + Content GeminiContent `json:"content"` + FinishReason string `json:"finishReason"` + AvgLogprobs float64 `json:"avgLogprobs"` +} + +type GeminiTokenDetails struct { + Modality string `json:"modality"` + TokenCount int `json:"tokenCount"` +} + +type GeminiUsageMetadata struct { + PromptTokenCount int `json:"promptTokenCount"` + CandidatesTokenCount int `json:"candidatesTokenCount"` + TotalTokenCount int `json:"totalTokenCount"` + PromptTokensDetails []GeminiTokenDetails `json:"promptTokensDetails"` + CandidatesTokensDetails []GeminiTokenDetails `json:"candidatesTokensDetails"` +} + +type GeminiTextGenerationResponse struct { + Candidates []GeminiCandidate `json:"candidates"` + UsageMetadata GeminiUsageMetadata `json:"usageMetadata"` + ModelVersion string `json:"modelVersion"` + ResponseID string `json:"responseId"` +} + +type GeminiGenerationConfig struct { + StopSequences []string `json:"stopSequences,omitempty"` + ResponseMimeType string `json:"responseMimeType,omitempty"` + ResponseSchema *json.RawMessage `json:"responseSchema,omitempty"` + ResponseModalities *json.RawMessage `json:"responseModalities,omitempty"` + CandidateCount int `json:"candidateCount,omitempty"` + MaxOutputTokens int `json:"maxOutputTokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"topP,omitempty"` + TopK int `json:"topK,omitempty"` + Seed int `json:"seed,omitempty"` + PresencePenalty float64 `json:"presencePenalty,omitempty"` + FrequencyPenalty float64 `json:"frequencyPenalty,omitempty"` + ResponseLogprobs bool `json:"responseLogprobs,omitempty"` + LogProbs int `json:"logProbs,omitempty"` + EnableEnhancedCivicAnswers bool `json:"enableEnhancedCivicAnswers,omitempty"` + SpeechConfig *json.RawMessage `json:"speechConfig,omitempty"` + ThinkingConfig *json.RawMessage `json:"thinkingConfig,omitempty"` + MediaResolution *json.RawMessage `json:"mediaResolution,omitempty"` +} + +type GeminiTextGenerationRequest struct { + Contents []GeminiContent `json:"contents"` + Tools *json.RawMessage `json:"tools,omitempty"` + ToolConfig *json.RawMessage `json:"toolConfig,omitempty"` + SafetySettings *json.RawMessage `json:"safetySettings,omitempty"` + SystemInstruction *json.RawMessage `json:"systemInstruction,omitempty"` + GenerationConfig GeminiGenerationConfig `json:"generationConfig,omitempty"` + CachedContent *json.RawMessage `json:"cachedContent,omitempty"` +} diff --git a/middleware/auth.go b/middleware/auth.go index fece4553..ce86bb36 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -1,13 +1,14 @@ package middleware import ( - "github.com/gin-contrib/sessions" - "github.com/gin-gonic/gin" "net/http" "one-api/common" "one-api/model" "strconv" "strings" + + "github.com/gin-contrib/sessions" + "github.com/gin-gonic/gin" ) func validUserInfo(username string, role int) bool { @@ -182,6 +183,13 @@ func TokenAuth() func(c *gin.Context) { c.Request.Header.Set("Authorization", "Bearer "+key) } } + // gemini api 从query中获取key + if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") { + skKey := c.Query("key") + if skKey != "" { + c.Request.Header.Set("Authorization", "Bearer "+skKey) + } + } key := c.Request.Header.Get("Authorization") parts := make([]string, 0) key = strings.TrimPrefix(key, "Bearer ") diff --git a/middleware/distributor.go b/middleware/distributor.go index e7db6d77..1bfe1821 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -162,6 +162,14 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { } c.Set("platform", string(constant.TaskPlatformSuno)) c.Set("relay_mode", relayMode) + } else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") { + // Gemini API 路径处理: /v1beta/models/gemini-2.0-flash:generateContent + relayMode := relayconstant.RelayModeGemini + modelName := extractModelNameFromGeminiPath(c.Request.URL.Path) + if modelName != "" { + modelRequest.Model = modelName + } + c.Set("relay_mode", relayMode) } else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") && !strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") { err = common.UnmarshalBodyReusable(c, &modelRequest) } @@ -244,3 +252,31 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode c.Set("bot_id", channel.Other) } } + +// extractModelNameFromGeminiPath 从 Gemini API URL 路径中提取模型名 +// 输入格式: /v1beta/models/gemini-2.0-flash:generateContent +// 输出: gemini-2.0-flash +func extractModelNameFromGeminiPath(path string) string { + // 查找 "/models/" 的位置 + modelsPrefix := "/models/" + modelsIndex := strings.Index(path, modelsPrefix) + if modelsIndex == -1 { + return "" + } + + // 从 "/models/" 之后开始提取 + startIndex := modelsIndex + len(modelsPrefix) + if startIndex >= len(path) { + return "" + } + + // 查找 ":" 的位置,模型名在 ":" 之前 + colonIndex := strings.Index(path[startIndex:], ":") + if colonIndex == -1 { + // 如果没有找到 ":",返回从 "/models/" 到路径结尾的部分 + return path[startIndex:] + } + + // 返回模型名部分 + return path[startIndex : startIndex+colonIndex] +} diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index c3c7b49d..12833736 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -10,6 +10,7 @@ import ( "one-api/dto" "one-api/relay/channel" relaycommon "one-api/relay/common" + "one-api/relay/constant" "one-api/service" "one-api/setting/model_setting" "strings" @@ -165,6 +166,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { + if info.RelayMode == constant.RelayModeGemini { + err, usage = GeminiTextGenerationHandler(c, resp, info) + return usage, err + } + if strings.HasPrefix(info.UpstreamModelName, "imagen") { return GeminiImageHandler(c, resp, info) } diff --git a/relay/channel/gemini/relay-gemini-native.go b/relay/channel/gemini/relay-gemini-native.go new file mode 100644 index 00000000..16374ea4 --- /dev/null +++ b/relay/channel/gemini/relay-gemini-native.go @@ -0,0 +1,77 @@ +package gemini + +import ( + "encoding/json" + "io" + "net/http" + "one-api/common" + "one-api/dto" + relaycommon "one-api/relay/common" + "one-api/service" + + "github.com/gin-gonic/gin" +) + +func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + // 读取响应体 + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + + if common.DebugEnabled { + println(string(responseBody)) + } + + // 解析为 Gemini 原生响应格式 + var geminiResponse dto.GeminiTextGenerationResponse + err = common.DecodeJson(responseBody, &geminiResponse) + if err != nil { + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + // 检查是否有候选响应 + if len(geminiResponse.Candidates) == 0 { + return &dto.OpenAIErrorWithStatusCode{ + Error: dto.OpenAIError{ + Message: "No candidates returned", + Type: "server_error", + Param: "", + Code: 500, + }, + StatusCode: resp.StatusCode, + }, nil + } + + // 计算使用量(基于 UsageMetadata) + usage := dto.Usage{ + PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount, + CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount, + TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount, + } + + // 设置模型版本 + if geminiResponse.ModelVersion == "" { + geminiResponse.ModelVersion = info.UpstreamModelName + } + + // 直接返回 Gemini 原生格式的 JSON 响应 + jsonResponse, err := json.Marshal(geminiResponse) + if err != nil { + return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + + // 设置响应头并写入响应 + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + if err != nil { + return service.OpenAIErrorWrapper(err, "write_response_failed", http.StatusInternalServerError), nil + } + + return nil, &usage +} diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go index 4454e815..f22a20bd 100644 --- a/relay/constant/relay_mode.go +++ b/relay/constant/relay_mode.go @@ -43,6 +43,8 @@ const ( RelayModeResponses RelayModeRealtime + + RelayModeGemini ) func Path2RelayMode(path string) int { @@ -75,6 +77,8 @@ func Path2RelayMode(path string) int { relayMode = RelayModeRerank } else if strings.HasPrefix(path, "/v1/realtime") { relayMode = RelayModeRealtime + } else if strings.HasPrefix(path, "/v1beta/models") { + relayMode = RelayModeGemini } return relayMode } diff --git a/relay/relay-gemini.go b/relay/relay-gemini.go new file mode 100644 index 00000000..9aa072e1 --- /dev/null +++ b/relay/relay-gemini.go @@ -0,0 +1,141 @@ +package relay + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "net/http" + "one-api/common" + "one-api/dto" + relaycommon "one-api/relay/common" + "one-api/relay/helper" + "one-api/service" + "one-api/setting" + "strings" + + "github.com/gin-gonic/gin" +) + +func getAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiTextGenerationRequest, error) { + request := &dto.GeminiTextGenerationRequest{} + err := common.UnmarshalBodyReusable(c, request) + if err != nil { + return nil, err + } + if len(request.Contents) == 0 { + return nil, errors.New("contents is required") + } + return request, nil +} + +func checkGeminiInputSensitive(textRequest *dto.GeminiTextGenerationRequest, info *relaycommon.RelayInfo) ([]string, error) { + var inputTexts []string + for _, content := range textRequest.Contents { + for _, part := range content.Parts { + if part.Text != "" { + inputTexts = append(inputTexts, part.Text) + } + } + } + if len(inputTexts) == 0 { + return nil, nil + } + + sensitiveWords, err := service.CheckSensitiveInput(inputTexts) + return sensitiveWords, err +} + +func getGeminiInputTokens(req *dto.GeminiTextGenerationRequest, info *relaycommon.RelayInfo) (int, error) { + // 计算输入 token 数量 + var inputTexts []string + for _, content := range req.Contents { + for _, part := range content.Parts { + if part.Text != "" { + inputTexts = append(inputTexts, part.Text) + } + } + } + + inputText := strings.Join(inputTexts, "\n") + inputTokens, err := service.CountTokenInput(inputText, info.UpstreamModelName) + info.PromptTokens = inputTokens + return inputTokens, err +} + +func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { + req, err := getAndValidateGeminiRequest(c) + if err != nil { + common.LogError(c, fmt.Sprintf("getAndValidateGeminiRequest error: %s", err.Error())) + return service.OpenAIErrorWrapperLocal(err, "invalid_gemini_request", http.StatusBadRequest) + } + + relayInfo := relaycommon.GenRelayInfo(c) + + if setting.ShouldCheckPromptSensitive() { + sensitiveWords, err := checkGeminiInputSensitive(req, relayInfo) + if err != nil { + common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", "))) + return service.OpenAIErrorWrapperLocal(err, "check_request_sensitive_error", http.StatusBadRequest) + } + } + + // model mapped 模型映射 + err = helper.ModelMappedHelper(c, relayInfo) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest) + } + + if value, exists := c.Get("prompt_tokens"); exists { + promptTokens := value.(int) + relayInfo.SetPromptTokens(promptTokens) + } else { + promptTokens, err := getGeminiInputTokens(req, relayInfo) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "count_input_tokens_error", http.StatusBadRequest) + } + c.Set("prompt_tokens", promptTokens) + } + + priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, req.GenerationConfig.MaxOutputTokens) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError) + } + + // pre consume quota + preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) + if openaiErr != nil { + return openaiErr + } + defer func() { + if openaiErr != nil { + returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) + } + }() + + adaptor := GetAdaptor(relayInfo.ApiType) + if adaptor == nil { + return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) + } + + adaptor.Init(relayInfo) + + requestBody, err := json.Marshal(req) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "marshal_text_request_failed", http.StatusInternalServerError) + } + + resp, err := adaptor.DoRequest(c, relayInfo, bytes.NewReader(requestBody)) + if err != nil { + common.LogError(c, "Do gemini request failed: "+err.Error()) + return service.OpenAIErrorWrapperLocal(err, "do_request_failed", http.StatusInternalServerError) + } + + usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), relayInfo) + if openaiErr != nil { + return openaiErr + } + + postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") + return nil +} diff --git a/router/relay-router.go b/router/relay-router.go index 4cd84b41..1115a491 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -79,6 +79,14 @@ func SetRelayRouter(router *gin.Engine) { relaySunoRouter.GET("/fetch/:id", controller.RelayTask) } + relayGeminiRouter := router.Group("/v1beta") + relayGeminiRouter.Use(middleware.TokenAuth()) + relayGeminiRouter.Use(middleware.ModelRequestRateLimit()) + relayGeminiRouter.Use(middleware.Distribute()) + { + // Gemini API 路径格式: /v1beta/models/{model_name}:{action} + relayGeminiRouter.POST("/models/*path", controller.Relay) + } } func registerMjRouterGroup(relayMjRouter *gin.RouterGroup) { From d90e4bef63ac262bc3190002bab90180f69acdef Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Mon, 26 May 2025 14:50:50 +0800 Subject: [PATCH 02/15] gemini stream --- dto/gemini.go | 69 ------------------ relay/channel/gemini/adaptor.go | 7 +- relay/channel/gemini/relay-gemini-native.go | 81 +++++++++++++++++---- relay/relay-gemini.go | 28 +++++-- 4 files changed, 93 insertions(+), 92 deletions(-) delete mode 100644 dto/gemini.go diff --git a/dto/gemini.go b/dto/gemini.go deleted file mode 100644 index 898c966f..00000000 --- a/dto/gemini.go +++ /dev/null @@ -1,69 +0,0 @@ -package dto - -import "encoding/json" - -type GeminiPart struct { - Text string `json:"text"` -} - -type GeminiContent struct { - Parts []GeminiPart `json:"parts"` - Role string `json:"role"` -} - -type GeminiCandidate struct { - Content GeminiContent `json:"content"` - FinishReason string `json:"finishReason"` - AvgLogprobs float64 `json:"avgLogprobs"` -} - -type GeminiTokenDetails struct { - Modality string `json:"modality"` - TokenCount int `json:"tokenCount"` -} - -type GeminiUsageMetadata struct { - PromptTokenCount int `json:"promptTokenCount"` - CandidatesTokenCount int `json:"candidatesTokenCount"` - TotalTokenCount int `json:"totalTokenCount"` - PromptTokensDetails []GeminiTokenDetails `json:"promptTokensDetails"` - CandidatesTokensDetails []GeminiTokenDetails `json:"candidatesTokensDetails"` -} - -type GeminiTextGenerationResponse struct { - Candidates []GeminiCandidate `json:"candidates"` - UsageMetadata GeminiUsageMetadata `json:"usageMetadata"` - ModelVersion string `json:"modelVersion"` - ResponseID string `json:"responseId"` -} - -type GeminiGenerationConfig struct { - StopSequences []string `json:"stopSequences,omitempty"` - ResponseMimeType string `json:"responseMimeType,omitempty"` - ResponseSchema *json.RawMessage `json:"responseSchema,omitempty"` - ResponseModalities *json.RawMessage `json:"responseModalities,omitempty"` - CandidateCount int `json:"candidateCount,omitempty"` - MaxOutputTokens int `json:"maxOutputTokens,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"topP,omitempty"` - TopK int `json:"topK,omitempty"` - Seed int `json:"seed,omitempty"` - PresencePenalty float64 `json:"presencePenalty,omitempty"` - FrequencyPenalty float64 `json:"frequencyPenalty,omitempty"` - ResponseLogprobs bool `json:"responseLogprobs,omitempty"` - LogProbs int `json:"logProbs,omitempty"` - EnableEnhancedCivicAnswers bool `json:"enableEnhancedCivicAnswers,omitempty"` - SpeechConfig *json.RawMessage `json:"speechConfig,omitempty"` - ThinkingConfig *json.RawMessage `json:"thinkingConfig,omitempty"` - MediaResolution *json.RawMessage `json:"mediaResolution,omitempty"` -} - -type GeminiTextGenerationRequest struct { - Contents []GeminiContent `json:"contents"` - Tools *json.RawMessage `json:"tools,omitempty"` - ToolConfig *json.RawMessage `json:"toolConfig,omitempty"` - SafetySettings *json.RawMessage `json:"safetySettings,omitempty"` - SystemInstruction *json.RawMessage `json:"systemInstruction,omitempty"` - GenerationConfig GeminiGenerationConfig `json:"generationConfig,omitempty"` - CachedContent *json.RawMessage `json:"cachedContent,omitempty"` -} diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index 12833736..e6f66d5f 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -167,8 +167,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { if info.RelayMode == constant.RelayModeGemini { - err, usage = GeminiTextGenerationHandler(c, resp, info) - return usage, err + if info.IsStream { + return GeminiTextGenerationStreamHandler(c, resp, info) + } else { + return GeminiTextGenerationHandler(c, resp, info) + } } if strings.HasPrefix(info.UpstreamModelName, "imagen") { diff --git a/relay/channel/gemini/relay-gemini-native.go b/relay/channel/gemini/relay-gemini-native.go index 16374ea4..c055e299 100644 --- a/relay/channel/gemini/relay-gemini-native.go +++ b/relay/channel/gemini/relay-gemini-native.go @@ -7,20 +7,21 @@ import ( "one-api/common" "one-api/dto" relaycommon "one-api/relay/common" + "one-api/relay/helper" "one-api/service" "github.com/gin-gonic/gin" ) -func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) { // 读取响应体 responseBody, err := io.ReadAll(resp.Body) if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return nil, service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) } err = resp.Body.Close() if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + return nil, service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) } if common.DebugEnabled { @@ -28,15 +29,15 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela } // 解析为 Gemini 原生响应格式 - var geminiResponse dto.GeminiTextGenerationResponse + var geminiResponse GeminiChatResponse err = common.DecodeJson(responseBody, &geminiResponse) if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return nil, service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) } // 检查是否有候选响应 if len(geminiResponse.Candidates) == 0 { - return &dto.OpenAIErrorWithStatusCode{ + return nil, &dto.OpenAIErrorWithStatusCode{ Error: dto.OpenAIError{ Message: "No candidates returned", Type: "server_error", @@ -44,7 +45,7 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela Code: 500, }, StatusCode: resp.StatusCode, - }, nil + } } // 计算使用量(基于 UsageMetadata) @@ -54,15 +55,10 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount, } - // 设置模型版本 - if geminiResponse.ModelVersion == "" { - geminiResponse.ModelVersion = info.UpstreamModelName - } - // 直接返回 Gemini 原生格式的 JSON 响应 jsonResponse, err := json.Marshal(geminiResponse) if err != nil { - return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return nil, service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError) } // 设置响应头并写入响应 @@ -70,8 +66,63 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela c.Writer.WriteHeader(resp.StatusCode) _, err = c.Writer.Write(jsonResponse) if err != nil { - return service.OpenAIErrorWrapper(err, "write_response_failed", http.StatusInternalServerError), nil + return nil, service.OpenAIErrorWrapper(err, "write_response_failed", http.StatusInternalServerError) } - return nil, &usage + return &usage, nil +} + +func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) { + var usage = &dto.Usage{} + var imageCount int + + helper.SetEventStreamHeaders(c) + + helper.StreamScannerHandler(c, resp, info, func(data string) bool { + var geminiResponse GeminiChatResponse + err := common.DecodeJsonStr(data, &geminiResponse) + if err != nil { + common.LogError(c, "error unmarshalling stream response: "+err.Error()) + return false + } + + // 统计图片数量 + for _, candidate := range geminiResponse.Candidates { + for _, part := range candidate.Content.Parts { + if part.InlineData != nil && part.InlineData.MimeType != "" { + imageCount++ + } + } + } + + // 更新使用量统计 + if geminiResponse.UsageMetadata.TotalTokenCount != 0 { + usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount + usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount + } + + // 直接发送 GeminiChatResponse 响应 + err = helper.ObjectData(c, geminiResponse) + if err != nil { + common.LogError(c, err.Error()) + } + + return true + }) + + if imageCount != 0 { + if usage.CompletionTokens == 0 { + usage.CompletionTokens = imageCount * 258 + } + } + + // 计算最终使用量 + usage.PromptTokensDetails.TextTokens = usage.PromptTokens + usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens + + // 结束流式响应 + helper.Done(c) + + return usage, nil } diff --git a/relay/relay-gemini.go b/relay/relay-gemini.go index 9aa072e1..93a2b7aa 100644 --- a/relay/relay-gemini.go +++ b/relay/relay-gemini.go @@ -8,6 +8,7 @@ import ( "net/http" "one-api/common" "one-api/dto" + "one-api/relay/channel/gemini" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" @@ -17,8 +18,8 @@ import ( "github.com/gin-gonic/gin" ) -func getAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiTextGenerationRequest, error) { - request := &dto.GeminiTextGenerationRequest{} +func getAndValidateGeminiRequest(c *gin.Context) (*gemini.GeminiChatRequest, error) { + request := &gemini.GeminiChatRequest{} err := common.UnmarshalBodyReusable(c, request) if err != nil { return nil, err @@ -29,7 +30,19 @@ func getAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiTextGenerationReque return request, nil } -func checkGeminiInputSensitive(textRequest *dto.GeminiTextGenerationRequest, info *relaycommon.RelayInfo) ([]string, error) { +// 流模式 +// /v1beta/models/gemini-2.0-flash:streamGenerateContent?alt=sse&key=xxx +func checkGeminiStreamMode(c *gin.Context, relayInfo *relaycommon.RelayInfo) { + if c.Query("alt") == "sse" { + relayInfo.IsStream = true + } + + // if strings.Contains(c.Request.URL.Path, "streamGenerateContent") { + // relayInfo.IsStream = true + // } +} + +func checkGeminiInputSensitive(textRequest *gemini.GeminiChatRequest) ([]string, error) { var inputTexts []string for _, content := range textRequest.Contents { for _, part := range content.Parts { @@ -46,7 +59,7 @@ func checkGeminiInputSensitive(textRequest *dto.GeminiTextGenerationRequest, inf return sensitiveWords, err } -func getGeminiInputTokens(req *dto.GeminiTextGenerationRequest, info *relaycommon.RelayInfo) (int, error) { +func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) (int, error) { // 计算输入 token 数量 var inputTexts []string for _, content := range req.Contents { @@ -72,8 +85,11 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { relayInfo := relaycommon.GenRelayInfo(c) + // 检查 Gemini 流式模式 + checkGeminiStreamMode(c, relayInfo) + if setting.ShouldCheckPromptSensitive() { - sensitiveWords, err := checkGeminiInputSensitive(req, relayInfo) + sensitiveWords, err := checkGeminiInputSensitive(req) if err != nil { common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", "))) return service.OpenAIErrorWrapperLocal(err, "check_request_sensitive_error", http.StatusBadRequest) @@ -97,7 +113,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { c.Set("prompt_tokens", promptTokens) } - priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, req.GenerationConfig.MaxOutputTokens) + priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.GenerationConfig.MaxOutputTokens)) if err != nil { return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError) } From 156ad5c3fdc1c547a7b5905a4325bfb8a19cc869 Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Mon, 26 May 2025 15:02:20 +0800 Subject: [PATCH 03/15] vertex --- relay/channel/vertex/adaptor.go | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go index d21a3e08..e58ea762 100644 --- a/relay/channel/vertex/adaptor.go +++ b/relay/channel/vertex/adaptor.go @@ -12,6 +12,7 @@ import ( "one-api/relay/channel/gemini" "one-api/relay/channel/openai" relaycommon "one-api/relay/common" + "one-api/relay/constant" "one-api/setting/model_setting" "strings" @@ -192,7 +193,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom case RequestModeClaude: err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage) case RequestModeGemini: - err, usage = gemini.GeminiChatStreamHandler(c, resp, info) + if info.RelayMode == constant.RelayModeGemini { + usage, err = gemini.GeminiTextGenerationStreamHandler(c, resp, info) + } else { + err, usage = gemini.GeminiChatStreamHandler(c, resp, info) + } case RequestModeLlama: err, usage = openai.OaiStreamHandler(c, resp, info) } @@ -201,7 +206,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom case RequestModeClaude: err, usage = claude.ClaudeHandler(c, resp, claude.RequestModeMessage, info) case RequestModeGemini: - err, usage = gemini.GeminiChatHandler(c, resp, info) + if info.RelayMode == constant.RelayModeGemini { + usage, err = gemini.GeminiTextGenerationHandler(c, resp, info) + } else { + err, usage = gemini.GeminiChatHandler(c, resp, info) + } case RequestModeLlama: err, usage = openai.OpenaiHandler(c, resp, info) } From d608a6f12398f2a52617951685c6990877740def Mon Sep 17 00:00:00 2001 From: Akkuman Date: Thu, 29 May 2025 10:56:01 +0800 Subject: [PATCH 04/15] feat: streaming response for tts --- relay/channel/openai/relay-openai.go | 37 ++++++++++------------------ 1 file changed, 13 insertions(+), 24 deletions(-) diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 86c47a15..2e3d8df1 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -273,36 +273,25 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI } func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil - } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - // Reset response body - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) - // We shouldn't set the header before we parse the response body, because the parse part may fail. - // And then we will have to send an error response, but in this case, the header has already been set. - // So the httpClient will be confused by the response. - // For example, Postman will report error, and we cannot check the response at all. + // the status code has been judged before, if there is a body reading failure, + // it should be regarded as a non-recoverable error, so it should not return err for external retry. + // Analogous to nginx's load balancing, it will only retry if it can't be requested or + // if the upstream returns a specific status code, once the upstream has already written the header, + // the subsequent failure of the response body should be regarded as a non-recoverable error, + // and can be terminated directly. + defer resp.Body.Close() + usage := &dto.Usage{} + usage.PromptTokens = info.PromptTokens + usage.TotalTokens = info.PromptTokens for k, v := range resp.Header { c.Writer.Header().Set(k, v[0]) } c.Writer.WriteHeader(resp.StatusCode) - _, err = io.Copy(c.Writer, resp.Body) + c.Writer.WriteHeaderNow() + _, err := io.Copy(c.Writer, resp.Body) if err != nil { - return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil + common.LogError(c, err.Error()) } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - - usage := &dto.Usage{} - usage.PromptTokens = info.PromptTokens - usage.TotalTokens = info.PromptTokens return nil, usage } From 1b64db55215bf3fb6e10d69b7da30126ed9d1f5a Mon Sep 17 00:00:00 2001 From: RedwindA <128586631+RedwindA@users.noreply.github.com> Date: Thu, 29 May 2025 12:33:27 +0800 Subject: [PATCH 05/15] Add `ERROR_LOG_ENABLED` description --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index a807b07d..5d0014f9 100644 --- a/README.md +++ b/README.md @@ -110,6 +110,7 @@ New API提供了丰富的功能,详细特性请参考[特性说明](https://do - `AZURE_DEFAULT_API_VERSION`:Azure渠道默认API版本,默认 `2025-04-01-preview` - `NOTIFICATION_LIMIT_DURATION_MINUTE`:通知限制持续时间,默认 `10`分钟 - `NOTIFY_LIMIT_COUNT`:用户通知在指定持续时间内的最大数量,默认 `2` +- `ERROR_LOG_ENABLED=true`: 是否记录并显示错误日志,默认`false` ## 部署 From f907c25b21137e8d7a94caa9a8450913e980b941 Mon Sep 17 00:00:00 2001 From: RedwindA <128586631+RedwindA@users.noreply.github.com> Date: Thu, 29 May 2025 12:35:13 +0800 Subject: [PATCH 06/15] Add `ERROR_LOG_ENABLED` description --- README.en.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.en.md b/README.en.md index 4709bc5b..ad11f386 100644 --- a/README.en.md +++ b/README.en.md @@ -110,6 +110,7 @@ For detailed configuration instructions, please refer to [Installation Guide-Env - `AZURE_DEFAULT_API_VERSION`: Azure channel default API version, default is `2025-04-01-preview` - `NOTIFICATION_LIMIT_DURATION_MINUTE`: Notification limit duration, default is `10` minutes - `NOTIFY_LIMIT_COUNT`: Maximum number of user notifications within the specified duration, default is `2` +- `ERROR_LOG_ENABLED=true`: Whether to record and display error logs, default is `false` ## Deployment From 1c4d7fd84b55519235cd88e48cf14cd383275281 Mon Sep 17 00:00:00 2001 From: xqx121 <78908927+xqx121@users.noreply.github.com> Date: Sat, 31 May 2025 17:50:00 +0800 Subject: [PATCH 07/15] Fix: Gemini2.5pro ThinkingConfig --- relay/channel/gemini/relay-gemini.go | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index da0bc5fc..9ab167b1 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -39,15 +39,22 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon } if model_setting.GetGeminiSettings().ThinkingAdapterEnabled { - if strings.HasSuffix(info.OriginModelName, "-thinking") { - budgetTokens := model_setting.GetGeminiSettings().ThinkingAdapterBudgetTokensPercentage * float64(geminiRequest.GenerationConfig.MaxOutputTokens) - if budgetTokens == 0 || budgetTokens > 24576 { - budgetTokens = 24576 - } - geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{ - ThinkingBudget: common.GetPointer(int(budgetTokens)), - IncludeThoughts: true, - } + if strings.HasSuffix(info.OriginModelName, "-thinking") { + // 如果模型名以 gemini-2.5-pro 开头,不设置 ThinkingBudget + if strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro") { + geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{ + IncludeThoughts: true, + } + } else { + budgetTokens := model_setting.GetGeminiSettings().ThinkingAdapterBudgetTokensPercentage * float64(geminiRequest.GenerationConfig.MaxOutputTokens) + if budgetTokens == 0 || budgetTokens > 24576 { + budgetTokens = 24576 + } + geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{ + ThinkingBudget: common.GetPointer(int(budgetTokens)), + IncludeThoughts: true, + } + } } else if strings.HasSuffix(info.OriginModelName, "-nothinking") { geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{ ThinkingBudget: common.GetPointer(0), From c51a30b862525aa4af9bfdd510cdd59ba301b9b5 Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Sat, 31 May 2025 22:13:17 +0800 Subject: [PATCH 08/15] =?UTF-8?q?fix:=20=E6=B5=81=E5=BC=8F=E8=AF=B7?= =?UTF-8?q?=E6=B1=82ping?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- relay/channel/api_request.go | 114 ++++++++++++++++++++--------------- 1 file changed, 66 insertions(+), 48 deletions(-) diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go index da8d4e14..1d733bd4 100644 --- a/relay/channel/api_request.go +++ b/relay/channel/api_request.go @@ -104,6 +104,65 @@ func DoWssRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody return targetConn, nil } +func startPingKeepAlive(c *gin.Context, pingInterval time.Duration) context.CancelFunc { + pingerCtx, stopPinger := context.WithCancel(context.Background()) + + gopool.Go(func() { + defer func() { + if common2.DebugEnabled { + println("SSE ping goroutine stopped.") + } + }() + + if pingInterval <= 0 { + pingInterval = helper.DefaultPingInterval + } + + ticker := time.NewTicker(pingInterval) + // 退出时清理 ticker + defer ticker.Stop() + + var pingMutex sync.Mutex + if common2.DebugEnabled { + println("SSE ping goroutine started") + } + + for { + select { + // 发送 ping 数据 + case <-ticker.C: + if err := sendPingData(c, &pingMutex); err != nil { + return + } + // 收到退出信号 + case <-pingerCtx.Done(): + return + // request 结束 + case <-c.Request.Context().Done(): + return + } + } + }) + + return stopPinger +} + +func sendPingData(c *gin.Context, mutex *sync.Mutex) error { + mutex.Lock() + defer mutex.Unlock() + + err := helper.PingData(c) + if err != nil { + common2.LogError(c, "SSE ping error: "+err.Error()) + return err + } + + if common2.DebugEnabled { + println("SSE ping data sent.") + } + return nil +} + func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http.Response, error) { var client *http.Client var err error @@ -115,69 +174,28 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http } else { client = service.GetHttpClient() } - // 流式请求 ping 保活 - var stopPinger func() - generalSettings := operation_setting.GetGeneralSetting() - pingEnabled := generalSettings.PingIntervalEnabled - var pingerWg sync.WaitGroup + if info.IsStream { helper.SetEventStreamHeaders(c) - if pingEnabled { + // 处理流式请求的 ping 保活 + generalSettings := operation_setting.GetGeneralSetting() + if generalSettings.PingIntervalEnabled { pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second - var pingerCtx context.Context - pingerCtx, stopPinger = context.WithCancel(c.Request.Context()) - // 退出时清理 pingerCtx 防止泄露 + stopPinger := startPingKeepAlive(c, pingInterval) defer stopPinger() - pingerWg.Add(1) - gopool.Go(func() { - defer pingerWg.Done() - if pingInterval <= 0 { - pingInterval = helper.DefaultPingInterval - } - - ticker := time.NewTicker(pingInterval) - defer ticker.Stop() - var pingMutex sync.Mutex - if common2.DebugEnabled { - println("SSE ping goroutine started") - } - - for { - select { - case <-ticker.C: - pingMutex.Lock() - err2 := helper.PingData(c) - pingMutex.Unlock() - if err2 != nil { - common2.LogError(c, "SSE ping error: "+err.Error()) - return - } - if common2.DebugEnabled { - println("SSE ping data sent.") - } - case <-pingerCtx.Done(): - if common2.DebugEnabled { - println("SSE ping goroutine stopped.") - } - return - } - } - }) } } resp, err := client.Do(req) - // request结束后等待 ping goroutine 完成 - if info.IsStream && pingEnabled { - pingerWg.Wait() - } + if err != nil { return nil, err } if resp == nil { return nil, errors.New("resp is nil") } + _ = req.Body.Close() _ = c.Request.Body.Close() return resp, nil From 611d77e1a9f94a5ceacf8380d4f3513dac0fcaaf Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Sun, 1 Jun 2025 01:10:10 +0800 Subject: [PATCH 09/15] feat: add ToMap method and enhance OpenAI request handling --- dto/openai_request.go | 12 ++++++++++-- relay/channel/baidu_v2/adaptor.go | 13 +++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/dto/openai_request.go b/dto/openai_request.go index bda1bb17..16cdf3a2 100644 --- a/dto/openai_request.go +++ b/dto/openai_request.go @@ -2,6 +2,7 @@ package dto import ( "encoding/json" + "one-api/common" "strings" ) @@ -57,6 +58,13 @@ type GeneralOpenAIRequest struct { WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"` } +func (r *GeneralOpenAIRequest) ToMap() map[string]any { + result := make(map[string]any) + data, _ := common.EncodeJson(r) + _ = common.DecodeJson(data, &result) + return result +} + type ToolCallRequest struct { ID string `json:"id,omitempty"` Type string `json:"type"` @@ -74,11 +82,11 @@ type StreamOptions struct { IncludeUsage bool `json:"include_usage,omitempty"` } -func (r GeneralOpenAIRequest) GetMaxTokens() int { +func (r *GeneralOpenAIRequest) GetMaxTokens() int { return int(r.MaxTokens) } -func (r GeneralOpenAIRequest) ParseInput() []string { +func (r *GeneralOpenAIRequest) ParseInput() []string { if r.Input == nil { return nil } diff --git a/relay/channel/baidu_v2/adaptor.go b/relay/channel/baidu_v2/adaptor.go index 77afe2dd..2b8a52a2 100644 --- a/relay/channel/baidu_v2/adaptor.go +++ b/relay/channel/baidu_v2/adaptor.go @@ -9,6 +9,7 @@ import ( "one-api/relay/channel" "one-api/relay/channel/openai" relaycommon "one-api/relay/common" + "strings" "github.com/gin-gonic/gin" ) @@ -49,6 +50,18 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn if request == nil { return nil, errors.New("request is nil") } + if strings.HasSuffix(info.UpstreamModelName, "-search") { + info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-search") + request.Model = info.UpstreamModelName + toMap := request.ToMap() + toMap["web_search"] = map[string]any{ + "enable": true, + "enable_citation": true, + "enable_trace": true, + "enable_status": false, + } + return toMap, nil + } return request, nil } From f1ee9a301d04018861f38c21f9922dcb8e4eaefb Mon Sep 17 00:00:00 2001 From: RedwindA Date: Fri, 23 May 2025 20:02:50 +0800 Subject: [PATCH 10/15] refactor: enhance cleanFunctionParameters for improved handling of JSON schema, including support for $defs and conditional keywords --- relay/channel/gemini/relay-gemini.go | 169 +++++++++++++++------------ 1 file changed, 93 insertions(+), 76 deletions(-) diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index 9ab167b1..c75745ad 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -297,94 +297,111 @@ func cleanFunctionParameters(params interface{}) interface{} { return nil } - paramMap, ok := params.(map[string]interface{}) - if !ok { - // Not a map, return as is (e.g., could be an array or primitive) - return params - } + switch v := params.(type) { + case map[string]interface{}: + // Create a copy to avoid modifying the original + cleanedMap := make(map[string]interface{}) + for k, val := range v { + cleanedMap[k] = val + } - // Create a copy to avoid modifying the original - cleanedMap := make(map[string]interface{}) - for k, v := range paramMap { - cleanedMap[k] = v - } + // Remove unsupported root-level fields + delete(cleanedMap, "default") + delete(cleanedMap, "exclusiveMaximum") + delete(cleanedMap, "exclusiveMinimum") + delete(cleanedMap, "$schema") + delete(cleanedMap, "additionalProperties") - // Remove unsupported root-level fields - delete(cleanedMap, "default") - delete(cleanedMap, "exclusiveMaximum") - delete(cleanedMap, "exclusiveMinimum") - delete(cleanedMap, "$schema") - delete(cleanedMap, "additionalProperties") - - // Clean properties - if props, ok := cleanedMap["properties"].(map[string]interface{}); ok && props != nil { - cleanedProps := make(map[string]interface{}) - for propName, propValue := range props { - propMap, ok := propValue.(map[string]interface{}) - if !ok { - cleanedProps[propName] = propValue // Keep non-map properties - continue - } - - // Create a copy of the property map - cleanedPropMap := make(map[string]interface{}) - for k, v := range propMap { - cleanedPropMap[k] = v - } - - // Remove unsupported fields - delete(cleanedPropMap, "default") - delete(cleanedPropMap, "exclusiveMaximum") - delete(cleanedPropMap, "exclusiveMinimum") - delete(cleanedPropMap, "$schema") - delete(cleanedPropMap, "additionalProperties") - - // Check and clean 'format' for string types - if propType, typeExists := cleanedPropMap["type"].(string); typeExists && propType == "string" { - if formatValue, formatExists := cleanedPropMap["format"].(string); formatExists { - if formatValue != "enum" && formatValue != "date-time" { - delete(cleanedPropMap, "format") - } + // Check and clean 'format' for string types + if propType, typeExists := cleanedMap["type"].(string); typeExists && propType == "string" { + if formatValue, formatExists := cleanedMap["format"].(string); formatExists { + if formatValue != "enum" && formatValue != "date-time" { + delete(cleanedMap, "format") } } + } - // Recursively clean nested properties within this property if it's an object/array - // Check the type before recursing - if propType, typeExists := cleanedPropMap["type"].(string); typeExists && (propType == "object" || propType == "array") { - cleanedProps[propName] = cleanFunctionParameters(cleanedPropMap) - } else { - cleanedProps[propName] = cleanedPropMap // Assign the cleaned map back if not recursing + // Clean properties + if props, ok := cleanedMap["properties"].(map[string]interface{}); ok && props != nil { + cleanedProps := make(map[string]interface{}) + for propName, propValue := range props { + cleanedProps[propName] = cleanFunctionParameters(propValue) } - + cleanedMap["properties"] = cleanedProps } - cleanedMap["properties"] = cleanedProps - } - // Recursively clean items in arrays if needed (e.g., type: array, items: { ... }) - if items, ok := cleanedMap["items"].(map[string]interface{}); ok && items != nil { - cleanedMap["items"] = cleanFunctionParameters(items) - } - // Also handle items if it's an array of schemas - if itemsArray, ok := cleanedMap["items"].([]interface{}); ok { - cleanedItemsArray := make([]interface{}, len(itemsArray)) - for i, item := range itemsArray { - cleanedItemsArray[i] = cleanFunctionParameters(item) + // Recursively clean items in arrays + if items, ok := cleanedMap["items"].(map[string]interface{}); ok && items != nil { + cleanedMap["items"] = cleanFunctionParameters(items) } - cleanedMap["items"] = cleanedItemsArray - } - - // Recursively clean other schema composition keywords if necessary - for _, field := range []string{"allOf", "anyOf", "oneOf"} { - if nested, ok := cleanedMap[field].([]interface{}); ok { - cleanedNested := make([]interface{}, len(nested)) - for i, item := range nested { - cleanedNested[i] = cleanFunctionParameters(item) + // Also handle items if it's an array of schemas + if itemsArray, ok := cleanedMap["items"].([]interface{}); ok { + cleanedItemsArray := make([]interface{}, len(itemsArray)) + for i, item := range itemsArray { + cleanedItemsArray[i] = cleanFunctionParameters(item) } - cleanedMap[field] = cleanedNested + cleanedMap["items"] = cleanedItemsArray } - } - return cleanedMap + // Recursively clean other schema composition keywords + for _, field := range []string{"allOf", "anyOf", "oneOf"} { + if nested, ok := cleanedMap[field].([]interface{}); ok { + cleanedNested := make([]interface{}, len(nested)) + for i, item := range nested { + cleanedNested[i] = cleanFunctionParameters(item) + } + cleanedMap[field] = cleanedNested + } + } + + // Recursively clean patternProperties + if patternProps, ok := cleanedMap["patternProperties"].(map[string]interface{}); ok { + cleanedPatternProps := make(map[string]interface{}) + for pattern, schema := range patternProps { + cleanedPatternProps[pattern] = cleanFunctionParameters(schema) + } + cleanedMap["patternProperties"] = cleanedPatternProps + } + + // Recursively clean definitions + if definitions, ok := cleanedMap["definitions"].(map[string]interface{}); ok { + cleanedDefinitions := make(map[string]interface{}) + for defName, defSchema := range definitions { + cleanedDefinitions[defName] = cleanFunctionParameters(defSchema) + } + cleanedMap["definitions"] = cleanedDefinitions + } + + // Recursively clean $defs (newer JSON Schema draft) + if defs, ok := cleanedMap["$defs"].(map[string]interface{}); ok { + cleanedDefs := make(map[string]interface{}) + for defName, defSchema := range defs { + cleanedDefs[defName] = cleanFunctionParameters(defSchema) + } + cleanedMap["$defs"] = cleanedDefs + } + + // Clean conditional keywords + for _, field := range []string{"if", "then", "else", "not"} { + if nested, ok := cleanedMap[field]; ok { + cleanedMap[field] = cleanFunctionParameters(nested) + } + } + + return cleanedMap + + case []interface{}: + // Handle arrays of schemas + cleanedArray := make([]interface{}, len(v)) + for i, item := range v { + cleanedArray[i] = cleanFunctionParameters(item) + } + return cleanedArray + + default: + // Not a map or array, return as is (e.g., could be a primitive) + return params + } } func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interface{} { From 148c9749123d6c35924ab012d89565c84c856f86 Mon Sep 17 00:00:00 2001 From: RedwindA Date: Mon, 2 Jun 2025 19:00:55 +0800 Subject: [PATCH 11/15] =?UTF-8?q?feat:=20=E5=A2=9E=E5=8A=A0=E5=AF=B9Gemini?= =?UTF-8?q?MIME=E7=B1=BB=E5=9E=8B=E7=9A=84=E9=AA=8C=E8=AF=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- relay/channel/gemini/relay-gemini.go | 39 +++++++++++++++++++++++++--- 1 file changed, 36 insertions(+), 3 deletions(-) diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index 9ab167b1..5dff8ab6 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -18,6 +18,24 @@ import ( "github.com/gin-gonic/gin" ) +var geminiSupportedMimeTypes = map[string]bool{ + "application/pdf": true, + "audio/mpeg": true, + "audio/mp3": true, + "audio/wav": true, + "image/png": true, + "image/jpeg": true, + "text/plain": true, + "video/mov": true, + "video/mpeg": true, + "video/mp4": true, + "video/mpg": true, + "video/avi": true, + "video/wmv": true, + "video/mpegps": true, + "video/flv": true, +} + // Setting safety to the lowest possible values since Gemini is already powerless enough func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*GeminiChatRequest, error) { @@ -215,14 +233,20 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon } // 判断是否是url if strings.HasPrefix(part.GetImageMedia().Url, "http") { - // 是url,获取图片的类型和base64编码的数据 + // 是url,获取文件的类型和base64编码的数据 fileData, err := service.GetFileBase64FromUrl(part.GetImageMedia().Url) if err != nil { - return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error()) + return nil, fmt.Errorf("get file base64 from url '%s' failed: %w", part.GetImageMedia().Url, err) } + + // 校验 MimeType 是否在 Gemini 支持的白名单中 + if _, ok := geminiSupportedMimeTypes[strings.ToLower(fileData.MimeType)]; !ok { + return nil, fmt.Errorf("MIME type '%s' from URL '%s' is not supported by Gemini. Supported types are: %v", fileData.MimeType, part.GetImageMedia().Url, getSupportedMimeTypesList()) + } + parts = append(parts, GeminiPart{ InlineData: &GeminiInlineData{ - MimeType: fileData.MimeType, + MimeType: fileData.MimeType, // 使用原始的 MimeType,因为大小写可能对API有意义 Data: fileData.Base64Data, }, }) @@ -291,6 +315,15 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon return &geminiRequest, nil } +// Helper function to get a list of supported MIME types for error messages +func getSupportedMimeTypesList() []string { + keys := make([]string, 0, len(geminiSupportedMimeTypes)) + for k := range geminiSupportedMimeTypes { + keys = append(keys, k) + } + return keys +} + // cleanFunctionParameters recursively removes unsupported fields from Gemini function parameters. func cleanFunctionParameters(params interface{}) interface{} { if params == nil { From 37caafc722676ea91d6d74658104141fce63fca2 Mon Sep 17 00:00:00 2001 From: xqx121 <78908927+xqx121@users.noreply.github.com> Date: Mon, 2 Jun 2025 22:11:11 +0800 Subject: [PATCH 12/15] Fix: The edit interface is not billed (usage-based pricing). --- relay/relay-image.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/relay/relay-image.go b/relay/relay-image.go index 36b4b9f8..dc63cce8 100644 --- a/relay/relay-image.go +++ b/relay/relay-image.go @@ -41,6 +41,9 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto. imageRequest.Quality = "standard" } } + if imageRequest.N == 0 { + imageRequest.N = 1 + } default: err := common.UnmarshalBodyReusable(c, imageRequest) if err != nil { From 0af047b18c8ecbe500c87cd6f4fcda89947ab93a Mon Sep 17 00:00:00 2001 From: RedwindA Date: Thu, 5 Jun 2025 02:09:21 +0800 Subject: [PATCH 13/15] Add DeepWiki Badge in README --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index 5d0014f9..e9d1c154 100644 --- a/README.md +++ b/README.md @@ -44,6 +44,9 @@ 详细文档请访问我们的官方Wiki:[https://docs.newapi.pro/](https://docs.newapi.pro/) +也可访问AI生成的DeepWiki: +[![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/QuantumNous/new-api) + ## ✨ 主要特性 New API提供了丰富的功能,详细特性请参考[特性说明](https://docs.newapi.pro/wiki/features-introduction): From a8f4ae2a734310131d40459cd31d168cd5952204 Mon Sep 17 00:00:00 2001 From: "Apple\\Apple" Date: Thu, 5 Jun 2025 11:27:00 +0800 Subject: [PATCH 14/15] =?UTF-8?q?=F0=9F=93=95docs:=20Add=20DeepWiki=20Badg?= =?UTF-8?q?e=20in=20`README.en.md`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.en.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.en.md b/README.en.md index ad11f386..10a3cdb0 100644 --- a/README.en.md +++ b/README.en.md @@ -44,6 +44,9 @@ For detailed documentation, please visit our official Wiki: [https://docs.newapi.pro/](https://docs.newapi.pro/) +You can also access the AI-generated DeepWiki: +[![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/QuantumNous/new-api) + ## ✨ Key Features New API offers a wide range of features, please refer to [Features Introduction](https://docs.newapi.pro/wiki/features-introduction) for details: From 3665ad672ef43b86b234aa4dc7a9c052e00f3dde Mon Sep 17 00:00:00 2001 From: neotf Date: Thu, 5 Jun 2025 17:35:48 +0800 Subject: [PATCH 15/15] feat: support claude cache and thinking for upstream [OpenRouter] (#983) * feat: support claude cache for upstream [OpenRouter] * feat: support claude thinking for upstream [OpenRouter] * feat: reasoning is common params for OpenRouter --- dto/claude.go | 23 +++++++-------- dto/openai_request.go | 5 +++- relay/channel/openrouter/dto.go | 9 ++++++ service/convert.go | 50 ++++++++++++++++++++++++++------- 4 files changed, 65 insertions(+), 22 deletions(-) create mode 100644 relay/channel/openrouter/dto.go diff --git a/dto/claude.go b/dto/claude.go index 8068feb8..36dfc02e 100644 --- a/dto/claude.go +++ b/dto/claude.go @@ -7,17 +7,18 @@ type ClaudeMetadata struct { } type ClaudeMediaMessage struct { - Type string `json:"type,omitempty"` - Text *string `json:"text,omitempty"` - Model string `json:"model,omitempty"` - Source *ClaudeMessageSource `json:"source,omitempty"` - Usage *ClaudeUsage `json:"usage,omitempty"` - StopReason *string `json:"stop_reason,omitempty"` - PartialJson *string `json:"partial_json,omitempty"` - Role string `json:"role,omitempty"` - Thinking string `json:"thinking,omitempty"` - Signature string `json:"signature,omitempty"` - Delta string `json:"delta,omitempty"` + Type string `json:"type,omitempty"` + Text *string `json:"text,omitempty"` + Model string `json:"model,omitempty"` + Source *ClaudeMessageSource `json:"source,omitempty"` + Usage *ClaudeUsage `json:"usage,omitempty"` + StopReason *string `json:"stop_reason,omitempty"` + PartialJson *string `json:"partial_json,omitempty"` + Role string `json:"role,omitempty"` + Thinking string `json:"thinking,omitempty"` + Signature string `json:"signature,omitempty"` + Delta string `json:"delta,omitempty"` + CacheControl json.RawMessage `json:"cache_control,omitempty"` // tool_calls Id string `json:"id,omitempty"` Name string `json:"name,omitempty"` diff --git a/dto/openai_request.go b/dto/openai_request.go index 16cdf3a2..a7325fe8 100644 --- a/dto/openai_request.go +++ b/dto/openai_request.go @@ -29,7 +29,6 @@ type GeneralOpenAIRequest struct { MaxTokens uint `json:"max_tokens,omitempty"` MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"` ReasoningEffort string `json:"reasoning_effort,omitempty"` - //Reasoning json.RawMessage `json:"reasoning,omitempty"` Temperature *float64 `json:"temperature,omitempty"` TopP float64 `json:"top_p,omitempty"` TopK int `json:"top_k,omitempty"` @@ -56,6 +55,8 @@ type GeneralOpenAIRequest struct { EnableThinking any `json:"enable_thinking,omitempty"` // ali ExtraBody any `json:"extra_body,omitempty"` WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"` + // OpenRouter Params + Reasoning json.RawMessage `json:"reasoning,omitempty"` } func (r *GeneralOpenAIRequest) ToMap() map[string]any { @@ -125,6 +126,8 @@ type MediaContent struct { InputAudio any `json:"input_audio,omitempty"` File any `json:"file,omitempty"` VideoUrl any `json:"video_url,omitempty"` + // OpenRouter Params + CacheControl json.RawMessage `json:"cache_control,omitempty"` } func (m *MediaContent) GetImageMedia() *MessageImageUrl { diff --git a/relay/channel/openrouter/dto.go b/relay/channel/openrouter/dto.go new file mode 100644 index 00000000..607f495b --- /dev/null +++ b/relay/channel/openrouter/dto.go @@ -0,0 +1,9 @@ +package openrouter + +type RequestReasoning struct { + // One of the following (not both): + Effort string `json:"effort,omitempty"` // Can be "high", "medium", or "low" (OpenAI-style) + MaxTokens int `json:"max_tokens,omitempty"` // Specific token limit (Anthropic-style) + // Optional: Default is false. All models support this. + Exclude bool `json:"exclude,omitempty"` // Set to true to exclude reasoning tokens from response +} diff --git a/service/convert.go b/service/convert.go index cc462b40..cb964a46 100644 --- a/service/convert.go +++ b/service/convert.go @@ -5,6 +5,7 @@ import ( "fmt" "one-api/common" "one-api/dto" + "one-api/relay/channel/openrouter" relaycommon "one-api/relay/common" "strings" ) @@ -18,10 +19,24 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest, info *relaycommon.Re Stream: claudeRequest.Stream, } + isOpenRouter := info.ChannelType == common.ChannelTypeOpenRouter + if claudeRequest.Thinking != nil { - if strings.HasSuffix(info.OriginModelName, "-thinking") && - !strings.HasSuffix(claudeRequest.Model, "-thinking") { - openAIRequest.Model = openAIRequest.Model + "-thinking" + if isOpenRouter { + reasoning := openrouter.RequestReasoning{ + MaxTokens: claudeRequest.Thinking.BudgetTokens, + } + reasoningJSON, err := json.Marshal(reasoning) + if err != nil { + return nil, fmt.Errorf("failed to marshal reasoning: %w", err) + } + openAIRequest.Reasoning = reasoningJSON + } else { + thinkingSuffix := "-thinking" + if strings.HasSuffix(info.OriginModelName, thinkingSuffix) && + !strings.HasSuffix(openAIRequest.Model, thinkingSuffix) { + openAIRequest.Model = openAIRequest.Model + thinkingSuffix + } } } @@ -62,16 +77,30 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest, info *relaycommon.Re } else { systems := claudeRequest.ParseSystem() if len(systems) > 0 { - systemStr := "" openAIMessage := dto.Message{ Role: "system", } - for _, system := range systems { - if system.Text != nil { - systemStr += *system.Text + isOpenRouterClaude := isOpenRouter && strings.HasPrefix(info.UpstreamModelName, "anthropic/claude") + if isOpenRouterClaude { + systemMediaMessages := make([]dto.MediaContent, 0, len(systems)) + for _, system := range systems { + message := dto.MediaContent{ + Type: "text", + Text: system.GetText(), + CacheControl: system.CacheControl, + } + systemMediaMessages = append(systemMediaMessages, message) } + openAIMessage.SetMediaContent(systemMediaMessages) + } else { + systemStr := "" + for _, system := range systems { + if system.Text != nil { + systemStr += *system.Text + } + } + openAIMessage.SetStringContent(systemStr) } - openAIMessage.SetStringContent(systemStr) openAIMessages = append(openAIMessages, openAIMessage) } } @@ -97,8 +126,9 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest, info *relaycommon.Re switch mediaMsg.Type { case "text": message := dto.MediaContent{ - Type: "text", - Text: mediaMsg.GetText(), + Type: "text", + Text: mediaMsg.GetText(), + CacheControl: mediaMsg.CacheControl, } mediaMessages = append(mediaMessages, message) case "image":