From 18b3300ff1d431a46a7cadf9a48db5c3c6ac0519 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Mon, 5 May 2025 00:40:16 +0800 Subject: [PATCH 01/57] feat: implement OpenAI responses handling and streaming support with built-in tool tracking --- dto/openai_response.go | 11 ++- relay/channel/openai/adaptor.go | 2 +- relay/channel/openai/helper.go | 7 ++ relay/channel/openai/relay-openai.go | 99 -------------------- relay/channel/openai/relay_responses.go | 114 ++++++++++++++++++++++++ relay/channel/vertex/adaptor.go | 2 +- relay/common/relay_info.go | 37 ++++++++ relay/relay-responses.go | 10 +-- 8 files changed, 173 insertions(+), 109 deletions(-) create mode 100644 relay/channel/openai/relay_responses.go diff --git a/dto/openai_response.go b/dto/openai_response.go index 1508d1f6..c8f61b9d 100644 --- a/dto/openai_response.go +++ b/dto/openai_response.go @@ -238,8 +238,12 @@ type ResponsesOutputContent struct { } const ( - BuildInTools_WebSearch = "web_search_preview" - BuildInTools_FileSearch = "file_search" + BuildInToolWebSearchPreview = "web_search_preview" + BuildInToolFileSearch = "file_search" +) + +const ( + BuildInCallWebSearchCall = "web_search_call" ) const ( @@ -250,6 +254,7 @@ const ( // ResponsesStreamResponse 用于处理 /v1/responses 流式响应 type ResponsesStreamResponse struct { Type string `json:"type"` - Response *OpenAIResponsesResponse `json:"response"` + Response *OpenAIResponsesResponse `json:"response,omitempty"` Delta string `json:"delta,omitempty"` + Item *ResponsesOutput `json:"item,omitempty"` } diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 7740c498..eb12a22a 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -429,7 +429,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom if info.IsStream { err, usage = OaiResponsesStreamHandler(c, resp, info) } else { - err, usage = OpenaiResponsesHandler(c, resp, info) + err, usage = OaiResponsesHandler(c, resp, info) } default: if info.IsStream { diff --git a/relay/channel/openai/helper.go b/relay/channel/openai/helper.go index e7ba2e7b..a068c544 100644 --- a/relay/channel/openai/helper.go +++ b/relay/channel/openai/helper.go @@ -187,3 +187,10 @@ func handleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream } } } + +func sendResponsesStreamData(c *gin.Context, streamResponse dto.ResponsesStreamResponse, data string) { + if data == "" { + return + } + helper.ResponseChunkData(c, streamResponse, data) +} diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index ef660564..b9ed94e2 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -644,102 +644,3 @@ func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycomm } return nil, &usageResp.Usage } - -func OpenaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { - // read response body - var responsesResponse dto.OpenAIResponsesResponse - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil - } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - err = common.DecodeJson(responseBody, &responsesResponse) - if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil - } - if responsesResponse.Error != nil { - return &dto.OpenAIErrorWithStatusCode{ - Error: dto.OpenAIError{ - Message: responsesResponse.Error.Message, - Type: "openai_error", - Code: responsesResponse.Error.Code, - }, - StatusCode: resp.StatusCode, - }, nil - } - - // reset response body - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) - // We shouldn't set the header before we parse the response body, because the parse part may fail. - // And then we will have to send an error response, but in this case, the header has already been set. - // So the httpClient will be confused by the response. - // For example, Postman will report error, and we cannot check the response at all. - for k, v := range resp.Header { - c.Writer.Header().Set(k, v[0]) - } - c.Writer.WriteHeader(resp.StatusCode) - // copy response body - _, err = io.Copy(c.Writer, resp.Body) - if err != nil { - common.SysError("error copying response body: " + err.Error()) - } - resp.Body.Close() - // compute usage - usage := dto.Usage{} - usage.PromptTokens = responsesResponse.Usage.InputTokens - usage.CompletionTokens = responsesResponse.Usage.OutputTokens - usage.TotalTokens = responsesResponse.Usage.TotalTokens - return nil, &usage -} - -func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { - if resp == nil || resp.Body == nil { - common.LogError(c, "invalid response or response body") - return service.OpenAIErrorWrapper(fmt.Errorf("invalid response"), "invalid_response", http.StatusInternalServerError), nil - } - - var usage = &dto.Usage{} - var responseTextBuilder strings.Builder - - helper.StreamScannerHandler(c, resp, info, func(data string) bool { - - // 检查当前数据是否包含 completed 状态和 usage 信息 - var streamResponse dto.ResponsesStreamResponse - if err := common.DecodeJsonStr(data, &streamResponse); err == nil { - sendResponsesStreamData(c, streamResponse, data) - switch streamResponse.Type { - case "response.completed": - usage.PromptTokens = streamResponse.Response.Usage.InputTokens - usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens - usage.TotalTokens = streamResponse.Response.Usage.TotalTokens - case "response.output_text.delta": - // 处理输出文本 - responseTextBuilder.WriteString(streamResponse.Delta) - - } - } - return true - }) - - if usage.CompletionTokens == 0 { - // 计算输出文本的 token 数量 - tempStr := responseTextBuilder.String() - if len(tempStr) > 0 { - // 非正常结束,使用输出文本的 token 数量 - completionTokens, _ := service.CountTextToken(tempStr, info.UpstreamModelName) - usage.CompletionTokens = completionTokens - } - } - - return nil, usage -} - -func sendResponsesStreamData(c *gin.Context, streamResponse dto.ResponsesStreamResponse, data string) { - if data == "" { - return - } - helper.ResponseChunkData(c, streamResponse, data) -} diff --git a/relay/channel/openai/relay_responses.go b/relay/channel/openai/relay_responses.go new file mode 100644 index 00000000..6af8c676 --- /dev/null +++ b/relay/channel/openai/relay_responses.go @@ -0,0 +1,114 @@ +package openai + +import ( + "bytes" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/common" + "one-api/dto" + relaycommon "one-api/relay/common" + "one-api/relay/helper" + "one-api/service" + "strings" +) + +func OaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + // read response body + var responsesResponse dto.OpenAIResponsesResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = common.DecodeJson(responseBody, &responsesResponse) + if err != nil { + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if responsesResponse.Error != nil { + return &dto.OpenAIErrorWithStatusCode{ + Error: dto.OpenAIError{ + Message: responsesResponse.Error.Message, + Type: "openai_error", + Code: responsesResponse.Error.Code, + }, + StatusCode: resp.StatusCode, + }, nil + } + + // reset response body + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + // We shouldn't set the header before we parse the response body, because the parse part may fail. + // And then we will have to send an error response, but in this case, the header has already been set. + // So the httpClient will be confused by the response. + // For example, Postman will report error, and we cannot check the response at all. + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + c.Writer.WriteHeader(resp.StatusCode) + // copy response body + _, err = io.Copy(c.Writer, resp.Body) + if err != nil { + common.SysError("error copying response body: " + err.Error()) + } + resp.Body.Close() + // compute usage + usage := dto.Usage{} + usage.PromptTokens = responsesResponse.Usage.InputTokens + usage.CompletionTokens = responsesResponse.Usage.OutputTokens + usage.TotalTokens = responsesResponse.Usage.TotalTokens + return nil, &usage +} + +func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + if resp == nil || resp.Body == nil { + common.LogError(c, "invalid response or response body") + return service.OpenAIErrorWrapper(fmt.Errorf("invalid response"), "invalid_response", http.StatusInternalServerError), nil + } + + var usage = &dto.Usage{} + var responseTextBuilder strings.Builder + + helper.StreamScannerHandler(c, resp, info, func(data string) bool { + + // 检查当前数据是否包含 completed 状态和 usage 信息 + var streamResponse dto.ResponsesStreamResponse + if err := common.DecodeJsonStr(data, &streamResponse); err == nil { + sendResponsesStreamData(c, streamResponse, data) + switch streamResponse.Type { + case "response.completed": + usage.PromptTokens = streamResponse.Response.Usage.InputTokens + usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens + usage.TotalTokens = streamResponse.Response.Usage.TotalTokens + case "response.output_text.delta": + // 处理输出文本 + responseTextBuilder.WriteString(streamResponse.Delta) + case dto.ResponsesOutputTypeItemDone: + // 函数调用处理 + if streamResponse.Item != nil { + switch streamResponse.Item.Type { + case dto.BuildInCallWebSearchCall: + info.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview].CallCount++ + } + } + } + } + return true + }) + + if usage.CompletionTokens == 0 { + // 计算输出文本的 token 数量 + tempStr := responseTextBuilder.String() + if len(tempStr) > 0 { + // 非正常结束,使用输出文本的 token 数量 + completionTokens, _ := service.CountTextToken(tempStr, info.UpstreamModelName) + usage.CompletionTokens = completionTokens + } + } + + return nil, usage +} diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go index c1b64f11..7daf9a61 100644 --- a/relay/channel/vertex/adaptor.go +++ b/relay/channel/vertex/adaptor.go @@ -11,8 +11,8 @@ import ( "one-api/relay/channel/claude" "one-api/relay/channel/gemini" "one-api/relay/channel/openai" - "one-api/setting/model_setting" relaycommon "one-api/relay/common" + "one-api/setting/model_setting" "strings" "github.com/gin-gonic/gin" diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 915474e1..99c6d12b 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -36,6 +36,7 @@ type ClaudeConvertInfo struct { const ( RelayFormatOpenAI = "openai" RelayFormatClaude = "claude" + RelayFormatGemini = "gemini" ) type RerankerInfo struct { @@ -43,6 +44,16 @@ type RerankerInfo struct { ReturnDocuments bool } +type BuildInToolInfo struct { + ToolName string + CallCount int + SearchContextSize string +} + +type ResponsesUsageInfo struct { + BuiltInTools map[string]*BuildInToolInfo +} + type RelayInfo struct { ChannelType int ChannelId int @@ -90,6 +101,7 @@ type RelayInfo struct { ThinkingContentInfo *ClaudeConvertInfo *RerankerInfo + *ResponsesUsageInfo } // 定义支持流式选项的通道类型 @@ -134,6 +146,31 @@ func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo { return info } +func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *RelayInfo { + info := GenRelayInfo(c) + info.RelayMode = relayconstant.RelayModeResponses + info.ResponsesUsageInfo = &ResponsesUsageInfo{ + BuiltInTools: make(map[string]*BuildInToolInfo), + } + if len(req.Tools) > 0 { + for _, tool := range req.Tools { + info.ResponsesUsageInfo.BuiltInTools[tool.Type] = &BuildInToolInfo{ + ToolName: tool.Type, + CallCount: 0, + } + switch tool.Type { + case dto.BuildInToolWebSearchPreview: + if tool.SearchContextSize == "" { + tool.SearchContextSize = "medium" + } + info.ResponsesUsageInfo.BuiltInTools[tool.Type].SearchContextSize = tool.SearchContextSize + } + } + } + info.IsStream = req.Stream + return info +} + func GenRelayInfo(c *gin.Context) *RelayInfo { channelType := c.GetInt("channel_type") channelId := c.GetInt("channel_id") diff --git a/relay/relay-responses.go b/relay/relay-responses.go index cdb37ae7..fd3ddb5a 100644 --- a/relay/relay-responses.go +++ b/relay/relay-responses.go @@ -19,7 +19,7 @@ import ( "github.com/gin-gonic/gin" ) -func getAndValidateResponsesRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo) (*dto.OpenAIResponsesRequest, error) { +func getAndValidateResponsesRequest(c *gin.Context) (*dto.OpenAIResponsesRequest, error) { request := &dto.OpenAIResponsesRequest{} err := common.UnmarshalBodyReusable(c, request) if err != nil { @@ -31,13 +31,11 @@ func getAndValidateResponsesRequest(c *gin.Context, relayInfo *relaycommon.Relay if len(request.Input) == 0 { return nil, errors.New("input is required") } - relayInfo.IsStream = request.Stream return request, nil } func checkInputSensitive(textRequest *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) ([]string, error) { - sensitiveWords, err := service.CheckSensitiveInput(textRequest.Input) return sensitiveWords, err } @@ -49,12 +47,14 @@ func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo } func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { - relayInfo := relaycommon.GenRelayInfo(c) - req, err := getAndValidateResponsesRequest(c, relayInfo) + req, err := getAndValidateResponsesRequest(c) if err != nil { common.LogError(c, fmt.Sprintf("getAndValidateResponsesRequest error: %s", err.Error())) return service.OpenAIErrorWrapperLocal(err, "invalid_responses_request", http.StatusBadRequest) } + + relayInfo := relaycommon.GenRelayInfoResponses(c, req) + if setting.ShouldCheckPromptSensitive() { sensitiveWords, err := checkInputSensitive(req, relayInfo) if err != nil { From 6c3fb7777ec3fe4874b249251120e68b5e22642f Mon Sep 17 00:00:00 2001 From: tbphp Date: Mon, 5 May 2025 07:31:54 +0800 Subject: [PATCH 02/57] =?UTF-8?q?feat:=20=E5=A2=9E=E5=8A=A0=E5=88=86?= =?UTF-8?q?=E7=BB=84=E9=80=9F=E7=8E=87=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- middleware/model-rate-limit.go | 37 +++++++-- model/option.go | 47 +++++++++++- setting/rate_limit.go | 70 +++++++++++++++++ web/src/components/RateLimitSetting.js | 1 + .../RateLimit/SettingsRequestRateLimit.js | 76 +++++++++++++++++-- 5 files changed, 214 insertions(+), 17 deletions(-) diff --git a/middleware/model-rate-limit.go b/middleware/model-rate-limit.go index 581dc451..d4199ece 100644 --- a/middleware/model-rate-limit.go +++ b/middleware/model-rate-limit.go @@ -168,16 +168,39 @@ func ModelRequestRateLimit() func(c *gin.Context) { return } - // 计算限流参数 + // 计算通用限流参数 duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60) - totalMaxCount := setting.ModelRequestRateLimitCount - successMaxCount := setting.ModelRequestRateLimitSuccessCount - // 根据存储类型选择并执行限流处理器 - if common.RedisEnabled { - redisRateLimitHandler(duration, totalMaxCount, successMaxCount)(c) + // 获取用户组 + group := c.GetString("token_group") + if group == "" { + group = c.GetString("group") + } + if group == "" { + group = "default" // 默认组 + } + + // 尝试获取用户组特定的限制 + groupTotalCount, groupSuccessCount, found := setting.GetGroupRateLimit(group) + + // 确定最终的限制值 + finalTotalCount := setting.ModelRequestRateLimitCount // 默认使用全局总次数限制 + finalSuccessCount := setting.ModelRequestRateLimitSuccessCount // 默认使用全局成功次数限制 + + if found { + // 如果找到用户组特定限制,则使用它们 + finalTotalCount = groupTotalCount + finalSuccessCount = groupSuccessCount + common.LogWarn(c.Request.Context(), fmt.Sprintf("Using rate limit for group '%s': total=%d, success=%d", group, finalTotalCount, finalSuccessCount)) } else { - memoryRateLimitHandler(duration, totalMaxCount, successMaxCount)(c) + common.LogInfo(c.Request.Context(), fmt.Sprintf("No specific rate limit found for group '%s', using global limits: total=%d, success=%d", group, finalTotalCount, finalSuccessCount)) + } + + // 根据存储类型选择并执行限流处理器,传入最终确定的限制值 + if common.RedisEnabled { + redisRateLimitHandler(duration, finalTotalCount, finalSuccessCount)(c) + } else { + memoryRateLimitHandler(duration, finalTotalCount, finalSuccessCount)(c) } } } diff --git a/model/option.go b/model/option.go index d575742f..1f5fb3aa 100644 --- a/model/option.go +++ b/model/option.go @@ -1,6 +1,8 @@ package model import ( + "encoding/json" + "fmt" "one-api/common" "one-api/setting" "one-api/setting/config" @@ -96,6 +98,7 @@ func InitOptionMap() { common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString() common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString() common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString() + common.OptionMap[setting.ModelRequestRateLimitGroupKey] = "{}" // 添加用户组速率限制默认值 common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString() common.OptionMap["CompletionRatio"] = operation_setting.CompletionRatio2JSONString() common.OptionMap["TopUpLink"] = common.TopUpLink @@ -150,7 +153,32 @@ func SyncOptions(frequency int) { } func UpdateOption(key string, value string) error { - // Save to database first + originalValue := value // 保存原始值以备后用 + + // Validate and format specific keys before saving + if key == setting.ModelRequestRateLimitGroupKey { + var cfg map[string][2]int + // Validate the JSON structure first using the original value + err := json.Unmarshal([]byte(originalValue), &cfg) + if err != nil { + // 提供更具体的错误信息 + return fmt.Errorf("无效的 JSON 格式 for %s: %w", key, err) + } + // TODO: 可以添加更细致的结构验证,例如检查数组长度是否为2,值是否为非负数等。 + // if !isValidModelRequestRateLimitGroupConfig(cfg) { + // return fmt.Errorf("无效的配置值 for %s", key) + // } + + // If valid, format the JSON before saving + formattedValueBytes, marshalErr := json.MarshalIndent(cfg, "", " ") + if marshalErr != nil { + // This should ideally not happen if validation passed, but handle defensively + return fmt.Errorf("failed to marshal validated %s config: %w", key, marshalErr) + } + value = string(formattedValueBytes) // Use formatted JSON for saving and memory update + } + + // Save to database option := Option{ Key: key, } @@ -160,8 +188,12 @@ func UpdateOption(key string, value string) error { // Save is a combination function. // If save value does not contain primary key, it will execute Create, // otherwise it will execute Update (with all fields). - DB.Save(&option) - // Update OptionMap + if err := DB.Save(&option).Error; err != nil { + return fmt.Errorf("保存选项 %s 到数据库失败: %w", key, err) // 添加错误上下文 + } + + // Update OptionMap in memory using the potentially formatted value + // updateOptionMap 会处理内存中 setting.ModelRequestRateLimitGroupConfig 的更新 return updateOptionMap(key, value) } @@ -372,6 +404,15 @@ func updateOptionMap(key string, value string) (err error) { operation_setting.AutomaticDisableKeywordsFromString(value) case "StreamCacheQueueLength": setting.StreamCacheQueueLength, _ = strconv.Atoi(value) + case setting.ModelRequestRateLimitGroupKey: + // Use the (potentially formatted) value passed from UpdateOption + // to update the actual configuration in memory. + // This is the single point where the memory state for this specific setting is updated. + err = setting.UpdateModelRequestRateLimitGroupConfig(value) + if err != nil { + // 添加错误上下文 + err = fmt.Errorf("更新内存中的 %s 配置失败: %w", key, err) + } } return err } diff --git a/setting/rate_limit.go b/setting/rate_limit.go index 4b216948..c83885a6 100644 --- a/setting/rate_limit.go +++ b/setting/rate_limit.go @@ -1,6 +1,76 @@ package setting +import ( + "encoding/json" + "fmt" + "one-api/common" + "sync" +) + var ModelRequestRateLimitEnabled = false var ModelRequestRateLimitDurationMinutes = 1 var ModelRequestRateLimitCount = 0 var ModelRequestRateLimitSuccessCount = 1000 + +// ModelRequestRateLimitGroupKey 定义了模型请求按组速率限制的配置键 +const ModelRequestRateLimitGroupKey = "ModelRequestRateLimitGroup" + +// ModelRequestRateLimitGroupConfig 存储按用户组解析后的速率限制配置 +// map[groupName][2]int{totalCount, successCount} +var ModelRequestRateLimitGroupConfig map[string][2]int +var ModelRequestRateLimitGroupMutex sync.RWMutex + +// UpdateModelRequestRateLimitGroupConfig 解析、校验并更新内存中的用户组速率限制配置 +func UpdateModelRequestRateLimitGroupConfig(jsonStr string) error { + ModelRequestRateLimitGroupMutex.Lock() + defer ModelRequestRateLimitGroupMutex.Unlock() + + var newConfig map[string][2]int + if jsonStr == "" || jsonStr == "{}" { + // 如果配置为空或空JSON对象,则清空内存配置 + ModelRequestRateLimitGroupConfig = make(map[string][2]int) + common.SysLog("Model request rate limit group config cleared") + return nil + } + + err := json.Unmarshal([]byte(jsonStr), &newConfig) + if err != nil { + return fmt.Errorf("failed to unmarshal ModelRequestRateLimitGroup config: %w", err) + } + + // 校验配置值 + for group, limits := range newConfig { + if len(limits) != 2 { + return fmt.Errorf("invalid config for group '%s': limits array length must be 2", group) + } + if limits[1] <= 0 { // successCount must be greater than 0 + return fmt.Errorf("invalid config for group '%s': successCount (limits[1]) must be greater than 0", group) + } + if limits[0] < 0 { // totalCount can be 0 (no limit) or positive + return fmt.Errorf("invalid config for group '%s': totalCount (limits[0]) cannot be negative", group) + } + if limits[0] > 0 && limits[0] < limits[1] { // If totalCount is set, it must be >= successCount + return fmt.Errorf("invalid config for group '%s': totalCount (limits[0]) must be greater than or equal to successCount (limits[1]) when totalCount > 0", group) + } + } + + ModelRequestRateLimitGroupConfig = newConfig + common.SysLog("Model request rate limit group config updated") + return nil +} + +// GetGroupRateLimit 安全地获取指定用户组的速率限制值 +func GetGroupRateLimit(group string) (totalCount, successCount int, found bool) { + ModelRequestRateLimitGroupMutex.RLock() + defer ModelRequestRateLimitGroupMutex.RUnlock() + + if ModelRequestRateLimitGroupConfig == nil { + return 0, 0, false // 配置尚未初始化 + } + + limits, found := ModelRequestRateLimitGroupConfig[group] + if !found { + return 0, 0, false + } + return limits[0], limits[1], true +} diff --git a/web/src/components/RateLimitSetting.js b/web/src/components/RateLimitSetting.js index e06038d6..ad6b53da 100644 --- a/web/src/components/RateLimitSetting.js +++ b/web/src/components/RateLimitSetting.js @@ -13,6 +13,7 @@ const RateLimitSetting = () => { ModelRequestRateLimitCount: 0, ModelRequestRateLimitSuccessCount: 1000, ModelRequestRateLimitDurationMinutes: 1, + ModelRequestRateLimitGroup: {}, }); let [loading, setLoading] = useState(false); diff --git a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js index 800e9636..ec1c2158 100644 --- a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js +++ b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js @@ -18,6 +18,7 @@ export default function RequestRateLimit(props) { ModelRequestRateLimitCount: -1, ModelRequestRateLimitSuccessCount: 1000, ModelRequestRateLimitDurationMinutes: 1, + ModelRequestRateLimitGroup: '{}', // 添加新字段并设置默认值 }); const refForm = useRef(); const [inputsRow, setInputsRow] = useState(inputs); @@ -32,25 +33,49 @@ export default function RequestRateLimit(props) { } else { value = inputs[item.key]; } + // 校验 ModelRequestRateLimitGroup 是否为有效的 JSON 对象字符串 + if (item.key === 'ModelRequestRateLimitGroup') { + try { + JSON.parse(value); + } catch (e) { + showError(t('用户组速率限制配置不是有效的 JSON 格式!')); + // 阻止请求发送 + return Promise.reject('Invalid JSON format'); + } + } return API.put('/api/option/', { key: item.key, value, }); }); + + // 过滤掉无效的请求(例如,无效的 JSON) + const validRequests = requestQueue.filter(req => req !== null && req !== undefined && typeof req.then === 'function'); + + if (validRequests.length === 0 && requestQueue.length > 0) { + // 如果所有请求都被过滤掉了(因为 JSON 无效),则不继续执行 + return; + } + setLoading(true); - Promise.all(requestQueue) + Promise.all(validRequests) .then((res) => { - if (requestQueue.length === 1) { + if (validRequests.length === 1) { if (res.includes(undefined)) return; - } else if (requestQueue.length > 1) { + } else if (validRequests.length > 1) { if (res.includes(undefined)) return showError(t('部分保存失败,请重试')); } showSuccess(t('保存成功')); props.refresh(); + // 更新 inputsRow 以反映保存后的状态 + setInputsRow(structuredClone(inputs)); }) - .catch(() => { - showError(t('保存失败,请重试')); + .catch((error) => { + // 检查是否是由于无效 JSON 导致的错误 + if (error !== 'Invalid JSON format') { + showError(t('保存失败,请重试')); + } }) .finally(() => { setLoading(false); @@ -66,8 +91,11 @@ export default function RequestRateLimit(props) { } setInputs(currentInputs); setInputsRow(structuredClone(currentInputs)); - refForm.current.setValues(currentInputs); - }, [props.options]); + // 检查 refForm.current 是否存在 + if (refForm.current) { + refForm.current.setValues(currentInputs); + } + }, [props.options]); // 依赖项保持不变,因为 inputs 状态的结构已固定 return ( <> @@ -147,7 +175,41 @@ export default function RequestRateLimit(props) { /> + {/* 用户组速率限制配置项 */} + + +

{t('说明:')}

+
    +
  • {t('使用 JSON 对象格式,键为用户组名 (字符串),值为包含两个整数的数组 [总次数限制, 成功次数限制]。')}
  • +
  • {t('总次数限制: 周期内允许的总请求次数 (含失败),0 代表不限制。')}
  • +
  • {t('成功次数限制: 周期内允许的成功请求次数 (HTTP < 400),必须大于 0。')}
  • +
  • {t('此配置将优先于上方的全局限制设置。')}
  • +
  • {t('未在此处配置的用户组将使用全局限制。')}
  • +
  • {t('限制周期统一使用上方配置的“限制周期”值。')}
  • +
  • {t('输入无效的 JSON 将无法保存。')}
  • +
+ + } + autosize={{ minRows: 5, maxRows: 15 }} + style={{ fontFamily: 'monospace' }} + onChange={(value) => { + setInputs({ + ...inputs, + ModelRequestRateLimitGroup: value, // 直接更新字符串值 + }); + }} + /> + +
+ From 7e7d6112ca460be5c30a6c89fb4165346a6d5651 Mon Sep 17 00:00:00 2001 From: tbphp Date: Mon, 5 May 2025 11:34:57 +0800 Subject: [PATCH 03/57] =?UTF-8?q?feat:=20=E4=BC=98=E5=8C=96=E4=BB=A3?= =?UTF-8?q?=E7=A0=81=EF=BC=8C=E5=8E=BB=E9=99=A4=E5=A4=9A=E4=BD=99=E6=B3=A8?= =?UTF-8?q?=E9=87=8A=E5=92=8C=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .old/option.go | 402 ++++++++++++++++++ middleware/model-rate-limit.go | 43 +- model/option.go | 40 +- setting/rate_limit.go | 39 +- web/src/components/RateLimitSetting.js | 92 ++-- .../RateLimit/SettingsRequestRateLimit.js | 388 +++++++++-------- 6 files changed, 663 insertions(+), 341 deletions(-) create mode 100644 .old/option.go diff --git a/.old/option.go b/.old/option.go new file mode 100644 index 00000000..f80f5cb3 --- /dev/null +++ b/.old/option.go @@ -0,0 +1,402 @@ +package model + +import ( + "one-api/common" + "one-api/setting" + "one-api/setting/config" + "one-api/setting/operation_setting" + "strconv" + "strings" + "time" +) + +type Option struct { + Key string `json:"key" gorm:"primaryKey"` + Value string `json:"value"` +} + +func AllOption() ([]*Option, error) { + var options []*Option + var err error + err = DB.Find(&options).Error + return options, err +} + +func InitOptionMap() { + common.OptionMapRWMutex.Lock() + common.OptionMap = make(map[string]string) + + // 添加原有的系统配置 + common.OptionMap["FileUploadPermission"] = strconv.Itoa(common.FileUploadPermission) + common.OptionMap["FileDownloadPermission"] = strconv.Itoa(common.FileDownloadPermission) + common.OptionMap["ImageUploadPermission"] = strconv.Itoa(common.ImageUploadPermission) + common.OptionMap["ImageDownloadPermission"] = strconv.Itoa(common.ImageDownloadPermission) + common.OptionMap["PasswordLoginEnabled"] = strconv.FormatBool(common.PasswordLoginEnabled) + common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled) + common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled) + common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled) + common.OptionMap["LinuxDOOAuthEnabled"] = strconv.FormatBool(common.LinuxDOOAuthEnabled) + common.OptionMap["TelegramOAuthEnabled"] = strconv.FormatBool(common.TelegramOAuthEnabled) + common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled) + common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled) + common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled) + common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled) + common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled) + common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled) + common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled) + common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled) + common.OptionMap["DrawingEnabled"] = strconv.FormatBool(common.DrawingEnabled) + common.OptionMap["TaskEnabled"] = strconv.FormatBool(common.TaskEnabled) + common.OptionMap["DataExportEnabled"] = strconv.FormatBool(common.DataExportEnabled) + common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64) + common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled) + common.OptionMap["EmailAliasRestrictionEnabled"] = strconv.FormatBool(common.EmailAliasRestrictionEnabled) + common.OptionMap["EmailDomainWhitelist"] = strings.Join(common.EmailDomainWhitelist, ",") + common.OptionMap["SMTPServer"] = "" + common.OptionMap["SMTPFrom"] = "" + common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort) + common.OptionMap["SMTPAccount"] = "" + common.OptionMap["SMTPToken"] = "" + common.OptionMap["SMTPSSLEnabled"] = strconv.FormatBool(common.SMTPSSLEnabled) + common.OptionMap["Notice"] = "" + common.OptionMap["About"] = "" + common.OptionMap["HomePageContent"] = "" + common.OptionMap["Footer"] = common.Footer + common.OptionMap["SystemName"] = common.SystemName + common.OptionMap["Logo"] = common.Logo + common.OptionMap["ServerAddress"] = "" + common.OptionMap["WorkerUrl"] = setting.WorkerUrl + common.OptionMap["WorkerValidKey"] = setting.WorkerValidKey + common.OptionMap["PayAddress"] = "" + common.OptionMap["CustomCallbackAddress"] = "" + common.OptionMap["EpayId"] = "" + common.OptionMap["EpayKey"] = "" + common.OptionMap["Price"] = strconv.FormatFloat(setting.Price, 'f', -1, 64) + common.OptionMap["MinTopUp"] = strconv.Itoa(setting.MinTopUp) + common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString() + common.OptionMap["Chats"] = setting.Chats2JsonString() + common.OptionMap["GitHubClientId"] = "" + common.OptionMap["GitHubClientSecret"] = "" + common.OptionMap["TelegramBotToken"] = "" + common.OptionMap["TelegramBotName"] = "" + common.OptionMap["WeChatServerAddress"] = "" + common.OptionMap["WeChatServerToken"] = "" + common.OptionMap["WeChatAccountQRCodeImageURL"] = "" + common.OptionMap["TurnstileSiteKey"] = "" + common.OptionMap["TurnstileSecretKey"] = "" + common.OptionMap["QuotaForNewUser"] = strconv.Itoa(common.QuotaForNewUser) + common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter) + common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee) + common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold) + common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota) + common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount) + common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes) + common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount) + common.OptionMap["ModelRatio"] = operation_setting.ModelRatio2JSONString() + common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString() + common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString() + common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString() + common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString() + common.OptionMap["CompletionRatio"] = operation_setting.CompletionRatio2JSONString() + common.OptionMap["TopUpLink"] = common.TopUpLink + //common.OptionMap["ChatLink"] = common.ChatLink + //common.OptionMap["ChatLink2"] = common.ChatLink2 + common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64) + common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes) + common.OptionMap["DataExportInterval"] = strconv.Itoa(common.DataExportInterval) + common.OptionMap["DataExportDefaultTime"] = common.DataExportDefaultTime + common.OptionMap["DefaultCollapseSidebar"] = strconv.FormatBool(common.DefaultCollapseSidebar) + common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(setting.MjNotifyEnabled) + common.OptionMap["MjAccountFilterEnabled"] = strconv.FormatBool(setting.MjAccountFilterEnabled) + common.OptionMap["MjModeClearEnabled"] = strconv.FormatBool(setting.MjModeClearEnabled) + common.OptionMap["MjForwardUrlEnabled"] = strconv.FormatBool(setting.MjForwardUrlEnabled) + common.OptionMap["MjActionCheckSuccessEnabled"] = strconv.FormatBool(setting.MjActionCheckSuccessEnabled) + common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(setting.CheckSensitiveEnabled) + common.OptionMap["DemoSiteEnabled"] = strconv.FormatBool(operation_setting.DemoSiteEnabled) + common.OptionMap["SelfUseModeEnabled"] = strconv.FormatBool(operation_setting.SelfUseModeEnabled) + common.OptionMap["ModelRequestRateLimitEnabled"] = strconv.FormatBool(setting.ModelRequestRateLimitEnabled) + common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(setting.CheckSensitiveOnPromptEnabled) + common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(setting.StopOnSensitiveEnabled) + common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString() + common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength) + common.OptionMap["AutomaticDisableKeywords"] = operation_setting.AutomaticDisableKeywordsToString() + + // 自动添加所有注册的模型配置 + modelConfigs := config.GlobalConfig.ExportAllConfigs() + for k, v := range modelConfigs { + common.OptionMap[k] = v + } + + common.OptionMapRWMutex.Unlock() + loadOptionsFromDatabase() +} + +func loadOptionsFromDatabase() { + options, _ := AllOption() + for _, option := range options { + err := updateOptionMap(option.Key, option.Value) + if err != nil { + common.SysError("failed to update option map: " + err.Error()) + } + } +} + +func SyncOptions(frequency int) { + for { + time.Sleep(time.Duration(frequency) * time.Second) + common.SysLog("syncing options from database") + loadOptionsFromDatabase() + } +} + +func UpdateOption(key string, value string) error { + // Save to database first + option := Option{ + Key: key, + } + // https://gorm.io/docs/update.html#Save-All-Fields + DB.FirstOrCreate(&option, Option{Key: key}) + option.Value = value + // Save is a combination function. + // If save value does not contain primary key, it will execute Create, + // otherwise it will execute Update (with all fields). + DB.Save(&option) + // Update OptionMap + return updateOptionMap(key, value) +} + +func updateOptionMap(key string, value string) (err error) { + common.OptionMapRWMutex.Lock() + defer common.OptionMapRWMutex.Unlock() + common.OptionMap[key] = value + + // 检查是否是模型配置 - 使用更规范的方式处理 + if handleConfigUpdate(key, value) { + return nil // 已由配置系统处理 + } + + // 处理传统配置项... + if strings.HasSuffix(key, "Permission") { + intValue, _ := strconv.Atoi(value) + switch key { + case "FileUploadPermission": + common.FileUploadPermission = intValue + case "FileDownloadPermission": + common.FileDownloadPermission = intValue + case "ImageUploadPermission": + common.ImageUploadPermission = intValue + case "ImageDownloadPermission": + common.ImageDownloadPermission = intValue + } + } + if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" { + boolValue := value == "true" + switch key { + case "PasswordRegisterEnabled": + common.PasswordRegisterEnabled = boolValue + case "PasswordLoginEnabled": + common.PasswordLoginEnabled = boolValue + case "EmailVerificationEnabled": + common.EmailVerificationEnabled = boolValue + case "GitHubOAuthEnabled": + common.GitHubOAuthEnabled = boolValue + case "LinuxDOOAuthEnabled": + common.LinuxDOOAuthEnabled = boolValue + case "WeChatAuthEnabled": + common.WeChatAuthEnabled = boolValue + case "TelegramOAuthEnabled": + common.TelegramOAuthEnabled = boolValue + case "TurnstileCheckEnabled": + common.TurnstileCheckEnabled = boolValue + case "RegisterEnabled": + common.RegisterEnabled = boolValue + case "EmailDomainRestrictionEnabled": + common.EmailDomainRestrictionEnabled = boolValue + case "EmailAliasRestrictionEnabled": + common.EmailAliasRestrictionEnabled = boolValue + case "AutomaticDisableChannelEnabled": + common.AutomaticDisableChannelEnabled = boolValue + case "AutomaticEnableChannelEnabled": + common.AutomaticEnableChannelEnabled = boolValue + case "LogConsumeEnabled": + common.LogConsumeEnabled = boolValue + case "DisplayInCurrencyEnabled": + common.DisplayInCurrencyEnabled = boolValue + case "DisplayTokenStatEnabled": + common.DisplayTokenStatEnabled = boolValue + case "DrawingEnabled": + common.DrawingEnabled = boolValue + case "TaskEnabled": + common.TaskEnabled = boolValue + case "DataExportEnabled": + common.DataExportEnabled = boolValue + case "DefaultCollapseSidebar": + common.DefaultCollapseSidebar = boolValue + case "MjNotifyEnabled": + setting.MjNotifyEnabled = boolValue + case "MjAccountFilterEnabled": + setting.MjAccountFilterEnabled = boolValue + case "MjModeClearEnabled": + setting.MjModeClearEnabled = boolValue + case "MjForwardUrlEnabled": + setting.MjForwardUrlEnabled = boolValue + case "MjActionCheckSuccessEnabled": + setting.MjActionCheckSuccessEnabled = boolValue + case "CheckSensitiveEnabled": + setting.CheckSensitiveEnabled = boolValue + case "DemoSiteEnabled": + operation_setting.DemoSiteEnabled = boolValue + case "SelfUseModeEnabled": + operation_setting.SelfUseModeEnabled = boolValue + case "CheckSensitiveOnPromptEnabled": + setting.CheckSensitiveOnPromptEnabled = boolValue + case "ModelRequestRateLimitEnabled": + setting.ModelRequestRateLimitEnabled = boolValue + case "StopOnSensitiveEnabled": + setting.StopOnSensitiveEnabled = boolValue + case "SMTPSSLEnabled": + common.SMTPSSLEnabled = boolValue + } + } + switch key { + case "EmailDomainWhitelist": + common.EmailDomainWhitelist = strings.Split(value, ",") + case "SMTPServer": + common.SMTPServer = value + case "SMTPPort": + intValue, _ := strconv.Atoi(value) + common.SMTPPort = intValue + case "SMTPAccount": + common.SMTPAccount = value + case "SMTPFrom": + common.SMTPFrom = value + case "SMTPToken": + common.SMTPToken = value + case "ServerAddress": + setting.ServerAddress = value + case "WorkerUrl": + setting.WorkerUrl = value + case "WorkerValidKey": + setting.WorkerValidKey = value + case "PayAddress": + setting.PayAddress = value + case "Chats": + err = setting.UpdateChatsByJsonString(value) + case "CustomCallbackAddress": + setting.CustomCallbackAddress = value + case "EpayId": + setting.EpayId = value + case "EpayKey": + setting.EpayKey = value + case "Price": + setting.Price, _ = strconv.ParseFloat(value, 64) + case "MinTopUp": + setting.MinTopUp, _ = strconv.Atoi(value) + case "TopupGroupRatio": + err = common.UpdateTopupGroupRatioByJSONString(value) + case "GitHubClientId": + common.GitHubClientId = value + case "GitHubClientSecret": + common.GitHubClientSecret = value + case "LinuxDOClientId": + common.LinuxDOClientId = value + case "LinuxDOClientSecret": + common.LinuxDOClientSecret = value + case "Footer": + common.Footer = value + case "SystemName": + common.SystemName = value + case "Logo": + common.Logo = value + case "WeChatServerAddress": + common.WeChatServerAddress = value + case "WeChatServerToken": + common.WeChatServerToken = value + case "WeChatAccountQRCodeImageURL": + common.WeChatAccountQRCodeImageURL = value + case "TelegramBotToken": + common.TelegramBotToken = value + case "TelegramBotName": + common.TelegramBotName = value + case "TurnstileSiteKey": + common.TurnstileSiteKey = value + case "TurnstileSecretKey": + common.TurnstileSecretKey = value + case "QuotaForNewUser": + common.QuotaForNewUser, _ = strconv.Atoi(value) + case "QuotaForInviter": + common.QuotaForInviter, _ = strconv.Atoi(value) + case "QuotaForInvitee": + common.QuotaForInvitee, _ = strconv.Atoi(value) + case "QuotaRemindThreshold": + common.QuotaRemindThreshold, _ = strconv.Atoi(value) + case "PreConsumedQuota": + common.PreConsumedQuota, _ = strconv.Atoi(value) + case "ModelRequestRateLimitCount": + setting.ModelRequestRateLimitCount, _ = strconv.Atoi(value) + case "ModelRequestRateLimitDurationMinutes": + setting.ModelRequestRateLimitDurationMinutes, _ = strconv.Atoi(value) + case "ModelRequestRateLimitSuccessCount": + setting.ModelRequestRateLimitSuccessCount, _ = strconv.Atoi(value) + case "RetryTimes": + common.RetryTimes, _ = strconv.Atoi(value) + case "DataExportInterval": + common.DataExportInterval, _ = strconv.Atoi(value) + case "DataExportDefaultTime": + common.DataExportDefaultTime = value + case "ModelRatio": + err = operation_setting.UpdateModelRatioByJSONString(value) + case "GroupRatio": + err = setting.UpdateGroupRatioByJSONString(value) + case "UserUsableGroups": + err = setting.UpdateUserUsableGroupsByJSONString(value) + case "CompletionRatio": + err = operation_setting.UpdateCompletionRatioByJSONString(value) + case "ModelPrice": + err = operation_setting.UpdateModelPriceByJSONString(value) + case "CacheRatio": + err = operation_setting.UpdateCacheRatioByJSONString(value) + case "TopUpLink": + common.TopUpLink = value + //case "ChatLink": + // common.ChatLink = value + //case "ChatLink2": + // common.ChatLink2 = value + case "ChannelDisableThreshold": + common.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64) + case "QuotaPerUnit": + common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64) + case "SensitiveWords": + setting.SensitiveWordsFromString(value) + case "AutomaticDisableKeywords": + operation_setting.AutomaticDisableKeywordsFromString(value) + case "StreamCacheQueueLength": + setting.StreamCacheQueueLength, _ = strconv.Atoi(value) + } + return err +} + +// handleConfigUpdate 处理分层配置更新,返回是否已处理 +func handleConfigUpdate(key, value string) bool { + parts := strings.SplitN(key, ".", 2) + if len(parts) != 2 { + return false // 不是分层配置 + } + + configName := parts[0] + configKey := parts[1] + + // 获取配置对象 + cfg := config.GlobalConfig.Get(configName) + if cfg == nil { + return false // 未注册的配置 + } + + // 更新配置 + configMap := map[string]string{ + configKey: value, + } + config.UpdateConfigFromMap(cfg, configMap) + + return true // 已处理 +} \ No newline at end of file diff --git a/middleware/model-rate-limit.go b/middleware/model-rate-limit.go index d4199ece..b0047b70 100644 --- a/middleware/model-rate-limit.go +++ b/middleware/model-rate-limit.go @@ -19,25 +19,20 @@ const ( ModelRequestRateLimitSuccessCountMark = "MRRLS" ) -// 检查Redis中的请求限制 func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, maxCount int, duration int64) (bool, error) { - // 如果maxCount为0,表示不限制 if maxCount == 0 { return true, nil } - // 获取当前计数 length, err := rdb.LLen(ctx, key).Result() if err != nil { return false, err } - // 如果未达到限制,允许请求 if length < int64(maxCount) { return true, nil } - // 检查时间窗口 oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result() oldTime, err := time.Parse(timeFormat, oldTimeStr) if err != nil { @@ -49,7 +44,6 @@ func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, max if err != nil { return false, err } - // 如果在时间窗口内已达到限制,拒绝请求 subTime := nowTime.Sub(oldTime).Seconds() if int64(subTime) < duration { rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute) @@ -59,9 +53,7 @@ func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, max return true, nil } -// 记录Redis请求 func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxCount int) { - // 如果maxCount为0,不记录请求 if maxCount == 0 { return } @@ -72,14 +64,12 @@ func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxC rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute) } -// Redis限流处理器 func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc { return func(c *gin.Context) { userId := strconv.Itoa(c.GetInt("id")) ctx := context.Background() rdb := common.RDB - // 1. 检查成功请求数限制 successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId) allowed, err := checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration) if err != nil { @@ -92,9 +82,7 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g return } - //2.检查总请求数限制并记录总请求(当totalMaxCount为0时会自动跳过,使用令牌桶限流器 totalKey := fmt.Sprintf("rateLimit:%s", userId) - // 初始化 tb := limiter.New(ctx, rdb) allowed, err = tb.Allow( ctx, @@ -114,17 +102,14 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount)) } - // 4. 处理请求 c.Next() - // 5. 如果请求成功,记录成功请求 if c.Writer.Status() < 400 { recordRedisRequest(ctx, rdb, successKey, successMaxCount) } } } -// 内存限流处理器 func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc { inMemoryRateLimiter.Init(time.Duration(setting.ModelRequestRateLimitDurationMinutes) * time.Minute) @@ -133,15 +118,12 @@ func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) totalKey := ModelRequestRateLimitCountMark + userId successKey := ModelRequestRateLimitSuccessCountMark + userId - // 1. 检查总请求数限制(当totalMaxCount为0时跳过) if totalMaxCount > 0 && !inMemoryRateLimiter.Request(totalKey, totalMaxCount, duration) { c.Status(http.StatusTooManyRequests) c.Abort() return } - // 2. 检查成功请求数限制 - // 使用一个临时key来检查限制,这样可以避免实际记录 checkKey := successKey + "_check" if !inMemoryRateLimiter.Request(checkKey, successMaxCount, duration) { c.Status(http.StatusTooManyRequests) @@ -149,54 +131,47 @@ func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) return } - // 3. 处理请求 c.Next() - // 4. 如果请求成功,记录到实际的成功请求计数中 if c.Writer.Status() < 400 { inMemoryRateLimiter.Request(successKey, successMaxCount, duration) } } } -// ModelRequestRateLimit 模型请求限流中间件 func ModelRequestRateLimit() func(c *gin.Context) { return func(c *gin.Context) { - // 在每个请求时检查是否启用限流 if !setting.ModelRequestRateLimitEnabled { c.Next() return } - // 计算通用限流参数 duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60) - // 获取用户组 group := c.GetString("token_group") if group == "" { group = c.GetString("group") } if group == "" { - group = "default" // 默认组 + group = "default" } - // 尝试获取用户组特定的限制 + finalTotalCount := setting.ModelRequestRateLimitCount + finalSuccessCount := setting.ModelRequestRateLimitSuccessCount + foundGroupLimit := false + groupTotalCount, groupSuccessCount, found := setting.GetGroupRateLimit(group) - - // 确定最终的限制值 - finalTotalCount := setting.ModelRequestRateLimitCount // 默认使用全局总次数限制 - finalSuccessCount := setting.ModelRequestRateLimitSuccessCount // 默认使用全局成功次数限制 - if found { - // 如果找到用户组特定限制,则使用它们 finalTotalCount = groupTotalCount finalSuccessCount = groupSuccessCount + foundGroupLimit = true common.LogWarn(c.Request.Context(), fmt.Sprintf("Using rate limit for group '%s': total=%d, success=%d", group, finalTotalCount, finalSuccessCount)) - } else { + } + + if !foundGroupLimit { common.LogInfo(c.Request.Context(), fmt.Sprintf("No specific rate limit found for group '%s', using global limits: total=%d, success=%d", group, finalTotalCount, finalSuccessCount)) } - // 根据存储类型选择并执行限流处理器,传入最终确定的限制值 if common.RedisEnabled { redisRateLimitHandler(duration, finalTotalCount, finalSuccessCount)(c) } else { diff --git a/model/option.go b/model/option.go index 1f5fb3aa..79556737 100644 --- a/model/option.go +++ b/model/option.go @@ -94,11 +94,12 @@ func InitOptionMap() { common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount) common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes) common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount) + jsonBytes, _ := json.Marshal(map[string][2]int{}) + common.OptionMap["ModelRequestRateLimitGroup"] = string(jsonBytes) common.OptionMap["ModelRatio"] = operation_setting.ModelRatio2JSONString() common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString() common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString() common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString() - common.OptionMap[setting.ModelRequestRateLimitGroupKey] = "{}" // 添加用户组速率限制默认值 common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString() common.OptionMap["CompletionRatio"] = operation_setting.CompletionRatio2JSONString() common.OptionMap["TopUpLink"] = common.TopUpLink @@ -153,47 +154,31 @@ func SyncOptions(frequency int) { } func UpdateOption(key string, value string) error { - originalValue := value // 保存原始值以备后用 + originalValue := value - // Validate and format specific keys before saving - if key == setting.ModelRequestRateLimitGroupKey { + if key == "ModelRequestRateLimitGroup" { var cfg map[string][2]int - // Validate the JSON structure first using the original value err := json.Unmarshal([]byte(originalValue), &cfg) if err != nil { - // 提供更具体的错误信息 return fmt.Errorf("无效的 JSON 格式 for %s: %w", key, err) } - // TODO: 可以添加更细致的结构验证,例如检查数组长度是否为2,值是否为非负数等。 - // if !isValidModelRequestRateLimitGroupConfig(cfg) { - // return fmt.Errorf("无效的配置值 for %s", key) - // } - // If valid, format the JSON before saving formattedValueBytes, marshalErr := json.MarshalIndent(cfg, "", " ") if marshalErr != nil { - // This should ideally not happen if validation passed, but handle defensively return fmt.Errorf("failed to marshal validated %s config: %w", key, marshalErr) } - value = string(formattedValueBytes) // Use formatted JSON for saving and memory update + value = string(formattedValueBytes) } - // Save to database option := Option{ Key: key, } - // https://gorm.io/docs/update.html#Save-All-Fields DB.FirstOrCreate(&option, Option{Key: key}) option.Value = value - // Save is a combination function. - // If save value does not contain primary key, it will execute Create, - // otherwise it will execute Update (with all fields). if err := DB.Save(&option).Error; err != nil { - return fmt.Errorf("保存选项 %s 到数据库失败: %w", key, err) // 添加错误上下文 + return fmt.Errorf("保存选项 %s 到数据库失败: %w", key, err) } - // Update OptionMap in memory using the potentially formatted value - // updateOptionMap 会处理内存中 setting.ModelRequestRateLimitGroupConfig 的更新 return updateOptionMap(key, value) } @@ -370,6 +355,8 @@ func updateOptionMap(key string, value string) (err error) { setting.ModelRequestRateLimitDurationMinutes, _ = strconv.Atoi(value) case "ModelRequestRateLimitSuccessCount": setting.ModelRequestRateLimitSuccessCount, _ = strconv.Atoi(value) + case "ModelRequestRateLimitGroup": + err = setting.UpdateModelRequestRateLimitGroup(value) case "RetryTimes": common.RetryTimes, _ = strconv.Atoi(value) case "DataExportInterval": @@ -404,15 +391,6 @@ func updateOptionMap(key string, value string) (err error) { operation_setting.AutomaticDisableKeywordsFromString(value) case "StreamCacheQueueLength": setting.StreamCacheQueueLength, _ = strconv.Atoi(value) - case setting.ModelRequestRateLimitGroupKey: - // Use the (potentially formatted) value passed from UpdateOption - // to update the actual configuration in memory. - // This is the single point where the memory state for this specific setting is updated. - err = setting.UpdateModelRequestRateLimitGroupConfig(value) - if err != nil { - // 添加错误上下文 - err = fmt.Errorf("更新内存中的 %s 配置失败: %w", key, err) - } } return err } @@ -440,4 +418,4 @@ func handleConfigUpdate(key, value string) bool { config.UpdateConfigFromMap(cfg, configMap) return true // 已处理 -} +} \ No newline at end of file diff --git a/setting/rate_limit.go b/setting/rate_limit.go index c83885a6..5be75cc1 100644 --- a/setting/rate_limit.go +++ b/setting/rate_limit.go @@ -11,24 +11,17 @@ var ModelRequestRateLimitEnabled = false var ModelRequestRateLimitDurationMinutes = 1 var ModelRequestRateLimitCount = 0 var ModelRequestRateLimitSuccessCount = 1000 +var ModelRequestRateLimitGroup map[string][2]int -// ModelRequestRateLimitGroupKey 定义了模型请求按组速率限制的配置键 -const ModelRequestRateLimitGroupKey = "ModelRequestRateLimitGroup" - -// ModelRequestRateLimitGroupConfig 存储按用户组解析后的速率限制配置 -// map[groupName][2]int{totalCount, successCount} -var ModelRequestRateLimitGroupConfig map[string][2]int var ModelRequestRateLimitGroupMutex sync.RWMutex -// UpdateModelRequestRateLimitGroupConfig 解析、校验并更新内存中的用户组速率限制配置 -func UpdateModelRequestRateLimitGroupConfig(jsonStr string) error { +func UpdateModelRequestRateLimitGroup(jsonStr string) error { ModelRequestRateLimitGroupMutex.Lock() defer ModelRequestRateLimitGroupMutex.Unlock() var newConfig map[string][2]int if jsonStr == "" || jsonStr == "{}" { - // 如果配置为空或空JSON对象,则清空内存配置 - ModelRequestRateLimitGroupConfig = make(map[string][2]int) + ModelRequestRateLimitGroup = make(map[string][2]int) common.SysLog("Model request rate limit group config cleared") return nil } @@ -38,37 +31,19 @@ func UpdateModelRequestRateLimitGroupConfig(jsonStr string) error { return fmt.Errorf("failed to unmarshal ModelRequestRateLimitGroup config: %w", err) } - // 校验配置值 - for group, limits := range newConfig { - if len(limits) != 2 { - return fmt.Errorf("invalid config for group '%s': limits array length must be 2", group) - } - if limits[1] <= 0 { // successCount must be greater than 0 - return fmt.Errorf("invalid config for group '%s': successCount (limits[1]) must be greater than 0", group) - } - if limits[0] < 0 { // totalCount can be 0 (no limit) or positive - return fmt.Errorf("invalid config for group '%s': totalCount (limits[0]) cannot be negative", group) - } - if limits[0] > 0 && limits[0] < limits[1] { // If totalCount is set, it must be >= successCount - return fmt.Errorf("invalid config for group '%s': totalCount (limits[0]) must be greater than or equal to successCount (limits[1]) when totalCount > 0", group) - } - } - - ModelRequestRateLimitGroupConfig = newConfig - common.SysLog("Model request rate limit group config updated") + ModelRequestRateLimitGroup = newConfig return nil } -// GetGroupRateLimit 安全地获取指定用户组的速率限制值 func GetGroupRateLimit(group string) (totalCount, successCount int, found bool) { ModelRequestRateLimitGroupMutex.RLock() defer ModelRequestRateLimitGroupMutex.RUnlock() - if ModelRequestRateLimitGroupConfig == nil { - return 0, 0, false // 配置尚未初始化 + if ModelRequestRateLimitGroup == nil { + return 0, 0, false } - limits, found := ModelRequestRateLimitGroupConfig[group] + limits, found := ModelRequestRateLimitGroup[group] if !found { return 0, 0, false } diff --git a/web/src/components/RateLimitSetting.js b/web/src/components/RateLimitSetting.js index ad6b53da..7e206672 100644 --- a/web/src/components/RateLimitSetting.js +++ b/web/src/components/RateLimitSetting.js @@ -9,59 +9,59 @@ import RequestRateLimit from '../pages/Setting/RateLimit/SettingsRequestRateLimi const RateLimitSetting = () => { const { t } = useTranslation(); let [inputs, setInputs] = useState({ - ModelRequestRateLimitEnabled: false, - ModelRequestRateLimitCount: 0, - ModelRequestRateLimitSuccessCount: 1000, - ModelRequestRateLimitDurationMinutes: 1, - ModelRequestRateLimitGroup: {}, + ModelRequestRateLimitEnabled: false, + ModelRequestRateLimitCount: 0, + ModelRequestRateLimitSuccessCount: 1000, + ModelRequestRateLimitDurationMinutes: 1, + ModelRequestRateLimitGroup: '{}', }); - + let [loading, setLoading] = useState(false); - + const getOptions = async () => { - const res = await API.get('/api/option/'); - const { success, message, data } = res.data; - if (success) { - let newInputs = {}; - data.forEach((item) => { - if (item.key.endsWith('Enabled')) { - newInputs[item.key] = item.value === 'true' ? true : false; - } else { - newInputs[item.key] = item.value; - } - }); - - setInputs(newInputs); - } else { - showError(message); - } + const res = await API.get('/api/option/'); + const { success, message, data } = res.data; + if (success) { + let newInputs = {}; + data.forEach((item) => { + // 检查 key 是否在初始 inputs 中定义 + if (Object.prototype.hasOwnProperty.call(inputs, item.key)) { + if (item.key.endsWith('Enabled')) { + newInputs[item.key] = item.value === 'true'; + } else { + newInputs[item.key] = item.value; + } + } + }); + setInputs(newInputs); + } else { + showError(message); + } }; async function onRefresh() { - try { - setLoading(true); - await getOptions(); - // showSuccess('刷新成功'); - } catch (error) { - showError('刷新失败'); - } finally { - setLoading(false); - } + try { + setLoading(true); + await getOptions(); + } catch (error) { + showError('刷新失败'); + } finally { + setLoading(false); + } } - + useEffect(() => { - onRefresh(); + onRefresh(); }, []); - + return ( - <> - - {/* AI请求速率限制 */} - - - - - + <> + + + + + + ); -}; - -export default RateLimitSetting; + }; + + export default RateLimitSetting; diff --git a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js index ec1c2158..2434020e 100644 --- a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js +++ b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js @@ -14,209 +14,201 @@ export default function RequestRateLimit(props) { const [loading, setLoading] = useState(false); const [inputs, setInputs] = useState({ - ModelRequestRateLimitEnabled: false, - ModelRequestRateLimitCount: -1, - ModelRequestRateLimitSuccessCount: 1000, - ModelRequestRateLimitDurationMinutes: 1, - ModelRequestRateLimitGroup: '{}', // 添加新字段并设置默认值 + ModelRequestRateLimitEnabled: false, + ModelRequestRateLimitCount: -1, + ModelRequestRateLimitSuccessCount: 1000, + ModelRequestRateLimitDurationMinutes: 1, + ModelRequestRateLimitGroup: '{}', }); const refForm = useRef(); const [inputsRow, setInputsRow] = useState(inputs); - + function onSubmit() { - const updateArray = compareObjects(inputs, inputsRow); - if (!updateArray.length) return showWarning(t('你似乎并没有修改什么')); - const requestQueue = updateArray.map((item) => { - let value = ''; - if (typeof inputs[item.key] === 'boolean') { - value = String(inputs[item.key]); - } else { - value = inputs[item.key]; - } - // 校验 ModelRequestRateLimitGroup 是否为有效的 JSON 对象字符串 - if (item.key === 'ModelRequestRateLimitGroup') { - try { - JSON.parse(value); - } catch (e) { - showError(t('用户组速率限制配置不是有效的 JSON 格式!')); - // 阻止请求发送 - return Promise.reject('Invalid JSON format'); - } - } - return API.put('/api/option/', { - key: item.key, - value, - }); - }); - - // 过滤掉无效的请求(例如,无效的 JSON) - const validRequests = requestQueue.filter(req => req !== null && req !== undefined && typeof req.then === 'function'); - - if (validRequests.length === 0 && requestQueue.length > 0) { - // 如果所有请求都被过滤掉了(因为 JSON 无效),则不继续执行 - return; - } - - setLoading(true); - Promise.all(validRequests) - .then((res) => { - if (validRequests.length === 1) { - if (res.includes(undefined)) return; - } else if (validRequests.length > 1) { - if (res.includes(undefined)) - return showError(t('部分保存失败,请重试')); - } - showSuccess(t('保存成功')); - props.refresh(); - // 更新 inputsRow 以反映保存后的状态 - setInputsRow(structuredClone(inputs)); - }) - .catch((error) => { - // 检查是否是由于无效 JSON 导致的错误 - if (error !== 'Invalid JSON format') { - showError(t('保存失败,请重试')); - } - }) - .finally(() => { - setLoading(false); - }); + const updateArray = compareObjects(inputs, inputsRow); + if (!updateArray.length) return showWarning(t('你似乎并没有修改什么')); + const requestQueue = updateArray.map((item) => { + let value = ''; + if (typeof inputs[item.key] === 'boolean') { + value = String(inputs[item.key]); + } else { + value = inputs[item.key]; + } + if (item.key === 'ModelRequestRateLimitGroup') { + try { + JSON.parse(value); + } catch (e) { + showError(t('用户组速率限制配置不是有效的 JSON 格式!')); + return Promise.reject('Invalid JSON format'); + } + } + return API.put('/api/option/', { + key: item.key, + value, + }); + }); + + const validRequests = requestQueue.filter(req => req !== null && req !== undefined && typeof req.then === 'function'); + + if (validRequests.length === 0 && requestQueue.length > 0) { + return; + } + + setLoading(true); + Promise.all(validRequests) + .then((res) => { + if (validRequests.length === 1) { + if (res.includes(undefined)) return; + } else if (validRequests.length > 1) { + if (res.includes(undefined)) + return showError(t('部分保存失败,请重试')); + } + showSuccess(t('保存成功')); + props.refresh(); + setInputsRow(structuredClone(inputs)); + }) + .catch((error) => { + if (error !== 'Invalid JSON format') { + showError(t('保存失败,请重试')); + } + }) + .finally(() => { + setLoading(false); + }); } - + useEffect(() => { - const currentInputs = {}; - for (let key in props.options) { - if (Object.keys(inputs).includes(key)) { - currentInputs[key] = props.options[key]; - } - } - setInputs(currentInputs); - setInputsRow(structuredClone(currentInputs)); - // 检查 refForm.current 是否存在 - if (refForm.current) { - refForm.current.setValues(currentInputs); - } - }, [props.options]); // 依赖项保持不变,因为 inputs 状态的结构已固定 - + const currentInputs = {}; + for (let key in props.options) { + if (Object.prototype.hasOwnProperty.call(inputs, key)) { // 使用 hasOwnProperty 检查 + currentInputs[key] = props.options[key]; + } + } + setInputs(currentInputs); + setInputsRow(structuredClone(currentInputs)); + if (refForm.current) { + refForm.current.setValues(currentInputs); + } + }, [props.options]); + return ( - <> - -
(refForm.current = formAPI)} - style={{ marginBottom: 15 }} - > - - - - { - setInputs({ - ...inputs, - ModelRequestRateLimitEnabled: value, - }); - }} - /> - - - - - - setInputs({ - ...inputs, - ModelRequestRateLimitDurationMinutes: String(value), - }) - } - /> - - - - - - setInputs({ - ...inputs, - ModelRequestRateLimitCount: String(value), - }) - } - /> - - - - setInputs({ - ...inputs, - ModelRequestRateLimitSuccessCount: String(value), - }) - } - /> - - - {/* 用户组速率限制配置项 */} - - - -

{t('说明:')}

-
    -
  • {t('使用 JSON 对象格式,键为用户组名 (字符串),值为包含两个整数的数组 [总次数限制, 成功次数限制]。')}
  • -
  • {t('总次数限制: 周期内允许的总请求次数 (含失败),0 代表不限制。')}
  • -
  • {t('成功次数限制: 周期内允许的成功请求次数 (HTTP < 400),必须大于 0。')}
  • -
  • {t('此配置将优先于上方的全局限制设置。')}
  • -
  • {t('未在此处配置的用户组将使用全局限制。')}
  • -
  • {t('限制周期统一使用上方配置的“限制周期”值。')}
  • -
  • {t('输入无效的 JSON 将无法保存。')}
  • -
- - } - autosize={{ minRows: 5, maxRows: 15 }} - style={{ fontFamily: 'monospace' }} - onChange={(value) => { - setInputs({ - ...inputs, - ModelRequestRateLimitGroup: value, // 直接更新字符串值 - }); - }} - /> - -
- - - -
-
-
- + <> + +
(refForm.current = formAPI)} + style={{ marginBottom: 15 }} + > + + + + { + setInputs({ + ...inputs, + ModelRequestRateLimitEnabled: value, + }); + }} + /> + + + + + + setInputs({ + ...inputs, + ModelRequestRateLimitDurationMinutes: String(value), + }) + } + /> + + + + + + setInputs({ + ...inputs, + ModelRequestRateLimitCount: String(value), + }) + } + /> + + + + setInputs({ + ...inputs, + ModelRequestRateLimitSuccessCount: String(value), + }) + } + /> + + + + + +

