From 7e7d6112ca460be5c30a6c89fb4165346a6d5651 Mon Sep 17 00:00:00 2001 From: tbphp Date: Mon, 5 May 2025 11:34:57 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E4=BC=98=E5=8C=96=E4=BB=A3=E7=A0=81?= =?UTF-8?q?=EF=BC=8C=E5=8E=BB=E9=99=A4=E5=A4=9A=E4=BD=99=E6=B3=A8=E9=87=8A?= =?UTF-8?q?=E5=92=8C=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .old/option.go | 402 ++++++++++++++++++ middleware/model-rate-limit.go | 43 +- model/option.go | 40 +- setting/rate_limit.go | 39 +- web/src/components/RateLimitSetting.js | 92 ++-- .../RateLimit/SettingsRequestRateLimit.js | 388 +++++++++-------- 6 files changed, 663 insertions(+), 341 deletions(-) create mode 100644 .old/option.go diff --git a/.old/option.go b/.old/option.go new file mode 100644 index 00000000..f80f5cb3 --- /dev/null +++ b/.old/option.go @@ -0,0 +1,402 @@ +package model + +import ( + "one-api/common" + "one-api/setting" + "one-api/setting/config" + "one-api/setting/operation_setting" + "strconv" + "strings" + "time" +) + +type Option struct { + Key string `json:"key" gorm:"primaryKey"` + Value string `json:"value"` +} + +func AllOption() ([]*Option, error) { + var options []*Option + var err error + err = DB.Find(&options).Error + return options, err +} + +func InitOptionMap() { + common.OptionMapRWMutex.Lock() + common.OptionMap = make(map[string]string) + + // 添加原有的系统配置 + common.OptionMap["FileUploadPermission"] = strconv.Itoa(common.FileUploadPermission) + common.OptionMap["FileDownloadPermission"] = strconv.Itoa(common.FileDownloadPermission) + common.OptionMap["ImageUploadPermission"] = strconv.Itoa(common.ImageUploadPermission) + common.OptionMap["ImageDownloadPermission"] = strconv.Itoa(common.ImageDownloadPermission) + common.OptionMap["PasswordLoginEnabled"] = strconv.FormatBool(common.PasswordLoginEnabled) + common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled) + common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled) + common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled) + common.OptionMap["LinuxDOOAuthEnabled"] = strconv.FormatBool(common.LinuxDOOAuthEnabled) + common.OptionMap["TelegramOAuthEnabled"] = strconv.FormatBool(common.TelegramOAuthEnabled) + common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled) + common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled) + common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled) + common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled) + common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled) + common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled) + common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled) + common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled) + common.OptionMap["DrawingEnabled"] = strconv.FormatBool(common.DrawingEnabled) + common.OptionMap["TaskEnabled"] = strconv.FormatBool(common.TaskEnabled) + common.OptionMap["DataExportEnabled"] = strconv.FormatBool(common.DataExportEnabled) + common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64) + common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled) + common.OptionMap["EmailAliasRestrictionEnabled"] = strconv.FormatBool(common.EmailAliasRestrictionEnabled) + common.OptionMap["EmailDomainWhitelist"] = strings.Join(common.EmailDomainWhitelist, ",") + common.OptionMap["SMTPServer"] = "" + common.OptionMap["SMTPFrom"] = "" + common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort) + common.OptionMap["SMTPAccount"] = "" + common.OptionMap["SMTPToken"] = "" + common.OptionMap["SMTPSSLEnabled"] = strconv.FormatBool(common.SMTPSSLEnabled) + common.OptionMap["Notice"] = "" + common.OptionMap["About"] = "" + common.OptionMap["HomePageContent"] = "" + common.OptionMap["Footer"] = common.Footer + common.OptionMap["SystemName"] = common.SystemName + common.OptionMap["Logo"] = common.Logo + common.OptionMap["ServerAddress"] = "" + common.OptionMap["WorkerUrl"] = setting.WorkerUrl + common.OptionMap["WorkerValidKey"] = setting.WorkerValidKey + common.OptionMap["PayAddress"] = "" + common.OptionMap["CustomCallbackAddress"] = "" + common.OptionMap["EpayId"] = "" + common.OptionMap["EpayKey"] = "" + common.OptionMap["Price"] = strconv.FormatFloat(setting.Price, 'f', -1, 64) + common.OptionMap["MinTopUp"] = strconv.Itoa(setting.MinTopUp) + common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString() + common.OptionMap["Chats"] = setting.Chats2JsonString() + common.OptionMap["GitHubClientId"] = "" + common.OptionMap["GitHubClientSecret"] = "" + common.OptionMap["TelegramBotToken"] = "" + common.OptionMap["TelegramBotName"] = "" + common.OptionMap["WeChatServerAddress"] = "" + common.OptionMap["WeChatServerToken"] = "" + common.OptionMap["WeChatAccountQRCodeImageURL"] = "" + common.OptionMap["TurnstileSiteKey"] = "" + common.OptionMap["TurnstileSecretKey"] = "" + common.OptionMap["QuotaForNewUser"] = strconv.Itoa(common.QuotaForNewUser) + common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter) + common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee) + common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold) + common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota) + common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount) + common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes) + common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount) + common.OptionMap["ModelRatio"] = operation_setting.ModelRatio2JSONString() + common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString() + common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString() + common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString() + common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString() + common.OptionMap["CompletionRatio"] = operation_setting.CompletionRatio2JSONString() + common.OptionMap["TopUpLink"] = common.TopUpLink + //common.OptionMap["ChatLink"] = common.ChatLink + //common.OptionMap["ChatLink2"] = common.ChatLink2 + common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64) + common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes) + common.OptionMap["DataExportInterval"] = strconv.Itoa(common.DataExportInterval) + common.OptionMap["DataExportDefaultTime"] = common.DataExportDefaultTime + common.OptionMap["DefaultCollapseSidebar"] = strconv.FormatBool(common.DefaultCollapseSidebar) + common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(setting.MjNotifyEnabled) + common.OptionMap["MjAccountFilterEnabled"] = strconv.FormatBool(setting.MjAccountFilterEnabled) + common.OptionMap["MjModeClearEnabled"] = strconv.FormatBool(setting.MjModeClearEnabled) + common.OptionMap["MjForwardUrlEnabled"] = strconv.FormatBool(setting.MjForwardUrlEnabled) + common.OptionMap["MjActionCheckSuccessEnabled"] = strconv.FormatBool(setting.MjActionCheckSuccessEnabled) + common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(setting.CheckSensitiveEnabled) + common.OptionMap["DemoSiteEnabled"] = strconv.FormatBool(operation_setting.DemoSiteEnabled) + common.OptionMap["SelfUseModeEnabled"] = strconv.FormatBool(operation_setting.SelfUseModeEnabled) + common.OptionMap["ModelRequestRateLimitEnabled"] = strconv.FormatBool(setting.ModelRequestRateLimitEnabled) + common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(setting.CheckSensitiveOnPromptEnabled) + common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(setting.StopOnSensitiveEnabled) + common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString() + common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength) + common.OptionMap["AutomaticDisableKeywords"] = operation_setting.AutomaticDisableKeywordsToString() + + // 自动添加所有注册的模型配置 + modelConfigs := config.GlobalConfig.ExportAllConfigs() + for k, v := range modelConfigs { + common.OptionMap[k] = v + } + + common.OptionMapRWMutex.Unlock() + loadOptionsFromDatabase() +} + +func loadOptionsFromDatabase() { + options, _ := AllOption() + for _, option := range options { + err := updateOptionMap(option.Key, option.Value) + if err != nil { + common.SysError("failed to update option map: " + err.Error()) + } + } +} + +func SyncOptions(frequency int) { + for { + time.Sleep(time.Duration(frequency) * time.Second) + common.SysLog("syncing options from database") + loadOptionsFromDatabase() + } +} + +func UpdateOption(key string, value string) error { + // Save to database first + option := Option{ + Key: key, + } + // https://gorm.io/docs/update.html#Save-All-Fields + DB.FirstOrCreate(&option, Option{Key: key}) + option.Value = value + // Save is a combination function. + // If save value does not contain primary key, it will execute Create, + // otherwise it will execute Update (with all fields). + DB.Save(&option) + // Update OptionMap + return updateOptionMap(key, value) +} + +func updateOptionMap(key string, value string) (err error) { + common.OptionMapRWMutex.Lock() + defer common.OptionMapRWMutex.Unlock() + common.OptionMap[key] = value + + // 检查是否是模型配置 - 使用更规范的方式处理 + if handleConfigUpdate(key, value) { + return nil // 已由配置系统处理 + } + + // 处理传统配置项... + if strings.HasSuffix(key, "Permission") { + intValue, _ := strconv.Atoi(value) + switch key { + case "FileUploadPermission": + common.FileUploadPermission = intValue + case "FileDownloadPermission": + common.FileDownloadPermission = intValue + case "ImageUploadPermission": + common.ImageUploadPermission = intValue + case "ImageDownloadPermission": + common.ImageDownloadPermission = intValue + } + } + if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" { + boolValue := value == "true" + switch key { + case "PasswordRegisterEnabled": + common.PasswordRegisterEnabled = boolValue + case "PasswordLoginEnabled": + common.PasswordLoginEnabled = boolValue + case "EmailVerificationEnabled": + common.EmailVerificationEnabled = boolValue + case "GitHubOAuthEnabled": + common.GitHubOAuthEnabled = boolValue + case "LinuxDOOAuthEnabled": + common.LinuxDOOAuthEnabled = boolValue + case "WeChatAuthEnabled": + common.WeChatAuthEnabled = boolValue + case "TelegramOAuthEnabled": + common.TelegramOAuthEnabled = boolValue + case "TurnstileCheckEnabled": + common.TurnstileCheckEnabled = boolValue + case "RegisterEnabled": + common.RegisterEnabled = boolValue + case "EmailDomainRestrictionEnabled": + common.EmailDomainRestrictionEnabled = boolValue + case "EmailAliasRestrictionEnabled": + common.EmailAliasRestrictionEnabled = boolValue + case "AutomaticDisableChannelEnabled": + common.AutomaticDisableChannelEnabled = boolValue + case "AutomaticEnableChannelEnabled": + common.AutomaticEnableChannelEnabled = boolValue + case "LogConsumeEnabled": + common.LogConsumeEnabled = boolValue + case "DisplayInCurrencyEnabled": + common.DisplayInCurrencyEnabled = boolValue + case "DisplayTokenStatEnabled": + common.DisplayTokenStatEnabled = boolValue + case "DrawingEnabled": + common.DrawingEnabled = boolValue + case "TaskEnabled": + common.TaskEnabled = boolValue + case "DataExportEnabled": + common.DataExportEnabled = boolValue + case "DefaultCollapseSidebar": + common.DefaultCollapseSidebar = boolValue + case "MjNotifyEnabled": + setting.MjNotifyEnabled = boolValue + case "MjAccountFilterEnabled": + setting.MjAccountFilterEnabled = boolValue + case "MjModeClearEnabled": + setting.MjModeClearEnabled = boolValue + case "MjForwardUrlEnabled": + setting.MjForwardUrlEnabled = boolValue + case "MjActionCheckSuccessEnabled": + setting.MjActionCheckSuccessEnabled = boolValue + case "CheckSensitiveEnabled": + setting.CheckSensitiveEnabled = boolValue + case "DemoSiteEnabled": + operation_setting.DemoSiteEnabled = boolValue + case "SelfUseModeEnabled": + operation_setting.SelfUseModeEnabled = boolValue + case "CheckSensitiveOnPromptEnabled": + setting.CheckSensitiveOnPromptEnabled = boolValue + case "ModelRequestRateLimitEnabled": + setting.ModelRequestRateLimitEnabled = boolValue + case "StopOnSensitiveEnabled": + setting.StopOnSensitiveEnabled = boolValue + case "SMTPSSLEnabled": + common.SMTPSSLEnabled = boolValue + } + } + switch key { + case "EmailDomainWhitelist": + common.EmailDomainWhitelist = strings.Split(value, ",") + case "SMTPServer": + common.SMTPServer = value + case "SMTPPort": + intValue, _ := strconv.Atoi(value) + common.SMTPPort = intValue + case "SMTPAccount": + common.SMTPAccount = value + case "SMTPFrom": + common.SMTPFrom = value + case "SMTPToken": + common.SMTPToken = value + case "ServerAddress": + setting.ServerAddress = value + case "WorkerUrl": + setting.WorkerUrl = value + case "WorkerValidKey": + setting.WorkerValidKey = value + case "PayAddress": + setting.PayAddress = value + case "Chats": + err = setting.UpdateChatsByJsonString(value) + case "CustomCallbackAddress": + setting.CustomCallbackAddress = value + case "EpayId": + setting.EpayId = value + case "EpayKey": + setting.EpayKey = value + case "Price": + setting.Price, _ = strconv.ParseFloat(value, 64) + case "MinTopUp": + setting.MinTopUp, _ = strconv.Atoi(value) + case "TopupGroupRatio": + err = common.UpdateTopupGroupRatioByJSONString(value) + case "GitHubClientId": + common.GitHubClientId = value + case "GitHubClientSecret": + common.GitHubClientSecret = value + case "LinuxDOClientId": + common.LinuxDOClientId = value + case "LinuxDOClientSecret": + common.LinuxDOClientSecret = value + case "Footer": + common.Footer = value + case "SystemName": + common.SystemName = value + case "Logo": + common.Logo = value + case "WeChatServerAddress": + common.WeChatServerAddress = value + case "WeChatServerToken": + common.WeChatServerToken = value + case "WeChatAccountQRCodeImageURL": + common.WeChatAccountQRCodeImageURL = value + case "TelegramBotToken": + common.TelegramBotToken = value + case "TelegramBotName": + common.TelegramBotName = value + case "TurnstileSiteKey": + common.TurnstileSiteKey = value + case "TurnstileSecretKey": + common.TurnstileSecretKey = value + case "QuotaForNewUser": + common.QuotaForNewUser, _ = strconv.Atoi(value) + case "QuotaForInviter": + common.QuotaForInviter, _ = strconv.Atoi(value) + case "QuotaForInvitee": + common.QuotaForInvitee, _ = strconv.Atoi(value) + case "QuotaRemindThreshold": + common.QuotaRemindThreshold, _ = strconv.Atoi(value) + case "PreConsumedQuota": + common.PreConsumedQuota, _ = strconv.Atoi(value) + case "ModelRequestRateLimitCount": + setting.ModelRequestRateLimitCount, _ = strconv.Atoi(value) + case "ModelRequestRateLimitDurationMinutes": + setting.ModelRequestRateLimitDurationMinutes, _ = strconv.Atoi(value) + case "ModelRequestRateLimitSuccessCount": + setting.ModelRequestRateLimitSuccessCount, _ = strconv.Atoi(value) + case "RetryTimes": + common.RetryTimes, _ = strconv.Atoi(value) + case "DataExportInterval": + common.DataExportInterval, _ = strconv.Atoi(value) + case "DataExportDefaultTime": + common.DataExportDefaultTime = value + case "ModelRatio": + err = operation_setting.UpdateModelRatioByJSONString(value) + case "GroupRatio": + err = setting.UpdateGroupRatioByJSONString(value) + case "UserUsableGroups": + err = setting.UpdateUserUsableGroupsByJSONString(value) + case "CompletionRatio": + err = operation_setting.UpdateCompletionRatioByJSONString(value) + case "ModelPrice": + err = operation_setting.UpdateModelPriceByJSONString(value) + case "CacheRatio": + err = operation_setting.UpdateCacheRatioByJSONString(value) + case "TopUpLink": + common.TopUpLink = value + //case "ChatLink": + // common.ChatLink = value + //case "ChatLink2": + // common.ChatLink2 = value + case "ChannelDisableThreshold": + common.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64) + case "QuotaPerUnit": + common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64) + case "SensitiveWords": + setting.SensitiveWordsFromString(value) + case "AutomaticDisableKeywords": + operation_setting.AutomaticDisableKeywordsFromString(value) + case "StreamCacheQueueLength": + setting.StreamCacheQueueLength, _ = strconv.Atoi(value) + } + return err +} + +// handleConfigUpdate 处理分层配置更新,返回是否已处理 +func handleConfigUpdate(key, value string) bool { + parts := strings.SplitN(key, ".", 2) + if len(parts) != 2 { + return false // 不是分层配置 + } + + configName := parts[0] + configKey := parts[1] + + // 获取配置对象 + cfg := config.GlobalConfig.Get(configName) + if cfg == nil { + return false // 未注册的配置 + } + + // 更新配置 + configMap := map[string]string{ + configKey: value, + } + config.UpdateConfigFromMap(cfg, configMap) + + return true // 已处理 +} \ No newline at end of file diff --git a/middleware/model-rate-limit.go b/middleware/model-rate-limit.go index d4199ece..b0047b70 100644 --- a/middleware/model-rate-limit.go +++ b/middleware/model-rate-limit.go @@ -19,25 +19,20 @@ const ( ModelRequestRateLimitSuccessCountMark = "MRRLS" ) -// 检查Redis中的请求限制 func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, maxCount int, duration int64) (bool, error) { - // 如果maxCount为0,表示不限制 if maxCount == 0 { return true, nil } - // 获取当前计数 length, err := rdb.LLen(ctx, key).Result() if err != nil { return false, err } - // 如果未达到限制,允许请求 if length < int64(maxCount) { return true, nil } - // 检查时间窗口 oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result() oldTime, err := time.Parse(timeFormat, oldTimeStr) if err != nil { @@ -49,7 +44,6 @@ func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, max if err != nil { return false, err } - // 如果在时间窗口内已达到限制,拒绝请求 subTime := nowTime.Sub(oldTime).Seconds() if int64(subTime) < duration { rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute) @@ -59,9 +53,7 @@ func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, max return true, nil } -// 记录Redis请求 func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxCount int) { - // 如果maxCount为0,不记录请求 if maxCount == 0 { return } @@ -72,14 +64,12 @@ func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxC rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute) } -// Redis限流处理器 func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc { return func(c *gin.Context) { userId := strconv.Itoa(c.GetInt("id")) ctx := context.Background() rdb := common.RDB - // 1. 检查成功请求数限制 successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId) allowed, err := checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration) if err != nil { @@ -92,9 +82,7 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g return } - //2.检查总请求数限制并记录总请求(当totalMaxCount为0时会自动跳过,使用令牌桶限流器 totalKey := fmt.Sprintf("rateLimit:%s", userId) - // 初始化 tb := limiter.New(ctx, rdb) allowed, err = tb.Allow( ctx, @@ -114,17 +102,14 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount)) } - // 4. 处理请求 c.Next() - // 5. 如果请求成功,记录成功请求 if c.Writer.Status() < 400 { recordRedisRequest(ctx, rdb, successKey, successMaxCount) } } } -// 内存限流处理器 func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc { inMemoryRateLimiter.Init(time.Duration(setting.ModelRequestRateLimitDurationMinutes) * time.Minute) @@ -133,15 +118,12 @@ func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) totalKey := ModelRequestRateLimitCountMark + userId successKey := ModelRequestRateLimitSuccessCountMark + userId - // 1. 检查总请求数限制(当totalMaxCount为0时跳过) if totalMaxCount > 0 && !inMemoryRateLimiter.Request(totalKey, totalMaxCount, duration) { c.Status(http.StatusTooManyRequests) c.Abort() return } - // 2. 检查成功请求数限制 - // 使用一个临时key来检查限制,这样可以避免实际记录 checkKey := successKey + "_check" if !inMemoryRateLimiter.Request(checkKey, successMaxCount, duration) { c.Status(http.StatusTooManyRequests) @@ -149,54 +131,47 @@ func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) return } - // 3. 处理请求 c.Next() - // 4. 如果请求成功,记录到实际的成功请求计数中 if c.Writer.Status() < 400 { inMemoryRateLimiter.Request(successKey, successMaxCount, duration) } } } -// ModelRequestRateLimit 模型请求限流中间件 func ModelRequestRateLimit() func(c *gin.Context) { return func(c *gin.Context) { - // 在每个请求时检查是否启用限流 if !setting.ModelRequestRateLimitEnabled { c.Next() return } - // 计算通用限流参数 duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60) - // 获取用户组 group := c.GetString("token_group") if group == "" { group = c.GetString("group") } if group == "" { - group = "default" // 默认组 + group = "default" } - // 尝试获取用户组特定的限制 + finalTotalCount := setting.ModelRequestRateLimitCount + finalSuccessCount := setting.ModelRequestRateLimitSuccessCount + foundGroupLimit := false + groupTotalCount, groupSuccessCount, found := setting.GetGroupRateLimit(group) - - // 确定最终的限制值 - finalTotalCount := setting.ModelRequestRateLimitCount // 默认使用全局总次数限制 - finalSuccessCount := setting.ModelRequestRateLimitSuccessCount // 默认使用全局成功次数限制 - if found { - // 如果找到用户组特定限制,则使用它们 finalTotalCount = groupTotalCount finalSuccessCount = groupSuccessCount + foundGroupLimit = true common.LogWarn(c.Request.Context(), fmt.Sprintf("Using rate limit for group '%s': total=%d, success=%d", group, finalTotalCount, finalSuccessCount)) - } else { + } + + if !foundGroupLimit { common.LogInfo(c.Request.Context(), fmt.Sprintf("No specific rate limit found for group '%s', using global limits: total=%d, success=%d", group, finalTotalCount, finalSuccessCount)) } - // 根据存储类型选择并执行限流处理器,传入最终确定的限制值 if common.RedisEnabled { redisRateLimitHandler(duration, finalTotalCount, finalSuccessCount)(c) } else { diff --git a/model/option.go b/model/option.go index 1f5fb3aa..79556737 100644 --- a/model/option.go +++ b/model/option.go @@ -94,11 +94,12 @@ func InitOptionMap() { common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount) common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes) common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount) + jsonBytes, _ := json.Marshal(map[string][2]int{}) + common.OptionMap["ModelRequestRateLimitGroup"] = string(jsonBytes) common.OptionMap["ModelRatio"] = operation_setting.ModelRatio2JSONString() common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString() common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString() common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString() - common.OptionMap[setting.ModelRequestRateLimitGroupKey] = "{}" // 添加用户组速率限制默认值 common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString() common.OptionMap["CompletionRatio"] = operation_setting.CompletionRatio2JSONString() common.OptionMap["TopUpLink"] = common.TopUpLink @@ -153,47 +154,31 @@ func SyncOptions(frequency int) { } func UpdateOption(key string, value string) error { - originalValue := value // 保存原始值以备后用 + originalValue := value - // Validate and format specific keys before saving - if key == setting.ModelRequestRateLimitGroupKey { + if key == "ModelRequestRateLimitGroup" { var cfg map[string][2]int - // Validate the JSON structure first using the original value err := json.Unmarshal([]byte(originalValue), &cfg) if err != nil { - // 提供更具体的错误信息 return fmt.Errorf("无效的 JSON 格式 for %s: %w", key, err) } - // TODO: 可以添加更细致的结构验证,例如检查数组长度是否为2,值是否为非负数等。 - // if !isValidModelRequestRateLimitGroupConfig(cfg) { - // return fmt.Errorf("无效的配置值 for %s", key) - // } - // If valid, format the JSON before saving formattedValueBytes, marshalErr := json.MarshalIndent(cfg, "", " ") if marshalErr != nil { - // This should ideally not happen if validation passed, but handle defensively return fmt.Errorf("failed to marshal validated %s config: %w", key, marshalErr) } - value = string(formattedValueBytes) // Use formatted JSON for saving and memory update + value = string(formattedValueBytes) } - // Save to database option := Option{ Key: key, } - // https://gorm.io/docs/update.html#Save-All-Fields DB.FirstOrCreate(&option, Option{Key: key}) option.Value = value - // Save is a combination function. - // If save value does not contain primary key, it will execute Create, - // otherwise it will execute Update (with all fields). if err := DB.Save(&option).Error; err != nil { - return fmt.Errorf("保存选项 %s 到数据库失败: %w", key, err) // 添加错误上下文 + return fmt.Errorf("保存选项 %s 到数据库失败: %w", key, err) } - // Update OptionMap in memory using the potentially formatted value - // updateOptionMap 会处理内存中 setting.ModelRequestRateLimitGroupConfig 的更新 return updateOptionMap(key, value) } @@ -370,6 +355,8 @@ func updateOptionMap(key string, value string) (err error) { setting.ModelRequestRateLimitDurationMinutes, _ = strconv.Atoi(value) case "ModelRequestRateLimitSuccessCount": setting.ModelRequestRateLimitSuccessCount, _ = strconv.Atoi(value) + case "ModelRequestRateLimitGroup": + err = setting.UpdateModelRequestRateLimitGroup(value) case "RetryTimes": common.RetryTimes, _ = strconv.Atoi(value) case "DataExportInterval": @@ -404,15 +391,6 @@ func updateOptionMap(key string, value string) (err error) { operation_setting.AutomaticDisableKeywordsFromString(value) case "StreamCacheQueueLength": setting.StreamCacheQueueLength, _ = strconv.Atoi(value) - case setting.ModelRequestRateLimitGroupKey: - // Use the (potentially formatted) value passed from UpdateOption - // to update the actual configuration in memory. - // This is the single point where the memory state for this specific setting is updated. - err = setting.UpdateModelRequestRateLimitGroupConfig(value) - if err != nil { - // 添加错误上下文 - err = fmt.Errorf("更新内存中的 %s 配置失败: %w", key, err) - } } return err } @@ -440,4 +418,4 @@ func handleConfigUpdate(key, value string) bool { config.UpdateConfigFromMap(cfg, configMap) return true // 已处理 -} +} \ No newline at end of file diff --git a/setting/rate_limit.go b/setting/rate_limit.go index c83885a6..5be75cc1 100644 --- a/setting/rate_limit.go +++ b/setting/rate_limit.go @@ -11,24 +11,17 @@ var ModelRequestRateLimitEnabled = false var ModelRequestRateLimitDurationMinutes = 1 var ModelRequestRateLimitCount = 0 var ModelRequestRateLimitSuccessCount = 1000 +var ModelRequestRateLimitGroup map[string][2]int -// ModelRequestRateLimitGroupKey 定义了模型请求按组速率限制的配置键 -const ModelRequestRateLimitGroupKey = "ModelRequestRateLimitGroup" - -// ModelRequestRateLimitGroupConfig 存储按用户组解析后的速率限制配置 -// map[groupName][2]int{totalCount, successCount} -var ModelRequestRateLimitGroupConfig map[string][2]int var ModelRequestRateLimitGroupMutex sync.RWMutex -// UpdateModelRequestRateLimitGroupConfig 解析、校验并更新内存中的用户组速率限制配置 -func UpdateModelRequestRateLimitGroupConfig(jsonStr string) error { +func UpdateModelRequestRateLimitGroup(jsonStr string) error { ModelRequestRateLimitGroupMutex.Lock() defer ModelRequestRateLimitGroupMutex.Unlock() var newConfig map[string][2]int if jsonStr == "" || jsonStr == "{}" { - // 如果配置为空或空JSON对象,则清空内存配置 - ModelRequestRateLimitGroupConfig = make(map[string][2]int) + ModelRequestRateLimitGroup = make(map[string][2]int) common.SysLog("Model request rate limit group config cleared") return nil } @@ -38,37 +31,19 @@ func UpdateModelRequestRateLimitGroupConfig(jsonStr string) error { return fmt.Errorf("failed to unmarshal ModelRequestRateLimitGroup config: %w", err) } - // 校验配置值 - for group, limits := range newConfig { - if len(limits) != 2 { - return fmt.Errorf("invalid config for group '%s': limits array length must be 2", group) - } - if limits[1] <= 0 { // successCount must be greater than 0 - return fmt.Errorf("invalid config for group '%s': successCount (limits[1]) must be greater than 0", group) - } - if limits[0] < 0 { // totalCount can be 0 (no limit) or positive - return fmt.Errorf("invalid config for group '%s': totalCount (limits[0]) cannot be negative", group) - } - if limits[0] > 0 && limits[0] < limits[1] { // If totalCount is set, it must be >= successCount - return fmt.Errorf("invalid config for group '%s': totalCount (limits[0]) must be greater than or equal to successCount (limits[1]) when totalCount > 0", group) - } - } - - ModelRequestRateLimitGroupConfig = newConfig - common.SysLog("Model request rate limit group config updated") + ModelRequestRateLimitGroup = newConfig return nil } -// GetGroupRateLimit 安全地获取指定用户组的速率限制值 func GetGroupRateLimit(group string) (totalCount, successCount int, found bool) { ModelRequestRateLimitGroupMutex.RLock() defer ModelRequestRateLimitGroupMutex.RUnlock() - if ModelRequestRateLimitGroupConfig == nil { - return 0, 0, false // 配置尚未初始化 + if ModelRequestRateLimitGroup == nil { + return 0, 0, false } - limits, found := ModelRequestRateLimitGroupConfig[group] + limits, found := ModelRequestRateLimitGroup[group] if !found { return 0, 0, false } diff --git a/web/src/components/RateLimitSetting.js b/web/src/components/RateLimitSetting.js index ad6b53da..7e206672 100644 --- a/web/src/components/RateLimitSetting.js +++ b/web/src/components/RateLimitSetting.js @@ -9,59 +9,59 @@ import RequestRateLimit from '../pages/Setting/RateLimit/SettingsRequestRateLimi const RateLimitSetting = () => { const { t } = useTranslation(); let [inputs, setInputs] = useState({ - ModelRequestRateLimitEnabled: false, - ModelRequestRateLimitCount: 0, - ModelRequestRateLimitSuccessCount: 1000, - ModelRequestRateLimitDurationMinutes: 1, - ModelRequestRateLimitGroup: {}, + ModelRequestRateLimitEnabled: false, + ModelRequestRateLimitCount: 0, + ModelRequestRateLimitSuccessCount: 1000, + ModelRequestRateLimitDurationMinutes: 1, + ModelRequestRateLimitGroup: '{}', }); - + let [loading, setLoading] = useState(false); - + const getOptions = async () => { - const res = await API.get('/api/option/'); - const { success, message, data } = res.data; - if (success) { - let newInputs = {}; - data.forEach((item) => { - if (item.key.endsWith('Enabled')) { - newInputs[item.key] = item.value === 'true' ? true : false; - } else { - newInputs[item.key] = item.value; - } - }); - - setInputs(newInputs); - } else { - showError(message); - } + const res = await API.get('/api/option/'); + const { success, message, data } = res.data; + if (success) { + let newInputs = {}; + data.forEach((item) => { + // 检查 key 是否在初始 inputs 中定义 + if (Object.prototype.hasOwnProperty.call(inputs, item.key)) { + if (item.key.endsWith('Enabled')) { + newInputs[item.key] = item.value === 'true'; + } else { + newInputs[item.key] = item.value; + } + } + }); + setInputs(newInputs); + } else { + showError(message); + } }; async function onRefresh() { - try { - setLoading(true); - await getOptions(); - // showSuccess('刷新成功'); - } catch (error) { - showError('刷新失败'); - } finally { - setLoading(false); - } + try { + setLoading(true); + await getOptions(); + } catch (error) { + showError('刷新失败'); + } finally { + setLoading(false); + } } - + useEffect(() => { - onRefresh(); + onRefresh(); }, []); - + return ( - <> - - {/* AI请求速率限制 */} - - - - - + <> + + + + + + ); -}; - -export default RateLimitSetting; + }; + + export default RateLimitSetting; diff --git a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js index ec1c2158..2434020e 100644 --- a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js +++ b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js @@ -14,209 +14,201 @@ export default function RequestRateLimit(props) { const [loading, setLoading] = useState(false); const [inputs, setInputs] = useState({ - ModelRequestRateLimitEnabled: false, - ModelRequestRateLimitCount: -1, - ModelRequestRateLimitSuccessCount: 1000, - ModelRequestRateLimitDurationMinutes: 1, - ModelRequestRateLimitGroup: '{}', // 添加新字段并设置默认值 + ModelRequestRateLimitEnabled: false, + ModelRequestRateLimitCount: -1, + ModelRequestRateLimitSuccessCount: 1000, + ModelRequestRateLimitDurationMinutes: 1, + ModelRequestRateLimitGroup: '{}', }); const refForm = useRef(); const [inputsRow, setInputsRow] = useState(inputs); - + function onSubmit() { - const updateArray = compareObjects(inputs, inputsRow); - if (!updateArray.length) return showWarning(t('你似乎并没有修改什么')); - const requestQueue = updateArray.map((item) => { - let value = ''; - if (typeof inputs[item.key] === 'boolean') { - value = String(inputs[item.key]); - } else { - value = inputs[item.key]; - } - // 校验 ModelRequestRateLimitGroup 是否为有效的 JSON 对象字符串 - if (item.key === 'ModelRequestRateLimitGroup') { - try { - JSON.parse(value); - } catch (e) { - showError(t('用户组速率限制配置不是有效的 JSON 格式!')); - // 阻止请求发送 - return Promise.reject('Invalid JSON format'); - } - } - return API.put('/api/option/', { - key: item.key, - value, - }); - }); - - // 过滤掉无效的请求(例如,无效的 JSON) - const validRequests = requestQueue.filter(req => req !== null && req !== undefined && typeof req.then === 'function'); - - if (validRequests.length === 0 && requestQueue.length > 0) { - // 如果所有请求都被过滤掉了(因为 JSON 无效),则不继续执行 - return; - } - - setLoading(true); - Promise.all(validRequests) - .then((res) => { - if (validRequests.length === 1) { - if (res.includes(undefined)) return; - } else if (validRequests.length > 1) { - if (res.includes(undefined)) - return showError(t('部分保存失败,请重试')); - } - showSuccess(t('保存成功')); - props.refresh(); - // 更新 inputsRow 以反映保存后的状态 - setInputsRow(structuredClone(inputs)); - }) - .catch((error) => { - // 检查是否是由于无效 JSON 导致的错误 - if (error !== 'Invalid JSON format') { - showError(t('保存失败,请重试')); - } - }) - .finally(() => { - setLoading(false); - }); + const updateArray = compareObjects(inputs, inputsRow); + if (!updateArray.length) return showWarning(t('你似乎并没有修改什么')); + const requestQueue = updateArray.map((item) => { + let value = ''; + if (typeof inputs[item.key] === 'boolean') { + value = String(inputs[item.key]); + } else { + value = inputs[item.key]; + } + if (item.key === 'ModelRequestRateLimitGroup') { + try { + JSON.parse(value); + } catch (e) { + showError(t('用户组速率限制配置不是有效的 JSON 格式!')); + return Promise.reject('Invalid JSON format'); + } + } + return API.put('/api/option/', { + key: item.key, + value, + }); + }); + + const validRequests = requestQueue.filter(req => req !== null && req !== undefined && typeof req.then === 'function'); + + if (validRequests.length === 0 && requestQueue.length > 0) { + return; + } + + setLoading(true); + Promise.all(validRequests) + .then((res) => { + if (validRequests.length === 1) { + if (res.includes(undefined)) return; + } else if (validRequests.length > 1) { + if (res.includes(undefined)) + return showError(t('部分保存失败,请重试')); + } + showSuccess(t('保存成功')); + props.refresh(); + setInputsRow(structuredClone(inputs)); + }) + .catch((error) => { + if (error !== 'Invalid JSON format') { + showError(t('保存失败,请重试')); + } + }) + .finally(() => { + setLoading(false); + }); } - + useEffect(() => { - const currentInputs = {}; - for (let key in props.options) { - if (Object.keys(inputs).includes(key)) { - currentInputs[key] = props.options[key]; - } - } - setInputs(currentInputs); - setInputsRow(structuredClone(currentInputs)); - // 检查 refForm.current 是否存在 - if (refForm.current) { - refForm.current.setValues(currentInputs); - } - }, [props.options]); // 依赖项保持不变,因为 inputs 状态的结构已固定 - + const currentInputs = {}; + for (let key in props.options) { + if (Object.prototype.hasOwnProperty.call(inputs, key)) { // 使用 hasOwnProperty 检查 + currentInputs[key] = props.options[key]; + } + } + setInputs(currentInputs); + setInputsRow(structuredClone(currentInputs)); + if (refForm.current) { + refForm.current.setValues(currentInputs); + } + }, [props.options]); + return ( - <> - -
(refForm.current = formAPI)} - style={{ marginBottom: 15 }} - > - - - - { - setInputs({ - ...inputs, - ModelRequestRateLimitEnabled: value, - }); - }} - /> - - - - - - setInputs({ - ...inputs, - ModelRequestRateLimitDurationMinutes: String(value), - }) - } - /> - - - - - - setInputs({ - ...inputs, - ModelRequestRateLimitCount: String(value), - }) - } - /> - - - - setInputs({ - ...inputs, - ModelRequestRateLimitSuccessCount: String(value), - }) - } - /> - - - {/* 用户组速率限制配置项 */} - - - -

