diff --git a/.old/option.go b/.old/option.go
new file mode 100644
index 00000000..f80f5cb3
--- /dev/null
+++ b/.old/option.go
@@ -0,0 +1,402 @@
+package model
+
+import (
+ "one-api/common"
+ "one-api/setting"
+ "one-api/setting/config"
+ "one-api/setting/operation_setting"
+ "strconv"
+ "strings"
+ "time"
+)
+
+type Option struct {
+ Key string `json:"key" gorm:"primaryKey"`
+ Value string `json:"value"`
+}
+
+func AllOption() ([]*Option, error) {
+ var options []*Option
+ var err error
+ err = DB.Find(&options).Error
+ return options, err
+}
+
+func InitOptionMap() {
+ common.OptionMapRWMutex.Lock()
+ common.OptionMap = make(map[string]string)
+
+ // 添加原有的系统配置
+ common.OptionMap["FileUploadPermission"] = strconv.Itoa(common.FileUploadPermission)
+ common.OptionMap["FileDownloadPermission"] = strconv.Itoa(common.FileDownloadPermission)
+ common.OptionMap["ImageUploadPermission"] = strconv.Itoa(common.ImageUploadPermission)
+ common.OptionMap["ImageDownloadPermission"] = strconv.Itoa(common.ImageDownloadPermission)
+ common.OptionMap["PasswordLoginEnabled"] = strconv.FormatBool(common.PasswordLoginEnabled)
+ common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled)
+ common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled)
+ common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled)
+ common.OptionMap["LinuxDOOAuthEnabled"] = strconv.FormatBool(common.LinuxDOOAuthEnabled)
+ common.OptionMap["TelegramOAuthEnabled"] = strconv.FormatBool(common.TelegramOAuthEnabled)
+ common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled)
+ common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled)
+ common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled)
+ common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled)
+ common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled)
+ common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled)
+ common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled)
+ common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled)
+ common.OptionMap["DrawingEnabled"] = strconv.FormatBool(common.DrawingEnabled)
+ common.OptionMap["TaskEnabled"] = strconv.FormatBool(common.TaskEnabled)
+ common.OptionMap["DataExportEnabled"] = strconv.FormatBool(common.DataExportEnabled)
+ common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64)
+ common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled)
+ common.OptionMap["EmailAliasRestrictionEnabled"] = strconv.FormatBool(common.EmailAliasRestrictionEnabled)
+ common.OptionMap["EmailDomainWhitelist"] = strings.Join(common.EmailDomainWhitelist, ",")
+ common.OptionMap["SMTPServer"] = ""
+ common.OptionMap["SMTPFrom"] = ""
+ common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort)
+ common.OptionMap["SMTPAccount"] = ""
+ common.OptionMap["SMTPToken"] = ""
+ common.OptionMap["SMTPSSLEnabled"] = strconv.FormatBool(common.SMTPSSLEnabled)
+ common.OptionMap["Notice"] = ""
+ common.OptionMap["About"] = ""
+ common.OptionMap["HomePageContent"] = ""
+ common.OptionMap["Footer"] = common.Footer
+ common.OptionMap["SystemName"] = common.SystemName
+ common.OptionMap["Logo"] = common.Logo
+ common.OptionMap["ServerAddress"] = ""
+ common.OptionMap["WorkerUrl"] = setting.WorkerUrl
+ common.OptionMap["WorkerValidKey"] = setting.WorkerValidKey
+ common.OptionMap["PayAddress"] = ""
+ common.OptionMap["CustomCallbackAddress"] = ""
+ common.OptionMap["EpayId"] = ""
+ common.OptionMap["EpayKey"] = ""
+ common.OptionMap["Price"] = strconv.FormatFloat(setting.Price, 'f', -1, 64)
+ common.OptionMap["MinTopUp"] = strconv.Itoa(setting.MinTopUp)
+ common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
+ common.OptionMap["Chats"] = setting.Chats2JsonString()
+ common.OptionMap["GitHubClientId"] = ""
+ common.OptionMap["GitHubClientSecret"] = ""
+ common.OptionMap["TelegramBotToken"] = ""
+ common.OptionMap["TelegramBotName"] = ""
+ common.OptionMap["WeChatServerAddress"] = ""
+ common.OptionMap["WeChatServerToken"] = ""
+ common.OptionMap["WeChatAccountQRCodeImageURL"] = ""
+ common.OptionMap["TurnstileSiteKey"] = ""
+ common.OptionMap["TurnstileSecretKey"] = ""
+ common.OptionMap["QuotaForNewUser"] = strconv.Itoa(common.QuotaForNewUser)
+ common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter)
+ common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee)
+ common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
+ common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
+ common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount)
+ common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes)
+ common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount)
+ common.OptionMap["ModelRatio"] = operation_setting.ModelRatio2JSONString()
+ common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString()
+ common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString()
+ common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
+ common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
+ common.OptionMap["CompletionRatio"] = operation_setting.CompletionRatio2JSONString()
+ common.OptionMap["TopUpLink"] = common.TopUpLink
+ //common.OptionMap["ChatLink"] = common.ChatLink
+ //common.OptionMap["ChatLink2"] = common.ChatLink2
+ common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64)
+ common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes)
+ common.OptionMap["DataExportInterval"] = strconv.Itoa(common.DataExportInterval)
+ common.OptionMap["DataExportDefaultTime"] = common.DataExportDefaultTime
+ common.OptionMap["DefaultCollapseSidebar"] = strconv.FormatBool(common.DefaultCollapseSidebar)
+ common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(setting.MjNotifyEnabled)
+ common.OptionMap["MjAccountFilterEnabled"] = strconv.FormatBool(setting.MjAccountFilterEnabled)
+ common.OptionMap["MjModeClearEnabled"] = strconv.FormatBool(setting.MjModeClearEnabled)
+ common.OptionMap["MjForwardUrlEnabled"] = strconv.FormatBool(setting.MjForwardUrlEnabled)
+ common.OptionMap["MjActionCheckSuccessEnabled"] = strconv.FormatBool(setting.MjActionCheckSuccessEnabled)
+ common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(setting.CheckSensitiveEnabled)
+ common.OptionMap["DemoSiteEnabled"] = strconv.FormatBool(operation_setting.DemoSiteEnabled)
+ common.OptionMap["SelfUseModeEnabled"] = strconv.FormatBool(operation_setting.SelfUseModeEnabled)
+ common.OptionMap["ModelRequestRateLimitEnabled"] = strconv.FormatBool(setting.ModelRequestRateLimitEnabled)
+ common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(setting.CheckSensitiveOnPromptEnabled)
+ common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(setting.StopOnSensitiveEnabled)
+ common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString()
+ common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength)
+ common.OptionMap["AutomaticDisableKeywords"] = operation_setting.AutomaticDisableKeywordsToString()
+
+ // 自动添加所有注册的模型配置
+ modelConfigs := config.GlobalConfig.ExportAllConfigs()
+ for k, v := range modelConfigs {
+ common.OptionMap[k] = v
+ }
+
+ common.OptionMapRWMutex.Unlock()
+ loadOptionsFromDatabase()
+}
+
+func loadOptionsFromDatabase() {
+ options, _ := AllOption()
+ for _, option := range options {
+ err := updateOptionMap(option.Key, option.Value)
+ if err != nil {
+ common.SysError("failed to update option map: " + err.Error())
+ }
+ }
+}
+
+func SyncOptions(frequency int) {
+ for {
+ time.Sleep(time.Duration(frequency) * time.Second)
+ common.SysLog("syncing options from database")
+ loadOptionsFromDatabase()
+ }
+}
+
+func UpdateOption(key string, value string) error {
+ // Save to database first
+ option := Option{
+ Key: key,
+ }
+ // https://gorm.io/docs/update.html#Save-All-Fields
+ DB.FirstOrCreate(&option, Option{Key: key})
+ option.Value = value
+ // Save is a combination function.
+ // If save value does not contain primary key, it will execute Create,
+ // otherwise it will execute Update (with all fields).
+ DB.Save(&option)
+ // Update OptionMap
+ return updateOptionMap(key, value)
+}
+
+func updateOptionMap(key string, value string) (err error) {
+ common.OptionMapRWMutex.Lock()
+ defer common.OptionMapRWMutex.Unlock()
+ common.OptionMap[key] = value
+
+ // 检查是否是模型配置 - 使用更规范的方式处理
+ if handleConfigUpdate(key, value) {
+ return nil // 已由配置系统处理
+ }
+
+ // 处理传统配置项...
+ if strings.HasSuffix(key, "Permission") {
+ intValue, _ := strconv.Atoi(value)
+ switch key {
+ case "FileUploadPermission":
+ common.FileUploadPermission = intValue
+ case "FileDownloadPermission":
+ common.FileDownloadPermission = intValue
+ case "ImageUploadPermission":
+ common.ImageUploadPermission = intValue
+ case "ImageDownloadPermission":
+ common.ImageDownloadPermission = intValue
+ }
+ }
+ if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" {
+ boolValue := value == "true"
+ switch key {
+ case "PasswordRegisterEnabled":
+ common.PasswordRegisterEnabled = boolValue
+ case "PasswordLoginEnabled":
+ common.PasswordLoginEnabled = boolValue
+ case "EmailVerificationEnabled":
+ common.EmailVerificationEnabled = boolValue
+ case "GitHubOAuthEnabled":
+ common.GitHubOAuthEnabled = boolValue
+ case "LinuxDOOAuthEnabled":
+ common.LinuxDOOAuthEnabled = boolValue
+ case "WeChatAuthEnabled":
+ common.WeChatAuthEnabled = boolValue
+ case "TelegramOAuthEnabled":
+ common.TelegramOAuthEnabled = boolValue
+ case "TurnstileCheckEnabled":
+ common.TurnstileCheckEnabled = boolValue
+ case "RegisterEnabled":
+ common.RegisterEnabled = boolValue
+ case "EmailDomainRestrictionEnabled":
+ common.EmailDomainRestrictionEnabled = boolValue
+ case "EmailAliasRestrictionEnabled":
+ common.EmailAliasRestrictionEnabled = boolValue
+ case "AutomaticDisableChannelEnabled":
+ common.AutomaticDisableChannelEnabled = boolValue
+ case "AutomaticEnableChannelEnabled":
+ common.AutomaticEnableChannelEnabled = boolValue
+ case "LogConsumeEnabled":
+ common.LogConsumeEnabled = boolValue
+ case "DisplayInCurrencyEnabled":
+ common.DisplayInCurrencyEnabled = boolValue
+ case "DisplayTokenStatEnabled":
+ common.DisplayTokenStatEnabled = boolValue
+ case "DrawingEnabled":
+ common.DrawingEnabled = boolValue
+ case "TaskEnabled":
+ common.TaskEnabled = boolValue
+ case "DataExportEnabled":
+ common.DataExportEnabled = boolValue
+ case "DefaultCollapseSidebar":
+ common.DefaultCollapseSidebar = boolValue
+ case "MjNotifyEnabled":
+ setting.MjNotifyEnabled = boolValue
+ case "MjAccountFilterEnabled":
+ setting.MjAccountFilterEnabled = boolValue
+ case "MjModeClearEnabled":
+ setting.MjModeClearEnabled = boolValue
+ case "MjForwardUrlEnabled":
+ setting.MjForwardUrlEnabled = boolValue
+ case "MjActionCheckSuccessEnabled":
+ setting.MjActionCheckSuccessEnabled = boolValue
+ case "CheckSensitiveEnabled":
+ setting.CheckSensitiveEnabled = boolValue
+ case "DemoSiteEnabled":
+ operation_setting.DemoSiteEnabled = boolValue
+ case "SelfUseModeEnabled":
+ operation_setting.SelfUseModeEnabled = boolValue
+ case "CheckSensitiveOnPromptEnabled":
+ setting.CheckSensitiveOnPromptEnabled = boolValue
+ case "ModelRequestRateLimitEnabled":
+ setting.ModelRequestRateLimitEnabled = boolValue
+ case "StopOnSensitiveEnabled":
+ setting.StopOnSensitiveEnabled = boolValue
+ case "SMTPSSLEnabled":
+ common.SMTPSSLEnabled = boolValue
+ }
+ }
+ switch key {
+ case "EmailDomainWhitelist":
+ common.EmailDomainWhitelist = strings.Split(value, ",")
+ case "SMTPServer":
+ common.SMTPServer = value
+ case "SMTPPort":
+ intValue, _ := strconv.Atoi(value)
+ common.SMTPPort = intValue
+ case "SMTPAccount":
+ common.SMTPAccount = value
+ case "SMTPFrom":
+ common.SMTPFrom = value
+ case "SMTPToken":
+ common.SMTPToken = value
+ case "ServerAddress":
+ setting.ServerAddress = value
+ case "WorkerUrl":
+ setting.WorkerUrl = value
+ case "WorkerValidKey":
+ setting.WorkerValidKey = value
+ case "PayAddress":
+ setting.PayAddress = value
+ case "Chats":
+ err = setting.UpdateChatsByJsonString(value)
+ case "CustomCallbackAddress":
+ setting.CustomCallbackAddress = value
+ case "EpayId":
+ setting.EpayId = value
+ case "EpayKey":
+ setting.EpayKey = value
+ case "Price":
+ setting.Price, _ = strconv.ParseFloat(value, 64)
+ case "MinTopUp":
+ setting.MinTopUp, _ = strconv.Atoi(value)
+ case "TopupGroupRatio":
+ err = common.UpdateTopupGroupRatioByJSONString(value)
+ case "GitHubClientId":
+ common.GitHubClientId = value
+ case "GitHubClientSecret":
+ common.GitHubClientSecret = value
+ case "LinuxDOClientId":
+ common.LinuxDOClientId = value
+ case "LinuxDOClientSecret":
+ common.LinuxDOClientSecret = value
+ case "Footer":
+ common.Footer = value
+ case "SystemName":
+ common.SystemName = value
+ case "Logo":
+ common.Logo = value
+ case "WeChatServerAddress":
+ common.WeChatServerAddress = value
+ case "WeChatServerToken":
+ common.WeChatServerToken = value
+ case "WeChatAccountQRCodeImageURL":
+ common.WeChatAccountQRCodeImageURL = value
+ case "TelegramBotToken":
+ common.TelegramBotToken = value
+ case "TelegramBotName":
+ common.TelegramBotName = value
+ case "TurnstileSiteKey":
+ common.TurnstileSiteKey = value
+ case "TurnstileSecretKey":
+ common.TurnstileSecretKey = value
+ case "QuotaForNewUser":
+ common.QuotaForNewUser, _ = strconv.Atoi(value)
+ case "QuotaForInviter":
+ common.QuotaForInviter, _ = strconv.Atoi(value)
+ case "QuotaForInvitee":
+ common.QuotaForInvitee, _ = strconv.Atoi(value)
+ case "QuotaRemindThreshold":
+ common.QuotaRemindThreshold, _ = strconv.Atoi(value)
+ case "PreConsumedQuota":
+ common.PreConsumedQuota, _ = strconv.Atoi(value)
+ case "ModelRequestRateLimitCount":
+ setting.ModelRequestRateLimitCount, _ = strconv.Atoi(value)
+ case "ModelRequestRateLimitDurationMinutes":
+ setting.ModelRequestRateLimitDurationMinutes, _ = strconv.Atoi(value)
+ case "ModelRequestRateLimitSuccessCount":
+ setting.ModelRequestRateLimitSuccessCount, _ = strconv.Atoi(value)
+ case "RetryTimes":
+ common.RetryTimes, _ = strconv.Atoi(value)
+ case "DataExportInterval":
+ common.DataExportInterval, _ = strconv.Atoi(value)
+ case "DataExportDefaultTime":
+ common.DataExportDefaultTime = value
+ case "ModelRatio":
+ err = operation_setting.UpdateModelRatioByJSONString(value)
+ case "GroupRatio":
+ err = setting.UpdateGroupRatioByJSONString(value)
+ case "UserUsableGroups":
+ err = setting.UpdateUserUsableGroupsByJSONString(value)
+ case "CompletionRatio":
+ err = operation_setting.UpdateCompletionRatioByJSONString(value)
+ case "ModelPrice":
+ err = operation_setting.UpdateModelPriceByJSONString(value)
+ case "CacheRatio":
+ err = operation_setting.UpdateCacheRatioByJSONString(value)
+ case "TopUpLink":
+ common.TopUpLink = value
+ //case "ChatLink":
+ // common.ChatLink = value
+ //case "ChatLink2":
+ // common.ChatLink2 = value
+ case "ChannelDisableThreshold":
+ common.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64)
+ case "QuotaPerUnit":
+ common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64)
+ case "SensitiveWords":
+ setting.SensitiveWordsFromString(value)
+ case "AutomaticDisableKeywords":
+ operation_setting.AutomaticDisableKeywordsFromString(value)
+ case "StreamCacheQueueLength":
+ setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
+ }
+ return err
+}
+
+// handleConfigUpdate 处理分层配置更新,返回是否已处理
+func handleConfigUpdate(key, value string) bool {
+ parts := strings.SplitN(key, ".", 2)
+ if len(parts) != 2 {
+ return false // 不是分层配置
+ }
+
+ configName := parts[0]
+ configKey := parts[1]
+
+ // 获取配置对象
+ cfg := config.GlobalConfig.Get(configName)
+ if cfg == nil {
+ return false // 未注册的配置
+ }
+
+ // 更新配置
+ configMap := map[string]string{
+ configKey: value,
+ }
+ config.UpdateConfigFromMap(cfg, configMap)
+
+ return true // 已处理
+}
\ No newline at end of file
diff --git a/middleware/model-rate-limit.go b/middleware/model-rate-limit.go
index d4199ece..b0047b70 100644
--- a/middleware/model-rate-limit.go
+++ b/middleware/model-rate-limit.go
@@ -19,25 +19,20 @@ const (
ModelRequestRateLimitSuccessCountMark = "MRRLS"
)
-// 检查Redis中的请求限制
func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, maxCount int, duration int64) (bool, error) {
- // 如果maxCount为0,表示不限制
if maxCount == 0 {
return true, nil
}
- // 获取当前计数
length, err := rdb.LLen(ctx, key).Result()
if err != nil {
return false, err
}
- // 如果未达到限制,允许请求
if length < int64(maxCount) {
return true, nil
}
- // 检查时间窗口
oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result()
oldTime, err := time.Parse(timeFormat, oldTimeStr)
if err != nil {
@@ -49,7 +44,6 @@ func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, max
if err != nil {
return false, err
}
- // 如果在时间窗口内已达到限制,拒绝请求
subTime := nowTime.Sub(oldTime).Seconds()
if int64(subTime) < duration {
rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute)
@@ -59,9 +53,7 @@ func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, max
return true, nil
}
-// 记录Redis请求
func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxCount int) {
- // 如果maxCount为0,不记录请求
if maxCount == 0 {
return
}
@@ -72,14 +64,12 @@ func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxC
rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute)
}
-// Redis限流处理器
func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc {
return func(c *gin.Context) {
userId := strconv.Itoa(c.GetInt("id"))
ctx := context.Background()
rdb := common.RDB
- // 1. 检查成功请求数限制
successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId)
allowed, err := checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration)
if err != nil {
@@ -92,9 +82,7 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g
return
}
- //2.检查总请求数限制并记录总请求(当totalMaxCount为0时会自动跳过,使用令牌桶限流器
totalKey := fmt.Sprintf("rateLimit:%s", userId)
- // 初始化
tb := limiter.New(ctx, rdb)
allowed, err = tb.Allow(
ctx,
@@ -114,17 +102,14 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g
abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount))
}
- // 4. 处理请求
c.Next()
- // 5. 如果请求成功,记录成功请求
if c.Writer.Status() < 400 {
recordRedisRequest(ctx, rdb, successKey, successMaxCount)
}
}
}
-// 内存限流处理器
func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc {
inMemoryRateLimiter.Init(time.Duration(setting.ModelRequestRateLimitDurationMinutes) * time.Minute)
@@ -133,15 +118,12 @@ func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int)
totalKey := ModelRequestRateLimitCountMark + userId
successKey := ModelRequestRateLimitSuccessCountMark + userId
- // 1. 检查总请求数限制(当totalMaxCount为0时跳过)
if totalMaxCount > 0 && !inMemoryRateLimiter.Request(totalKey, totalMaxCount, duration) {
c.Status(http.StatusTooManyRequests)
c.Abort()
return
}
- // 2. 检查成功请求数限制
- // 使用一个临时key来检查限制,这样可以避免实际记录
checkKey := successKey + "_check"
if !inMemoryRateLimiter.Request(checkKey, successMaxCount, duration) {
c.Status(http.StatusTooManyRequests)
@@ -149,54 +131,47 @@ func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int)
return
}
- // 3. 处理请求
c.Next()
- // 4. 如果请求成功,记录到实际的成功请求计数中
if c.Writer.Status() < 400 {
inMemoryRateLimiter.Request(successKey, successMaxCount, duration)
}
}
}
-// ModelRequestRateLimit 模型请求限流中间件
func ModelRequestRateLimit() func(c *gin.Context) {
return func(c *gin.Context) {
- // 在每个请求时检查是否启用限流
if !setting.ModelRequestRateLimitEnabled {
c.Next()
return
}
- // 计算通用限流参数
duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60)
- // 获取用户组
group := c.GetString("token_group")
if group == "" {
group = c.GetString("group")
}
if group == "" {
- group = "default" // 默认组
+ group = "default"
}
- // 尝试获取用户组特定的限制
+ finalTotalCount := setting.ModelRequestRateLimitCount
+ finalSuccessCount := setting.ModelRequestRateLimitSuccessCount
+ foundGroupLimit := false
+
groupTotalCount, groupSuccessCount, found := setting.GetGroupRateLimit(group)
-
- // 确定最终的限制值
- finalTotalCount := setting.ModelRequestRateLimitCount // 默认使用全局总次数限制
- finalSuccessCount := setting.ModelRequestRateLimitSuccessCount // 默认使用全局成功次数限制
-
if found {
- // 如果找到用户组特定限制,则使用它们
finalTotalCount = groupTotalCount
finalSuccessCount = groupSuccessCount
+ foundGroupLimit = true
common.LogWarn(c.Request.Context(), fmt.Sprintf("Using rate limit for group '%s': total=%d, success=%d", group, finalTotalCount, finalSuccessCount))
- } else {
+ }
+
+ if !foundGroupLimit {
common.LogInfo(c.Request.Context(), fmt.Sprintf("No specific rate limit found for group '%s', using global limits: total=%d, success=%d", group, finalTotalCount, finalSuccessCount))
}
- // 根据存储类型选择并执行限流处理器,传入最终确定的限制值
if common.RedisEnabled {
redisRateLimitHandler(duration, finalTotalCount, finalSuccessCount)(c)
} else {
diff --git a/model/option.go b/model/option.go
index 1f5fb3aa..79556737 100644
--- a/model/option.go
+++ b/model/option.go
@@ -94,11 +94,12 @@ func InitOptionMap() {
common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount)
common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes)
common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount)
+ jsonBytes, _ := json.Marshal(map[string][2]int{})
+ common.OptionMap["ModelRequestRateLimitGroup"] = string(jsonBytes)
common.OptionMap["ModelRatio"] = operation_setting.ModelRatio2JSONString()
common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString()
common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString()
common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
- common.OptionMap[setting.ModelRequestRateLimitGroupKey] = "{}" // 添加用户组速率限制默认值
common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
common.OptionMap["CompletionRatio"] = operation_setting.CompletionRatio2JSONString()
common.OptionMap["TopUpLink"] = common.TopUpLink
@@ -153,47 +154,31 @@ func SyncOptions(frequency int) {
}
func UpdateOption(key string, value string) error {
- originalValue := value // 保存原始值以备后用
+ originalValue := value
- // Validate and format specific keys before saving
- if key == setting.ModelRequestRateLimitGroupKey {
+ if key == "ModelRequestRateLimitGroup" {
var cfg map[string][2]int
- // Validate the JSON structure first using the original value
err := json.Unmarshal([]byte(originalValue), &cfg)
if err != nil {
- // 提供更具体的错误信息
return fmt.Errorf("无效的 JSON 格式 for %s: %w", key, err)
}
- // TODO: 可以添加更细致的结构验证,例如检查数组长度是否为2,值是否为非负数等。
- // if !isValidModelRequestRateLimitGroupConfig(cfg) {
- // return fmt.Errorf("无效的配置值 for %s", key)
- // }
- // If valid, format the JSON before saving
formattedValueBytes, marshalErr := json.MarshalIndent(cfg, "", " ")
if marshalErr != nil {
- // This should ideally not happen if validation passed, but handle defensively
return fmt.Errorf("failed to marshal validated %s config: %w", key, marshalErr)
}
- value = string(formattedValueBytes) // Use formatted JSON for saving and memory update
+ value = string(formattedValueBytes)
}
- // Save to database
option := Option{
Key: key,
}
- // https://gorm.io/docs/update.html#Save-All-Fields
DB.FirstOrCreate(&option, Option{Key: key})
option.Value = value
- // Save is a combination function.
- // If save value does not contain primary key, it will execute Create,
- // otherwise it will execute Update (with all fields).
if err := DB.Save(&option).Error; err != nil {
- return fmt.Errorf("保存选项 %s 到数据库失败: %w", key, err) // 添加错误上下文
+ return fmt.Errorf("保存选项 %s 到数据库失败: %w", key, err)
}
- // Update OptionMap in memory using the potentially formatted value
- // updateOptionMap 会处理内存中 setting.ModelRequestRateLimitGroupConfig 的更新
return updateOptionMap(key, value)
}
@@ -370,6 +355,8 @@ func updateOptionMap(key string, value string) (err error) {
setting.ModelRequestRateLimitDurationMinutes, _ = strconv.Atoi(value)
case "ModelRequestRateLimitSuccessCount":
setting.ModelRequestRateLimitSuccessCount, _ = strconv.Atoi(value)
+ case "ModelRequestRateLimitGroup":
+ err = setting.UpdateModelRequestRateLimitGroup(value)
case "RetryTimes":
common.RetryTimes, _ = strconv.Atoi(value)
case "DataExportInterval":
@@ -404,15 +391,6 @@ func updateOptionMap(key string, value string) (err error) {
operation_setting.AutomaticDisableKeywordsFromString(value)
case "StreamCacheQueueLength":
setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
- case setting.ModelRequestRateLimitGroupKey:
- // Use the (potentially formatted) value passed from UpdateOption
- // to update the actual configuration in memory.
- // This is the single point where the memory state for this specific setting is updated.
- err = setting.UpdateModelRequestRateLimitGroupConfig(value)
- if err != nil {
- // 添加错误上下文
- err = fmt.Errorf("更新内存中的 %s 配置失败: %w", key, err)
- }
}
return err
}
@@ -440,4 +418,4 @@ func handleConfigUpdate(key, value string) bool {
config.UpdateConfigFromMap(cfg, configMap)
return true // 已处理
-}
+}
\ No newline at end of file
diff --git a/setting/rate_limit.go b/setting/rate_limit.go
index c83885a6..5be75cc1 100644
--- a/setting/rate_limit.go
+++ b/setting/rate_limit.go
@@ -11,24 +11,17 @@ var ModelRequestRateLimitEnabled = false
var ModelRequestRateLimitDurationMinutes = 1
var ModelRequestRateLimitCount = 0
var ModelRequestRateLimitSuccessCount = 1000
+var ModelRequestRateLimitGroup map[string][2]int
-// ModelRequestRateLimitGroupKey 定义了模型请求按组速率限制的配置键
-const ModelRequestRateLimitGroupKey = "ModelRequestRateLimitGroup"
-
-// ModelRequestRateLimitGroupConfig 存储按用户组解析后的速率限制配置
-// map[groupName][2]int{totalCount, successCount}
-var ModelRequestRateLimitGroupConfig map[string][2]int
var ModelRequestRateLimitGroupMutex sync.RWMutex
-// UpdateModelRequestRateLimitGroupConfig 解析、校验并更新内存中的用户组速率限制配置
-func UpdateModelRequestRateLimitGroupConfig(jsonStr string) error {
+func UpdateModelRequestRateLimitGroup(jsonStr string) error {
ModelRequestRateLimitGroupMutex.Lock()
defer ModelRequestRateLimitGroupMutex.Unlock()
var newConfig map[string][2]int
if jsonStr == "" || jsonStr == "{}" {
- // 如果配置为空或空JSON对象,则清空内存配置
- ModelRequestRateLimitGroupConfig = make(map[string][2]int)
+ ModelRequestRateLimitGroup = make(map[string][2]int)
common.SysLog("Model request rate limit group config cleared")
return nil
}
@@ -38,37 +31,19 @@ func UpdateModelRequestRateLimitGroupConfig(jsonStr string) error {
return fmt.Errorf("failed to unmarshal ModelRequestRateLimitGroup config: %w", err)
}
- // 校验配置值
- for group, limits := range newConfig {
- if len(limits) != 2 {
- return fmt.Errorf("invalid config for group '%s': limits array length must be 2", group)
- }
- if limits[1] <= 0 { // successCount must be greater than 0
- return fmt.Errorf("invalid config for group '%s': successCount (limits[1]) must be greater than 0", group)
- }
- if limits[0] < 0 { // totalCount can be 0 (no limit) or positive
- return fmt.Errorf("invalid config for group '%s': totalCount (limits[0]) cannot be negative", group)
- }
- if limits[0] > 0 && limits[0] < limits[1] { // If totalCount is set, it must be >= successCount
- return fmt.Errorf("invalid config for group '%s': totalCount (limits[0]) must be greater than or equal to successCount (limits[1]) when totalCount > 0", group)
- }
- }
-
- ModelRequestRateLimitGroupConfig = newConfig
- common.SysLog("Model request rate limit group config updated")
+ ModelRequestRateLimitGroup = newConfig
return nil
}
-// GetGroupRateLimit 安全地获取指定用户组的速率限制值
func GetGroupRateLimit(group string) (totalCount, successCount int, found bool) {
ModelRequestRateLimitGroupMutex.RLock()
defer ModelRequestRateLimitGroupMutex.RUnlock()
- if ModelRequestRateLimitGroupConfig == nil {
- return 0, 0, false // 配置尚未初始化
+ if ModelRequestRateLimitGroup == nil {
+ return 0, 0, false
}
- limits, found := ModelRequestRateLimitGroupConfig[group]
+ limits, found := ModelRequestRateLimitGroup[group]
if !found {
return 0, 0, false
}
diff --git a/web/src/components/RateLimitSetting.js b/web/src/components/RateLimitSetting.js
index ad6b53da..7e206672 100644
--- a/web/src/components/RateLimitSetting.js
+++ b/web/src/components/RateLimitSetting.js
@@ -9,59 +9,59 @@ import RequestRateLimit from '../pages/Setting/RateLimit/SettingsRequestRateLimi
const RateLimitSetting = () => {
const { t } = useTranslation();
let [inputs, setInputs] = useState({
- ModelRequestRateLimitEnabled: false,
- ModelRequestRateLimitCount: 0,
- ModelRequestRateLimitSuccessCount: 1000,
- ModelRequestRateLimitDurationMinutes: 1,
- ModelRequestRateLimitGroup: {},
+ ModelRequestRateLimitEnabled: false,
+ ModelRequestRateLimitCount: 0,
+ ModelRequestRateLimitSuccessCount: 1000,
+ ModelRequestRateLimitDurationMinutes: 1,
+ ModelRequestRateLimitGroup: '{}',
});
-
+
let [loading, setLoading] = useState(false);
-
+
const getOptions = async () => {
- const res = await API.get('/api/option/');
- const { success, message, data } = res.data;
- if (success) {
- let newInputs = {};
- data.forEach((item) => {
- if (item.key.endsWith('Enabled')) {
- newInputs[item.key] = item.value === 'true' ? true : false;
- } else {
- newInputs[item.key] = item.value;
- }
- });
-
- setInputs(newInputs);
- } else {
- showError(message);
- }
+ const res = await API.get('/api/option/');
+ const { success, message, data } = res.data;
+ if (success) {
+ let newInputs = {};
+ data.forEach((item) => {
+ // 检查 key 是否在初始 inputs 中定义
+ if (Object.prototype.hasOwnProperty.call(inputs, item.key)) {
+ if (item.key.endsWith('Enabled')) {
+ newInputs[item.key] = item.value === 'true';
+ } else {
+ newInputs[item.key] = item.value;
+ }
+ }
+ });
+ setInputs(newInputs);
+ } else {
+ showError(message);
+ }
};
async function onRefresh() {
- try {
- setLoading(true);
- await getOptions();
- // showSuccess('刷新成功');
- } catch (error) {
- showError('刷新失败');
- } finally {
- setLoading(false);
- }
+ try {
+ setLoading(true);
+ await getOptions();
+ } catch (error) {
+ showError('刷新失败');
+ } finally {
+ setLoading(false);
+ }
}
-
+
useEffect(() => {
- onRefresh();
+ onRefresh();
}, []);
-
+
return (
- <>
-
- {/* AI请求速率限制 */}
-
-
-
-
- >
+ <>
+
+
+
+
+
+ >
);
-};
-
-export default RateLimitSetting;
+ };
+
+ export default RateLimitSetting;
diff --git a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js
index ec1c2158..2434020e 100644
--- a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js
+++ b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js
@@ -14,209 +14,201 @@ export default function RequestRateLimit(props) {
const [loading, setLoading] = useState(false);
const [inputs, setInputs] = useState({
- ModelRequestRateLimitEnabled: false,
- ModelRequestRateLimitCount: -1,
- ModelRequestRateLimitSuccessCount: 1000,
- ModelRequestRateLimitDurationMinutes: 1,
- ModelRequestRateLimitGroup: '{}', // 添加新字段并设置默认值
+ ModelRequestRateLimitEnabled: false,
+ ModelRequestRateLimitCount: -1,
+ ModelRequestRateLimitSuccessCount: 1000,
+ ModelRequestRateLimitDurationMinutes: 1,
+ ModelRequestRateLimitGroup: '{}',
});
const refForm = useRef();
const [inputsRow, setInputsRow] = useState(inputs);
-
+
function onSubmit() {
- const updateArray = compareObjects(inputs, inputsRow);
- if (!updateArray.length) return showWarning(t('你似乎并没有修改什么'));
- const requestQueue = updateArray.map((item) => {
- let value = '';
- if (typeof inputs[item.key] === 'boolean') {
- value = String(inputs[item.key]);
- } else {
- value = inputs[item.key];
- }
- // 校验 ModelRequestRateLimitGroup 是否为有效的 JSON 对象字符串
- if (item.key === 'ModelRequestRateLimitGroup') {
- try {
- JSON.parse(value);
- } catch (e) {
- showError(t('用户组速率限制配置不是有效的 JSON 格式!'));
- // 阻止请求发送
- return Promise.reject('Invalid JSON format');
- }
- }
- return API.put('/api/option/', {
- key: item.key,
- value,
- });
- });
-
- // 过滤掉无效的请求(例如,无效的 JSON)
- const validRequests = requestQueue.filter(req => req !== null && req !== undefined && typeof req.then === 'function');
-
- if (validRequests.length === 0 && requestQueue.length > 0) {
- // 如果所有请求都被过滤掉了(因为 JSON 无效),则不继续执行
- return;
- }
-
- setLoading(true);
- Promise.all(validRequests)
- .then((res) => {
- if (validRequests.length === 1) {
- if (res.includes(undefined)) return;
- } else if (validRequests.length > 1) {
- if (res.includes(undefined))
- return showError(t('部分保存失败,请重试'));
- }
- showSuccess(t('保存成功'));
- props.refresh();
- // 更新 inputsRow 以反映保存后的状态
- setInputsRow(structuredClone(inputs));
- })
- .catch((error) => {
- // 检查是否是由于无效 JSON 导致的错误
- if (error !== 'Invalid JSON format') {
- showError(t('保存失败,请重试'));
- }
- })
- .finally(() => {
- setLoading(false);
- });
+ const updateArray = compareObjects(inputs, inputsRow);
+ if (!updateArray.length) return showWarning(t('你似乎并没有修改什么'));
+ const requestQueue = updateArray.map((item) => {
+ let value = '';
+ if (typeof inputs[item.key] === 'boolean') {
+ value = String(inputs[item.key]);
+ } else {
+ value = inputs[item.key];
+ }
+ if (item.key === 'ModelRequestRateLimitGroup') {
+ try {
+ JSON.parse(value);
+ } catch (e) {
+ showError(t('用户组速率限制配置不是有效的 JSON 格式!'));
+ return Promise.reject('Invalid JSON format');
+ }
+ }
+ return API.put('/api/option/', {
+ key: item.key,
+ value,
+ });
+ });
+
+ const validRequests = requestQueue.filter(req => req !== null && req !== undefined && typeof req.then === 'function');
+
+ if (validRequests.length === 0 && requestQueue.length > 0) {
+ return;
+ }
+
+ setLoading(true);
+ Promise.all(validRequests)
+ .then((res) => {
+ if (validRequests.length === 1) {
+ if (res.includes(undefined)) return;
+ } else if (validRequests.length > 1) {
+ if (res.includes(undefined))
+ return showError(t('部分保存失败,请重试'));
+ }
+ showSuccess(t('保存成功'));
+ props.refresh();
+ setInputsRow(structuredClone(inputs));
+ })
+ .catch((error) => {
+ if (error !== 'Invalid JSON format') {
+ showError(t('保存失败,请重试'));
+ }
+ })
+ .finally(() => {
+ setLoading(false);
+ });
}
-
+
useEffect(() => {
- const currentInputs = {};
- for (let key in props.options) {
- if (Object.keys(inputs).includes(key)) {
- currentInputs[key] = props.options[key];
- }
- }
- setInputs(currentInputs);
- setInputsRow(structuredClone(currentInputs));
- // 检查 refForm.current 是否存在
- if (refForm.current) {
- refForm.current.setValues(currentInputs);
- }
- }, [props.options]); // 依赖项保持不变,因为 inputs 状态的结构已固定
-
+ const currentInputs = {};
+ for (let key in props.options) {
+ if (Object.prototype.hasOwnProperty.call(inputs, key)) { // 使用 hasOwnProperty 检查
+ currentInputs[key] = props.options[key];
+ }
+ }
+ setInputs(currentInputs);
+ setInputsRow(structuredClone(currentInputs));
+ if (refForm.current) {
+ refForm.current.setValues(currentInputs);
+ }
+ }, [props.options]);
+
return (
- <>
-
-
-
-
- {
- setInputs({
- ...inputs,
- ModelRequestRateLimitEnabled: value,
- });
- }}
- />
-
-
-
-
-
- setInputs({
- ...inputs,
- ModelRequestRateLimitDurationMinutes: String(value),
- })
- }
- />
-
-
-
-
-
- setInputs({
- ...inputs,
- ModelRequestRateLimitCount: String(value),
- })
- }
- />
-
-
-
- setInputs({
- ...inputs,
- ModelRequestRateLimitSuccessCount: String(value),
- })
- }
- />
-
-
- {/* 用户组速率限制配置项 */}
-
-
-
- {t('说明:')}
-
- - {t('使用 JSON 对象格式,键为用户组名 (字符串),值为包含两个整数的数组 [总次数限制, 成功次数限制]。')}
- - {t('总次数限制: 周期内允许的总请求次数 (含失败),0 代表不限制。')}
- - {t('成功次数限制: 周期内允许的成功请求次数 (HTTP < 400),必须大于 0。')}
- - {t('此配置将优先于上方的全局限制设置。')}
- - {t('未在此处配置的用户组将使用全局限制。')}
- - {t('限制周期统一使用上方配置的“限制周期”值。')}
- - {t('输入无效的 JSON 将无法保存。')}
-
-
- }
- autosize={{ minRows: 5, maxRows: 15 }}
- style={{ fontFamily: 'monospace' }}
- onChange={(value) => {
- setInputs({
- ...inputs,
- ModelRequestRateLimitGroup: value, // 直接更新字符串值
- });
- }}
- />
-
-
-
-
-
-
-
-
- >
+ <>
+
+
+
+
+ {
+ setInputs({
+ ...inputs,
+ ModelRequestRateLimitEnabled: value,
+ });
+ }}
+ />
+
+
+
+
+
+ setInputs({
+ ...inputs,
+ ModelRequestRateLimitDurationMinutes: String(value),
+ })
+ }
+ />
+
+
+
+
+
+ setInputs({
+ ...inputs,
+ ModelRequestRateLimitCount: String(value),
+ })
+ }
+ />
+
+
+
+ setInputs({
+ ...inputs,
+ ModelRequestRateLimitSuccessCount: String(value),
+ })
+ }
+ />
+
+
+
+
+
+ {t('说明:')}
+
+ - {t('使用 JSON 对象格式,键为用户组名 (字符串),值为包含两个整数的数组 [总次数限制, 成功次数限制]。')}
+ - {t('总次数限制: 周期内允许的总请求次数 (含失败),0 代表不限制。')}
+ - {t('成功次数限制: 周期内允许的成功请求次数 (HTTP < 400),必须大于 0。')}
+ - {t('此配置将优先于上方的全局限制设置。')}
+ - {t('未在此处配置的用户组将使用全局限制。')}
+ - {t('限制周期统一使用上方配置的“限制周期”值。')}
+ - {t('输入无效的 JSON 将无法保存。')}
+
+
+ }
+ autosize={{ minRows: 5, maxRows: 15 }}
+ style={{ fontFamily: 'monospace' }}
+ onChange={(value) => {
+ setInputs({
+ ...inputs,
+ ModelRequestRateLimitGroup: value,
+ });
+ }}
+ />
+
+
+
+
+
+
+
+
+ >
);
-}
+ }