feat: 优化代码,去除多余注释和修改

This commit is contained in:
tbphp
2025-05-05 11:34:57 +08:00
parent 6c3fb7777e
commit 7e7d6112ca
6 changed files with 663 additions and 341 deletions

402
.old/option.go Normal file
View File

@@ -0,0 +1,402 @@
package model
import (
"one-api/common"
"one-api/setting"
"one-api/setting/config"
"one-api/setting/operation_setting"
"strconv"
"strings"
"time"
)
type Option struct {
Key string `json:"key" gorm:"primaryKey"`
Value string `json:"value"`
}
func AllOption() ([]*Option, error) {
var options []*Option
var err error
err = DB.Find(&options).Error
return options, err
}
func InitOptionMap() {
common.OptionMapRWMutex.Lock()
common.OptionMap = make(map[string]string)
// 添加原有的系统配置
common.OptionMap["FileUploadPermission"] = strconv.Itoa(common.FileUploadPermission)
common.OptionMap["FileDownloadPermission"] = strconv.Itoa(common.FileDownloadPermission)
common.OptionMap["ImageUploadPermission"] = strconv.Itoa(common.ImageUploadPermission)
common.OptionMap["ImageDownloadPermission"] = strconv.Itoa(common.ImageDownloadPermission)
common.OptionMap["PasswordLoginEnabled"] = strconv.FormatBool(common.PasswordLoginEnabled)
common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled)
common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled)
common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled)
common.OptionMap["LinuxDOOAuthEnabled"] = strconv.FormatBool(common.LinuxDOOAuthEnabled)
common.OptionMap["TelegramOAuthEnabled"] = strconv.FormatBool(common.TelegramOAuthEnabled)
common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled)
common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled)
common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled)
common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled)
common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled)
common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled)
common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled)
common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled)
common.OptionMap["DrawingEnabled"] = strconv.FormatBool(common.DrawingEnabled)
common.OptionMap["TaskEnabled"] = strconv.FormatBool(common.TaskEnabled)
common.OptionMap["DataExportEnabled"] = strconv.FormatBool(common.DataExportEnabled)
common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64)
common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled)
common.OptionMap["EmailAliasRestrictionEnabled"] = strconv.FormatBool(common.EmailAliasRestrictionEnabled)
common.OptionMap["EmailDomainWhitelist"] = strings.Join(common.EmailDomainWhitelist, ",")
common.OptionMap["SMTPServer"] = ""
common.OptionMap["SMTPFrom"] = ""
common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort)
common.OptionMap["SMTPAccount"] = ""
common.OptionMap["SMTPToken"] = ""
common.OptionMap["SMTPSSLEnabled"] = strconv.FormatBool(common.SMTPSSLEnabled)
common.OptionMap["Notice"] = ""
common.OptionMap["About"] = ""
common.OptionMap["HomePageContent"] = ""
common.OptionMap["Footer"] = common.Footer
common.OptionMap["SystemName"] = common.SystemName
common.OptionMap["Logo"] = common.Logo
common.OptionMap["ServerAddress"] = ""
common.OptionMap["WorkerUrl"] = setting.WorkerUrl
common.OptionMap["WorkerValidKey"] = setting.WorkerValidKey
common.OptionMap["PayAddress"] = ""
common.OptionMap["CustomCallbackAddress"] = ""
common.OptionMap["EpayId"] = ""
common.OptionMap["EpayKey"] = ""
common.OptionMap["Price"] = strconv.FormatFloat(setting.Price, 'f', -1, 64)
common.OptionMap["MinTopUp"] = strconv.Itoa(setting.MinTopUp)
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
common.OptionMap["Chats"] = setting.Chats2JsonString()
common.OptionMap["GitHubClientId"] = ""
common.OptionMap["GitHubClientSecret"] = ""
common.OptionMap["TelegramBotToken"] = ""
common.OptionMap["TelegramBotName"] = ""
common.OptionMap["WeChatServerAddress"] = ""
common.OptionMap["WeChatServerToken"] = ""
common.OptionMap["WeChatAccountQRCodeImageURL"] = ""
common.OptionMap["TurnstileSiteKey"] = ""
common.OptionMap["TurnstileSecretKey"] = ""
common.OptionMap["QuotaForNewUser"] = strconv.Itoa(common.QuotaForNewUser)
common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter)
common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee)
common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount)
common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes)
common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount)
common.OptionMap["ModelRatio"] = operation_setting.ModelRatio2JSONString()
common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString()
common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString()
common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
common.OptionMap["CompletionRatio"] = operation_setting.CompletionRatio2JSONString()
common.OptionMap["TopUpLink"] = common.TopUpLink
//common.OptionMap["ChatLink"] = common.ChatLink
//common.OptionMap["ChatLink2"] = common.ChatLink2
common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64)
common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes)
common.OptionMap["DataExportInterval"] = strconv.Itoa(common.DataExportInterval)
common.OptionMap["DataExportDefaultTime"] = common.DataExportDefaultTime
common.OptionMap["DefaultCollapseSidebar"] = strconv.FormatBool(common.DefaultCollapseSidebar)
common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(setting.MjNotifyEnabled)
common.OptionMap["MjAccountFilterEnabled"] = strconv.FormatBool(setting.MjAccountFilterEnabled)
common.OptionMap["MjModeClearEnabled"] = strconv.FormatBool(setting.MjModeClearEnabled)
common.OptionMap["MjForwardUrlEnabled"] = strconv.FormatBool(setting.MjForwardUrlEnabled)
common.OptionMap["MjActionCheckSuccessEnabled"] = strconv.FormatBool(setting.MjActionCheckSuccessEnabled)
common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(setting.CheckSensitiveEnabled)
common.OptionMap["DemoSiteEnabled"] = strconv.FormatBool(operation_setting.DemoSiteEnabled)
common.OptionMap["SelfUseModeEnabled"] = strconv.FormatBool(operation_setting.SelfUseModeEnabled)
common.OptionMap["ModelRequestRateLimitEnabled"] = strconv.FormatBool(setting.ModelRequestRateLimitEnabled)
common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(setting.CheckSensitiveOnPromptEnabled)
common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(setting.StopOnSensitiveEnabled)
common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString()
common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength)
common.OptionMap["AutomaticDisableKeywords"] = operation_setting.AutomaticDisableKeywordsToString()
// 自动添加所有注册的模型配置
modelConfigs := config.GlobalConfig.ExportAllConfigs()
for k, v := range modelConfigs {
common.OptionMap[k] = v
}
common.OptionMapRWMutex.Unlock()
loadOptionsFromDatabase()
}
func loadOptionsFromDatabase() {
options, _ := AllOption()
for _, option := range options {
err := updateOptionMap(option.Key, option.Value)
if err != nil {
common.SysError("failed to update option map: " + err.Error())
}
}
}
func SyncOptions(frequency int) {
for {
time.Sleep(time.Duration(frequency) * time.Second)
common.SysLog("syncing options from database")
loadOptionsFromDatabase()
}
}
func UpdateOption(key string, value string) error {
// Save to database first
option := Option{
Key: key,
}
// https://gorm.io/docs/update.html#Save-All-Fields
DB.FirstOrCreate(&option, Option{Key: key})
option.Value = value
// Save is a combination function.
// If save value does not contain primary key, it will execute Create,
// otherwise it will execute Update (with all fields).
DB.Save(&option)
// Update OptionMap
return updateOptionMap(key, value)
}
func updateOptionMap(key string, value string) (err error) {
common.OptionMapRWMutex.Lock()
defer common.OptionMapRWMutex.Unlock()
common.OptionMap[key] = value
// 检查是否是模型配置 - 使用更规范的方式处理
if handleConfigUpdate(key, value) {
return nil // 已由配置系统处理
}
// 处理传统配置项...
if strings.HasSuffix(key, "Permission") {
intValue, _ := strconv.Atoi(value)
switch key {
case "FileUploadPermission":
common.FileUploadPermission = intValue
case "FileDownloadPermission":
common.FileDownloadPermission = intValue
case "ImageUploadPermission":
common.ImageUploadPermission = intValue
case "ImageDownloadPermission":
common.ImageDownloadPermission = intValue
}
}
if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" {
boolValue := value == "true"
switch key {
case "PasswordRegisterEnabled":
common.PasswordRegisterEnabled = boolValue
case "PasswordLoginEnabled":
common.PasswordLoginEnabled = boolValue
case "EmailVerificationEnabled":
common.EmailVerificationEnabled = boolValue
case "GitHubOAuthEnabled":
common.GitHubOAuthEnabled = boolValue
case "LinuxDOOAuthEnabled":
common.LinuxDOOAuthEnabled = boolValue
case "WeChatAuthEnabled":
common.WeChatAuthEnabled = boolValue
case "TelegramOAuthEnabled":
common.TelegramOAuthEnabled = boolValue
case "TurnstileCheckEnabled":
common.TurnstileCheckEnabled = boolValue
case "RegisterEnabled":
common.RegisterEnabled = boolValue
case "EmailDomainRestrictionEnabled":
common.EmailDomainRestrictionEnabled = boolValue
case "EmailAliasRestrictionEnabled":
common.EmailAliasRestrictionEnabled = boolValue
case "AutomaticDisableChannelEnabled":
common.AutomaticDisableChannelEnabled = boolValue
case "AutomaticEnableChannelEnabled":
common.AutomaticEnableChannelEnabled = boolValue
case "LogConsumeEnabled":
common.LogConsumeEnabled = boolValue
case "DisplayInCurrencyEnabled":
common.DisplayInCurrencyEnabled = boolValue
case "DisplayTokenStatEnabled":
common.DisplayTokenStatEnabled = boolValue
case "DrawingEnabled":
common.DrawingEnabled = boolValue
case "TaskEnabled":
common.TaskEnabled = boolValue
case "DataExportEnabled":
common.DataExportEnabled = boolValue
case "DefaultCollapseSidebar":
common.DefaultCollapseSidebar = boolValue
case "MjNotifyEnabled":
setting.MjNotifyEnabled = boolValue
case "MjAccountFilterEnabled":
setting.MjAccountFilterEnabled = boolValue
case "MjModeClearEnabled":
setting.MjModeClearEnabled = boolValue
case "MjForwardUrlEnabled":
setting.MjForwardUrlEnabled = boolValue
case "MjActionCheckSuccessEnabled":
setting.MjActionCheckSuccessEnabled = boolValue
case "CheckSensitiveEnabled":
setting.CheckSensitiveEnabled = boolValue
case "DemoSiteEnabled":
operation_setting.DemoSiteEnabled = boolValue
case "SelfUseModeEnabled":
operation_setting.SelfUseModeEnabled = boolValue
case "CheckSensitiveOnPromptEnabled":
setting.CheckSensitiveOnPromptEnabled = boolValue
case "ModelRequestRateLimitEnabled":
setting.ModelRequestRateLimitEnabled = boolValue
case "StopOnSensitiveEnabled":
setting.StopOnSensitiveEnabled = boolValue
case "SMTPSSLEnabled":
common.SMTPSSLEnabled = boolValue
}
}
switch key {
case "EmailDomainWhitelist":
common.EmailDomainWhitelist = strings.Split(value, ",")
case "SMTPServer":
common.SMTPServer = value
case "SMTPPort":
intValue, _ := strconv.Atoi(value)
common.SMTPPort = intValue
case "SMTPAccount":
common.SMTPAccount = value
case "SMTPFrom":
common.SMTPFrom = value
case "SMTPToken":
common.SMTPToken = value
case "ServerAddress":
setting.ServerAddress = value
case "WorkerUrl":
setting.WorkerUrl = value
case "WorkerValidKey":
setting.WorkerValidKey = value
case "PayAddress":
setting.PayAddress = value
case "Chats":
err = setting.UpdateChatsByJsonString(value)
case "CustomCallbackAddress":
setting.CustomCallbackAddress = value
case "EpayId":
setting.EpayId = value
case "EpayKey":
setting.EpayKey = value
case "Price":
setting.Price, _ = strconv.ParseFloat(value, 64)
case "MinTopUp":
setting.MinTopUp, _ = strconv.Atoi(value)
case "TopupGroupRatio":
err = common.UpdateTopupGroupRatioByJSONString(value)
case "GitHubClientId":
common.GitHubClientId = value
case "GitHubClientSecret":
common.GitHubClientSecret = value
case "LinuxDOClientId":
common.LinuxDOClientId = value
case "LinuxDOClientSecret":
common.LinuxDOClientSecret = value
case "Footer":
common.Footer = value
case "SystemName":
common.SystemName = value
case "Logo":
common.Logo = value
case "WeChatServerAddress":
common.WeChatServerAddress = value
case "WeChatServerToken":
common.WeChatServerToken = value
case "WeChatAccountQRCodeImageURL":
common.WeChatAccountQRCodeImageURL = value
case "TelegramBotToken":
common.TelegramBotToken = value
case "TelegramBotName":
common.TelegramBotName = value
case "TurnstileSiteKey":
common.TurnstileSiteKey = value
case "TurnstileSecretKey":
common.TurnstileSecretKey = value
case "QuotaForNewUser":
common.QuotaForNewUser, _ = strconv.Atoi(value)
case "QuotaForInviter":
common.QuotaForInviter, _ = strconv.Atoi(value)
case "QuotaForInvitee":
common.QuotaForInvitee, _ = strconv.Atoi(value)
case "QuotaRemindThreshold":
common.QuotaRemindThreshold, _ = strconv.Atoi(value)
case "PreConsumedQuota":
common.PreConsumedQuota, _ = strconv.Atoi(value)
case "ModelRequestRateLimitCount":
setting.ModelRequestRateLimitCount, _ = strconv.Atoi(value)
case "ModelRequestRateLimitDurationMinutes":
setting.ModelRequestRateLimitDurationMinutes, _ = strconv.Atoi(value)
case "ModelRequestRateLimitSuccessCount":
setting.ModelRequestRateLimitSuccessCount, _ = strconv.Atoi(value)
case "RetryTimes":
common.RetryTimes, _ = strconv.Atoi(value)
case "DataExportInterval":
common.DataExportInterval, _ = strconv.Atoi(value)
case "DataExportDefaultTime":
common.DataExportDefaultTime = value
case "ModelRatio":
err = operation_setting.UpdateModelRatioByJSONString(value)
case "GroupRatio":
err = setting.UpdateGroupRatioByJSONString(value)
case "UserUsableGroups":
err = setting.UpdateUserUsableGroupsByJSONString(value)
case "CompletionRatio":
err = operation_setting.UpdateCompletionRatioByJSONString(value)
case "ModelPrice":
err = operation_setting.UpdateModelPriceByJSONString(value)
case "CacheRatio":
err = operation_setting.UpdateCacheRatioByJSONString(value)
case "TopUpLink":
common.TopUpLink = value
//case "ChatLink":
// common.ChatLink = value
//case "ChatLink2":
// common.ChatLink2 = value
case "ChannelDisableThreshold":
common.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64)
case "QuotaPerUnit":
common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64)
case "SensitiveWords":
setting.SensitiveWordsFromString(value)
case "AutomaticDisableKeywords":
operation_setting.AutomaticDisableKeywordsFromString(value)
case "StreamCacheQueueLength":
setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
}
return err
}
// handleConfigUpdate 处理分层配置更新,返回是否已处理
func handleConfigUpdate(key, value string) bool {
parts := strings.SplitN(key, ".", 2)
if len(parts) != 2 {
return false // 不是分层配置
}
configName := parts[0]
configKey := parts[1]
// 获取配置对象
cfg := config.GlobalConfig.Get(configName)
if cfg == nil {
return false // 未注册的配置
}
// 更新配置
configMap := map[string]string{
configKey: value,
}
config.UpdateConfigFromMap(cfg, configMap)
return true // 已处理
}