{t('说明:')}

+
    +
  • {t('使用 JSON 对象格式,键为用户组名 (字符串),值为包含两个整数的数组 [总次数限制, 成功次数限制]。')}
  • +
  • {t('总次数限制: 周期内允许的总请求次数 (含失败),0 代表不限制。')}
  • +
  • {t('成功次数限制: 周期内允许的成功请求次数 (HTTP < 400),必须大于 0。')}
  • +
  • {t('此配置将优先于上方的全局限制设置。')}
  • +
  • {t('未在此处配置的用户组将使用全局限制。')}
  • +
  • {t('限制周期统一使用上方配置的“限制周期”值。')}
  • +
  • {t('输入无效的 JSON 将无法保存。')}
  • +
+ + } + autosize={{ minRows: 5, maxRows: 15 }} + style={{ fontFamily: 'monospace' }} + onChange={(value) => { + setInputs({ + ...inputs, + ModelRequestRateLimitGroup: value, + }); + }} + /> + +
+ + + +
+
+
+ ); -} + } From b7fd1e4a203fb24d2b5a332ea4ec8abe3cdcecac Mon Sep 17 00:00:00 2001 From: tbphp Date: Mon, 5 May 2025 12:55:48 +0800 Subject: [PATCH 04/57] fix: Redis limit ignoring max eq 0 --- middleware/model-rate-limit.go | 36 ++++++++++++++++++---------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/middleware/model-rate-limit.go b/middleware/model-rate-limit.go index 581dc451..f81160fc 100644 --- a/middleware/model-rate-limit.go +++ b/middleware/model-rate-limit.go @@ -93,25 +93,27 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g } //2.检查总请求数限制并记录总请求(当totalMaxCount为0时会自动跳过,使用令牌桶限流器 - totalKey := fmt.Sprintf("rateLimit:%s", userId) - // 初始化 - tb := limiter.New(ctx, rdb) - allowed, err = tb.Allow( - ctx, - totalKey, - limiter.WithCapacity(int64(totalMaxCount)*duration), - limiter.WithRate(int64(totalMaxCount)), - limiter.WithRequested(duration), - ) + if totalMaxCount > 0 { + totalKey := fmt.Sprintf("rateLimit:%s", userId) + // 初始化 + tb := limiter.New(ctx, rdb) + allowed, err = tb.Allow( + ctx, + totalKey, + limiter.WithCapacity(int64(totalMaxCount)*duration), + limiter.WithRate(int64(totalMaxCount)), + limiter.WithRequested(duration), + ) - if err != nil { - fmt.Println("检查总请求数限制失败:", err.Error()) - abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed") - return - } + if err != nil { + fmt.Println("检查总请求数限制失败:", err.Error()) + abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed") + return + } - if !allowed { - abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount)) + if !allowed { + abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount)) + } } // 4. 处理请求 From 1e1d24d1b075042473902991cbc3610f6c8bfff8 Mon Sep 17 00:00:00 2001 From: tbphp Date: Mon, 5 May 2025 17:57:02 +0800 Subject: [PATCH 05/57] fix: rm debug file --- .old/option.go | 402 ------------------------------------------------- 1 file changed, 402 deletions(-) delete mode 100644 .old/option.go diff --git a/.old/option.go b/.old/option.go deleted file mode 100644 index f80f5cb3..00000000 --- a/.old/option.go +++ /dev/null @@ -1,402 +0,0 @@ -package model - -import ( - "one-api/common" - "one-api/setting" - "one-api/setting/config" - "one-api/setting/operation_setting" - "strconv" - "strings" - "time" -) - -type Option struct { - Key string `json:"key" gorm:"primaryKey"` - Value string `json:"value"` -} - -func AllOption() ([]*Option, error) { - var options []*Option - var err error - err = DB.Find(&options).Error - return options, err -} - -func InitOptionMap() { - common.OptionMapRWMutex.Lock() - common.OptionMap = make(map[string]string) - - // 添加原有的系统配置 - common.OptionMap["FileUploadPermission"] = strconv.Itoa(common.FileUploadPermission) - common.OptionMap["FileDownloadPermission"] = strconv.Itoa(common.FileDownloadPermission) - common.OptionMap["ImageUploadPermission"] = strconv.Itoa(common.ImageUploadPermission) - common.OptionMap["ImageDownloadPermission"] = strconv.Itoa(common.ImageDownloadPermission) - common.OptionMap["PasswordLoginEnabled"] = strconv.FormatBool(common.PasswordLoginEnabled) - common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled) - common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled) - common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled) - common.OptionMap["LinuxDOOAuthEnabled"] = strconv.FormatBool(common.LinuxDOOAuthEnabled) - common.OptionMap["TelegramOAuthEnabled"] = strconv.FormatBool(common.TelegramOAuthEnabled) - common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled) - common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled) - common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled) - common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled) - common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled) - common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled) - common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled) - common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled) - common.OptionMap["DrawingEnabled"] = strconv.FormatBool(common.DrawingEnabled) - common.OptionMap["TaskEnabled"] = strconv.FormatBool(common.TaskEnabled) - common.OptionMap["DataExportEnabled"] = strconv.FormatBool(common.DataExportEnabled) - common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64) - common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled) - common.OptionMap["EmailAliasRestrictionEnabled"] = strconv.FormatBool(common.EmailAliasRestrictionEnabled) - common.OptionMap["EmailDomainWhitelist"] = strings.Join(common.EmailDomainWhitelist, ",") - common.OptionMap["SMTPServer"] = "" - common.OptionMap["SMTPFrom"] = "" - common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort) - common.OptionMap["SMTPAccount"] = "" - common.OptionMap["SMTPToken"] = "" - common.OptionMap["SMTPSSLEnabled"] = strconv.FormatBool(common.SMTPSSLEnabled) - common.OptionMap["Notice"] = "" - common.OptionMap["About"] = "" - common.OptionMap["HomePageContent"] = "" - common.OptionMap["Footer"] = common.Footer - common.OptionMap["SystemName"] = common.SystemName - common.OptionMap["Logo"] = common.Logo - common.OptionMap["ServerAddress"] = "" - common.OptionMap["WorkerUrl"] = setting.WorkerUrl - common.OptionMap["WorkerValidKey"] = setting.WorkerValidKey - common.OptionMap["PayAddress"] = "" - common.OptionMap["CustomCallbackAddress"] = "" - common.OptionMap["EpayId"] = "" - common.OptionMap["EpayKey"] = "" - common.OptionMap["Price"] = strconv.FormatFloat(setting.Price, 'f', -1, 64) - common.OptionMap["MinTopUp"] = strconv.Itoa(setting.MinTopUp) - common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString() - common.OptionMap["Chats"] = setting.Chats2JsonString() - common.OptionMap["GitHubClientId"] = "" - common.OptionMap["GitHubClientSecret"] = "" - common.OptionMap["TelegramBotToken"] = "" - common.OptionMap["TelegramBotName"] = "" - common.OptionMap["WeChatServerAddress"] = "" - common.OptionMap["WeChatServerToken"] = "" - common.OptionMap["WeChatAccountQRCodeImageURL"] = "" - common.OptionMap["TurnstileSiteKey"] = "" - common.OptionMap["TurnstileSecretKey"] = "" - common.OptionMap["QuotaForNewUser"] = strconv.Itoa(common.QuotaForNewUser) - common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter) - common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee) - common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold) - common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota) - common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount) - common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes) - common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount) - common.OptionMap["ModelRatio"] = operation_setting.ModelRatio2JSONString() - common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString() - common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString() - common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString() - common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString() - common.OptionMap["CompletionRatio"] = operation_setting.CompletionRatio2JSONString() - common.OptionMap["TopUpLink"] = common.TopUpLink - //common.OptionMap["ChatLink"] = common.ChatLink - //common.OptionMap["ChatLink2"] = common.ChatLink2 - common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64) - common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes) - common.OptionMap["DataExportInterval"] = strconv.Itoa(common.DataExportInterval) - common.OptionMap["DataExportDefaultTime"] = common.DataExportDefaultTime - common.OptionMap["DefaultCollapseSidebar"] = strconv.FormatBool(common.DefaultCollapseSidebar) - common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(setting.MjNotifyEnabled) - common.OptionMap["MjAccountFilterEnabled"] = strconv.FormatBool(setting.MjAccountFilterEnabled) - common.OptionMap["MjModeClearEnabled"] = strconv.FormatBool(setting.MjModeClearEnabled) - common.OptionMap["MjForwardUrlEnabled"] = strconv.FormatBool(setting.MjForwardUrlEnabled) - common.OptionMap["MjActionCheckSuccessEnabled"] = strconv.FormatBool(setting.MjActionCheckSuccessEnabled) - common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(setting.CheckSensitiveEnabled) - common.OptionMap["DemoSiteEnabled"] = strconv.FormatBool(operation_setting.DemoSiteEnabled) - common.OptionMap["SelfUseModeEnabled"] = strconv.FormatBool(operation_setting.SelfUseModeEnabled) - common.OptionMap["ModelRequestRateLimitEnabled"] = strconv.FormatBool(setting.ModelRequestRateLimitEnabled) - common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(setting.CheckSensitiveOnPromptEnabled) - common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(setting.StopOnSensitiveEnabled) - common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString() - common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength) - common.OptionMap["AutomaticDisableKeywords"] = operation_setting.AutomaticDisableKeywordsToString() - - // 自动添加所有注册的模型配置 - modelConfigs := config.GlobalConfig.ExportAllConfigs() - for k, v := range modelConfigs { - common.OptionMap[k] = v - } - - common.OptionMapRWMutex.Unlock() - loadOptionsFromDatabase() -} - -func loadOptionsFromDatabase() { - options, _ := AllOption() - for _, option := range options { - err := updateOptionMap(option.Key, option.Value) - if err != nil { - common.SysError("failed to update option map: " + err.Error()) - } - } -} - -func SyncOptions(frequency int) { - for { - time.Sleep(time.Duration(frequency) * time.Second) - common.SysLog("syncing options from database") - loadOptionsFromDatabase() - } -} - -func UpdateOption(key string, value string) error { - // Save to database first - option := Option{ - Key: key, - } - // https://gorm.io/docs/update.html#Save-All-Fields - DB.FirstOrCreate(&option, Option{Key: key}) - option.Value = value - // Save is a combination function. - // If save value does not contain primary key, it will execute Create, - // otherwise it will execute Update (with all fields). - DB.Save(&option) - // Update OptionMap - return updateOptionMap(key, value) -} - -func updateOptionMap(key string, value string) (err error) { - common.OptionMapRWMutex.Lock() - defer common.OptionMapRWMutex.Unlock() - common.OptionMap[key] = value - - // 检查是否是模型配置 - 使用更规范的方式处理 - if handleConfigUpdate(key, value) { - return nil // 已由配置系统处理 - } - - // 处理传统配置项... - if strings.HasSuffix(key, "Permission") { - intValue, _ := strconv.Atoi(value) - switch key { - case "FileUploadPermission": - common.FileUploadPermission = intValue - case "FileDownloadPermission": - common.FileDownloadPermission = intValue - case "ImageUploadPermission": - common.ImageUploadPermission = intValue - case "ImageDownloadPermission": - common.ImageDownloadPermission = intValue - } - } - if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" { - boolValue := value == "true" - switch key { - case "PasswordRegisterEnabled": - common.PasswordRegisterEnabled = boolValue - case "PasswordLoginEnabled": - common.PasswordLoginEnabled = boolValue - case "EmailVerificationEnabled": - common.EmailVerificationEnabled = boolValue - case "GitHubOAuthEnabled": - common.GitHubOAuthEnabled = boolValue - case "LinuxDOOAuthEnabled": - common.LinuxDOOAuthEnabled = boolValue - case "WeChatAuthEnabled": - common.WeChatAuthEnabled = boolValue - case "TelegramOAuthEnabled": - common.TelegramOAuthEnabled = boolValue - case "TurnstileCheckEnabled": - common.TurnstileCheckEnabled = boolValue - case "RegisterEnabled": - common.RegisterEnabled = boolValue - case "EmailDomainRestrictionEnabled": - common.EmailDomainRestrictionEnabled = boolValue - case "EmailAliasRestrictionEnabled": - common.EmailAliasRestrictionEnabled = boolValue - case "AutomaticDisableChannelEnabled": - common.AutomaticDisableChannelEnabled = boolValue - case "AutomaticEnableChannelEnabled": - common.AutomaticEnableChannelEnabled = boolValue - case "LogConsumeEnabled": - common.LogConsumeEnabled = boolValue - case "DisplayInCurrencyEnabled": - common.DisplayInCurrencyEnabled = boolValue - case "DisplayTokenStatEnabled": - common.DisplayTokenStatEnabled = boolValue - case "DrawingEnabled": - common.DrawingEnabled = boolValue - case "TaskEnabled": - common.TaskEnabled = boolValue - case "DataExportEnabled": - common.DataExportEnabled = boolValue - case "DefaultCollapseSidebar": - common.DefaultCollapseSidebar = boolValue - case "MjNotifyEnabled": - setting.MjNotifyEnabled = boolValue - case "MjAccountFilterEnabled": - setting.MjAccountFilterEnabled = boolValue - case "MjModeClearEnabled": - setting.MjModeClearEnabled = boolValue - case "MjForwardUrlEnabled": - setting.MjForwardUrlEnabled = boolValue - case "MjActionCheckSuccessEnabled": - setting.MjActionCheckSuccessEnabled = boolValue - case "CheckSensitiveEnabled": - setting.CheckSensitiveEnabled = boolValue - case "DemoSiteEnabled": - operation_setting.DemoSiteEnabled = boolValue - case "SelfUseModeEnabled": - operation_setting.SelfUseModeEnabled = boolValue - case "CheckSensitiveOnPromptEnabled": - setting.CheckSensitiveOnPromptEnabled = boolValue - case "ModelRequestRateLimitEnabled": - setting.ModelRequestRateLimitEnabled = boolValue - case "StopOnSensitiveEnabled": - setting.StopOnSensitiveEnabled = boolValue - case "SMTPSSLEnabled": - common.SMTPSSLEnabled = boolValue - } - } - switch key { - case "EmailDomainWhitelist": - common.EmailDomainWhitelist = strings.Split(value, ",") - case "SMTPServer": - common.SMTPServer = value - case "SMTPPort": - intValue, _ := strconv.Atoi(value) - common.SMTPPort = intValue - case "SMTPAccount": - common.SMTPAccount = value - case "SMTPFrom": - common.SMTPFrom = value - case "SMTPToken": - common.SMTPToken = value - case "ServerAddress": - setting.ServerAddress = value - case "WorkerUrl": - setting.WorkerUrl = value - case "WorkerValidKey": - setting.WorkerValidKey = value - case "PayAddress": - setting.PayAddress = value - case "Chats": - err = setting.UpdateChatsByJsonString(value) - case "CustomCallbackAddress": - setting.CustomCallbackAddress = value - case "EpayId": - setting.EpayId = value - case "EpayKey": - setting.EpayKey = value - case "Price": - setting.Price, _ = strconv.ParseFloat(value, 64) - case "MinTopUp": - setting.MinTopUp, _ = strconv.Atoi(value) - case "TopupGroupRatio": - err = common.UpdateTopupGroupRatioByJSONString(value) - case "GitHubClientId": - common.GitHubClientId = value - case "GitHubClientSecret": - common.GitHubClientSecret = value - case "LinuxDOClientId": - common.LinuxDOClientId = value - case "LinuxDOClientSecret": - common.LinuxDOClientSecret = value - case "Footer": - common.Footer = value - case "SystemName": - common.SystemName = value - case "Logo": - common.Logo = value - case "WeChatServerAddress": - common.WeChatServerAddress = value - case "WeChatServerToken": - common.WeChatServerToken = value - case "WeChatAccountQRCodeImageURL": - common.WeChatAccountQRCodeImageURL = value - case "TelegramBotToken": - common.TelegramBotToken = value - case "TelegramBotName": - common.TelegramBotName = value - case "TurnstileSiteKey": - common.TurnstileSiteKey = value - case "TurnstileSecretKey": - common.TurnstileSecretKey = value - case "QuotaForNewUser": - common.QuotaForNewUser, _ = strconv.Atoi(value) - case "QuotaForInviter": - common.QuotaForInviter, _ = strconv.Atoi(value) - case "QuotaForInvitee": - common.QuotaForInvitee, _ = strconv.Atoi(value) - case "QuotaRemindThreshold": - common.QuotaRemindThreshold, _ = strconv.Atoi(value) - case "PreConsumedQuota": - common.PreConsumedQuota, _ = strconv.Atoi(value) - case "ModelRequestRateLimitCount": - setting.ModelRequestRateLimitCount, _ = strconv.Atoi(value) - case "ModelRequestRateLimitDurationMinutes": - setting.ModelRequestRateLimitDurationMinutes, _ = strconv.Atoi(value) - case "ModelRequestRateLimitSuccessCount": - setting.ModelRequestRateLimitSuccessCount, _ = strconv.Atoi(value) - case "RetryTimes": - common.RetryTimes, _ = strconv.Atoi(value) - case "DataExportInterval": - common.DataExportInterval, _ = strconv.Atoi(value) - case "DataExportDefaultTime": - common.DataExportDefaultTime = value - case "ModelRatio": - err = operation_setting.UpdateModelRatioByJSONString(value) - case "GroupRatio": - err = setting.UpdateGroupRatioByJSONString(value) - case "UserUsableGroups": - err = setting.UpdateUserUsableGroupsByJSONString(value) - case "CompletionRatio": - err = operation_setting.UpdateCompletionRatioByJSONString(value) - case "ModelPrice": - err = operation_setting.UpdateModelPriceByJSONString(value) - case "CacheRatio": - err = operation_setting.UpdateCacheRatioByJSONString(value) - case "TopUpLink": - common.TopUpLink = value - //case "ChatLink": - // common.ChatLink = value - //case "ChatLink2": - // common.ChatLink2 = value - case "ChannelDisableThreshold": - common.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64) - case "QuotaPerUnit": - common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64) - case "SensitiveWords": - setting.SensitiveWordsFromString(value) - case "AutomaticDisableKeywords": - operation_setting.AutomaticDisableKeywordsFromString(value) - case "StreamCacheQueueLength": - setting.StreamCacheQueueLength, _ = strconv.Atoi(value) - } - return err -} - -// handleConfigUpdate 处理分层配置更新,返回是否已处理 -func handleConfigUpdate(key, value string) bool { - parts := strings.SplitN(key, ".", 2) - if len(parts) != 2 { - return false // 不是分层配置 - } - - configName := parts[0] - configKey := parts[1] - - // 获取配置对象 - cfg := config.GlobalConfig.Get(configName) - if cfg == nil { - return false // 未注册的配置 - } - - // 更新配置 - configMap := map[string]string{ - configKey: value, - } - config.UpdateConfigFromMap(cfg, configMap) - - return true // 已处理 -} \ No newline at end of file From 1513ed78477044999e066d5eb3b1fc1762dce531 Mon Sep 17 00:00:00 2001 From: tbphp Date: Mon, 5 May 2025 19:32:22 +0800 Subject: [PATCH 06/57] =?UTF-8?q?refactor:=20=E8=B0=83=E6=95=B4=E4=BB=A3?= =?UTF-8?q?=E7=A0=81=EF=BC=8C=E7=AC=A6=E5=90=88=E9=A1=B9=E7=9B=AE=E7=8E=B0?= =?UTF-8?q?=E6=9C=89=E8=A7=84=E8=8C=83?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- middleware/model-rate-limit.go | 54 +++++++++++++++++--------- model/option.go | 34 +++++----------- setting/rate_limit.go | 37 ++++++++---------- web/src/components/RateLimitSetting.js | 6 ++- 4 files changed, 65 insertions(+), 66 deletions(-) diff --git a/middleware/model-rate-limit.go b/middleware/model-rate-limit.go index b0047b70..1ca5ace6 100644 --- a/middleware/model-rate-limit.go +++ b/middleware/model-rate-limit.go @@ -6,6 +6,7 @@ import ( "net/http" "one-api/common" "one-api/common/limiter" + "one-api/constant" "one-api/setting" "strconv" "time" @@ -19,20 +20,25 @@ const ( ModelRequestRateLimitSuccessCountMark = "MRRLS" ) +// 检查Redis中的请求限制 func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, maxCount int, duration int64) (bool, error) { + // 如果maxCount为0,表示不限制 if maxCount == 0 { return true, nil } + // 获取当前计数 length, err := rdb.LLen(ctx, key).Result() if err != nil { return false, err } + // 如果未达到限制,允许请求 if length < int64(maxCount) { return true, nil } + // 检查时间窗口 oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result() oldTime, err := time.Parse(timeFormat, oldTimeStr) if err != nil { @@ -44,6 +50,7 @@ func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, max if err != nil { return false, err } + // 如果在时间窗口内已达到限制,拒绝请求 subTime := nowTime.Sub(oldTime).Seconds() if int64(subTime) < duration { rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute) @@ -53,7 +60,9 @@ func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, max return true, nil } +// 记录Redis请求 func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxCount int) { + // 如果maxCount为0,不记录请求 if maxCount == 0 { return } @@ -64,12 +73,14 @@ func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxC rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute) } +// Redis限流处理器 func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc { return func(c *gin.Context) { userId := strconv.Itoa(c.GetInt("id")) ctx := context.Background() rdb := common.RDB + // 1. 检查成功请求数限制 successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId) allowed, err := checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration) if err != nil { @@ -82,7 +93,9 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g return } + //2.检查总请求数限制并记录总请求(当totalMaxCount为0时会自动跳过,使用令牌桶限流器 totalKey := fmt.Sprintf("rateLimit:%s", userId) + // 初始化 tb := limiter.New(ctx, rdb) allowed, err = tb.Allow( ctx, @@ -102,14 +115,17 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount)) } + // 4. 处理请求 c.Next() + // 5. 如果请求成功,记录成功请求 if c.Writer.Status() < 400 { recordRedisRequest(ctx, rdb, successKey, successMaxCount) } } } +// 内存限流处理器 func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc { inMemoryRateLimiter.Init(time.Duration(setting.ModelRequestRateLimitDurationMinutes) * time.Minute) @@ -118,12 +134,15 @@ func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) totalKey := ModelRequestRateLimitCountMark + userId successKey := ModelRequestRateLimitSuccessCountMark + userId + // 1. 检查总请求数限制(当totalMaxCount为0时跳过) if totalMaxCount > 0 && !inMemoryRateLimiter.Request(totalKey, totalMaxCount, duration) { c.Status(http.StatusTooManyRequests) c.Abort() return } + // 2. 检查成功请求数限制 + // 使用一个临时key来检查限制,这样可以避免实际记录 checkKey := successKey + "_check" if !inMemoryRateLimiter.Request(checkKey, successMaxCount, duration) { c.Status(http.StatusTooManyRequests) @@ -131,51 +150,48 @@ func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) return } + // 3. 处理请求 c.Next() + // 4. 如果请求成功,记录到实际的成功请求计数中 if c.Writer.Status() < 400 { inMemoryRateLimiter.Request(successKey, successMaxCount, duration) } } } +// ModelRequestRateLimit 模型请求限流中间件 func ModelRequestRateLimit() func(c *gin.Context) { return func(c *gin.Context) { + // 在每个请求时检查是否启用限流 if !setting.ModelRequestRateLimitEnabled { c.Next() return } + // 计算限流参数 duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60) + totalMaxCount := setting.ModelRequestRateLimitCount + successMaxCount := setting.ModelRequestRateLimitSuccessCount + // 获取分组 group := c.GetString("token_group") if group == "" { - group = c.GetString("group") - } - if group == "" { - group = "default" + group = c.GetString(constant.ContextKeyUserGroup) } - finalTotalCount := setting.ModelRequestRateLimitCount - finalSuccessCount := setting.ModelRequestRateLimitSuccessCount - foundGroupLimit := false - + //获取分组的限流配置 groupTotalCount, groupSuccessCount, found := setting.GetGroupRateLimit(group) if found { - finalTotalCount = groupTotalCount - finalSuccessCount = groupSuccessCount - foundGroupLimit = true - common.LogWarn(c.Request.Context(), fmt.Sprintf("Using rate limit for group '%s': total=%d, success=%d", group, finalTotalCount, finalSuccessCount)) - } - - if !foundGroupLimit { - common.LogInfo(c.Request.Context(), fmt.Sprintf("No specific rate limit found for group '%s', using global limits: total=%d, success=%d", group, finalTotalCount, finalSuccessCount)) + totalMaxCount = groupTotalCount + successMaxCount = groupSuccessCount } + // 根据存储类型选择并执行限流处理器 if common.RedisEnabled { - redisRateLimitHandler(duration, finalTotalCount, finalSuccessCount)(c) + redisRateLimitHandler(duration, totalMaxCount, successMaxCount)(c) } else { - memoryRateLimitHandler(duration, finalTotalCount, finalSuccessCount)(c) + memoryRateLimitHandler(duration, totalMaxCount, successMaxCount)(c) } } -} +} \ No newline at end of file diff --git a/model/option.go b/model/option.go index 79556737..e9c129e1 100644 --- a/model/option.go +++ b/model/option.go @@ -1,8 +1,6 @@ package model import ( - "encoding/json" - "fmt" "one-api/common" "one-api/setting" "one-api/setting/config" @@ -94,8 +92,7 @@ func InitOptionMap() { common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount) common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes) common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount) - jsonBytes, _ := json.Marshal(map[string][2]int{}) - common.OptionMap["ModelRequestRateLimitGroup"] = string(jsonBytes) + common.OptionMap["ModelRequestRateLimitGroup"] = setting.ModelRequestRateLimitGroup2JSONString() common.OptionMap["ModelRatio"] = operation_setting.ModelRatio2JSONString() common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString() common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString() @@ -154,31 +151,18 @@ func SyncOptions(frequency int) { } func UpdateOption(key string, value string) error { - originalValue := value - - if key == "ModelRequestRateLimitGroup" { - var cfg map[string][2]int - err := json.Unmarshal([]byte(originalValue), &cfg) - if err != nil { - return fmt.Errorf("无效的 JSON 格式 for %s: %w", key, err) - } - - formattedValueBytes, marshalErr := json.MarshalIndent(cfg, "", " ") - if marshalErr != nil { - return fmt.Errorf("failed to marshal validated %s config: %w", key, marshalErr) - } - value = string(formattedValueBytes) - } - + // Save to database first option := Option{ Key: key, } + // https://gorm.io/docs/update.html#Save-All-Fields DB.FirstOrCreate(&option, Option{Key: key}) option.Value = value - if err := DB.Save(&option).Error; err != nil { - return fmt.Errorf("保存选项 %s 到数据库失败: %w", key, err) - } - + // Save is a combination function. + // If save value does not contain primary key, it will execute Create, + // otherwise it will execute Update (with all fields). + DB.Save(&option) + // Update OptionMap return updateOptionMap(key, value) } @@ -356,7 +340,7 @@ func updateOptionMap(key string, value string) (err error) { case "ModelRequestRateLimitSuccessCount": setting.ModelRequestRateLimitSuccessCount, _ = strconv.Atoi(value) case "ModelRequestRateLimitGroup": - err = setting.UpdateModelRequestRateLimitGroup(value) + err = setting.UpdateModelRequestRateLimitGroupByJSONString(value) case "RetryTimes": common.RetryTimes, _ = strconv.Atoi(value) case "DataExportInterval": diff --git a/setting/rate_limit.go b/setting/rate_limit.go index 5be75cc1..aab030cd 100644 --- a/setting/rate_limit.go +++ b/setting/rate_limit.go @@ -2,7 +2,6 @@ package setting import ( "encoding/json" - "fmt" "one-api/common" "sync" ) @@ -11,33 +10,31 @@ var ModelRequestRateLimitEnabled = false var ModelRequestRateLimitDurationMinutes = 1 var ModelRequestRateLimitCount = 0 var ModelRequestRateLimitSuccessCount = 1000 -var ModelRequestRateLimitGroup map[string][2]int +var ModelRequestRateLimitGroup = map[string][2]int{} +var ModelRequestRateLimitMutex sync.RWMutex -var ModelRequestRateLimitGroupMutex sync.RWMutex +func ModelRequestRateLimitGroup2JSONString() string { + ModelRequestRateLimitMutex.RLock() + defer ModelRequestRateLimitMutex.RUnlock() -func UpdateModelRequestRateLimitGroup(jsonStr string) error { - ModelRequestRateLimitGroupMutex.Lock() - defer ModelRequestRateLimitGroupMutex.Unlock() - - var newConfig map[string][2]int - if jsonStr == "" || jsonStr == "{}" { - ModelRequestRateLimitGroup = make(map[string][2]int) - common.SysLog("Model request rate limit group config cleared") - return nil - } - - err := json.Unmarshal([]byte(jsonStr), &newConfig) + jsonBytes, err := json.Marshal(ModelRequestRateLimitGroup) if err != nil { - return fmt.Errorf("failed to unmarshal ModelRequestRateLimitGroup config: %w", err) + common.SysError("error marshalling model ratio: " + err.Error()) } + return string(jsonBytes) +} - ModelRequestRateLimitGroup = newConfig - return nil +func UpdateModelRequestRateLimitGroupByJSONString(jsonStr string) error { + ModelRequestRateLimitMutex.RLock() + defer ModelRequestRateLimitMutex.RUnlock() + + ModelRequestRateLimitGroup = make(map[string][2]int) + return json.Unmarshal([]byte(jsonStr), &ModelRequestRateLimitGroup) } func GetGroupRateLimit(group string) (totalCount, successCount int, found bool) { - ModelRequestRateLimitGroupMutex.RLock() - defer ModelRequestRateLimitGroupMutex.RUnlock() + ModelRequestRateLimitMutex.RLock() + defer ModelRequestRateLimitMutex.RUnlock() if ModelRequestRateLimitGroup == nil { return 0, 0, false diff --git a/web/src/components/RateLimitSetting.js b/web/src/components/RateLimitSetting.js index 7e206672..309b94de 100644 --- a/web/src/components/RateLimitSetting.js +++ b/web/src/components/RateLimitSetting.js @@ -24,7 +24,6 @@ const RateLimitSetting = () => { if (success) { let newInputs = {}; data.forEach((item) => { - // 检查 key 是否在初始 inputs 中定义 if (Object.prototype.hasOwnProperty.call(inputs, item.key)) { if (item.key.endsWith('Enabled')) { newInputs[item.key] = item.value === 'true'; @@ -33,6 +32,7 @@ const RateLimitSetting = () => { } } }); + setInputs(newInputs); } else { showError(message); @@ -42,6 +42,7 @@ const RateLimitSetting = () => { try { setLoading(true); await getOptions(); + // showSuccess('刷新成功'); } catch (error) { showError('刷新失败'); } finally { @@ -56,6 +57,7 @@ const RateLimitSetting = () => { return ( <> + {/* AI请求速率限制 */} @@ -64,4 +66,4 @@ const RateLimitSetting = () => { ); }; - export default RateLimitSetting; + export default RateLimitSetting; \ No newline at end of file From 88ed83f41927eacc43526b5739592016d2ae4c10 Mon Sep 17 00:00:00 2001 From: tbphp Date: Mon, 5 May 2025 20:00:06 +0800 Subject: [PATCH 07/57] feat: Modellimitgroup check --- controller/option.go | 9 +++++++++ setting/rate_limit.go | 16 ++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/controller/option.go b/controller/option.go index 81ef463c..250f16bb 100644 --- a/controller/option.go +++ b/controller/option.go @@ -110,6 +110,15 @@ func UpdateOption(c *gin.Context) { }) return } + case "ModelRequestRateLimitGroup": + err = setting.CheckModelRequestRateLimitGroup(option.Value) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } } err = model.UpdateOption(option.Key, option.Value) diff --git a/setting/rate_limit.go b/setting/rate_limit.go index aab030cd..14680791 100644 --- a/setting/rate_limit.go +++ b/setting/rate_limit.go @@ -2,6 +2,7 @@ package setting import ( "encoding/json" + "fmt" "one-api/common" "sync" ) @@ -46,3 +47,18 @@ func GetGroupRateLimit(group string) (totalCount, successCount int, found bool) } return limits[0], limits[1], true } + +func CheckModelRequestRateLimitGroup(jsonStr string) error { + checkModelRequestRateLimitGroup := make(map[string][2]int) + err := json.Unmarshal([]byte(jsonStr), &checkModelRequestRateLimitGroup) + if err != nil { + return err + } + for group, limits := range checkModelRequestRateLimitGroup { + if limits[0] < 0 || limits[1] < 0 { + return fmt.Errorf("group %s has negative rate limit values: [%d, %d]", group, limits[0], limits[1]) + } + } + + return nil +} From 1cb4d750e471649da8fa5824942c43bffdc4705e Mon Sep 17 00:00:00 2001 From: tbphp Date: Mon, 5 May 2025 22:06:16 +0800 Subject: [PATCH 08/57] =?UTF-8?q?feat:=20=E5=88=86=E7=BB=84=E9=80=9F?= =?UTF-8?q?=E7=8E=87=E5=89=8D=E7=AB=AF=E4=BC=98=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- web/src/components/RateLimitSetting.js | 16 ++-- .../RateLimit/SettingsRequestRateLimit.js | 83 ++++++++----------- 2 files changed, 45 insertions(+), 54 deletions(-) diff --git a/web/src/components/RateLimitSetting.js b/web/src/components/RateLimitSetting.js index 309b94de..4671317f 100644 --- a/web/src/components/RateLimitSetting.js +++ b/web/src/components/RateLimitSetting.js @@ -13,7 +13,7 @@ const RateLimitSetting = () => { ModelRequestRateLimitCount: 0, ModelRequestRateLimitSuccessCount: 1000, ModelRequestRateLimitDurationMinutes: 1, - ModelRequestRateLimitGroup: '{}', + ModelRequestRateLimitGroup: '', }); let [loading, setLoading] = useState(false); @@ -24,12 +24,14 @@ const RateLimitSetting = () => { if (success) { let newInputs = {}; data.forEach((item) => { - if (Object.prototype.hasOwnProperty.call(inputs, item.key)) { - if (item.key.endsWith('Enabled')) { - newInputs[item.key] = item.value === 'true'; - } else { - newInputs[item.key] = item.value; - } + if (item.key === 'ModelRequestRateLimitGroup') { + item.value = JSON.stringify(JSON.parse(item.value), null, 2); + } + + if (item.key.endsWith('Enabled')) { + newInputs[item.key] = item.value === 'true' ? true : false; + } else { + newInputs[item.key] = item.value; } }); diff --git a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js index 2434020e..b77c1e6a 100644 --- a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js +++ b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js @@ -6,6 +6,7 @@ import { showError, showSuccess, showWarning, + verifyJSON, } from '../../../helpers'; import { useTranslation } from 'react-i18next'; @@ -18,7 +19,7 @@ export default function RequestRateLimit(props) { ModelRequestRateLimitCount: -1, ModelRequestRateLimitSuccessCount: 1000, ModelRequestRateLimitDurationMinutes: 1, - ModelRequestRateLimitGroup: '{}', + ModelRequestRateLimitGroup: '', }); const refForm = useRef(); const [inputsRow, setInputsRow] = useState(inputs); @@ -33,43 +34,32 @@ export default function RequestRateLimit(props) { } else { value = inputs[item.key]; } - if (item.key === 'ModelRequestRateLimitGroup') { - try { - JSON.parse(value); - } catch (e) { - showError(t('用户组速率限制配置不是有效的 JSON 格式!')); - return Promise.reject('Invalid JSON format'); - } - } return API.put('/api/option/', { key: item.key, value, }); }); - - const validRequests = requestQueue.filter(req => req !== null && req !== undefined && typeof req.then === 'function'); - - if (validRequests.length === 0 && requestQueue.length > 0) { - return; - } - setLoading(true); - Promise.all(validRequests) + Promise.all(requestQueue) .then((res) => { - if (validRequests.length === 1) { + if (requestQueue.length === 1) { if (res.includes(undefined)) return; - } else if (validRequests.length > 1) { + } else if (requestQueue.length > 1) { if (res.includes(undefined)) return showError(t('部分保存失败,请重试')); } + + for (let i = 0; i < res.length; i++) { + if (!res[i].data.success) { + return showError(res[i].data.message); + } + } + showSuccess(t('保存成功')); props.refresh(); - setInputsRow(structuredClone(inputs)); }) - .catch((error) => { - if (error !== 'Invalid JSON format') { - showError(t('保存失败,请重试')); - } + .catch(() => { + showError(t('保存失败,请重试')); }) .finally(() => { setLoading(false); @@ -79,15 +69,13 @@ export default function RequestRateLimit(props) { useEffect(() => { const currentInputs = {}; for (let key in props.options) { - if (Object.prototype.hasOwnProperty.call(inputs, key)) { // 使用 hasOwnProperty 检查 + if (Object.keys(inputs).includes(key)) { currentInputs[key] = props.options[key]; } } setInputs(currentInputs); setInputsRow(structuredClone(currentInputs)); - if (refForm.current) { refForm.current.setValues(currentInputs); - } }, [props.options]); return ( @@ -168,40 +156,41 @@ export default function RequestRateLimit(props) { />
- - + + verifyJSON(value), + message: t('不是合法的 JSON 字符串'), + }, + ]} extraText={
-

{t('说明:')}

+

{t('说明:')}

    -
  • {t('使用 JSON 对象格式,键为用户组名 (字符串),值为包含两个整数的数组 [总次数限制, 成功次数限制]。')}
  • -
  • {t('总次数限制: 周期内允许的总请求次数 (含失败),0 代表不限制。')}
  • -
  • {t('成功次数限制: 周期内允许的成功请求次数 (HTTP < 400),必须大于 0。')}
  • -
  • {t('此配置将优先于上方的全局限制设置。')}
  • -
  • {t('未在此处配置的用户组将使用全局限制。')}
  • +
  • {t('使用 JSON 对象格式,格式为:{"组名": [最多请求次数, 最多请求完成次数]}')}
  • +
  • {t('示例:{"default": [200, 100], "vip": [0, 1000]}。')}
  • +
  • {t('分组速率配置优先级高于全局速率限制。')}
  • {t('限制周期统一使用上方配置的“限制周期”值。')}
  • -
  • {t('输入无效的 JSON 将无法保存。')}
} - autosize={{ minRows: 5, maxRows: 15 }} - style={{ fontFamily: 'monospace' }} onChange={(value) => { - setInputs({ - ...inputs, - ModelRequestRateLimitGroup: value, - }); + setInputs({ ...inputs, ModelRequestRateLimitGroup: value }); }} />
- + @@ -211,4 +200,4 @@ export default function RequestRateLimit(props) { ); - } + } \ No newline at end of file From 0be3678c9ca8d687920ba52ff7d17d65afba23ca Mon Sep 17 00:00:00 2001 From: tbphp Date: Mon, 5 May 2025 23:41:43 +0800 Subject: [PATCH 09/57] =?UTF-8?q?fix:=20=E8=AF=B7=E6=B1=82=E5=AE=8C?= =?UTF-8?q?=E6=88=90=E6=95=B0=E5=BF=85=E9=A1=BB=E5=A4=A7=E4=BA=8E=E7=AD=89?= =?UTF-8?q?=E4=BA=8E1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- setting/rate_limit.go | 2 +- web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/setting/rate_limit.go b/setting/rate_limit.go index 14680791..53b53f88 100644 --- a/setting/rate_limit.go +++ b/setting/rate_limit.go @@ -55,7 +55,7 @@ func CheckModelRequestRateLimitGroup(jsonStr string) error { return err } for group, limits := range checkModelRequestRateLimitGroup { - if limits[0] < 0 || limits[1] < 0 { + if limits[0] < 0 || limits[1] < 1 { return fmt.Errorf("group %s has negative rate limit values: [%d, %d]", group, limits[0], limits[1]) } } diff --git a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js index b77c1e6a..ae54b1ef 100644 --- a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js +++ b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js @@ -179,6 +179,7 @@ export default function RequestRateLimit(props) {
  • {t('使用 JSON 对象格式,格式为:{"组名": [最多请求次数, 最多请求完成次数]}')}
  • {t('示例:{"default": [200, 100], "vip": [0, 1000]}。')}
  • +
  • {t('[最多请求次数]必须大于等于0,[最多请求完成次数]必须大于等于1')}
  • {t('分组速率配置优先级高于全局速率限制。')}
  • {t('限制周期统一使用上方配置的“限制周期”值。')}
From bbab729619820b49706af49a48596e8cab105bde Mon Sep 17 00:00:00 2001 From: tbphp Date: Mon, 5 May 2025 23:48:15 +0800 Subject: [PATCH 10/57] fix: text --- web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js index ae54b1ef..7003c279 100644 --- a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js +++ b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js @@ -179,7 +179,7 @@ export default function RequestRateLimit(props) {
  • {t('使用 JSON 对象格式,格式为:{"组名": [最多请求次数, 最多请求完成次数]}')}
  • {t('示例:{"default": [200, 100], "vip": [0, 1000]}。')}
  • -
  • {t('[最多请求次数]必须大于等于0,[最多请求完成次数]必须大于等于1')}
  • +
  • {t('[最多请求次数]必须大于等于0,[最多请求完成次数]必须大于等于1。')}
  • {t('分组速率配置优先级高于全局速率限制。')}
  • {t('限制周期统一使用上方配置的“限制周期”值。')}
From 87188cd7d458464c7e83e3502eb0a11126e6f94e Mon Sep 17 00:00:00 2001 From: tbphp Date: Mon, 5 May 2025 23:53:05 +0800 Subject: [PATCH 11/57] =?UTF-8?q?fix:=20=E7=BC=A9=E8=BF=9B=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D=E8=BF=98=E5=8E=9F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- middleware/model-rate-limit.go | 2 +- model/option.go | 2 +- web/src/components/RateLimitSetting.js | 92 ++--- .../RateLimit/SettingsRequestRateLimit.js | 344 +++++++++--------- 4 files changed, 220 insertions(+), 220 deletions(-) diff --git a/middleware/model-rate-limit.go b/middleware/model-rate-limit.go index 1ca5ace6..03ef0ff3 100644 --- a/middleware/model-rate-limit.go +++ b/middleware/model-rate-limit.go @@ -194,4 +194,4 @@ func ModelRequestRateLimit() func(c *gin.Context) { memoryRateLimitHandler(duration, totalMaxCount, successMaxCount)(c) } } -} \ No newline at end of file +} diff --git a/model/option.go b/model/option.go index e9c129e1..d98a9d38 100644 --- a/model/option.go +++ b/model/option.go @@ -402,4 +402,4 @@ func handleConfigUpdate(key, value string) bool { config.UpdateConfigFromMap(cfg, configMap) return true // 已处理 -} \ No newline at end of file +} diff --git a/web/src/components/RateLimitSetting.js b/web/src/components/RateLimitSetting.js index 4671317f..a0953db7 100644 --- a/web/src/components/RateLimitSetting.js +++ b/web/src/components/RateLimitSetting.js @@ -9,62 +9,62 @@ import RequestRateLimit from '../pages/Setting/RateLimit/SettingsRequestRateLimi const RateLimitSetting = () => { const { t } = useTranslation(); let [inputs, setInputs] = useState({ - ModelRequestRateLimitEnabled: false, - ModelRequestRateLimitCount: 0, - ModelRequestRateLimitSuccessCount: 1000, - ModelRequestRateLimitDurationMinutes: 1, - ModelRequestRateLimitGroup: '', + ModelRequestRateLimitEnabled: false, + ModelRequestRateLimitCount: 0, + ModelRequestRateLimitSuccessCount: 1000, + ModelRequestRateLimitDurationMinutes: 1, + ModelRequestRateLimitGroup: '', }); - - let [loading, setLoading] = useState(false); - - const getOptions = async () => { - const res = await API.get('/api/option/'); - const { success, message, data } = res.data; - if (success) { - let newInputs = {}; - data.forEach((item) => { - if (item.key === 'ModelRequestRateLimitGroup') { - item.value = JSON.stringify(JSON.parse(item.value), null, 2); - } - if (item.key.endsWith('Enabled')) { - newInputs[item.key] = item.value === 'true' ? true : false; - } else { - newInputs[item.key] = item.value; - } - }); - - setInputs(newInputs); - } else { - showError(message); - } + let [loading, setLoading] = useState(false); + + const getOptions = async () => { + const res = await API.get('/api/option/'); + const { success, message, data } = res.data; + if (success) { + let newInputs = {}; + data.forEach((item) => { + if (item.key === 'ModelRequestRateLimitGroup') { + item.value = JSON.stringify(JSON.parse(item.value), null, 2); + } + + if (item.key.endsWith('Enabled')) { + newInputs[item.key] = item.value === 'true' ? true : false; + } else { + newInputs[item.key] = item.value; + } + }); + + setInputs(newInputs); + } else { + showError(message); + } }; async function onRefresh() { - try { - setLoading(true); - await getOptions(); - // showSuccess('刷新成功'); - } catch (error) { - showError('刷新失败'); - } finally { - setLoading(false); - } + try { + setLoading(true); + await getOptions(); + // showSuccess('刷新成功'); + } catch (error) { + showError('刷新失败'); + } finally { + setLoading(false); + } } useEffect(() => { - onRefresh(); + onRefresh(); }, []); return ( - <> - - {/* AI请求速率限制 */} - - - - - + <> + + {/* AI请求速率限制 */} + + + + + ); }; diff --git a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js index 7003c279..7c60bc47 100644 --- a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js +++ b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js @@ -15,190 +15,190 @@ export default function RequestRateLimit(props) { const [loading, setLoading] = useState(false); const [inputs, setInputs] = useState({ - ModelRequestRateLimitEnabled: false, - ModelRequestRateLimitCount: -1, - ModelRequestRateLimitSuccessCount: 1000, - ModelRequestRateLimitDurationMinutes: 1, - ModelRequestRateLimitGroup: '', + ModelRequestRateLimitEnabled: false, + ModelRequestRateLimitCount: -1, + ModelRequestRateLimitSuccessCount: 1000, + ModelRequestRateLimitDurationMinutes: 1, + ModelRequestRateLimitGroup: '', }); const refForm = useRef(); const [inputsRow, setInputsRow] = useState(inputs); function onSubmit() { - const updateArray = compareObjects(inputs, inputsRow); - if (!updateArray.length) return showWarning(t('你似乎并没有修改什么')); - const requestQueue = updateArray.map((item) => { - let value = ''; - if (typeof inputs[item.key] === 'boolean') { - value = String(inputs[item.key]); - } else { - value = inputs[item.key]; - } - return API.put('/api/option/', { - key: item.key, - value, - }); - }); - setLoading(true); - Promise.all(requestQueue) - .then((res) => { - if (requestQueue.length === 1) { - if (res.includes(undefined)) return; - } else if (requestQueue.length > 1) { - if (res.includes(undefined)) - return showError(t('部分保存失败,请重试')); - } + const updateArray = compareObjects(inputs, inputsRow); + if (!updateArray.length) return showWarning(t('你似乎并没有修改什么')); + const requestQueue = updateArray.map((item) => { + let value = ''; + if (typeof inputs[item.key] === 'boolean') { + value = String(inputs[item.key]); + } else { + value = inputs[item.key]; + } + return API.put('/api/option/', { + key: item.key, + value, + }); + }); + setLoading(true); + Promise.all(requestQueue) + .then((res) => { + if (requestQueue.length === 1) { + if (res.includes(undefined)) return; + } else if (requestQueue.length > 1) { + if (res.includes(undefined)) + return showError(t('部分保存失败,请重试')); + } - for (let i = 0; i < res.length; i++) { - if (!res[i].data.success) { - return showError(res[i].data.message); - } - } + for (let i = 0; i < res.length; i++) { + if (!res[i].data.success) { + return showError(res[i].data.message); + } + } - showSuccess(t('保存成功')); - props.refresh(); - }) - .catch(() => { - showError(t('保存失败,请重试')); - }) - .finally(() => { - setLoading(false); - }); + showSuccess(t('保存成功')); + props.refresh(); + }) + .catch(() => { + showError(t('保存失败,请重试')); + }) + .finally(() => { + setLoading(false); + }); } useEffect(() => { - const currentInputs = {}; - for (let key in props.options) { - if (Object.keys(inputs).includes(key)) { - currentInputs[key] = props.options[key]; - } - } - setInputs(currentInputs); - setInputsRow(structuredClone(currentInputs)); - refForm.current.setValues(currentInputs); + const currentInputs = {}; + for (let key in props.options) { + if (Object.keys(inputs).includes(key)) { + currentInputs[key] = props.options[key]; + } + } + setInputs(currentInputs); + setInputsRow(structuredClone(currentInputs)); + refForm.current.setValues(currentInputs); }, [props.options]); return ( - <> - -
(refForm.current = formAPI)} - style={{ marginBottom: 15 }} - > - - - - { - setInputs({ - ...inputs, - ModelRequestRateLimitEnabled: value, - }); - }} - /> - - - - - - setInputs({ - ...inputs, - ModelRequestRateLimitDurationMinutes: String(value), - }) - } - /> - - - - - - setInputs({ - ...inputs, - ModelRequestRateLimitCount: String(value), - }) - } - /> - - - - setInputs({ - ...inputs, - ModelRequestRateLimitSuccessCount: String(value), - }) - } - /> - - - - - verifyJSON(value), - message: t('不是合法的 JSON 字符串'), - }, - ]} - extraText={ -
-

{t('说明:')}

-
    -
  • {t('使用 JSON 对象格式,格式为:{"组名": [最多请求次数, 最多请求完成次数]}')}
  • -
  • {t('示例:{"default": [200, 100], "vip": [0, 1000]}。')}
  • -
  • {t('[最多请求次数]必须大于等于0,[最多请求完成次数]必须大于等于1。')}
  • -
  • {t('分组速率配置优先级高于全局速率限制。')}
  • -
  • {t('限制周期统一使用上方配置的“限制周期”值。')}
  • -
-
- } - onChange={(value) => { - setInputs({ ...inputs, ModelRequestRateLimitGroup: value }); - }} - /> - -
- - - -
-
-
- + <> + +
(refForm.current = formAPI)} + style={{ marginBottom: 15 }} + > + + + + { + setInputs({ + ...inputs, + ModelRequestRateLimitEnabled: value, + }); + }} + /> + + + + + + setInputs({ + ...inputs, + ModelRequestRateLimitDurationMinutes: String(value), + }) + } + /> + + + + + + setInputs({ + ...inputs, + ModelRequestRateLimitCount: String(value), + }) + } + /> + + + + setInputs({ + ...inputs, + ModelRequestRateLimitSuccessCount: String(value), + }) + } + /> + + + + + verifyJSON(value), + message: t('不是合法的 JSON 字符串'), + }, + ]} + extraText={ +
+

{t('说明:')}

+
    +
  • {t('使用 JSON 对象格式,格式为:{"组名": [最多请求次数, 最多请求完成次数]}')}
  • +
  • {t('示例:{"default": [200, 100], "vip": [0, 1000]}。')}
  • +
  • {t('[最多请求次数]必须大于等于0,[最多请求完成次数]必须大于等于1。')}
  • +
  • {t('分组速率配置优先级高于全局速率限制。')}
  • +
  • {t('限制周期统一使用上方配置的“限制周期”值。')}
  • +
+
+ } + onChange={(value) => { + setInputs({ ...inputs, ModelRequestRateLimitGroup: value }); + }} + /> + +
+ + + +
+
+
+ ); } \ No newline at end of file From 3d243c3ee2bc2a92d21d31f0155378ac5c188c39 Mon Sep 17 00:00:00 2001 From: tbphp Date: Mon, 5 May 2025 23:56:15 +0800 Subject: [PATCH 12/57] =?UTF-8?q?fix:=20=E6=A0=B7=E5=BC=8F=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- web/src/components/RateLimitSetting.js | 16 ++++++++-------- .../RateLimit/SettingsRequestRateLimit.js | 10 +++++----- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/web/src/components/RateLimitSetting.js b/web/src/components/RateLimitSetting.js index a0953db7..5f0200e1 100644 --- a/web/src/components/RateLimitSetting.js +++ b/web/src/components/RateLimitSetting.js @@ -34,7 +34,7 @@ const RateLimitSetting = () => { newInputs[item.key] = item.value; } }); - + setInputs(newInputs); } else { showError(message); @@ -44,28 +44,28 @@ const RateLimitSetting = () => { try { setLoading(true); await getOptions(); - // showSuccess('刷新成功'); + // showSuccess('刷新成功'); } catch (error) { showError('刷新失败'); } finally { setLoading(false); } } - + useEffect(() => { onRefresh(); }, []); - + return ( <> - {/* AI请求速率限制 */} + {/* AI请求速率限制 */} ); - }; - - export default RateLimitSetting; \ No newline at end of file +}; + +export default RateLimitSetting; diff --git a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js index 7c60bc47..73626351 100644 --- a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js +++ b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js @@ -23,7 +23,7 @@ export default function RequestRateLimit(props) { }); const refForm = useRef(); const [inputsRow, setInputsRow] = useState(inputs); - + function onSubmit() { const updateArray = compareObjects(inputs, inputsRow); if (!updateArray.length) return showWarning(t('你似乎并没有修改什么')); @@ -65,7 +65,7 @@ export default function RequestRateLimit(props) { setLoading(false); }); } - + useEffect(() => { const currentInputs = {}; for (let key in props.options) { @@ -75,9 +75,9 @@ export default function RequestRateLimit(props) { } setInputs(currentInputs); setInputsRow(structuredClone(currentInputs)); - refForm.current.setValues(currentInputs); + refForm.current.setValues(currentInputs); }, [props.options]); - + return ( <> @@ -201,4 +201,4 @@ export default function RequestRateLimit(props) { ); - } \ No newline at end of file +} From 0cf4c59d227a90a8dd4b66927b7b563dc3cea72d Mon Sep 17 00:00:00 2001 From: skynono <6811626@qq.com> Date: Tue, 6 May 2025 14:18:15 +0800 Subject: [PATCH 13/57] feat: add original password verification when changing password --- controller/user.go | 26 +++++++++++++++++++++++++- model/user.go | 1 + web/src/components/PersonalSetting.js | 20 ++++++++++++++++++++ 3 files changed, 46 insertions(+), 1 deletion(-) diff --git a/controller/user.go b/controller/user.go index e194f531..567c2aa7 100644 --- a/controller/user.go +++ b/controller/user.go @@ -592,7 +592,14 @@ func UpdateSelf(c *gin.Context) { user.Password = "" // rollback to what it should be cleanUser.Password = "" } - updatePassword := user.Password != "" + updatePassword, err := checkUpdatePassword(user.OriginalPassword, user.Password, cleanUser.Id) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } if err := cleanUser.Update(updatePassword); err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -608,6 +615,23 @@ func UpdateSelf(c *gin.Context) { return } +func checkUpdatePassword(originalPassword string, newPassword string, userId int) (updatePassword bool, err error) { + if newPassword == "" { + return + } + var currentUser *model.User + currentUser, err = model.GetUserById(userId, true) + if err != nil { + return + } + if !common.ValidatePasswordAndHash(originalPassword, currentUser.Password) { + err = fmt.Errorf("原密码错误") + return + } + updatePassword = true + return +} + func DeleteUser(c *gin.Context) { id, err := strconv.Atoi(c.Param("id")) if err != nil { diff --git a/model/user.go b/model/user.go index 0aea2ff5..1a3372aa 100644 --- a/model/user.go +++ b/model/user.go @@ -18,6 +18,7 @@ type User struct { Id int `json:"id"` Username string `json:"username" gorm:"unique;index" validate:"max=12"` Password string `json:"password" gorm:"not null;" validate:"min=8,max=20"` + OriginalPassword string `json:"original_password" gorm:"-:all"` // this field is only for Password change verification, don't save it to database! DisplayName string `json:"display_name" gorm:"index" validate:"max=20"` Role int `json:"role" gorm:"type:int;default:1"` // admin, common Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled diff --git a/web/src/components/PersonalSetting.js b/web/src/components/PersonalSetting.js index d1e03db2..fbd74536 100644 --- a/web/src/components/PersonalSetting.js +++ b/web/src/components/PersonalSetting.js @@ -57,6 +57,7 @@ const PersonalSetting = () => { email_verification_code: '', email: '', self_account_deletion_confirmation: '', + original_password: '', set_new_password: '', set_new_password_confirmation: '', }); @@ -239,11 +240,20 @@ const PersonalSetting = () => { }; const changePassword = async () => { + if (inputs.original_password === '') { + showError(t('请输入原密码!')); + return; + } + if (inputs.original_password === inputs.set_new_password) { + showError(t('新密码需要和原密码不一致!')); + return; + } if (inputs.set_new_password !== inputs.set_new_password_confirmation) { showError(t('两次输入的密码不一致!')); return; } const res = await API.put(`/api/user/self`, { + original_password: inputs.original_password, password: inputs.set_new_password, }); const { success, message } = res.data; @@ -1118,6 +1128,16 @@ const PersonalSetting = () => { >
+ handleInputChange('original_password', value) + } + /> + Date: Tue, 6 May 2025 18:41:01 +0800 Subject: [PATCH 14/57] feat: add support for DeepSeek channel in streamSupportedChannels --- relay/common/relay_info.go | 1 + 1 file changed, 1 insertion(+) diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 915474e1..0135283d 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -103,6 +103,7 @@ var streamSupportedChannels = map[int]bool{ common.ChannelTypeVolcEngine: true, common.ChannelTypeOllama: true, common.ChannelTypeXai: true, + common.ChannelTypeDeepSeek: true, } func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo { From 459c277c941ac61b81189f22b06637eff71485bf Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Tue, 6 May 2025 21:58:01 +0800 Subject: [PATCH 15/57] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=20built=20in?= =?UTF-8?q?=20tools=20=E8=AE=A1=E8=B4=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 增加非流的工具调用次数统计 - 添加 web search 和 file search 计费 --- dto/openai_response.go | 44 +++++++------- relay/channel/openai/relay_responses.go | 7 ++- relay/relay-text.go | 77 +++++++++++++++++++++++++ 3 files changed, 105 insertions(+), 23 deletions(-) diff --git a/dto/openai_response.go b/dto/openai_response.go index c8f61b9d..790d4df8 100644 --- a/dto/openai_response.go +++ b/dto/openai_response.go @@ -195,28 +195,28 @@ type OutputTokenDetails struct { } type OpenAIResponsesResponse struct { - ID string `json:"id"` - Object string `json:"object"` - CreatedAt int `json:"created_at"` - Status string `json:"status"` - Error *OpenAIError `json:"error,omitempty"` - IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"` - Instructions string `json:"instructions"` - MaxOutputTokens int `json:"max_output_tokens"` - Model string `json:"model"` - Output []ResponsesOutput `json:"output"` - ParallelToolCalls bool `json:"parallel_tool_calls"` - PreviousResponseID string `json:"previous_response_id"` - Reasoning *Reasoning `json:"reasoning"` - Store bool `json:"store"` - Temperature float64 `json:"temperature"` - ToolChoice string `json:"tool_choice"` - Tools []interface{} `json:"tools"` - TopP float64 `json:"top_p"` - Truncation string `json:"truncation"` - Usage *Usage `json:"usage"` - User json.RawMessage `json:"user"` - Metadata json.RawMessage `json:"metadata"` + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int `json:"created_at"` + Status string `json:"status"` + Error *OpenAIError `json:"error,omitempty"` + IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"` + Instructions string `json:"instructions"` + MaxOutputTokens int `json:"max_output_tokens"` + Model string `json:"model"` + Output []ResponsesOutput `json:"output"` + ParallelToolCalls bool `json:"parallel_tool_calls"` + PreviousResponseID string `json:"previous_response_id"` + Reasoning *Reasoning `json:"reasoning"` + Store bool `json:"store"` + Temperature float64 `json:"temperature"` + ToolChoice string `json:"tool_choice"` + Tools []ResponsesToolsCall `json:"tools"` + TopP float64 `json:"top_p"` + Truncation string `json:"truncation"` + Usage *Usage `json:"usage"` + User json.RawMessage `json:"user"` + Metadata json.RawMessage `json:"metadata"` } type IncompleteDetails struct { diff --git a/relay/channel/openai/relay_responses.go b/relay/channel/openai/relay_responses.go index 6af8c676..1d1e060e 100644 --- a/relay/channel/openai/relay_responses.go +++ b/relay/channel/openai/relay_responses.go @@ -3,7 +3,6 @@ package openai import ( "bytes" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" @@ -12,6 +11,8 @@ import ( "one-api/relay/helper" "one-api/service" "strings" + + "github.com/gin-gonic/gin" ) func OaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { @@ -61,6 +62,10 @@ func OaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon. usage.PromptTokens = responsesResponse.Usage.InputTokens usage.CompletionTokens = responsesResponse.Usage.OutputTokens usage.TotalTokens = responsesResponse.Usage.TotalTokens + // 解析 Tools 用量 + for _, tool := range responsesResponse.Tools { + info.ResponsesUsageInfo.BuiltInTools[tool.Type].CallCount++ + } return nil, &usage } diff --git a/relay/relay-text.go b/relay/relay-text.go index 4fdd435d..a528ec52 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -358,6 +358,67 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, ratio := dModelRatio.Mul(dGroupRatio) + // openai web search 工具计费 + var dWebSearchQuota decimal.Decimal + if relayInfo.ResponsesUsageInfo != nil { + if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists && webSearchTool.CallCount > 0 { + // 确定模型类型 + // https://platform.openai.com/docs/pricing Web search 价格按模型类型和 search context size 收费 + // gpt-4.1, gpt-4o, or gpt-4o-search-preview 更贵,gpt-4.1-mini, gpt-4o-mini, gpt-4o-mini-search-preview 更便宜 + isHighTierModel := (strings.HasPrefix(modelName, "gpt-4.1") || strings.HasPrefix(modelName, "gpt-4o")) && + !strings.Contains(modelName, "mini") + + // 确定 search context size 对应的价格 + var priceWebSearchPerThousandCalls float64 + switch webSearchTool.SearchContextSize { + case "low": + if isHighTierModel { + priceWebSearchPerThousandCalls = 30.0 + } else { + priceWebSearchPerThousandCalls = 25.0 + } + case "medium": + if isHighTierModel { + priceWebSearchPerThousandCalls = 35.0 + } else { + priceWebSearchPerThousandCalls = 27.5 + } + case "high": + if isHighTierModel { + priceWebSearchPerThousandCalls = 50.0 + } else { + priceWebSearchPerThousandCalls = 30.0 + } + default: + // search context size 默认为 medium + if isHighTierModel { + priceWebSearchPerThousandCalls = 35.0 + } else { + priceWebSearchPerThousandCalls = 27.5 + } + } + // 计算 web search 调用的配额 (配额 = 价格 * 调用次数 / 1000) + dWebSearchQuota = decimal.NewFromFloat(priceWebSearchPerThousandCalls). + Mul(decimal.NewFromInt(int64(webSearchTool.CallCount))). + Div(decimal.NewFromInt(1000)) + extraContent += fmt.Sprintf("Web Search 调用 %d 次,上下文大小 %s,调用花费 $%s", + webSearchTool.CallCount, webSearchTool.SearchContextSize, dWebSearchQuota.String()) + } + } + // file search tool 计费 + var dFileSearchQuota decimal.Decimal + if relayInfo.ResponsesUsageInfo != nil { + if fileSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolFileSearch]; exists && fileSearchTool.CallCount > 0 { + // file search tool 调用价格 $2.50/1k calls + // 计算 file search tool 调用的配额 (配额 = 价格 * 调用次数 / 1000) + dFileSearchQuota = decimal.NewFromFloat(2.5). + Mul(decimal.NewFromInt(int64(fileSearchTool.CallCount))). + Div(decimal.NewFromInt(1000)) + extraContent += fmt.Sprintf("File Search 调用 %d 次,调用花费 $%s", + fileSearchTool.CallCount, dFileSearchQuota.String()) + } + } + var quotaCalculateDecimal decimal.Decimal if !priceData.UsePrice { nonCachedTokens := dPromptTokens.Sub(dCacheTokens) @@ -380,6 +441,9 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, } else { quotaCalculateDecimal = dModelPrice.Mul(dQuotaPerUnit).Mul(dGroupRatio) } + // 添加 responses tools call 调用的配额 + quotaCalculateDecimal = quotaCalculateDecimal.Add(dWebSearchQuota) + quotaCalculateDecimal = quotaCalculateDecimal.Add(dFileSearchQuota) quota := int(quotaCalculateDecimal.Round(0).IntPart()) totalTokens := promptTokens + completionTokens @@ -430,6 +494,19 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, other["image_ratio"] = imageRatio other["image_output"] = imageTokens } + if !dWebSearchQuota.IsZero() && relayInfo.ResponsesUsageInfo != nil { + if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists { + other["web_search"] = true + other["web_search_call_count"] = webSearchTool.CallCount + other["web_search_context_size"] = webSearchTool.SearchContextSize + } + } + if !dFileSearchQuota.IsZero() && relayInfo.ResponsesUsageInfo != nil { + if fileSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolFileSearch]; exists { + other["file_search"] = true + other["file_search_call_count"] = fileSearchTool.CallCount + } + } model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel, tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other) } From d859e3fa645672ca3e38e97654a2de30c6bbd577 Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Tue, 6 May 2025 22:28:32 +0800 Subject: [PATCH 16/57] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E6=9C=AA?= =?UTF-8?q?=E8=BE=93=E5=85=A5=E6=96=B0=E5=AF=86=E7=A0=81=E6=97=B6=E6=8F=90?= =?UTF-8?q?=E7=A4=BA=E4=BF=AE=E6=94=B9=E6=88=90=E5=8A=9F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- controller/user.go | 6 ++--- web/src/components/PersonalSetting.js | 35 ++++++++++++++++++++------- 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/controller/user.go b/controller/user.go index 567c2aa7..fd53e743 100644 --- a/controller/user.go +++ b/controller/user.go @@ -616,9 +616,6 @@ func UpdateSelf(c *gin.Context) { } func checkUpdatePassword(originalPassword string, newPassword string, userId int) (updatePassword bool, err error) { - if newPassword == "" { - return - } var currentUser *model.User currentUser, err = model.GetUserById(userId, true) if err != nil { @@ -628,6 +625,9 @@ func checkUpdatePassword(originalPassword string, newPassword string, userId int err = fmt.Errorf("原密码错误") return } + if newPassword == "" { + return + } updatePassword = true return } diff --git a/web/src/components/PersonalSetting.js b/web/src/components/PersonalSetting.js index fbd74536..0f52c319 100644 --- a/web/src/components/PersonalSetting.js +++ b/web/src/components/PersonalSetting.js @@ -244,6 +244,10 @@ const PersonalSetting = () => { showError(t('请输入原密码!')); return; } + if (inputs.set_new_password === '') { + showError(t('请输入新密码!')); + return; + } if (inputs.original_password === inputs.set_new_password) { showError(t('新密码需要和原密码不一致!')); return; @@ -826,8 +830,8 @@ const PersonalSetting = () => {
- - + +
{t('通知方式')}
@@ -1003,23 +1007,36 @@ const PersonalSetting = () => {
- +
- {t('接受未设置价格模型')} + + {t('接受未设置价格模型')} +
handleNotificationSettingChange('acceptUnsetModelRatioModel', e.target.checked)} + checked={ + notificationSettings.acceptUnsetModelRatioModel + } + onChange={(e) => + handleNotificationSettingChange( + 'acceptUnsetModelRatioModel', + e.target.checked, + ) + } > {t('接受未设置价格模型')} - - {t('当模型没有设置价格时仍接受调用,仅当您信任该网站时使用,可能会产生高额费用')} + + {t( + '当模型没有设置价格时仍接受调用,仅当您信任该网站时使用,可能会产生高额费用', + )}
-
@@ -799,7 +812,13 @@ const SystemSetting = () => { onChange={(value) => setEmailToAdd(value)} style={{ marginTop: 16 }} suffix={ - + } onEnterPress={handleAddEmail} /> From ec615342569d0a58c93b2ca852b3bb9f51db2aab Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Fri, 9 May 2025 13:57:00 +0800 Subject: [PATCH 41/57] feat: send SSE ping before get response --- relay/channel/api_request.go | 66 ++++++++++++++++++- relay/helper/common.go | 18 ++++-- relay/helper/stream_scanner.go | 115 +++++++++++++-------------------- relay/relay-text.go | 10 +-- 4 files changed, 122 insertions(+), 87 deletions(-) diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go index 8b2ca889..db5d4f44 100644 --- a/relay/channel/api_request.go +++ b/relay/channel/api_request.go @@ -1,16 +1,23 @@ package channel import ( + "context" "errors" "fmt" - "github.com/gin-gonic/gin" - "github.com/gorilla/websocket" "io" "net/http" common2 "one-api/common" "one-api/relay/common" "one-api/relay/constant" + "one-api/relay/helper" "one-api/service" + "one-api/setting/operation_setting" + "sync" + "time" + + "github.com/bytedance/gopkg/util/gopool" + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" ) func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Header) { @@ -105,7 +112,62 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http } else { client = service.GetHttpClient() } + // 流式请求 ping 保活 + var stopPinger func() + generalSettings := operation_setting.GetGeneralSetting() + pingEnabled := generalSettings.PingIntervalEnabled + var pingerWg sync.WaitGroup + if info.IsStream { + helper.SetEventStreamHeaders(c) + pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second + var pingerCtx context.Context + pingerCtx, stopPinger = context.WithCancel(c.Request.Context()) + + if pingEnabled { + pingerWg.Add(1) + gopool.Go(func() { + defer pingerWg.Done() + if pingInterval <= 0 { + pingInterval = helper.DefaultPingInterval + } + + ticker := time.NewTicker(pingInterval) + defer ticker.Stop() + var pingMutex sync.Mutex + if common2.DebugEnabled { + println("SSE ping goroutine started") + } + + for { + select { + case <-ticker.C: + pingMutex.Lock() + err2 := helper.PingData(c) + pingMutex.Unlock() + if err2 != nil { + common2.LogError(c, "SSE ping error: "+err.Error()) + return + } + if common2.DebugEnabled { + println("SSE ping data sent.") + } + case <-pingerCtx.Done(): + if common2.DebugEnabled { + println("SSE ping goroutine stopped.") + } + return + } + } + }) + } + } + resp, err := client.Do(req) + // request结束后停止ping + if info.IsStream && pingEnabled { + stopPinger() + pingerWg.Wait() + } if err != nil { return nil, err } diff --git a/relay/helper/common.go b/relay/helper/common.go index 0a3aba1e..35d983f7 100644 --- a/relay/helper/common.go +++ b/relay/helper/common.go @@ -12,11 +12,19 @@ import ( ) func SetEventStreamHeaders(c *gin.Context) { - c.Writer.Header().Set("Content-Type", "text/event-stream") - c.Writer.Header().Set("Cache-Control", "no-cache") - c.Writer.Header().Set("Connection", "keep-alive") - c.Writer.Header().Set("Transfer-Encoding", "chunked") - c.Writer.Header().Set("X-Accel-Buffering", "no") + // 检查是否已经设置过头部 + if _, exists := c.Get("event_stream_headers_set"); exists { + return + } + + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("Transfer-Encoding", "chunked") + c.Writer.Header().Set("X-Accel-Buffering", "no") + + // 设置标志,表示头部已经设置过 + c.Set("event_stream_headers_set", true) } func ClaudeData(c *gin.Context, resp dto.ClaudeResponse) error { diff --git a/relay/helper/stream_scanner.go b/relay/helper/stream_scanner.go index ce4d3a6d..c1bc0d6e 100644 --- a/relay/helper/stream_scanner.go +++ b/relay/helper/stream_scanner.go @@ -3,7 +3,6 @@ package helper import ( "bufio" "context" - "github.com/bytedance/gopkg/util/gopool" "io" "net/http" "one-api/common" @@ -14,6 +13,8 @@ import ( "sync" "time" + "github.com/bytedance/gopkg/util/gopool" + "github.com/gin-gonic/gin" ) @@ -23,76 +24,6 @@ const ( DefaultPingInterval = 10 * time.Second ) -type DoRequestFunc func(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) - -// Optional SSE Ping keep-alive mechanism -// -// Used to solve the problem of the connection with the client timing out due to no data being sent when the upstream -// channel response time is long (e.g., thinking model). -// When enabled, it will send ping data packets to the client via SSE at the specified interval to maintain the connection. -func DoStreamRequestWithPinger(doRequest DoRequestFunc, c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { - SetEventStreamHeaders(c) - - generalSettings := operation_setting.GetGeneralSetting() - pingEnabled := generalSettings.PingIntervalEnabled - pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second - - pingerCtx, stopPinger := context.WithCancel(c.Request.Context()) - var pingerWg sync.WaitGroup - var doRequestErr error - var resp any - - if pingEnabled { - pingerWg.Add(1) - - gopool.Go(func() { - defer pingerWg.Done() - - if pingInterval <= 0 { - pingInterval = DefaultPingInterval - } - - ticker := time.NewTicker(pingInterval) - defer ticker.Stop() - var pingMutex sync.Mutex - - if common.DebugEnabled { - println("SSE ping goroutine started.") - } - - for { - select { - case <-ticker.C: - pingMutex.Lock() - err := PingData(c) - pingMutex.Unlock() - if err != nil { - common.LogError(c, "SSE ping error: "+err.Error()) - return - } - if common.DebugEnabled { - println("SSE ping data sent.") - } - case <-pingerCtx.Done(): - if common.DebugEnabled { - println("SSE ping goroutine stopped.") - } - return - } - } - }) - } - - resp, doRequestErr = doRequest(c, info, requestBody) - - stopPinger() - if pingEnabled { - pingerWg.Wait() - } - - return resp, doRequestErr -} - func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, dataHandler func(data string) bool) { if resp == nil || dataHandler == nil { @@ -111,11 +42,26 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon stopChan = make(chan bool, 2) scanner = bufio.NewScanner(resp.Body) ticker = time.NewTicker(streamingTimeout) + pingTicker *time.Ticker writeMutex sync.Mutex // Mutex to protect concurrent writes ) + generalSettings := operation_setting.GetGeneralSetting() + pingEnabled := generalSettings.PingIntervalEnabled + pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second + if pingInterval <= 0 { + pingInterval = DefaultPingInterval + } + + if pingEnabled { + pingTicker = time.NewTicker(pingInterval) + } + defer func() { ticker.Stop() + if pingTicker != nil { + pingTicker.Stop() + } close(stopChan) }() scanner.Buffer(make([]byte, InitialScannerBufferSize), MaxScannerBufferSize) @@ -127,6 +73,33 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon ctx = context.WithValue(ctx, "stop_chan", stopChan) + // Handle ping data sending + if pingEnabled && pingTicker != nil { + gopool.Go(func() { + for { + select { + case <-pingTicker.C: + writeMutex.Lock() // Lock before writing + err := PingData(c) + writeMutex.Unlock() // Unlock after writing + if err != nil { + common.LogError(c, "ping data error: "+err.Error()) + common.SafeSendBool(stopChan, true) + return + } + if common.DebugEnabled { + println("ping data sent") + } + case <-ctx.Done(): + if common.DebugEnabled { + println("ping data goroutine stopped") + } + return + } + } + }) + } + common.RelayCtxGo(ctx, func() { for scanner.Scan() { ticker.Reset(streamingTimeout) diff --git a/relay/relay-text.go b/relay/relay-text.go index 69a48637..8d5cd384 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -193,15 +193,7 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { } var httpResp *http.Response - var resp any - - if relayInfo.IsStream { - // Streaming requests can use SSE ping to keep alive and avoid connection timeout - // The judgment of whether ping is enabled will be made within the function - resp, err = helper.DoStreamRequestWithPinger(adaptor.DoRequest, c, relayInfo, requestBody) - } else { - resp, err = adaptor.DoRequest(c, relayInfo, requestBody) - } + resp, err := adaptor.DoRequest(c, relayInfo, requestBody) if err != nil { return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) From 40efa73a42e5d7cc943ca46a9f087ea11030f101 Mon Sep 17 00:00:00 2001 From: skynono <6811626@qq.com> Date: Fri, 9 May 2025 17:11:25 +0800 Subject: [PATCH 42/57] fix: correct formatting string in PriceData.ToSetting to handle ImageRatio as float instead of integer --- relay/helper/price.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/relay/helper/price.go b/relay/helper/price.go index 899c72b9..89efa1da 100644 --- a/relay/helper/price.go +++ b/relay/helper/price.go @@ -23,7 +23,7 @@ type PriceData struct { } func (p PriceData) ToSetting() string { - return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %d", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio) + return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio) } func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) { From 9ebfcaf6aa3ec55078121eba172a57530751dd81 Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Fri, 9 May 2025 18:11:37 +0800 Subject: [PATCH 43/57] feat: change azure default api version to 2025-04-01-preview --- README.en.md | 2 +- README.md | 2 +- constant/env.go | 2 +- web/src/i18n/locales/en.json | 4 +- web/src/pages/Channel/EditChannel.js | 67 +++++++++++++++++----------- 5 files changed, 45 insertions(+), 32 deletions(-) diff --git a/README.en.md b/README.en.md index 23fdbe1f..4709bc5b 100644 --- a/README.en.md +++ b/README.en.md @@ -107,7 +107,7 @@ For detailed configuration instructions, please refer to [Installation Guide-Env - `GEMINI_VISION_MAX_IMAGE_NUM`: Maximum number of images for Gemini models, default is `16` - `MAX_FILE_DOWNLOAD_MB`: Maximum file download size in MB, default is `20` - `CRYPTO_SECRET`: Encryption key used for encrypting database content -- `AZURE_DEFAULT_API_VERSION`: Azure channel default API version, default is `2024-12-01-preview` +- `AZURE_DEFAULT_API_VERSION`: Azure channel default API version, default is `2025-04-01-preview` - `NOTIFICATION_LIMIT_DURATION_MINUTE`: Notification limit duration, default is `10` minutes - `NOTIFY_LIMIT_COUNT`: Maximum number of user notifications within the specified duration, default is `2` diff --git a/README.md b/README.md index 67af9916..a807b07d 100644 --- a/README.md +++ b/README.md @@ -107,7 +107,7 @@ New API提供了丰富的功能,详细特性请参考[特性说明](https://do - `GEMINI_VISION_MAX_IMAGE_NUM`:Gemini模型最大图片数量,默认 `16` - `MAX_FILE_DOWNLOAD_MB`: 最大文件下载大小,单位MB,默认 `20` - `CRYPTO_SECRET`:加密密钥,用于加密数据库内容 -- `AZURE_DEFAULT_API_VERSION`:Azure渠道默认API版本,默认 `2024-12-01-preview` +- `AZURE_DEFAULT_API_VERSION`:Azure渠道默认API版本,默认 `2025-04-01-preview` - `NOTIFICATION_LIMIT_DURATION_MINUTE`:通知限制持续时间,默认 `10`分钟 - `NOTIFY_LIMIT_COUNT`:用户通知在指定持续时间内的最大数量,默认 `2` diff --git a/constant/env.go b/constant/env.go index fae48625..612f3e8b 100644 --- a/constant/env.go +++ b/constant/env.go @@ -31,7 +31,7 @@ func InitEnv() { GetMediaToken = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true) GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true) UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true) - AzureDefaultAPIVersion = common.GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2024-12-01-preview") + AzureDefaultAPIVersion = common.GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2025-04-01-preview") GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16) NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2) NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10) diff --git a/web/src/i18n/locales/en.json b/web/src/i18n/locales/en.json index eedf1196..916329e7 100644 --- a/web/src/i18n/locales/en.json +++ b/web/src/i18n/locales/en.json @@ -1086,7 +1086,7 @@ "没有账户?": "No account? ", "请输入 AZURE_OPENAI_ENDPOINT,例如:https://docs-test-001.openai.azure.com": "Please enter AZURE_OPENAI_ENDPOINT, e.g.: https://docs-test-001.openai.azure.com", "默认 API 版本": "Default API Version", - "请输入默认 API 版本,例如:2024-12-01-preview": "Please enter default API version, e.g.: 2024-12-01-preview.", + "请输入默认 API 版本,例如:2025-04-01-preview": "Please enter default API version, e.g.: 2025-04-01-preview.", "请为渠道命名": "Please name the channel", "请选择可以使用该渠道的分组": "Please select groups that can use this channel", "请在系统设置页面编辑分组倍率以添加新的分组:": "Please edit Group ratios in system settings to add new groups:", @@ -1374,4 +1374,4 @@ "适用于展示系统功能的场景。": "Suitable for scenarios where the system functions are displayed.", "可在初始化后修改": "Can be modified after initialization", "初始化系统": "Initialize system" -} +} \ No newline at end of file diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index cba787fc..fd96ffb6 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -24,7 +24,8 @@ import { TextArea, Checkbox, Banner, - Modal, ImagePreview + Modal, + ImagePreview, } from '@douyinfe/semi-ui'; import { getChannelModels, loadChannelModels } from '../../components/utils.js'; import { IconHelpCircle } from '@douyinfe/semi-icons'; @@ -306,7 +307,7 @@ const EditChannel = (props) => { fetchModels().then(); fetchGroups().then(); if (isEdit) { - loadChannel().then(() => { }); + loadChannel().then(() => {}); } else { setInputs(originInputs); let localModels = getChannelModels(inputs.type); @@ -477,7 +478,9 @@ const EditChannel = (props) => { type={'warning'} description={ <> - {t('2025年5月10日后添加的渠道,不需要再在部署的时候移除模型名称中的"."')} + {t( + '2025年5月10日后添加的渠道,不需要再在部署的时候移除模型名称中的"."', + )} {/*
*/} {/* { { handleInputChange('other', value); }} @@ -584,25 +587,35 @@ const EditChannel = (props) => { value={inputs.name} autoComplete='new-password' /> - {inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && inputs.type !== 36 && inputs.type !== 45 && ( - <> -
- {t('API地址')}: -
- - { - handleInputChange('base_url', value); - }} - value={inputs.base_url} - autoComplete="new-password" - /> - - - )} + {inputs.type !== 3 && + inputs.type !== 8 && + inputs.type !== 22 && + inputs.type !== 36 && + inputs.type !== 45 && ( + <> +
+ {t('API地址')}: +
+ + { + handleInputChange('base_url', value); + }} + value={inputs.base_url} + autoComplete='new-password' + /> + + + )}
{t('密钥')}:
@@ -761,10 +774,10 @@ const EditChannel = (props) => { name='other' placeholder={t( '请输入部署地区,例如:us-central1\n支持使用模型映射格式\n' + - '{\n' + - ' "default": "us-central1",\n' + - ' "claude-3-5-sonnet-20240620": "europe-west1"\n' + - '}', + '{\n' + + ' "default": "us-central1",\n' + + ' "claude-3-5-sonnet-20240620": "europe-west1"\n' + + '}', )} autosize={{ minRows: 2 }} onChange={(value) => { From 0d929800cf40f483679684a48c430a163775cf48 Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Fri, 9 May 2025 18:13:19 +0800 Subject: [PATCH 44/57] fix: GetRequestURL remove unnecessary case --- relay/channel/openai/adaptor.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index da92692b..f0cf073f 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -67,9 +67,6 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if info.RelayFormat == relaycommon.RelayFormatClaude { return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil } - if info.RelayMode == constant.RelayModeResponses { - return fmt.Sprintf("%s/v1/responses", info.BaseUrl), nil - } if info.RelayMode == constant.RelayModeRealtime { if strings.HasPrefix(info.BaseUrl, "https://") { baseUrl := strings.TrimPrefix(info.BaseUrl, "https://") From 7b176015b82a3ec0276b24805f2498a49da48aa9 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Fri, 9 May 2025 18:57:06 +0800 Subject: [PATCH 45/57] feat: enhance OpenAI handler to support forced response formatting and add debug logging for request URLs --- relay/channel/api_request.go | 3 +++ relay/channel/openai/relay-openai.go | 39 ++++++++++++++++++---------- 2 files changed, 29 insertions(+), 13 deletions(-) diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go index db5d4f44..03eff9cf 100644 --- a/relay/channel/api_request.go +++ b/relay/channel/api_request.go @@ -62,6 +62,9 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod if err != nil { return nil, fmt.Errorf("get request url failed: %w", err) } + if common2.DebugEnabled { + println("fullRequestURL:", fullRequestURL) + } req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) if err != nil { return nil, fmt.Errorf("new request failed: %w", err) diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index b9ed94e2..86c47a15 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -215,10 +215,35 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI StatusCode: resp.StatusCode, }, nil } + + forceFormat := false + if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok { + forceFormat = forceFmt + } + + if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) { + completionTokens := 0 + for _, choice := range simpleResponse.Choices { + ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName) + completionTokens += ctkm + } + simpleResponse.Usage = dto.Usage{ + PromptTokens: info.PromptTokens, + CompletionTokens: completionTokens, + TotalTokens: info.PromptTokens + completionTokens, + } + } switch info.RelayFormat { case relaycommon.RelayFormatOpenAI: - break + if forceFormat { + responseBody, err = json.Marshal(simpleResponse) + if err != nil { + return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + } else { + break + } case relaycommon.RelayFormatClaude: claudeResp := service.ResponseOpenAI2Claude(&simpleResponse, info) claudeRespStr, err := json.Marshal(claudeResp) @@ -244,18 +269,6 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI common.SysError("error copying response body: " + err.Error()) } resp.Body.Close() - if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) { - completionTokens := 0 - for _, choice := range simpleResponse.Choices { - ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName) - completionTokens += ctkm - } - simpleResponse.Usage = dto.Usage{ - PromptTokens: info.PromptTokens, - CompletionTokens: completionTokens, - TotalTokens: info.PromptTokens + completionTokens, - } - } return nil, &simpleResponse.Usage } From 28cdfc0a14e95602b8263a6eab7e6a0a90088fe3 Mon Sep 17 00:00:00 2001 From: a37836323 <37836323@qq.com> Date: Sat, 10 May 2025 04:33:49 +0800 Subject: [PATCH 46/57] =?UTF-8?q?=E6=B7=BB=E5=8A=A0DALL-E=E5=9B=BE?= =?UTF-8?q?=E5=83=8F=E7=94=9F=E6=88=90=E8=AF=B7=E6=B1=82=E4=B8=AD=E7=9A=84?= =?UTF-8?q?Background=E5=92=8CModeration=E5=AD=97=E6=AE=B5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dto/dalle.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dto/dalle.go b/dto/dalle.go index 562d5f1a..44104d33 100644 --- a/dto/dalle.go +++ b/dto/dalle.go @@ -12,6 +12,8 @@ type ImageRequest struct { Style string `json:"style,omitempty"` User string `json:"user,omitempty"` ExtraFields json.RawMessage `json:"extra_fields,omitempty"` + Background string `json:"background,omitempty"` + Moderation string `json:"moderation,omitempty"` } type ImageResponse struct { From 58dc7ad770dcd6f5595aeac1c91194761fccfbc2 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Sat, 10 May 2025 15:52:41 +0800 Subject: [PATCH 47/57] feat: add moderation and background fields to ImageRequest struct in dalle.go #1052 --- dto/dalle.go | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/dto/dalle.go b/dto/dalle.go index 562d5f1a..ab2c94e1 100644 --- a/dto/dalle.go +++ b/dto/dalle.go @@ -1,17 +1,16 @@ package dto -import "encoding/json" - type ImageRequest struct { - Model string `json:"model"` - Prompt string `json:"prompt" binding:"required"` - N int `json:"n,omitempty"` - Size string `json:"size,omitempty"` - Quality string `json:"quality,omitempty"` - ResponseFormat string `json:"response_format,omitempty"` - Style string `json:"style,omitempty"` - User string `json:"user,omitempty"` - ExtraFields json.RawMessage `json:"extra_fields,omitempty"` + Model string `json:"model"` + Prompt string `json:"prompt" binding:"required"` + N int `json:"n,omitempty"` + Size string `json:"size,omitempty"` + Quality string `json:"quality,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` + Style string `json:"style,omitempty"` + User string `json:"user,omitempty"` + Moderation string `json:"moderation,omitempty"` + Background string `json:"background,omitempty"` } type ImageResponse struct { From d985563516a10806284254824fe7cb4ca9676ec4 Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Sun, 11 May 2025 17:00:33 +0800 Subject: [PATCH 48/57] feat: add support for socks5h --- service/http_client.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/service/http_client.go b/service/http_client.go index c3f8df7a..64a361cf 100644 --- a/service/http_client.go +++ b/service/http_client.go @@ -3,12 +3,13 @@ package service import ( "context" "fmt" - "golang.org/x/net/proxy" "net" "net/http" "net/url" "one-api/common" "time" + + "golang.org/x/net/proxy" ) var httpClient *http.Client @@ -55,7 +56,7 @@ func NewProxyHttpClient(proxyURL string) (*http.Client, error) { }, }, nil - case "socks5": + case "socks5", "socks5h": // 获取认证信息 var auth *proxy.Auth if parsedURL.User != nil { @@ -69,6 +70,7 @@ func NewProxyHttpClient(proxyURL string) (*http.Client, error) { } // 创建 SOCKS5 代理拨号器 + // proxy.SOCKS5 使用 tcp 参数,所有 TCP 连接包括 DNS 查询都将通过代理进行。行为与 socks5h 相同 dialer, err := proxy.SOCKS5("tcp", parsedURL.Host, auth, proxy.Direct) if err != nil { return nil, err From b2cad229520ab533f1981daefe9a478502ddb31f Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Tue, 13 May 2025 12:52:22 +0800 Subject: [PATCH 49/57] add coze request --- common/constants.go | 2 + relay/channel/coze/adaptor.go | 125 +++++++++++++++++++++++++++++++ relay/channel/coze/constants.go | 8 ++ relay/channel/coze/dto.go | 81 ++++++++++++++++++++ relay/channel/coze/relay-coze.go | 121 ++++++++++++++++++++++++++++++ relay/constant/api_type.go | 3 + relay/relay_adaptor.go | 3 + 7 files changed, 343 insertions(+) create mode 100644 relay/channel/coze/adaptor.go create mode 100644 relay/channel/coze/constants.go create mode 100644 relay/channel/coze/dto.go create mode 100644 relay/channel/coze/relay-coze.go diff --git a/common/constants.go b/common/constants.go index dd4f3b04..bee00506 100644 --- a/common/constants.go +++ b/common/constants.go @@ -240,6 +240,7 @@ const ( ChannelTypeBaiduV2 = 46 ChannelTypeXinference = 47 ChannelTypeXai = 48 + ChannelTypeCoze = 49 ChannelTypeDummy // this one is only for count, do not add any channel after this ) @@ -294,4 +295,5 @@ var ChannelBaseURLs = []string{ "https://qianfan.baidubce.com", //46 "", //47 "https://api.x.ai", //48 + "https://api.coze.cn", //49 } diff --git a/relay/channel/coze/adaptor.go b/relay/channel/coze/adaptor.go new file mode 100644 index 00000000..b14239a6 --- /dev/null +++ b/relay/channel/coze/adaptor.go @@ -0,0 +1,125 @@ +package coze + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "one-api/dto" + "one-api/relay/channel" + "one-api/relay/common" + "time" + + "github.com/gin-gonic/gin" +) + +type Adaptor struct { +} + +// ConvertAudioRequest implements channel.Adaptor. +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *common.RelayInfo, request dto.AudioRequest) (io.Reader, error) { + return nil, errors.New("not implemented") +} + +// ConvertClaudeRequest implements channel.Adaptor. +func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *common.RelayInfo, request *dto.ClaudeRequest) (any, error) { + return nil, errors.New("not implemented") +} + +// ConvertEmbeddingRequest implements channel.Adaptor. +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *common.RelayInfo, request dto.EmbeddingRequest) (any, error) { + return nil, errors.New("not implemented") +} + +// ConvertImageRequest implements channel.Adaptor. +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *common.RelayInfo, request dto.ImageRequest) (any, error) { + return nil, errors.New("not implemented") +} + +// ConvertOpenAIRequest implements channel.Adaptor. +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *common.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return convertCozeChatRequest(*request), nil +} + +// ConvertOpenAIResponsesRequest implements channel.Adaptor. +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *common.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + return nil, errors.New("not implemented") +} + +// ConvertRerankRequest implements channel.Adaptor. +func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { + return nil, errors.New("not implemented") +} + +// DoRequest implements channel.Adaptor. +func (a *Adaptor) DoRequest(c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (any, error) { + // 首先发送创建消息请求,成功后再发送获取消息请求 + // 发送创建消息请求 + resp, err := channel.DoApiRequest(a, c, info, requestBody) + if err != nil { + return nil, err + } + // 解析 resp + var cozeResponse CozeChatResponse + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + err = json.Unmarshal(respBody, &cozeResponse) + if cozeResponse.Code != 0 { + return nil, errors.New(cozeResponse.Msg) + } + c.Set("coze_conversation_id", cozeResponse.Data.ConversationId) + c.Set("coze_chat_id", cozeResponse.Data.Id) + // 轮询检查消息是否完成 + for { + err, isComplete := checkIfChatComplete(a, c, info) + if err != nil { + return nil, err + } else { + if isComplete { + break + } + } + time.Sleep(time.Second * 1) + } + // 发送获取消息请求 + return channel.DoApiRequest(a, c, info, requestBody) +} + +// DoResponse implements channel.Adaptor. +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *common.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { + err, usage = cozeChatHandler(c, resp, info) + return +} + +// GetChannelName implements channel.Adaptor. +func (a *Adaptor) GetChannelName() string { + return ChannelName +} + +// GetModelList implements channel.Adaptor. +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +// GetRequestURL implements channel.Adaptor. +func (a *Adaptor) GetRequestURL(info *common.RelayInfo) (string, error) { + return fmt.Sprintf("%s/v3/chat/message/list", info.BaseUrl), nil +} + +// Init implements channel.Adaptor. +func (a *Adaptor) Init(info *common.RelayInfo) { + +} + +// SetupRequestHeader implements channel.Adaptor. +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *common.RelayInfo) error { + channel.SetupApiRequestHeader(info, c, req) + req.Set("Authorization", "Bearer "+info.ApiKey) + return nil +} diff --git a/relay/channel/coze/constants.go b/relay/channel/coze/constants.go new file mode 100644 index 00000000..da28cb83 --- /dev/null +++ b/relay/channel/coze/constants.go @@ -0,0 +1,8 @@ +package coze + +var ModelList = []string{ + // TODO: 完整列表 + "deepseek-v3", +} + +var ChannelName = "coze" diff --git a/relay/channel/coze/dto.go b/relay/channel/coze/dto.go new file mode 100644 index 00000000..fb92289a --- /dev/null +++ b/relay/channel/coze/dto.go @@ -0,0 +1,81 @@ +package coze + +import "encoding/json" + +// type CozeResponse struct { +// Code int `json:"code"` +// Message string `json:"message"` +// Data CozeConversationData `json:"data"` +// Detail CozeConversationData `json:"detail"` +// } + +// type CozeConversationData struct { +// Id string `json:"id"` +// CreatedAt int64 `json:"created_at"` +// MetaData json.RawMessage `json:"meta_data"` +// LastSectionId string `json:"last_section_id"` +// } + +// type CozeResponseDetail struct { +// Logid string `json:"logid"` +// } + +type CozeError struct { + Code int `json:"code"` + Message string `json:"message"` +} + +// type CozeErrorWithStatusCode struct { +// Error CozeError `json:"error"` +// StatusCode int +// LocalError bool +// } + +type CozeRequest struct { + BotId string `json:"bot_id,omitempty"` + MetaData json.RawMessage `json:"meta_data,omitempty"` + Messages []CozeEnterMessage `json:"messages,omitempty"` +} + +type CozeEnterMessage struct { + Role string `json:"role"` + Type string `json:"type,omitempty"` + Content json.RawMessage `json:"content,omitempty"` + MetaData json.RawMessage `json:"meta_data,omitempty"` + ContentType string `json:"content_type,omitempty"` +} + +type CozeChatRequest struct { + BotId string `json:"bot_id"` + UserId string `json:"user_id"` + AdditionalMessages []CozeEnterMessage `json:"additional_messages,omitempty"` + Stream bool `json:"stream,omitempty"` + CustomVariables json.RawMessage `json:"custom_variables,omitempty"` + AutoSaveHistory bool `json:"auto_save_history,omitempty"` + MetaData json.RawMessage `json:"meta_data,omitempty"` + ExtraParams json.RawMessage `json:"extra_params,omitempty"` + ShortcutCommand json.RawMessage `json:"shortcut_command,omitempty"` + Parameters json.RawMessage `json:"parameters,omitempty"` +} + +type CozeChatResponse struct { + Code int `json:"code"` + Msg string `json:"msg"` + Data CozeChatResponseData `json:"data"` +} + +type CozeChatResponseData struct { + Id string `json:"id"` + ConversationId string `json:"conversation_id"` + BotId string `json:"bot_id"` + CreatedAt int64 `json:"created_at"` + LastError CozeError `json:"last_error"` + Status string `json:"status"` + Usage CozeChatUsage `json:"usage"` +} + +type CozeChatUsage struct { + TokenCount int `json:"token_count"` + OutputCount int `json:"output_count"` + InputCount int `json:"input_count"` +} diff --git a/relay/channel/coze/relay-coze.go b/relay/channel/coze/relay-coze.go new file mode 100644 index 00000000..49a3ac15 --- /dev/null +++ b/relay/channel/coze/relay-coze.go @@ -0,0 +1,121 @@ +package coze + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "one-api/dto" + "one-api/relay/common" + relaycommon "one-api/relay/common" + "one-api/service" + + "github.com/gin-gonic/gin" +) + +func convertCozeChatRequest(request dto.GeneralOpenAIRequest) *CozeRequest { + var messages []CozeEnterMessage + // 将 request的messages的role为user的content转换为CozeMessage + for _, message := range request.Messages { + if message.Role == "user" { + messages = append(messages, CozeEnterMessage{ + Role: "user", + Content: message.Content, + // TODO: support more content type + ContentType: "text", + }) + } + } + cozeRequest := &CozeRequest{ + // TODO: model to botid + BotId: "1", + Messages: messages, + } + return cozeRequest +} + +func cozeChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + // convert coze response to openai response + var response dto.TextResponse + var cozeResponse CozeChatResponse + err = json.Unmarshal(responseBody, &cozeResponse) + if err != nil { + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + response.Model = info.UpstreamModelName + // TODO: 处理 cozeResponse + return nil, nil +} + +func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (error, bool) { + requestURL := fmt.Sprintf("%s/v3/chat/retrieve", info.BaseUrl) + + requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id") + // 将 conversationId和chatId作为参数发送get请求 + req, err := http.NewRequest("GET", requestURL, nil) + if err != nil { + return err, false + } + err = a.SetupRequestHeader(c, &req.Header, info) + if err != nil { + return err, false + } + + resp, err := doRequest(req, info) // 调用 doRequest + if err != nil { + return err, false + } + if resp == nil { // 确保在 doRequest 失败时 resp 不为 nil 导致 panic + return fmt.Errorf("resp is nil"), false + } + defer resp.Body.Close() // 确保响应体被关闭 + + // 解析 resp 到 CozeChatResponse + var cozeResponse CozeChatResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("read response body failed: %w", err), false + } + err = json.Unmarshal(responseBody, &cozeResponse) + if err != nil { + return fmt.Errorf("unmarshal response body failed: %w", err), false + } + if cozeResponse.Data.Status == "completed" { + // 在上下文设置 usage + c.Set("coze_token_count", cozeResponse.Data.Usage.TokenCount) + c.Set("coze_output_count", cozeResponse.Data.Usage.OutputCount) + c.Set("coze_input_count", cozeResponse.Data.Usage.InputCount) + return nil, true + } else if cozeResponse.Data.Status == "failed" || cozeResponse.Data.Status == "canceled" || cozeResponse.Data.Status == "requires_action" { + return fmt.Errorf("chat status: %s", cozeResponse.Data.Status), false + } else { + return nil, false + } +} + +func doRequest(req *http.Request, info *common.RelayInfo) (*http.Response, error) { + var client *http.Client + var err error // 声明 err 变量 + if proxyURL, ok := info.ChannelSetting["proxy"]; ok { + client, err = service.NewProxyHttpClient(proxyURL.(string)) + if err != nil { + return nil, fmt.Errorf("new proxy http client failed: %w", err) + } + } else { + client = service.GetHttpClient() + } + resp, err := client.Do(req) + if err != nil { // 增加对 client.Do(req) 返回错误的检查 + return nil, fmt.Errorf("client.Do failed: %w", err) + } + _ = resp.Body.Close() + return resp, nil +} diff --git a/relay/constant/api_type.go b/relay/constant/api_type.go index fef38f23..3f1ecd78 100644 --- a/relay/constant/api_type.go +++ b/relay/constant/api_type.go @@ -33,6 +33,7 @@ const ( APITypeOpenRouter APITypeXinference APITypeXai + APITypeCoze APITypeDummy // this one is only for count, do not add any channel after this ) @@ -95,6 +96,8 @@ func ChannelType2APIType(channelType int) (int, bool) { apiType = APITypeXinference case common.ChannelTypeXai: apiType = APITypeXai + case common.ChannelTypeCoze: + apiType = APITypeCoze } if apiType == -1 { return APITypeOpenAI, false diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go index 8b4afcb3..7bf0da9f 100644 --- a/relay/relay_adaptor.go +++ b/relay/relay_adaptor.go @@ -10,6 +10,7 @@ import ( "one-api/relay/channel/claude" "one-api/relay/channel/cloudflare" "one-api/relay/channel/cohere" + "one-api/relay/channel/coze" "one-api/relay/channel/deepseek" "one-api/relay/channel/dify" "one-api/relay/channel/gemini" @@ -88,6 +89,8 @@ func GetAdaptor(apiType int) channel.Adaptor { return &openai.Adaptor{} case constant.APITypeXai: return &xai.Adaptor{} + case constant.APITypeCoze: + return &coze.Adaptor{} } return nil } From f17f38e56906936ce1e000e6842371fd85520eff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=B0=B8=E6=8C=AF?= Date: Tue, 13 May 2025 13:39:44 +0800 Subject: [PATCH 50/57] fix: ALI completions api path error --- relay/channel/ali/adaptor.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index 8e34fd80..ab632d22 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -33,6 +33,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", info.BaseUrl) case constant.RelayModeImagesGenerations: fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.BaseUrl) + case constant.RelayModeCompletions: + fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/completions", info.BaseUrl) default: fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.BaseUrl) } From b2499b0a7ed0d902ad7ae4653dd0d0ab7e81055a Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Tue, 13 May 2025 21:13:34 +0800 Subject: [PATCH 51/57] DoRequest --- relay/channel/coze/adaptor.go | 6 +++--- relay/channel/coze/dto.go | 30 ------------------------------ relay/channel/coze/relay-coze.go | 29 +++++++++++++++++++++++++---- 3 files changed, 28 insertions(+), 37 deletions(-) diff --git a/relay/channel/coze/adaptor.go b/relay/channel/coze/adaptor.go index b14239a6..34931cc6 100644 --- a/relay/channel/coze/adaptor.go +++ b/relay/channel/coze/adaptor.go @@ -42,7 +42,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *common.RelayInfo, r if request == nil { return nil, errors.New("request is nil") } - return convertCozeChatRequest(*request), nil + return convertCozeChatRequest(c, *request), nil } // ConvertOpenAIResponsesRequest implements channel.Adaptor. @@ -88,7 +88,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *common.RelayInfo, requestBody time.Sleep(time.Second * 1) } // 发送获取消息请求 - return channel.DoApiRequest(a, c, info, requestBody) + return getChatDetail(a, c, info) } // DoResponse implements channel.Adaptor. @@ -109,7 +109,7 @@ func (a *Adaptor) GetModelList() []string { // GetRequestURL implements channel.Adaptor. func (a *Adaptor) GetRequestURL(info *common.RelayInfo) (string, error) { - return fmt.Sprintf("%s/v3/chat/message/list", info.BaseUrl), nil + return fmt.Sprintf("%s/v3/chat", info.BaseUrl), nil } // Init implements channel.Adaptor. diff --git a/relay/channel/coze/dto.go b/relay/channel/coze/dto.go index fb92289a..38fc2f16 100644 --- a/relay/channel/coze/dto.go +++ b/relay/channel/coze/dto.go @@ -2,41 +2,11 @@ package coze import "encoding/json" -// type CozeResponse struct { -// Code int `json:"code"` -// Message string `json:"message"` -// Data CozeConversationData `json:"data"` -// Detail CozeConversationData `json:"detail"` -// } - -// type CozeConversationData struct { -// Id string `json:"id"` -// CreatedAt int64 `json:"created_at"` -// MetaData json.RawMessage `json:"meta_data"` -// LastSectionId string `json:"last_section_id"` -// } - -// type CozeResponseDetail struct { -// Logid string `json:"logid"` -// } - type CozeError struct { Code int `json:"code"` Message string `json:"message"` } -// type CozeErrorWithStatusCode struct { -// Error CozeError `json:"error"` -// StatusCode int -// LocalError bool -// } - -type CozeRequest struct { - BotId string `json:"bot_id,omitempty"` - MetaData json.RawMessage `json:"meta_data,omitempty"` - Messages []CozeEnterMessage `json:"messages,omitempty"` -} - type CozeEnterMessage struct { Role string `json:"role"` Type string `json:"type,omitempty"` diff --git a/relay/channel/coze/relay-coze.go b/relay/channel/coze/relay-coze.go index 49a3ac15..7c16763e 100644 --- a/relay/channel/coze/relay-coze.go +++ b/relay/channel/coze/relay-coze.go @@ -13,7 +13,7 @@ import ( "github.com/gin-gonic/gin" ) -func convertCozeChatRequest(request dto.GeneralOpenAIRequest) *CozeRequest { +func convertCozeChatRequest(c *gin.Context, request dto.GeneralOpenAIRequest) *CozeChatRequest { var messages []CozeEnterMessage // 将 request的messages的role为user的content转换为CozeMessage for _, message := range request.Messages { @@ -26,10 +26,12 @@ func convertCozeChatRequest(request dto.GeneralOpenAIRequest) *CozeRequest { }) } } - cozeRequest := &CozeRequest{ + cozeRequest := &CozeChatRequest{ // TODO: model to botid - BotId: "1", - Messages: messages, + BotId: "1", + UserId: c.GetString("id"), + AdditionalMessages: messages, + Stream: request.Stream, } return cozeRequest } @@ -101,6 +103,25 @@ func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo } } +func getChatDetail(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (*http.Response, error) { + requestURL := fmt.Sprintf("%s/v3/chat/message/list", info.BaseUrl) + + requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id") + req, err := http.NewRequest("GET", requestURL, nil) + if err != nil { + return nil, fmt.Errorf("new request failed: %w", err) + } + err = a.SetupRequestHeader(c, &req.Header, info) + if err != nil { + return nil, fmt.Errorf("setup request header failed: %w", err) + } + resp, err := doRequest(req, info) + if err != nil { + return nil, fmt.Errorf("do request failed: %w", err) + } + return resp, nil +} + func doRequest(req *http.Request, info *common.RelayInfo) (*http.Response, error) { var client *http.Client var err error // 声明 err 变量 From 29c95c598e380dbe5ff80cd0690a1c4c3770f93d Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Tue, 13 May 2025 22:01:12 +0800 Subject: [PATCH 52/57] cozeChatHelper --- relay/channel/coze/dto.go | 27 ++++++++++++++++++++ relay/channel/coze/relay-coze.go | 43 +++++++++++++++++++++++++++++--- 2 files changed, 66 insertions(+), 4 deletions(-) diff --git a/relay/channel/coze/dto.go b/relay/channel/coze/dto.go index 38fc2f16..4e9afa23 100644 --- a/relay/channel/coze/dto.go +++ b/relay/channel/coze/dto.go @@ -49,3 +49,30 @@ type CozeChatUsage struct { OutputCount int `json:"output_count"` InputCount int `json:"input_count"` } + +type CozeChatDetailResponse struct { + Data []CozeChatV3MessageDetail `json:"data"` + Code int `json:"code"` + Msg string `json:"msg"` + Detail CozeResponseDetail `json:"detail"` +} + +type CozeChatV3MessageDetail struct { + Id string `json:"id"` + Role string `json:"role"` + Type string `json:"type"` + BotId string `json:"bot_id"` + ChatId string `json:"chat_id"` + Content json.RawMessage `json:"content"` + MetaData json.RawMessage `json:"meta_data"` + CreatedAt int64 `json:"created_at"` + SectionId string `json:"section_id"` + UpdatedAt int64 `json:"updated_at"` + ContentType string `json:"content_type"` + ConversationId string `json:"conversation_id"` + ReasoningContent string `json:"reasoning_content"` +} + +type CozeResponseDetail struct { + Logid string `json:"logid"` +} diff --git a/relay/channel/coze/relay-coze.go b/relay/channel/coze/relay-coze.go index 7c16763e..fe630ef6 100644 --- a/relay/channel/coze/relay-coze.go +++ b/relay/channel/coze/relay-coze.go @@ -2,12 +2,14 @@ package coze import ( "encoding/json" + "errors" "fmt" "io" "net/http" "one-api/dto" "one-api/relay/common" relaycommon "one-api/relay/common" + "one-api/relay/helper" "one-api/service" "github.com/gin-gonic/gin" @@ -47,14 +49,47 @@ func cozeChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela } // convert coze response to openai response var response dto.TextResponse - var cozeResponse CozeChatResponse + var cozeResponse CozeChatDetailResponse + response.Model = info.UpstreamModelName err = json.Unmarshal(responseBody, &cozeResponse) if err != nil { return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } - response.Model = info.UpstreamModelName - // TODO: 处理 cozeResponse - return nil, nil + if cozeResponse.Code != 0 { + return service.OpenAIErrorWrapper(errors.New(cozeResponse.Msg), fmt.Sprintf("%d", cozeResponse.Code), http.StatusInternalServerError), nil + } + // 从上下文获取 usage + var usage dto.Usage + usage.PromptTokens = c.GetInt("coze_input_count") + usage.CompletionTokens = c.GetInt("coze_output_count") + usage.TotalTokens = c.GetInt("coze_token_count") + response.Usage = usage + response.Id = helper.GetResponseID(c) + + var responseContent json.RawMessage + for _, data := range cozeResponse.Data { + if data.Type == "answer" { + responseContent = data.Content + response.Created = data.CreatedAt + } + } + // 添加 response.Choices + response.Choices = []dto.OpenAITextResponseChoice{ + { + Index: 0, + Message: dto.Message{Role: "assistant", Content: responseContent}, + FinishReason: "stop", + }, + } + jsonResponse, err := json.Marshal(response) + if err != nil { + return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, _ = c.Writer.Write(jsonResponse) + + return nil, &usage } func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (error, bool) { From 108b67be6cc269778c17e24d38b5bc1971d11919 Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Tue, 13 May 2025 22:23:38 +0800 Subject: [PATCH 53/57] use channel bot id --- middleware/distributor.go | 2 ++ relay/channel/coze/relay-coze.go | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/middleware/distributor.go b/middleware/distributor.go index 34882381..fdda8dda 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -240,5 +240,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode c.Set("api_version", channel.Other) case common.ChannelTypeMokaAI: c.Set("api_version", channel.Other) + case common.ChannelTypeCoze: + c.Set("bot_id", channel.Other) } } diff --git a/relay/channel/coze/relay-coze.go b/relay/channel/coze/relay-coze.go index fe630ef6..8e9b8e3e 100644 --- a/relay/channel/coze/relay-coze.go +++ b/relay/channel/coze/relay-coze.go @@ -30,7 +30,7 @@ func convertCozeChatRequest(c *gin.Context, request dto.GeneralOpenAIRequest) *C } cozeRequest := &CozeChatRequest{ // TODO: model to botid - BotId: "1", + BotId: c.GetString("bot_id"), UserId: c.GetString("id"), AdditionalMessages: messages, Stream: request.Stream, From ea04e6bcc53e38e0f8f2776d12daadf67e32de52 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Wed, 14 May 2025 17:01:50 +0800 Subject: [PATCH 54/57] fix: update model selection logic for image edits in distributor middleware --- middleware/distributor.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/middleware/distributor.go b/middleware/distributor.go index 34882381..755a477d 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -185,7 +185,7 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "dall-e") } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") { - modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "gpt-image-1") + modelRequest.Model = common.GetStringIfEmpty(c.PostForm("model"), "gpt-image-1") } if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { relayMode := relayconstant.RelayModeAudioSpeech From 4825404d375622dff567deefdd69dd7495fa8c35 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Thu, 15 May 2025 14:51:33 +0800 Subject: [PATCH 55/57] feat: enhance image decoding logic to handle base64 file types and improve error handling --- service/token_counter.go | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/service/token_counter.go b/service/token_counter.go index 21b882af..d63b54ad 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -120,11 +120,12 @@ func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, m var config image.Config var err error var format string + var b64str string if strings.HasPrefix(imageUrl.Url, "http") { config, format, err = DecodeUrlImageData(imageUrl.Url) } else { common.SysLog(fmt.Sprintf("decoding image")) - config, format, _, err = DecodeBase64ImageData(imageUrl.Url) + config, format, b64str, err = DecodeBase64ImageData(imageUrl.Url) } if err != nil { return 0, err @@ -132,7 +133,12 @@ func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, m imageUrl.MimeType = format if config.Width == 0 || config.Height == 0 { - return 0, errors.New(fmt.Sprintf("fail to decode image config: %s", imageUrl.Url)) + // not an image + if format != "" && b64str != "" { + // file type + return 3 * baseTokens, nil + } + return 0, errors.New(fmt.Sprintf("fail to decode base64 config: %s", imageUrl.Url)) } shortSide := config.Width From 59aabb43119059bca2e26fd2059904294b6e0ce3 Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Thu, 15 May 2025 20:00:59 +0800 Subject: [PATCH 56/57] add frontend display, more model --- relay/channel/coze/constants.go | 24 +++++++++++++++++++++++- relay/channel/coze/relay-coze.go | 9 ++++++--- web/src/constants/channel.constants.js | 9 +++++++-- web/src/pages/Channel/EditChannel.js | 16 ++++++++++++++++ 4 files changed, 52 insertions(+), 6 deletions(-) diff --git a/relay/channel/coze/constants.go b/relay/channel/coze/constants.go index da28cb83..873ffe24 100644 --- a/relay/channel/coze/constants.go +++ b/relay/channel/coze/constants.go @@ -1,8 +1,30 @@ package coze var ModelList = []string{ - // TODO: 完整列表 + "moonshot-v1-8k", + "moonshot-v1-32k", + "moonshot-v1-128k", + "Baichuan4", + "abab6.5s-chat-pro", + "glm-4-0520", + "qwen-max", + "deepseek-r1", "deepseek-v3", + "deepseek-r1-distill-qwen-32b", + "deepseek-r1-distill-qwen-7b", + "step-1v-8k", + "step-1.5v-mini", + "Doubao-pro-32k", + "Doubao-pro-256k", + "Doubao-lite-128k", + "Doubao-lite-32k", + "Doubao-vision-lite-32k", + "Doubao-vision-pro-32k", + "Doubao-1.5-pro-vision-32k", + "Doubao-1.5-lite-32k", + "Doubao-1.5-pro-32k", + "Doubao-1.5-thinking-pro", + "Doubao-1.5-pro-256k", } var ChannelName = "coze" diff --git a/relay/channel/coze/relay-coze.go b/relay/channel/coze/relay-coze.go index 8e9b8e3e..1ebdb7c1 100644 --- a/relay/channel/coze/relay-coze.go +++ b/relay/channel/coze/relay-coze.go @@ -28,10 +28,13 @@ func convertCozeChatRequest(c *gin.Context, request dto.GeneralOpenAIRequest) *C }) } } + user := request.User + if user == "" { + user = helper.GetResponseID(c) + } cozeRequest := &CozeChatRequest{ - // TODO: model to botid BotId: c.GetString("bot_id"), - UserId: c.GetString("id"), + UserId: user, AdditionalMessages: messages, Stream: request.Stream, } @@ -172,6 +175,6 @@ func doRequest(req *http.Request, info *common.RelayInfo) (*http.Response, error if err != nil { // 增加对 client.Do(req) 返回错误的检查 return nil, fmt.Errorf("client.Do failed: %w", err) } - _ = resp.Body.Close() + // _ = resp.Body.Close() return resp, nil } diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index fa59bcce..054da535 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -118,6 +118,11 @@ export const CHANNEL_OPTIONS = [ { value: 48, color: 'blue', - label: 'xAI' - } + label: 'xAI', + }, + { + value: 49, + color: 'blue', + label: 'Coze', + }, ]; diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index fd96ffb6..f7fab057 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -838,6 +838,22 @@ const EditChannel = (props) => { /> )} + {inputs.type === 49 && ( + <> +
+ 智能体ID: +
+ { + handleInputChange('other', value); + }} + value={inputs.other} + autoComplete='new-password' + /> + + )}
{t('模型')}:
From e379ee8f66c1d3f85c89a26994b88227564ffa10 Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Fri, 16 May 2025 10:27:07 +0800 Subject: [PATCH 57/57] coze stream --- relay/channel/coze/adaptor.go | 9 ++- relay/channel/coze/relay-coze.go | 124 ++++++++++++++++++++++++++++++- 2 files changed, 130 insertions(+), 3 deletions(-) diff --git a/relay/channel/coze/adaptor.go b/relay/channel/coze/adaptor.go index 34931cc6..80441a51 100644 --- a/relay/channel/coze/adaptor.go +++ b/relay/channel/coze/adaptor.go @@ -57,6 +57,9 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt // DoRequest implements channel.Adaptor. func (a *Adaptor) DoRequest(c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (any, error) { + if info.IsStream { + return channel.DoApiRequest(a, c, info, requestBody) + } // 首先发送创建消息请求,成功后再发送获取消息请求 // 发送创建消息请求 resp, err := channel.DoApiRequest(a, c, info, requestBody) @@ -93,7 +96,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *common.RelayInfo, requestBody // DoResponse implements channel.Adaptor. func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *common.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { - err, usage = cozeChatHandler(c, resp, info) + if info.IsStream { + err, usage = cozeChatStreamHandler(c, resp, info) + } else { + err, usage = cozeChatHandler(c, resp, info) + } return } diff --git a/relay/channel/coze/relay-coze.go b/relay/channel/coze/relay-coze.go index 1ebdb7c1..6db40213 100644 --- a/relay/channel/coze/relay-coze.go +++ b/relay/channel/coze/relay-coze.go @@ -1,16 +1,18 @@ package coze import ( + "bufio" "encoding/json" "errors" "fmt" "io" "net/http" + "one-api/common" "one-api/dto" - "one-api/relay/common" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" + "strings" "github.com/gin-gonic/gin" ) @@ -95,6 +97,124 @@ func cozeChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela return nil, &usage } +func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + scanner := bufio.NewScanner(resp.Body) + scanner.Split(bufio.ScanLines) + helper.SetEventStreamHeaders(c) + id := helper.GetResponseID(c) + var responseText string + + var currentEvent string + var currentData string + var usage dto.Usage + + for scanner.Scan() { + line := scanner.Text() + + if line == "" { + if currentEvent != "" && currentData != "" { + // handle last event + handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info) + currentEvent = "" + currentData = "" + } + continue + } + + if strings.HasPrefix(line, "event:") { + currentEvent = strings.TrimSpace(line[6:]) + continue + } + + if strings.HasPrefix(line, "data:") { + currentData = strings.TrimSpace(line[5:]) + continue + } + } + + // Last event + if currentEvent != "" && currentData != "" { + handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info) + } + + if err := scanner.Err(); err != nil { + return service.OpenAIErrorWrapper(err, "stream_scanner_error", http.StatusInternalServerError), nil + } + helper.Done(c) + + if usage.TotalTokens == 0 { + usage.PromptTokens = info.PromptTokens + usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText) + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + } + + return nil, &usage +} + +func handleCozeEvent(c *gin.Context, event string, data string, responseText *string, usage *dto.Usage, id string, info *relaycommon.RelayInfo) { + switch event { + case "conversation.chat.completed": + // 将 data 解析为 CozeChatResponseData + var chatData CozeChatResponseData + err := json.Unmarshal([]byte(data), &chatData) + if err != nil { + common.SysError("error_unmarshalling_stream_response: " + err.Error()) + return + } + + usage.PromptTokens = chatData.Usage.InputCount + usage.CompletionTokens = chatData.Usage.OutputCount + usage.TotalTokens = chatData.Usage.TokenCount + + finishReason := "stop" + stopResponse := helper.GenerateStopResponse(id, common.GetTimestamp(), info.UpstreamModelName, finishReason) + helper.ObjectData(c, stopResponse) + + case "conversation.message.delta": + // 将 data 解析为 CozeChatV3MessageDetail + var messageData CozeChatV3MessageDetail + err := json.Unmarshal([]byte(data), &messageData) + if err != nil { + common.SysError("error_unmarshalling_stream_response: " + err.Error()) + return + } + + var content string + err = json.Unmarshal(messageData.Content, &content) + if err != nil { + common.SysError("error_unmarshalling_stream_response: " + err.Error()) + return + } + + *responseText += content + + openaiResponse := dto.ChatCompletionsStreamResponse{ + Id: id, + Object: "chat.completion.chunk", + Created: common.GetTimestamp(), + Model: info.UpstreamModelName, + } + + choice := dto.ChatCompletionsStreamResponseChoice{ + Index: 0, + } + choice.Delta.SetContentString(content) + openaiResponse.Choices = append(openaiResponse.Choices, choice) + + helper.ObjectData(c, openaiResponse) + + case "error": + var errorData CozeError + err := json.Unmarshal([]byte(data), &errorData) + if err != nil { + common.SysError("error_unmarshalling_stream_response: " + err.Error()) + return + } + + common.SysError(fmt.Sprintf("stream event error: ", errorData.Code, errorData.Message)) + } +} + func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (error, bool) { requestURL := fmt.Sprintf("%s/v3/chat/retrieve", info.BaseUrl) @@ -160,7 +280,7 @@ func getChatDetail(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (*ht return resp, nil } -func doRequest(req *http.Request, info *common.RelayInfo) (*http.Response, error) { +func doRequest(req *http.Request, info *relaycommon.RelayInfo) (*http.Response, error) { var client *http.Client var err error // 声明 err 变量 if proxyURL, ok := info.ChannelSetting["proxy"]; ok {