From efc9d200b1e21f5bc156560b1c92182f265e6393 Mon Sep 17 00:00:00 2001 From: tbphp Date: Tue, 29 Apr 2025 13:30:03 +0800 Subject: [PATCH 001/105] feat: support thinking suffix for vertex gemini channel --- relay/channel/vertex/adaptor.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go index 77f29620..75d86677 100644 --- a/relay/channel/vertex/adaptor.go +++ b/relay/channel/vertex/adaptor.go @@ -12,6 +12,7 @@ import ( "one-api/relay/channel/claude" "one-api/relay/channel/gemini" "one-api/relay/channel/openai" + "one-api/setting/model_setting" relaycommon "one-api/relay/common" "strings" ) @@ -77,6 +78,15 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { a.AccountCredentials = *adc suffix := "" if a.RequestMode == RequestModeGemini { + if model_setting.GetGeminiSettings().ThinkingAdapterEnabled { + // suffix -thinking and -nothinking + if strings.HasSuffix(info.OriginModelName, "-thinking") { + info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking") + } else if strings.HasSuffix(info.OriginModelName, "-nothinking") { + info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking") + } + } + if info.IsStream { suffix = "streamGenerateContent?alt=sse" } else { From 425feb88d80be216963933b058c23dbed9cebed3 Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Fri, 2 May 2025 13:59:46 +0800 Subject: [PATCH 002/105] feat: support /v1/responses API --- controller/relay.go | 7 +- dto/openai_request.go | 46 +++++++ dto/openai_response.go | 45 +++++++ relay/channel/adapter.go | 4 +- relay/channel/ali/adaptor.go | 8 +- relay/channel/aws/adaptor.go | 8 +- relay/channel/baidu/adaptor.go | 8 +- relay/channel/baidu_v2/adaptor.go | 8 +- relay/channel/claude/adaptor.go | 8 +- relay/channel/cloudflare/adaptor.go | 5 + relay/channel/cohere/adaptor.go | 8 +- relay/channel/deepseek/adaptor.go | 8 +- relay/channel/dify/adaptor.go | 8 +- relay/channel/gemini/adaptor.go | 5 + relay/channel/jina/adaptor.go | 8 +- relay/channel/mistral/adaptor.go | 8 +- relay/channel/mokaai/adaptor.go | 8 +- relay/channel/ollama/adaptor.go | 8 +- relay/channel/openai/adaptor.go | 27 ++++- relay/channel/openai/relay-openai.go | 50 ++++++++ relay/channel/palm/adaptor.go | 8 +- relay/channel/perplexity/adaptor.go | 8 +- relay/channel/siliconflow/adaptor.go | 8 +- relay/channel/tencent/adaptor.go | 8 +- relay/channel/vertex/adaptor.go | 8 +- relay/channel/volcengine/adaptor.go | 8 +- relay/channel/xai/adaptor.go | 8 +- relay/channel/xunfei/adaptor.go | 8 +- relay/channel/zhipu/adaptor.go | 8 +- relay/channel/zhipu_4v/adaptor.go | 8 +- relay/common/relay_info.go | 4 + relay/constant/relay_mode.go | 4 + relay/relay-responses.go | 171 +++++++++++++++++++++++++++ router/relay-router.go | 4 +- 34 files changed, 521 insertions(+), 27 deletions(-) create mode 100644 relay/relay-responses.go diff --git a/controller/relay.go b/controller/relay.go index 91477665..41cb22a5 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -4,8 +4,6 @@ import ( "bytes" "errors" "fmt" - "github.com/gin-gonic/gin" - "github.com/gorilla/websocket" "io" "log" "net/http" @@ -20,6 +18,9 @@ import ( "one-api/relay/helper" "one-api/service" "strings" + + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" ) func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { @@ -37,6 +38,8 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode err = relay.RerankHelper(c, relayMode) case relayconstant.RelayModeEmbeddings: err = relay.EmbeddingHelper(c) + case relayconstant.RelayModeResponses: + err = relay.ResponsesHelper(c) default: err = relay.TextHelper(c) } diff --git a/dto/openai_request.go b/dto/openai_request.go index 652d8cce..f8804ca5 100644 --- a/dto/openai_request.go +++ b/dto/openai_request.go @@ -355,3 +355,49 @@ func (m *Message) ParseContent() []MediaContent { } return contentList } + +type OpenAIResponsesRequest struct { + Model string `json:"model"` + Input json.RawMessage `json:"input,omitempty"` + Include json.RawMessage `json:"include,omitempty"` + Instructions json.RawMessage `json:"instructions,omitempty"` + MaxOutputTokens uint `json:"max_output_tokens,omitempty"` + Metadata json.RawMessage `json:"metadata,omitempty"` + ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"` + PreviousResponseID string `json:"previous_response_id,omitempty"` + Reasoning *Reasoning `json:"reasoning,omitempty"` + ServiceTier string `json:"service_tier,omitempty"` + Store bool `json:"store,omitempty"` + Stream bool `json:"stream,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + Text json.RawMessage `json:"text,omitempty"` + ToolChoice json.RawMessage `json:"tool_choice,omitempty"` + Tools []ResponsesToolsCall `json:"tools,omitempty"` + TopP float64 `json:"top_p,omitempty"` + Truncation string `json:"truncation,omitempty"` + User string `json:"user,omitempty"` +} + +type Reasoning struct { + Effort string `json:"effort,omitempty"` + Summary string `json:"summary,omitempty"` +} + +type ResponsesToolsCall struct { + Type string `json:"type"` + // Web Search + UserLocation json.RawMessage `json:"user_location,omitempty"` + SearchContextSize string `json:"search_context_size,omitempty"` + // File Search + VectorStoreIds []string `json:"vector_store_ids,omitempty"` + MaxNumResults uint `json:"max_num_results,omitempty"` + Filters json.RawMessage `json:"filters,omitempty"` + // Computer Use + DisplayWidth uint `json:"display_width,omitempty"` + DisplayHeight uint `json:"display_height,omitempty"` + Environment string `json:"environment,omitempty"` + // Function + Name string `json:"name,omitempty"` + Description string `json:"description,omitempty"` + Parameters json.RawMessage `json:"parameters,omitempty"` +} diff --git a/dto/openai_response.go b/dto/openai_response.go index c2100ec8..2f858d26 100644 --- a/dto/openai_response.go +++ b/dto/openai_response.go @@ -1,5 +1,7 @@ package dto +import "encoding/json" + type SimpleResponse struct { Usage `json:"usage"` Error *OpenAIError `json:"error"` @@ -191,3 +193,46 @@ type OutputTokenDetails struct { AudioTokens int `json:"audio_tokens"` ReasoningTokens int `json:"reasoning_tokens"` } + +type OpenAIResponsesResponse struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int `json:"created_at"` + Status string `json:"status"` + Error *OpenAIError `json:"error,omitempty"` + IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"` + Instructions string `json:"instructions"` + MaxOutputTokens int `json:"max_output_tokens"` + Model string `json:"model"` + Output []ResponsesOutput `json:"output"` + ParallelToolCalls bool `json:"parallel_tool_calls"` + PreviousResponseID string `json:"previous_response_id"` + Reasoning *Reasoning `json:"reasoning"` + Store bool `json:"store"` + Temperature float64 `json:"temperature"` + ToolChoice string `json:"tool_choice"` + Tools []interface{} `json:"tools"` + TopP float64 `json:"top_p"` + Truncation string `json:"truncation"` + Usage Usage `json:"usage"` + User json.RawMessage `json:"user"` + Metadata json.RawMessage `json:"metadata"` +} + +type IncompleteDetails struct { + Reasoning string `json:"reasoning"` +} + +type ResponsesOutput struct { + Type string `json:"type"` + ID string `json:"id"` + Status string `json:"status"` + Role string `json:"role"` + Content []ResponsesOutputContent `json:"content"` +} + +type ResponsesOutputContent struct { + Type string `json:"type"` + Text string `json:"text"` + Annotations []interface{} `json:"annotations"` +} diff --git a/relay/channel/adapter.go b/relay/channel/adapter.go index e097dbe6..50255d0a 100644 --- a/relay/channel/adapter.go +++ b/relay/channel/adapter.go @@ -1,11 +1,12 @@ package channel import ( - "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" relaycommon "one-api/relay/common" + + "github.com/gin-gonic/gin" ) type Adaptor interface { @@ -18,6 +19,7 @@ type Adaptor interface { ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) + ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) GetModelList() []string diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index 0cbcef44..8e34fd80 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -3,7 +3,6 @@ package ali import ( "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" @@ -11,6 +10,8 @@ import ( "one-api/relay/channel/openai" relaycommon "one-api/relay/common" "one-api/relay/constant" + + "github.com/gin-gonic/gin" ) type Adaptor struct { @@ -79,6 +80,11 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf return nil, errors.New("not implemented") } +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + // TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/aws/adaptor.go b/relay/channel/aws/adaptor.go index ceed39a2..9c879399 100644 --- a/relay/channel/aws/adaptor.go +++ b/relay/channel/aws/adaptor.go @@ -2,13 +2,14 @@ package aws import ( "errors" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" "one-api/relay/channel/claude" relaycommon "one-api/relay/common" "one-api/setting/model_setting" + + "github.com/gin-gonic/gin" ) const ( @@ -74,6 +75,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return nil, errors.New("not implemented") } +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + // TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return nil, nil } diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go index eecb0bac..396c31ab 100644 --- a/relay/channel/baidu/adaptor.go +++ b/relay/channel/baidu/adaptor.go @@ -3,7 +3,6 @@ package baidu import ( "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" @@ -11,6 +10,8 @@ import ( relaycommon "one-api/relay/common" "one-api/relay/constant" "strings" + + "github.com/gin-gonic/gin" ) type Adaptor struct { @@ -130,6 +131,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return baiduEmbeddingRequest, nil } +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + // TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/baidu_v2/adaptor.go b/relay/channel/baidu_v2/adaptor.go index ec7936dc..77afe2dd 100644 --- a/relay/channel/baidu_v2/adaptor.go +++ b/relay/channel/baidu_v2/adaptor.go @@ -3,13 +3,14 @@ package baidu_v2 import ( "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" "one-api/relay/channel" "one-api/relay/channel/openai" relaycommon "one-api/relay/common" + + "github.com/gin-gonic/gin" ) type Adaptor struct { @@ -60,6 +61,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return nil, errors.New("not implemented") } +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + // TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/claude/adaptor.go b/relay/channel/claude/adaptor.go index 6d65d6d4..4b071712 100644 --- a/relay/channel/claude/adaptor.go +++ b/relay/channel/claude/adaptor.go @@ -3,7 +3,6 @@ package claude import ( "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" @@ -11,6 +10,8 @@ import ( relaycommon "one-api/relay/common" "one-api/setting/model_setting" "strings" + + "github.com/gin-gonic/gin" ) const ( @@ -84,6 +85,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return nil, errors.New("not implemented") } +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + // TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/cloudflare/adaptor.go b/relay/channel/cloudflare/adaptor.go index 3d5a5a8a..06f4ca34 100644 --- a/relay/channel/cloudflare/adaptor.go +++ b/relay/channel/cloudflare/adaptor.go @@ -55,6 +55,11 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn } } +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + // TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/cohere/adaptor.go b/relay/channel/cohere/adaptor.go index 53a357ad..a93b10f6 100644 --- a/relay/channel/cohere/adaptor.go +++ b/relay/channel/cohere/adaptor.go @@ -3,13 +3,14 @@ package cohere import ( "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" "one-api/relay/channel" relaycommon "one-api/relay/common" "one-api/relay/constant" + + "github.com/gin-gonic/gin" ) type Adaptor struct { @@ -52,6 +53,11 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn return requestOpenAI2Cohere(*request), nil } +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + // TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/deepseek/adaptor.go b/relay/channel/deepseek/adaptor.go index f6e910e8..76e7fa8d 100644 --- a/relay/channel/deepseek/adaptor.go +++ b/relay/channel/deepseek/adaptor.go @@ -3,7 +3,6 @@ package deepseek import ( "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" @@ -12,6 +11,8 @@ import ( relaycommon "one-api/relay/common" "one-api/relay/constant" "strings" + + "github.com/gin-gonic/gin" ) type Adaptor struct { @@ -71,6 +72,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return nil, errors.New("not implemented") } +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + // TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/dify/adaptor.go b/relay/channel/dify/adaptor.go index dddcb994..51dbee71 100644 --- a/relay/channel/dify/adaptor.go +++ b/relay/channel/dify/adaptor.go @@ -3,12 +3,13 @@ package dify import ( "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" "one-api/relay/channel" relaycommon "one-api/relay/common" + + "github.com/gin-gonic/gin" ) const ( @@ -86,6 +87,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return nil, errors.New("not implemented") } +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + // TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index feaed8f4..c3c7b49d 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -155,6 +155,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return geminiRequest, nil } +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + // TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/jina/adaptor.go b/relay/channel/jina/adaptor.go index 3faac243..85b6a83f 100644 --- a/relay/channel/jina/adaptor.go +++ b/relay/channel/jina/adaptor.go @@ -3,7 +3,6 @@ package jina import ( "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" @@ -12,6 +11,8 @@ import ( relaycommon "one-api/relay/common" "one-api/relay/common_handler" "one-api/relay/constant" + + "github.com/gin-gonic/gin" ) type Adaptor struct { @@ -55,6 +56,11 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn return request, nil } +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + // TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/mistral/adaptor.go b/relay/channel/mistral/adaptor.go index 82c82496..44f57e61 100644 --- a/relay/channel/mistral/adaptor.go +++ b/relay/channel/mistral/adaptor.go @@ -2,13 +2,14 @@ package mistral import ( "errors" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" "one-api/relay/channel" "one-api/relay/channel/openai" relaycommon "one-api/relay/common" + + "github.com/gin-gonic/gin" ) type Adaptor struct { @@ -59,6 +60,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return nil, errors.New("not implemented") } +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + // TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/mokaai/adaptor.go b/relay/channel/mokaai/adaptor.go index 304351fd..b889f225 100644 --- a/relay/channel/mokaai/adaptor.go +++ b/relay/channel/mokaai/adaptor.go @@ -3,7 +3,6 @@ package mokaai import ( "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" @@ -11,6 +10,8 @@ import ( relaycommon "one-api/relay/common" "one-api/relay/constant" "strings" + + "github.com/gin-gonic/gin" ) type Adaptor struct { @@ -74,6 +75,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + // TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go index 39e408ab..18069311 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/channel/ollama/adaptor.go @@ -2,7 +2,6 @@ package ollama import ( "errors" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" @@ -10,6 +9,8 @@ import ( "one-api/relay/channel/openai" relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" + + "github.com/gin-gonic/gin" ) type Adaptor struct { @@ -64,6 +65,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return requestOpenAI2Embeddings(request), nil } +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + // TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 502cee69..dc5098c4 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -25,8 +25,9 @@ import ( "path/filepath" "strings" - "github.com/gin-gonic/gin" "net/textproto" + + "github.com/gin-gonic/gin" ) type Adaptor struct { @@ -67,6 +68,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if info.RelayFormat == relaycommon.RelayFormatClaude { return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil } + if info.RelayMode == constant.RelayModeResponses { + return fmt.Sprintf("%s/v1/responses", info.BaseUrl), nil + } if info.RelayMode == constant.RelayModeRealtime { if strings.HasPrefix(info.BaseUrl, "https://") { baseUrl := strings.TrimPrefix(info.BaseUrl, "https://") @@ -380,6 +384,21 @@ func detectImageMimeType(filename string) string { } } +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + // 模型后缀转换 reasoning effort + if strings.HasSuffix(request.Model, "-high") { + request.Reasoning.Effort = "high" + request.Model = strings.TrimSuffix(request.Model, "-high") + } else if strings.HasSuffix(request.Model, "-low") { + request.Reasoning.Effort = "low" + request.Model = strings.TrimSuffix(request.Model, "-low") + } else if strings.HasSuffix(request.Model, "-medium") { + request.Reasoning.Effort = "medium" + request.Model = strings.TrimSuffix(request.Model, "-medium") + } + return request, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation || @@ -406,6 +425,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom err, usage = OpenaiHandlerWithUsage(c, resp, info) case constant.RelayModeRerank: err, usage = common_handler.RerankHandler(c, info, resp) + case constant.RelayModeResponses: + if info.IsStream { + err, usage = OaiStreamHandler(c, resp, info) + } else { + err, usage = OpenaiResponsesHandler(c, resp, info) + } default: if info.IsStream { err, usage = OaiStreamHandler(c, resp, info) diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index b9ed94e2..269a76f7 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -644,3 +644,53 @@ func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycomm } return nil, &usageResp.Usage } + +func OpenaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + // read response body + var responsesResponse dto.OpenAIResponsesResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = common.DecodeJson(responseBody, &responsesResponse) + if err != nil { + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if responsesResponse.Error != nil { + return &dto.OpenAIErrorWithStatusCode{ + Error: dto.OpenAIError{ + Message: responsesResponse.Error.Message, + Type: "openai_error", + Code: responsesResponse.Error.Code, + }, + StatusCode: resp.StatusCode, + }, nil + } + + // reset response body + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + // We shouldn't set the header before we parse the response body, because the parse part may fail. + // And then we will have to send an error response, but in this case, the header has already been set. + // So the httpClient will be confused by the response. + // For example, Postman will report error, and we cannot check the response at all. + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + c.Writer.WriteHeader(resp.StatusCode) + // copy response body + _, err = io.Copy(c.Writer, resp.Body) + if err != nil { + common.SysError("error copying response body: " + err.Error()) + } + resp.Body.Close() + // compute usage + usage := dto.Usage{} + usage.PromptTokens = responsesResponse.Usage.InputTokens + usage.CompletionTokens = responsesResponse.Usage.OutputTokens + usage.TotalTokens = responsesResponse.Usage.TotalTokens + return nil, &usage +} diff --git a/relay/channel/palm/adaptor.go b/relay/channel/palm/adaptor.go index f0220f4f..3a06e7ee 100644 --- a/relay/channel/palm/adaptor.go +++ b/relay/channel/palm/adaptor.go @@ -3,13 +3,14 @@ package palm import ( "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" "one-api/relay/channel" relaycommon "one-api/relay/common" "one-api/service" + + "github.com/gin-gonic/gin" ) type Adaptor struct { @@ -60,6 +61,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return nil, errors.New("not implemented") } +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + // TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/perplexity/adaptor.go b/relay/channel/perplexity/adaptor.go index 5727cac7..ca206503 100644 --- a/relay/channel/perplexity/adaptor.go +++ b/relay/channel/perplexity/adaptor.go @@ -3,13 +3,14 @@ package perplexity import ( "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" "one-api/relay/channel" "one-api/relay/channel/openai" relaycommon "one-api/relay/common" + + "github.com/gin-gonic/gin" ) type Adaptor struct { @@ -63,6 +64,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return nil, errors.New("not implemented") } +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + // TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/siliconflow/adaptor.go b/relay/channel/siliconflow/adaptor.go index cf38c15e..89236ea3 100644 --- a/relay/channel/siliconflow/adaptor.go +++ b/relay/channel/siliconflow/adaptor.go @@ -3,7 +3,6 @@ package siliconflow import ( "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" @@ -11,6 +10,8 @@ import ( "one-api/relay/channel/openai" relaycommon "one-api/relay/common" "one-api/relay/constant" + + "github.com/gin-gonic/gin" ) type Adaptor struct { @@ -58,6 +59,11 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn return request, nil } +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + // TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go index f2b51ee9..44718a25 100644 --- a/relay/channel/tencent/adaptor.go +++ b/relay/channel/tencent/adaptor.go @@ -3,7 +3,6 @@ package tencent import ( "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" @@ -13,6 +12,8 @@ import ( "one-api/service" "strconv" "strings" + + "github.com/gin-gonic/gin" ) type Adaptor struct { @@ -84,6 +85,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return nil, errors.New("not implemented") } +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + // TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go index 77f29620..a1425315 100644 --- a/relay/channel/vertex/adaptor.go +++ b/relay/channel/vertex/adaptor.go @@ -4,7 +4,6 @@ import ( "encoding/json" "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" @@ -14,6 +13,8 @@ import ( "one-api/relay/channel/openai" relaycommon "one-api/relay/common" "strings" + + "github.com/gin-gonic/gin" ) const ( @@ -164,6 +165,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return nil, errors.New("not implemented") } +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + // TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/volcengine/adaptor.go b/relay/channel/volcengine/adaptor.go index 277285b7..a4a48ee9 100644 --- a/relay/channel/volcengine/adaptor.go +++ b/relay/channel/volcengine/adaptor.go @@ -3,7 +3,6 @@ package volcengine import ( "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" @@ -12,6 +11,8 @@ import ( relaycommon "one-api/relay/common" "one-api/relay/constant" "strings" + + "github.com/gin-gonic/gin" ) type Adaptor struct { @@ -71,6 +72,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return request, nil } +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + // TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/xai/adaptor.go b/relay/channel/xai/adaptor.go index 669b8c68..12634c84 100644 --- a/relay/channel/xai/adaptor.go +++ b/relay/channel/xai/adaptor.go @@ -3,13 +3,14 @@ package xai import ( "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" "one-api/relay/channel" relaycommon "one-api/relay/common" "strings" + + "github.com/gin-gonic/gin" ) type Adaptor struct { @@ -78,6 +79,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return nil, errors.New("not available") } +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + // TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/xunfei/adaptor.go b/relay/channel/xunfei/adaptor.go index 9521bb47..7591e0e7 100644 --- a/relay/channel/xunfei/adaptor.go +++ b/relay/channel/xunfei/adaptor.go @@ -2,7 +2,6 @@ package xunfei import ( "errors" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" @@ -10,6 +9,8 @@ import ( relaycommon "one-api/relay/common" "one-api/service" "strings" + + "github.com/gin-gonic/gin" ) type Adaptor struct { @@ -61,6 +62,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return nil, errors.New("not implemented") } +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + // TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { // xunfei's request is not http request, so we don't need to do anything here dummyResp := &http.Response{} diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go index 04369001..b4d8fb30 100644 --- a/relay/channel/zhipu/adaptor.go +++ b/relay/channel/zhipu/adaptor.go @@ -3,12 +3,13 @@ package zhipu import ( "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" "one-api/relay/channel" relaycommon "one-api/relay/common" + + "github.com/gin-gonic/gin" ) type Adaptor struct { @@ -71,6 +72,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + // TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { err, usage = zhipuStreamHandler(c, resp) diff --git a/relay/channel/zhipu_4v/adaptor.go b/relay/channel/zhipu_4v/adaptor.go index e13a7ad2..222cdff8 100644 --- a/relay/channel/zhipu_4v/adaptor.go +++ b/relay/channel/zhipu_4v/adaptor.go @@ -3,7 +3,6 @@ package zhipu_4v import ( "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" @@ -11,6 +10,8 @@ import ( "one-api/relay/channel/openai" relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" + + "github.com/gin-gonic/gin" ) type Adaptor struct { @@ -70,6 +71,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return request, nil } +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + // TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index a07ec316..915474e1 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -200,6 +200,10 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { if streamSupportedChannels[info.ChannelType] { info.SupportStreamOptions = true } + // responses 模式不支持 StreamOptions + if relayconstant.RelayModeResponses == info.RelayMode { + info.SupportStreamOptions = false + } return info } diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go index e2d51098..4454e815 100644 --- a/relay/constant/relay_mode.go +++ b/relay/constant/relay_mode.go @@ -40,6 +40,8 @@ const ( RelayModeRerank + RelayModeResponses + RelayModeRealtime ) @@ -61,6 +63,8 @@ func Path2RelayMode(path string) int { relayMode = RelayModeImagesEdits } else if strings.HasPrefix(path, "/v1/edits") { relayMode = RelayModeEdits + } else if strings.HasPrefix(path, "/v1/responses") { + relayMode = RelayModeResponses } else if strings.HasPrefix(path, "/v1/audio/speech") { relayMode = RelayModeAudioSpeech } else if strings.HasPrefix(path, "/v1/audio/transcriptions") { diff --git a/relay/relay-responses.go b/relay/relay-responses.go new file mode 100644 index 00000000..cdb37ae7 --- /dev/null +++ b/relay/relay-responses.go @@ -0,0 +1,171 @@ +package relay + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "one-api/common" + "one-api/dto" + relaycommon "one-api/relay/common" + "one-api/relay/helper" + "one-api/service" + "one-api/setting" + "one-api/setting/model_setting" + "strings" + + "github.com/gin-gonic/gin" +) + +func getAndValidateResponsesRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo) (*dto.OpenAIResponsesRequest, error) { + request := &dto.OpenAIResponsesRequest{} + err := common.UnmarshalBodyReusable(c, request) + if err != nil { + return nil, err + } + if request.Model == "" { + return nil, errors.New("model is required") + } + if len(request.Input) == 0 { + return nil, errors.New("input is required") + } + relayInfo.IsStream = request.Stream + return request, nil + +} + +func checkInputSensitive(textRequest *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) ([]string, error) { + + sensitiveWords, err := service.CheckSensitiveInput(textRequest.Input) + return sensitiveWords, err +} + +func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) (int, error) { + inputTokens, err := service.CountTokenInput(req.Input, req.Model) + info.PromptTokens = inputTokens + return inputTokens, err +} + +func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { + relayInfo := relaycommon.GenRelayInfo(c) + req, err := getAndValidateResponsesRequest(c, relayInfo) + if err != nil { + common.LogError(c, fmt.Sprintf("getAndValidateResponsesRequest error: %s", err.Error())) + return service.OpenAIErrorWrapperLocal(err, "invalid_responses_request", http.StatusBadRequest) + } + if setting.ShouldCheckPromptSensitive() { + sensitiveWords, err := checkInputSensitive(req, relayInfo) + if err != nil { + common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", "))) + return service.OpenAIErrorWrapperLocal(err, "check_request_sensitive_error", http.StatusBadRequest) + } + } + + err = helper.ModelMappedHelper(c, relayInfo) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest) + } + req.Model = relayInfo.UpstreamModelName + if value, exists := c.Get("prompt_tokens"); exists { + promptTokens := value.(int) + relayInfo.SetPromptTokens(promptTokens) + } else { + promptTokens, err := getInputTokens(req, relayInfo) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "count_input_tokens_error", http.StatusBadRequest) + } + c.Set("prompt_tokens", promptTokens) + } + + priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.MaxOutputTokens)) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError) + } + // pre consume quota + preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) + if openaiErr != nil { + return openaiErr + } + defer func() { + if openaiErr != nil { + returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) + } + }() + adaptor := GetAdaptor(relayInfo.ApiType) + if adaptor == nil { + return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) + } + adaptor.Init(relayInfo) + var requestBody io.Reader + if model_setting.GetGlobalSettings().PassThroughRequestEnabled { + body, err := common.GetRequestBody(c) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "get_request_body_error", http.StatusInternalServerError) + } + requestBody = bytes.NewBuffer(body) + } else { + convertedRequest, err := adaptor.ConvertOpenAIResponsesRequest(c, relayInfo, *req) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "convert_request_error", http.StatusBadRequest) + } + jsonData, err := json.Marshal(convertedRequest) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "marshal_request_error", http.StatusInternalServerError) + } + // apply param override + if len(relayInfo.ParamOverride) > 0 { + reqMap := make(map[string]interface{}) + err = json.Unmarshal(jsonData, &reqMap) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "param_override_unmarshal_failed", http.StatusInternalServerError) + } + for key, value := range relayInfo.ParamOverride { + reqMap[key] = value + } + jsonData, err = json.Marshal(reqMap) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "param_override_marshal_failed", http.StatusInternalServerError) + } + } + + if common.DebugEnabled { + println("requestBody: ", string(jsonData)) + } + requestBody = bytes.NewBuffer(jsonData) + } + + var httpResp *http.Response + resp, err := adaptor.DoRequest(c, relayInfo, requestBody) + if err != nil { + return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) + } + + statusCodeMappingStr := c.GetString("status_code_mapping") + + if resp != nil { + httpResp = resp.(*http.Response) + + if httpResp.StatusCode != http.StatusOK { + openaiErr = service.RelayErrorHandler(httpResp, false) + // reset status code 重置状态码 + service.ResetStatusCode(openaiErr, statusCodeMappingStr) + return openaiErr + } + } + + usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo) + if openaiErr != nil { + // reset status code 重置状态码 + service.ResetStatusCode(openaiErr, statusCodeMappingStr) + return openaiErr + } + + if strings.HasPrefix(relayInfo.OriginModelName, "gpt-4o-audio") { + service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") + } else { + postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") + } + return nil +} diff --git a/router/relay-router.go b/router/relay-router.go index 85000beb..4cd84b41 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -1,10 +1,11 @@ package router import ( - "github.com/gin-gonic/gin" "one-api/controller" "one-api/middleware" "one-api/relay" + + "github.com/gin-gonic/gin" ) func SetRelayRouter(router *gin.Engine) { @@ -47,6 +48,7 @@ func SetRelayRouter(router *gin.Engine) { httpRouter.POST("/audio/transcriptions", controller.Relay) httpRouter.POST("/audio/translations", controller.Relay) httpRouter.POST("/audio/speech", controller.Relay) + httpRouter.POST("/responses", controller.Relay) httpRouter.GET("/files", controller.RelayNotImplemented) httpRouter.POST("/files", controller.RelayNotImplemented) httpRouter.DELETE("/files/:id", controller.RelayNotImplemented) From e097d5a538c55f4347bb31aed9d57e221905de8c Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Sat, 3 May 2025 21:12:07 +0800 Subject: [PATCH 003/105] feat: add video URL support in MediaContent and update token counting logic --- dto/openai_request.go | 15 +++++++++++++++ main.go | 4 ++-- service/token_counter.go | 2 ++ 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/dto/openai_request.go b/dto/openai_request.go index 652d8cce..28903ed7 100644 --- a/dto/openai_request.go +++ b/dto/openai_request.go @@ -114,6 +114,7 @@ type MediaContent struct { ImageUrl any `json:"image_url,omitempty"` InputAudio any `json:"input_audio,omitempty"` File any `json:"file,omitempty"` + VideoUrl any `json:"video_url,omitempty"` } func (m *MediaContent) GetImageMedia() *MessageImageUrl { @@ -158,11 +159,16 @@ type MessageFile struct { FileId string `json:"file_id,omitempty"` } +type MessageVideoUrl struct { + Url string `json:"url"` +} + const ( ContentTypeText = "text" ContentTypeImageURL = "image_url" ContentTypeInputAudio = "input_audio" ContentTypeFile = "file" + ContentTypeVideoUrl = "video_url" // 阿里百炼视频识别 ) func (m *Message) GetPrefix() bool { @@ -346,6 +352,15 @@ func (m *Message) ParseContent() []MediaContent { } } } + case ContentTypeVideoUrl: + if videoUrl, ok := contentItem["video_url"].(string); ok { + contentList = append(contentList, MediaContent{ + Type: ContentTypeVideoUrl, + VideoUrl: &MessageVideoUrl{ + Url: videoUrl, + }, + }) + } } } } diff --git a/main.go b/main.go index 4bdc97bd..95c6820d 100644 --- a/main.go +++ b/main.go @@ -80,6 +80,8 @@ func main() { // Initialize options model.InitOptionMap() + service.InitTokenEncoders() + if common.RedisEnabled { // for compatibility with old versions common.MemoryCacheEnabled = true @@ -133,8 +135,6 @@ func main() { common.SysLog("pprof enabled") } - service.InitTokenEncoders() - // Initialize HTTP server server := gin.New() server.Use(gin.CustomRecovery(func(c *gin.Context, err any) { diff --git a/service/token_counter.go b/service/token_counter.go index f3c3b6b0..21b882af 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -400,6 +400,8 @@ func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, mod tokenNum += 100 } else if m.Type == dto.ContentTypeFile { tokenNum += 5000 + } else if m.Type == dto.ContentTypeVideoUrl { + tokenNum += 5000 } else { tokenNum += getTokenNum(tokenEncoder, m.Text) } From 1236fa8fe42f8fc736c4450836bc5b93b05e0f85 Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Sat, 3 May 2025 19:19:19 +0800 Subject: [PATCH 004/105] add OaiResponsesStreamHandler --- dto/openai_response.go | 8 +++- relay/channel/openai/adaptor.go | 2 +- relay/channel/openai/relay-openai.go | 62 ++++++++++++++++++++++++++++ 3 files changed, 70 insertions(+), 2 deletions(-) diff --git a/dto/openai_response.go b/dto/openai_response.go index 2f858d26..02befd79 100644 --- a/dto/openai_response.go +++ b/dto/openai_response.go @@ -214,7 +214,7 @@ type OpenAIResponsesResponse struct { Tools []interface{} `json:"tools"` TopP float64 `json:"top_p"` Truncation string `json:"truncation"` - Usage Usage `json:"usage"` + Usage *Usage `json:"usage"` User json.RawMessage `json:"user"` Metadata json.RawMessage `json:"metadata"` } @@ -236,3 +236,9 @@ type ResponsesOutputContent struct { Text string `json:"text"` Annotations []interface{} `json:"annotations"` } + +// ResponsesStreamResponse 用于处理 /v1/responses 流式响应 +type ResponsesStreamResponse struct { + Type string `json:"type"` + Response *OpenAIResponsesResponse `json:"response"` +} diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index dc5098c4..7740c498 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -427,7 +427,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom err, usage = common_handler.RerankHandler(c, info, resp) case constant.RelayModeResponses: if info.IsStream { - err, usage = OaiStreamHandler(c, resp, info) + err, usage = OaiResponsesStreamHandler(c, resp, info) } else { err, usage = OpenaiResponsesHandler(c, resp, info) } diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 269a76f7..bfeed2cf 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -694,3 +694,65 @@ func OpenaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycomm usage.TotalTokens = responsesResponse.Usage.TotalTokens return nil, &usage } + +func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + if resp == nil || resp.Body == nil { + common.LogError(c, "invalid response or response body") + return service.OpenAIErrorWrapper(fmt.Errorf("invalid response"), "invalid_response", http.StatusInternalServerError), nil + } + + var usage = &dto.Usage{} + var streamItems []string // 存储流式数据项 + // var responseTextBuilder strings.Builder + // var toolCount int + var forceFormat bool + + if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok { + forceFormat = forceFmt + } + + var lastStreamData string + + helper.StreamScannerHandler(c, resp, info, func(data string) bool { + if lastStreamData != "" { + // 处理上一条数据 + sendResponsesStreamData(c, lastStreamData, forceFormat) + } + lastStreamData = data + streamItems = append(streamItems, data) + + // 检查当前数据是否包含 completed 状态和 usage 信息 + var streamResponse dto.ResponsesStreamResponse + if err := common.DecodeJsonStr(data, &streamResponse); err == nil { + if streamResponse.Type == "response.completed" { + // 处理 completed 状态 + usage.PromptTokens = streamResponse.Response.Usage.InputTokens + usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens + usage.TotalTokens = streamResponse.Response.Usage.TotalTokens + } + } + return true + }) + + // 处理最后一条数据 + sendResponsesStreamData(c, lastStreamData, forceFormat) + + // 处理token计算 + // if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil { + // common.SysError("error processing tokens: " + err.Error()) + // } + + return nil, usage +} + +func sendResponsesStreamData(c *gin.Context, data string, forceFormat bool) error { + if data == "" { + return nil + } + + if forceFormat { + return helper.ObjectData(c, data) + } else { + return helper.StringData(c, data) + } +} From fe3232bf23484297ca57cc4fefb29a566e0f3f3e Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Sun, 4 May 2025 17:09:37 +0800 Subject: [PATCH 005/105] feat: enhance OaiResponsesStreamHandler to handle output text and improve response streaming --- dto/openai_response.go | 11 ++++++ relay/channel/openai/relay-openai.go | 53 +++++++++++----------------- relay/helper/common.go | 8 +++++ 3 files changed, 40 insertions(+), 32 deletions(-) diff --git a/dto/openai_response.go b/dto/openai_response.go index 02befd79..1508d1f6 100644 --- a/dto/openai_response.go +++ b/dto/openai_response.go @@ -237,8 +237,19 @@ type ResponsesOutputContent struct { Annotations []interface{} `json:"annotations"` } +const ( + BuildInTools_WebSearch = "web_search_preview" + BuildInTools_FileSearch = "file_search" +) + +const ( + ResponsesOutputTypeItemAdded = "response.output_item.added" + ResponsesOutputTypeItemDone = "response.output_item.done" +) + // ResponsesStreamResponse 用于处理 /v1/responses 流式响应 type ResponsesStreamResponse struct { Type string `json:"type"` Response *OpenAIResponsesResponse `json:"response"` + Delta string `json:"delta,omitempty"` } diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index bfeed2cf..f10ebc1b 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -702,57 +702,46 @@ func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relayc } var usage = &dto.Usage{} - var streamItems []string // 存储流式数据项 - // var responseTextBuilder strings.Builder - // var toolCount int - var forceFormat bool - - if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok { - forceFormat = forceFmt - } - - var lastStreamData string + var responseTextBuilder strings.Builder helper.StreamScannerHandler(c, resp, info, func(data string) bool { - if lastStreamData != "" { - // 处理上一条数据 - sendResponsesStreamData(c, lastStreamData, forceFormat) - } - lastStreamData = data - streamItems = append(streamItems, data) // 检查当前数据是否包含 completed 状态和 usage 信息 var streamResponse dto.ResponsesStreamResponse if err := common.DecodeJsonStr(data, &streamResponse); err == nil { - if streamResponse.Type == "response.completed" { - // 处理 completed 状态 + sendResponsesStreamData(c, streamResponse, data) + switch streamResponse.Type { + case "response.completed": usage.PromptTokens = streamResponse.Response.Usage.InputTokens usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens usage.TotalTokens = streamResponse.Response.Usage.TotalTokens + case "response.output_text.delta": + // 处理输出文本 + responseTextBuilder.WriteString(streamResponse.Delta) + } } return true }) - // 处理最后一条数据 - sendResponsesStreamData(c, lastStreamData, forceFormat) + helper.Done(c) - // 处理token计算 - // if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil { - // common.SysError("error processing tokens: " + err.Error()) - // } + if usage.CompletionTokens == 0 { + // 计算输出文本的 token 数量 + tempStr := responseTextBuilder.String() + if len(tempStr) > 0 { + // 非正常结束,使用输出文本的 token 数量 + completionTokens, _ := service.CountTextToken(tempStr, info.UpstreamModelName) + usage.CompletionTokens = completionTokens + } + } return nil, usage } -func sendResponsesStreamData(c *gin.Context, data string, forceFormat bool) error { +func sendResponsesStreamData(c *gin.Context, streamResponse dto.ResponsesStreamResponse, data string) { if data == "" { - return nil - } - - if forceFormat { - return helper.ObjectData(c, data) - } else { - return helper.StringData(c, data) + return } + helper.ResponseChunkData(c, streamResponse, data) } diff --git a/relay/helper/common.go b/relay/helper/common.go index ebfb6d58..43e8b92c 100644 --- a/relay/helper/common.go +++ b/relay/helper/common.go @@ -43,6 +43,14 @@ func ClaudeChunkData(c *gin.Context, resp dto.ClaudeResponse, data string) { } } +func ResponseChunkData(c *gin.Context, resp dto.ResponsesStreamResponse, data string) { + c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)}) + c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("data: %s\n", data)}) + if flusher, ok := c.Writer.(http.Flusher); ok { + flusher.Flush() + } +} + func StringData(c *gin.Context, str string) error { //str = strings.TrimPrefix(str, "data: ") //str = strings.TrimSuffix(str, "\r") From 419a056fbfdebfdad0fe5f78892e2e6ec7bf5144 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Sun, 4 May 2025 17:35:45 +0800 Subject: [PATCH 006/105] refactor: remove unnecessary call to helper.Done and adjust data rendering in ClaudeChunkData --- relay/channel/openai/relay-openai.go | 2 -- relay/helper/common.go | 2 +- relay/helper/stream_scanner.go | 4 ++-- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index f10ebc1b..ef660564 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -724,8 +724,6 @@ func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relayc return true }) - helper.Done(c) - if usage.CompletionTokens == 0 { // 计算输出文本的 token 数量 tempStr := responseTextBuilder.String() diff --git a/relay/helper/common.go b/relay/helper/common.go index 43e8b92c..6a8ca2d7 100644 --- a/relay/helper/common.go +++ b/relay/helper/common.go @@ -37,7 +37,7 @@ func ClaudeData(c *gin.Context, resp dto.ClaudeResponse) error { func ClaudeChunkData(c *gin.Context, resp dto.ClaudeResponse, data string) { c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)}) - c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("data: %s\n", data)}) + c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("data: %s", data)}) if flusher, ok := c.Writer.(http.Flusher); ok { flusher.Flush() } diff --git a/relay/helper/stream_scanner.go b/relay/helper/stream_scanner.go index abb98f42..2738ce2a 100644 --- a/relay/helper/stream_scanner.go +++ b/relay/helper/stream_scanner.go @@ -32,7 +32,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon defer resp.Body.Close() streamingTimeout := time.Duration(constant.StreamingTimeout) * time.Second - if strings.HasPrefix(info.UpstreamModelName, "o1") || strings.HasPrefix(info.UpstreamModelName, "o3") { + if strings.HasPrefix(info.UpstreamModelName, "o") { // twice timeout for thinking model streamingTimeout *= 2 } @@ -115,7 +115,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon } data = data[5:] data = strings.TrimLeft(data, " ") - data = strings.TrimSuffix(data, "\"") + data = strings.TrimSuffix(data, "\r") if !strings.HasPrefix(data, "[DONE]") { info.SetFirstResponseTime() writeMutex.Lock() // Lock before writing From 3def2bbd30abf37766411f7a294688546057e311 Mon Sep 17 00:00:00 2001 From: tbphp Date: Sun, 4 May 2025 18:26:18 +0800 Subject: [PATCH 007/105] fix: EditUser text error --- web/src/pages/User/EditUser.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/src/pages/User/EditUser.js b/web/src/pages/User/EditUser.js index 9f903142..36e54727 100644 --- a/web/src/pages/User/EditUser.js +++ b/web/src/pages/User/EditUser.js @@ -239,7 +239,7 @@ const EditUser = (props) => { readonly />
- {t('`已绑定的 OIDC 账户')} + {t('已绑定的 OIDC 账户')}
Date: Mon, 5 May 2025 00:40:16 +0800 Subject: [PATCH 008/105] feat: implement OpenAI responses handling and streaming support with built-in tool tracking --- dto/openai_response.go | 11 ++- relay/channel/openai/adaptor.go | 2 +- relay/channel/openai/helper.go | 7 ++ relay/channel/openai/relay-openai.go | 99 -------------------- relay/channel/openai/relay_responses.go | 114 ++++++++++++++++++++++++ relay/channel/vertex/adaptor.go | 2 +- relay/common/relay_info.go | 37 ++++++++ relay/relay-responses.go | 10 +-- 8 files changed, 173 insertions(+), 109 deletions(-) create mode 100644 relay/channel/openai/relay_responses.go diff --git a/dto/openai_response.go b/dto/openai_response.go index 1508d1f6..c8f61b9d 100644 --- a/dto/openai_response.go +++ b/dto/openai_response.go @@ -238,8 +238,12 @@ type ResponsesOutputContent struct { } const ( - BuildInTools_WebSearch = "web_search_preview" - BuildInTools_FileSearch = "file_search" + BuildInToolWebSearchPreview = "web_search_preview" + BuildInToolFileSearch = "file_search" +) + +const ( + BuildInCallWebSearchCall = "web_search_call" ) const ( @@ -250,6 +254,7 @@ const ( // ResponsesStreamResponse 用于处理 /v1/responses 流式响应 type ResponsesStreamResponse struct { Type string `json:"type"` - Response *OpenAIResponsesResponse `json:"response"` + Response *OpenAIResponsesResponse `json:"response,omitempty"` Delta string `json:"delta,omitempty"` + Item *ResponsesOutput `json:"item,omitempty"` } diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 7740c498..eb12a22a 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -429,7 +429,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom if info.IsStream { err, usage = OaiResponsesStreamHandler(c, resp, info) } else { - err, usage = OpenaiResponsesHandler(c, resp, info) + err, usage = OaiResponsesHandler(c, resp, info) } default: if info.IsStream { diff --git a/relay/channel/openai/helper.go b/relay/channel/openai/helper.go index e7ba2e7b..a068c544 100644 --- a/relay/channel/openai/helper.go +++ b/relay/channel/openai/helper.go @@ -187,3 +187,10 @@ func handleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream } } } + +func sendResponsesStreamData(c *gin.Context, streamResponse dto.ResponsesStreamResponse, data string) { + if data == "" { + return + } + helper.ResponseChunkData(c, streamResponse, data) +} diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index ef660564..b9ed94e2 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -644,102 +644,3 @@ func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycomm } return nil, &usageResp.Usage } - -func OpenaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { - // read response body - var responsesResponse dto.OpenAIResponsesResponse - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil - } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - err = common.DecodeJson(responseBody, &responsesResponse) - if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil - } - if responsesResponse.Error != nil { - return &dto.OpenAIErrorWithStatusCode{ - Error: dto.OpenAIError{ - Message: responsesResponse.Error.Message, - Type: "openai_error", - Code: responsesResponse.Error.Code, - }, - StatusCode: resp.StatusCode, - }, nil - } - - // reset response body - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) - // We shouldn't set the header before we parse the response body, because the parse part may fail. - // And then we will have to send an error response, but in this case, the header has already been set. - // So the httpClient will be confused by the response. - // For example, Postman will report error, and we cannot check the response at all. - for k, v := range resp.Header { - c.Writer.Header().Set(k, v[0]) - } - c.Writer.WriteHeader(resp.StatusCode) - // copy response body - _, err = io.Copy(c.Writer, resp.Body) - if err != nil { - common.SysError("error copying response body: " + err.Error()) - } - resp.Body.Close() - // compute usage - usage := dto.Usage{} - usage.PromptTokens = responsesResponse.Usage.InputTokens - usage.CompletionTokens = responsesResponse.Usage.OutputTokens - usage.TotalTokens = responsesResponse.Usage.TotalTokens - return nil, &usage -} - -func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { - if resp == nil || resp.Body == nil { - common.LogError(c, "invalid response or response body") - return service.OpenAIErrorWrapper(fmt.Errorf("invalid response"), "invalid_response", http.StatusInternalServerError), nil - } - - var usage = &dto.Usage{} - var responseTextBuilder strings.Builder - - helper.StreamScannerHandler(c, resp, info, func(data string) bool { - - // 检查当前数据是否包含 completed 状态和 usage 信息 - var streamResponse dto.ResponsesStreamResponse - if err := common.DecodeJsonStr(data, &streamResponse); err == nil { - sendResponsesStreamData(c, streamResponse, data) - switch streamResponse.Type { - case "response.completed": - usage.PromptTokens = streamResponse.Response.Usage.InputTokens - usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens - usage.TotalTokens = streamResponse.Response.Usage.TotalTokens - case "response.output_text.delta": - // 处理输出文本 - responseTextBuilder.WriteString(streamResponse.Delta) - - } - } - return true - }) - - if usage.CompletionTokens == 0 { - // 计算输出文本的 token 数量 - tempStr := responseTextBuilder.String() - if len(tempStr) > 0 { - // 非正常结束,使用输出文本的 token 数量 - completionTokens, _ := service.CountTextToken(tempStr, info.UpstreamModelName) - usage.CompletionTokens = completionTokens - } - } - - return nil, usage -} - -func sendResponsesStreamData(c *gin.Context, streamResponse dto.ResponsesStreamResponse, data string) { - if data == "" { - return - } - helper.ResponseChunkData(c, streamResponse, data) -} diff --git a/relay/channel/openai/relay_responses.go b/relay/channel/openai/relay_responses.go new file mode 100644 index 00000000..6af8c676 --- /dev/null +++ b/relay/channel/openai/relay_responses.go @@ -0,0 +1,114 @@ +package openai + +import ( + "bytes" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/common" + "one-api/dto" + relaycommon "one-api/relay/common" + "one-api/relay/helper" + "one-api/service" + "strings" +) + +func OaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + // read response body + var responsesResponse dto.OpenAIResponsesResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = common.DecodeJson(responseBody, &responsesResponse) + if err != nil { + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if responsesResponse.Error != nil { + return &dto.OpenAIErrorWithStatusCode{ + Error: dto.OpenAIError{ + Message: responsesResponse.Error.Message, + Type: "openai_error", + Code: responsesResponse.Error.Code, + }, + StatusCode: resp.StatusCode, + }, nil + } + + // reset response body + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + // We shouldn't set the header before we parse the response body, because the parse part may fail. + // And then we will have to send an error response, but in this case, the header has already been set. + // So the httpClient will be confused by the response. + // For example, Postman will report error, and we cannot check the response at all. + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + c.Writer.WriteHeader(resp.StatusCode) + // copy response body + _, err = io.Copy(c.Writer, resp.Body) + if err != nil { + common.SysError("error copying response body: " + err.Error()) + } + resp.Body.Close() + // compute usage + usage := dto.Usage{} + usage.PromptTokens = responsesResponse.Usage.InputTokens + usage.CompletionTokens = responsesResponse.Usage.OutputTokens + usage.TotalTokens = responsesResponse.Usage.TotalTokens + return nil, &usage +} + +func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + if resp == nil || resp.Body == nil { + common.LogError(c, "invalid response or response body") + return service.OpenAIErrorWrapper(fmt.Errorf("invalid response"), "invalid_response", http.StatusInternalServerError), nil + } + + var usage = &dto.Usage{} + var responseTextBuilder strings.Builder + + helper.StreamScannerHandler(c, resp, info, func(data string) bool { + + // 检查当前数据是否包含 completed 状态和 usage 信息 + var streamResponse dto.ResponsesStreamResponse + if err := common.DecodeJsonStr(data, &streamResponse); err == nil { + sendResponsesStreamData(c, streamResponse, data) + switch streamResponse.Type { + case "response.completed": + usage.PromptTokens = streamResponse.Response.Usage.InputTokens + usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens + usage.TotalTokens = streamResponse.Response.Usage.TotalTokens + case "response.output_text.delta": + // 处理输出文本 + responseTextBuilder.WriteString(streamResponse.Delta) + case dto.ResponsesOutputTypeItemDone: + // 函数调用处理 + if streamResponse.Item != nil { + switch streamResponse.Item.Type { + case dto.BuildInCallWebSearchCall: + info.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview].CallCount++ + } + } + } + } + return true + }) + + if usage.CompletionTokens == 0 { + // 计算输出文本的 token 数量 + tempStr := responseTextBuilder.String() + if len(tempStr) > 0 { + // 非正常结束,使用输出文本的 token 数量 + completionTokens, _ := service.CountTextToken(tempStr, info.UpstreamModelName) + usage.CompletionTokens = completionTokens + } + } + + return nil, usage +} diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go index c1b64f11..7daf9a61 100644 --- a/relay/channel/vertex/adaptor.go +++ b/relay/channel/vertex/adaptor.go @@ -11,8 +11,8 @@ import ( "one-api/relay/channel/claude" "one-api/relay/channel/gemini" "one-api/relay/channel/openai" - "one-api/setting/model_setting" relaycommon "one-api/relay/common" + "one-api/setting/model_setting" "strings" "github.com/gin-gonic/gin" diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 915474e1..99c6d12b 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -36,6 +36,7 @@ type ClaudeConvertInfo struct { const ( RelayFormatOpenAI = "openai" RelayFormatClaude = "claude" + RelayFormatGemini = "gemini" ) type RerankerInfo struct { @@ -43,6 +44,16 @@ type RerankerInfo struct { ReturnDocuments bool } +type BuildInToolInfo struct { + ToolName string + CallCount int + SearchContextSize string +} + +type ResponsesUsageInfo struct { + BuiltInTools map[string]*BuildInToolInfo +} + type RelayInfo struct { ChannelType int ChannelId int @@ -90,6 +101,7 @@ type RelayInfo struct { ThinkingContentInfo *ClaudeConvertInfo *RerankerInfo + *ResponsesUsageInfo } // 定义支持流式选项的通道类型 @@ -134,6 +146,31 @@ func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo { return info } +func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *RelayInfo { + info := GenRelayInfo(c) + info.RelayMode = relayconstant.RelayModeResponses + info.ResponsesUsageInfo = &ResponsesUsageInfo{ + BuiltInTools: make(map[string]*BuildInToolInfo), + } + if len(req.Tools) > 0 { + for _, tool := range req.Tools { + info.ResponsesUsageInfo.BuiltInTools[tool.Type] = &BuildInToolInfo{ + ToolName: tool.Type, + CallCount: 0, + } + switch tool.Type { + case dto.BuildInToolWebSearchPreview: + if tool.SearchContextSize == "" { + tool.SearchContextSize = "medium" + } + info.ResponsesUsageInfo.BuiltInTools[tool.Type].SearchContextSize = tool.SearchContextSize + } + } + } + info.IsStream = req.Stream + return info +} + func GenRelayInfo(c *gin.Context) *RelayInfo { channelType := c.GetInt("channel_type") channelId := c.GetInt("channel_id") diff --git a/relay/relay-responses.go b/relay/relay-responses.go index cdb37ae7..fd3ddb5a 100644 --- a/relay/relay-responses.go +++ b/relay/relay-responses.go @@ -19,7 +19,7 @@ import ( "github.com/gin-gonic/gin" ) -func getAndValidateResponsesRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo) (*dto.OpenAIResponsesRequest, error) { +func getAndValidateResponsesRequest(c *gin.Context) (*dto.OpenAIResponsesRequest, error) { request := &dto.OpenAIResponsesRequest{} err := common.UnmarshalBodyReusable(c, request) if err != nil { @@ -31,13 +31,11 @@ func getAndValidateResponsesRequest(c *gin.Context, relayInfo *relaycommon.Relay if len(request.Input) == 0 { return nil, errors.New("input is required") } - relayInfo.IsStream = request.Stream return request, nil } func checkInputSensitive(textRequest *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) ([]string, error) { - sensitiveWords, err := service.CheckSensitiveInput(textRequest.Input) return sensitiveWords, err } @@ -49,12 +47,14 @@ func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo } func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { - relayInfo := relaycommon.GenRelayInfo(c) - req, err := getAndValidateResponsesRequest(c, relayInfo) + req, err := getAndValidateResponsesRequest(c) if err != nil { common.LogError(c, fmt.Sprintf("getAndValidateResponsesRequest error: %s", err.Error())) return service.OpenAIErrorWrapperLocal(err, "invalid_responses_request", http.StatusBadRequest) } + + relayInfo := relaycommon.GenRelayInfoResponses(c, req) + if setting.ShouldCheckPromptSensitive() { sensitiveWords, err := checkInputSensitive(req, relayInfo) if err != nil { From 6c3fb7777ec3fe4874b249251120e68b5e22642f Mon Sep 17 00:00:00 2001 From: tbphp Date: Mon, 5 May 2025 07:31:54 +0800 Subject: [PATCH 009/105] =?UTF-8?q?feat:=20=E5=A2=9E=E5=8A=A0=E5=88=86?= =?UTF-8?q?=E7=BB=84=E9=80=9F=E7=8E=87=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- middleware/model-rate-limit.go | 37 +++++++-- model/option.go | 47 +++++++++++- setting/rate_limit.go | 70 +++++++++++++++++ web/src/components/RateLimitSetting.js | 1 + .../RateLimit/SettingsRequestRateLimit.js | 76 +++++++++++++++++-- 5 files changed, 214 insertions(+), 17 deletions(-) diff --git a/middleware/model-rate-limit.go b/middleware/model-rate-limit.go index 581dc451..d4199ece 100644 --- a/middleware/model-rate-limit.go +++ b/middleware/model-rate-limit.go @@ -168,16 +168,39 @@ func ModelRequestRateLimit() func(c *gin.Context) { return } - // 计算限流参数 + // 计算通用限流参数 duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60) - totalMaxCount := setting.ModelRequestRateLimitCount - successMaxCount := setting.ModelRequestRateLimitSuccessCount - // 根据存储类型选择并执行限流处理器 - if common.RedisEnabled { - redisRateLimitHandler(duration, totalMaxCount, successMaxCount)(c) + // 获取用户组 + group := c.GetString("token_group") + if group == "" { + group = c.GetString("group") + } + if group == "" { + group = "default" // 默认组 + } + + // 尝试获取用户组特定的限制 + groupTotalCount, groupSuccessCount, found := setting.GetGroupRateLimit(group) + + // 确定最终的限制值 + finalTotalCount := setting.ModelRequestRateLimitCount // 默认使用全局总次数限制 + finalSuccessCount := setting.ModelRequestRateLimitSuccessCount // 默认使用全局成功次数限制 + + if found { + // 如果找到用户组特定限制,则使用它们 + finalTotalCount = groupTotalCount + finalSuccessCount = groupSuccessCount + common.LogWarn(c.Request.Context(), fmt.Sprintf("Using rate limit for group '%s': total=%d, success=%d", group, finalTotalCount, finalSuccessCount)) } else { - memoryRateLimitHandler(duration, totalMaxCount, successMaxCount)(c) + common.LogInfo(c.Request.Context(), fmt.Sprintf("No specific rate limit found for group '%s', using global limits: total=%d, success=%d", group, finalTotalCount, finalSuccessCount)) + } + + // 根据存储类型选择并执行限流处理器,传入最终确定的限制值 + if common.RedisEnabled { + redisRateLimitHandler(duration, finalTotalCount, finalSuccessCount)(c) + } else { + memoryRateLimitHandler(duration, finalTotalCount, finalSuccessCount)(c) } } } diff --git a/model/option.go b/model/option.go index d575742f..1f5fb3aa 100644 --- a/model/option.go +++ b/model/option.go @@ -1,6 +1,8 @@ package model import ( + "encoding/json" + "fmt" "one-api/common" "one-api/setting" "one-api/setting/config" @@ -96,6 +98,7 @@ func InitOptionMap() { common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString() common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString() common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString() + common.OptionMap[setting.ModelRequestRateLimitGroupKey] = "{}" // 添加用户组速率限制默认值 common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString() common.OptionMap["CompletionRatio"] = operation_setting.CompletionRatio2JSONString() common.OptionMap["TopUpLink"] = common.TopUpLink @@ -150,7 +153,32 @@ func SyncOptions(frequency int) { } func UpdateOption(key string, value string) error { - // Save to database first + originalValue := value // 保存原始值以备后用 + + // Validate and format specific keys before saving + if key == setting.ModelRequestRateLimitGroupKey { + var cfg map[string][2]int + // Validate the JSON structure first using the original value + err := json.Unmarshal([]byte(originalValue), &cfg) + if err != nil { + // 提供更具体的错误信息 + return fmt.Errorf("无效的 JSON 格式 for %s: %w", key, err) + } + // TODO: 可以添加更细致的结构验证,例如检查数组长度是否为2,值是否为非负数等。 + // if !isValidModelRequestRateLimitGroupConfig(cfg) { + // return fmt.Errorf("无效的配置值 for %s", key) + // } + + // If valid, format the JSON before saving + formattedValueBytes, marshalErr := json.MarshalIndent(cfg, "", " ") + if marshalErr != nil { + // This should ideally not happen if validation passed, but handle defensively + return fmt.Errorf("failed to marshal validated %s config: %w", key, marshalErr) + } + value = string(formattedValueBytes) // Use formatted JSON for saving and memory update + } + + // Save to database option := Option{ Key: key, } @@ -160,8 +188,12 @@ func UpdateOption(key string, value string) error { // Save is a combination function. // If save value does not contain primary key, it will execute Create, // otherwise it will execute Update (with all fields). - DB.Save(&option) - // Update OptionMap + if err := DB.Save(&option).Error; err != nil { + return fmt.Errorf("保存选项 %s 到数据库失败: %w", key, err) // 添加错误上下文 + } + + // Update OptionMap in memory using the potentially formatted value + // updateOptionMap 会处理内存中 setting.ModelRequestRateLimitGroupConfig 的更新 return updateOptionMap(key, value) } @@ -372,6 +404,15 @@ func updateOptionMap(key string, value string) (err error) { operation_setting.AutomaticDisableKeywordsFromString(value) case "StreamCacheQueueLength": setting.StreamCacheQueueLength, _ = strconv.Atoi(value) + case setting.ModelRequestRateLimitGroupKey: + // Use the (potentially formatted) value passed from UpdateOption + // to update the actual configuration in memory. + // This is the single point where the memory state for this specific setting is updated. + err = setting.UpdateModelRequestRateLimitGroupConfig(value) + if err != nil { + // 添加错误上下文 + err = fmt.Errorf("更新内存中的 %s 配置失败: %w", key, err) + } } return err } diff --git a/setting/rate_limit.go b/setting/rate_limit.go index 4b216948..c83885a6 100644 --- a/setting/rate_limit.go +++ b/setting/rate_limit.go @@ -1,6 +1,76 @@ package setting +import ( + "encoding/json" + "fmt" + "one-api/common" + "sync" +) + var ModelRequestRateLimitEnabled = false var ModelRequestRateLimitDurationMinutes = 1 var ModelRequestRateLimitCount = 0 var ModelRequestRateLimitSuccessCount = 1000 + +// ModelRequestRateLimitGroupKey 定义了模型请求按组速率限制的配置键 +const ModelRequestRateLimitGroupKey = "ModelRequestRateLimitGroup" + +// ModelRequestRateLimitGroupConfig 存储按用户组解析后的速率限制配置 +// map[groupName][2]int{totalCount, successCount} +var ModelRequestRateLimitGroupConfig map[string][2]int +var ModelRequestRateLimitGroupMutex sync.RWMutex + +// UpdateModelRequestRateLimitGroupConfig 解析、校验并更新内存中的用户组速率限制配置 +func UpdateModelRequestRateLimitGroupConfig(jsonStr string) error { + ModelRequestRateLimitGroupMutex.Lock() + defer ModelRequestRateLimitGroupMutex.Unlock() + + var newConfig map[string][2]int + if jsonStr == "" || jsonStr == "{}" { + // 如果配置为空或空JSON对象,则清空内存配置 + ModelRequestRateLimitGroupConfig = make(map[string][2]int) + common.SysLog("Model request rate limit group config cleared") + return nil + } + + err := json.Unmarshal([]byte(jsonStr), &newConfig) + if err != nil { + return fmt.Errorf("failed to unmarshal ModelRequestRateLimitGroup config: %w", err) + } + + // 校验配置值 + for group, limits := range newConfig { + if len(limits) != 2 { + return fmt.Errorf("invalid config for group '%s': limits array length must be 2", group) + } + if limits[1] <= 0 { // successCount must be greater than 0 + return fmt.Errorf("invalid config for group '%s': successCount (limits[1]) must be greater than 0", group) + } + if limits[0] < 0 { // totalCount can be 0 (no limit) or positive + return fmt.Errorf("invalid config for group '%s': totalCount (limits[0]) cannot be negative", group) + } + if limits[0] > 0 && limits[0] < limits[1] { // If totalCount is set, it must be >= successCount + return fmt.Errorf("invalid config for group '%s': totalCount (limits[0]) must be greater than or equal to successCount (limits[1]) when totalCount > 0", group) + } + } + + ModelRequestRateLimitGroupConfig = newConfig + common.SysLog("Model request rate limit group config updated") + return nil +} + +// GetGroupRateLimit 安全地获取指定用户组的速率限制值 +func GetGroupRateLimit(group string) (totalCount, successCount int, found bool) { + ModelRequestRateLimitGroupMutex.RLock() + defer ModelRequestRateLimitGroupMutex.RUnlock() + + if ModelRequestRateLimitGroupConfig == nil { + return 0, 0, false // 配置尚未初始化 + } + + limits, found := ModelRequestRateLimitGroupConfig[group] + if !found { + return 0, 0, false + } + return limits[0], limits[1], true +} diff --git a/web/src/components/RateLimitSetting.js b/web/src/components/RateLimitSetting.js index e06038d6..ad6b53da 100644 --- a/web/src/components/RateLimitSetting.js +++ b/web/src/components/RateLimitSetting.js @@ -13,6 +13,7 @@ const RateLimitSetting = () => { ModelRequestRateLimitCount: 0, ModelRequestRateLimitSuccessCount: 1000, ModelRequestRateLimitDurationMinutes: 1, + ModelRequestRateLimitGroup: {}, }); let [loading, setLoading] = useState(false); diff --git a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js index 800e9636..ec1c2158 100644 --- a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js +++ b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js @@ -18,6 +18,7 @@ export default function RequestRateLimit(props) { ModelRequestRateLimitCount: -1, ModelRequestRateLimitSuccessCount: 1000, ModelRequestRateLimitDurationMinutes: 1, + ModelRequestRateLimitGroup: '{}', // 添加新字段并设置默认值 }); const refForm = useRef(); const [inputsRow, setInputsRow] = useState(inputs); @@ -32,25 +33,49 @@ export default function RequestRateLimit(props) { } else { value = inputs[item.key]; } + // 校验 ModelRequestRateLimitGroup 是否为有效的 JSON 对象字符串 + if (item.key === 'ModelRequestRateLimitGroup') { + try { + JSON.parse(value); + } catch (e) { + showError(t('用户组速率限制配置不是有效的 JSON 格式!')); + // 阻止请求发送 + return Promise.reject('Invalid JSON format'); + } + } return API.put('/api/option/', { key: item.key, value, }); }); + + // 过滤掉无效的请求(例如,无效的 JSON) + const validRequests = requestQueue.filter(req => req !== null && req !== undefined && typeof req.then === 'function'); + + if (validRequests.length === 0 && requestQueue.length > 0) { + // 如果所有请求都被过滤掉了(因为 JSON 无效),则不继续执行 + return; + } + setLoading(true); - Promise.all(requestQueue) + Promise.all(validRequests) .then((res) => { - if (requestQueue.length === 1) { + if (validRequests.length === 1) { if (res.includes(undefined)) return; - } else if (requestQueue.length > 1) { + } else if (validRequests.length > 1) { if (res.includes(undefined)) return showError(t('部分保存失败,请重试')); } showSuccess(t('保存成功')); props.refresh(); + // 更新 inputsRow 以反映保存后的状态 + setInputsRow(structuredClone(inputs)); }) - .catch(() => { - showError(t('保存失败,请重试')); + .catch((error) => { + // 检查是否是由于无效 JSON 导致的错误 + if (error !== 'Invalid JSON format') { + showError(t('保存失败,请重试')); + } }) .finally(() => { setLoading(false); @@ -66,8 +91,11 @@ export default function RequestRateLimit(props) { } setInputs(currentInputs); setInputsRow(structuredClone(currentInputs)); - refForm.current.setValues(currentInputs); - }, [props.options]); + // 检查 refForm.current 是否存在 + if (refForm.current) { + refForm.current.setValues(currentInputs); + } + }, [props.options]); // 依赖项保持不变,因为 inputs 状态的结构已固定 return ( <> @@ -147,7 +175,41 @@ export default function RequestRateLimit(props) { /> + {/* 用户组速率限制配置项 */} + + +

{t('说明:')}

+
    +
  • {t('使用 JSON 对象格式,键为用户组名 (字符串),值为包含两个整数的数组 [总次数限制, 成功次数限制]。')}
  • +
  • {t('总次数限制: 周期内允许的总请求次数 (含失败),0 代表不限制。')}
  • +
  • {t('成功次数限制: 周期内允许的成功请求次数 (HTTP < 400),必须大于 0。')}
  • +
  • {t('此配置将优先于上方的全局限制设置。')}
  • +
  • {t('未在此处配置的用户组将使用全局限制。')}
  • +
  • {t('限制周期统一使用上方配置的“限制周期”值。')}
  • +
  • {t('输入无效的 JSON 将无法保存。')}
  • +
+ + } + autosize={{ minRows: 5, maxRows: 15 }} + style={{ fontFamily: 'monospace' }} + onChange={(value) => { + setInputs({ + ...inputs, + ModelRequestRateLimitGroup: value, // 直接更新字符串值 + }); + }} + /> + +
+ From 7e7d6112ca460be5c30a6c89fb4165346a6d5651 Mon Sep 17 00:00:00 2001 From: tbphp Date: Mon, 5 May 2025 11:34:57 +0800 Subject: [PATCH 010/105] =?UTF-8?q?feat:=20=E4=BC=98=E5=8C=96=E4=BB=A3?= =?UTF-8?q?=E7=A0=81=EF=BC=8C=E5=8E=BB=E9=99=A4=E5=A4=9A=E4=BD=99=E6=B3=A8?= =?UTF-8?q?=E9=87=8A=E5=92=8C=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .old/option.go | 402 ++++++++++++++++++ middleware/model-rate-limit.go | 43 +- model/option.go | 40 +- setting/rate_limit.go | 39 +- web/src/components/RateLimitSetting.js | 92 ++-- .../RateLimit/SettingsRequestRateLimit.js | 388 +++++++++-------- 6 files changed, 663 insertions(+), 341 deletions(-) create mode 100644 .old/option.go diff --git a/.old/option.go b/.old/option.go new file mode 100644 index 00000000..f80f5cb3 --- /dev/null +++ b/.old/option.go @@ -0,0 +1,402 @@ +package model + +import ( + "one-api/common" + "one-api/setting" + "one-api/setting/config" + "one-api/setting/operation_setting" + "strconv" + "strings" + "time" +) + +type Option struct { + Key string `json:"key" gorm:"primaryKey"` + Value string `json:"value"` +} + +func AllOption() ([]*Option, error) { + var options []*Option + var err error + err = DB.Find(&options).Error + return options, err +} + +func InitOptionMap() { + common.OptionMapRWMutex.Lock() + common.OptionMap = make(map[string]string) + + // 添加原有的系统配置 + common.OptionMap["FileUploadPermission"] = strconv.Itoa(common.FileUploadPermission) + common.OptionMap["FileDownloadPermission"] = strconv.Itoa(common.FileDownloadPermission) + common.OptionMap["ImageUploadPermission"] = strconv.Itoa(common.ImageUploadPermission) + common.OptionMap["ImageDownloadPermission"] = strconv.Itoa(common.ImageDownloadPermission) + common.OptionMap["PasswordLoginEnabled"] = strconv.FormatBool(common.PasswordLoginEnabled) + common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled) + common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled) + common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled) + common.OptionMap["LinuxDOOAuthEnabled"] = strconv.FormatBool(common.LinuxDOOAuthEnabled) + common.OptionMap["TelegramOAuthEnabled"] = strconv.FormatBool(common.TelegramOAuthEnabled) + common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled) + common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled) + common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled) + common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled) + common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled) + common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled) + common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled) + common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled) + common.OptionMap["DrawingEnabled"] = strconv.FormatBool(common.DrawingEnabled) + common.OptionMap["TaskEnabled"] = strconv.FormatBool(common.TaskEnabled) + common.OptionMap["DataExportEnabled"] = strconv.FormatBool(common.DataExportEnabled) + common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64) + common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled) + common.OptionMap["EmailAliasRestrictionEnabled"] = strconv.FormatBool(common.EmailAliasRestrictionEnabled) + common.OptionMap["EmailDomainWhitelist"] = strings.Join(common.EmailDomainWhitelist, ",") + common.OptionMap["SMTPServer"] = "" + common.OptionMap["SMTPFrom"] = "" + common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort) + common.OptionMap["SMTPAccount"] = "" + common.OptionMap["SMTPToken"] = "" + common.OptionMap["SMTPSSLEnabled"] = strconv.FormatBool(common.SMTPSSLEnabled) + common.OptionMap["Notice"] = "" + common.OptionMap["About"] = "" + common.OptionMap["HomePageContent"] = "" + common.OptionMap["Footer"] = common.Footer + common.OptionMap["SystemName"] = common.SystemName + common.OptionMap["Logo"] = common.Logo + common.OptionMap["ServerAddress"] = "" + common.OptionMap["WorkerUrl"] = setting.WorkerUrl + common.OptionMap["WorkerValidKey"] = setting.WorkerValidKey + common.OptionMap["PayAddress"] = "" + common.OptionMap["CustomCallbackAddress"] = "" + common.OptionMap["EpayId"] = "" + common.OptionMap["EpayKey"] = "" + common.OptionMap["Price"] = strconv.FormatFloat(setting.Price, 'f', -1, 64) + common.OptionMap["MinTopUp"] = strconv.Itoa(setting.MinTopUp) + common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString() + common.OptionMap["Chats"] = setting.Chats2JsonString() + common.OptionMap["GitHubClientId"] = "" + common.OptionMap["GitHubClientSecret"] = "" + common.OptionMap["TelegramBotToken"] = "" + common.OptionMap["TelegramBotName"] = "" + common.OptionMap["WeChatServerAddress"] = "" + common.OptionMap["WeChatServerToken"] = "" + common.OptionMap["WeChatAccountQRCodeImageURL"] = "" + common.OptionMap["TurnstileSiteKey"] = "" + common.OptionMap["TurnstileSecretKey"] = "" + common.OptionMap["QuotaForNewUser"] = strconv.Itoa(common.QuotaForNewUser) + common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter) + common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee) + common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold) + common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota) + common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount) + common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes) + common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount) + common.OptionMap["ModelRatio"] = operation_setting.ModelRatio2JSONString() + common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString() + common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString() + common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString() + common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString() + common.OptionMap["CompletionRatio"] = operation_setting.CompletionRatio2JSONString() + common.OptionMap["TopUpLink"] = common.TopUpLink + //common.OptionMap["ChatLink"] = common.ChatLink + //common.OptionMap["ChatLink2"] = common.ChatLink2 + common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64) + common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes) + common.OptionMap["DataExportInterval"] = strconv.Itoa(common.DataExportInterval) + common.OptionMap["DataExportDefaultTime"] = common.DataExportDefaultTime + common.OptionMap["DefaultCollapseSidebar"] = strconv.FormatBool(common.DefaultCollapseSidebar) + common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(setting.MjNotifyEnabled) + common.OptionMap["MjAccountFilterEnabled"] = strconv.FormatBool(setting.MjAccountFilterEnabled) + common.OptionMap["MjModeClearEnabled"] = strconv.FormatBool(setting.MjModeClearEnabled) + common.OptionMap["MjForwardUrlEnabled"] = strconv.FormatBool(setting.MjForwardUrlEnabled) + common.OptionMap["MjActionCheckSuccessEnabled"] = strconv.FormatBool(setting.MjActionCheckSuccessEnabled) + common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(setting.CheckSensitiveEnabled) + common.OptionMap["DemoSiteEnabled"] = strconv.FormatBool(operation_setting.DemoSiteEnabled) + common.OptionMap["SelfUseModeEnabled"] = strconv.FormatBool(operation_setting.SelfUseModeEnabled) + common.OptionMap["ModelRequestRateLimitEnabled"] = strconv.FormatBool(setting.ModelRequestRateLimitEnabled) + common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(setting.CheckSensitiveOnPromptEnabled) + common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(setting.StopOnSensitiveEnabled) + common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString() + common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength) + common.OptionMap["AutomaticDisableKeywords"] = operation_setting.AutomaticDisableKeywordsToString() + + // 自动添加所有注册的模型配置 + modelConfigs := config.GlobalConfig.ExportAllConfigs() + for k, v := range modelConfigs { + common.OptionMap[k] = v + } + + common.OptionMapRWMutex.Unlock() + loadOptionsFromDatabase() +} + +func loadOptionsFromDatabase() { + options, _ := AllOption() + for _, option := range options { + err := updateOptionMap(option.Key, option.Value) + if err != nil { + common.SysError("failed to update option map: " + err.Error()) + } + } +} + +func SyncOptions(frequency int) { + for { + time.Sleep(time.Duration(frequency) * time.Second) + common.SysLog("syncing options from database") + loadOptionsFromDatabase() + } +} + +func UpdateOption(key string, value string) error { + // Save to database first + option := Option{ + Key: key, + } + // https://gorm.io/docs/update.html#Save-All-Fields + DB.FirstOrCreate(&option, Option{Key: key}) + option.Value = value + // Save is a combination function. + // If save value does not contain primary key, it will execute Create, + // otherwise it will execute Update (with all fields). + DB.Save(&option) + // Update OptionMap + return updateOptionMap(key, value) +} + +func updateOptionMap(key string, value string) (err error) { + common.OptionMapRWMutex.Lock() + defer common.OptionMapRWMutex.Unlock() + common.OptionMap[key] = value + + // 检查是否是模型配置 - 使用更规范的方式处理 + if handleConfigUpdate(key, value) { + return nil // 已由配置系统处理 + } + + // 处理传统配置项... + if strings.HasSuffix(key, "Permission") { + intValue, _ := strconv.Atoi(value) + switch key { + case "FileUploadPermission": + common.FileUploadPermission = intValue + case "FileDownloadPermission": + common.FileDownloadPermission = intValue + case "ImageUploadPermission": + common.ImageUploadPermission = intValue + case "ImageDownloadPermission": + common.ImageDownloadPermission = intValue + } + } + if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" { + boolValue := value == "true" + switch key { + case "PasswordRegisterEnabled": + common.PasswordRegisterEnabled = boolValue + case "PasswordLoginEnabled": + common.PasswordLoginEnabled = boolValue + case "EmailVerificationEnabled": + common.EmailVerificationEnabled = boolValue + case "GitHubOAuthEnabled": + common.GitHubOAuthEnabled = boolValue + case "LinuxDOOAuthEnabled": + common.LinuxDOOAuthEnabled = boolValue + case "WeChatAuthEnabled": + common.WeChatAuthEnabled = boolValue + case "TelegramOAuthEnabled": + common.TelegramOAuthEnabled = boolValue + case "TurnstileCheckEnabled": + common.TurnstileCheckEnabled = boolValue + case "RegisterEnabled": + common.RegisterEnabled = boolValue + case "EmailDomainRestrictionEnabled": + common.EmailDomainRestrictionEnabled = boolValue + case "EmailAliasRestrictionEnabled": + common.EmailAliasRestrictionEnabled = boolValue + case "AutomaticDisableChannelEnabled": + common.AutomaticDisableChannelEnabled = boolValue + case "AutomaticEnableChannelEnabled": + common.AutomaticEnableChannelEnabled = boolValue + case "LogConsumeEnabled": + common.LogConsumeEnabled = boolValue + case "DisplayInCurrencyEnabled": + common.DisplayInCurrencyEnabled = boolValue + case "DisplayTokenStatEnabled": + common.DisplayTokenStatEnabled = boolValue + case "DrawingEnabled": + common.DrawingEnabled = boolValue + case "TaskEnabled": + common.TaskEnabled = boolValue + case "DataExportEnabled": + common.DataExportEnabled = boolValue + case "DefaultCollapseSidebar": + common.DefaultCollapseSidebar = boolValue + case "MjNotifyEnabled": + setting.MjNotifyEnabled = boolValue + case "MjAccountFilterEnabled": + setting.MjAccountFilterEnabled = boolValue + case "MjModeClearEnabled": + setting.MjModeClearEnabled = boolValue + case "MjForwardUrlEnabled": + setting.MjForwardUrlEnabled = boolValue + case "MjActionCheckSuccessEnabled": + setting.MjActionCheckSuccessEnabled = boolValue + case "CheckSensitiveEnabled": + setting.CheckSensitiveEnabled = boolValue + case "DemoSiteEnabled": + operation_setting.DemoSiteEnabled = boolValue + case "SelfUseModeEnabled": + operation_setting.SelfUseModeEnabled = boolValue + case "CheckSensitiveOnPromptEnabled": + setting.CheckSensitiveOnPromptEnabled = boolValue + case "ModelRequestRateLimitEnabled": + setting.ModelRequestRateLimitEnabled = boolValue + case "StopOnSensitiveEnabled": + setting.StopOnSensitiveEnabled = boolValue + case "SMTPSSLEnabled": + common.SMTPSSLEnabled = boolValue + } + } + switch key { + case "EmailDomainWhitelist": + common.EmailDomainWhitelist = strings.Split(value, ",") + case "SMTPServer": + common.SMTPServer = value + case "SMTPPort": + intValue, _ := strconv.Atoi(value) + common.SMTPPort = intValue + case "SMTPAccount": + common.SMTPAccount = value + case "SMTPFrom": + common.SMTPFrom = value + case "SMTPToken": + common.SMTPToken = value + case "ServerAddress": + setting.ServerAddress = value + case "WorkerUrl": + setting.WorkerUrl = value + case "WorkerValidKey": + setting.WorkerValidKey = value + case "PayAddress": + setting.PayAddress = value + case "Chats": + err = setting.UpdateChatsByJsonString(value) + case "CustomCallbackAddress": + setting.CustomCallbackAddress = value + case "EpayId": + setting.EpayId = value + case "EpayKey": + setting.EpayKey = value + case "Price": + setting.Price, _ = strconv.ParseFloat(value, 64) + case "MinTopUp": + setting.MinTopUp, _ = strconv.Atoi(value) + case "TopupGroupRatio": + err = common.UpdateTopupGroupRatioByJSONString(value) + case "GitHubClientId": + common.GitHubClientId = value + case "GitHubClientSecret": + common.GitHubClientSecret = value + case "LinuxDOClientId": + common.LinuxDOClientId = value + case "LinuxDOClientSecret": + common.LinuxDOClientSecret = value + case "Footer": + common.Footer = value + case "SystemName": + common.SystemName = value + case "Logo": + common.Logo = value + case "WeChatServerAddress": + common.WeChatServerAddress = value + case "WeChatServerToken": + common.WeChatServerToken = value + case "WeChatAccountQRCodeImageURL": + common.WeChatAccountQRCodeImageURL = value + case "TelegramBotToken": + common.TelegramBotToken = value + case "TelegramBotName": + common.TelegramBotName = value + case "TurnstileSiteKey": + common.TurnstileSiteKey = value + case "TurnstileSecretKey": + common.TurnstileSecretKey = value + case "QuotaForNewUser": + common.QuotaForNewUser, _ = strconv.Atoi(value) + case "QuotaForInviter": + common.QuotaForInviter, _ = strconv.Atoi(value) + case "QuotaForInvitee": + common.QuotaForInvitee, _ = strconv.Atoi(value) + case "QuotaRemindThreshold": + common.QuotaRemindThreshold, _ = strconv.Atoi(value) + case "PreConsumedQuota": + common.PreConsumedQuota, _ = strconv.Atoi(value) + case "ModelRequestRateLimitCount": + setting.ModelRequestRateLimitCount, _ = strconv.Atoi(value) + case "ModelRequestRateLimitDurationMinutes": + setting.ModelRequestRateLimitDurationMinutes, _ = strconv.Atoi(value) + case "ModelRequestRateLimitSuccessCount": + setting.ModelRequestRateLimitSuccessCount, _ = strconv.Atoi(value) + case "RetryTimes": + common.RetryTimes, _ = strconv.Atoi(value) + case "DataExportInterval": + common.DataExportInterval, _ = strconv.Atoi(value) + case "DataExportDefaultTime": + common.DataExportDefaultTime = value + case "ModelRatio": + err = operation_setting.UpdateModelRatioByJSONString(value) + case "GroupRatio": + err = setting.UpdateGroupRatioByJSONString(value) + case "UserUsableGroups": + err = setting.UpdateUserUsableGroupsByJSONString(value) + case "CompletionRatio": + err = operation_setting.UpdateCompletionRatioByJSONString(value) + case "ModelPrice": + err = operation_setting.UpdateModelPriceByJSONString(value) + case "CacheRatio": + err = operation_setting.UpdateCacheRatioByJSONString(value) + case "TopUpLink": + common.TopUpLink = value + //case "ChatLink": + // common.ChatLink = value + //case "ChatLink2": + // common.ChatLink2 = value + case "ChannelDisableThreshold": + common.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64) + case "QuotaPerUnit": + common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64) + case "SensitiveWords": + setting.SensitiveWordsFromString(value) + case "AutomaticDisableKeywords": + operation_setting.AutomaticDisableKeywordsFromString(value) + case "StreamCacheQueueLength": + setting.StreamCacheQueueLength, _ = strconv.Atoi(value) + } + return err +} + +// handleConfigUpdate 处理分层配置更新,返回是否已处理 +func handleConfigUpdate(key, value string) bool { + parts := strings.SplitN(key, ".", 2) + if len(parts) != 2 { + return false // 不是分层配置 + } + + configName := parts[0] + configKey := parts[1] + + // 获取配置对象 + cfg := config.GlobalConfig.Get(configName) + if cfg == nil { + return false // 未注册的配置 + } + + // 更新配置 + configMap := map[string]string{ + configKey: value, + } + config.UpdateConfigFromMap(cfg, configMap) + + return true // 已处理 +} \ No newline at end of file diff --git a/middleware/model-rate-limit.go b/middleware/model-rate-limit.go index d4199ece..b0047b70 100644 --- a/middleware/model-rate-limit.go +++ b/middleware/model-rate-limit.go @@ -19,25 +19,20 @@ const ( ModelRequestRateLimitSuccessCountMark = "MRRLS" ) -// 检查Redis中的请求限制 func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, maxCount int, duration int64) (bool, error) { - // 如果maxCount为0,表示不限制 if maxCount == 0 { return true, nil } - // 获取当前计数 length, err := rdb.LLen(ctx, key).Result() if err != nil { return false, err } - // 如果未达到限制,允许请求 if length < int64(maxCount) { return true, nil } - // 检查时间窗口 oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result() oldTime, err := time.Parse(timeFormat, oldTimeStr) if err != nil { @@ -49,7 +44,6 @@ func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, max if err != nil { return false, err } - // 如果在时间窗口内已达到限制,拒绝请求 subTime := nowTime.Sub(oldTime).Seconds() if int64(subTime) < duration { rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute) @@ -59,9 +53,7 @@ func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, max return true, nil } -// 记录Redis请求 func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxCount int) { - // 如果maxCount为0,不记录请求 if maxCount == 0 { return } @@ -72,14 +64,12 @@ func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxC rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute) } -// Redis限流处理器 func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc { return func(c *gin.Context) { userId := strconv.Itoa(c.GetInt("id")) ctx := context.Background() rdb := common.RDB - // 1. 检查成功请求数限制 successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId) allowed, err := checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration) if err != nil { @@ -92,9 +82,7 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g return } - //2.检查总请求数限制并记录总请求(当totalMaxCount为0时会自动跳过,使用令牌桶限流器 totalKey := fmt.Sprintf("rateLimit:%s", userId) - // 初始化 tb := limiter.New(ctx, rdb) allowed, err = tb.Allow( ctx, @@ -114,17 +102,14 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount)) } - // 4. 处理请求 c.Next() - // 5. 如果请求成功,记录成功请求 if c.Writer.Status() < 400 { recordRedisRequest(ctx, rdb, successKey, successMaxCount) } } } -// 内存限流处理器 func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc { inMemoryRateLimiter.Init(time.Duration(setting.ModelRequestRateLimitDurationMinutes) * time.Minute) @@ -133,15 +118,12 @@ func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) totalKey := ModelRequestRateLimitCountMark + userId successKey := ModelRequestRateLimitSuccessCountMark + userId - // 1. 检查总请求数限制(当totalMaxCount为0时跳过) if totalMaxCount > 0 && !inMemoryRateLimiter.Request(totalKey, totalMaxCount, duration) { c.Status(http.StatusTooManyRequests) c.Abort() return } - // 2. 检查成功请求数限制 - // 使用一个临时key来检查限制,这样可以避免实际记录 checkKey := successKey + "_check" if !inMemoryRateLimiter.Request(checkKey, successMaxCount, duration) { c.Status(http.StatusTooManyRequests) @@ -149,54 +131,47 @@ func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) return } - // 3. 处理请求 c.Next() - // 4. 如果请求成功,记录到实际的成功请求计数中 if c.Writer.Status() < 400 { inMemoryRateLimiter.Request(successKey, successMaxCount, duration) } } } -// ModelRequestRateLimit 模型请求限流中间件 func ModelRequestRateLimit() func(c *gin.Context) { return func(c *gin.Context) { - // 在每个请求时检查是否启用限流 if !setting.ModelRequestRateLimitEnabled { c.Next() return } - // 计算通用限流参数 duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60) - // 获取用户组 group := c.GetString("token_group") if group == "" { group = c.GetString("group") } if group == "" { - group = "default" // 默认组 + group = "default" } - // 尝试获取用户组特定的限制 + finalTotalCount := setting.ModelRequestRateLimitCount + finalSuccessCount := setting.ModelRequestRateLimitSuccessCount + foundGroupLimit := false + groupTotalCount, groupSuccessCount, found := setting.GetGroupRateLimit(group) - - // 确定最终的限制值 - finalTotalCount := setting.ModelRequestRateLimitCount // 默认使用全局总次数限制 - finalSuccessCount := setting.ModelRequestRateLimitSuccessCount // 默认使用全局成功次数限制 - if found { - // 如果找到用户组特定限制,则使用它们 finalTotalCount = groupTotalCount finalSuccessCount = groupSuccessCount + foundGroupLimit = true common.LogWarn(c.Request.Context(), fmt.Sprintf("Using rate limit for group '%s': total=%d, success=%d", group, finalTotalCount, finalSuccessCount)) - } else { + } + + if !foundGroupLimit { common.LogInfo(c.Request.Context(), fmt.Sprintf("No specific rate limit found for group '%s', using global limits: total=%d, success=%d", group, finalTotalCount, finalSuccessCount)) } - // 根据存储类型选择并执行限流处理器,传入最终确定的限制值 if common.RedisEnabled { redisRateLimitHandler(duration, finalTotalCount, finalSuccessCount)(c) } else { diff --git a/model/option.go b/model/option.go index 1f5fb3aa..79556737 100644 --- a/model/option.go +++ b/model/option.go @@ -94,11 +94,12 @@ func InitOptionMap() { common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount) common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes) common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount) + jsonBytes, _ := json.Marshal(map[string][2]int{}) + common.OptionMap["ModelRequestRateLimitGroup"] = string(jsonBytes) common.OptionMap["ModelRatio"] = operation_setting.ModelRatio2JSONString() common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString() common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString() common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString() - common.OptionMap[setting.ModelRequestRateLimitGroupKey] = "{}" // 添加用户组速率限制默认值 common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString() common.OptionMap["CompletionRatio"] = operation_setting.CompletionRatio2JSONString() common.OptionMap["TopUpLink"] = common.TopUpLink @@ -153,47 +154,31 @@ func SyncOptions(frequency int) { } func UpdateOption(key string, value string) error { - originalValue := value // 保存原始值以备后用 + originalValue := value - // Validate and format specific keys before saving - if key == setting.ModelRequestRateLimitGroupKey { + if key == "ModelRequestRateLimitGroup" { var cfg map[string][2]int - // Validate the JSON structure first using the original value err := json.Unmarshal([]byte(originalValue), &cfg) if err != nil { - // 提供更具体的错误信息 return fmt.Errorf("无效的 JSON 格式 for %s: %w", key, err) } - // TODO: 可以添加更细致的结构验证,例如检查数组长度是否为2,值是否为非负数等。 - // if !isValidModelRequestRateLimitGroupConfig(cfg) { - // return fmt.Errorf("无效的配置值 for %s", key) - // } - // If valid, format the JSON before saving formattedValueBytes, marshalErr := json.MarshalIndent(cfg, "", " ") if marshalErr != nil { - // This should ideally not happen if validation passed, but handle defensively return fmt.Errorf("failed to marshal validated %s config: %w", key, marshalErr) } - value = string(formattedValueBytes) // Use formatted JSON for saving and memory update + value = string(formattedValueBytes) } - // Save to database option := Option{ Key: key, } - // https://gorm.io/docs/update.html#Save-All-Fields DB.FirstOrCreate(&option, Option{Key: key}) option.Value = value - // Save is a combination function. - // If save value does not contain primary key, it will execute Create, - // otherwise it will execute Update (with all fields). if err := DB.Save(&option).Error; err != nil { - return fmt.Errorf("保存选项 %s 到数据库失败: %w", key, err) // 添加错误上下文 + return fmt.Errorf("保存选项 %s 到数据库失败: %w", key, err) } - // Update OptionMap in memory using the potentially formatted value - // updateOptionMap 会处理内存中 setting.ModelRequestRateLimitGroupConfig 的更新 return updateOptionMap(key, value) } @@ -370,6 +355,8 @@ func updateOptionMap(key string, value string) (err error) { setting.ModelRequestRateLimitDurationMinutes, _ = strconv.Atoi(value) case "ModelRequestRateLimitSuccessCount": setting.ModelRequestRateLimitSuccessCount, _ = strconv.Atoi(value) + case "ModelRequestRateLimitGroup": + err = setting.UpdateModelRequestRateLimitGroup(value) case "RetryTimes": common.RetryTimes, _ = strconv.Atoi(value) case "DataExportInterval": @@ -404,15 +391,6 @@ func updateOptionMap(key string, value string) (err error) { operation_setting.AutomaticDisableKeywordsFromString(value) case "StreamCacheQueueLength": setting.StreamCacheQueueLength, _ = strconv.Atoi(value) - case setting.ModelRequestRateLimitGroupKey: - // Use the (potentially formatted) value passed from UpdateOption - // to update the actual configuration in memory. - // This is the single point where the memory state for this specific setting is updated. - err = setting.UpdateModelRequestRateLimitGroupConfig(value) - if err != nil { - // 添加错误上下文 - err = fmt.Errorf("更新内存中的 %s 配置失败: %w", key, err) - } } return err } @@ -440,4 +418,4 @@ func handleConfigUpdate(key, value string) bool { config.UpdateConfigFromMap(cfg, configMap) return true // 已处理 -} +} \ No newline at end of file diff --git a/setting/rate_limit.go b/setting/rate_limit.go index c83885a6..5be75cc1 100644 --- a/setting/rate_limit.go +++ b/setting/rate_limit.go @@ -11,24 +11,17 @@ var ModelRequestRateLimitEnabled = false var ModelRequestRateLimitDurationMinutes = 1 var ModelRequestRateLimitCount = 0 var ModelRequestRateLimitSuccessCount = 1000 +var ModelRequestRateLimitGroup map[string][2]int -// ModelRequestRateLimitGroupKey 定义了模型请求按组速率限制的配置键 -const ModelRequestRateLimitGroupKey = "ModelRequestRateLimitGroup" - -// ModelRequestRateLimitGroupConfig 存储按用户组解析后的速率限制配置 -// map[groupName][2]int{totalCount, successCount} -var ModelRequestRateLimitGroupConfig map[string][2]int var ModelRequestRateLimitGroupMutex sync.RWMutex -// UpdateModelRequestRateLimitGroupConfig 解析、校验并更新内存中的用户组速率限制配置 -func UpdateModelRequestRateLimitGroupConfig(jsonStr string) error { +func UpdateModelRequestRateLimitGroup(jsonStr string) error { ModelRequestRateLimitGroupMutex.Lock() defer ModelRequestRateLimitGroupMutex.Unlock() var newConfig map[string][2]int if jsonStr == "" || jsonStr == "{}" { - // 如果配置为空或空JSON对象,则清空内存配置 - ModelRequestRateLimitGroupConfig = make(map[string][2]int) + ModelRequestRateLimitGroup = make(map[string][2]int) common.SysLog("Model request rate limit group config cleared") return nil } @@ -38,37 +31,19 @@ func UpdateModelRequestRateLimitGroupConfig(jsonStr string) error { return fmt.Errorf("failed to unmarshal ModelRequestRateLimitGroup config: %w", err) } - // 校验配置值 - for group, limits := range newConfig { - if len(limits) != 2 { - return fmt.Errorf("invalid config for group '%s': limits array length must be 2", group) - } - if limits[1] <= 0 { // successCount must be greater than 0 - return fmt.Errorf("invalid config for group '%s': successCount (limits[1]) must be greater than 0", group) - } - if limits[0] < 0 { // totalCount can be 0 (no limit) or positive - return fmt.Errorf("invalid config for group '%s': totalCount (limits[0]) cannot be negative", group) - } - if limits[0] > 0 && limits[0] < limits[1] { // If totalCount is set, it must be >= successCount - return fmt.Errorf("invalid config for group '%s': totalCount (limits[0]) must be greater than or equal to successCount (limits[1]) when totalCount > 0", group) - } - } - - ModelRequestRateLimitGroupConfig = newConfig - common.SysLog("Model request rate limit group config updated") + ModelRequestRateLimitGroup = newConfig return nil } -// GetGroupRateLimit 安全地获取指定用户组的速率限制值 func GetGroupRateLimit(group string) (totalCount, successCount int, found bool) { ModelRequestRateLimitGroupMutex.RLock() defer ModelRequestRateLimitGroupMutex.RUnlock() - if ModelRequestRateLimitGroupConfig == nil { - return 0, 0, false // 配置尚未初始化 + if ModelRequestRateLimitGroup == nil { + return 0, 0, false } - limits, found := ModelRequestRateLimitGroupConfig[group] + limits, found := ModelRequestRateLimitGroup[group] if !found { return 0, 0, false } diff --git a/web/src/components/RateLimitSetting.js b/web/src/components/RateLimitSetting.js index ad6b53da..7e206672 100644 --- a/web/src/components/RateLimitSetting.js +++ b/web/src/components/RateLimitSetting.js @@ -9,59 +9,59 @@ import RequestRateLimit from '../pages/Setting/RateLimit/SettingsRequestRateLimi const RateLimitSetting = () => { const { t } = useTranslation(); let [inputs, setInputs] = useState({ - ModelRequestRateLimitEnabled: false, - ModelRequestRateLimitCount: 0, - ModelRequestRateLimitSuccessCount: 1000, - ModelRequestRateLimitDurationMinutes: 1, - ModelRequestRateLimitGroup: {}, + ModelRequestRateLimitEnabled: false, + ModelRequestRateLimitCount: 0, + ModelRequestRateLimitSuccessCount: 1000, + ModelRequestRateLimitDurationMinutes: 1, + ModelRequestRateLimitGroup: '{}', }); - + let [loading, setLoading] = useState(false); - + const getOptions = async () => { - const res = await API.get('/api/option/'); - const { success, message, data } = res.data; - if (success) { - let newInputs = {}; - data.forEach((item) => { - if (item.key.endsWith('Enabled')) { - newInputs[item.key] = item.value === 'true' ? true : false; - } else { - newInputs[item.key] = item.value; - } - }); - - setInputs(newInputs); - } else { - showError(message); - } + const res = await API.get('/api/option/'); + const { success, message, data } = res.data; + if (success) { + let newInputs = {}; + data.forEach((item) => { + // 检查 key 是否在初始 inputs 中定义 + if (Object.prototype.hasOwnProperty.call(inputs, item.key)) { + if (item.key.endsWith('Enabled')) { + newInputs[item.key] = item.value === 'true'; + } else { + newInputs[item.key] = item.value; + } + } + }); + setInputs(newInputs); + } else { + showError(message); + } }; async function onRefresh() { - try { - setLoading(true); - await getOptions(); - // showSuccess('刷新成功'); - } catch (error) { - showError('刷新失败'); - } finally { - setLoading(false); - } + try { + setLoading(true); + await getOptions(); + } catch (error) { + showError('刷新失败'); + } finally { + setLoading(false); + } } - + useEffect(() => { - onRefresh(); + onRefresh(); }, []); - + return ( - <> - - {/* AI请求速率限制 */} - - - - - + <> + + + + + + ); -}; - -export default RateLimitSetting; + }; + + export default RateLimitSetting; diff --git a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js index ec1c2158..2434020e 100644 --- a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js +++ b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js @@ -14,209 +14,201 @@ export default function RequestRateLimit(props) { const [loading, setLoading] = useState(false); const [inputs, setInputs] = useState({ - ModelRequestRateLimitEnabled: false, - ModelRequestRateLimitCount: -1, - ModelRequestRateLimitSuccessCount: 1000, - ModelRequestRateLimitDurationMinutes: 1, - ModelRequestRateLimitGroup: '{}', // 添加新字段并设置默认值 + ModelRequestRateLimitEnabled: false, + ModelRequestRateLimitCount: -1, + ModelRequestRateLimitSuccessCount: 1000, + ModelRequestRateLimitDurationMinutes: 1, + ModelRequestRateLimitGroup: '{}', }); const refForm = useRef(); const [inputsRow, setInputsRow] = useState(inputs); - + function onSubmit() { - const updateArray = compareObjects(inputs, inputsRow); - if (!updateArray.length) return showWarning(t('你似乎并没有修改什么')); - const requestQueue = updateArray.map((item) => { - let value = ''; - if (typeof inputs[item.key] === 'boolean') { - value = String(inputs[item.key]); - } else { - value = inputs[item.key]; - } - // 校验 ModelRequestRateLimitGroup 是否为有效的 JSON 对象字符串 - if (item.key === 'ModelRequestRateLimitGroup') { - try { - JSON.parse(value); - } catch (e) { - showError(t('用户组速率限制配置不是有效的 JSON 格式!')); - // 阻止请求发送 - return Promise.reject('Invalid JSON format'); - } - } - return API.put('/api/option/', { - key: item.key, - value, - }); - }); - - // 过滤掉无效的请求(例如,无效的 JSON) - const validRequests = requestQueue.filter(req => req !== null && req !== undefined && typeof req.then === 'function'); - - if (validRequests.length === 0 && requestQueue.length > 0) { - // 如果所有请求都被过滤掉了(因为 JSON 无效),则不继续执行 - return; - } - - setLoading(true); - Promise.all(validRequests) - .then((res) => { - if (validRequests.length === 1) { - if (res.includes(undefined)) return; - } else if (validRequests.length > 1) { - if (res.includes(undefined)) - return showError(t('部分保存失败,请重试')); - } - showSuccess(t('保存成功')); - props.refresh(); - // 更新 inputsRow 以反映保存后的状态 - setInputsRow(structuredClone(inputs)); - }) - .catch((error) => { - // 检查是否是由于无效 JSON 导致的错误 - if (error !== 'Invalid JSON format') { - showError(t('保存失败,请重试')); - } - }) - .finally(() => { - setLoading(false); - }); + const updateArray = compareObjects(inputs, inputsRow); + if (!updateArray.length) return showWarning(t('你似乎并没有修改什么')); + const requestQueue = updateArray.map((item) => { + let value = ''; + if (typeof inputs[item.key] === 'boolean') { + value = String(inputs[item.key]); + } else { + value = inputs[item.key]; + } + if (item.key === 'ModelRequestRateLimitGroup') { + try { + JSON.parse(value); + } catch (e) { + showError(t('用户组速率限制配置不是有效的 JSON 格式!')); + return Promise.reject('Invalid JSON format'); + } + } + return API.put('/api/option/', { + key: item.key, + value, + }); + }); + + const validRequests = requestQueue.filter(req => req !== null && req !== undefined && typeof req.then === 'function'); + + if (validRequests.length === 0 && requestQueue.length > 0) { + return; + } + + setLoading(true); + Promise.all(validRequests) + .then((res) => { + if (validRequests.length === 1) { + if (res.includes(undefined)) return; + } else if (validRequests.length > 1) { + if (res.includes(undefined)) + return showError(t('部分保存失败,请重试')); + } + showSuccess(t('保存成功')); + props.refresh(); + setInputsRow(structuredClone(inputs)); + }) + .catch((error) => { + if (error !== 'Invalid JSON format') { + showError(t('保存失败,请重试')); + } + }) + .finally(() => { + setLoading(false); + }); } - + useEffect(() => { - const currentInputs = {}; - for (let key in props.options) { - if (Object.keys(inputs).includes(key)) { - currentInputs[key] = props.options[key]; - } - } - setInputs(currentInputs); - setInputsRow(structuredClone(currentInputs)); - // 检查 refForm.current 是否存在 - if (refForm.current) { - refForm.current.setValues(currentInputs); - } - }, [props.options]); // 依赖项保持不变,因为 inputs 状态的结构已固定 - + const currentInputs = {}; + for (let key in props.options) { + if (Object.prototype.hasOwnProperty.call(inputs, key)) { // 使用 hasOwnProperty 检查 + currentInputs[key] = props.options[key]; + } + } + setInputs(currentInputs); + setInputsRow(structuredClone(currentInputs)); + if (refForm.current) { + refForm.current.setValues(currentInputs); + } + }, [props.options]); + return ( - <> - -
(refForm.current = formAPI)} - style={{ marginBottom: 15 }} - > - - - - { - setInputs({ - ...inputs, - ModelRequestRateLimitEnabled: value, - }); - }} - /> - - - - - - setInputs({ - ...inputs, - ModelRequestRateLimitDurationMinutes: String(value), - }) - } - /> - - - - - - setInputs({ - ...inputs, - ModelRequestRateLimitCount: String(value), - }) - } - /> - - - - setInputs({ - ...inputs, - ModelRequestRateLimitSuccessCount: String(value), - }) - } - /> - - - {/* 用户组速率限制配置项 */} - - - -

{t('说明:')}

-
    -
  • {t('使用 JSON 对象格式,键为用户组名 (字符串),值为包含两个整数的数组 [总次数限制, 成功次数限制]。')}
  • -
  • {t('总次数限制: 周期内允许的总请求次数 (含失败),0 代表不限制。')}
  • -
  • {t('成功次数限制: 周期内允许的成功请求次数 (HTTP < 400),必须大于 0。')}
  • -
  • {t('此配置将优先于上方的全局限制设置。')}
  • -
  • {t('未在此处配置的用户组将使用全局限制。')}
  • -
  • {t('限制周期统一使用上方配置的“限制周期”值。')}
  • -
  • {t('输入无效的 JSON 将无法保存。')}
  • -
- - } - autosize={{ minRows: 5, maxRows: 15 }} - style={{ fontFamily: 'monospace' }} - onChange={(value) => { - setInputs({ - ...inputs, - ModelRequestRateLimitGroup: value, // 直接更新字符串值 - }); - }} - /> - -
- - - -
-
-
- + <> + +
(refForm.current = formAPI)} + style={{ marginBottom: 15 }} + > + + + + { + setInputs({ + ...inputs, + ModelRequestRateLimitEnabled: value, + }); + }} + /> + + + + + + setInputs({ + ...inputs, + ModelRequestRateLimitDurationMinutes: String(value), + }) + } + /> + + + + + + setInputs({ + ...inputs, + ModelRequestRateLimitCount: String(value), + }) + } + /> + + + + setInputs({ + ...inputs, + ModelRequestRateLimitSuccessCount: String(value), + }) + } + /> + + + + + +

{t('说明:')}

+
    +
  • {t('使用 JSON 对象格式,键为用户组名 (字符串),值为包含两个整数的数组 [总次数限制, 成功次数限制]。')}
  • +
  • {t('总次数限制: 周期内允许的总请求次数 (含失败),0 代表不限制。')}
  • +
  • {t('成功次数限制: 周期内允许的成功请求次数 (HTTP < 400),必须大于 0。')}
  • +
  • {t('此配置将优先于上方的全局限制设置。')}
  • +
  • {t('未在此处配置的用户组将使用全局限制。')}
  • +
  • {t('限制周期统一使用上方配置的“限制周期”值。')}
  • +
  • {t('输入无效的 JSON 将无法保存。')}
  • +
+ + } + autosize={{ minRows: 5, maxRows: 15 }} + style={{ fontFamily: 'monospace' }} + onChange={(value) => { + setInputs({ + ...inputs, + ModelRequestRateLimitGroup: value, + }); + }} + /> + +
+ + + +
+
+
+ ); -} + } From b7fd1e4a203fb24d2b5a332ea4ec8abe3cdcecac Mon Sep 17 00:00:00 2001 From: tbphp Date: Mon, 5 May 2025 12:55:48 +0800 Subject: [PATCH 011/105] fix: Redis limit ignoring max eq 0 --- middleware/model-rate-limit.go | 36 ++++++++++++++++++---------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/middleware/model-rate-limit.go b/middleware/model-rate-limit.go index 581dc451..f81160fc 100644 --- a/middleware/model-rate-limit.go +++ b/middleware/model-rate-limit.go @@ -93,25 +93,27 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g } //2.检查总请求数限制并记录总请求(当totalMaxCount为0时会自动跳过,使用令牌桶限流器 - totalKey := fmt.Sprintf("rateLimit:%s", userId) - // 初始化 - tb := limiter.New(ctx, rdb) - allowed, err = tb.Allow( - ctx, - totalKey, - limiter.WithCapacity(int64(totalMaxCount)*duration), - limiter.WithRate(int64(totalMaxCount)), - limiter.WithRequested(duration), - ) + if totalMaxCount > 0 { + totalKey := fmt.Sprintf("rateLimit:%s", userId) + // 初始化 + tb := limiter.New(ctx, rdb) + allowed, err = tb.Allow( + ctx, + totalKey, + limiter.WithCapacity(int64(totalMaxCount)*duration), + limiter.WithRate(int64(totalMaxCount)), + limiter.WithRequested(duration), + ) - if err != nil { - fmt.Println("检查总请求数限制失败:", err.Error()) - abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed") - return - } + if err != nil { + fmt.Println("检查总请求数限制失败:", err.Error()) + abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed") + return + } - if !allowed { - abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount)) + if !allowed { + abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount)) + } } // 4. 处理请求 From 1e1d24d1b075042473902991cbc3610f6c8bfff8 Mon Sep 17 00:00:00 2001 From: tbphp Date: Mon, 5 May 2025 17:57:02 +0800 Subject: [PATCH 012/105] fix: rm debug file --- .old/option.go | 402 ------------------------------------------------- 1 file changed, 402 deletions(-) delete mode 100644 .old/option.go diff --git a/.old/option.go b/.old/option.go deleted file mode 100644 index f80f5cb3..00000000 --- a/.old/option.go +++ /dev/null @@ -1,402 +0,0 @@ -package model - -import ( - "one-api/common" - "one-api/setting" - "one-api/setting/config" - "one-api/setting/operation_setting" - "strconv" - "strings" - "time" -) - -type Option struct { - Key string `json:"key" gorm:"primaryKey"` - Value string `json:"value"` -} - -func AllOption() ([]*Option, error) { - var options []*Option - var err error - err = DB.Find(&options).Error - return options, err -} - -func InitOptionMap() { - common.OptionMapRWMutex.Lock() - common.OptionMap = make(map[string]string) - - // 添加原有的系统配置 - common.OptionMap["FileUploadPermission"] = strconv.Itoa(common.FileUploadPermission) - common.OptionMap["FileDownloadPermission"] = strconv.Itoa(common.FileDownloadPermission) - common.OptionMap["ImageUploadPermission"] = strconv.Itoa(common.ImageUploadPermission) - common.OptionMap["ImageDownloadPermission"] = strconv.Itoa(common.ImageDownloadPermission) - common.OptionMap["PasswordLoginEnabled"] = strconv.FormatBool(common.PasswordLoginEnabled) - common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled) - common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled) - common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled) - common.OptionMap["LinuxDOOAuthEnabled"] = strconv.FormatBool(common.LinuxDOOAuthEnabled) - common.OptionMap["TelegramOAuthEnabled"] = strconv.FormatBool(common.TelegramOAuthEnabled) - common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled) - common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled) - common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled) - common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled) - common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled) - common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled) - common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled) - common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled) - common.OptionMap["DrawingEnabled"] = strconv.FormatBool(common.DrawingEnabled) - common.OptionMap["TaskEnabled"] = strconv.FormatBool(common.TaskEnabled) - common.OptionMap["DataExportEnabled"] = strconv.FormatBool(common.DataExportEnabled) - common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64) - common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled) - common.OptionMap["EmailAliasRestrictionEnabled"] = strconv.FormatBool(common.EmailAliasRestrictionEnabled) - common.OptionMap["EmailDomainWhitelist"] = strings.Join(common.EmailDomainWhitelist, ",") - common.OptionMap["SMTPServer"] = "" - common.OptionMap["SMTPFrom"] = "" - common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort) - common.OptionMap["SMTPAccount"] = "" - common.OptionMap["SMTPToken"] = "" - common.OptionMap["SMTPSSLEnabled"] = strconv.FormatBool(common.SMTPSSLEnabled) - common.OptionMap["Notice"] = "" - common.OptionMap["About"] = "" - common.OptionMap["HomePageContent"] = "" - common.OptionMap["Footer"] = common.Footer - common.OptionMap["SystemName"] = common.SystemName - common.OptionMap["Logo"] = common.Logo - common.OptionMap["ServerAddress"] = "" - common.OptionMap["WorkerUrl"] = setting.WorkerUrl - common.OptionMap["WorkerValidKey"] = setting.WorkerValidKey - common.OptionMap["PayAddress"] = "" - common.OptionMap["CustomCallbackAddress"] = "" - common.OptionMap["EpayId"] = "" - common.OptionMap["EpayKey"] = "" - common.OptionMap["Price"] = strconv.FormatFloat(setting.Price, 'f', -1, 64) - common.OptionMap["MinTopUp"] = strconv.Itoa(setting.MinTopUp) - common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString() - common.OptionMap["Chats"] = setting.Chats2JsonString() - common.OptionMap["GitHubClientId"] = "" - common.OptionMap["GitHubClientSecret"] = "" - common.OptionMap["TelegramBotToken"] = "" - common.OptionMap["TelegramBotName"] = "" - common.OptionMap["WeChatServerAddress"] = "" - common.OptionMap["WeChatServerToken"] = "" - common.OptionMap["WeChatAccountQRCodeImageURL"] = "" - common.OptionMap["TurnstileSiteKey"] = "" - common.OptionMap["TurnstileSecretKey"] = "" - common.OptionMap["QuotaForNewUser"] = strconv.Itoa(common.QuotaForNewUser) - common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter) - common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee) - common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold) - common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota) - common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount) - common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes) - common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount) - common.OptionMap["ModelRatio"] = operation_setting.ModelRatio2JSONString() - common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString() - common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString() - common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString() - common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString() - common.OptionMap["CompletionRatio"] = operation_setting.CompletionRatio2JSONString() - common.OptionMap["TopUpLink"] = common.TopUpLink - //common.OptionMap["ChatLink"] = common.ChatLink - //common.OptionMap["ChatLink2"] = common.ChatLink2 - common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64) - common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes) - common.OptionMap["DataExportInterval"] = strconv.Itoa(common.DataExportInterval) - common.OptionMap["DataExportDefaultTime"] = common.DataExportDefaultTime - common.OptionMap["DefaultCollapseSidebar"] = strconv.FormatBool(common.DefaultCollapseSidebar) - common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(setting.MjNotifyEnabled) - common.OptionMap["MjAccountFilterEnabled"] = strconv.FormatBool(setting.MjAccountFilterEnabled) - common.OptionMap["MjModeClearEnabled"] = strconv.FormatBool(setting.MjModeClearEnabled) - common.OptionMap["MjForwardUrlEnabled"] = strconv.FormatBool(setting.MjForwardUrlEnabled) - common.OptionMap["MjActionCheckSuccessEnabled"] = strconv.FormatBool(setting.MjActionCheckSuccessEnabled) - common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(setting.CheckSensitiveEnabled) - common.OptionMap["DemoSiteEnabled"] = strconv.FormatBool(operation_setting.DemoSiteEnabled) - common.OptionMap["SelfUseModeEnabled"] = strconv.FormatBool(operation_setting.SelfUseModeEnabled) - common.OptionMap["ModelRequestRateLimitEnabled"] = strconv.FormatBool(setting.ModelRequestRateLimitEnabled) - common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(setting.CheckSensitiveOnPromptEnabled) - common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(setting.StopOnSensitiveEnabled) - common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString() - common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength) - common.OptionMap["AutomaticDisableKeywords"] = operation_setting.AutomaticDisableKeywordsToString() - - // 自动添加所有注册的模型配置 - modelConfigs := config.GlobalConfig.ExportAllConfigs() - for k, v := range modelConfigs { - common.OptionMap[k] = v - } - - common.OptionMapRWMutex.Unlock() - loadOptionsFromDatabase() -} - -func loadOptionsFromDatabase() { - options, _ := AllOption() - for _, option := range options { - err := updateOptionMap(option.Key, option.Value) - if err != nil { - common.SysError("failed to update option map: " + err.Error()) - } - } -} - -func SyncOptions(frequency int) { - for { - time.Sleep(time.Duration(frequency) * time.Second) - common.SysLog("syncing options from database") - loadOptionsFromDatabase() - } -} - -func UpdateOption(key string, value string) error { - // Save to database first - option := Option{ - Key: key, - } - // https://gorm.io/docs/update.html#Save-All-Fields - DB.FirstOrCreate(&option, Option{Key: key}) - option.Value = value - // Save is a combination function. - // If save value does not contain primary key, it will execute Create, - // otherwise it will execute Update (with all fields). - DB.Save(&option) - // Update OptionMap - return updateOptionMap(key, value) -} - -func updateOptionMap(key string, value string) (err error) { - common.OptionMapRWMutex.Lock() - defer common.OptionMapRWMutex.Unlock() - common.OptionMap[key] = value - - // 检查是否是模型配置 - 使用更规范的方式处理 - if handleConfigUpdate(key, value) { - return nil // 已由配置系统处理 - } - - // 处理传统配置项... - if strings.HasSuffix(key, "Permission") { - intValue, _ := strconv.Atoi(value) - switch key { - case "FileUploadPermission": - common.FileUploadPermission = intValue - case "FileDownloadPermission": - common.FileDownloadPermission = intValue - case "ImageUploadPermission": - common.ImageUploadPermission = intValue - case "ImageDownloadPermission": - common.ImageDownloadPermission = intValue - } - } - if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" { - boolValue := value == "true" - switch key { - case "PasswordRegisterEnabled": - common.PasswordRegisterEnabled = boolValue - case "PasswordLoginEnabled": - common.PasswordLoginEnabled = boolValue - case "EmailVerificationEnabled": - common.EmailVerificationEnabled = boolValue - case "GitHubOAuthEnabled": - common.GitHubOAuthEnabled = boolValue - case "LinuxDOOAuthEnabled": - common.LinuxDOOAuthEnabled = boolValue - case "WeChatAuthEnabled": - common.WeChatAuthEnabled = boolValue - case "TelegramOAuthEnabled": - common.TelegramOAuthEnabled = boolValue - case "TurnstileCheckEnabled": - common.TurnstileCheckEnabled = boolValue - case "RegisterEnabled": - common.RegisterEnabled = boolValue - case "EmailDomainRestrictionEnabled": - common.EmailDomainRestrictionEnabled = boolValue - case "EmailAliasRestrictionEnabled": - common.EmailAliasRestrictionEnabled = boolValue - case "AutomaticDisableChannelEnabled": - common.AutomaticDisableChannelEnabled = boolValue - case "AutomaticEnableChannelEnabled": - common.AutomaticEnableChannelEnabled = boolValue - case "LogConsumeEnabled": - common.LogConsumeEnabled = boolValue - case "DisplayInCurrencyEnabled": - common.DisplayInCurrencyEnabled = boolValue - case "DisplayTokenStatEnabled": - common.DisplayTokenStatEnabled = boolValue - case "DrawingEnabled": - common.DrawingEnabled = boolValue - case "TaskEnabled": - common.TaskEnabled = boolValue - case "DataExportEnabled": - common.DataExportEnabled = boolValue - case "DefaultCollapseSidebar": - common.DefaultCollapseSidebar = boolValue - case "MjNotifyEnabled": - setting.MjNotifyEnabled = boolValue - case "MjAccountFilterEnabled": - setting.MjAccountFilterEnabled = boolValue - case "MjModeClearEnabled": - setting.MjModeClearEnabled = boolValue - case "MjForwardUrlEnabled": - setting.MjForwardUrlEnabled = boolValue - case "MjActionCheckSuccessEnabled": - setting.MjActionCheckSuccessEnabled = boolValue - case "CheckSensitiveEnabled": - setting.CheckSensitiveEnabled = boolValue - case "DemoSiteEnabled": - operation_setting.DemoSiteEnabled = boolValue - case "SelfUseModeEnabled": - operation_setting.SelfUseModeEnabled = boolValue - case "CheckSensitiveOnPromptEnabled": - setting.CheckSensitiveOnPromptEnabled = boolValue - case "ModelRequestRateLimitEnabled": - setting.ModelRequestRateLimitEnabled = boolValue - case "StopOnSensitiveEnabled": - setting.StopOnSensitiveEnabled = boolValue - case "SMTPSSLEnabled": - common.SMTPSSLEnabled = boolValue - } - } - switch key { - case "EmailDomainWhitelist": - common.EmailDomainWhitelist = strings.Split(value, ",") - case "SMTPServer": - common.SMTPServer = value - case "SMTPPort": - intValue, _ := strconv.Atoi(value) - common.SMTPPort = intValue - case "SMTPAccount": - common.SMTPAccount = value - case "SMTPFrom": - common.SMTPFrom = value - case "SMTPToken": - common.SMTPToken = value - case "ServerAddress": - setting.ServerAddress = value - case "WorkerUrl": - setting.WorkerUrl = value - case "WorkerValidKey": - setting.WorkerValidKey = value - case "PayAddress": - setting.PayAddress = value - case "Chats": - err = setting.UpdateChatsByJsonString(value) - case "CustomCallbackAddress": - setting.CustomCallbackAddress = value - case "EpayId": - setting.EpayId = value - case "EpayKey": - setting.EpayKey = value - case "Price": - setting.Price, _ = strconv.ParseFloat(value, 64) - case "MinTopUp": - setting.MinTopUp, _ = strconv.Atoi(value) - case "TopupGroupRatio": - err = common.UpdateTopupGroupRatioByJSONString(value) - case "GitHubClientId": - common.GitHubClientId = value - case "GitHubClientSecret": - common.GitHubClientSecret = value - case "LinuxDOClientId": - common.LinuxDOClientId = value - case "LinuxDOClientSecret": - common.LinuxDOClientSecret = value - case "Footer": - common.Footer = value - case "SystemName": - common.SystemName = value - case "Logo": - common.Logo = value - case "WeChatServerAddress": - common.WeChatServerAddress = value - case "WeChatServerToken": - common.WeChatServerToken = value - case "WeChatAccountQRCodeImageURL": - common.WeChatAccountQRCodeImageURL = value - case "TelegramBotToken": - common.TelegramBotToken = value - case "TelegramBotName": - common.TelegramBotName = value - case "TurnstileSiteKey": - common.TurnstileSiteKey = value - case "TurnstileSecretKey": - common.TurnstileSecretKey = value - case "QuotaForNewUser": - common.QuotaForNewUser, _ = strconv.Atoi(value) - case "QuotaForInviter": - common.QuotaForInviter, _ = strconv.Atoi(value) - case "QuotaForInvitee": - common.QuotaForInvitee, _ = strconv.Atoi(value) - case "QuotaRemindThreshold": - common.QuotaRemindThreshold, _ = strconv.Atoi(value) - case "PreConsumedQuota": - common.PreConsumedQuota, _ = strconv.Atoi(value) - case "ModelRequestRateLimitCount": - setting.ModelRequestRateLimitCount, _ = strconv.Atoi(value) - case "ModelRequestRateLimitDurationMinutes": - setting.ModelRequestRateLimitDurationMinutes, _ = strconv.Atoi(value) - case "ModelRequestRateLimitSuccessCount": - setting.ModelRequestRateLimitSuccessCount, _ = strconv.Atoi(value) - case "RetryTimes": - common.RetryTimes, _ = strconv.Atoi(value) - case "DataExportInterval": - common.DataExportInterval, _ = strconv.Atoi(value) - case "DataExportDefaultTime": - common.DataExportDefaultTime = value - case "ModelRatio": - err = operation_setting.UpdateModelRatioByJSONString(value) - case "GroupRatio": - err = setting.UpdateGroupRatioByJSONString(value) - case "UserUsableGroups": - err = setting.UpdateUserUsableGroupsByJSONString(value) - case "CompletionRatio": - err = operation_setting.UpdateCompletionRatioByJSONString(value) - case "ModelPrice": - err = operation_setting.UpdateModelPriceByJSONString(value) - case "CacheRatio": - err = operation_setting.UpdateCacheRatioByJSONString(value) - case "TopUpLink": - common.TopUpLink = value - //case "ChatLink": - // common.ChatLink = value - //case "ChatLink2": - // common.ChatLink2 = value - case "ChannelDisableThreshold": - common.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64) - case "QuotaPerUnit": - common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64) - case "SensitiveWords": - setting.SensitiveWordsFromString(value) - case "AutomaticDisableKeywords": - operation_setting.AutomaticDisableKeywordsFromString(value) - case "StreamCacheQueueLength": - setting.StreamCacheQueueLength, _ = strconv.Atoi(value) - } - return err -} - -// handleConfigUpdate 处理分层配置更新,返回是否已处理 -func handleConfigUpdate(key, value string) bool { - parts := strings.SplitN(key, ".", 2) - if len(parts) != 2 { - return false // 不是分层配置 - } - - configName := parts[0] - configKey := parts[1] - - // 获取配置对象 - cfg := config.GlobalConfig.Get(configName) - if cfg == nil { - return false // 未注册的配置 - } - - // 更新配置 - configMap := map[string]string{ - configKey: value, - } - config.UpdateConfigFromMap(cfg, configMap) - - return true // 已处理 -} \ No newline at end of file From 1513ed78477044999e066d5eb3b1fc1762dce531 Mon Sep 17 00:00:00 2001 From: tbphp Date: Mon, 5 May 2025 19:32:22 +0800 Subject: [PATCH 013/105] =?UTF-8?q?refactor:=20=E8=B0=83=E6=95=B4=E4=BB=A3?= =?UTF-8?q?=E7=A0=81=EF=BC=8C=E7=AC=A6=E5=90=88=E9=A1=B9=E7=9B=AE=E7=8E=B0?= =?UTF-8?q?=E6=9C=89=E8=A7=84=E8=8C=83?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- middleware/model-rate-limit.go | 54 +++++++++++++++++--------- model/option.go | 34 +++++----------- setting/rate_limit.go | 37 ++++++++---------- web/src/components/RateLimitSetting.js | 6 ++- 4 files changed, 65 insertions(+), 66 deletions(-) diff --git a/middleware/model-rate-limit.go b/middleware/model-rate-limit.go index b0047b70..1ca5ace6 100644 --- a/middleware/model-rate-limit.go +++ b/middleware/model-rate-limit.go @@ -6,6 +6,7 @@ import ( "net/http" "one-api/common" "one-api/common/limiter" + "one-api/constant" "one-api/setting" "strconv" "time" @@ -19,20 +20,25 @@ const ( ModelRequestRateLimitSuccessCountMark = "MRRLS" ) +// 检查Redis中的请求限制 func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, maxCount int, duration int64) (bool, error) { + // 如果maxCount为0,表示不限制 if maxCount == 0 { return true, nil } + // 获取当前计数 length, err := rdb.LLen(ctx, key).Result() if err != nil { return false, err } + // 如果未达到限制,允许请求 if length < int64(maxCount) { return true, nil } + // 检查时间窗口 oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result() oldTime, err := time.Parse(timeFormat, oldTimeStr) if err != nil { @@ -44,6 +50,7 @@ func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, max if err != nil { return false, err } + // 如果在时间窗口内已达到限制,拒绝请求 subTime := nowTime.Sub(oldTime).Seconds() if int64(subTime) < duration { rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute) @@ -53,7 +60,9 @@ func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, max return true, nil } +// 记录Redis请求 func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxCount int) { + // 如果maxCount为0,不记录请求 if maxCount == 0 { return } @@ -64,12 +73,14 @@ func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxC rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute) } +// Redis限流处理器 func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc { return func(c *gin.Context) { userId := strconv.Itoa(c.GetInt("id")) ctx := context.Background() rdb := common.RDB + // 1. 检查成功请求数限制 successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId) allowed, err := checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration) if err != nil { @@ -82,7 +93,9 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g return } + //2.检查总请求数限制并记录总请求(当totalMaxCount为0时会自动跳过,使用令牌桶限流器 totalKey := fmt.Sprintf("rateLimit:%s", userId) + // 初始化 tb := limiter.New(ctx, rdb) allowed, err = tb.Allow( ctx, @@ -102,14 +115,17 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount)) } + // 4. 处理请求 c.Next() + // 5. 如果请求成功,记录成功请求 if c.Writer.Status() < 400 { recordRedisRequest(ctx, rdb, successKey, successMaxCount) } } } +// 内存限流处理器 func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc { inMemoryRateLimiter.Init(time.Duration(setting.ModelRequestRateLimitDurationMinutes) * time.Minute) @@ -118,12 +134,15 @@ func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) totalKey := ModelRequestRateLimitCountMark + userId successKey := ModelRequestRateLimitSuccessCountMark + userId + // 1. 检查总请求数限制(当totalMaxCount为0时跳过) if totalMaxCount > 0 && !inMemoryRateLimiter.Request(totalKey, totalMaxCount, duration) { c.Status(http.StatusTooManyRequests) c.Abort() return } + // 2. 检查成功请求数限制 + // 使用一个临时key来检查限制,这样可以避免实际记录 checkKey := successKey + "_check" if !inMemoryRateLimiter.Request(checkKey, successMaxCount, duration) { c.Status(http.StatusTooManyRequests) @@ -131,51 +150,48 @@ func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) return } + // 3. 处理请求 c.Next() + // 4. 如果请求成功,记录到实际的成功请求计数中 if c.Writer.Status() < 400 { inMemoryRateLimiter.Request(successKey, successMaxCount, duration) } } } +// ModelRequestRateLimit 模型请求限流中间件 func ModelRequestRateLimit() func(c *gin.Context) { return func(c *gin.Context) { + // 在每个请求时检查是否启用限流 if !setting.ModelRequestRateLimitEnabled { c.Next() return } + // 计算限流参数 duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60) + totalMaxCount := setting.ModelRequestRateLimitCount + successMaxCount := setting.ModelRequestRateLimitSuccessCount + // 获取分组 group := c.GetString("token_group") if group == "" { - group = c.GetString("group") - } - if group == "" { - group = "default" + group = c.GetString(constant.ContextKeyUserGroup) } - finalTotalCount := setting.ModelRequestRateLimitCount - finalSuccessCount := setting.ModelRequestRateLimitSuccessCount - foundGroupLimit := false - + //获取分组的限流配置 groupTotalCount, groupSuccessCount, found := setting.GetGroupRateLimit(group) if found { - finalTotalCount = groupTotalCount - finalSuccessCount = groupSuccessCount - foundGroupLimit = true - common.LogWarn(c.Request.Context(), fmt.Sprintf("Using rate limit for group '%s': total=%d, success=%d", group, finalTotalCount, finalSuccessCount)) - } - - if !foundGroupLimit { - common.LogInfo(c.Request.Context(), fmt.Sprintf("No specific rate limit found for group '%s', using global limits: total=%d, success=%d", group, finalTotalCount, finalSuccessCount)) + totalMaxCount = groupTotalCount + successMaxCount = groupSuccessCount } + // 根据存储类型选择并执行限流处理器 if common.RedisEnabled { - redisRateLimitHandler(duration, finalTotalCount, finalSuccessCount)(c) + redisRateLimitHandler(duration, totalMaxCount, successMaxCount)(c) } else { - memoryRateLimitHandler(duration, finalTotalCount, finalSuccessCount)(c) + memoryRateLimitHandler(duration, totalMaxCount, successMaxCount)(c) } } -} +} \ No newline at end of file diff --git a/model/option.go b/model/option.go index 79556737..e9c129e1 100644 --- a/model/option.go +++ b/model/option.go @@ -1,8 +1,6 @@ package model import ( - "encoding/json" - "fmt" "one-api/common" "one-api/setting" "one-api/setting/config" @@ -94,8 +92,7 @@ func InitOptionMap() { common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount) common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes) common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount) - jsonBytes, _ := json.Marshal(map[string][2]int{}) - common.OptionMap["ModelRequestRateLimitGroup"] = string(jsonBytes) + common.OptionMap["ModelRequestRateLimitGroup"] = setting.ModelRequestRateLimitGroup2JSONString() common.OptionMap["ModelRatio"] = operation_setting.ModelRatio2JSONString() common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString() common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString() @@ -154,31 +151,18 @@ func SyncOptions(frequency int) { } func UpdateOption(key string, value string) error { - originalValue := value - - if key == "ModelRequestRateLimitGroup" { - var cfg map[string][2]int - err := json.Unmarshal([]byte(originalValue), &cfg) - if err != nil { - return fmt.Errorf("无效的 JSON 格式 for %s: %w", key, err) - } - - formattedValueBytes, marshalErr := json.MarshalIndent(cfg, "", " ") - if marshalErr != nil { - return fmt.Errorf("failed to marshal validated %s config: %w", key, marshalErr) - } - value = string(formattedValueBytes) - } - + // Save to database first option := Option{ Key: key, } + // https://gorm.io/docs/update.html#Save-All-Fields DB.FirstOrCreate(&option, Option{Key: key}) option.Value = value - if err := DB.Save(&option).Error; err != nil { - return fmt.Errorf("保存选项 %s 到数据库失败: %w", key, err) - } - + // Save is a combination function. + // If save value does not contain primary key, it will execute Create, + // otherwise it will execute Update (with all fields). + DB.Save(&option) + // Update OptionMap return updateOptionMap(key, value) } @@ -356,7 +340,7 @@ func updateOptionMap(key string, value string) (err error) { case "ModelRequestRateLimitSuccessCount": setting.ModelRequestRateLimitSuccessCount, _ = strconv.Atoi(value) case "ModelRequestRateLimitGroup": - err = setting.UpdateModelRequestRateLimitGroup(value) + err = setting.UpdateModelRequestRateLimitGroupByJSONString(value) case "RetryTimes": common.RetryTimes, _ = strconv.Atoi(value) case "DataExportInterval": diff --git a/setting/rate_limit.go b/setting/rate_limit.go index 5be75cc1..aab030cd 100644 --- a/setting/rate_limit.go +++ b/setting/rate_limit.go @@ -2,7 +2,6 @@ package setting import ( "encoding/json" - "fmt" "one-api/common" "sync" ) @@ -11,33 +10,31 @@ var ModelRequestRateLimitEnabled = false var ModelRequestRateLimitDurationMinutes = 1 var ModelRequestRateLimitCount = 0 var ModelRequestRateLimitSuccessCount = 1000 -var ModelRequestRateLimitGroup map[string][2]int +var ModelRequestRateLimitGroup = map[string][2]int{} +var ModelRequestRateLimitMutex sync.RWMutex -var ModelRequestRateLimitGroupMutex sync.RWMutex +func ModelRequestRateLimitGroup2JSONString() string { + ModelRequestRateLimitMutex.RLock() + defer ModelRequestRateLimitMutex.RUnlock() -func UpdateModelRequestRateLimitGroup(jsonStr string) error { - ModelRequestRateLimitGroupMutex.Lock() - defer ModelRequestRateLimitGroupMutex.Unlock() - - var newConfig map[string][2]int - if jsonStr == "" || jsonStr == "{}" { - ModelRequestRateLimitGroup = make(map[string][2]int) - common.SysLog("Model request rate limit group config cleared") - return nil - } - - err := json.Unmarshal([]byte(jsonStr), &newConfig) + jsonBytes, err := json.Marshal(ModelRequestRateLimitGroup) if err != nil { - return fmt.Errorf("failed to unmarshal ModelRequestRateLimitGroup config: %w", err) + common.SysError("error marshalling model ratio: " + err.Error()) } + return string(jsonBytes) +} - ModelRequestRateLimitGroup = newConfig - return nil +func UpdateModelRequestRateLimitGroupByJSONString(jsonStr string) error { + ModelRequestRateLimitMutex.RLock() + defer ModelRequestRateLimitMutex.RUnlock() + + ModelRequestRateLimitGroup = make(map[string][2]int) + return json.Unmarshal([]byte(jsonStr), &ModelRequestRateLimitGroup) } func GetGroupRateLimit(group string) (totalCount, successCount int, found bool) { - ModelRequestRateLimitGroupMutex.RLock() - defer ModelRequestRateLimitGroupMutex.RUnlock() + ModelRequestRateLimitMutex.RLock() + defer ModelRequestRateLimitMutex.RUnlock() if ModelRequestRateLimitGroup == nil { return 0, 0, false diff --git a/web/src/components/RateLimitSetting.js b/web/src/components/RateLimitSetting.js index 7e206672..309b94de 100644 --- a/web/src/components/RateLimitSetting.js +++ b/web/src/components/RateLimitSetting.js @@ -24,7 +24,6 @@ const RateLimitSetting = () => { if (success) { let newInputs = {}; data.forEach((item) => { - // 检查 key 是否在初始 inputs 中定义 if (Object.prototype.hasOwnProperty.call(inputs, item.key)) { if (item.key.endsWith('Enabled')) { newInputs[item.key] = item.value === 'true'; @@ -33,6 +32,7 @@ const RateLimitSetting = () => { } } }); + setInputs(newInputs); } else { showError(message); @@ -42,6 +42,7 @@ const RateLimitSetting = () => { try { setLoading(true); await getOptions(); + // showSuccess('刷新成功'); } catch (error) { showError('刷新失败'); } finally { @@ -56,6 +57,7 @@ const RateLimitSetting = () => { return ( <> + {/* AI请求速率限制 */} @@ -64,4 +66,4 @@ const RateLimitSetting = () => { ); }; - export default RateLimitSetting; + export default RateLimitSetting; \ No newline at end of file From 88ed83f41927eacc43526b5739592016d2ae4c10 Mon Sep 17 00:00:00 2001 From: tbphp Date: Mon, 5 May 2025 20:00:06 +0800 Subject: [PATCH 014/105] feat: Modellimitgroup check --- controller/option.go | 9 +++++++++ setting/rate_limit.go | 16 ++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/controller/option.go b/controller/option.go index 81ef463c..250f16bb 100644 --- a/controller/option.go +++ b/controller/option.go @@ -110,6 +110,15 @@ func UpdateOption(c *gin.Context) { }) return } + case "ModelRequestRateLimitGroup": + err = setting.CheckModelRequestRateLimitGroup(option.Value) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } } err = model.UpdateOption(option.Key, option.Value) diff --git a/setting/rate_limit.go b/setting/rate_limit.go index aab030cd..14680791 100644 --- a/setting/rate_limit.go +++ b/setting/rate_limit.go @@ -2,6 +2,7 @@ package setting import ( "encoding/json" + "fmt" "one-api/common" "sync" ) @@ -46,3 +47,18 @@ func GetGroupRateLimit(group string) (totalCount, successCount int, found bool) } return limits[0], limits[1], true } + +func CheckModelRequestRateLimitGroup(jsonStr string) error { + checkModelRequestRateLimitGroup := make(map[string][2]int) + err := json.Unmarshal([]byte(jsonStr), &checkModelRequestRateLimitGroup) + if err != nil { + return err + } + for group, limits := range checkModelRequestRateLimitGroup { + if limits[0] < 0 || limits[1] < 0 { + return fmt.Errorf("group %s has negative rate limit values: [%d, %d]", group, limits[0], limits[1]) + } + } + + return nil +} From 1cb4d750e471649da8fa5824942c43bffdc4705e Mon Sep 17 00:00:00 2001 From: tbphp Date: Mon, 5 May 2025 22:06:16 +0800 Subject: [PATCH 015/105] =?UTF-8?q?feat:=20=E5=88=86=E7=BB=84=E9=80=9F?= =?UTF-8?q?=E7=8E=87=E5=89=8D=E7=AB=AF=E4=BC=98=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- web/src/components/RateLimitSetting.js | 16 ++-- .../RateLimit/SettingsRequestRateLimit.js | 83 ++++++++----------- 2 files changed, 45 insertions(+), 54 deletions(-) diff --git a/web/src/components/RateLimitSetting.js b/web/src/components/RateLimitSetting.js index 309b94de..4671317f 100644 --- a/web/src/components/RateLimitSetting.js +++ b/web/src/components/RateLimitSetting.js @@ -13,7 +13,7 @@ const RateLimitSetting = () => { ModelRequestRateLimitCount: 0, ModelRequestRateLimitSuccessCount: 1000, ModelRequestRateLimitDurationMinutes: 1, - ModelRequestRateLimitGroup: '{}', + ModelRequestRateLimitGroup: '', }); let [loading, setLoading] = useState(false); @@ -24,12 +24,14 @@ const RateLimitSetting = () => { if (success) { let newInputs = {}; data.forEach((item) => { - if (Object.prototype.hasOwnProperty.call(inputs, item.key)) { - if (item.key.endsWith('Enabled')) { - newInputs[item.key] = item.value === 'true'; - } else { - newInputs[item.key] = item.value; - } + if (item.key === 'ModelRequestRateLimitGroup') { + item.value = JSON.stringify(JSON.parse(item.value), null, 2); + } + + if (item.key.endsWith('Enabled')) { + newInputs[item.key] = item.value === 'true' ? true : false; + } else { + newInputs[item.key] = item.value; } }); diff --git a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js index 2434020e..b77c1e6a 100644 --- a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js +++ b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js @@ -6,6 +6,7 @@ import { showError, showSuccess, showWarning, + verifyJSON, } from '../../../helpers'; import { useTranslation } from 'react-i18next'; @@ -18,7 +19,7 @@ export default function RequestRateLimit(props) { ModelRequestRateLimitCount: -1, ModelRequestRateLimitSuccessCount: 1000, ModelRequestRateLimitDurationMinutes: 1, - ModelRequestRateLimitGroup: '{}', + ModelRequestRateLimitGroup: '', }); const refForm = useRef(); const [inputsRow, setInputsRow] = useState(inputs); @@ -33,43 +34,32 @@ export default function RequestRateLimit(props) { } else { value = inputs[item.key]; } - if (item.key === 'ModelRequestRateLimitGroup') { - try { - JSON.parse(value); - } catch (e) { - showError(t('用户组速率限制配置不是有效的 JSON 格式!')); - return Promise.reject('Invalid JSON format'); - } - } return API.put('/api/option/', { key: item.key, value, }); }); - - const validRequests = requestQueue.filter(req => req !== null && req !== undefined && typeof req.then === 'function'); - - if (validRequests.length === 0 && requestQueue.length > 0) { - return; - } - setLoading(true); - Promise.all(validRequests) + Promise.all(requestQueue) .then((res) => { - if (validRequests.length === 1) { + if (requestQueue.length === 1) { if (res.includes(undefined)) return; - } else if (validRequests.length > 1) { + } else if (requestQueue.length > 1) { if (res.includes(undefined)) return showError(t('部分保存失败,请重试')); } + + for (let i = 0; i < res.length; i++) { + if (!res[i].data.success) { + return showError(res[i].data.message); + } + } + showSuccess(t('保存成功')); props.refresh(); - setInputsRow(structuredClone(inputs)); }) - .catch((error) => { - if (error !== 'Invalid JSON format') { - showError(t('保存失败,请重试')); - } + .catch(() => { + showError(t('保存失败,请重试')); }) .finally(() => { setLoading(false); @@ -79,15 +69,13 @@ export default function RequestRateLimit(props) { useEffect(() => { const currentInputs = {}; for (let key in props.options) { - if (Object.prototype.hasOwnProperty.call(inputs, key)) { // 使用 hasOwnProperty 检查 + if (Object.keys(inputs).includes(key)) { currentInputs[key] = props.options[key]; } } setInputs(currentInputs); setInputsRow(structuredClone(currentInputs)); - if (refForm.current) { refForm.current.setValues(currentInputs); - } }, [props.options]); return ( @@ -168,40 +156,41 @@ export default function RequestRateLimit(props) { />
- - + + verifyJSON(value), + message: t('不是合法的 JSON 字符串'), + }, + ]} extraText={
-

{t('说明:')}

+

{t('说明:')}

    -
  • {t('使用 JSON 对象格式,键为用户组名 (字符串),值为包含两个整数的数组 [总次数限制, 成功次数限制]。')}
  • -
  • {t('总次数限制: 周期内允许的总请求次数 (含失败),0 代表不限制。')}
  • -
  • {t('成功次数限制: 周期内允许的成功请求次数 (HTTP < 400),必须大于 0。')}
  • -
  • {t('此配置将优先于上方的全局限制设置。')}
  • -
  • {t('未在此处配置的用户组将使用全局限制。')}
  • +
  • {t('使用 JSON 对象格式,格式为:{"组名": [最多请求次数, 最多请求完成次数]}')}
  • +
  • {t('示例:{"default": [200, 100], "vip": [0, 1000]}。')}
  • +
  • {t('分组速率配置优先级高于全局速率限制。')}
  • {t('限制周期统一使用上方配置的“限制周期”值。')}
  • -
  • {t('输入无效的 JSON 将无法保存。')}
} - autosize={{ minRows: 5, maxRows: 15 }} - style={{ fontFamily: 'monospace' }} onChange={(value) => { - setInputs({ - ...inputs, - ModelRequestRateLimitGroup: value, - }); + setInputs({ ...inputs, ModelRequestRateLimitGroup: value }); }} />
- + @@ -211,4 +200,4 @@ export default function RequestRateLimit(props) { ); - } + } \ No newline at end of file From 0be3678c9ca8d687920ba52ff7d17d65afba23ca Mon Sep 17 00:00:00 2001 From: tbphp Date: Mon, 5 May 2025 23:41:43 +0800 Subject: [PATCH 016/105] =?UTF-8?q?fix:=20=E8=AF=B7=E6=B1=82=E5=AE=8C?= =?UTF-8?q?=E6=88=90=E6=95=B0=E5=BF=85=E9=A1=BB=E5=A4=A7=E4=BA=8E=E7=AD=89?= =?UTF-8?q?=E4=BA=8E1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- setting/rate_limit.go | 2 +- web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/setting/rate_limit.go b/setting/rate_limit.go index 14680791..53b53f88 100644 --- a/setting/rate_limit.go +++ b/setting/rate_limit.go @@ -55,7 +55,7 @@ func CheckModelRequestRateLimitGroup(jsonStr string) error { return err } for group, limits := range checkModelRequestRateLimitGroup { - if limits[0] < 0 || limits[1] < 0 { + if limits[0] < 0 || limits[1] < 1 { return fmt.Errorf("group %s has negative rate limit values: [%d, %d]", group, limits[0], limits[1]) } } diff --git a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js index b77c1e6a..ae54b1ef 100644 --- a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js +++ b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js @@ -179,6 +179,7 @@ export default function RequestRateLimit(props) {
  • {t('使用 JSON 对象格式,格式为:{"组名": [最多请求次数, 最多请求完成次数]}')}
  • {t('示例:{"default": [200, 100], "vip": [0, 1000]}。')}
  • +
  • {t('[最多请求次数]必须大于等于0,[最多请求完成次数]必须大于等于1')}
  • {t('分组速率配置优先级高于全局速率限制。')}
  • {t('限制周期统一使用上方配置的“限制周期”值。')}
From 1c67dd3c31b010b808c2db63117cb00243f2f544 Mon Sep 17 00:00:00 2001 From: "Apple\\Apple" Date: Mon, 5 May 2025 23:44:30 +0800 Subject: [PATCH 017/105] =?UTF-8?q?=F0=9F=93=95docs:=20Update=20the=20cont?= =?UTF-8?q?ent=20in=20`README.en.md`=20and=20the=20structure=20of=20the=20?= =?UTF-8?q?docs=20directory?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.en.md | 237 ++++++++++----------- README.md | 2 +- BT.md => docs/installation/BT.md | 6 +- Midjourney.md => docs/models/Midjourney.md | 0 Rerank.md => docs/models/Rerank.md | 0 Suno.md => docs/models/Suno.md | 0 6 files changed, 116 insertions(+), 129 deletions(-) rename BT.md => docs/installation/BT.md (98%) rename Midjourney.md => docs/models/Midjourney.md (100%) rename Rerank.md => docs/models/Rerank.md (100%) rename Suno.md => docs/models/Suno.md (100%) diff --git a/README.en.md b/README.en.md index c3be8381..23fdbe1f 100644 --- a/README.en.md +++ b/README.en.md @@ -1,10 +1,13 @@ +

+ 中文 | English +

![new-api](/web/public/logo.png) # New API -🍥 Next Generation LLM Gateway and AI Asset Management System +🍥 Next-Generation Large Model Gateway and AI Asset Management System Calcium-Ion%2Fnew-api | Trendshift @@ -33,171 +36,155 @@ > This is an open-source project developed based on [One API](https://github.com/songquanpeng/one-api) > [!IMPORTANT] -> - Users must comply with OpenAI's [Terms of Use](https://openai.com/policies/terms-of-use) and relevant laws and regulations. Not to be used for illegal purposes. -> - This project is for personal learning only. Stability is not guaranteed, and no technical support is provided. +> - This project is for personal learning purposes only, with no guarantee of stability or technical support. +> - Users must comply with OpenAI's [Terms of Use](https://openai.com/policies/terms-of-use) and **applicable laws and regulations**, and must not use it for illegal purposes. +> - According to the [《Interim Measures for the Management of Generative Artificial Intelligence Services》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm), please do not provide any unregistered generative AI services to the public in China. + +## 📚 Documentation + +For detailed documentation, please visit our official Wiki: [https://docs.newapi.pro/](https://docs.newapi.pro/) ## ✨ Key Features -1. 🎨 New UI interface (some interfaces pending update) -2. 🌍 Multi-language support (work in progress) -3. 🎨 Added [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy) interface support, [Integration Guide](Midjourney.md) -4. 💰 Online recharge support, configurable in system settings: - - [x] EasyPay -5. 🔍 Query usage quota by key: - - Works with [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool) -6. 📑 Configurable items per page in pagination -7. 🔄 Compatible with original One API database (one-api.db) -8. 💵 Support per-request model pricing, configurable in System Settings - Operation Settings -9. ⚖️ Support channel **weighted random** selection -10. 📈 Data dashboard (console) -11. 🔒 Configurable model access per token -12. 🤖 Telegram authorization login support: - 1. System Settings - Configure Login Registration - Allow Telegram Login - 2. Send /setdomain command to [@Botfather](https://t.me/botfather) - 3. Select your bot, then enter http(s)://your-website/login - 4. Telegram Bot name is the bot username without @ -13. 🎵 Added [Suno API](https://github.com/Suno-API/Suno-API) interface support, [Integration Guide](Suno.md) -14. 🔄 Support for Rerank models, compatible with Cohere and Jina, can integrate with Dify, [Integration Guide](Rerank.md) -15. ⚡ **[OpenAI Realtime API](https://platform.openai.com/docs/guides/realtime/integration)** - Support for OpenAI's Realtime API, including Azure channels -16. 🧠 Support for setting reasoning effort through model name suffix: - - Add suffix `-high` to set high reasoning effort (e.g., `o3-mini-high`) - - Add suffix `-medium` to set medium reasoning effort - - Add suffix `-low` to set low reasoning effort -17. 🔄 Thinking to content option `thinking_to_content` in `Channel->Edit->Channel Extra Settings`, default is `false`, when `true`, the `reasoning_content` of the thinking content will be converted to `` tags and concatenated to the content returned. -18. 🔄 Model rate limit, support setting total request limit and successful request limit in `System Settings->Rate Limit Settings` -19. 💰 Cache billing support, when enabled can charge a configurable ratio for cache hits: - 1. Set `Prompt Cache Ratio` in `System Settings -> Operation Settings` - 2. Set `Prompt Cache Ratio` in channel settings, range 0-1 (e.g., 0.5 means 50% charge on cache hits) +New API offers a wide range of features, please refer to [Features Introduction](https://docs.newapi.pro/wiki/features-introduction) for details: + +1. 🎨 Brand new UI interface +2. 🌍 Multi-language support +3. 💰 Online recharge functionality (YiPay) +4. 🔍 Support for querying usage quotas with keys (works with [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool)) +5. 🔄 Compatible with the original One API database +6. 💵 Support for pay-per-use model pricing +7. ⚖️ Support for weighted random channel selection +8. 📈 Data dashboard (console) +9. 🔒 Token grouping and model restrictions +10. 🤖 Support for more authorization login methods (LinuxDO, Telegram, OIDC) +11. 🔄 Support for Rerank models (Cohere and Jina), [API Documentation](https://docs.newapi.pro/api/jinaai-rerank) +12. ⚡ Support for OpenAI Realtime API (including Azure channels), [API Documentation](https://docs.newapi.pro/api/openai-realtime) +13. ⚡ Support for Claude Messages format, [API Documentation](https://docs.newapi.pro/api/anthropic-chat) +14. Support for entering chat interface via /chat2link route +15. 🧠 Support for setting reasoning effort through model name suffixes: + 1. OpenAI o-series models + - Add `-high` suffix for high reasoning effort (e.g.: `o3-mini-high`) + - Add `-medium` suffix for medium reasoning effort (e.g.: `o3-mini-medium`) + - Add `-low` suffix for low reasoning effort (e.g.: `o3-mini-low`) + 2. Claude thinking models + - Add `-thinking` suffix to enable thinking mode (e.g.: `claude-3-7-sonnet-20250219-thinking`) +16. 🔄 Thinking-to-content functionality +17. 🔄 Model rate limiting for users +18. 💰 Cache billing support, which allows billing at a set ratio when cache is hit: + 1. Set the `Prompt Cache Ratio` option in `System Settings-Operation Settings` + 2. Set `Prompt Cache Ratio` in the channel, range 0-1, e.g., setting to 0.5 means billing at 50% when cache is hit 3. Supported channels: - [x] OpenAI - - [x] Azure + - [x] Azure - [x] DeepSeek - - [ ] Claude + - [x] Claude ## Model Support -This version additionally supports: -1. Third-party model **gpts** (gpt-4-gizmo-*) -2. [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy) interface, [Integration Guide](Midjourney.md) -3. Custom channels with full API URL support -4. [Suno API](https://github.com/Suno-API/Suno-API) interface, [Integration Guide](Suno.md) -5. Rerank models, supporting [Cohere](https://cohere.ai/) and [Jina](https://jina.ai/), [Integration Guide](Rerank.md) -6. Dify -You can add custom models gpt-4-gizmo-* in channels. These are third-party models and cannot be called with official OpenAI keys. +This version supports multiple models, please refer to [API Documentation-Relay Interface](https://docs.newapi.pro/api) for details: -## Additional Configurations Beyond One API -- `GENERATE_DEFAULT_TOKEN`: Generate initial token for new users, default `false` -- `STREAMING_TIMEOUT`: Set streaming response timeout, default 60 seconds -- `DIFY_DEBUG`: Output workflow and node info to client for Dify channel, default `true` -- `FORCE_STREAM_OPTION`: Override client stream_options parameter, default `true` -- `GET_MEDIA_TOKEN`: Calculate image tokens, default `true` -- `GET_MEDIA_TOKEN_NOT_STREAM`: Calculate image tokens in non-stream mode, default `true` -- `UPDATE_TASK`: Update async tasks (Midjourney, Suno), default `true` -- `GEMINI_MODEL_MAP`: Specify Gemini model versions (v1/v1beta), format: "model:version", comma-separated -- `COHERE_SAFETY_SETTING`: Cohere model [safety settings](https://docs.cohere.com/docs/safety-modes#overview), options: `NONE`, `CONTEXTUAL`, `STRICT`, default `NONE` -- `GEMINI_VISION_MAX_IMAGE_NUM`: Gemini model maximum image number, default `16`, set to `-1` to disable -- `MAX_FILE_DOWNLOAD_MB`: Maximum file download size in MB, default `20` -- `CRYPTO_SECRET`: Encryption key for encrypting database content -- `AZURE_DEFAULT_API_VERSION`: Azure channel default API version, if not specified in channel settings, use this version, default `2024-12-01-preview` -- `NOTIFICATION_LIMIT_DURATION_MINUTE`: Duration of notification limit in minutes, default `10` -- `NOTIFY_LIMIT_COUNT`: Maximum number of user notifications in the specified duration, default `2` +1. Third-party models **gpts** (gpt-4-gizmo-*) +2. Third-party channel [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy) interface, [API Documentation](https://docs.newapi.pro/api/midjourney-proxy-image) +3. Third-party channel [Suno API](https://github.com/Suno-API/Suno-API) interface, [API Documentation](https://docs.newapi.pro/api/suno-music) +4. Custom channels, supporting full call address input +5. Rerank models ([Cohere](https://cohere.ai/) and [Jina](https://jina.ai/)), [API Documentation](https://docs.newapi.pro/api/jinaai-rerank) +6. Claude Messages format, [API Documentation](https://docs.newapi.pro/api/anthropic-chat) +7. Dify, currently only supports chatflow + +## Environment Variable Configuration + +For detailed configuration instructions, please refer to [Installation Guide-Environment Variables Configuration](https://docs.newapi.pro/installation/environment-variables): + +- `GENERATE_DEFAULT_TOKEN`: Whether to generate initial tokens for newly registered users, default is `false` +- `STREAMING_TIMEOUT`: Streaming response timeout, default is 60 seconds +- `DIFY_DEBUG`: Whether to output workflow and node information for Dify channels, default is `true` +- `FORCE_STREAM_OPTION`: Whether to override client stream_options parameter, default is `true` +- `GET_MEDIA_TOKEN`: Whether to count image tokens, default is `true` +- `GET_MEDIA_TOKEN_NOT_STREAM`: Whether to count image tokens in non-streaming cases, default is `true` +- `UPDATE_TASK`: Whether to update asynchronous tasks (Midjourney, Suno), default is `true` +- `COHERE_SAFETY_SETTING`: Cohere model safety settings, options are `NONE`, `CONTEXTUAL`, `STRICT`, default is `NONE` +- `GEMINI_VISION_MAX_IMAGE_NUM`: Maximum number of images for Gemini models, default is `16` +- `MAX_FILE_DOWNLOAD_MB`: Maximum file download size in MB, default is `20` +- `CRYPTO_SECRET`: Encryption key used for encrypting database content +- `AZURE_DEFAULT_API_VERSION`: Azure channel default API version, default is `2024-12-01-preview` +- `NOTIFICATION_LIMIT_DURATION_MINUTE`: Notification limit duration, default is `10` minutes +- `NOTIFY_LIMIT_COUNT`: Maximum number of user notifications within the specified duration, default is `2` ## Deployment +For detailed deployment guides, please refer to [Installation Guide-Deployment Methods](https://docs.newapi.pro/installation): + > [!TIP] -> Latest Docker image: `calciumion/new-api:latest` -> Default account: root, password: 123456 +> Latest Docker image: `calciumion/new-api:latest` -### Multi-Server Deployment -- Must set `SESSION_SECRET` environment variable, otherwise login state will not be consistent across multiple servers. -- If using a public Redis, must set `CRYPTO_SECRET` environment variable, otherwise Redis content will not be able to be obtained in multi-server deployment. +### Multi-machine Deployment Considerations +- Environment variable `SESSION_SECRET` must be set, otherwise login status will be inconsistent across multiple machines +- If sharing Redis, `CRYPTO_SECRET` must be set, otherwise Redis content cannot be accessed across multiple machines -### Requirements -- Local database (default): SQLite (Docker deployment must mount `/data` directory) -- Remote database: MySQL >= 5.7.8, PgSQL >= 9.6 +### Deployment Requirements +- Local database (default): SQLite (Docker deployment must mount the `/data` directory) +- Remote database: MySQL version >= 5.7.8, PgSQL version >= 9.6 -### Deployment with BT Panel -Install BT Panel (**version 9.2.0** or above) from [BT Panel Official Website](https://www.bt.cn/new/download.html), choose the stable version script to download and install. -After installation, log in to BT Panel and click Docker in the menu bar. First-time access will prompt to install Docker service. Click Install Now and follow the prompts to complete installation. -After installation, find **New-API** in the app store, click install, configure basic options to complete installation. -[Pictorial Guide](BT.md) +### Deployment Methods -### Docker Deployment +#### Using BaoTa Panel Docker Feature +Install BaoTa Panel (version **9.2.0** or above), find **New-API** in the application store and install it. +[Tutorial with images](./docs/BT.md) -### Using Docker Compose (Recommended) +#### Using Docker Compose (Recommended) ```shell -# Clone project +# Download the project git clone https://github.com/Calcium-Ion/new-api.git cd new-api # Edit docker-compose.yml as needed -# nano docker-compose.yml -# vim docker-compose.yml # Start docker-compose up -d ``` -#### Update Version +#### Using Docker Image Directly ```shell -docker-compose pull -docker-compose up -d -``` - -### Direct Docker Image Usage -```shell -# SQLite deployment: +# Using SQLite docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest -# MySQL deployment (add -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi"), modify database connection parameters as needed -# Example: +# Using MySQL docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest ``` -#### Update Version -```shell -# Pull the latest image -docker pull calciumion/new-api:latest -# Stop and remove the old container -docker stop new-api -docker rm new-api -# Run the new container with the same parameters as before -docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest -``` +## Channel Retry and Cache +Channel retry functionality has been implemented, you can set the number of retries in `Settings->Operation Settings->General Settings`. It is **recommended to enable caching**. -Alternatively, you can use Watchtower for automatic updates (not recommended, may cause database incompatibility): -```shell -docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR -``` +### Cache Configuration Method +1. `REDIS_CONN_STRING`: Set Redis as cache +2. `MEMORY_CACHE_ENABLED`: Enable memory cache (no need to set manually if Redis is set) -## Channel Retry -Channel retry is implemented, configurable in `Settings->Operation Settings->General Settings`. **Cache recommended**. -If retry is enabled, the system will automatically use the next priority channel for the same request after a failed request. +## API Documentation -### Cache Configuration -1. `REDIS_CONN_STRING`: Use Redis as cache - + Example: `REDIS_CONN_STRING=redis://default:redispw@localhost:49153` -2. `MEMORY_CACHE_ENABLED`: Enable memory cache, default `false` - + Example: `MEMORY_CACHE_ENABLED=true` +For detailed API documentation, please refer to [API Documentation](https://docs.newapi.pro/api): -### Why Some Errors Don't Retry -Error codes 400, 504, 524 won't retry -### To Enable Retry for 400 -In `Channel->Edit`, set `Status Code Override` to: -```json -{ - "400": "500" -} -``` - -## Integration Guides -- [Midjourney Integration](Midjourney.md) -- [Suno Integration](Suno.md) +- [Chat API](https://docs.newapi.pro/api/openai-chat) +- [Image API](https://docs.newapi.pro/api/openai-image) +- [Rerank API](https://docs.newapi.pro/api/jinaai-rerank) +- [Realtime API](https://docs.newapi.pro/api/openai-realtime) +- [Claude Chat API (messages)](https://docs.newapi.pro/api/anthropic-chat) ## Related Projects - [One API](https://github.com/songquanpeng/one-api): Original project - [Midjourney-Proxy](https://github.com/novicezk/midjourney-proxy): Midjourney interface support -- [chatnio](https://github.com/Deeptrain-Community/chatnio): Next-gen AI B/C solution -- [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool): Query usage quota by key +- [chatnio](https://github.com/Deeptrain-Community/chatnio): Next-generation AI one-stop B/C-end solution +- [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool): Query usage quota with key + +Other projects based on New API: +- [new-api-horizon](https://github.com/Calcium-Ion/new-api-horizon): High-performance optimized version of New API +- [VoAPI](https://github.com/VoAPI/VoAPI): Frontend beautified version based on New API + +## Help and Support + +If you have any questions, please refer to [Help and Support](https://docs.newapi.pro/support): +- [Community Interaction](https://docs.newapi.pro/support/community-interaction) +- [Issue Feedback](https://docs.newapi.pro/support/feedback-issues) +- [FAQ](https://docs.newapi.pro/support/faq) ## 🌟 Star History -[![Star History Chart](https://api.star-history.com/svg?repos=Calcium-Ion/new-api&type=Date)](https://star-history.com/#Calcium-Ion/new-api&Date) \ No newline at end of file +[![Star History Chart](https://api.star-history.com/svg?repos=Calcium-Ion/new-api&type=Date)](https://star-history.com/#Calcium-Ion/new-api&Date) diff --git a/README.md b/README.md index 6ac8839b..67af9916 100644 --- a/README.md +++ b/README.md @@ -130,7 +130,7 @@ New API提供了丰富的功能,详细特性请参考[特性说明](https://do #### 使用宝塔面板Docker功能部署 安装宝塔面板(**9.2.0版本**及以上),在应用商店中找到**New-API**安装即可。 -[图文教程](BT.md) +[图文教程](./docs/BT.md) #### 使用Docker Compose部署(推荐) ```shell diff --git a/BT.md b/docs/installation/BT.md similarity index 98% rename from BT.md rename to docs/installation/BT.md index e57cdab7..b4ea5b2f 100644 --- a/BT.md +++ b/docs/installation/BT.md @@ -1,3 +1,3 @@ -密钥为环境变量SESSION_SECRET - -![8285bba413e770fe9620f1bf9b40d44e](https://github.com/user-attachments/assets/7a6fc03e-c457-45e4-b8f9-184508fc26b0) +密钥为环境变量SESSION_SECRET + +![8285bba413e770fe9620f1bf9b40d44e](https://github.com/user-attachments/assets/7a6fc03e-c457-45e4-b8f9-184508fc26b0) diff --git a/Midjourney.md b/docs/models/Midjourney.md similarity index 100% rename from Midjourney.md rename to docs/models/Midjourney.md diff --git a/Rerank.md b/docs/models/Rerank.md similarity index 100% rename from Rerank.md rename to docs/models/Rerank.md diff --git a/Suno.md b/docs/models/Suno.md similarity index 100% rename from Suno.md rename to docs/models/Suno.md From bbab729619820b49706af49a48596e8cab105bde Mon Sep 17 00:00:00 2001 From: tbphp Date: Mon, 5 May 2025 23:48:15 +0800 Subject: [PATCH 018/105] fix: text --- web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js index ae54b1ef..7003c279 100644 --- a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js +++ b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js @@ -179,7 +179,7 @@ export default function RequestRateLimit(props) {
  • {t('使用 JSON 对象格式,格式为:{"组名": [最多请求次数, 最多请求完成次数]}')}
  • {t('示例:{"default": [200, 100], "vip": [0, 1000]}。')}
  • -
  • {t('[最多请求次数]必须大于等于0,[最多请求完成次数]必须大于等于1')}
  • +
  • {t('[最多请求次数]必须大于等于0,[最多请求完成次数]必须大于等于1。')}
  • {t('分组速率配置优先级高于全局速率限制。')}
  • {t('限制周期统一使用上方配置的“限制周期”值。')}
From 87188cd7d458464c7e83e3502eb0a11126e6f94e Mon Sep 17 00:00:00 2001 From: tbphp Date: Mon, 5 May 2025 23:53:05 +0800 Subject: [PATCH 019/105] =?UTF-8?q?fix:=20=E7=BC=A9=E8=BF=9B=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D=E8=BF=98=E5=8E=9F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- middleware/model-rate-limit.go | 2 +- model/option.go | 2 +- web/src/components/RateLimitSetting.js | 92 ++--- .../RateLimit/SettingsRequestRateLimit.js | 344 +++++++++--------- 4 files changed, 220 insertions(+), 220 deletions(-) diff --git a/middleware/model-rate-limit.go b/middleware/model-rate-limit.go index 1ca5ace6..03ef0ff3 100644 --- a/middleware/model-rate-limit.go +++ b/middleware/model-rate-limit.go @@ -194,4 +194,4 @@ func ModelRequestRateLimit() func(c *gin.Context) { memoryRateLimitHandler(duration, totalMaxCount, successMaxCount)(c) } } -} \ No newline at end of file +} diff --git a/model/option.go b/model/option.go index e9c129e1..d98a9d38 100644 --- a/model/option.go +++ b/model/option.go @@ -402,4 +402,4 @@ func handleConfigUpdate(key, value string) bool { config.UpdateConfigFromMap(cfg, configMap) return true // 已处理 -} \ No newline at end of file +} diff --git a/web/src/components/RateLimitSetting.js b/web/src/components/RateLimitSetting.js index 4671317f..a0953db7 100644 --- a/web/src/components/RateLimitSetting.js +++ b/web/src/components/RateLimitSetting.js @@ -9,62 +9,62 @@ import RequestRateLimit from '../pages/Setting/RateLimit/SettingsRequestRateLimi const RateLimitSetting = () => { const { t } = useTranslation(); let [inputs, setInputs] = useState({ - ModelRequestRateLimitEnabled: false, - ModelRequestRateLimitCount: 0, - ModelRequestRateLimitSuccessCount: 1000, - ModelRequestRateLimitDurationMinutes: 1, - ModelRequestRateLimitGroup: '', + ModelRequestRateLimitEnabled: false, + ModelRequestRateLimitCount: 0, + ModelRequestRateLimitSuccessCount: 1000, + ModelRequestRateLimitDurationMinutes: 1, + ModelRequestRateLimitGroup: '', }); - - let [loading, setLoading] = useState(false); - - const getOptions = async () => { - const res = await API.get('/api/option/'); - const { success, message, data } = res.data; - if (success) { - let newInputs = {}; - data.forEach((item) => { - if (item.key === 'ModelRequestRateLimitGroup') { - item.value = JSON.stringify(JSON.parse(item.value), null, 2); - } - if (item.key.endsWith('Enabled')) { - newInputs[item.key] = item.value === 'true' ? true : false; - } else { - newInputs[item.key] = item.value; - } - }); - - setInputs(newInputs); - } else { - showError(message); - } + let [loading, setLoading] = useState(false); + + const getOptions = async () => { + const res = await API.get('/api/option/'); + const { success, message, data } = res.data; + if (success) { + let newInputs = {}; + data.forEach((item) => { + if (item.key === 'ModelRequestRateLimitGroup') { + item.value = JSON.stringify(JSON.parse(item.value), null, 2); + } + + if (item.key.endsWith('Enabled')) { + newInputs[item.key] = item.value === 'true' ? true : false; + } else { + newInputs[item.key] = item.value; + } + }); + + setInputs(newInputs); + } else { + showError(message); + } }; async function onRefresh() { - try { - setLoading(true); - await getOptions(); - // showSuccess('刷新成功'); - } catch (error) { - showError('刷新失败'); - } finally { - setLoading(false); - } + try { + setLoading(true); + await getOptions(); + // showSuccess('刷新成功'); + } catch (error) { + showError('刷新失败'); + } finally { + setLoading(false); + } } useEffect(() => { - onRefresh(); + onRefresh(); }, []); return ( - <> - - {/* AI请求速率限制 */} - - - - - + <> + + {/* AI请求速率限制 */} + + + + + ); }; diff --git a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js index 7003c279..7c60bc47 100644 --- a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js +++ b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js @@ -15,190 +15,190 @@ export default function RequestRateLimit(props) { const [loading, setLoading] = useState(false); const [inputs, setInputs] = useState({ - ModelRequestRateLimitEnabled: false, - ModelRequestRateLimitCount: -1, - ModelRequestRateLimitSuccessCount: 1000, - ModelRequestRateLimitDurationMinutes: 1, - ModelRequestRateLimitGroup: '', + ModelRequestRateLimitEnabled: false, + ModelRequestRateLimitCount: -1, + ModelRequestRateLimitSuccessCount: 1000, + ModelRequestRateLimitDurationMinutes: 1, + ModelRequestRateLimitGroup: '', }); const refForm = useRef(); const [inputsRow, setInputsRow] = useState(inputs); function onSubmit() { - const updateArray = compareObjects(inputs, inputsRow); - if (!updateArray.length) return showWarning(t('你似乎并没有修改什么')); - const requestQueue = updateArray.map((item) => { - let value = ''; - if (typeof inputs[item.key] === 'boolean') { - value = String(inputs[item.key]); - } else { - value = inputs[item.key]; - } - return API.put('/api/option/', { - key: item.key, - value, - }); - }); - setLoading(true); - Promise.all(requestQueue) - .then((res) => { - if (requestQueue.length === 1) { - if (res.includes(undefined)) return; - } else if (requestQueue.length > 1) { - if (res.includes(undefined)) - return showError(t('部分保存失败,请重试')); - } + const updateArray = compareObjects(inputs, inputsRow); + if (!updateArray.length) return showWarning(t('你似乎并没有修改什么')); + const requestQueue = updateArray.map((item) => { + let value = ''; + if (typeof inputs[item.key] === 'boolean') { + value = String(inputs[item.key]); + } else { + value = inputs[item.key]; + } + return API.put('/api/option/', { + key: item.key, + value, + }); + }); + setLoading(true); + Promise.all(requestQueue) + .then((res) => { + if (requestQueue.length === 1) { + if (res.includes(undefined)) return; + } else if (requestQueue.length > 1) { + if (res.includes(undefined)) + return showError(t('部分保存失败,请重试')); + } - for (let i = 0; i < res.length; i++) { - if (!res[i].data.success) { - return showError(res[i].data.message); - } - } + for (let i = 0; i < res.length; i++) { + if (!res[i].data.success) { + return showError(res[i].data.message); + } + } - showSuccess(t('保存成功')); - props.refresh(); - }) - .catch(() => { - showError(t('保存失败,请重试')); - }) - .finally(() => { - setLoading(false); - }); + showSuccess(t('保存成功')); + props.refresh(); + }) + .catch(() => { + showError(t('保存失败,请重试')); + }) + .finally(() => { + setLoading(false); + }); } useEffect(() => { - const currentInputs = {}; - for (let key in props.options) { - if (Object.keys(inputs).includes(key)) { - currentInputs[key] = props.options[key]; - } - } - setInputs(currentInputs); - setInputsRow(structuredClone(currentInputs)); - refForm.current.setValues(currentInputs); + const currentInputs = {}; + for (let key in props.options) { + if (Object.keys(inputs).includes(key)) { + currentInputs[key] = props.options[key]; + } + } + setInputs(currentInputs); + setInputsRow(structuredClone(currentInputs)); + refForm.current.setValues(currentInputs); }, [props.options]); return ( - <> - -
(refForm.current = formAPI)} - style={{ marginBottom: 15 }} - > - - - - { - setInputs({ - ...inputs, - ModelRequestRateLimitEnabled: value, - }); - }} - /> - - - - - - setInputs({ - ...inputs, - ModelRequestRateLimitDurationMinutes: String(value), - }) - } - /> - - - - - - setInputs({ - ...inputs, - ModelRequestRateLimitCount: String(value), - }) - } - /> - - - - setInputs({ - ...inputs, - ModelRequestRateLimitSuccessCount: String(value), - }) - } - /> - - - - - verifyJSON(value), - message: t('不是合法的 JSON 字符串'), - }, - ]} - extraText={ -
-

{t('说明:')}

-
    -
  • {t('使用 JSON 对象格式,格式为:{"组名": [最多请求次数, 最多请求完成次数]}')}
  • -
  • {t('示例:{"default": [200, 100], "vip": [0, 1000]}。')}
  • -
  • {t('[最多请求次数]必须大于等于0,[最多请求完成次数]必须大于等于1。')}
  • -
  • {t('分组速率配置优先级高于全局速率限制。')}
  • -
  • {t('限制周期统一使用上方配置的“限制周期”值。')}
  • -
-
- } - onChange={(value) => { - setInputs({ ...inputs, ModelRequestRateLimitGroup: value }); - }} - /> - -
- - - -
-
-
- + <> + +
(refForm.current = formAPI)} + style={{ marginBottom: 15 }} + > + + + + { + setInputs({ + ...inputs, + ModelRequestRateLimitEnabled: value, + }); + }} + /> + + + + + + setInputs({ + ...inputs, + ModelRequestRateLimitDurationMinutes: String(value), + }) + } + /> + + + + + + setInputs({ + ...inputs, + ModelRequestRateLimitCount: String(value), + }) + } + /> + + + + setInputs({ + ...inputs, + ModelRequestRateLimitSuccessCount: String(value), + }) + } + /> + + + + + verifyJSON(value), + message: t('不是合法的 JSON 字符串'), + }, + ]} + extraText={ +
+

{t('说明:')}

+
    +
  • {t('使用 JSON 对象格式,格式为:{"组名": [最多请求次数, 最多请求完成次数]}')}
  • +
  • {t('示例:{"default": [200, 100], "vip": [0, 1000]}。')}
  • +
  • {t('[最多请求次数]必须大于等于0,[最多请求完成次数]必须大于等于1。')}
  • +
  • {t('分组速率配置优先级高于全局速率限制。')}
  • +
  • {t('限制周期统一使用上方配置的“限制周期”值。')}
  • +
+
+ } + onChange={(value) => { + setInputs({ ...inputs, ModelRequestRateLimitGroup: value }); + }} + /> + +
+ + + +
+
+
+ ); } \ No newline at end of file From 3d243c3ee2bc2a92d21d31f0155378ac5c188c39 Mon Sep 17 00:00:00 2001 From: tbphp Date: Mon, 5 May 2025 23:56:15 +0800 Subject: [PATCH 020/105] =?UTF-8?q?fix:=20=E6=A0=B7=E5=BC=8F=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- web/src/components/RateLimitSetting.js | 16 ++++++++-------- .../RateLimit/SettingsRequestRateLimit.js | 10 +++++----- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/web/src/components/RateLimitSetting.js b/web/src/components/RateLimitSetting.js index a0953db7..5f0200e1 100644 --- a/web/src/components/RateLimitSetting.js +++ b/web/src/components/RateLimitSetting.js @@ -34,7 +34,7 @@ const RateLimitSetting = () => { newInputs[item.key] = item.value; } }); - + setInputs(newInputs); } else { showError(message); @@ -44,28 +44,28 @@ const RateLimitSetting = () => { try { setLoading(true); await getOptions(); - // showSuccess('刷新成功'); + // showSuccess('刷新成功'); } catch (error) { showError('刷新失败'); } finally { setLoading(false); } } - + useEffect(() => { onRefresh(); }, []); - + return ( <> - {/* AI请求速率限制 */} + {/* AI请求速率限制 */} ); - }; - - export default RateLimitSetting; \ No newline at end of file +}; + +export default RateLimitSetting; diff --git a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js index 7c60bc47..73626351 100644 --- a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js +++ b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js @@ -23,7 +23,7 @@ export default function RequestRateLimit(props) { }); const refForm = useRef(); const [inputsRow, setInputsRow] = useState(inputs); - + function onSubmit() { const updateArray = compareObjects(inputs, inputsRow); if (!updateArray.length) return showWarning(t('你似乎并没有修改什么')); @@ -65,7 +65,7 @@ export default function RequestRateLimit(props) { setLoading(false); }); } - + useEffect(() => { const currentInputs = {}; for (let key in props.options) { @@ -75,9 +75,9 @@ export default function RequestRateLimit(props) { } setInputs(currentInputs); setInputsRow(structuredClone(currentInputs)); - refForm.current.setValues(currentInputs); + refForm.current.setValues(currentInputs); }, [props.options]); - + return ( <> @@ -201,4 +201,4 @@ export default function RequestRateLimit(props) { ); - } \ No newline at end of file +} From 0cf4c59d227a90a8dd4b66927b7b563dc3cea72d Mon Sep 17 00:00:00 2001 From: skynono <6811626@qq.com> Date: Tue, 6 May 2025 14:18:15 +0800 Subject: [PATCH 021/105] feat: add original password verification when changing password --- controller/user.go | 26 +++++++++++++++++++++++++- model/user.go | 1 + web/src/components/PersonalSetting.js | 20 ++++++++++++++++++++ 3 files changed, 46 insertions(+), 1 deletion(-) diff --git a/controller/user.go b/controller/user.go index e194f531..567c2aa7 100644 --- a/controller/user.go +++ b/controller/user.go @@ -592,7 +592,14 @@ func UpdateSelf(c *gin.Context) { user.Password = "" // rollback to what it should be cleanUser.Password = "" } - updatePassword := user.Password != "" + updatePassword, err := checkUpdatePassword(user.OriginalPassword, user.Password, cleanUser.Id) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } if err := cleanUser.Update(updatePassword); err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -608,6 +615,23 @@ func UpdateSelf(c *gin.Context) { return } +func checkUpdatePassword(originalPassword string, newPassword string, userId int) (updatePassword bool, err error) { + if newPassword == "" { + return + } + var currentUser *model.User + currentUser, err = model.GetUserById(userId, true) + if err != nil { + return + } + if !common.ValidatePasswordAndHash(originalPassword, currentUser.Password) { + err = fmt.Errorf("原密码错误") + return + } + updatePassword = true + return +} + func DeleteUser(c *gin.Context) { id, err := strconv.Atoi(c.Param("id")) if err != nil { diff --git a/model/user.go b/model/user.go index 0aea2ff5..1a3372aa 100644 --- a/model/user.go +++ b/model/user.go @@ -18,6 +18,7 @@ type User struct { Id int `json:"id"` Username string `json:"username" gorm:"unique;index" validate:"max=12"` Password string `json:"password" gorm:"not null;" validate:"min=8,max=20"` + OriginalPassword string `json:"original_password" gorm:"-:all"` // this field is only for Password change verification, don't save it to database! DisplayName string `json:"display_name" gorm:"index" validate:"max=20"` Role int `json:"role" gorm:"type:int;default:1"` // admin, common Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled diff --git a/web/src/components/PersonalSetting.js b/web/src/components/PersonalSetting.js index d1e03db2..fbd74536 100644 --- a/web/src/components/PersonalSetting.js +++ b/web/src/components/PersonalSetting.js @@ -57,6 +57,7 @@ const PersonalSetting = () => { email_verification_code: '', email: '', self_account_deletion_confirmation: '', + original_password: '', set_new_password: '', set_new_password_confirmation: '', }); @@ -239,11 +240,20 @@ const PersonalSetting = () => { }; const changePassword = async () => { + if (inputs.original_password === '') { + showError(t('请输入原密码!')); + return; + } + if (inputs.original_password === inputs.set_new_password) { + showError(t('新密码需要和原密码不一致!')); + return; + } if (inputs.set_new_password !== inputs.set_new_password_confirmation) { showError(t('两次输入的密码不一致!')); return; } const res = await API.put(`/api/user/self`, { + original_password: inputs.original_password, password: inputs.set_new_password, }); const { success, message } = res.data; @@ -1118,6 +1128,16 @@ const PersonalSetting = () => { >
+ handleInputChange('original_password', value) + } + /> + Date: Tue, 6 May 2025 18:41:01 +0800 Subject: [PATCH 022/105] feat: add support for DeepSeek channel in streamSupportedChannels --- relay/common/relay_info.go | 1 + 1 file changed, 1 insertion(+) diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 915474e1..0135283d 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -103,6 +103,7 @@ var streamSupportedChannels = map[int]bool{ common.ChannelTypeVolcEngine: true, common.ChannelTypeOllama: true, common.ChannelTypeXai: true, + common.ChannelTypeDeepSeek: true, } func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo { From 459c277c941ac61b81189f22b06637eff71485bf Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Tue, 6 May 2025 21:58:01 +0800 Subject: [PATCH 023/105] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=20built=20in?= =?UTF-8?q?=20tools=20=E8=AE=A1=E8=B4=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 增加非流的工具调用次数统计 - 添加 web search 和 file search 计费 --- dto/openai_response.go | 44 +++++++------- relay/channel/openai/relay_responses.go | 7 ++- relay/relay-text.go | 77 +++++++++++++++++++++++++ 3 files changed, 105 insertions(+), 23 deletions(-) diff --git a/dto/openai_response.go b/dto/openai_response.go index c8f61b9d..790d4df8 100644 --- a/dto/openai_response.go +++ b/dto/openai_response.go @@ -195,28 +195,28 @@ type OutputTokenDetails struct { } type OpenAIResponsesResponse struct { - ID string `json:"id"` - Object string `json:"object"` - CreatedAt int `json:"created_at"` - Status string `json:"status"` - Error *OpenAIError `json:"error,omitempty"` - IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"` - Instructions string `json:"instructions"` - MaxOutputTokens int `json:"max_output_tokens"` - Model string `json:"model"` - Output []ResponsesOutput `json:"output"` - ParallelToolCalls bool `json:"parallel_tool_calls"` - PreviousResponseID string `json:"previous_response_id"` - Reasoning *Reasoning `json:"reasoning"` - Store bool `json:"store"` - Temperature float64 `json:"temperature"` - ToolChoice string `json:"tool_choice"` - Tools []interface{} `json:"tools"` - TopP float64 `json:"top_p"` - Truncation string `json:"truncation"` - Usage *Usage `json:"usage"` - User json.RawMessage `json:"user"` - Metadata json.RawMessage `json:"metadata"` + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int `json:"created_at"` + Status string `json:"status"` + Error *OpenAIError `json:"error,omitempty"` + IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"` + Instructions string `json:"instructions"` + MaxOutputTokens int `json:"max_output_tokens"` + Model string `json:"model"` + Output []ResponsesOutput `json:"output"` + ParallelToolCalls bool `json:"parallel_tool_calls"` + PreviousResponseID string `json:"previous_response_id"` + Reasoning *Reasoning `json:"reasoning"` + Store bool `json:"store"` + Temperature float64 `json:"temperature"` + ToolChoice string `json:"tool_choice"` + Tools []ResponsesToolsCall `json:"tools"` + TopP float64 `json:"top_p"` + Truncation string `json:"truncation"` + Usage *Usage `json:"usage"` + User json.RawMessage `json:"user"` + Metadata json.RawMessage `json:"metadata"` } type IncompleteDetails struct { diff --git a/relay/channel/openai/relay_responses.go b/relay/channel/openai/relay_responses.go index 6af8c676..1d1e060e 100644 --- a/relay/channel/openai/relay_responses.go +++ b/relay/channel/openai/relay_responses.go @@ -3,7 +3,6 @@ package openai import ( "bytes" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" @@ -12,6 +11,8 @@ import ( "one-api/relay/helper" "one-api/service" "strings" + + "github.com/gin-gonic/gin" ) func OaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { @@ -61,6 +62,10 @@ func OaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon. usage.PromptTokens = responsesResponse.Usage.InputTokens usage.CompletionTokens = responsesResponse.Usage.OutputTokens usage.TotalTokens = responsesResponse.Usage.TotalTokens + // 解析 Tools 用量 + for _, tool := range responsesResponse.Tools { + info.ResponsesUsageInfo.BuiltInTools[tool.Type].CallCount++ + } return nil, &usage } diff --git a/relay/relay-text.go b/relay/relay-text.go index 4fdd435d..a528ec52 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -358,6 +358,67 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, ratio := dModelRatio.Mul(dGroupRatio) + // openai web search 工具计费 + var dWebSearchQuota decimal.Decimal + if relayInfo.ResponsesUsageInfo != nil { + if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists && webSearchTool.CallCount > 0 { + // 确定模型类型 + // https://platform.openai.com/docs/pricing Web search 价格按模型类型和 search context size 收费 + // gpt-4.1, gpt-4o, or gpt-4o-search-preview 更贵,gpt-4.1-mini, gpt-4o-mini, gpt-4o-mini-search-preview 更便宜 + isHighTierModel := (strings.HasPrefix(modelName, "gpt-4.1") || strings.HasPrefix(modelName, "gpt-4o")) && + !strings.Contains(modelName, "mini") + + // 确定 search context size 对应的价格 + var priceWebSearchPerThousandCalls float64 + switch webSearchTool.SearchContextSize { + case "low": + if isHighTierModel { + priceWebSearchPerThousandCalls = 30.0 + } else { + priceWebSearchPerThousandCalls = 25.0 + } + case "medium": + if isHighTierModel { + priceWebSearchPerThousandCalls = 35.0 + } else { + priceWebSearchPerThousandCalls = 27.5 + } + case "high": + if isHighTierModel { + priceWebSearchPerThousandCalls = 50.0 + } else { + priceWebSearchPerThousandCalls = 30.0 + } + default: + // search context size 默认为 medium + if isHighTierModel { + priceWebSearchPerThousandCalls = 35.0 + } else { + priceWebSearchPerThousandCalls = 27.5 + } + } + // 计算 web search 调用的配额 (配额 = 价格 * 调用次数 / 1000) + dWebSearchQuota = decimal.NewFromFloat(priceWebSearchPerThousandCalls). + Mul(decimal.NewFromInt(int64(webSearchTool.CallCount))). + Div(decimal.NewFromInt(1000)) + extraContent += fmt.Sprintf("Web Search 调用 %d 次,上下文大小 %s,调用花费 $%s", + webSearchTool.CallCount, webSearchTool.SearchContextSize, dWebSearchQuota.String()) + } + } + // file search tool 计费 + var dFileSearchQuota decimal.Decimal + if relayInfo.ResponsesUsageInfo != nil { + if fileSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolFileSearch]; exists && fileSearchTool.CallCount > 0 { + // file search tool 调用价格 $2.50/1k calls + // 计算 file search tool 调用的配额 (配额 = 价格 * 调用次数 / 1000) + dFileSearchQuota = decimal.NewFromFloat(2.5). + Mul(decimal.NewFromInt(int64(fileSearchTool.CallCount))). + Div(decimal.NewFromInt(1000)) + extraContent += fmt.Sprintf("File Search 调用 %d 次,调用花费 $%s", + fileSearchTool.CallCount, dFileSearchQuota.String()) + } + } + var quotaCalculateDecimal decimal.Decimal if !priceData.UsePrice { nonCachedTokens := dPromptTokens.Sub(dCacheTokens) @@ -380,6 +441,9 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, } else { quotaCalculateDecimal = dModelPrice.Mul(dQuotaPerUnit).Mul(dGroupRatio) } + // 添加 responses tools call 调用的配额 + quotaCalculateDecimal = quotaCalculateDecimal.Add(dWebSearchQuota) + quotaCalculateDecimal = quotaCalculateDecimal.Add(dFileSearchQuota) quota := int(quotaCalculateDecimal.Round(0).IntPart()) totalTokens := promptTokens + completionTokens @@ -430,6 +494,19 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, other["image_ratio"] = imageRatio other["image_output"] = imageTokens } + if !dWebSearchQuota.IsZero() && relayInfo.ResponsesUsageInfo != nil { + if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists { + other["web_search"] = true + other["web_search_call_count"] = webSearchTool.CallCount + other["web_search_context_size"] = webSearchTool.SearchContextSize + } + } + if !dFileSearchQuota.IsZero() && relayInfo.ResponsesUsageInfo != nil { + if fileSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolFileSearch]; exists { + other["file_search"] = true + other["file_search_call_count"] = fileSearchTool.CallCount + } + } model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel, tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other) } From d859e3fa645672ca3e38e97654a2de30c6bbd577 Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Tue, 6 May 2025 22:28:32 +0800 Subject: [PATCH 024/105] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E6=9C=AA?= =?UTF-8?q?=E8=BE=93=E5=85=A5=E6=96=B0=E5=AF=86=E7=A0=81=E6=97=B6=E6=8F=90?= =?UTF-8?q?=E7=A4=BA=E4=BF=AE=E6=94=B9=E6=88=90=E5=8A=9F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- controller/user.go | 6 ++--- web/src/components/PersonalSetting.js | 35 ++++++++++++++++++++------- 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/controller/user.go b/controller/user.go index 567c2aa7..fd53e743 100644 --- a/controller/user.go +++ b/controller/user.go @@ -616,9 +616,6 @@ func UpdateSelf(c *gin.Context) { } func checkUpdatePassword(originalPassword string, newPassword string, userId int) (updatePassword bool, err error) { - if newPassword == "" { - return - } var currentUser *model.User currentUser, err = model.GetUserById(userId, true) if err != nil { @@ -628,6 +625,9 @@ func checkUpdatePassword(originalPassword string, newPassword string, userId int err = fmt.Errorf("原密码错误") return } + if newPassword == "" { + return + } updatePassword = true return } diff --git a/web/src/components/PersonalSetting.js b/web/src/components/PersonalSetting.js index fbd74536..0f52c319 100644 --- a/web/src/components/PersonalSetting.js +++ b/web/src/components/PersonalSetting.js @@ -244,6 +244,10 @@ const PersonalSetting = () => { showError(t('请输入原密码!')); return; } + if (inputs.set_new_password === '') { + showError(t('请输入新密码!')); + return; + } if (inputs.original_password === inputs.set_new_password) { showError(t('新密码需要和原密码不一致!')); return; @@ -826,8 +830,8 @@ const PersonalSetting = () => {
- - + +
{t('通知方式')}
@@ -1003,23 +1007,36 @@ const PersonalSetting = () => {
- +
- {t('接受未设置价格模型')} + + {t('接受未设置价格模型')} +
handleNotificationSettingChange('acceptUnsetModelRatioModel', e.target.checked)} + checked={ + notificationSettings.acceptUnsetModelRatioModel + } + onChange={(e) => + handleNotificationSettingChange( + 'acceptUnsetModelRatioModel', + e.target.checked, + ) + } > {t('接受未设置价格模型')} - - {t('当模型没有设置价格时仍接受调用,仅当您信任该网站时使用,可能会产生高额费用')} + + {t( + '当模型没有设置价格时仍接受调用,仅当您信任该网站时使用,可能会产生高额费用', + )}
-
@@ -799,7 +812,13 @@ const SystemSetting = () => { onChange={(value) => setEmailToAdd(value)} style={{ marginTop: 16 }} suffix={ - + } onEnterPress={handleAddEmail} /> From ec615342569d0a58c93b2ca852b3bb9f51db2aab Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Fri, 9 May 2025 13:57:00 +0800 Subject: [PATCH 049/105] feat: send SSE ping before get response --- relay/channel/api_request.go | 66 ++++++++++++++++++- relay/helper/common.go | 18 ++++-- relay/helper/stream_scanner.go | 115 +++++++++++++-------------------- relay/relay-text.go | 10 +-- 4 files changed, 122 insertions(+), 87 deletions(-) diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go index 8b2ca889..db5d4f44 100644 --- a/relay/channel/api_request.go +++ b/relay/channel/api_request.go @@ -1,16 +1,23 @@ package channel import ( + "context" "errors" "fmt" - "github.com/gin-gonic/gin" - "github.com/gorilla/websocket" "io" "net/http" common2 "one-api/common" "one-api/relay/common" "one-api/relay/constant" + "one-api/relay/helper" "one-api/service" + "one-api/setting/operation_setting" + "sync" + "time" + + "github.com/bytedance/gopkg/util/gopool" + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" ) func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Header) { @@ -105,7 +112,62 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http } else { client = service.GetHttpClient() } + // 流式请求 ping 保活 + var stopPinger func() + generalSettings := operation_setting.GetGeneralSetting() + pingEnabled := generalSettings.PingIntervalEnabled + var pingerWg sync.WaitGroup + if info.IsStream { + helper.SetEventStreamHeaders(c) + pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second + var pingerCtx context.Context + pingerCtx, stopPinger = context.WithCancel(c.Request.Context()) + + if pingEnabled { + pingerWg.Add(1) + gopool.Go(func() { + defer pingerWg.Done() + if pingInterval <= 0 { + pingInterval = helper.DefaultPingInterval + } + + ticker := time.NewTicker(pingInterval) + defer ticker.Stop() + var pingMutex sync.Mutex + if common2.DebugEnabled { + println("SSE ping goroutine started") + } + + for { + select { + case <-ticker.C: + pingMutex.Lock() + err2 := helper.PingData(c) + pingMutex.Unlock() + if err2 != nil { + common2.LogError(c, "SSE ping error: "+err.Error()) + return + } + if common2.DebugEnabled { + println("SSE ping data sent.") + } + case <-pingerCtx.Done(): + if common2.DebugEnabled { + println("SSE ping goroutine stopped.") + } + return + } + } + }) + } + } + resp, err := client.Do(req) + // request结束后停止ping + if info.IsStream && pingEnabled { + stopPinger() + pingerWg.Wait() + } if err != nil { return nil, err } diff --git a/relay/helper/common.go b/relay/helper/common.go index 0a3aba1e..35d983f7 100644 --- a/relay/helper/common.go +++ b/relay/helper/common.go @@ -12,11 +12,19 @@ import ( ) func SetEventStreamHeaders(c *gin.Context) { - c.Writer.Header().Set("Content-Type", "text/event-stream") - c.Writer.Header().Set("Cache-Control", "no-cache") - c.Writer.Header().Set("Connection", "keep-alive") - c.Writer.Header().Set("Transfer-Encoding", "chunked") - c.Writer.Header().Set("X-Accel-Buffering", "no") + // 检查是否已经设置过头部 + if _, exists := c.Get("event_stream_headers_set"); exists { + return + } + + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("Transfer-Encoding", "chunked") + c.Writer.Header().Set("X-Accel-Buffering", "no") + + // 设置标志,表示头部已经设置过 + c.Set("event_stream_headers_set", true) } func ClaudeData(c *gin.Context, resp dto.ClaudeResponse) error { diff --git a/relay/helper/stream_scanner.go b/relay/helper/stream_scanner.go index ce4d3a6d..c1bc0d6e 100644 --- a/relay/helper/stream_scanner.go +++ b/relay/helper/stream_scanner.go @@ -3,7 +3,6 @@ package helper import ( "bufio" "context" - "github.com/bytedance/gopkg/util/gopool" "io" "net/http" "one-api/common" @@ -14,6 +13,8 @@ import ( "sync" "time" + "github.com/bytedance/gopkg/util/gopool" + "github.com/gin-gonic/gin" ) @@ -23,76 +24,6 @@ const ( DefaultPingInterval = 10 * time.Second ) -type DoRequestFunc func(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) - -// Optional SSE Ping keep-alive mechanism -// -// Used to solve the problem of the connection with the client timing out due to no data being sent when the upstream -// channel response time is long (e.g., thinking model). -// When enabled, it will send ping data packets to the client via SSE at the specified interval to maintain the connection. -func DoStreamRequestWithPinger(doRequest DoRequestFunc, c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { - SetEventStreamHeaders(c) - - generalSettings := operation_setting.GetGeneralSetting() - pingEnabled := generalSettings.PingIntervalEnabled - pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second - - pingerCtx, stopPinger := context.WithCancel(c.Request.Context()) - var pingerWg sync.WaitGroup - var doRequestErr error - var resp any - - if pingEnabled { - pingerWg.Add(1) - - gopool.Go(func() { - defer pingerWg.Done() - - if pingInterval <= 0 { - pingInterval = DefaultPingInterval - } - - ticker := time.NewTicker(pingInterval) - defer ticker.Stop() - var pingMutex sync.Mutex - - if common.DebugEnabled { - println("SSE ping goroutine started.") - } - - for { - select { - case <-ticker.C: - pingMutex.Lock() - err := PingData(c) - pingMutex.Unlock() - if err != nil { - common.LogError(c, "SSE ping error: "+err.Error()) - return - } - if common.DebugEnabled { - println("SSE ping data sent.") - } - case <-pingerCtx.Done(): - if common.DebugEnabled { - println("SSE ping goroutine stopped.") - } - return - } - } - }) - } - - resp, doRequestErr = doRequest(c, info, requestBody) - - stopPinger() - if pingEnabled { - pingerWg.Wait() - } - - return resp, doRequestErr -} - func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, dataHandler func(data string) bool) { if resp == nil || dataHandler == nil { @@ -111,11 +42,26 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon stopChan = make(chan bool, 2) scanner = bufio.NewScanner(resp.Body) ticker = time.NewTicker(streamingTimeout) + pingTicker *time.Ticker writeMutex sync.Mutex // Mutex to protect concurrent writes ) + generalSettings := operation_setting.GetGeneralSetting() + pingEnabled := generalSettings.PingIntervalEnabled + pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second + if pingInterval <= 0 { + pingInterval = DefaultPingInterval + } + + if pingEnabled { + pingTicker = time.NewTicker(pingInterval) + } + defer func() { ticker.Stop() + if pingTicker != nil { + pingTicker.Stop() + } close(stopChan) }() scanner.Buffer(make([]byte, InitialScannerBufferSize), MaxScannerBufferSize) @@ -127,6 +73,33 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon ctx = context.WithValue(ctx, "stop_chan", stopChan) + // Handle ping data sending + if pingEnabled && pingTicker != nil { + gopool.Go(func() { + for { + select { + case <-pingTicker.C: + writeMutex.Lock() // Lock before writing + err := PingData(c) + writeMutex.Unlock() // Unlock after writing + if err != nil { + common.LogError(c, "ping data error: "+err.Error()) + common.SafeSendBool(stopChan, true) + return + } + if common.DebugEnabled { + println("ping data sent") + } + case <-ctx.Done(): + if common.DebugEnabled { + println("ping data goroutine stopped") + } + return + } + } + }) + } + common.RelayCtxGo(ctx, func() { for scanner.Scan() { ticker.Reset(streamingTimeout) diff --git a/relay/relay-text.go b/relay/relay-text.go index 69a48637..8d5cd384 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -193,15 +193,7 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { } var httpResp *http.Response - var resp any - - if relayInfo.IsStream { - // Streaming requests can use SSE ping to keep alive and avoid connection timeout - // The judgment of whether ping is enabled will be made within the function - resp, err = helper.DoStreamRequestWithPinger(adaptor.DoRequest, c, relayInfo, requestBody) - } else { - resp, err = adaptor.DoRequest(c, relayInfo, requestBody) - } + resp, err := adaptor.DoRequest(c, relayInfo, requestBody) if err != nil { return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) From 40efa73a42e5d7cc943ca46a9f087ea11030f101 Mon Sep 17 00:00:00 2001 From: skynono <6811626@qq.com> Date: Fri, 9 May 2025 17:11:25 +0800 Subject: [PATCH 050/105] fix: correct formatting string in PriceData.ToSetting to handle ImageRatio as float instead of integer --- relay/helper/price.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/relay/helper/price.go b/relay/helper/price.go index 899c72b9..89efa1da 100644 --- a/relay/helper/price.go +++ b/relay/helper/price.go @@ -23,7 +23,7 @@ type PriceData struct { } func (p PriceData) ToSetting() string { - return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %d", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio) + return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio) } func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) { From 9ebfcaf6aa3ec55078121eba172a57530751dd81 Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Fri, 9 May 2025 18:11:37 +0800 Subject: [PATCH 051/105] feat: change azure default api version to 2025-04-01-preview --- README.en.md | 2 +- README.md | 2 +- constant/env.go | 2 +- web/src/i18n/locales/en.json | 4 +- web/src/pages/Channel/EditChannel.js | 67 +++++++++++++++++----------- 5 files changed, 45 insertions(+), 32 deletions(-) diff --git a/README.en.md b/README.en.md index 23fdbe1f..4709bc5b 100644 --- a/README.en.md +++ b/README.en.md @@ -107,7 +107,7 @@ For detailed configuration instructions, please refer to [Installation Guide-Env - `GEMINI_VISION_MAX_IMAGE_NUM`: Maximum number of images for Gemini models, default is `16` - `MAX_FILE_DOWNLOAD_MB`: Maximum file download size in MB, default is `20` - `CRYPTO_SECRET`: Encryption key used for encrypting database content -- `AZURE_DEFAULT_API_VERSION`: Azure channel default API version, default is `2024-12-01-preview` +- `AZURE_DEFAULT_API_VERSION`: Azure channel default API version, default is `2025-04-01-preview` - `NOTIFICATION_LIMIT_DURATION_MINUTE`: Notification limit duration, default is `10` minutes - `NOTIFY_LIMIT_COUNT`: Maximum number of user notifications within the specified duration, default is `2` diff --git a/README.md b/README.md index 67af9916..a807b07d 100644 --- a/README.md +++ b/README.md @@ -107,7 +107,7 @@ New API提供了丰富的功能,详细特性请参考[特性说明](https://do - `GEMINI_VISION_MAX_IMAGE_NUM`:Gemini模型最大图片数量,默认 `16` - `MAX_FILE_DOWNLOAD_MB`: 最大文件下载大小,单位MB,默认 `20` - `CRYPTO_SECRET`:加密密钥,用于加密数据库内容 -- `AZURE_DEFAULT_API_VERSION`:Azure渠道默认API版本,默认 `2024-12-01-preview` +- `AZURE_DEFAULT_API_VERSION`:Azure渠道默认API版本,默认 `2025-04-01-preview` - `NOTIFICATION_LIMIT_DURATION_MINUTE`:通知限制持续时间,默认 `10`分钟 - `NOTIFY_LIMIT_COUNT`:用户通知在指定持续时间内的最大数量,默认 `2` diff --git a/constant/env.go b/constant/env.go index fae48625..612f3e8b 100644 --- a/constant/env.go +++ b/constant/env.go @@ -31,7 +31,7 @@ func InitEnv() { GetMediaToken = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true) GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true) UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true) - AzureDefaultAPIVersion = common.GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2024-12-01-preview") + AzureDefaultAPIVersion = common.GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2025-04-01-preview") GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16) NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2) NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10) diff --git a/web/src/i18n/locales/en.json b/web/src/i18n/locales/en.json index eedf1196..916329e7 100644 --- a/web/src/i18n/locales/en.json +++ b/web/src/i18n/locales/en.json @@ -1086,7 +1086,7 @@ "没有账户?": "No account? ", "请输入 AZURE_OPENAI_ENDPOINT,例如:https://docs-test-001.openai.azure.com": "Please enter AZURE_OPENAI_ENDPOINT, e.g.: https://docs-test-001.openai.azure.com", "默认 API 版本": "Default API Version", - "请输入默认 API 版本,例如:2024-12-01-preview": "Please enter default API version, e.g.: 2024-12-01-preview.", + "请输入默认 API 版本,例如:2025-04-01-preview": "Please enter default API version, e.g.: 2025-04-01-preview.", "请为渠道命名": "Please name the channel", "请选择可以使用该渠道的分组": "Please select groups that can use this channel", "请在系统设置页面编辑分组倍率以添加新的分组:": "Please edit Group ratios in system settings to add new groups:", @@ -1374,4 +1374,4 @@ "适用于展示系统功能的场景。": "Suitable for scenarios where the system functions are displayed.", "可在初始化后修改": "Can be modified after initialization", "初始化系统": "Initialize system" -} +} \ No newline at end of file diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index cba787fc..fd96ffb6 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -24,7 +24,8 @@ import { TextArea, Checkbox, Banner, - Modal, ImagePreview + Modal, + ImagePreview, } from '@douyinfe/semi-ui'; import { getChannelModels, loadChannelModels } from '../../components/utils.js'; import { IconHelpCircle } from '@douyinfe/semi-icons'; @@ -306,7 +307,7 @@ const EditChannel = (props) => { fetchModels().then(); fetchGroups().then(); if (isEdit) { - loadChannel().then(() => { }); + loadChannel().then(() => {}); } else { setInputs(originInputs); let localModels = getChannelModels(inputs.type); @@ -477,7 +478,9 @@ const EditChannel = (props) => { type={'warning'} description={ <> - {t('2025年5月10日后添加的渠道,不需要再在部署的时候移除模型名称中的"."')} + {t( + '2025年5月10日后添加的渠道,不需要再在部署的时候移除模型名称中的"."', + )} {/*
*/} {/* { { handleInputChange('other', value); }} @@ -584,25 +587,35 @@ const EditChannel = (props) => { value={inputs.name} autoComplete='new-password' /> - {inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && inputs.type !== 36 && inputs.type !== 45 && ( - <> -
- {t('API地址')}: -
- - { - handleInputChange('base_url', value); - }} - value={inputs.base_url} - autoComplete="new-password" - /> - - - )} + {inputs.type !== 3 && + inputs.type !== 8 && + inputs.type !== 22 && + inputs.type !== 36 && + inputs.type !== 45 && ( + <> +
+ {t('API地址')}: +
+ + { + handleInputChange('base_url', value); + }} + value={inputs.base_url} + autoComplete='new-password' + /> + + + )}
{t('密钥')}:
@@ -761,10 +774,10 @@ const EditChannel = (props) => { name='other' placeholder={t( '请输入部署地区,例如:us-central1\n支持使用模型映射格式\n' + - '{\n' + - ' "default": "us-central1",\n' + - ' "claude-3-5-sonnet-20240620": "europe-west1"\n' + - '}', + '{\n' + + ' "default": "us-central1",\n' + + ' "claude-3-5-sonnet-20240620": "europe-west1"\n' + + '}', )} autosize={{ minRows: 2 }} onChange={(value) => { From 0d929800cf40f483679684a48c430a163775cf48 Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Fri, 9 May 2025 18:13:19 +0800 Subject: [PATCH 052/105] fix: GetRequestURL remove unnecessary case --- relay/channel/openai/adaptor.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index da92692b..f0cf073f 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -67,9 +67,6 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if info.RelayFormat == relaycommon.RelayFormatClaude { return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil } - if info.RelayMode == constant.RelayModeResponses { - return fmt.Sprintf("%s/v1/responses", info.BaseUrl), nil - } if info.RelayMode == constant.RelayModeRealtime { if strings.HasPrefix(info.BaseUrl, "https://") { baseUrl := strings.TrimPrefix(info.BaseUrl, "https://") From 7b176015b82a3ec0276b24805f2498a49da48aa9 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Fri, 9 May 2025 18:57:06 +0800 Subject: [PATCH 053/105] feat: enhance OpenAI handler to support forced response formatting and add debug logging for request URLs --- relay/channel/api_request.go | 3 +++ relay/channel/openai/relay-openai.go | 39 ++++++++++++++++++---------- 2 files changed, 29 insertions(+), 13 deletions(-) diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go index db5d4f44..03eff9cf 100644 --- a/relay/channel/api_request.go +++ b/relay/channel/api_request.go @@ -62,6 +62,9 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod if err != nil { return nil, fmt.Errorf("get request url failed: %w", err) } + if common2.DebugEnabled { + println("fullRequestURL:", fullRequestURL) + } req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) if err != nil { return nil, fmt.Errorf("new request failed: %w", err) diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index b9ed94e2..86c47a15 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -215,10 +215,35 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI StatusCode: resp.StatusCode, }, nil } + + forceFormat := false + if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok { + forceFormat = forceFmt + } + + if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) { + completionTokens := 0 + for _, choice := range simpleResponse.Choices { + ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName) + completionTokens += ctkm + } + simpleResponse.Usage = dto.Usage{ + PromptTokens: info.PromptTokens, + CompletionTokens: completionTokens, + TotalTokens: info.PromptTokens + completionTokens, + } + } switch info.RelayFormat { case relaycommon.RelayFormatOpenAI: - break + if forceFormat { + responseBody, err = json.Marshal(simpleResponse) + if err != nil { + return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + } else { + break + } case relaycommon.RelayFormatClaude: claudeResp := service.ResponseOpenAI2Claude(&simpleResponse, info) claudeRespStr, err := json.Marshal(claudeResp) @@ -244,18 +269,6 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI common.SysError("error copying response body: " + err.Error()) } resp.Body.Close() - if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) { - completionTokens := 0 - for _, choice := range simpleResponse.Choices { - ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName) - completionTokens += ctkm - } - simpleResponse.Usage = dto.Usage{ - PromptTokens: info.PromptTokens, - CompletionTokens: completionTokens, - TotalTokens: info.PromptTokens + completionTokens, - } - } return nil, &simpleResponse.Usage } From 28cdfc0a14e95602b8263a6eab7e6a0a90088fe3 Mon Sep 17 00:00:00 2001 From: a37836323 <37836323@qq.com> Date: Sat, 10 May 2025 04:33:49 +0800 Subject: [PATCH 054/105] =?UTF-8?q?=E6=B7=BB=E5=8A=A0DALL-E=E5=9B=BE?= =?UTF-8?q?=E5=83=8F=E7=94=9F=E6=88=90=E8=AF=B7=E6=B1=82=E4=B8=AD=E7=9A=84?= =?UTF-8?q?Background=E5=92=8CModeration=E5=AD=97=E6=AE=B5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dto/dalle.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dto/dalle.go b/dto/dalle.go index 562d5f1a..44104d33 100644 --- a/dto/dalle.go +++ b/dto/dalle.go @@ -12,6 +12,8 @@ type ImageRequest struct { Style string `json:"style,omitempty"` User string `json:"user,omitempty"` ExtraFields json.RawMessage `json:"extra_fields,omitempty"` + Background string `json:"background,omitempty"` + Moderation string `json:"moderation,omitempty"` } type ImageResponse struct { From 58dc7ad770dcd6f5595aeac1c91194761fccfbc2 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Sat, 10 May 2025 15:52:41 +0800 Subject: [PATCH 055/105] feat: add moderation and background fields to ImageRequest struct in dalle.go #1052 --- dto/dalle.go | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/dto/dalle.go b/dto/dalle.go index 562d5f1a..ab2c94e1 100644 --- a/dto/dalle.go +++ b/dto/dalle.go @@ -1,17 +1,16 @@ package dto -import "encoding/json" - type ImageRequest struct { - Model string `json:"model"` - Prompt string `json:"prompt" binding:"required"` - N int `json:"n,omitempty"` - Size string `json:"size,omitempty"` - Quality string `json:"quality,omitempty"` - ResponseFormat string `json:"response_format,omitempty"` - Style string `json:"style,omitempty"` - User string `json:"user,omitempty"` - ExtraFields json.RawMessage `json:"extra_fields,omitempty"` + Model string `json:"model"` + Prompt string `json:"prompt" binding:"required"` + N int `json:"n,omitempty"` + Size string `json:"size,omitempty"` + Quality string `json:"quality,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` + Style string `json:"style,omitempty"` + User string `json:"user,omitempty"` + Moderation string `json:"moderation,omitempty"` + Background string `json:"background,omitempty"` } type ImageResponse struct { From d985563516a10806284254824fe7cb4ca9676ec4 Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Sun, 11 May 2025 17:00:33 +0800 Subject: [PATCH 056/105] feat: add support for socks5h --- service/http_client.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/service/http_client.go b/service/http_client.go index c3f8df7a..64a361cf 100644 --- a/service/http_client.go +++ b/service/http_client.go @@ -3,12 +3,13 @@ package service import ( "context" "fmt" - "golang.org/x/net/proxy" "net" "net/http" "net/url" "one-api/common" "time" + + "golang.org/x/net/proxy" ) var httpClient *http.Client @@ -55,7 +56,7 @@ func NewProxyHttpClient(proxyURL string) (*http.Client, error) { }, }, nil - case "socks5": + case "socks5", "socks5h": // 获取认证信息 var auth *proxy.Auth if parsedURL.User != nil { @@ -69,6 +70,7 @@ func NewProxyHttpClient(proxyURL string) (*http.Client, error) { } // 创建 SOCKS5 代理拨号器 + // proxy.SOCKS5 使用 tcp 参数,所有 TCP 连接包括 DNS 查询都将通过代理进行。行为与 socks5h 相同 dialer, err := proxy.SOCKS5("tcp", parsedURL.Host, auth, proxy.Direct) if err != nil { return nil, err From b2cad229520ab533f1981daefe9a478502ddb31f Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Tue, 13 May 2025 12:52:22 +0800 Subject: [PATCH 057/105] add coze request --- common/constants.go | 2 + relay/channel/coze/adaptor.go | 125 +++++++++++++++++++++++++++++++ relay/channel/coze/constants.go | 8 ++ relay/channel/coze/dto.go | 81 ++++++++++++++++++++ relay/channel/coze/relay-coze.go | 121 ++++++++++++++++++++++++++++++ relay/constant/api_type.go | 3 + relay/relay_adaptor.go | 3 + 7 files changed, 343 insertions(+) create mode 100644 relay/channel/coze/adaptor.go create mode 100644 relay/channel/coze/constants.go create mode 100644 relay/channel/coze/dto.go create mode 100644 relay/channel/coze/relay-coze.go diff --git a/common/constants.go b/common/constants.go index dd4f3b04..bee00506 100644 --- a/common/constants.go +++ b/common/constants.go @@ -240,6 +240,7 @@ const ( ChannelTypeBaiduV2 = 46 ChannelTypeXinference = 47 ChannelTypeXai = 48 + ChannelTypeCoze = 49 ChannelTypeDummy // this one is only for count, do not add any channel after this ) @@ -294,4 +295,5 @@ var ChannelBaseURLs = []string{ "https://qianfan.baidubce.com", //46 "", //47 "https://api.x.ai", //48 + "https://api.coze.cn", //49 } diff --git a/relay/channel/coze/adaptor.go b/relay/channel/coze/adaptor.go new file mode 100644 index 00000000..b14239a6 --- /dev/null +++ b/relay/channel/coze/adaptor.go @@ -0,0 +1,125 @@ +package coze + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "one-api/dto" + "one-api/relay/channel" + "one-api/relay/common" + "time" + + "github.com/gin-gonic/gin" +) + +type Adaptor struct { +} + +// ConvertAudioRequest implements channel.Adaptor. +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *common.RelayInfo, request dto.AudioRequest) (io.Reader, error) { + return nil, errors.New("not implemented") +} + +// ConvertClaudeRequest implements channel.Adaptor. +func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *common.RelayInfo, request *dto.ClaudeRequest) (any, error) { + return nil, errors.New("not implemented") +} + +// ConvertEmbeddingRequest implements channel.Adaptor. +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *common.RelayInfo, request dto.EmbeddingRequest) (any, error) { + return nil, errors.New("not implemented") +} + +// ConvertImageRequest implements channel.Adaptor. +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *common.RelayInfo, request dto.ImageRequest) (any, error) { + return nil, errors.New("not implemented") +} + +// ConvertOpenAIRequest implements channel.Adaptor. +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *common.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return convertCozeChatRequest(*request), nil +} + +// ConvertOpenAIResponsesRequest implements channel.Adaptor. +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *common.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + return nil, errors.New("not implemented") +} + +// ConvertRerankRequest implements channel.Adaptor. +func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { + return nil, errors.New("not implemented") +} + +// DoRequest implements channel.Adaptor. +func (a *Adaptor) DoRequest(c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (any, error) { + // 首先发送创建消息请求,成功后再发送获取消息请求 + // 发送创建消息请求 + resp, err := channel.DoApiRequest(a, c, info, requestBody) + if err != nil { + return nil, err + } + // 解析 resp + var cozeResponse CozeChatResponse + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + err = json.Unmarshal(respBody, &cozeResponse) + if cozeResponse.Code != 0 { + return nil, errors.New(cozeResponse.Msg) + } + c.Set("coze_conversation_id", cozeResponse.Data.ConversationId) + c.Set("coze_chat_id", cozeResponse.Data.Id) + // 轮询检查消息是否完成 + for { + err, isComplete := checkIfChatComplete(a, c, info) + if err != nil { + return nil, err + } else { + if isComplete { + break + } + } + time.Sleep(time.Second * 1) + } + // 发送获取消息请求 + return channel.DoApiRequest(a, c, info, requestBody) +} + +// DoResponse implements channel.Adaptor. +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *common.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { + err, usage = cozeChatHandler(c, resp, info) + return +} + +// GetChannelName implements channel.Adaptor. +func (a *Adaptor) GetChannelName() string { + return ChannelName +} + +// GetModelList implements channel.Adaptor. +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +// GetRequestURL implements channel.Adaptor. +func (a *Adaptor) GetRequestURL(info *common.RelayInfo) (string, error) { + return fmt.Sprintf("%s/v3/chat/message/list", info.BaseUrl), nil +} + +// Init implements channel.Adaptor. +func (a *Adaptor) Init(info *common.RelayInfo) { + +} + +// SetupRequestHeader implements channel.Adaptor. +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *common.RelayInfo) error { + channel.SetupApiRequestHeader(info, c, req) + req.Set("Authorization", "Bearer "+info.ApiKey) + return nil +} diff --git a/relay/channel/coze/constants.go b/relay/channel/coze/constants.go new file mode 100644 index 00000000..da28cb83 --- /dev/null +++ b/relay/channel/coze/constants.go @@ -0,0 +1,8 @@ +package coze + +var ModelList = []string{ + // TODO: 完整列表 + "deepseek-v3", +} + +var ChannelName = "coze" diff --git a/relay/channel/coze/dto.go b/relay/channel/coze/dto.go new file mode 100644 index 00000000..fb92289a --- /dev/null +++ b/relay/channel/coze/dto.go @@ -0,0 +1,81 @@ +package coze + +import "encoding/json" + +// type CozeResponse struct { +// Code int `json:"code"` +// Message string `json:"message"` +// Data CozeConversationData `json:"data"` +// Detail CozeConversationData `json:"detail"` +// } + +// type CozeConversationData struct { +// Id string `json:"id"` +// CreatedAt int64 `json:"created_at"` +// MetaData json.RawMessage `json:"meta_data"` +// LastSectionId string `json:"last_section_id"` +// } + +// type CozeResponseDetail struct { +// Logid string `json:"logid"` +// } + +type CozeError struct { + Code int `json:"code"` + Message string `json:"message"` +} + +// type CozeErrorWithStatusCode struct { +// Error CozeError `json:"error"` +// StatusCode int +// LocalError bool +// } + +type CozeRequest struct { + BotId string `json:"bot_id,omitempty"` + MetaData json.RawMessage `json:"meta_data,omitempty"` + Messages []CozeEnterMessage `json:"messages,omitempty"` +} + +type CozeEnterMessage struct { + Role string `json:"role"` + Type string `json:"type,omitempty"` + Content json.RawMessage `json:"content,omitempty"` + MetaData json.RawMessage `json:"meta_data,omitempty"` + ContentType string `json:"content_type,omitempty"` +} + +type CozeChatRequest struct { + BotId string `json:"bot_id"` + UserId string `json:"user_id"` + AdditionalMessages []CozeEnterMessage `json:"additional_messages,omitempty"` + Stream bool `json:"stream,omitempty"` + CustomVariables json.RawMessage `json:"custom_variables,omitempty"` + AutoSaveHistory bool `json:"auto_save_history,omitempty"` + MetaData json.RawMessage `json:"meta_data,omitempty"` + ExtraParams json.RawMessage `json:"extra_params,omitempty"` + ShortcutCommand json.RawMessage `json:"shortcut_command,omitempty"` + Parameters json.RawMessage `json:"parameters,omitempty"` +} + +type CozeChatResponse struct { + Code int `json:"code"` + Msg string `json:"msg"` + Data CozeChatResponseData `json:"data"` +} + +type CozeChatResponseData struct { + Id string `json:"id"` + ConversationId string `json:"conversation_id"` + BotId string `json:"bot_id"` + CreatedAt int64 `json:"created_at"` + LastError CozeError `json:"last_error"` + Status string `json:"status"` + Usage CozeChatUsage `json:"usage"` +} + +type CozeChatUsage struct { + TokenCount int `json:"token_count"` + OutputCount int `json:"output_count"` + InputCount int `json:"input_count"` +} diff --git a/relay/channel/coze/relay-coze.go b/relay/channel/coze/relay-coze.go new file mode 100644 index 00000000..49a3ac15 --- /dev/null +++ b/relay/channel/coze/relay-coze.go @@ -0,0 +1,121 @@ +package coze + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "one-api/dto" + "one-api/relay/common" + relaycommon "one-api/relay/common" + "one-api/service" + + "github.com/gin-gonic/gin" +) + +func convertCozeChatRequest(request dto.GeneralOpenAIRequest) *CozeRequest { + var messages []CozeEnterMessage + // 将 request的messages的role为user的content转换为CozeMessage + for _, message := range request.Messages { + if message.Role == "user" { + messages = append(messages, CozeEnterMessage{ + Role: "user", + Content: message.Content, + // TODO: support more content type + ContentType: "text", + }) + } + } + cozeRequest := &CozeRequest{ + // TODO: model to botid + BotId: "1", + Messages: messages, + } + return cozeRequest +} + +func cozeChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + // convert coze response to openai response + var response dto.TextResponse + var cozeResponse CozeChatResponse + err = json.Unmarshal(responseBody, &cozeResponse) + if err != nil { + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + response.Model = info.UpstreamModelName + // TODO: 处理 cozeResponse + return nil, nil +} + +func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (error, bool) { + requestURL := fmt.Sprintf("%s/v3/chat/retrieve", info.BaseUrl) + + requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id") + // 将 conversationId和chatId作为参数发送get请求 + req, err := http.NewRequest("GET", requestURL, nil) + if err != nil { + return err, false + } + err = a.SetupRequestHeader(c, &req.Header, info) + if err != nil { + return err, false + } + + resp, err := doRequest(req, info) // 调用 doRequest + if err != nil { + return err, false + } + if resp == nil { // 确保在 doRequest 失败时 resp 不为 nil 导致 panic + return fmt.Errorf("resp is nil"), false + } + defer resp.Body.Close() // 确保响应体被关闭 + + // 解析 resp 到 CozeChatResponse + var cozeResponse CozeChatResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("read response body failed: %w", err), false + } + err = json.Unmarshal(responseBody, &cozeResponse) + if err != nil { + return fmt.Errorf("unmarshal response body failed: %w", err), false + } + if cozeResponse.Data.Status == "completed" { + // 在上下文设置 usage + c.Set("coze_token_count", cozeResponse.Data.Usage.TokenCount) + c.Set("coze_output_count", cozeResponse.Data.Usage.OutputCount) + c.Set("coze_input_count", cozeResponse.Data.Usage.InputCount) + return nil, true + } else if cozeResponse.Data.Status == "failed" || cozeResponse.Data.Status == "canceled" || cozeResponse.Data.Status == "requires_action" { + return fmt.Errorf("chat status: %s", cozeResponse.Data.Status), false + } else { + return nil, false + } +} + +func doRequest(req *http.Request, info *common.RelayInfo) (*http.Response, error) { + var client *http.Client + var err error // 声明 err 变量 + if proxyURL, ok := info.ChannelSetting["proxy"]; ok { + client, err = service.NewProxyHttpClient(proxyURL.(string)) + if err != nil { + return nil, fmt.Errorf("new proxy http client failed: %w", err) + } + } else { + client = service.GetHttpClient() + } + resp, err := client.Do(req) + if err != nil { // 增加对 client.Do(req) 返回错误的检查 + return nil, fmt.Errorf("client.Do failed: %w", err) + } + _ = resp.Body.Close() + return resp, nil +} diff --git a/relay/constant/api_type.go b/relay/constant/api_type.go index fef38f23..3f1ecd78 100644 --- a/relay/constant/api_type.go +++ b/relay/constant/api_type.go @@ -33,6 +33,7 @@ const ( APITypeOpenRouter APITypeXinference APITypeXai + APITypeCoze APITypeDummy // this one is only for count, do not add any channel after this ) @@ -95,6 +96,8 @@ func ChannelType2APIType(channelType int) (int, bool) { apiType = APITypeXinference case common.ChannelTypeXai: apiType = APITypeXai + case common.ChannelTypeCoze: + apiType = APITypeCoze } if apiType == -1 { return APITypeOpenAI, false diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go index 8b4afcb3..7bf0da9f 100644 --- a/relay/relay_adaptor.go +++ b/relay/relay_adaptor.go @@ -10,6 +10,7 @@ import ( "one-api/relay/channel/claude" "one-api/relay/channel/cloudflare" "one-api/relay/channel/cohere" + "one-api/relay/channel/coze" "one-api/relay/channel/deepseek" "one-api/relay/channel/dify" "one-api/relay/channel/gemini" @@ -88,6 +89,8 @@ func GetAdaptor(apiType int) channel.Adaptor { return &openai.Adaptor{} case constant.APITypeXai: return &xai.Adaptor{} + case constant.APITypeCoze: + return &coze.Adaptor{} } return nil } From f17f38e56906936ce1e000e6842371fd85520eff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=B0=B8=E6=8C=AF?= Date: Tue, 13 May 2025 13:39:44 +0800 Subject: [PATCH 058/105] fix: ALI completions api path error --- relay/channel/ali/adaptor.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index 8e34fd80..ab632d22 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -33,6 +33,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", info.BaseUrl) case constant.RelayModeImagesGenerations: fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.BaseUrl) + case constant.RelayModeCompletions: + fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/completions", info.BaseUrl) default: fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.BaseUrl) } From b2499b0a7ed0d902ad7ae4653dd0d0ab7e81055a Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Tue, 13 May 2025 21:13:34 +0800 Subject: [PATCH 059/105] DoRequest --- relay/channel/coze/adaptor.go | 6 +++--- relay/channel/coze/dto.go | 30 ------------------------------ relay/channel/coze/relay-coze.go | 29 +++++++++++++++++++++++++---- 3 files changed, 28 insertions(+), 37 deletions(-) diff --git a/relay/channel/coze/adaptor.go b/relay/channel/coze/adaptor.go index b14239a6..34931cc6 100644 --- a/relay/channel/coze/adaptor.go +++ b/relay/channel/coze/adaptor.go @@ -42,7 +42,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *common.RelayInfo, r if request == nil { return nil, errors.New("request is nil") } - return convertCozeChatRequest(*request), nil + return convertCozeChatRequest(c, *request), nil } // ConvertOpenAIResponsesRequest implements channel.Adaptor. @@ -88,7 +88,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *common.RelayInfo, requestBody time.Sleep(time.Second * 1) } // 发送获取消息请求 - return channel.DoApiRequest(a, c, info, requestBody) + return getChatDetail(a, c, info) } // DoResponse implements channel.Adaptor. @@ -109,7 +109,7 @@ func (a *Adaptor) GetModelList() []string { // GetRequestURL implements channel.Adaptor. func (a *Adaptor) GetRequestURL(info *common.RelayInfo) (string, error) { - return fmt.Sprintf("%s/v3/chat/message/list", info.BaseUrl), nil + return fmt.Sprintf("%s/v3/chat", info.BaseUrl), nil } // Init implements channel.Adaptor. diff --git a/relay/channel/coze/dto.go b/relay/channel/coze/dto.go index fb92289a..38fc2f16 100644 --- a/relay/channel/coze/dto.go +++ b/relay/channel/coze/dto.go @@ -2,41 +2,11 @@ package coze import "encoding/json" -// type CozeResponse struct { -// Code int `json:"code"` -// Message string `json:"message"` -// Data CozeConversationData `json:"data"` -// Detail CozeConversationData `json:"detail"` -// } - -// type CozeConversationData struct { -// Id string `json:"id"` -// CreatedAt int64 `json:"created_at"` -// MetaData json.RawMessage `json:"meta_data"` -// LastSectionId string `json:"last_section_id"` -// } - -// type CozeResponseDetail struct { -// Logid string `json:"logid"` -// } - type CozeError struct { Code int `json:"code"` Message string `json:"message"` } -// type CozeErrorWithStatusCode struct { -// Error CozeError `json:"error"` -// StatusCode int -// LocalError bool -// } - -type CozeRequest struct { - BotId string `json:"bot_id,omitempty"` - MetaData json.RawMessage `json:"meta_data,omitempty"` - Messages []CozeEnterMessage `json:"messages,omitempty"` -} - type CozeEnterMessage struct { Role string `json:"role"` Type string `json:"type,omitempty"` diff --git a/relay/channel/coze/relay-coze.go b/relay/channel/coze/relay-coze.go index 49a3ac15..7c16763e 100644 --- a/relay/channel/coze/relay-coze.go +++ b/relay/channel/coze/relay-coze.go @@ -13,7 +13,7 @@ import ( "github.com/gin-gonic/gin" ) -func convertCozeChatRequest(request dto.GeneralOpenAIRequest) *CozeRequest { +func convertCozeChatRequest(c *gin.Context, request dto.GeneralOpenAIRequest) *CozeChatRequest { var messages []CozeEnterMessage // 将 request的messages的role为user的content转换为CozeMessage for _, message := range request.Messages { @@ -26,10 +26,12 @@ func convertCozeChatRequest(request dto.GeneralOpenAIRequest) *CozeRequest { }) } } - cozeRequest := &CozeRequest{ + cozeRequest := &CozeChatRequest{ // TODO: model to botid - BotId: "1", - Messages: messages, + BotId: "1", + UserId: c.GetString("id"), + AdditionalMessages: messages, + Stream: request.Stream, } return cozeRequest } @@ -101,6 +103,25 @@ func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo } } +func getChatDetail(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (*http.Response, error) { + requestURL := fmt.Sprintf("%s/v3/chat/message/list", info.BaseUrl) + + requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id") + req, err := http.NewRequest("GET", requestURL, nil) + if err != nil { + return nil, fmt.Errorf("new request failed: %w", err) + } + err = a.SetupRequestHeader(c, &req.Header, info) + if err != nil { + return nil, fmt.Errorf("setup request header failed: %w", err) + } + resp, err := doRequest(req, info) + if err != nil { + return nil, fmt.Errorf("do request failed: %w", err) + } + return resp, nil +} + func doRequest(req *http.Request, info *common.RelayInfo) (*http.Response, error) { var client *http.Client var err error // 声明 err 变量 From 29c95c598e380dbe5ff80cd0690a1c4c3770f93d Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Tue, 13 May 2025 22:01:12 +0800 Subject: [PATCH 060/105] cozeChatHelper --- relay/channel/coze/dto.go | 27 ++++++++++++++++++++ relay/channel/coze/relay-coze.go | 43 +++++++++++++++++++++++++++++--- 2 files changed, 66 insertions(+), 4 deletions(-) diff --git a/relay/channel/coze/dto.go b/relay/channel/coze/dto.go index 38fc2f16..4e9afa23 100644 --- a/relay/channel/coze/dto.go +++ b/relay/channel/coze/dto.go @@ -49,3 +49,30 @@ type CozeChatUsage struct { OutputCount int `json:"output_count"` InputCount int `json:"input_count"` } + +type CozeChatDetailResponse struct { + Data []CozeChatV3MessageDetail `json:"data"` + Code int `json:"code"` + Msg string `json:"msg"` + Detail CozeResponseDetail `json:"detail"` +} + +type CozeChatV3MessageDetail struct { + Id string `json:"id"` + Role string `json:"role"` + Type string `json:"type"` + BotId string `json:"bot_id"` + ChatId string `json:"chat_id"` + Content json.RawMessage `json:"content"` + MetaData json.RawMessage `json:"meta_data"` + CreatedAt int64 `json:"created_at"` + SectionId string `json:"section_id"` + UpdatedAt int64 `json:"updated_at"` + ContentType string `json:"content_type"` + ConversationId string `json:"conversation_id"` + ReasoningContent string `json:"reasoning_content"` +} + +type CozeResponseDetail struct { + Logid string `json:"logid"` +} diff --git a/relay/channel/coze/relay-coze.go b/relay/channel/coze/relay-coze.go index 7c16763e..fe630ef6 100644 --- a/relay/channel/coze/relay-coze.go +++ b/relay/channel/coze/relay-coze.go @@ -2,12 +2,14 @@ package coze import ( "encoding/json" + "errors" "fmt" "io" "net/http" "one-api/dto" "one-api/relay/common" relaycommon "one-api/relay/common" + "one-api/relay/helper" "one-api/service" "github.com/gin-gonic/gin" @@ -47,14 +49,47 @@ func cozeChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela } // convert coze response to openai response var response dto.TextResponse - var cozeResponse CozeChatResponse + var cozeResponse CozeChatDetailResponse + response.Model = info.UpstreamModelName err = json.Unmarshal(responseBody, &cozeResponse) if err != nil { return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } - response.Model = info.UpstreamModelName - // TODO: 处理 cozeResponse - return nil, nil + if cozeResponse.Code != 0 { + return service.OpenAIErrorWrapper(errors.New(cozeResponse.Msg), fmt.Sprintf("%d", cozeResponse.Code), http.StatusInternalServerError), nil + } + // 从上下文获取 usage + var usage dto.Usage + usage.PromptTokens = c.GetInt("coze_input_count") + usage.CompletionTokens = c.GetInt("coze_output_count") + usage.TotalTokens = c.GetInt("coze_token_count") + response.Usage = usage + response.Id = helper.GetResponseID(c) + + var responseContent json.RawMessage + for _, data := range cozeResponse.Data { + if data.Type == "answer" { + responseContent = data.Content + response.Created = data.CreatedAt + } + } + // 添加 response.Choices + response.Choices = []dto.OpenAITextResponseChoice{ + { + Index: 0, + Message: dto.Message{Role: "assistant", Content: responseContent}, + FinishReason: "stop", + }, + } + jsonResponse, err := json.Marshal(response) + if err != nil { + return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, _ = c.Writer.Write(jsonResponse) + + return nil, &usage } func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (error, bool) { From 108b67be6cc269778c17e24d38b5bc1971d11919 Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Tue, 13 May 2025 22:23:38 +0800 Subject: [PATCH 061/105] use channel bot id --- middleware/distributor.go | 2 ++ relay/channel/coze/relay-coze.go | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/middleware/distributor.go b/middleware/distributor.go index 34882381..fdda8dda 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -240,5 +240,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode c.Set("api_version", channel.Other) case common.ChannelTypeMokaAI: c.Set("api_version", channel.Other) + case common.ChannelTypeCoze: + c.Set("bot_id", channel.Other) } } diff --git a/relay/channel/coze/relay-coze.go b/relay/channel/coze/relay-coze.go index fe630ef6..8e9b8e3e 100644 --- a/relay/channel/coze/relay-coze.go +++ b/relay/channel/coze/relay-coze.go @@ -30,7 +30,7 @@ func convertCozeChatRequest(c *gin.Context, request dto.GeneralOpenAIRequest) *C } cozeRequest := &CozeChatRequest{ // TODO: model to botid - BotId: "1", + BotId: c.GetString("bot_id"), UserId: c.GetString("id"), AdditionalMessages: messages, Stream: request.Stream, From ea04e6bcc53e38e0f8f2776d12daadf67e32de52 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Wed, 14 May 2025 17:01:50 +0800 Subject: [PATCH 062/105] fix: update model selection logic for image edits in distributor middleware --- middleware/distributor.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/middleware/distributor.go b/middleware/distributor.go index 34882381..755a477d 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -185,7 +185,7 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "dall-e") } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") { - modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "gpt-image-1") + modelRequest.Model = common.GetStringIfEmpty(c.PostForm("model"), "gpt-image-1") } if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { relayMode := relayconstant.RelayModeAudioSpeech From 4825404d375622dff567deefdd69dd7495fa8c35 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Thu, 15 May 2025 14:51:33 +0800 Subject: [PATCH 063/105] feat: enhance image decoding logic to handle base64 file types and improve error handling --- service/token_counter.go | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/service/token_counter.go b/service/token_counter.go index 21b882af..d63b54ad 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -120,11 +120,12 @@ func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, m var config image.Config var err error var format string + var b64str string if strings.HasPrefix(imageUrl.Url, "http") { config, format, err = DecodeUrlImageData(imageUrl.Url) } else { common.SysLog(fmt.Sprintf("decoding image")) - config, format, _, err = DecodeBase64ImageData(imageUrl.Url) + config, format, b64str, err = DecodeBase64ImageData(imageUrl.Url) } if err != nil { return 0, err @@ -132,7 +133,12 @@ func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, m imageUrl.MimeType = format if config.Width == 0 || config.Height == 0 { - return 0, errors.New(fmt.Sprintf("fail to decode image config: %s", imageUrl.Url)) + // not an image + if format != "" && b64str != "" { + // file type + return 3 * baseTokens, nil + } + return 0, errors.New(fmt.Sprintf("fail to decode base64 config: %s", imageUrl.Url)) } shortSide := config.Width From 59aabb43119059bca2e26fd2059904294b6e0ce3 Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Thu, 15 May 2025 20:00:59 +0800 Subject: [PATCH 064/105] add frontend display, more model --- relay/channel/coze/constants.go | 24 +++++++++++++++++++++++- relay/channel/coze/relay-coze.go | 9 ++++++--- web/src/constants/channel.constants.js | 9 +++++++-- web/src/pages/Channel/EditChannel.js | 16 ++++++++++++++++ 4 files changed, 52 insertions(+), 6 deletions(-) diff --git a/relay/channel/coze/constants.go b/relay/channel/coze/constants.go index da28cb83..873ffe24 100644 --- a/relay/channel/coze/constants.go +++ b/relay/channel/coze/constants.go @@ -1,8 +1,30 @@ package coze var ModelList = []string{ - // TODO: 完整列表 + "moonshot-v1-8k", + "moonshot-v1-32k", + "moonshot-v1-128k", + "Baichuan4", + "abab6.5s-chat-pro", + "glm-4-0520", + "qwen-max", + "deepseek-r1", "deepseek-v3", + "deepseek-r1-distill-qwen-32b", + "deepseek-r1-distill-qwen-7b", + "step-1v-8k", + "step-1.5v-mini", + "Doubao-pro-32k", + "Doubao-pro-256k", + "Doubao-lite-128k", + "Doubao-lite-32k", + "Doubao-vision-lite-32k", + "Doubao-vision-pro-32k", + "Doubao-1.5-pro-vision-32k", + "Doubao-1.5-lite-32k", + "Doubao-1.5-pro-32k", + "Doubao-1.5-thinking-pro", + "Doubao-1.5-pro-256k", } var ChannelName = "coze" diff --git a/relay/channel/coze/relay-coze.go b/relay/channel/coze/relay-coze.go index 8e9b8e3e..1ebdb7c1 100644 --- a/relay/channel/coze/relay-coze.go +++ b/relay/channel/coze/relay-coze.go @@ -28,10 +28,13 @@ func convertCozeChatRequest(c *gin.Context, request dto.GeneralOpenAIRequest) *C }) } } + user := request.User + if user == "" { + user = helper.GetResponseID(c) + } cozeRequest := &CozeChatRequest{ - // TODO: model to botid BotId: c.GetString("bot_id"), - UserId: c.GetString("id"), + UserId: user, AdditionalMessages: messages, Stream: request.Stream, } @@ -172,6 +175,6 @@ func doRequest(req *http.Request, info *common.RelayInfo) (*http.Response, error if err != nil { // 增加对 client.Do(req) 返回错误的检查 return nil, fmt.Errorf("client.Do failed: %w", err) } - _ = resp.Body.Close() + // _ = resp.Body.Close() return resp, nil } diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index fa59bcce..054da535 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -118,6 +118,11 @@ export const CHANNEL_OPTIONS = [ { value: 48, color: 'blue', - label: 'xAI' - } + label: 'xAI', + }, + { + value: 49, + color: 'blue', + label: 'Coze', + }, ]; diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index fd96ffb6..f7fab057 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -838,6 +838,22 @@ const EditChannel = (props) => { /> )} + {inputs.type === 49 && ( + <> +
+ 智能体ID: +
+ { + handleInputChange('other', value); + }} + value={inputs.other} + autoComplete='new-password' + /> + + )}
{t('模型')}:
From e379ee8f66c1d3f85c89a26994b88227564ffa10 Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Fri, 16 May 2025 10:27:07 +0800 Subject: [PATCH 065/105] coze stream --- relay/channel/coze/adaptor.go | 9 ++- relay/channel/coze/relay-coze.go | 124 ++++++++++++++++++++++++++++++- 2 files changed, 130 insertions(+), 3 deletions(-) diff --git a/relay/channel/coze/adaptor.go b/relay/channel/coze/adaptor.go index 34931cc6..80441a51 100644 --- a/relay/channel/coze/adaptor.go +++ b/relay/channel/coze/adaptor.go @@ -57,6 +57,9 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt // DoRequest implements channel.Adaptor. func (a *Adaptor) DoRequest(c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (any, error) { + if info.IsStream { + return channel.DoApiRequest(a, c, info, requestBody) + } // 首先发送创建消息请求,成功后再发送获取消息请求 // 发送创建消息请求 resp, err := channel.DoApiRequest(a, c, info, requestBody) @@ -93,7 +96,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *common.RelayInfo, requestBody // DoResponse implements channel.Adaptor. func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *common.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { - err, usage = cozeChatHandler(c, resp, info) + if info.IsStream { + err, usage = cozeChatStreamHandler(c, resp, info) + } else { + err, usage = cozeChatHandler(c, resp, info) + } return } diff --git a/relay/channel/coze/relay-coze.go b/relay/channel/coze/relay-coze.go index 1ebdb7c1..6db40213 100644 --- a/relay/channel/coze/relay-coze.go +++ b/relay/channel/coze/relay-coze.go @@ -1,16 +1,18 @@ package coze import ( + "bufio" "encoding/json" "errors" "fmt" "io" "net/http" + "one-api/common" "one-api/dto" - "one-api/relay/common" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" + "strings" "github.com/gin-gonic/gin" ) @@ -95,6 +97,124 @@ func cozeChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela return nil, &usage } +func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + scanner := bufio.NewScanner(resp.Body) + scanner.Split(bufio.ScanLines) + helper.SetEventStreamHeaders(c) + id := helper.GetResponseID(c) + var responseText string + + var currentEvent string + var currentData string + var usage dto.Usage + + for scanner.Scan() { + line := scanner.Text() + + if line == "" { + if currentEvent != "" && currentData != "" { + // handle last event + handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info) + currentEvent = "" + currentData = "" + } + continue + } + + if strings.HasPrefix(line, "event:") { + currentEvent = strings.TrimSpace(line[6:]) + continue + } + + if strings.HasPrefix(line, "data:") { + currentData = strings.TrimSpace(line[5:]) + continue + } + } + + // Last event + if currentEvent != "" && currentData != "" { + handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info) + } + + if err := scanner.Err(); err != nil { + return service.OpenAIErrorWrapper(err, "stream_scanner_error", http.StatusInternalServerError), nil + } + helper.Done(c) + + if usage.TotalTokens == 0 { + usage.PromptTokens = info.PromptTokens + usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText) + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + } + + return nil, &usage +} + +func handleCozeEvent(c *gin.Context, event string, data string, responseText *string, usage *dto.Usage, id string, info *relaycommon.RelayInfo) { + switch event { + case "conversation.chat.completed": + // 将 data 解析为 CozeChatResponseData + var chatData CozeChatResponseData + err := json.Unmarshal([]byte(data), &chatData) + if err != nil { + common.SysError("error_unmarshalling_stream_response: " + err.Error()) + return + } + + usage.PromptTokens = chatData.Usage.InputCount + usage.CompletionTokens = chatData.Usage.OutputCount + usage.TotalTokens = chatData.Usage.TokenCount + + finishReason := "stop" + stopResponse := helper.GenerateStopResponse(id, common.GetTimestamp(), info.UpstreamModelName, finishReason) + helper.ObjectData(c, stopResponse) + + case "conversation.message.delta": + // 将 data 解析为 CozeChatV3MessageDetail + var messageData CozeChatV3MessageDetail + err := json.Unmarshal([]byte(data), &messageData) + if err != nil { + common.SysError("error_unmarshalling_stream_response: " + err.Error()) + return + } + + var content string + err = json.Unmarshal(messageData.Content, &content) + if err != nil { + common.SysError("error_unmarshalling_stream_response: " + err.Error()) + return + } + + *responseText += content + + openaiResponse := dto.ChatCompletionsStreamResponse{ + Id: id, + Object: "chat.completion.chunk", + Created: common.GetTimestamp(), + Model: info.UpstreamModelName, + } + + choice := dto.ChatCompletionsStreamResponseChoice{ + Index: 0, + } + choice.Delta.SetContentString(content) + openaiResponse.Choices = append(openaiResponse.Choices, choice) + + helper.ObjectData(c, openaiResponse) + + case "error": + var errorData CozeError + err := json.Unmarshal([]byte(data), &errorData) + if err != nil { + common.SysError("error_unmarshalling_stream_response: " + err.Error()) + return + } + + common.SysError(fmt.Sprintf("stream event error: ", errorData.Code, errorData.Message)) + } +} + func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (error, bool) { requestURL := fmt.Sprintf("%s/v3/chat/retrieve", info.BaseUrl) @@ -160,7 +280,7 @@ func getChatDetail(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (*ht return resp, nil } -func doRequest(req *http.Request, info *common.RelayInfo) (*http.Response, error) { +func doRequest(req *http.Request, info *relaycommon.RelayInfo) (*http.Response, error) { var client *http.Client var err error // 声明 err 变量 if proxyURL, ok := info.ChannelSetting["proxy"]; ok { From 9927e5d191c619c7a1835de36c5b51e17414a184 Mon Sep 17 00:00:00 2001 From: skynono <6811626@qq.com> Date: Fri, 16 May 2025 17:45:17 +0800 Subject: [PATCH 066/105] fix: proxy settings not applied when request MJ image url --- relay/relay-mj.go | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/relay/relay-mj.go b/relay/relay-mj.go index a7018456..9d0a2077 100644 --- a/relay/relay-mj.go +++ b/relay/relay-mj.go @@ -32,7 +32,23 @@ func RelayMidjourneyImage(c *gin.Context) { }) return } - resp, err := http.Get(midjourneyTask.ImageUrl) + var httpClient *http.Client + if channel, err := model.CacheGetChannel(midjourneyTask.ChannelId); err == nil { + if proxy, ok := channel.GetSetting()["proxy"]; ok { + if proxyURL, ok := proxy.(string); ok && proxyURL != "" { + if httpClient, err = service.NewProxyHttpClient(proxyURL); err != nil { + c.JSON(400, gin.H{ + "error": "proxy_url_invalid", + }) + return + } + } + } + } + if httpClient == nil { + httpClient = service.GetHttpClient() + } + resp, err := httpClient.Get(midjourneyTask.ImageUrl) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "error": "http_get_image_failed", From 9c12e02cb50ad696aab25170c2cace1ad3eb3691 Mon Sep 17 00:00:00 2001 From: skynono <6811626@qq.com> Date: Mon, 19 May 2025 14:28:29 +0800 Subject: [PATCH 067/105] fix: if default model is not exist, set the first one as default --- web/src/pages/Playground/Playground.js | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/web/src/pages/Playground/Playground.js b/web/src/pages/Playground/Playground.js index e8138c01..08eada17 100644 --- a/web/src/pages/Playground/Playground.js +++ b/web/src/pages/Playground/Playground.js @@ -64,8 +64,9 @@ const Playground = () => { }, ]; + const defaultModel = 'gpt-4o-mini'; const [inputs, setInputs] = useState({ - model: 'gpt-4o-mini', + model: defaultModel, group: '', max_tokens: 0, temperature: 0, @@ -108,6 +109,11 @@ const Playground = () => { value: model, })); setModels(localModelOptions); + // if default model is not in the list, set the first one as default + const hasDefault = localModelOptions.some(option => option.value === defaultModel); + if (!hasDefault && localModelOptions.length > 0) { + setInputs((inputs) => ({ ...inputs, model: localModelOptions[0].value })); + } } else { showError(t(message)); } From 498d73f67cf939330b17cf39b501fb10d6ba88fc Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Mon, 19 May 2025 20:26:30 +0800 Subject: [PATCH 068/105] refactor: update JSON field names in GeminiChatRequest for consistency --- relay/channel/gemini/dto.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/relay/channel/gemini/dto.go b/relay/channel/gemini/dto.go index 5d5c1287..93643a20 100644 --- a/relay/channel/gemini/dto.go +++ b/relay/channel/gemini/dto.go @@ -2,10 +2,10 @@ package gemini type GeminiChatRequest struct { Contents []GeminiChatContent `json:"contents"` - SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"` - GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"` + SafetySettings []GeminiChatSafetySettings `json:"safetySettings,omitempty"` + GenerationConfig GeminiChatGenerationConfig `json:"generationConfig,omitempty"` Tools []GeminiChatTool `json:"tools,omitempty"` - SystemInstructions *GeminiChatContent `json:"system_instruction,omitempty"` + SystemInstructions *GeminiChatContent `json:"systemInstruction,omitempty"` } type GeminiThinkingConfig struct { From 1f9fc09989624fbbbe74a987f9b4141a10af8054 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Tue, 20 May 2025 19:40:29 +0800 Subject: [PATCH 069/105] feat: add OutputFormat field to ImageRequest for enhanced image processing options --- dto/dalle.go | 1 + 1 file changed, 1 insertion(+) diff --git a/dto/dalle.go b/dto/dalle.go index 44104d33..a1309b6c 100644 --- a/dto/dalle.go +++ b/dto/dalle.go @@ -14,6 +14,7 @@ type ImageRequest struct { ExtraFields json.RawMessage `json:"extra_fields,omitempty"` Background string `json:"background,omitempty"` Moderation string `json:"moderation,omitempty"` + OutputFormat string `json:"output_format,omitempty"` } type ImageResponse struct { From e1190f98e9de90f8f7fa090043bc1ae02cc5a7a7 Mon Sep 17 00:00:00 2001 From: skynono <6811626@qq.com> Date: Mon, 19 May 2025 15:42:36 +0800 Subject: [PATCH 070/105] fix: typo in oidc_enabled field (previously oidc) --- web/src/pages/Home/index.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/src/pages/Home/index.js b/web/src/pages/Home/index.js index 599c7930..84fabf6f 100644 --- a/web/src/pages/Home/index.js +++ b/web/src/pages/Home/index.js @@ -158,7 +158,7 @@ const Home = () => {

{t('OIDC 身份验证')}: - {statusState?.status?.oidc === true + {statusState?.status?.oidc_enabled === true ? t('已启用') : t('未启用')}

From 66bdfe180c4ba681cbd6fbba0fc271043983f6cf Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Thu, 22 May 2025 15:52:23 +0800 Subject: [PATCH 071/105] feat: add Thought field to GeminiPart and update response handling in streamResponseGeminiChat2OpenAI --- relay/channel/gemini/dto.go | 1 + relay/channel/gemini/relay-gemini.go | 10 +++++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/relay/channel/gemini/dto.go b/relay/channel/gemini/dto.go index 93643a20..a0e38cb4 100644 --- a/relay/channel/gemini/dto.go +++ b/relay/channel/gemini/dto.go @@ -54,6 +54,7 @@ type GeminiFileData struct { type GeminiPart struct { Text string `json:"text,omitempty"` + Thought bool `json:"thought,omitempty"` InlineData *GeminiInlineData `json:"inlineData,omitempty"` FunctionCall *FunctionCall `json:"functionCall,omitempty"` FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"` diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index ae9a3b7b..80d55f6d 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -596,6 +596,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C } var texts []string isTools := false + isThought := false if candidate.FinishReason != nil { // p := GeminiConvertFinishReason(*candidate.FinishReason) switch *candidate.FinishReason { @@ -620,6 +621,9 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C call.SetIndex(len(choice.Delta.ToolCalls)) choice.Delta.ToolCalls = append(choice.Delta.ToolCalls, *call) } + } else if part.Thought { + isThought = true + texts = append(texts, part.Text) } else { if part.ExecutableCode != nil { texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```\n") @@ -632,7 +636,11 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C } } } - choice.Delta.SetContentString(strings.Join(texts, "\n")) + if isThought { + choice.Delta.SetReasoningContent(strings.Join(texts, "\n")) + } else { + choice.Delta.SetContentString(strings.Join(texts, "\n")) + } if isTools { choice.FinishReason = &constant.FinishReasonToolCalls } From e18001299b8d366c1147d71050f8d75cb2e90891 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Thu, 22 May 2025 16:11:50 +0800 Subject: [PATCH 072/105] feat: enhance Gemini response handling by adding reasoning content and updating JSON decoding method --- relay/channel/gemini/relay-gemini.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index 80d55f6d..da0bc5fc 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -539,6 +539,8 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp if call := getResponseToolCall(&part); call != nil { toolCalls = append(toolCalls, *call) } + } else if part.Thought { + choice.Message.ReasoningContent = part.Text } else { if part.ExecutableCode != nil { texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```") @@ -556,7 +558,6 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp choice.Message.SetToolCalls(toolCalls) isToolCall = true } - choice.Message.SetStringContent(strings.Join(texts, "\n")) } @@ -724,8 +725,11 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re if err != nil { return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } + if common.DebugEnabled { + println(string(responseBody)) + } var geminiResponse GeminiChatResponse - err = json.Unmarshal(responseBody, &geminiResponse) + err = common.DecodeJson(responseBody, &geminiResponse) if err != nil { return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } From 9a59da16a540c41ac199f1126028560bb8e0d67c Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Thu, 22 May 2025 16:54:55 +0800 Subject: [PATCH 073/105] feat: implement search functionality in ChannelsTable for improved channel filtering --- web/src/components/ChannelsTable.js | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/web/src/components/ChannelsTable.js b/web/src/components/ChannelsTable.js index 3425beea..9b1dd602 100644 --- a/web/src/components/ChannelsTable.js +++ b/web/src/components/ChannelsTable.js @@ -871,7 +871,16 @@ const ChannelsTable = () => { }; const refresh = async () => { - await loadChannels(activePage - 1, pageSize, idSort, enableTagMode); + if (searchKeyword === '' && searchGroup === '' && searchModel === '') { + await loadChannels(activePage - 1, pageSize, idSort, enableTagMode); + } else { + await searchChannels( + searchKeyword, + searchGroup, + searchModel, + enableTagMode, + ); + } }; useEffect(() => { @@ -979,8 +988,8 @@ const ChannelsTable = () => { enableTagMode, ) => { if (searchKeyword === '' && searchGroup === '' && searchModel === '') { - await loadChannels(0, pageSize, idSort, enableTagMode); - setActivePage(1); + await loadChannels(activePage - 1, pageSize, idSort, enableTagMode); + // setActivePage(1); return; } setSearching(true); From c53a48cde51a144eab74469d7699e51345c8cafe Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Fri, 23 May 2025 01:26:52 +0800 Subject: [PATCH 074/105] feat: add panic recovery and retry mechanism for InitChannelCache; improve batch deletion of abilities in FixAbility --- main.go | 19 ++++++++++++++++--- model/ability.go | 47 ++++++++++++++++++++++++++++++++++++----------- model/cache.go | 7 +++++-- model/channel.go | 11 +++++++++++ 4 files changed, 68 insertions(+), 16 deletions(-) diff --git a/main.go b/main.go index 95c6820d..c286650f 100644 --- a/main.go +++ b/main.go @@ -89,9 +89,22 @@ func main() { if common.MemoryCacheEnabled { common.SysLog("memory cache enabled") common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency)) - model.InitChannelCache() - } - if common.MemoryCacheEnabled { + + // Add panic recovery and retry for InitChannelCache + func() { + defer func() { + if r := recover(); r != nil { + common.SysError(fmt.Sprintf("InitChannelCache panic: %v, retrying once", r)) + // Retry once + _, fixErr := model.FixAbility() + if fixErr != nil { + common.SysError(fmt.Sprintf("InitChannelCache failed: %s", fixErr.Error())) + } + } + }() + model.InitChannelCache() + }() + go model.SyncOptions(common.SyncFrequency) go model.SyncChannelCache(common.SyncFrequency) } diff --git a/model/ability.go b/model/ability.go index 52720307..38b0bd73 100644 --- a/model/ability.go +++ b/model/ability.go @@ -50,7 +50,7 @@ func getPriority(group string, model string, retry int) (int, error) { err := DB.Model(&Ability{}). Select("DISTINCT(priority)"). Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model). - Order("priority DESC"). // 按优先级降序排序 + Order("priority DESC"). // 按优先级降序排序 Pluck("priority", &priorities).Error // Pluck用于将查询的结果直接扫描到一个切片中 if err != nil { @@ -261,12 +261,28 @@ func FixAbility() (int, error) { common.SysError(fmt.Sprintf("Get channel ids from channel table failed: %s", err.Error())) return 0, err } - // Delete abilities of channels that are not in channel table - err = DB.Where("channel_id NOT IN (?)", channelIds).Delete(&Ability{}).Error - if err != nil { - common.SysError(fmt.Sprintf("Delete abilities of channels that are not in channel table failed: %s", err.Error())) - return 0, err + + // Delete abilities of channels that are not in channel table - in batches to avoid too many placeholders + if len(channelIds) > 0 { + // Process deletion in chunks to avoid "too many placeholders" error + for _, chunk := range lo.Chunk(channelIds, 100) { + err = DB.Where("channel_id NOT IN (?)", chunk).Delete(&Ability{}).Error + if err != nil { + common.SysError(fmt.Sprintf("Delete abilities of channels (batch) that are not in channel table failed: %s", err.Error())) + return 0, err + } + } + } else { + // If no channels exist, delete all abilities + err = DB.Delete(&Ability{}).Error + if err != nil { + common.SysError(fmt.Sprintf("Delete all abilities failed: %s", err.Error())) + return 0, err + } + common.SysLog("Delete all abilities successfully") + return 0, nil } + common.SysLog(fmt.Sprintf("Delete abilities of channels that are not in channel table successfully, ids: %v", channelIds)) count += len(channelIds) @@ -275,17 +291,26 @@ func FixAbility() (int, error) { err = DB.Table("abilities").Distinct("channel_id").Pluck("channel_id", &abilityChannelIds).Error if err != nil { common.SysError(fmt.Sprintf("Get channel ids from abilities table failed: %s", err.Error())) - return 0, err + return count, err } + var channels []Channel if len(abilityChannelIds) == 0 { err = DB.Find(&channels).Error } else { - err = DB.Where("id NOT IN (?)", abilityChannelIds).Find(&channels).Error - } - if err != nil { - return 0, err + // Process query in chunks to avoid "too many placeholders" error + err = nil + for _, chunk := range lo.Chunk(abilityChannelIds, 100) { + var channelsChunk []Channel + err = DB.Where("id NOT IN (?)", chunk).Find(&channelsChunk).Error + if err != nil { + common.SysError(fmt.Sprintf("Find channels not in abilities table failed: %s", err.Error())) + return count, err + } + channels = append(channels, channelsChunk...) + } } + for _, channel := range channels { err := channel.UpdateAbilities(nil) if err != nil { diff --git a/model/cache.go b/model/cache.go index 2d1c36bf..e2f83e22 100644 --- a/model/cache.go +++ b/model/cache.go @@ -16,6 +16,9 @@ var channelsIDM map[int]*Channel var channelSyncLock sync.RWMutex func InitChannelCache() { + if !common.MemoryCacheEnabled { + return + } newChannelId2channel := make(map[int]*Channel) var channels []*Channel DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels) @@ -84,11 +87,11 @@ func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Cha if !common.MemoryCacheEnabled { return GetRandomSatisfiedChannel(group, model, retry) } - + channelSyncLock.RLock() channels := group2model2channels[group][model] channelSyncLock.RUnlock() - + if len(channels) == 0 { return nil, errors.New("channel not found") } diff --git a/model/channel.go b/model/channel.go index 41e5e371..ed7a0a7e 100644 --- a/model/channel.go +++ b/model/channel.go @@ -46,6 +46,17 @@ func (channel *Channel) GetModels() []string { return strings.Split(strings.Trim(channel.Models, ","), ",") } +func (channel *Channel) GetGroups() []string { + if channel.Group == "" { + return []string{} + } + groups := strings.Split(strings.Trim(channel.Group, ","), ",") + for i, group := range groups { + groups[i] = strings.TrimSpace(group) + } + return groups +} + func (channel *Channel) GetOtherInfo() map[string]interface{} { otherInfo := make(map[string]interface{}) if channel.OtherInfo != "" { From f796c3b216cf3a748d58abea65d2d0acaf8d68e6 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Fri, 23 May 2025 01:34:53 +0800 Subject: [PATCH 075/105] fix: update Init method to correctly set RequestMode based on upstream model name prefixes --- relay/channel/claude/adaptor.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/relay/channel/claude/adaptor.go b/relay/channel/claude/adaptor.go index 4b071712..8389b9f1 100644 --- a/relay/channel/claude/adaptor.go +++ b/relay/channel/claude/adaptor.go @@ -38,10 +38,10 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf } func (a *Adaptor) Init(info *relaycommon.RelayInfo) { - if strings.HasPrefix(info.UpstreamModelName, "claude-3") { - a.RequestMode = RequestModeMessage - } else { + if strings.HasPrefix(info.UpstreamModelName, "claude-2") || strings.HasPrefix(info.UpstreamModelName, "claude-instant") { a.RequestMode = RequestModeCompletion + } else { + a.RequestMode = RequestModeMessage } } From 66a8612d127d68d350bac4183c703a1502774c94 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Fri, 23 May 2025 02:02:21 +0800 Subject: [PATCH 076/105] feat: add new model ratios for Claude Sonnet 4 and Claude Opus 4; update ratio retrieval logic for improved handling of model names --- setting/operation_setting/model-ratio.go | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/setting/operation_setting/model-ratio.go b/setting/operation_setting/model-ratio.go index fdc1c950..700a7c4e 100644 --- a/setting/operation_setting/model-ratio.go +++ b/setting/operation_setting/model-ratio.go @@ -114,7 +114,9 @@ var defaultModelRatio = map[string]float64{ "claude-3-5-sonnet-20241022": 1.5, "claude-3-7-sonnet-20250219": 1.5, "claude-3-7-sonnet-20250219-thinking": 1.5, + "claude-sonnet-4-20250514": 1.5, "claude-3-opus-20240229": 7.5, // $15 / 1M tokens + "claude-opus-4-20250514": 7.5, "ERNIE-4.0-8K": 0.120 * RMB, "ERNIE-3.5-8K": 0.012 * RMB, "ERNIE-3.5-8K-0205": 0.024 * RMB, @@ -440,13 +442,15 @@ func getHardcodedCompletionModelRatio(name string) (float64, bool) { if name == "chatgpt-4o-latest" { return 3, true } - if strings.Contains(name, "claude-instant-1") { - return 3, true - } else if strings.Contains(name, "claude-2") { - return 3, true - } else if strings.Contains(name, "claude-3") { + + if strings.Contains(name, "claude-3") { return 5, true + } else if strings.Contains(name, "claude-sonnet-4") || strings.Contains(name, "claude-opus-4") { + return 5, true + } else if strings.Contains(name, "claude-instant-1") || strings.Contains(name, "claude-2") { + return 3, true } + if strings.HasPrefix(name, "gpt-3.5") { if name == "gpt-3.5-turbo" || strings.HasSuffix(name, "0125") { // https://openai.com/blog/new-embedding-models-and-api-updates From 1644b7b15d6917a7aa6b35ea14ad1413f246db0f Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Fri, 23 May 2025 15:20:16 +0800 Subject: [PATCH 077/105] feat: add new model entries for Claude Sonnet 4 and Claude Opus 4 across multiple components, including constants and cache settings --- relay/channel/aws/constants.go | 2 ++ relay/channel/claude/constants.go | 4 ++++ relay/channel/vertex/adaptor.go | 2 ++ setting/operation_setting/cache_ratio.go | 8 ++++++++ 4 files changed, 16 insertions(+) diff --git a/relay/channel/aws/constants.go b/relay/channel/aws/constants.go index 37196fd8..9285482a 100644 --- a/relay/channel/aws/constants.go +++ b/relay/channel/aws/constants.go @@ -11,6 +11,8 @@ var awsModelIDMap = map[string]string{ "claude-3-5-sonnet-20241022": "anthropic.claude-3-5-sonnet-20241022-v2:0", "claude-3-5-haiku-20241022": "anthropic.claude-3-5-haiku-20241022-v1:0", "claude-3-7-sonnet-20250219": "anthropic.claude-3-7-sonnet-20250219-v1:0", + "claude-sonnet-4-20250514": "anthropic.claude-sonnet-4-20250514-v1:0", + "claude-opus-4-20250514": "anthropic.claude-opus-4-20250514-v1:0", } var awsModelCanCrossRegionMap = map[string]map[string]bool{ diff --git a/relay/channel/claude/constants.go b/relay/channel/claude/constants.go index d7e0c8e3..e0e3c421 100644 --- a/relay/channel/claude/constants.go +++ b/relay/channel/claude/constants.go @@ -13,6 +13,10 @@ var ModelList = []string{ "claude-3-5-sonnet-20241022", "claude-3-7-sonnet-20250219", "claude-3-7-sonnet-20250219-thinking", + "claude-sonnet-4-20250514", + "claude-sonnet-4-20250514-thinking", + "claude-opus-4-20250514", + "claude-opus-4-20250514-thinking", } var ChannelName = "claude" diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go index 7daf9a61..d21a3e08 100644 --- a/relay/channel/vertex/adaptor.go +++ b/relay/channel/vertex/adaptor.go @@ -31,6 +31,8 @@ var claudeModelMap = map[string]string{ "claude-3-5-sonnet-20240620": "claude-3-5-sonnet@20240620", "claude-3-5-sonnet-20241022": "claude-3-5-sonnet-v2@20241022", "claude-3-7-sonnet-20250219": "claude-3-7-sonnet@20250219", + "claude-sonnet-4-20250514": "claude-sonnet-4@20250514", + "claude-opus-4-20250514": "claude-opus-4@20250514", } const anthropicVersion = "vertex-2023-10-16" diff --git a/setting/operation_setting/cache_ratio.go b/setting/operation_setting/cache_ratio.go index dd29eac2..ec0c766d 100644 --- a/setting/operation_setting/cache_ratio.go +++ b/setting/operation_setting/cache_ratio.go @@ -36,6 +36,10 @@ var defaultCacheRatio = map[string]float64{ "claude-3-5-sonnet-20241022": 0.1, "claude-3-7-sonnet-20250219": 0.1, "claude-3-7-sonnet-20250219-thinking": 0.1, + "claude-sonnet-4-20250514": 0.1, + "claude-sonnet-4-20250514-thinking": 0.1, + "claude-opus-4-20250514": 0.1, + "claude-opus-4-20250514-thinking": 0.1, } var defaultCreateCacheRatio = map[string]float64{ @@ -47,6 +51,10 @@ var defaultCreateCacheRatio = map[string]float64{ "claude-3-5-sonnet-20241022": 1.25, "claude-3-7-sonnet-20250219": 1.25, "claude-3-7-sonnet-20250219-thinking": 1.25, + "claude-sonnet-4-20250514": 1.25, + "claude-sonnet-4-20250514-thinking": 1.25, + "claude-opus-4-20250514": 1.25, + "claude-opus-4-20250514-thinking": 1.25, } //var defaultCreateCacheRatio = map[string]float64{} From 2cc2d4f6526ad809e35cb405e8e6597691e3a5e1 Mon Sep 17 00:00:00 2001 From: skynono <6811626@qq.com> Date: Fri, 23 May 2025 13:51:11 +0800 Subject: [PATCH 078/105] fix: keep BatchDelete and TagMode enabled status --- web/src/components/ChannelsTable.js | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/web/src/components/ChannelsTable.js b/web/src/components/ChannelsTable.js index 9b1dd602..f490e14a 100644 --- a/web/src/components/ChannelsTable.js +++ b/web/src/components/ChannelsTable.js @@ -888,9 +888,13 @@ const ChannelsTable = () => { const localIdSort = localStorage.getItem('id-sort') === 'true'; const localPageSize = parseInt(localStorage.getItem('page-size')) || ITEMS_PER_PAGE; + const localEnableTagMode = localStorage.getItem('enable-tag-mode') === 'true'; + const localEnableBatchDelete = localStorage.getItem('enable-batch-delete') === 'true'; setIdSort(localIdSort); setPageSize(localPageSize); - loadChannels(0, localPageSize, localIdSort, enableTagMode) + setEnableTagMode(localEnableTagMode); + setEnableBatchDelete(localEnableBatchDelete); + loadChannels(0, localPageSize, localIdSort, localEnableTagMode) .then() .catch((reason) => { showError(reason); @@ -1486,10 +1490,12 @@ const ChannelsTable = () => { {t('开启批量操作')} { + localStorage.setItem('enable-batch-delete', v + ''); setEnableBatchDelete(v); }} /> @@ -1553,6 +1559,7 @@ const ChannelsTable = () => { uncheckedText={t('关')} aria-label={t('是否启用标签聚合')} onChange={(v) => { + localStorage.setItem('enable-tag-mode', v + ''); setEnableTagMode(v); loadChannels(0, pageSize, idSort, v); }} From d95c2436d75720652fef4fd23d34aed2dbcb8d80 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Fri, 23 May 2025 21:11:00 +0800 Subject: [PATCH 079/105] feat: add support for new regions in Claude Sonnet 4 and Claude Opus 4 models in AWS constants --- relay/channel/aws/constants.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/relay/channel/aws/constants.go b/relay/channel/aws/constants.go index 9285482a..078155f6 100644 --- a/relay/channel/aws/constants.go +++ b/relay/channel/aws/constants.go @@ -43,6 +43,16 @@ var awsModelCanCrossRegionMap = map[string]map[string]bool{ }, "anthropic.claude-3-7-sonnet-20250219-v1:0": { "us": true, + "ap": true, + "eu": true, + }, + "apac.anthropic.claude-sonnet-4-20250514-v1:0": { + "us": true, + "ap": true, + "eu": true, + }, + "anthropic.claude-opus-4-20250514-v1:0": { + "us": true, }, } From 0595636cebed9ccf43a9bcd85420a333d9bd8225 Mon Sep 17 00:00:00 2001 From: daggeryu <997411652@qq.com> Date: Sat, 24 May 2025 01:21:14 +0800 Subject: [PATCH 080/105] fix aws claude-sonnet-4-20250514 --- relay/channel/aws/constants.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/relay/channel/aws/constants.go b/relay/channel/aws/constants.go index 078155f6..64c7b747 100644 --- a/relay/channel/aws/constants.go +++ b/relay/channel/aws/constants.go @@ -46,7 +46,7 @@ var awsModelCanCrossRegionMap = map[string]map[string]bool{ "ap": true, "eu": true, }, - "apac.anthropic.claude-sonnet-4-20250514-v1:0": { + "anthropic.claude-sonnet-4-20250514-v1:0": { "us": true, "ap": true, "eu": true, From fbdad581b51f7e7007d4070c9afe82906effaccd Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Sat, 24 May 2025 15:26:55 +0800 Subject: [PATCH 081/105] fix: improve input validation and error handling in ModelSetting and SettingGeminiModel components --- web/src/components/ModelSetting.js | 5 +- .../pages/Setting/Model/SettingGeminiModel.js | 70 ++++++++++++------- 2 files changed, 47 insertions(+), 28 deletions(-) diff --git a/web/src/components/ModelSetting.js b/web/src/components/ModelSetting.js index 2a566d6b..9c60a390 100644 --- a/web/src/components/ModelSetting.js +++ b/web/src/components/ModelSetting.js @@ -39,7 +39,9 @@ const ModelSetting = () => { item.key === 'claude.default_max_tokens' || item.key === 'gemini.supported_imagine_models' ) { - item.value = JSON.stringify(JSON.parse(item.value), null, 2); + if (item.value !== '') { + item.value = JSON.stringify(JSON.parse(item.value), null, 2); + } } if (item.key.endsWith('Enabled') || item.key.endsWith('enabled')) { newInputs[item.key] = item.value === 'true' ? true : false; @@ -60,6 +62,7 @@ const ModelSetting = () => { // showSuccess('刷新成功'); } catch (error) { showError('刷新失败'); + console.error(error); } finally { setLoading(false); } diff --git a/web/src/pages/Setting/Model/SettingGeminiModel.js b/web/src/pages/Setting/Model/SettingGeminiModel.js index 6f6da279..b802af1a 100644 --- a/web/src/pages/Setting/Model/SettingGeminiModel.js +++ b/web/src/pages/Setting/Model/SettingGeminiModel.js @@ -27,40 +27,48 @@ export default function SettingGeminiModel(props) { const [inputs, setInputs] = useState({ 'gemini.safety_settings': '', 'gemini.version_settings': '', - 'gemini.supported_imagine_models': [], + 'gemini.supported_imagine_models': '', 'gemini.thinking_adapter_enabled': false, 'gemini.thinking_adapter_budget_tokens_percentage': 0.6, }); const refForm = useRef(); const [inputsRow, setInputsRow] = useState(inputs); - function onSubmit() { - const updateArray = compareObjects(inputs, inputsRow); - if (!updateArray.length) return showWarning(t('你似乎并没有修改什么')); - const requestQueue = updateArray.map((item) => { - let value = String(inputs[item.key]); - return API.put('/api/option/', { - key: item.key, - value, - }); - }); - setLoading(true); - Promise.all(requestQueue) - .then((res) => { - if (requestQueue.length === 1) { - if (res.includes(undefined)) return; - } else if (requestQueue.length > 1) { - if (res.includes(undefined)) - return showError(t('部分保存失败,请重试')); - } - showSuccess(t('保存成功')); - props.refresh(); + async function onSubmit() { + await refForm.current + .validate() + .then(() => { + const updateArray = compareObjects(inputs, inputsRow); + if (!updateArray.length) return showWarning(t('你似乎并没有修改什么')); + const requestQueue = updateArray.map((item) => { + let value = String(inputs[item.key]); + return API.put('/api/option/', { + key: item.key, + value, + }); + }); + setLoading(true); + Promise.all(requestQueue) + .then((res) => { + if (requestQueue.length === 1) { + if (res.includes(undefined)) return; + } else if (requestQueue.length > 1) { + if (res.includes(undefined)) + return showError(t('部分保存失败,请重试')); + } + showSuccess(t('保存成功')); + props.refresh(); + }) + .catch(() => { + showError(t('保存失败,请重试')); + }) + .finally(() => { + setLoading(false); + }); }) - .catch(() => { - showError(t('保存失败,请重试')); - }) - .finally(() => { - setLoading(false); + .catch((error) => { + console.error('Validation failed:', error); + showError(t('请检查输入')); }); } @@ -146,6 +154,14 @@ export default function SettingGeminiModel(props) { label={t('支持的图像模型')} placeholder={t('例如:') + '\n' + JSON.stringify(['gemini-2.0-flash-exp-image-generation'], null, 2)} onChange={(value) => setInputs({ ...inputs, 'gemini.supported_imagine_models': value })} + trigger='blur' + stopValidateWithError + rules={[ + { + validator: (rule, value) => verifyJSON(value), + message: t('不是合法的 JSON 字符串'), + }, + ]} /> From 738a9a455853163f6e1e705fe1fb4c2ca5aa094f Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Mon, 26 May 2025 13:34:41 +0800 Subject: [PATCH 082/105] gemini text generation --- controller/relay.go | 2 + dto/gemini.go | 69 ++++++++++ middleware/auth.go | 12 +- middleware/distributor.go | 36 +++++ relay/channel/gemini/adaptor.go | 6 + relay/channel/gemini/relay-gemini-native.go | 77 +++++++++++ relay/constant/relay_mode.go | 4 + relay/relay-gemini.go | 141 ++++++++++++++++++++ router/relay-router.go | 8 ++ 9 files changed, 353 insertions(+), 2 deletions(-) create mode 100644 dto/gemini.go create mode 100644 relay/channel/gemini/relay-gemini-native.go create mode 100644 relay/relay-gemini.go diff --git a/controller/relay.go b/controller/relay.go index 41cb22a5..1a875dbc 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -40,6 +40,8 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode err = relay.EmbeddingHelper(c) case relayconstant.RelayModeResponses: err = relay.ResponsesHelper(c) + case relayconstant.RelayModeGemini: + err = relay.GeminiHelper(c) default: err = relay.TextHelper(c) } diff --git a/dto/gemini.go b/dto/gemini.go new file mode 100644 index 00000000..898c966f --- /dev/null +++ b/dto/gemini.go @@ -0,0 +1,69 @@ +package dto + +import "encoding/json" + +type GeminiPart struct { + Text string `json:"text"` +} + +type GeminiContent struct { + Parts []GeminiPart `json:"parts"` + Role string `json:"role"` +} + +type GeminiCandidate struct { + Content GeminiContent `json:"content"` + FinishReason string `json:"finishReason"` + AvgLogprobs float64 `json:"avgLogprobs"` +} + +type GeminiTokenDetails struct { + Modality string `json:"modality"` + TokenCount int `json:"tokenCount"` +} + +type GeminiUsageMetadata struct { + PromptTokenCount int `json:"promptTokenCount"` + CandidatesTokenCount int `json:"candidatesTokenCount"` + TotalTokenCount int `json:"totalTokenCount"` + PromptTokensDetails []GeminiTokenDetails `json:"promptTokensDetails"` + CandidatesTokensDetails []GeminiTokenDetails `json:"candidatesTokensDetails"` +} + +type GeminiTextGenerationResponse struct { + Candidates []GeminiCandidate `json:"candidates"` + UsageMetadata GeminiUsageMetadata `json:"usageMetadata"` + ModelVersion string `json:"modelVersion"` + ResponseID string `json:"responseId"` +} + +type GeminiGenerationConfig struct { + StopSequences []string `json:"stopSequences,omitempty"` + ResponseMimeType string `json:"responseMimeType,omitempty"` + ResponseSchema *json.RawMessage `json:"responseSchema,omitempty"` + ResponseModalities *json.RawMessage `json:"responseModalities,omitempty"` + CandidateCount int `json:"candidateCount,omitempty"` + MaxOutputTokens int `json:"maxOutputTokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"topP,omitempty"` + TopK int `json:"topK,omitempty"` + Seed int `json:"seed,omitempty"` + PresencePenalty float64 `json:"presencePenalty,omitempty"` + FrequencyPenalty float64 `json:"frequencyPenalty,omitempty"` + ResponseLogprobs bool `json:"responseLogprobs,omitempty"` + LogProbs int `json:"logProbs,omitempty"` + EnableEnhancedCivicAnswers bool `json:"enableEnhancedCivicAnswers,omitempty"` + SpeechConfig *json.RawMessage `json:"speechConfig,omitempty"` + ThinkingConfig *json.RawMessage `json:"thinkingConfig,omitempty"` + MediaResolution *json.RawMessage `json:"mediaResolution,omitempty"` +} + +type GeminiTextGenerationRequest struct { + Contents []GeminiContent `json:"contents"` + Tools *json.RawMessage `json:"tools,omitempty"` + ToolConfig *json.RawMessage `json:"toolConfig,omitempty"` + SafetySettings *json.RawMessage `json:"safetySettings,omitempty"` + SystemInstruction *json.RawMessage `json:"systemInstruction,omitempty"` + GenerationConfig GeminiGenerationConfig `json:"generationConfig,omitempty"` + CachedContent *json.RawMessage `json:"cachedContent,omitempty"` +} diff --git a/middleware/auth.go b/middleware/auth.go index fece4553..ce86bb36 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -1,13 +1,14 @@ package middleware import ( - "github.com/gin-contrib/sessions" - "github.com/gin-gonic/gin" "net/http" "one-api/common" "one-api/model" "strconv" "strings" + + "github.com/gin-contrib/sessions" + "github.com/gin-gonic/gin" ) func validUserInfo(username string, role int) bool { @@ -182,6 +183,13 @@ func TokenAuth() func(c *gin.Context) { c.Request.Header.Set("Authorization", "Bearer "+key) } } + // gemini api 从query中获取key + if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") { + skKey := c.Query("key") + if skKey != "" { + c.Request.Header.Set("Authorization", "Bearer "+skKey) + } + } key := c.Request.Header.Get("Authorization") parts := make([]string, 0) key = strings.TrimPrefix(key, "Bearer ") diff --git a/middleware/distributor.go b/middleware/distributor.go index e7db6d77..1bfe1821 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -162,6 +162,14 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { } c.Set("platform", string(constant.TaskPlatformSuno)) c.Set("relay_mode", relayMode) + } else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") { + // Gemini API 路径处理: /v1beta/models/gemini-2.0-flash:generateContent + relayMode := relayconstant.RelayModeGemini + modelName := extractModelNameFromGeminiPath(c.Request.URL.Path) + if modelName != "" { + modelRequest.Model = modelName + } + c.Set("relay_mode", relayMode) } else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") && !strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") { err = common.UnmarshalBodyReusable(c, &modelRequest) } @@ -244,3 +252,31 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode c.Set("bot_id", channel.Other) } } + +// extractModelNameFromGeminiPath 从 Gemini API URL 路径中提取模型名 +// 输入格式: /v1beta/models/gemini-2.0-flash:generateContent +// 输出: gemini-2.0-flash +func extractModelNameFromGeminiPath(path string) string { + // 查找 "/models/" 的位置 + modelsPrefix := "/models/" + modelsIndex := strings.Index(path, modelsPrefix) + if modelsIndex == -1 { + return "" + } + + // 从 "/models/" 之后开始提取 + startIndex := modelsIndex + len(modelsPrefix) + if startIndex >= len(path) { + return "" + } + + // 查找 ":" 的位置,模型名在 ":" 之前 + colonIndex := strings.Index(path[startIndex:], ":") + if colonIndex == -1 { + // 如果没有找到 ":",返回从 "/models/" 到路径结尾的部分 + return path[startIndex:] + } + + // 返回模型名部分 + return path[startIndex : startIndex+colonIndex] +} diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index c3c7b49d..12833736 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -10,6 +10,7 @@ import ( "one-api/dto" "one-api/relay/channel" relaycommon "one-api/relay/common" + "one-api/relay/constant" "one-api/service" "one-api/setting/model_setting" "strings" @@ -165,6 +166,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { + if info.RelayMode == constant.RelayModeGemini { + err, usage = GeminiTextGenerationHandler(c, resp, info) + return usage, err + } + if strings.HasPrefix(info.UpstreamModelName, "imagen") { return GeminiImageHandler(c, resp, info) } diff --git a/relay/channel/gemini/relay-gemini-native.go b/relay/channel/gemini/relay-gemini-native.go new file mode 100644 index 00000000..16374ea4 --- /dev/null +++ b/relay/channel/gemini/relay-gemini-native.go @@ -0,0 +1,77 @@ +package gemini + +import ( + "encoding/json" + "io" + "net/http" + "one-api/common" + "one-api/dto" + relaycommon "one-api/relay/common" + "one-api/service" + + "github.com/gin-gonic/gin" +) + +func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + // 读取响应体 + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + + if common.DebugEnabled { + println(string(responseBody)) + } + + // 解析为 Gemini 原生响应格式 + var geminiResponse dto.GeminiTextGenerationResponse + err = common.DecodeJson(responseBody, &geminiResponse) + if err != nil { + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + // 检查是否有候选响应 + if len(geminiResponse.Candidates) == 0 { + return &dto.OpenAIErrorWithStatusCode{ + Error: dto.OpenAIError{ + Message: "No candidates returned", + Type: "server_error", + Param: "", + Code: 500, + }, + StatusCode: resp.StatusCode, + }, nil + } + + // 计算使用量(基于 UsageMetadata) + usage := dto.Usage{ + PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount, + CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount, + TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount, + } + + // 设置模型版本 + if geminiResponse.ModelVersion == "" { + geminiResponse.ModelVersion = info.UpstreamModelName + } + + // 直接返回 Gemini 原生格式的 JSON 响应 + jsonResponse, err := json.Marshal(geminiResponse) + if err != nil { + return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + + // 设置响应头并写入响应 + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + if err != nil { + return service.OpenAIErrorWrapper(err, "write_response_failed", http.StatusInternalServerError), nil + } + + return nil, &usage +} diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go index 4454e815..f22a20bd 100644 --- a/relay/constant/relay_mode.go +++ b/relay/constant/relay_mode.go @@ -43,6 +43,8 @@ const ( RelayModeResponses RelayModeRealtime + + RelayModeGemini ) func Path2RelayMode(path string) int { @@ -75,6 +77,8 @@ func Path2RelayMode(path string) int { relayMode = RelayModeRerank } else if strings.HasPrefix(path, "/v1/realtime") { relayMode = RelayModeRealtime + } else if strings.HasPrefix(path, "/v1beta/models") { + relayMode = RelayModeGemini } return relayMode } diff --git a/relay/relay-gemini.go b/relay/relay-gemini.go new file mode 100644 index 00000000..9aa072e1 --- /dev/null +++ b/relay/relay-gemini.go @@ -0,0 +1,141 @@ +package relay + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "net/http" + "one-api/common" + "one-api/dto" + relaycommon "one-api/relay/common" + "one-api/relay/helper" + "one-api/service" + "one-api/setting" + "strings" + + "github.com/gin-gonic/gin" +) + +func getAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiTextGenerationRequest, error) { + request := &dto.GeminiTextGenerationRequest{} + err := common.UnmarshalBodyReusable(c, request) + if err != nil { + return nil, err + } + if len(request.Contents) == 0 { + return nil, errors.New("contents is required") + } + return request, nil +} + +func checkGeminiInputSensitive(textRequest *dto.GeminiTextGenerationRequest, info *relaycommon.RelayInfo) ([]string, error) { + var inputTexts []string + for _, content := range textRequest.Contents { + for _, part := range content.Parts { + if part.Text != "" { + inputTexts = append(inputTexts, part.Text) + } + } + } + if len(inputTexts) == 0 { + return nil, nil + } + + sensitiveWords, err := service.CheckSensitiveInput(inputTexts) + return sensitiveWords, err +} + +func getGeminiInputTokens(req *dto.GeminiTextGenerationRequest, info *relaycommon.RelayInfo) (int, error) { + // 计算输入 token 数量 + var inputTexts []string + for _, content := range req.Contents { + for _, part := range content.Parts { + if part.Text != "" { + inputTexts = append(inputTexts, part.Text) + } + } + } + + inputText := strings.Join(inputTexts, "\n") + inputTokens, err := service.CountTokenInput(inputText, info.UpstreamModelName) + info.PromptTokens = inputTokens + return inputTokens, err +} + +func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { + req, err := getAndValidateGeminiRequest(c) + if err != nil { + common.LogError(c, fmt.Sprintf("getAndValidateGeminiRequest error: %s", err.Error())) + return service.OpenAIErrorWrapperLocal(err, "invalid_gemini_request", http.StatusBadRequest) + } + + relayInfo := relaycommon.GenRelayInfo(c) + + if setting.ShouldCheckPromptSensitive() { + sensitiveWords, err := checkGeminiInputSensitive(req, relayInfo) + if err != nil { + common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", "))) + return service.OpenAIErrorWrapperLocal(err, "check_request_sensitive_error", http.StatusBadRequest) + } + } + + // model mapped 模型映射 + err = helper.ModelMappedHelper(c, relayInfo) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest) + } + + if value, exists := c.Get("prompt_tokens"); exists { + promptTokens := value.(int) + relayInfo.SetPromptTokens(promptTokens) + } else { + promptTokens, err := getGeminiInputTokens(req, relayInfo) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "count_input_tokens_error", http.StatusBadRequest) + } + c.Set("prompt_tokens", promptTokens) + } + + priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, req.GenerationConfig.MaxOutputTokens) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError) + } + + // pre consume quota + preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) + if openaiErr != nil { + return openaiErr + } + defer func() { + if openaiErr != nil { + returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) + } + }() + + adaptor := GetAdaptor(relayInfo.ApiType) + if adaptor == nil { + return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) + } + + adaptor.Init(relayInfo) + + requestBody, err := json.Marshal(req) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "marshal_text_request_failed", http.StatusInternalServerError) + } + + resp, err := adaptor.DoRequest(c, relayInfo, bytes.NewReader(requestBody)) + if err != nil { + common.LogError(c, "Do gemini request failed: "+err.Error()) + return service.OpenAIErrorWrapperLocal(err, "do_request_failed", http.StatusInternalServerError) + } + + usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), relayInfo) + if openaiErr != nil { + return openaiErr + } + + postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") + return nil +} diff --git a/router/relay-router.go b/router/relay-router.go index 4cd84b41..1115a491 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -79,6 +79,14 @@ func SetRelayRouter(router *gin.Engine) { relaySunoRouter.GET("/fetch/:id", controller.RelayTask) } + relayGeminiRouter := router.Group("/v1beta") + relayGeminiRouter.Use(middleware.TokenAuth()) + relayGeminiRouter.Use(middleware.ModelRequestRateLimit()) + relayGeminiRouter.Use(middleware.Distribute()) + { + // Gemini API 路径格式: /v1beta/models/{model_name}:{action} + relayGeminiRouter.POST("/models/*path", controller.Relay) + } } func registerMjRouterGroup(relayMjRouter *gin.RouterGroup) { From d90e4bef63ac262bc3190002bab90180f69acdef Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Mon, 26 May 2025 14:50:50 +0800 Subject: [PATCH 083/105] gemini stream --- dto/gemini.go | 69 ------------------ relay/channel/gemini/adaptor.go | 7 +- relay/channel/gemini/relay-gemini-native.go | 81 +++++++++++++++++---- relay/relay-gemini.go | 28 +++++-- 4 files changed, 93 insertions(+), 92 deletions(-) delete mode 100644 dto/gemini.go diff --git a/dto/gemini.go b/dto/gemini.go deleted file mode 100644 index 898c966f..00000000 --- a/dto/gemini.go +++ /dev/null @@ -1,69 +0,0 @@ -package dto - -import "encoding/json" - -type GeminiPart struct { - Text string `json:"text"` -} - -type GeminiContent struct { - Parts []GeminiPart `json:"parts"` - Role string `json:"role"` -} - -type GeminiCandidate struct { - Content GeminiContent `json:"content"` - FinishReason string `json:"finishReason"` - AvgLogprobs float64 `json:"avgLogprobs"` -} - -type GeminiTokenDetails struct { - Modality string `json:"modality"` - TokenCount int `json:"tokenCount"` -} - -type GeminiUsageMetadata struct { - PromptTokenCount int `json:"promptTokenCount"` - CandidatesTokenCount int `json:"candidatesTokenCount"` - TotalTokenCount int `json:"totalTokenCount"` - PromptTokensDetails []GeminiTokenDetails `json:"promptTokensDetails"` - CandidatesTokensDetails []GeminiTokenDetails `json:"candidatesTokensDetails"` -} - -type GeminiTextGenerationResponse struct { - Candidates []GeminiCandidate `json:"candidates"` - UsageMetadata GeminiUsageMetadata `json:"usageMetadata"` - ModelVersion string `json:"modelVersion"` - ResponseID string `json:"responseId"` -} - -type GeminiGenerationConfig struct { - StopSequences []string `json:"stopSequences,omitempty"` - ResponseMimeType string `json:"responseMimeType,omitempty"` - ResponseSchema *json.RawMessage `json:"responseSchema,omitempty"` - ResponseModalities *json.RawMessage `json:"responseModalities,omitempty"` - CandidateCount int `json:"candidateCount,omitempty"` - MaxOutputTokens int `json:"maxOutputTokens,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"topP,omitempty"` - TopK int `json:"topK,omitempty"` - Seed int `json:"seed,omitempty"` - PresencePenalty float64 `json:"presencePenalty,omitempty"` - FrequencyPenalty float64 `json:"frequencyPenalty,omitempty"` - ResponseLogprobs bool `json:"responseLogprobs,omitempty"` - LogProbs int `json:"logProbs,omitempty"` - EnableEnhancedCivicAnswers bool `json:"enableEnhancedCivicAnswers,omitempty"` - SpeechConfig *json.RawMessage `json:"speechConfig,omitempty"` - ThinkingConfig *json.RawMessage `json:"thinkingConfig,omitempty"` - MediaResolution *json.RawMessage `json:"mediaResolution,omitempty"` -} - -type GeminiTextGenerationRequest struct { - Contents []GeminiContent `json:"contents"` - Tools *json.RawMessage `json:"tools,omitempty"` - ToolConfig *json.RawMessage `json:"toolConfig,omitempty"` - SafetySettings *json.RawMessage `json:"safetySettings,omitempty"` - SystemInstruction *json.RawMessage `json:"systemInstruction,omitempty"` - GenerationConfig GeminiGenerationConfig `json:"generationConfig,omitempty"` - CachedContent *json.RawMessage `json:"cachedContent,omitempty"` -} diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index 12833736..e6f66d5f 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -167,8 +167,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { if info.RelayMode == constant.RelayModeGemini { - err, usage = GeminiTextGenerationHandler(c, resp, info) - return usage, err + if info.IsStream { + return GeminiTextGenerationStreamHandler(c, resp, info) + } else { + return GeminiTextGenerationHandler(c, resp, info) + } } if strings.HasPrefix(info.UpstreamModelName, "imagen") { diff --git a/relay/channel/gemini/relay-gemini-native.go b/relay/channel/gemini/relay-gemini-native.go index 16374ea4..c055e299 100644 --- a/relay/channel/gemini/relay-gemini-native.go +++ b/relay/channel/gemini/relay-gemini-native.go @@ -7,20 +7,21 @@ import ( "one-api/common" "one-api/dto" relaycommon "one-api/relay/common" + "one-api/relay/helper" "one-api/service" "github.com/gin-gonic/gin" ) -func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) { // 读取响应体 responseBody, err := io.ReadAll(resp.Body) if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return nil, service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) } err = resp.Body.Close() if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + return nil, service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) } if common.DebugEnabled { @@ -28,15 +29,15 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela } // 解析为 Gemini 原生响应格式 - var geminiResponse dto.GeminiTextGenerationResponse + var geminiResponse GeminiChatResponse err = common.DecodeJson(responseBody, &geminiResponse) if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return nil, service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) } // 检查是否有候选响应 if len(geminiResponse.Candidates) == 0 { - return &dto.OpenAIErrorWithStatusCode{ + return nil, &dto.OpenAIErrorWithStatusCode{ Error: dto.OpenAIError{ Message: "No candidates returned", Type: "server_error", @@ -44,7 +45,7 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela Code: 500, }, StatusCode: resp.StatusCode, - }, nil + } } // 计算使用量(基于 UsageMetadata) @@ -54,15 +55,10 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount, } - // 设置模型版本 - if geminiResponse.ModelVersion == "" { - geminiResponse.ModelVersion = info.UpstreamModelName - } - // 直接返回 Gemini 原生格式的 JSON 响应 jsonResponse, err := json.Marshal(geminiResponse) if err != nil { - return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return nil, service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError) } // 设置响应头并写入响应 @@ -70,8 +66,63 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela c.Writer.WriteHeader(resp.StatusCode) _, err = c.Writer.Write(jsonResponse) if err != nil { - return service.OpenAIErrorWrapper(err, "write_response_failed", http.StatusInternalServerError), nil + return nil, service.OpenAIErrorWrapper(err, "write_response_failed", http.StatusInternalServerError) } - return nil, &usage + return &usage, nil +} + +func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) { + var usage = &dto.Usage{} + var imageCount int + + helper.SetEventStreamHeaders(c) + + helper.StreamScannerHandler(c, resp, info, func(data string) bool { + var geminiResponse GeminiChatResponse + err := common.DecodeJsonStr(data, &geminiResponse) + if err != nil { + common.LogError(c, "error unmarshalling stream response: "+err.Error()) + return false + } + + // 统计图片数量 + for _, candidate := range geminiResponse.Candidates { + for _, part := range candidate.Content.Parts { + if part.InlineData != nil && part.InlineData.MimeType != "" { + imageCount++ + } + } + } + + // 更新使用量统计 + if geminiResponse.UsageMetadata.TotalTokenCount != 0 { + usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount + usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount + } + + // 直接发送 GeminiChatResponse 响应 + err = helper.ObjectData(c, geminiResponse) + if err != nil { + common.LogError(c, err.Error()) + } + + return true + }) + + if imageCount != 0 { + if usage.CompletionTokens == 0 { + usage.CompletionTokens = imageCount * 258 + } + } + + // 计算最终使用量 + usage.PromptTokensDetails.TextTokens = usage.PromptTokens + usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens + + // 结束流式响应 + helper.Done(c) + + return usage, nil } diff --git a/relay/relay-gemini.go b/relay/relay-gemini.go index 9aa072e1..93a2b7aa 100644 --- a/relay/relay-gemini.go +++ b/relay/relay-gemini.go @@ -8,6 +8,7 @@ import ( "net/http" "one-api/common" "one-api/dto" + "one-api/relay/channel/gemini" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" @@ -17,8 +18,8 @@ import ( "github.com/gin-gonic/gin" ) -func getAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiTextGenerationRequest, error) { - request := &dto.GeminiTextGenerationRequest{} +func getAndValidateGeminiRequest(c *gin.Context) (*gemini.GeminiChatRequest, error) { + request := &gemini.GeminiChatRequest{} err := common.UnmarshalBodyReusable(c, request) if err != nil { return nil, err @@ -29,7 +30,19 @@ func getAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiTextGenerationReque return request, nil } -func checkGeminiInputSensitive(textRequest *dto.GeminiTextGenerationRequest, info *relaycommon.RelayInfo) ([]string, error) { +// 流模式 +// /v1beta/models/gemini-2.0-flash:streamGenerateContent?alt=sse&key=xxx +func checkGeminiStreamMode(c *gin.Context, relayInfo *relaycommon.RelayInfo) { + if c.Query("alt") == "sse" { + relayInfo.IsStream = true + } + + // if strings.Contains(c.Request.URL.Path, "streamGenerateContent") { + // relayInfo.IsStream = true + // } +} + +func checkGeminiInputSensitive(textRequest *gemini.GeminiChatRequest) ([]string, error) { var inputTexts []string for _, content := range textRequest.Contents { for _, part := range content.Parts { @@ -46,7 +59,7 @@ func checkGeminiInputSensitive(textRequest *dto.GeminiTextGenerationRequest, inf return sensitiveWords, err } -func getGeminiInputTokens(req *dto.GeminiTextGenerationRequest, info *relaycommon.RelayInfo) (int, error) { +func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) (int, error) { // 计算输入 token 数量 var inputTexts []string for _, content := range req.Contents { @@ -72,8 +85,11 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { relayInfo := relaycommon.GenRelayInfo(c) + // 检查 Gemini 流式模式 + checkGeminiStreamMode(c, relayInfo) + if setting.ShouldCheckPromptSensitive() { - sensitiveWords, err := checkGeminiInputSensitive(req, relayInfo) + sensitiveWords, err := checkGeminiInputSensitive(req) if err != nil { common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", "))) return service.OpenAIErrorWrapperLocal(err, "check_request_sensitive_error", http.StatusBadRequest) @@ -97,7 +113,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { c.Set("prompt_tokens", promptTokens) } - priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, req.GenerationConfig.MaxOutputTokens) + priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.GenerationConfig.MaxOutputTokens)) if err != nil { return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError) } From 156ad5c3fdc1c547a7b5905a4325bfb8a19cc869 Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Mon, 26 May 2025 15:02:20 +0800 Subject: [PATCH 084/105] vertex --- relay/channel/vertex/adaptor.go | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go index d21a3e08..e58ea762 100644 --- a/relay/channel/vertex/adaptor.go +++ b/relay/channel/vertex/adaptor.go @@ -12,6 +12,7 @@ import ( "one-api/relay/channel/gemini" "one-api/relay/channel/openai" relaycommon "one-api/relay/common" + "one-api/relay/constant" "one-api/setting/model_setting" "strings" @@ -192,7 +193,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom case RequestModeClaude: err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage) case RequestModeGemini: - err, usage = gemini.GeminiChatStreamHandler(c, resp, info) + if info.RelayMode == constant.RelayModeGemini { + usage, err = gemini.GeminiTextGenerationStreamHandler(c, resp, info) + } else { + err, usage = gemini.GeminiChatStreamHandler(c, resp, info) + } case RequestModeLlama: err, usage = openai.OaiStreamHandler(c, resp, info) } @@ -201,7 +206,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom case RequestModeClaude: err, usage = claude.ClaudeHandler(c, resp, claude.RequestModeMessage, info) case RequestModeGemini: - err, usage = gemini.GeminiChatHandler(c, resp, info) + if info.RelayMode == constant.RelayModeGemini { + usage, err = gemini.GeminiTextGenerationHandler(c, resp, info) + } else { + err, usage = gemini.GeminiChatHandler(c, resp, info) + } case RequestModeLlama: err, usage = openai.OpenaiHandler(c, resp, info) } From 368fd75c86aa775a52fcfd6270144213b0c02f2c Mon Sep 17 00:00:00 2001 From: skynono <6811626@qq.com> Date: Mon, 26 May 2025 17:11:45 +0800 Subject: [PATCH 085/105] fix: ali parameter.enable_thinking must be set to false for non-streaming calls --- relay/channel/ali/adaptor.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index ab632d22..31e926d6 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -57,6 +57,12 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn if request == nil { return nil, errors.New("request is nil") } + + // fix: ali parameter.enable_thinking must be set to false for non-streaming calls + if !info.IsStream { + request.EnableThinking = false + } + switch info.RelayMode { default: aliReq := requestOpenAI2Ali(*request) From 30d5a11f466d0ec95d86028bb6455b51e39f7be4 Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Mon, 26 May 2025 18:53:41 +0800 Subject: [PATCH 086/105] fix: search-preview model web search billing --- dto/openai_request.go | 6 ++++++ relay/relay-text.go | 43 +++++++++++++++++++++++++++++++++++++++---- 2 files changed, 45 insertions(+), 4 deletions(-) diff --git a/dto/openai_request.go b/dto/openai_request.go index e8833b3d..e491812a 100644 --- a/dto/openai_request.go +++ b/dto/openai_request.go @@ -53,6 +53,7 @@ type GeneralOpenAIRequest struct { Audio any `json:"audio,omitempty"` EnableThinking any `json:"enable_thinking,omitempty"` // ali ExtraBody any `json:"extra_body,omitempty"` + WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"` } type ToolCallRequest struct { @@ -371,6 +372,11 @@ func (m *Message) ParseContent() []MediaContent { return contentList } +type WebSearchOptions struct { + SearchContextSize string `json:"search_context_size,omitempty"` + UserLocation json.RawMessage `json:"user_location,omitempty"` +} + type OpenAIResponsesRequest struct { Model string `json:"model"` Input json.RawMessage `json:"input,omitempty"` diff --git a/relay/relay-text.go b/relay/relay-text.go index 8d5cd384..f1105907 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -47,6 +47,20 @@ func getAndValidateTextRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo) if textRequest.Model == "" { return nil, errors.New("model is required") } + if textRequest.WebSearchOptions != nil { + if textRequest.WebSearchOptions.SearchContextSize != "" { + validSizes := map[string]bool{ + "high": true, + "medium": true, + "low": true, + } + if !validSizes[textRequest.WebSearchOptions.SearchContextSize] { + return nil, errors.New("invalid search_context_size, must be one of: high, medium, low") + } + } else { + textRequest.WebSearchOptions.SearchContextSize = "medium" + } + } switch relayInfo.RelayMode { case relayconstant.RelayModeCompletions: if textRequest.Prompt == "" { @@ -76,6 +90,10 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { // get & validate textRequest 获取并验证文本请求 textRequest, err := getAndValidateTextRequest(c, relayInfo) + if textRequest.WebSearchOptions != nil { + c.Set("chat_completion_web_search_context_size", textRequest.WebSearchOptions.SearchContextSize) + } + if err != nil { common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error())) return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest) @@ -370,9 +388,20 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, dWebSearchQuota = decimal.NewFromFloat(webSearchPrice). Mul(decimal.NewFromInt(int64(webSearchTool.CallCount))). Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit) - extraContent += fmt.Sprintf("Web Search 调用 %d 次,上下文大小 %s,调用花费 $%s", + extraContent += fmt.Sprintf("Web Search 调用 %d 次,上下文大小 %s,调用花费 %s", webSearchTool.CallCount, webSearchTool.SearchContextSize, dWebSearchQuota.String()) } + } else if strings.HasSuffix(modelName, "search-preview") { + // search-preview 模型不支持 response api + searchContextSize := ctx.GetString("chat_completion_web_search_context_size") + if searchContextSize == "" { + searchContextSize = "medium" + } + webSearchPrice = operation_setting.GetWebSearchPricePerThousand(modelName, searchContextSize) + dWebSearchQuota = decimal.NewFromFloat(webSearchPrice). + Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit) + extraContent += fmt.Sprintf("Web Search 调用 1 次,上下文大小 %s,调用花费 %s", + searchContextSize, dWebSearchQuota.String()) } // file search tool 计费 var dFileSearchQuota decimal.Decimal @@ -463,10 +492,16 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, other["image_ratio"] = imageRatio other["image_output"] = imageTokens } - if !dWebSearchQuota.IsZero() && relayInfo.ResponsesUsageInfo != nil { - if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists { + if !dWebSearchQuota.IsZero() { + if relayInfo.ResponsesUsageInfo != nil { + if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists { + other["web_search"] = true + other["web_search_call_count"] = webSearchTool.CallCount + other["web_search_price"] = webSearchPrice + } + } else if strings.HasSuffix(modelName, "search-preview") { other["web_search"] = true - other["web_search_call_count"] = webSearchTool.CallCount + other["web_search_call_count"] = 1 other["web_search_price"] = webSearchPrice } } From 76824a033769552c10ad352cb240f49e78dd87bb Mon Sep 17 00:00:00 2001 From: "wang.rong" Date: Tue, 27 May 2025 09:32:20 +0800 Subject: [PATCH 087/105] =?UTF-8?q?chat/completion=E9=80=8F=E4=BC=A0parall?= =?UTF-8?q?el=5Ftool=5Fcalls=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dto/openai_request.go | 1 + 1 file changed, 1 insertion(+) diff --git a/dto/openai_request.go b/dto/openai_request.go index e491812a..bda1bb17 100644 --- a/dto/openai_request.go +++ b/dto/openai_request.go @@ -43,6 +43,7 @@ type GeneralOpenAIRequest struct { ResponseFormat *ResponseFormat `json:"response_format,omitempty"` EncodingFormat any `json:"encoding_format,omitempty"` Seed float64 `json:"seed,omitempty"` + ParallelTooCalls *bool `json:"parallel_tool_calls,omitempty"` Tools []ToolCallRequest `json:"tools,omitempty"` ToolChoice any `json:"tool_choice,omitempty"` User string `json:"user,omitempty"` From 96ab4177ca671506432b0dca9331280e078d0f18 Mon Sep 17 00:00:00 2001 From: skynono <6811626@qq.com> Date: Mon, 26 May 2025 17:22:13 +0800 Subject: [PATCH 088/105] fix: ali FetchUpstreamModels url --- controller/channel.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/controller/channel.go b/controller/channel.go index ad85fe24..a31e1f47 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -119,8 +119,11 @@ func FetchUpstreamModels(c *gin.Context) { baseURL = channel.GetBaseURL() } url := fmt.Sprintf("%s/v1/models", baseURL) - if channel.Type == common.ChannelTypeGemini { + switch channel.Type { + case common.ChannelTypeGemini: url = fmt.Sprintf("%s/v1beta/openai/models", baseURL) + case common.ChannelTypeAli: + url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL) } body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) if err != nil { From 6e72dcd0ba00781bd14c4830c06993c524dad319 Mon Sep 17 00:00:00 2001 From: tbphp Date: Tue, 27 May 2025 21:50:53 +0800 Subject: [PATCH 089/105] fix: Vertex channel global region format --- relay/channel/vertex/adaptor.go | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go index d21a3e08..b75136bf 100644 --- a/relay/channel/vertex/adaptor.go +++ b/relay/channel/vertex/adaptor.go @@ -95,14 +95,23 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { } else { suffix = "generateContent" } - return fmt.Sprintf( - "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s", - region, - adc.ProjectID, - region, - info.UpstreamModelName, - suffix, - ), nil + if region == "global" { + return fmt.Sprintf( + "https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:%s", + adc.ProjectID, + info.UpstreamModelName, + suffix, + ), nil + } else { + return fmt.Sprintf( + "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s", + region, + adc.ProjectID, + region, + info.UpstreamModelName, + suffix, + ), nil + } } else if a.RequestMode == RequestModeClaude { if info.IsStream { suffix = "streamRawPredict?alt=sse" From e3d7b31a49743cfa99a17cfafc7b2de540d2eb3f Mon Sep 17 00:00:00 2001 From: IcedTangerine Date: Wed, 28 May 2025 14:25:24 +0800 Subject: [PATCH 090/105] Update openai_request.go --- dto/openai_request.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dto/openai_request.go b/dto/openai_request.go index bda1bb17..78706f9c 100644 --- a/dto/openai_request.go +++ b/dto/openai_request.go @@ -43,7 +43,7 @@ type GeneralOpenAIRequest struct { ResponseFormat *ResponseFormat `json:"response_format,omitempty"` EncodingFormat any `json:"encoding_format,omitempty"` Seed float64 `json:"seed,omitempty"` - ParallelTooCalls *bool `json:"parallel_tool_calls,omitempty"` + ParallelTooCalls bool `json:"parallel_tool_calls,omitempty"` Tools []ToolCallRequest `json:"tools,omitempty"` ToolChoice any `json:"tool_choice,omitempty"` User string `json:"user,omitempty"` From f613a79f3e234a23af552fc49980e66b91edc7ff Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Wed, 28 May 2025 20:18:37 +0800 Subject: [PATCH 091/105] feat: Enhance image request validation in relay-image.go: set default model and size, improve error handling for size format, and ensure prompt and N parameters are validated correctly. --- relay/relay-image.go | 96 ++++++++++++++------------------------------ 1 file changed, 31 insertions(+), 65 deletions(-) diff --git a/relay/relay-image.go b/relay/relay-image.go index daed3d80..36b4b9f8 100644 --- a/relay/relay-image.go +++ b/relay/relay-image.go @@ -46,11 +46,23 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto. if err != nil { return nil, err } + + if imageRequest.Model == "" { + imageRequest.Model = "dall-e-3" + } + + if strings.Contains(imageRequest.Size, "×") { + return nil, errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'") + } + // Not "256x256", "512x512", or "1024x1024" if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" { if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" { return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024 for dall-e-2 or dall-e") } + if imageRequest.Size == "" { + imageRequest.Size = "1024x1024" + } } else if imageRequest.Model == "dall-e-3" { if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" { return nil, errors.New("size must be one of 1024x1024, 1024x1792 or 1792x1024 for dall-e-3") @@ -58,74 +70,24 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto. if imageRequest.Quality == "" { imageRequest.Quality = "standard" } - // N should between 1 and 10 - //if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) { - // return service.OpenAIErrorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest) - //} + if imageRequest.Size == "" { + imageRequest.Size = "1024x1024" + } + } else if imageRequest.Model == "gpt-image-1" { + if imageRequest.Quality == "" { + imageRequest.Quality = "auto" + } + } + + if imageRequest.Prompt == "" { + return nil, errors.New("prompt is required") + } + + if imageRequest.N == 0 { + imageRequest.N = 1 } } - if imageRequest.Prompt == "" { - return nil, errors.New("prompt is required") - } - - if imageRequest.Model == "" { - imageRequest.Model = "dall-e-2" - } - if strings.Contains(imageRequest.Size, "×") { - return nil, errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'") - } - if imageRequest.N == 0 { - imageRequest.N = 1 - } - if imageRequest.Size == "" { - imageRequest.Size = "1024x1024" - } - - err := common.UnmarshalBodyReusable(c, imageRequest) - if err != nil { - return nil, err - } - if imageRequest.Prompt == "" { - return nil, errors.New("prompt is required") - } - if strings.Contains(imageRequest.Size, "×") { - return nil, errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'") - } - if imageRequest.N == 0 { - imageRequest.N = 1 - } - if imageRequest.Size == "" { - imageRequest.Size = "1024x1024" - } - if imageRequest.Model == "" { - imageRequest.Model = "dall-e-2" - } - // x.ai grok-2-image not support size, quality or style - if imageRequest.Size == "empty" { - imageRequest.Size = "" - } - - // Not "256x256", "512x512", or "1024x1024" - if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" { - if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" { - return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024") - } - } else if imageRequest.Model == "dall-e-3" { - if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" { - return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024") - } - if imageRequest.Quality == "" { - imageRequest.Quality = "standard" - } - //if imageRequest.N != 1 { - // return nil, errors.New("n must be 1") - //} - } - // N should between 1 and 10 - //if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) { - // return service.OpenAIErrorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest) - //} if setting.ShouldCheckPromptSensitive() { words, err := service.CheckSensitiveInput(imageRequest.Prompt) if err != nil { @@ -229,6 +191,10 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { requestBody = bytes.NewBuffer(jsonData) } + if common.DebugEnabled { + println(fmt.Sprintf("image request body: %s", requestBody)) + } + statusCodeMappingStr := c.GetString("status_code_mapping") resp, err := adaptor.DoRequest(c, relayInfo, requestBody) From e01b5178439a6920293b98d15b71e24d9789afc1 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Wed, 28 May 2025 21:12:55 +0800 Subject: [PATCH 092/105] fix: Change ParallelTooCalls from bool to *bool in GeneralOpenAIRequest for optional handling --- dto/openai_request.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dto/openai_request.go b/dto/openai_request.go index 78706f9c..bda1bb17 100644 --- a/dto/openai_request.go +++ b/dto/openai_request.go @@ -43,7 +43,7 @@ type GeneralOpenAIRequest struct { ResponseFormat *ResponseFormat `json:"response_format,omitempty"` EncodingFormat any `json:"encoding_format,omitempty"` Seed float64 `json:"seed,omitempty"` - ParallelTooCalls bool `json:"parallel_tool_calls,omitempty"` + ParallelTooCalls *bool `json:"parallel_tool_calls,omitempty"` Tools []ToolCallRequest `json:"tools,omitempty"` ToolChoice any `json:"tool_choice,omitempty"` User string `json:"user,omitempty"` From 361b0abec9f37b7019991536478fd1dcf9d075d9 Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Wed, 28 May 2025 21:34:45 +0800 Subject: [PATCH 093/105] =?UTF-8?q?fix:=20pingerCtx=20=E6=B3=84=E6=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- relay/channel/api_request.go | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go index 03eff9cf..da8d4e14 100644 --- a/relay/channel/api_request.go +++ b/relay/channel/api_request.go @@ -122,11 +122,13 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http var pingerWg sync.WaitGroup if info.IsStream { helper.SetEventStreamHeaders(c) - pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second - var pingerCtx context.Context - pingerCtx, stopPinger = context.WithCancel(c.Request.Context()) if pingEnabled { + pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second + var pingerCtx context.Context + pingerCtx, stopPinger = context.WithCancel(c.Request.Context()) + // 退出时清理 pingerCtx 防止泄露 + defer stopPinger() pingerWg.Add(1) gopool.Go(func() { defer pingerWg.Done() @@ -166,9 +168,8 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http } resp, err := client.Do(req) - // request结束后停止ping + // request结束后等待 ping goroutine 完成 if info.IsStream && pingEnabled { - stopPinger() pingerWg.Wait() } if err != nil { From d608a6f12398f2a52617951685c6990877740def Mon Sep 17 00:00:00 2001 From: Akkuman Date: Thu, 29 May 2025 10:56:01 +0800 Subject: [PATCH 094/105] feat: streaming response for tts --- relay/channel/openai/relay-openai.go | 37 ++++++++++------------------ 1 file changed, 13 insertions(+), 24 deletions(-) diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 86c47a15..2e3d8df1 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -273,36 +273,25 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI } func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil - } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - // Reset response body - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) - // We shouldn't set the header before we parse the response body, because the parse part may fail. - // And then we will have to send an error response, but in this case, the header has already been set. - // So the httpClient will be confused by the response. - // For example, Postman will report error, and we cannot check the response at all. + // the status code has been judged before, if there is a body reading failure, + // it should be regarded as a non-recoverable error, so it should not return err for external retry. + // Analogous to nginx's load balancing, it will only retry if it can't be requested or + // if the upstream returns a specific status code, once the upstream has already written the header, + // the subsequent failure of the response body should be regarded as a non-recoverable error, + // and can be terminated directly. + defer resp.Body.Close() + usage := &dto.Usage{} + usage.PromptTokens = info.PromptTokens + usage.TotalTokens = info.PromptTokens for k, v := range resp.Header { c.Writer.Header().Set(k, v[0]) } c.Writer.WriteHeader(resp.StatusCode) - _, err = io.Copy(c.Writer, resp.Body) + c.Writer.WriteHeaderNow() + _, err := io.Copy(c.Writer, resp.Body) if err != nil { - return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil + common.LogError(c, err.Error()) } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - - usage := &dto.Usage{} - usage.PromptTokens = info.PromptTokens - usage.TotalTokens = info.PromptTokens return nil, usage } From 1b64db55215bf3fb6e10d69b7da30126ed9d1f5a Mon Sep 17 00:00:00 2001 From: RedwindA <128586631+RedwindA@users.noreply.github.com> Date: Thu, 29 May 2025 12:33:27 +0800 Subject: [PATCH 095/105] Add `ERROR_LOG_ENABLED` description --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index a807b07d..5d0014f9 100644 --- a/README.md +++ b/README.md @@ -110,6 +110,7 @@ New API提供了丰富的功能,详细特性请参考[特性说明](https://do - `AZURE_DEFAULT_API_VERSION`:Azure渠道默认API版本,默认 `2025-04-01-preview` - `NOTIFICATION_LIMIT_DURATION_MINUTE`:通知限制持续时间,默认 `10`分钟 - `NOTIFY_LIMIT_COUNT`:用户通知在指定持续时间内的最大数量,默认 `2` +- `ERROR_LOG_ENABLED=true`: 是否记录并显示错误日志,默认`false` ## 部署 From f907c25b21137e8d7a94caa9a8450913e980b941 Mon Sep 17 00:00:00 2001 From: RedwindA <128586631+RedwindA@users.noreply.github.com> Date: Thu, 29 May 2025 12:35:13 +0800 Subject: [PATCH 096/105] Add `ERROR_LOG_ENABLED` description --- README.en.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.en.md b/README.en.md index 4709bc5b..ad11f386 100644 --- a/README.en.md +++ b/README.en.md @@ -110,6 +110,7 @@ For detailed configuration instructions, please refer to [Installation Guide-Env - `AZURE_DEFAULT_API_VERSION`: Azure channel default API version, default is `2025-04-01-preview` - `NOTIFICATION_LIMIT_DURATION_MINUTE`: Notification limit duration, default is `10` minutes - `NOTIFY_LIMIT_COUNT`: Maximum number of user notifications within the specified duration, default is `2` +- `ERROR_LOG_ENABLED=true`: Whether to record and display error logs, default is `false` ## Deployment From 1c4d7fd84b55519235cd88e48cf14cd383275281 Mon Sep 17 00:00:00 2001 From: xqx121 <78908927+xqx121@users.noreply.github.com> Date: Sat, 31 May 2025 17:50:00 +0800 Subject: [PATCH 097/105] Fix: Gemini2.5pro ThinkingConfig --- relay/channel/gemini/relay-gemini.go | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index da0bc5fc..9ab167b1 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -39,15 +39,22 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon } if model_setting.GetGeminiSettings().ThinkingAdapterEnabled { - if strings.HasSuffix(info.OriginModelName, "-thinking") { - budgetTokens := model_setting.GetGeminiSettings().ThinkingAdapterBudgetTokensPercentage * float64(geminiRequest.GenerationConfig.MaxOutputTokens) - if budgetTokens == 0 || budgetTokens > 24576 { - budgetTokens = 24576 - } - geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{ - ThinkingBudget: common.GetPointer(int(budgetTokens)), - IncludeThoughts: true, - } + if strings.HasSuffix(info.OriginModelName, "-thinking") { + // 如果模型名以 gemini-2.5-pro 开头,不设置 ThinkingBudget + if strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro") { + geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{ + IncludeThoughts: true, + } + } else { + budgetTokens := model_setting.GetGeminiSettings().ThinkingAdapterBudgetTokensPercentage * float64(geminiRequest.GenerationConfig.MaxOutputTokens) + if budgetTokens == 0 || budgetTokens > 24576 { + budgetTokens = 24576 + } + geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{ + ThinkingBudget: common.GetPointer(int(budgetTokens)), + IncludeThoughts: true, + } + } } else if strings.HasSuffix(info.OriginModelName, "-nothinking") { geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{ ThinkingBudget: common.GetPointer(0), From c51a30b862525aa4af9bfdd510cdd59ba301b9b5 Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Sat, 31 May 2025 22:13:17 +0800 Subject: [PATCH 098/105] =?UTF-8?q?fix:=20=E6=B5=81=E5=BC=8F=E8=AF=B7?= =?UTF-8?q?=E6=B1=82ping?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- relay/channel/api_request.go | 114 ++++++++++++++++++++--------------- 1 file changed, 66 insertions(+), 48 deletions(-) diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go index da8d4e14..1d733bd4 100644 --- a/relay/channel/api_request.go +++ b/relay/channel/api_request.go @@ -104,6 +104,65 @@ func DoWssRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody return targetConn, nil } +func startPingKeepAlive(c *gin.Context, pingInterval time.Duration) context.CancelFunc { + pingerCtx, stopPinger := context.WithCancel(context.Background()) + + gopool.Go(func() { + defer func() { + if common2.DebugEnabled { + println("SSE ping goroutine stopped.") + } + }() + + if pingInterval <= 0 { + pingInterval = helper.DefaultPingInterval + } + + ticker := time.NewTicker(pingInterval) + // 退出时清理 ticker + defer ticker.Stop() + + var pingMutex sync.Mutex + if common2.DebugEnabled { + println("SSE ping goroutine started") + } + + for { + select { + // 发送 ping 数据 + case <-ticker.C: + if err := sendPingData(c, &pingMutex); err != nil { + return + } + // 收到退出信号 + case <-pingerCtx.Done(): + return + // request 结束 + case <-c.Request.Context().Done(): + return + } + } + }) + + return stopPinger +} + +func sendPingData(c *gin.Context, mutex *sync.Mutex) error { + mutex.Lock() + defer mutex.Unlock() + + err := helper.PingData(c) + if err != nil { + common2.LogError(c, "SSE ping error: "+err.Error()) + return err + } + + if common2.DebugEnabled { + println("SSE ping data sent.") + } + return nil +} + func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http.Response, error) { var client *http.Client var err error @@ -115,69 +174,28 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http } else { client = service.GetHttpClient() } - // 流式请求 ping 保活 - var stopPinger func() - generalSettings := operation_setting.GetGeneralSetting() - pingEnabled := generalSettings.PingIntervalEnabled - var pingerWg sync.WaitGroup + if info.IsStream { helper.SetEventStreamHeaders(c) - if pingEnabled { + // 处理流式请求的 ping 保活 + generalSettings := operation_setting.GetGeneralSetting() + if generalSettings.PingIntervalEnabled { pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second - var pingerCtx context.Context - pingerCtx, stopPinger = context.WithCancel(c.Request.Context()) - // 退出时清理 pingerCtx 防止泄露 + stopPinger := startPingKeepAlive(c, pingInterval) defer stopPinger() - pingerWg.Add(1) - gopool.Go(func() { - defer pingerWg.Done() - if pingInterval <= 0 { - pingInterval = helper.DefaultPingInterval - } - - ticker := time.NewTicker(pingInterval) - defer ticker.Stop() - var pingMutex sync.Mutex - if common2.DebugEnabled { - println("SSE ping goroutine started") - } - - for { - select { - case <-ticker.C: - pingMutex.Lock() - err2 := helper.PingData(c) - pingMutex.Unlock() - if err2 != nil { - common2.LogError(c, "SSE ping error: "+err.Error()) - return - } - if common2.DebugEnabled { - println("SSE ping data sent.") - } - case <-pingerCtx.Done(): - if common2.DebugEnabled { - println("SSE ping goroutine stopped.") - } - return - } - } - }) } } resp, err := client.Do(req) - // request结束后等待 ping goroutine 完成 - if info.IsStream && pingEnabled { - pingerWg.Wait() - } + if err != nil { return nil, err } if resp == nil { return nil, errors.New("resp is nil") } + _ = req.Body.Close() _ = c.Request.Body.Close() return resp, nil From 611d77e1a9f94a5ceacf8380d4f3513dac0fcaaf Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Sun, 1 Jun 2025 01:10:10 +0800 Subject: [PATCH 099/105] feat: add ToMap method and enhance OpenAI request handling --- dto/openai_request.go | 12 ++++++++++-- relay/channel/baidu_v2/adaptor.go | 13 +++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/dto/openai_request.go b/dto/openai_request.go index bda1bb17..16cdf3a2 100644 --- a/dto/openai_request.go +++ b/dto/openai_request.go @@ -2,6 +2,7 @@ package dto import ( "encoding/json" + "one-api/common" "strings" ) @@ -57,6 +58,13 @@ type GeneralOpenAIRequest struct { WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"` } +func (r *GeneralOpenAIRequest) ToMap() map[string]any { + result := make(map[string]any) + data, _ := common.EncodeJson(r) + _ = common.DecodeJson(data, &result) + return result +} + type ToolCallRequest struct { ID string `json:"id,omitempty"` Type string `json:"type"` @@ -74,11 +82,11 @@ type StreamOptions struct { IncludeUsage bool `json:"include_usage,omitempty"` } -func (r GeneralOpenAIRequest) GetMaxTokens() int { +func (r *GeneralOpenAIRequest) GetMaxTokens() int { return int(r.MaxTokens) } -func (r GeneralOpenAIRequest) ParseInput() []string { +func (r *GeneralOpenAIRequest) ParseInput() []string { if r.Input == nil { return nil } diff --git a/relay/channel/baidu_v2/adaptor.go b/relay/channel/baidu_v2/adaptor.go index 77afe2dd..2b8a52a2 100644 --- a/relay/channel/baidu_v2/adaptor.go +++ b/relay/channel/baidu_v2/adaptor.go @@ -9,6 +9,7 @@ import ( "one-api/relay/channel" "one-api/relay/channel/openai" relaycommon "one-api/relay/common" + "strings" "github.com/gin-gonic/gin" ) @@ -49,6 +50,18 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn if request == nil { return nil, errors.New("request is nil") } + if strings.HasSuffix(info.UpstreamModelName, "-search") { + info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-search") + request.Model = info.UpstreamModelName + toMap := request.ToMap() + toMap["web_search"] = map[string]any{ + "enable": true, + "enable_citation": true, + "enable_trace": true, + "enable_status": false, + } + return toMap, nil + } return request, nil } From f1ee9a301d04018861f38c21f9922dcb8e4eaefb Mon Sep 17 00:00:00 2001 From: RedwindA Date: Fri, 23 May 2025 20:02:50 +0800 Subject: [PATCH 100/105] refactor: enhance cleanFunctionParameters for improved handling of JSON schema, including support for $defs and conditional keywords --- relay/channel/gemini/relay-gemini.go | 169 +++++++++++++++------------ 1 file changed, 93 insertions(+), 76 deletions(-) diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index 9ab167b1..c75745ad 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -297,94 +297,111 @@ func cleanFunctionParameters(params interface{}) interface{} { return nil } - paramMap, ok := params.(map[string]interface{}) - if !ok { - // Not a map, return as is (e.g., could be an array or primitive) - return params - } + switch v := params.(type) { + case map[string]interface{}: + // Create a copy to avoid modifying the original + cleanedMap := make(map[string]interface{}) + for k, val := range v { + cleanedMap[k] = val + } - // Create a copy to avoid modifying the original - cleanedMap := make(map[string]interface{}) - for k, v := range paramMap { - cleanedMap[k] = v - } + // Remove unsupported root-level fields + delete(cleanedMap, "default") + delete(cleanedMap, "exclusiveMaximum") + delete(cleanedMap, "exclusiveMinimum") + delete(cleanedMap, "$schema") + delete(cleanedMap, "additionalProperties") - // Remove unsupported root-level fields - delete(cleanedMap, "default") - delete(cleanedMap, "exclusiveMaximum") - delete(cleanedMap, "exclusiveMinimum") - delete(cleanedMap, "$schema") - delete(cleanedMap, "additionalProperties") - - // Clean properties - if props, ok := cleanedMap["properties"].(map[string]interface{}); ok && props != nil { - cleanedProps := make(map[string]interface{}) - for propName, propValue := range props { - propMap, ok := propValue.(map[string]interface{}) - if !ok { - cleanedProps[propName] = propValue // Keep non-map properties - continue - } - - // Create a copy of the property map - cleanedPropMap := make(map[string]interface{}) - for k, v := range propMap { - cleanedPropMap[k] = v - } - - // Remove unsupported fields - delete(cleanedPropMap, "default") - delete(cleanedPropMap, "exclusiveMaximum") - delete(cleanedPropMap, "exclusiveMinimum") - delete(cleanedPropMap, "$schema") - delete(cleanedPropMap, "additionalProperties") - - // Check and clean 'format' for string types - if propType, typeExists := cleanedPropMap["type"].(string); typeExists && propType == "string" { - if formatValue, formatExists := cleanedPropMap["format"].(string); formatExists { - if formatValue != "enum" && formatValue != "date-time" { - delete(cleanedPropMap, "format") - } + // Check and clean 'format' for string types + if propType, typeExists := cleanedMap["type"].(string); typeExists && propType == "string" { + if formatValue, formatExists := cleanedMap["format"].(string); formatExists { + if formatValue != "enum" && formatValue != "date-time" { + delete(cleanedMap, "format") } } + } - // Recursively clean nested properties within this property if it's an object/array - // Check the type before recursing - if propType, typeExists := cleanedPropMap["type"].(string); typeExists && (propType == "object" || propType == "array") { - cleanedProps[propName] = cleanFunctionParameters(cleanedPropMap) - } else { - cleanedProps[propName] = cleanedPropMap // Assign the cleaned map back if not recursing + // Clean properties + if props, ok := cleanedMap["properties"].(map[string]interface{}); ok && props != nil { + cleanedProps := make(map[string]interface{}) + for propName, propValue := range props { + cleanedProps[propName] = cleanFunctionParameters(propValue) } - + cleanedMap["properties"] = cleanedProps } - cleanedMap["properties"] = cleanedProps - } - // Recursively clean items in arrays if needed (e.g., type: array, items: { ... }) - if items, ok := cleanedMap["items"].(map[string]interface{}); ok && items != nil { - cleanedMap["items"] = cleanFunctionParameters(items) - } - // Also handle items if it's an array of schemas - if itemsArray, ok := cleanedMap["items"].([]interface{}); ok { - cleanedItemsArray := make([]interface{}, len(itemsArray)) - for i, item := range itemsArray { - cleanedItemsArray[i] = cleanFunctionParameters(item) + // Recursively clean items in arrays + if items, ok := cleanedMap["items"].(map[string]interface{}); ok && items != nil { + cleanedMap["items"] = cleanFunctionParameters(items) } - cleanedMap["items"] = cleanedItemsArray - } - - // Recursively clean other schema composition keywords if necessary - for _, field := range []string{"allOf", "anyOf", "oneOf"} { - if nested, ok := cleanedMap[field].([]interface{}); ok { - cleanedNested := make([]interface{}, len(nested)) - for i, item := range nested { - cleanedNested[i] = cleanFunctionParameters(item) + // Also handle items if it's an array of schemas + if itemsArray, ok := cleanedMap["items"].([]interface{}); ok { + cleanedItemsArray := make([]interface{}, len(itemsArray)) + for i, item := range itemsArray { + cleanedItemsArray[i] = cleanFunctionParameters(item) } - cleanedMap[field] = cleanedNested + cleanedMap["items"] = cleanedItemsArray } - } - return cleanedMap + // Recursively clean other schema composition keywords + for _, field := range []string{"allOf", "anyOf", "oneOf"} { + if nested, ok := cleanedMap[field].([]interface{}); ok { + cleanedNested := make([]interface{}, len(nested)) + for i, item := range nested { + cleanedNested[i] = cleanFunctionParameters(item) + } + cleanedMap[field] = cleanedNested + } + } + + // Recursively clean patternProperties + if patternProps, ok := cleanedMap["patternProperties"].(map[string]interface{}); ok { + cleanedPatternProps := make(map[string]interface{}) + for pattern, schema := range patternProps { + cleanedPatternProps[pattern] = cleanFunctionParameters(schema) + } + cleanedMap["patternProperties"] = cleanedPatternProps + } + + // Recursively clean definitions + if definitions, ok := cleanedMap["definitions"].(map[string]interface{}); ok { + cleanedDefinitions := make(map[string]interface{}) + for defName, defSchema := range definitions { + cleanedDefinitions[defName] = cleanFunctionParameters(defSchema) + } + cleanedMap["definitions"] = cleanedDefinitions + } + + // Recursively clean $defs (newer JSON Schema draft) + if defs, ok := cleanedMap["$defs"].(map[string]interface{}); ok { + cleanedDefs := make(map[string]interface{}) + for defName, defSchema := range defs { + cleanedDefs[defName] = cleanFunctionParameters(defSchema) + } + cleanedMap["$defs"] = cleanedDefs + } + + // Clean conditional keywords + for _, field := range []string{"if", "then", "else", "not"} { + if nested, ok := cleanedMap[field]; ok { + cleanedMap[field] = cleanFunctionParameters(nested) + } + } + + return cleanedMap + + case []interface{}: + // Handle arrays of schemas + cleanedArray := make([]interface{}, len(v)) + for i, item := range v { + cleanedArray[i] = cleanFunctionParameters(item) + } + return cleanedArray + + default: + // Not a map or array, return as is (e.g., could be a primitive) + return params + } } func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interface{} { From 148c9749123d6c35924ab012d89565c84c856f86 Mon Sep 17 00:00:00 2001 From: RedwindA Date: Mon, 2 Jun 2025 19:00:55 +0800 Subject: [PATCH 101/105] =?UTF-8?q?feat:=20=E5=A2=9E=E5=8A=A0=E5=AF=B9Gemi?= =?UTF-8?q?niMIME=E7=B1=BB=E5=9E=8B=E7=9A=84=E9=AA=8C=E8=AF=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- relay/channel/gemini/relay-gemini.go | 39 +++++++++++++++++++++++++--- 1 file changed, 36 insertions(+), 3 deletions(-) diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index 9ab167b1..5dff8ab6 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -18,6 +18,24 @@ import ( "github.com/gin-gonic/gin" ) +var geminiSupportedMimeTypes = map[string]bool{ + "application/pdf": true, + "audio/mpeg": true, + "audio/mp3": true, + "audio/wav": true, + "image/png": true, + "image/jpeg": true, + "text/plain": true, + "video/mov": true, + "video/mpeg": true, + "video/mp4": true, + "video/mpg": true, + "video/avi": true, + "video/wmv": true, + "video/mpegps": true, + "video/flv": true, +} + // Setting safety to the lowest possible values since Gemini is already powerless enough func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*GeminiChatRequest, error) { @@ -215,14 +233,20 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon } // 判断是否是url if strings.HasPrefix(part.GetImageMedia().Url, "http") { - // 是url,获取图片的类型和base64编码的数据 + // 是url,获取文件的类型和base64编码的数据 fileData, err := service.GetFileBase64FromUrl(part.GetImageMedia().Url) if err != nil { - return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error()) + return nil, fmt.Errorf("get file base64 from url '%s' failed: %w", part.GetImageMedia().Url, err) } + + // 校验 MimeType 是否在 Gemini 支持的白名单中 + if _, ok := geminiSupportedMimeTypes[strings.ToLower(fileData.MimeType)]; !ok { + return nil, fmt.Errorf("MIME type '%s' from URL '%s' is not supported by Gemini. Supported types are: %v", fileData.MimeType, part.GetImageMedia().Url, getSupportedMimeTypesList()) + } + parts = append(parts, GeminiPart{ InlineData: &GeminiInlineData{ - MimeType: fileData.MimeType, + MimeType: fileData.MimeType, // 使用原始的 MimeType,因为大小写可能对API有意义 Data: fileData.Base64Data, }, }) @@ -291,6 +315,15 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon return &geminiRequest, nil } +// Helper function to get a list of supported MIME types for error messages +func getSupportedMimeTypesList() []string { + keys := make([]string, 0, len(geminiSupportedMimeTypes)) + for k := range geminiSupportedMimeTypes { + keys = append(keys, k) + } + return keys +} + // cleanFunctionParameters recursively removes unsupported fields from Gemini function parameters. func cleanFunctionParameters(params interface{}) interface{} { if params == nil { From 37caafc722676ea91d6d74658104141fce63fca2 Mon Sep 17 00:00:00 2001 From: xqx121 <78908927+xqx121@users.noreply.github.com> Date: Mon, 2 Jun 2025 22:11:11 +0800 Subject: [PATCH 102/105] Fix: The edit interface is not billed (usage-based pricing). --- relay/relay-image.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/relay/relay-image.go b/relay/relay-image.go index 36b4b9f8..dc63cce8 100644 --- a/relay/relay-image.go +++ b/relay/relay-image.go @@ -41,6 +41,9 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto. imageRequest.Quality = "standard" } } + if imageRequest.N == 0 { + imageRequest.N = 1 + } default: err := common.UnmarshalBodyReusable(c, imageRequest) if err != nil { From 0af047b18c8ecbe500c87cd6f4fcda89947ab93a Mon Sep 17 00:00:00 2001 From: RedwindA Date: Thu, 5 Jun 2025 02:09:21 +0800 Subject: [PATCH 103/105] Add DeepWiki Badge in README --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index 5d0014f9..e9d1c154 100644 --- a/README.md +++ b/README.md @@ -44,6 +44,9 @@ 详细文档请访问我们的官方Wiki:[https://docs.newapi.pro/](https://docs.newapi.pro/) +也可访问AI生成的DeepWiki: +[![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/QuantumNous/new-api) + ## ✨ 主要特性 New API提供了丰富的功能,详细特性请参考[特性说明](https://docs.newapi.pro/wiki/features-introduction): From a8f4ae2a734310131d40459cd31d168cd5952204 Mon Sep 17 00:00:00 2001 From: "Apple\\Apple" Date: Thu, 5 Jun 2025 11:27:00 +0800 Subject: [PATCH 104/105] =?UTF-8?q?=F0=9F=93=95docs:=20Add=20DeepWiki=20Ba?= =?UTF-8?q?dge=20in=20`README.en.md`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.en.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.en.md b/README.en.md index ad11f386..10a3cdb0 100644 --- a/README.en.md +++ b/README.en.md @@ -44,6 +44,9 @@ For detailed documentation, please visit our official Wiki: [https://docs.newapi.pro/](https://docs.newapi.pro/) +You can also access the AI-generated DeepWiki: +[![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/QuantumNous/new-api) + ## ✨ Key Features New API offers a wide range of features, please refer to [Features Introduction](https://docs.newapi.pro/wiki/features-introduction) for details: From 3665ad672ef43b86b234aa4dc7a9c052e00f3dde Mon Sep 17 00:00:00 2001 From: neotf Date: Thu, 5 Jun 2025 17:35:48 +0800 Subject: [PATCH 105/105] feat: support claude cache and thinking for upstream [OpenRouter] (#983) * feat: support claude cache for upstream [OpenRouter] * feat: support claude thinking for upstream [OpenRouter] * feat: reasoning is common params for OpenRouter --- dto/claude.go | 23 +++++++-------- dto/openai_request.go | 5 +++- relay/channel/openrouter/dto.go | 9 ++++++ service/convert.go | 50 ++++++++++++++++++++++++++------- 4 files changed, 65 insertions(+), 22 deletions(-) create mode 100644 relay/channel/openrouter/dto.go diff --git a/dto/claude.go b/dto/claude.go index 8068feb8..36dfc02e 100644 --- a/dto/claude.go +++ b/dto/claude.go @@ -7,17 +7,18 @@ type ClaudeMetadata struct { } type ClaudeMediaMessage struct { - Type string `json:"type,omitempty"` - Text *string `json:"text,omitempty"` - Model string `json:"model,omitempty"` - Source *ClaudeMessageSource `json:"source,omitempty"` - Usage *ClaudeUsage `json:"usage,omitempty"` - StopReason *string `json:"stop_reason,omitempty"` - PartialJson *string `json:"partial_json,omitempty"` - Role string `json:"role,omitempty"` - Thinking string `json:"thinking,omitempty"` - Signature string `json:"signature,omitempty"` - Delta string `json:"delta,omitempty"` + Type string `json:"type,omitempty"` + Text *string `json:"text,omitempty"` + Model string `json:"model,omitempty"` + Source *ClaudeMessageSource `json:"source,omitempty"` + Usage *ClaudeUsage `json:"usage,omitempty"` + StopReason *string `json:"stop_reason,omitempty"` + PartialJson *string `json:"partial_json,omitempty"` + Role string `json:"role,omitempty"` + Thinking string `json:"thinking,omitempty"` + Signature string `json:"signature,omitempty"` + Delta string `json:"delta,omitempty"` + CacheControl json.RawMessage `json:"cache_control,omitempty"` // tool_calls Id string `json:"id,omitempty"` Name string `json:"name,omitempty"` diff --git a/dto/openai_request.go b/dto/openai_request.go index 16cdf3a2..a7325fe8 100644 --- a/dto/openai_request.go +++ b/dto/openai_request.go @@ -29,7 +29,6 @@ type GeneralOpenAIRequest struct { MaxTokens uint `json:"max_tokens,omitempty"` MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"` ReasoningEffort string `json:"reasoning_effort,omitempty"` - //Reasoning json.RawMessage `json:"reasoning,omitempty"` Temperature *float64 `json:"temperature,omitempty"` TopP float64 `json:"top_p,omitempty"` TopK int `json:"top_k,omitempty"` @@ -56,6 +55,8 @@ type GeneralOpenAIRequest struct { EnableThinking any `json:"enable_thinking,omitempty"` // ali ExtraBody any `json:"extra_body,omitempty"` WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"` + // OpenRouter Params + Reasoning json.RawMessage `json:"reasoning,omitempty"` } func (r *GeneralOpenAIRequest) ToMap() map[string]any { @@ -125,6 +126,8 @@ type MediaContent struct { InputAudio any `json:"input_audio,omitempty"` File any `json:"file,omitempty"` VideoUrl any `json:"video_url,omitempty"` + // OpenRouter Params + CacheControl json.RawMessage `json:"cache_control,omitempty"` } func (m *MediaContent) GetImageMedia() *MessageImageUrl { diff --git a/relay/channel/openrouter/dto.go b/relay/channel/openrouter/dto.go new file mode 100644 index 00000000..607f495b --- /dev/null +++ b/relay/channel/openrouter/dto.go @@ -0,0 +1,9 @@ +package openrouter + +type RequestReasoning struct { + // One of the following (not both): + Effort string `json:"effort,omitempty"` // Can be "high", "medium", or "low" (OpenAI-style) + MaxTokens int `json:"max_tokens,omitempty"` // Specific token limit (Anthropic-style) + // Optional: Default is false. All models support this. + Exclude bool `json:"exclude,omitempty"` // Set to true to exclude reasoning tokens from response +} diff --git a/service/convert.go b/service/convert.go index cc462b40..cb964a46 100644 --- a/service/convert.go +++ b/service/convert.go @@ -5,6 +5,7 @@ import ( "fmt" "one-api/common" "one-api/dto" + "one-api/relay/channel/openrouter" relaycommon "one-api/relay/common" "strings" ) @@ -18,10 +19,24 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest, info *relaycommon.Re Stream: claudeRequest.Stream, } + isOpenRouter := info.ChannelType == common.ChannelTypeOpenRouter + if claudeRequest.Thinking != nil { - if strings.HasSuffix(info.OriginModelName, "-thinking") && - !strings.HasSuffix(claudeRequest.Model, "-thinking") { - openAIRequest.Model = openAIRequest.Model + "-thinking" + if isOpenRouter { + reasoning := openrouter.RequestReasoning{ + MaxTokens: claudeRequest.Thinking.BudgetTokens, + } + reasoningJSON, err := json.Marshal(reasoning) + if err != nil { + return nil, fmt.Errorf("failed to marshal reasoning: %w", err) + } + openAIRequest.Reasoning = reasoningJSON + } else { + thinkingSuffix := "-thinking" + if strings.HasSuffix(info.OriginModelName, thinkingSuffix) && + !strings.HasSuffix(openAIRequest.Model, thinkingSuffix) { + openAIRequest.Model = openAIRequest.Model + thinkingSuffix + } } } @@ -62,16 +77,30 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest, info *relaycommon.Re } else { systems := claudeRequest.ParseSystem() if len(systems) > 0 { - systemStr := "" openAIMessage := dto.Message{ Role: "system", } - for _, system := range systems { - if system.Text != nil { - systemStr += *system.Text + isOpenRouterClaude := isOpenRouter && strings.HasPrefix(info.UpstreamModelName, "anthropic/claude") + if isOpenRouterClaude { + systemMediaMessages := make([]dto.MediaContent, 0, len(systems)) + for _, system := range systems { + message := dto.MediaContent{ + Type: "text", + Text: system.GetText(), + CacheControl: system.CacheControl, + } + systemMediaMessages = append(systemMediaMessages, message) } + openAIMessage.SetMediaContent(systemMediaMessages) + } else { + systemStr := "" + for _, system := range systems { + if system.Text != nil { + systemStr += *system.Text + } + } + openAIMessage.SetStringContent(systemStr) } - openAIMessage.SetStringContent(systemStr) openAIMessages = append(openAIMessages, openAIMessage) } } @@ -97,8 +126,9 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest, info *relaycommon.Re switch mediaMsg.Type { case "text": message := dto.MediaContent{ - Type: "text", - Text: mediaMsg.GetText(), + Type: "text", + Text: mediaMsg.GetText(), + CacheControl: mediaMsg.CacheControl, } mediaMessages = append(mediaMessages, message) case "image":