View File

@@ -19,25 +19,20 @@ const (
ModelRequestRateLimitSuccessCountMark = "MRRLS"
)
// 检查Redis中的请求限制
func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, maxCount int, duration int64) (bool, error) {
// 如果maxCount为0表示不限制
if maxCount == 0 {
return true, nil
}
// 获取当前计数
length, err := rdb.LLen(ctx, key).Result()
if err != nil {
return false, err
}
// 如果未达到限制,允许请求
if length < int64(maxCount) {
return true, nil
}
// 检查时间窗口
oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result()
oldTime, err := time.Parse(timeFormat, oldTimeStr)
if err != nil {
@@ -49,7 +44,6 @@ func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, max
if err != nil {
return false, err
}
// 如果在时间窗口内已达到限制,拒绝请求
subTime := nowTime.Sub(oldTime).Seconds()
if int64(subTime) < duration {
rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute)
@@ -59,9 +53,7 @@ func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, max
return true, nil
}
// 记录Redis请求
func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxCount int) {
// 如果maxCount为0不记录请求
if maxCount == 0 {
return
}
@@ -72,14 +64,12 @@ func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxC
rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute)
}
// Redis限流处理器
func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc {
return func(c *gin.Context) {
userId := strconv.Itoa(c.GetInt("id"))
ctx := context.Background()
rdb := common.RDB
// 1. 检查成功请求数限制
successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId)
allowed, err := checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration)
if err != nil {
@@ -92,9 +82,7 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g
return
}
//2.检查总请求数限制并记录总请求当totalMaxCount为0时会自动跳过使用令牌桶限流器
totalKey := fmt.Sprintf("rateLimit:%s", userId)
// 初始化
tb := limiter.New(ctx, rdb)
allowed, err = tb.Allow(
ctx,
@@ -114,17 +102,14 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g
abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次包括失败次数请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount))
}
// 4. 处理请求
c.Next()
// 5. 如果请求成功,记录成功请求
if c.Writer.Status() < 400 {
recordRedisRequest(ctx, rdb, successKey, successMaxCount)
}
}
}
// 内存限流处理器
func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc {
inMemoryRateLimiter.Init(time.Duration(setting.ModelRequestRateLimitDurationMinutes) * time.Minute)
@@ -133,15 +118,12 @@ func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int)
totalKey := ModelRequestRateLimitCountMark + userId
successKey := ModelRequestRateLimitSuccessCountMark + userId
// 1. 检查总请求数限制当totalMaxCount为0时跳过
if totalMaxCount > 0 && !inMemoryRateLimiter.Request(totalKey, totalMaxCount, duration) {
c.Status(http.StatusTooManyRequests)
c.Abort()
return
}
// 2. 检查成功请求数限制
// 使用一个临时key来检查限制这样可以避免实际记录
checkKey := successKey + "_check"
if !inMemoryRateLimiter.Request(checkKey, successMaxCount, duration) {
c.Status(http.StatusTooManyRequests)
@@ -149,54 +131,47 @@ func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int)
return
}
// 3. 处理请求
c.Next()
// 4. 如果请求成功,记录到实际的成功请求计数中
if c.Writer.Status() < 400 {
inMemoryRateLimiter.Request(successKey, successMaxCount, duration)
}
}
}
// ModelRequestRateLimit 模型请求限流中间件
func ModelRequestRateLimit() func(c *gin.Context) {
return func(c *gin.Context) {
// 在每个请求时检查是否启用限流
if !setting.ModelRequestRateLimitEnabled {
c.Next()
return
}
// 计算通用限流参数
duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60)
// 获取用户组
group := c.GetString("token_group")
if group == "" {
group = c.GetString("group")
}
if group == "" {
group = "default" // 默认组
group = "default"
}
// 尝试获取用户组特定的限制
finalTotalCount := setting.ModelRequestRateLimitCount
finalSuccessCount := setting.ModelRequestRateLimitSuccessCount
foundGroupLimit := false
groupTotalCount, groupSuccessCount, found := setting.GetGroupRateLimit(group)
// 确定最终的限制值
finalTotalCount := setting.ModelRequestRateLimitCount // 默认使用全局总次数限制
finalSuccessCount := setting.ModelRequestRateLimitSuccessCount // 默认使用全局成功次数限制
if found {
// 如果找到用户组特定限制,则使用它们
finalTotalCount = groupTotalCount
finalSuccessCount = groupSuccessCount
foundGroupLimit = true
common.LogWarn(c.Request.Context(), fmt.Sprintf("Using rate limit for group '%s': total=%d, success=%d", group, finalTotalCount, finalSuccessCount))
} else {
}
if !foundGroupLimit {
common.LogInfo(c.Request.Context(), fmt.Sprintf("No specific rate limit found for group '%s', using global limits: total=%d, success=%d", group, finalTotalCount, finalSuccessCount))
}
// 根据存储类型选择并执行限流处理器,传入最终确定的限制值
if common.RedisEnabled {
redisRateLimitHandler(duration, finalTotalCount, finalSuccessCount)(c)
} else {

View File

@@ -94,11 +94,12 @@ func InitOptionMap() {
common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount)
common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes)
common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount)
jsonBytes, _ := json.Marshal(map[string][2]int{})
common.OptionMap["ModelRequestRateLimitGroup"] = string(jsonBytes)
common.OptionMap["ModelRatio"] = operation_setting.ModelRatio2JSONString()
common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString()
common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString()
common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
common.OptionMap[setting.ModelRequestRateLimitGroupKey] = "{}" // 添加用户组速率限制默认值
common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
common.OptionMap["CompletionRatio"] = operation_setting.CompletionRatio2JSONString()
common.OptionMap["TopUpLink"] = common.TopUpLink
@@ -153,47 +154,31 @@ func SyncOptions(frequency int) {
}
func UpdateOption(key string, value string) error {
originalValue := value // 保存原始值以备后用
originalValue := value
// Validate and format specific keys before saving
if key == setting.ModelRequestRateLimitGroupKey {
if key == "ModelRequestRateLimitGroup" {
var cfg map[string][2]int
// Validate the JSON structure first using the original value
err := json.Unmarshal([]byte(originalValue), &cfg)
if err != nil {
// 提供更具体的错误信息
return fmt.Errorf("无效的 JSON 格式 for %s: %w", key, err)
}
// TODO: 可以添加更细致的结构验证例如检查数组长度是否为2值是否为非负数等。
// if !isValidModelRequestRateLimitGroupConfig(cfg) {
// return fmt.Errorf("无效的配置值 for %s", key)
// }
// If valid, format the JSON before saving
formattedValueBytes, marshalErr := json.MarshalIndent(cfg, "", " ")
if marshalErr != nil {
// This should ideally not happen if validation passed, but handle defensively
return fmt.Errorf("failed to marshal validated %s config: %w", key, marshalErr)
}
value = string(formattedValueBytes) // Use formatted JSON for saving and memory update
value = string(formattedValueBytes)
}
// Save to database
option := Option{
Key: key,
}
// https://gorm.io/docs/update.html#Save-All-Fields
DB.FirstOrCreate(&option, Option{Key: key})
option.Value = value
// Save is a combination function.
// If save value does not contain primary key, it will execute Create,
// otherwise it will execute Update (with all fields).
if err := DB.Save(&option).Error; err != nil {
return fmt.Errorf("保存选项 %s 到数据库失败: %w", key, err) // 添加错误上下文
return fmt.Errorf("保存选项 %s 到数据库失败: %w", key, err)
}
// Update OptionMap in memory using the potentially formatted value
// updateOptionMap 会处理内存中 setting.ModelRequestRateLimitGroupConfig 的更新
return updateOptionMap(key, value)
}
@@ -370,6 +355,8 @@ func updateOptionMap(key string, value string) (err error) {
setting.ModelRequestRateLimitDurationMinutes, _ = strconv.Atoi(value)
case "ModelRequestRateLimitSuccessCount":
setting.ModelRequestRateLimitSuccessCount, _ = strconv.Atoi(value)
case "ModelRequestRateLimitGroup":
err = setting.UpdateModelRequestRateLimitGroup(value)
case "RetryTimes":
common.RetryTimes, _ = strconv.Atoi(value)
case "DataExportInterval":
@@ -404,15 +391,6 @@ func updateOptionMap(key string, value string) (err error) {
operation_setting.AutomaticDisableKeywordsFromString(value)
case "StreamCacheQueueLength":
setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
case setting.ModelRequestRateLimitGroupKey:
// Use the (potentially formatted) value passed from UpdateOption
// to update the actual configuration in memory.
// This is the single point where the memory state for this specific setting is updated.
err = setting.UpdateModelRequestRateLimitGroupConfig(value)
if err != nil {
// 添加错误上下文
err = fmt.Errorf("更新内存中的 %s 配置失败: %w", key, err)
}
}
return err
}
@@ -440,4 +418,4 @@ func handleConfigUpdate(key, value string) bool {
config.UpdateConfigFromMap(cfg, configMap)
return true // 已处理
}
}