{t('说明:')}

-
    -
  • {t('使用 JSON 对象格式,键为用户组名 (字符串),值为包含两个整数的数组 [总次数限制, 成功次数限制]。')}
  • -
  • {t('总次数限制: 周期内允许的总请求次数 (含失败),0 代表不限制。')}
  • -
  • {t('成功次数限制: 周期内允许的成功请求次数 (HTTP < 400),必须大于 0。')}
  • -
  • {t('此配置将优先于上方的全局限制设置。')}
  • -
  • {t('未在此处配置的用户组将使用全局限制。')}
  • -
  • {t('限制周期统一使用上方配置的“限制周期”值。')}
  • -
  • {t('输入无效的 JSON 将无法保存。')}
  • -
- - } - autosize={{ minRows: 5, maxRows: 15 }} - style={{ fontFamily: 'monospace' }} - onChange={(value) => { - setInputs({ - ...inputs, - ModelRequestRateLimitGroup: value, // 直接更新字符串值 - }); - }} - /> - -
- - - -
-
-
- + <> + +
(refForm.current = formAPI)} + style={{ marginBottom: 15 }} + > + + + + { + setInputs({ + ...inputs, + ModelRequestRateLimitEnabled: value, + }); + }} + /> + + + + + + setInputs({ + ...inputs, + ModelRequestRateLimitDurationMinutes: String(value), + }) + } + /> + + + + + + setInputs({ + ...inputs, + ModelRequestRateLimitCount: String(value), + }) + } + /> + + + + setInputs({ + ...inputs, + ModelRequestRateLimitSuccessCount: String(value), + }) + } + /> + + + + + +

{t('说明:')}

+
    +
  • {t('使用 JSON 对象格式,键为用户组名 (字符串),值为包含两个整数的数组 [总次数限制, 成功次数限制]。')}
  • +
  • {t('总次数限制: 周期内允许的总请求次数 (含失败),0 代表不限制。')}
  • +
  • {t('成功次数限制: 周期内允许的成功请求次数 (HTTP < 400),必须大于 0。')}
  • +
  • {t('此配置将优先于上方的全局限制设置。')}
  • +
  • {t('未在此处配置的用户组将使用全局限制。')}
  • +
  • {t('限制周期统一使用上方配置的“限制周期”值。')}
  • +
  • {t('输入无效的 JSON 将无法保存。')}
  • +
+ + } + autosize={{ minRows: 5, maxRows: 15 }} + style={{ fontFamily: 'monospace' }} + onChange={(value) => { + setInputs({ + ...inputs, + ModelRequestRateLimitGroup: value, + }); + }} + /> + +
+ + + +
+
+
+ ); -} + }