View File

@@ -11,24 +11,17 @@ var ModelRequestRateLimitEnabled = false
var ModelRequestRateLimitDurationMinutes = 1
var ModelRequestRateLimitCount = 0
var ModelRequestRateLimitSuccessCount = 1000
var ModelRequestRateLimitGroup map[string][2]int
// ModelRequestRateLimitGroupKey 定义了模型请求按组速率限制的配置键
const ModelRequestRateLimitGroupKey = "ModelRequestRateLimitGroup"
// ModelRequestRateLimitGroupConfig 存储按用户组解析后的速率限制配置
// map[groupName][2]int{totalCount, successCount}
var ModelRequestRateLimitGroupConfig map[string][2]int
var ModelRequestRateLimitGroupMutex sync.RWMutex
// UpdateModelRequestRateLimitGroupConfig 解析、校验并更新内存中的用户组速率限制配置
func UpdateModelRequestRateLimitGroupConfig(jsonStr string) error {
func UpdateModelRequestRateLimitGroup(jsonStr string) error {
ModelRequestRateLimitGroupMutex.Lock()
defer ModelRequestRateLimitGroupMutex.Unlock()
var newConfig map[string][2]int
if jsonStr == "" || jsonStr == "{}" {
// 如果配置为空或空JSON对象则清空内存配置
ModelRequestRateLimitGroupConfig = make(map[string][2]int)
ModelRequestRateLimitGroup = make(map[string][2]int)
common.SysLog("Model request rate limit group config cleared")
return nil
}
@@ -38,37 +31,19 @@ func UpdateModelRequestRateLimitGroupConfig(jsonStr string) error {
return fmt.Errorf("failed to unmarshal ModelRequestRateLimitGroup config: %w", err)
}
// 校验配置值
for group, limits := range newConfig {
if len(limits) != 2 {
return fmt.Errorf("invalid config for group '%s': limits array length must be 2", group)
}
if limits[1] <= 0 { // successCount must be greater than 0
return fmt.Errorf("invalid config for group '%s': successCount (limits[1]) must be greater than 0", group)
}
if limits[0] < 0 { // totalCount can be 0 (no limit) or positive
return fmt.Errorf("invalid config for group '%s': totalCount (limits[0]) cannot be negative", group)
}
if limits[0] > 0 && limits[0] < limits[1] { // If totalCount is set, it must be >= successCount
return fmt.Errorf("invalid config for group '%s': totalCount (limits[0]) must be greater than or equal to successCount (limits[1]) when totalCount > 0", group)
}
}
ModelRequestRateLimitGroupConfig = newConfig
common.SysLog("Model request rate limit group config updated")
ModelRequestRateLimitGroup = newConfig
return nil
}
// GetGroupRateLimit 安全地获取指定用户组的速率限制值
func GetGroupRateLimit(group string) (totalCount, successCount int, found bool) {
ModelRequestRateLimitGroupMutex.RLock()
defer ModelRequestRateLimitGroupMutex.RUnlock()
if ModelRequestRateLimitGroupConfig == nil {
return 0, 0, false // 配置尚未初始化
if ModelRequestRateLimitGroup == nil {
return 0, 0, false
}
limits, found := ModelRequestRateLimitGroupConfig[group]
limits, found := ModelRequestRateLimitGroup[group]
if !found {
return 0, 0, false
}

View File

@@ -9,59 +9,59 @@ import RequestRateLimit from '../pages/Setting/RateLimit/SettingsRequestRateLimi
const RateLimitSetting = () => {
const { t } = useTranslation();
let [inputs, setInputs] = useState({
ModelRequestRateLimitEnabled: false,
ModelRequestRateLimitCount: 0,
ModelRequestRateLimitSuccessCount: 1000,
ModelRequestRateLimitDurationMinutes: 1,
ModelRequestRateLimitGroup: {},
ModelRequestRateLimitEnabled: false,
ModelRequestRateLimitCount: 0,
ModelRequestRateLimitSuccessCount: 1000,
ModelRequestRateLimitDurationMinutes: 1,
ModelRequestRateLimitGroup: '{}',
});
let [loading, setLoading] = useState(false);
const getOptions = async () => {
const res = await API.get('/api/option/');
const { success, message, data } = res.data;
if (success) {
let newInputs = {};
data.forEach((item) => {
if (item.key.endsWith('Enabled')) {
newInputs[item.key] = item.value === 'true' ? true : false;
} else {
newInputs[item.key] = item.value;
}
});
setInputs(newInputs);
} else {
showError(message);
}
const res = await API.get('/api/option/');
const { success, message, data } = res.data;
if (success) {
let newInputs = {};
data.forEach((item) => {
// 检查 key 是否在初始 inputs 中定义
if (Object.prototype.hasOwnProperty.call(inputs, item.key)) {
if (item.key.endsWith('Enabled')) {
newInputs[item.key] = item.value === 'true';
} else {
newInputs[item.key] = item.value;
}
}
});
setInputs(newInputs);
} else {
showError(message);
}
};
async function onRefresh() {
try {
setLoading(true);
await getOptions();
// showSuccess('刷新成功');
} catch (error) {
showError('刷新失败');
} finally {
setLoading(false);
}
try {
setLoading(true);
await getOptions();
} catch (error) {
showError('刷新失败');
} finally {
setLoading(false);
}
}
useEffect(() => {
onRefresh();
onRefresh();
}, []);
return (
<>
<Spin spinning={loading} size='large'>
{/* AI请求速率限制 */}
<Card style={{ marginTop: '10px' }}>
<RequestRateLimit options={inputs} refresh={onRefresh} />
</Card>
</Spin>
</>
<>
<Spin spinning={loading} size='large'>
<Card style={{ marginTop: '10px' }}>
<RequestRateLimit options={inputs} refresh={onRefresh} />
</Card>
</Spin>
</>
);
};
export default RateLimitSetting;
};
export default RateLimitSetting;

View File

@@ -14,209 +14,201 @@ export default function RequestRateLimit(props) {
const [loading, setLoading] = useState(false);
const [inputs, setInputs] = useState({
ModelRequestRateLimitEnabled: false,
ModelRequestRateLimitCount: -1,
ModelRequestRateLimitSuccessCount: 1000,
ModelRequestRateLimitDurationMinutes: 1,
ModelRequestRateLimitGroup: '{}', // 添加新字段并设置默认值
ModelRequestRateLimitEnabled: false,
ModelRequestRateLimitCount: -1,
ModelRequestRateLimitSuccessCount: 1000,
ModelRequestRateLimitDurationMinutes: 1,
ModelRequestRateLimitGroup: '{}',
});
const refForm = useRef();
const [inputsRow, setInputsRow] = useState(inputs);
function onSubmit() {
const updateArray = compareObjects(inputs, inputsRow);
if (!updateArray.length) return showWarning(t('你似乎并没有修改什么'));
const requestQueue = updateArray.map((item) => {
let value = '';
if (typeof inputs[item.key] === 'boolean') {
value = String(inputs[item.key]);
} else {
value = inputs[item.key];
}
// 校验 ModelRequestRateLimitGroup 是否为有效的 JSON 对象字符串
if (item.key === 'ModelRequestRateLimitGroup') {
try {
JSON.parse(value);
} catch (e) {
showError(t('用户组速率限制配置不是有效的 JSON 格式!'));
// 阻止请求发送
return Promise.reject('Invalid JSON format');
}
}
return API.put('/api/option/', {
key: item.key,
value,
});
});
// 过滤掉无效的请求(例如,无效的 JSON
const validRequests = requestQueue.filter(req => req !== null && req !== undefined && typeof req.then === 'function');
if (validRequests.length === 0 && requestQueue.length > 0) {
// 如果所有请求都被过滤掉了(因为 JSON 无效),则不继续执行
return;
}
setLoading(true);
Promise.all(validRequests)
.then((res) => {
if (validRequests.length === 1) {
if (res.includes(undefined)) return;
} else if (validRequests.length > 1) {
if (res.includes(undefined))
return showError(t('部分保存失败,请重试'));
}
showSuccess(t('保存成功'));
props.refresh();
// 更新 inputsRow 以反映保存后的状态
setInputsRow(structuredClone(inputs));
})
.catch((error) => {
// 检查是否是由于无效 JSON 导致的错误
if (error !== 'Invalid JSON format') {
showError(t('保存失败,请重试'));
}
})
.finally(() => {
setLoading(false);
});
const updateArray = compareObjects(inputs, inputsRow);
if (!updateArray.length) return showWarning(t('你似乎并没有修改什么'));
const requestQueue = updateArray.map((item) => {
let value = '';
if (typeof inputs[item.key] === 'boolean') {
value = String(inputs[item.key]);
} else {
value = inputs[item.key];
}
if (item.key === 'ModelRequestRateLimitGroup') {
try {
JSON.parse(value);
} catch (e) {
showError(t('用户组速率限制配置不是有效的 JSON 格式!'));
return Promise.reject('Invalid JSON format');
}
}
return API.put('/api/option/', {
key: item.key,
value,
});
});
const validRequests = requestQueue.filter(req => req !== null && req !== undefined && typeof req.then === 'function');
if (validRequests.length === 0 && requestQueue.length > 0) {
return;
}
setLoading(true);
Promise.all(validRequests)
.then((res) => {
if (validRequests.length === 1) {
if (res.includes(undefined)) return;
} else if (validRequests.length > 1) {
if (res.includes(undefined))
return showError(t('部分保存失败,请重试'));
}
showSuccess(t('保存成功'));
props.refresh();
setInputsRow(structuredClone(inputs));
})
.catch((error) => {
if (error !== 'Invalid JSON format') {
showError(t('保存失败,请重试'));
}
})
.finally(() => {
setLoading(false);
});
}
useEffect(() => {
const currentInputs = {};
for (let key in props.options) {
if (Object.keys(inputs).includes(key)) {
currentInputs[key] = props.options[key];
}
}
setInputs(currentInputs);
setInputsRow(structuredClone(currentInputs));
// 检查 refForm.current 是否存在
if (refForm.current) {
refForm.current.setValues(currentInputs);
}
}, [props.options]); // 依赖项保持不变,因为 inputs 状态的结构已固定
const currentInputs = {};
for (let key in props.options) {
if (Object.prototype.hasOwnProperty.call(inputs, key)) { // 使用 hasOwnProperty 检查
currentInputs[key] = props.options[key];
}
}
setInputs(currentInputs);
setInputsRow(structuredClone(currentInputs));
if (refForm.current) {
refForm.current.setValues(currentInputs);
}
}, [props.options]);
return (
<>
<Spin spinning={loading}>
<Form
values={inputs}
getFormApi={(formAPI) => (refForm.current = formAPI)}
style={{ marginBottom: 15 }}
>
<Form.Section text={t('模型请求速率限制')}>
<Row gutter={16}>
<Col xs={24} sm={12} md={8} lg={8} xl={8}>
<Form.Switch
field={'ModelRequestRateLimitEnabled'}
label={t('启用用户模型请求速率限制(可能会影响高并发性能)')}
size='default'
checkedText=''
uncheckedText=''
onChange={(value) => {
setInputs({
...inputs,
ModelRequestRateLimitEnabled: value,
});
}}
/>
</Col>
</Row>
<Row>
<Col xs={24} sm={12} md={8} lg={8} xl={8}>
<Form.InputNumber
label={t('限制周期')}
step={1}
min={0}
suffix={t('分钟')}
extraText={t('频率限制的周期(分钟)')}
field={'ModelRequestRateLimitDurationMinutes'}
onChange={(value) =>
setInputs({
...inputs,
ModelRequestRateLimitDurationMinutes: String(value),
})
}
/>
</Col>
</Row>
<Row>
<Col xs={24} sm={12} md={8} lg={8} xl={8}>
<Form.InputNumber
label={t('用户每周期最多请求次数')}
step={1}
min={0}
suffix={t('次')}
extraText={t('包括失败请求的次数0代表不限制')}
field={'ModelRequestRateLimitCount'}
onChange={(value) =>
setInputs({
...inputs,
ModelRequestRateLimitCount: String(value),
})
}
/>
</Col>
<Col xs={24} sm={12} md={8} lg={8} xl={8}>
<Form.InputNumber
label={t('用户每周期最多请求完成次数')}
step={1}
min={1}
suffix={t('次')}
extraText={t('只包括请求成功的次数')}
field={'ModelRequestRateLimitSuccessCount'}
onChange={(value) =>
setInputs({
...inputs,
ModelRequestRateLimitSuccessCount: String(value),
})
}
/>
</Col>
</Row>
{/* 用户组速率限制配置项 */}
<Row>
<Col span={24}>
<Form.TextArea
label={t('用户组速率限制 (JSON)')}
field={'ModelRequestRateLimitGroup'}
placeholder={t( // 更新 placeholder
'请输入 JSON 格式的用户组限制,例如:\n{\n "default": [200, 100],\n "vip": [1000, 500]\n}',
)}
extraText={ // 更新 extraText
<div>
<p>{t('说明:')}</p>
<ul>
<li>{t('使用 JSON 对象格式,键为用户组名 (字符串),值为包含两个整数的数组 [总次数限制, 成功次数限制]。')}</li>
<li>{t('次数限制: 周期内允许的请求次数 (含失败)0 代表不限制。')}</li>
<li>{t('成功次数限制: 周期内允许的成功请求次数 (HTTP < 400),必须大于 0。')}</li>
<li>{t('此配置将优先于上方的全局限制设置。')}</li>
<li>{t('未在此处配置的用户组将使用全局限制。')}</li>
<li>{t('限制周期统一使用上方配置的“限制周期”值。')}</li>
<li>{t('输入无效的 JSON 将无法保存。')}</li>
</ul>
</div>
}
autosize={{ minRows: 5, maxRows: 15 }}
style={{ fontFamily: 'monospace' }}
onChange={(value) => {
setInputs({
...inputs,
ModelRequestRateLimitGroup: value, // 直接更新字符串值
});
}}
/>
</Col>
</Row>
<Row style={{ marginTop: 15 }}>
<Button size='default' onClick={onSubmit}>
{t('保存模型速率限制')}
</Button>
</Row>
</Form.Section>
</Form>
</Spin>
</>
<>
<Spin spinning={loading}>
<Form
values={inputs}
getFormApi={(formAPI) => (refForm.current = formAPI)}
style={{ marginBottom: 15 }}
>
<Form.Section text={t('模型请求速率限制')}>
<Row gutter={16}>
<Col xs={24} sm={12} md={8} lg={8} xl={8}>
<Form.Switch
field={'ModelRequestRateLimitEnabled'}
label={t('启用用户模型请求速率限制(可能会影响高并发性能)')}
size='default'
checkedText=''
uncheckedText=''
onChange={(value) => {
setInputs({
...inputs,
ModelRequestRateLimitEnabled: value,
});
}}
/>
</Col>
</Row>
<Row>
<Col xs={24} sm={12} md={8} lg={8} xl={8}>
<Form.InputNumber
label={t('限制周期')}
step={1}
min={0}
suffix={t('分钟')}
extraText={t('频率限制的周期(分钟)')}
field={'ModelRequestRateLimitDurationMinutes'}
onChange={(value) =>
setInputs({
...inputs,
ModelRequestRateLimitDurationMinutes: String(value),
})
}
/>
</Col>
</Row>
<Row>
<Col xs={24} sm={12} md={8} lg={8} xl={8}>
<Form.InputNumber
label={t('用户每周期最多请求次数')}
step={1}
min={0}
suffix={t('次')}
extraText={t('包括失败请求的次数0代表不限制')}
field={'ModelRequestRateLimitCount'}
onChange={(value) =>
setInputs({
...inputs,
ModelRequestRateLimitCount: String(value),
})
}
/>
</Col>
<Col xs={24} sm={12} md={8} lg={8} xl={8}>
<Form.InputNumber
label={t('用户每周期最多请求完成次数')}
step={1}
min={1}
suffix={t('次')}
extraText={t('只包括请求成功的次数')}
field={'ModelRequestRateLimitSuccessCount'}
onChange={(value) =>
setInputs({
...inputs,
ModelRequestRateLimitSuccessCount: String(value),
})
}
/>
</Col>
</Row>
<Row style={{ marginTop: 15 }}>
<Col span={24}>
<Form.TextArea
label={t('用户组速率限制 (JSON)')}
field={'ModelRequestRateLimitGroup'}
placeholder={t(
'请输入 JSON 格式的用户组限制,例如:\n{\n "default": [200, 100],\n "vip": [1000, 500]\n}',
)}
extraText={
<div>
<p>{t('说明:')}</p>
<ul>
<li>{t('使用 JSON 对象格式,键为用户组名 (字符串),值为包含两个整数的数组 [总次数限制, 成功次数限制]。')}</li>
<li>{t('总次数限制: 周期内允许的总请求次数 (含失败)0 代表不限制。')}</li>
<li>{t('成功次数限制: 周期内允许的成功请求次数 (HTTP < 400),必须大于 0。')}</li>
<li>{t('此配置将优先于上方的全局限制设置。')}</li>
<li>{t('未在此处配置的用户组将使用全局限制。')}</li>
<li>{t('限制周期统一使用上方配置的“限制周期”值。')}</li>
<li>{t('输入无效的 JSON 将无法保存。')}</li>
</ul>
</div>
}
autosize={{ minRows: 5, maxRows: 15 }}
style={{ fontFamily: 'monospace' }}
onChange={(value) => {
setInputs({
...inputs,
ModelRequestRateLimitGroup: value,
});
}}
/>
</Col>
</Row>
<Row style={{ marginTop: 15 }}>
<Button size='default' onClick={onSubmit}>
{t('保存模型速率限制')}
</Button>
</Row>
</Form.Section>
</Form>
</Spin>
</>
);
}
}