refactor: 调整代码,符合项目现有规范
This commit is contained in:
@@ -6,6 +6,7 @@ import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/limiter"
|
||||
"one-api/constant"
|
||||
"one-api/setting"
|
||||
"strconv"
|
||||
"time"
|
||||
@@ -19,20 +20,25 @@ const (
|
||||
ModelRequestRateLimitSuccessCountMark = "MRRLS"
|
||||
)
|
||||
|
||||
// 检查Redis中的请求限制
|
||||
func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, maxCount int, duration int64) (bool, error) {
|
||||
// 如果maxCount为0,表示不限制
|
||||
if maxCount == 0 {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// 获取当前计数
|
||||
length, err := rdb.LLen(ctx, key).Result()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// 如果未达到限制,允许请求
|
||||
if length < int64(maxCount) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// 检查时间窗口
|
||||
oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result()
|
||||
oldTime, err := time.Parse(timeFormat, oldTimeStr)
|
||||
if err != nil {
|
||||
@@ -44,6 +50,7 @@ func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, max
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
// 如果在时间窗口内已达到限制,拒绝请求
|
||||
subTime := nowTime.Sub(oldTime).Seconds()
|
||||
if int64(subTime) < duration {
|
||||
rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute)
|
||||
@@ -53,7 +60,9 @@ func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, max
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// 记录Redis请求
|
||||
func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxCount int) {
|
||||
// 如果maxCount为0,不记录请求
|
||||
if maxCount == 0 {
|
||||
return
|
||||
}
|
||||
@@ -64,12 +73,14 @@ func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxC
|
||||
rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute)
|
||||
}
|
||||
|
||||
// Redis限流处理器
|
||||
func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
userId := strconv.Itoa(c.GetInt("id"))
|
||||
ctx := context.Background()
|
||||
rdb := common.RDB
|
||||
|
||||
// 1. 检查成功请求数限制
|
||||
successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId)
|
||||
allowed, err := checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration)
|
||||
if err != nil {
|
||||
@@ -82,7 +93,9 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g
|
||||
return
|
||||
}
|
||||
|
||||
//2.检查总请求数限制并记录总请求(当totalMaxCount为0时会自动跳过,使用令牌桶限流器
|
||||
totalKey := fmt.Sprintf("rateLimit:%s", userId)
|
||||
// 初始化
|
||||
tb := limiter.New(ctx, rdb)
|
||||
allowed, err = tb.Allow(
|
||||
ctx,
|
||||
@@ -102,14 +115,17 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g
|
||||
abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount))
|
||||
}
|
||||
|
||||
// 4. 处理请求
|
||||
c.Next()
|
||||
|
||||
// 5. 如果请求成功,记录成功请求
|
||||
if c.Writer.Status() < 400 {
|
||||
recordRedisRequest(ctx, rdb, successKey, successMaxCount)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 内存限流处理器
|
||||
func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc {
|
||||
inMemoryRateLimiter.Init(time.Duration(setting.ModelRequestRateLimitDurationMinutes) * time.Minute)
|
||||
|
||||
@@ -118,12 +134,15 @@ func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int)
|
||||
totalKey := ModelRequestRateLimitCountMark + userId
|
||||
successKey := ModelRequestRateLimitSuccessCountMark + userId
|
||||
|
||||
// 1. 检查总请求数限制(当totalMaxCount为0时跳过)
|
||||
if totalMaxCount > 0 && !inMemoryRateLimiter.Request(totalKey, totalMaxCount, duration) {
|
||||
c.Status(http.StatusTooManyRequests)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 2. 检查成功请求数限制
|
||||
// 使用一个临时key来检查限制,这样可以避免实际记录
|
||||
checkKey := successKey + "_check"
|
||||
if !inMemoryRateLimiter.Request(checkKey, successMaxCount, duration) {
|
||||
c.Status(http.StatusTooManyRequests)
|
||||
@@ -131,51 +150,48 @@ func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int)
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 处理请求
|
||||
c.Next()
|
||||
|
||||
// 4. 如果请求成功,记录到实际的成功请求计数中
|
||||
if c.Writer.Status() < 400 {
|
||||
inMemoryRateLimiter.Request(successKey, successMaxCount, duration)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ModelRequestRateLimit 模型请求限流中间件
|
||||
func ModelRequestRateLimit() func(c *gin.Context) {
|
||||
return func(c *gin.Context) {
|
||||
// 在每个请求时检查是否启用限流
|
||||
if !setting.ModelRequestRateLimitEnabled {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 计算限流参数
|
||||
duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60)
|
||||
totalMaxCount := setting.ModelRequestRateLimitCount
|
||||
successMaxCount := setting.ModelRequestRateLimitSuccessCount
|
||||
|
||||
// 获取分组
|
||||
group := c.GetString("token_group")
|
||||
if group == "" {
|
||||
group = c.GetString("group")
|
||||
}
|
||||
if group == "" {
|
||||
group = "default"
|
||||
group = c.GetString(constant.ContextKeyUserGroup)
|
||||
}
|
||||
|
||||
finalTotalCount := setting.ModelRequestRateLimitCount
|
||||
finalSuccessCount := setting.ModelRequestRateLimitSuccessCount
|
||||
foundGroupLimit := false
|
||||
|
||||
//获取分组的限流配置
|
||||
groupTotalCount, groupSuccessCount, found := setting.GetGroupRateLimit(group)
|
||||
if found {
|
||||
finalTotalCount = groupTotalCount
|
||||
finalSuccessCount = groupSuccessCount
|
||||
foundGroupLimit = true
|
||||
common.LogWarn(c.Request.Context(), fmt.Sprintf("Using rate limit for group '%s': total=%d, success=%d", group, finalTotalCount, finalSuccessCount))
|
||||
}
|
||||
|
||||
if !foundGroupLimit {
|
||||
common.LogInfo(c.Request.Context(), fmt.Sprintf("No specific rate limit found for group '%s', using global limits: total=%d, success=%d", group, finalTotalCount, finalSuccessCount))
|
||||
totalMaxCount = groupTotalCount
|
||||
successMaxCount = groupSuccessCount
|
||||
}
|
||||
|
||||
// 根据存储类型选择并执行限流处理器
|
||||
if common.RedisEnabled {
|
||||
redisRateLimitHandler(duration, finalTotalCount, finalSuccessCount)(c)
|
||||
redisRateLimitHandler(duration, totalMaxCount, successMaxCount)(c)
|
||||
} else {
|
||||
memoryRateLimitHandler(duration, finalTotalCount, finalSuccessCount)(c)
|
||||
memoryRateLimitHandler(duration, totalMaxCount, successMaxCount)(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,8 +1,6 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/setting"
|
||||
"one-api/setting/config"
|
||||
@@ -94,8 +92,7 @@ func InitOptionMap() {
|
||||
common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount)
|
||||
common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes)
|
||||
common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount)
|
||||
jsonBytes, _ := json.Marshal(map[string][2]int{})
|
||||
common.OptionMap["ModelRequestRateLimitGroup"] = string(jsonBytes)
|
||||
common.OptionMap["ModelRequestRateLimitGroup"] = setting.ModelRequestRateLimitGroup2JSONString()
|
||||
common.OptionMap["ModelRatio"] = operation_setting.ModelRatio2JSONString()
|
||||
common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString()
|
||||
common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString()
|
||||
@@ -154,31 +151,18 @@ func SyncOptions(frequency int) {
|
||||
}
|
||||
|
||||
func UpdateOption(key string, value string) error {
|
||||
originalValue := value
|
||||
|
||||
if key == "ModelRequestRateLimitGroup" {
|
||||
var cfg map[string][2]int
|
||||
err := json.Unmarshal([]byte(originalValue), &cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("无效的 JSON 格式 for %s: %w", key, err)
|
||||
}
|
||||
|
||||
formattedValueBytes, marshalErr := json.MarshalIndent(cfg, "", " ")
|
||||
if marshalErr != nil {
|
||||
return fmt.Errorf("failed to marshal validated %s config: %w", key, marshalErr)
|
||||
}
|
||||
value = string(formattedValueBytes)
|
||||
}
|
||||
|
||||
// Save to database first
|
||||
option := Option{
|
||||
Key: key,
|
||||
}
|
||||
// https://gorm.io/docs/update.html#Save-All-Fields
|
||||
DB.FirstOrCreate(&option, Option{Key: key})
|
||||
option.Value = value
|
||||
if err := DB.Save(&option).Error; err != nil {
|
||||
return fmt.Errorf("保存选项 %s 到数据库失败: %w", key, err)
|
||||
}
|
||||
|
||||
// Save is a combination function.
|
||||
// If save value does not contain primary key, it will execute Create,
|
||||
// otherwise it will execute Update (with all fields).
|
||||
DB.Save(&option)
|
||||
// Update OptionMap
|
||||
return updateOptionMap(key, value)
|
||||
}
|
||||
|
||||
@@ -356,7 +340,7 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
case "ModelRequestRateLimitSuccessCount":
|
||||
setting.ModelRequestRateLimitSuccessCount, _ = strconv.Atoi(value)
|
||||
case "ModelRequestRateLimitGroup":
|
||||
err = setting.UpdateModelRequestRateLimitGroup(value)
|
||||
err = setting.UpdateModelRequestRateLimitGroupByJSONString(value)
|
||||
case "RetryTimes":
|
||||
common.RetryTimes, _ = strconv.Atoi(value)
|
||||
case "DataExportInterval":
|
||||
|
||||
@@ -2,7 +2,6 @@ package setting
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"sync"
|
||||
)
|
||||
@@ -11,33 +10,31 @@ var ModelRequestRateLimitEnabled = false
|
||||
var ModelRequestRateLimitDurationMinutes = 1
|
||||
var ModelRequestRateLimitCount = 0
|
||||
var ModelRequestRateLimitSuccessCount = 1000
|
||||
var ModelRequestRateLimitGroup map[string][2]int
|
||||
var ModelRequestRateLimitGroup = map[string][2]int{}
|
||||
var ModelRequestRateLimitMutex sync.RWMutex
|
||||
|
||||
var ModelRequestRateLimitGroupMutex sync.RWMutex
|
||||
func ModelRequestRateLimitGroup2JSONString() string {
|
||||
ModelRequestRateLimitMutex.RLock()
|
||||
defer ModelRequestRateLimitMutex.RUnlock()
|
||||
|
||||
func UpdateModelRequestRateLimitGroup(jsonStr string) error {
|
||||
ModelRequestRateLimitGroupMutex.Lock()
|
||||
defer ModelRequestRateLimitGroupMutex.Unlock()
|
||||
|
||||
var newConfig map[string][2]int
|
||||
if jsonStr == "" || jsonStr == "{}" {
|
||||
ModelRequestRateLimitGroup = make(map[string][2]int)
|
||||
common.SysLog("Model request rate limit group config cleared")
|
||||
return nil
|
||||
}
|
||||
|
||||
err := json.Unmarshal([]byte(jsonStr), &newConfig)
|
||||
jsonBytes, err := json.Marshal(ModelRequestRateLimitGroup)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unmarshal ModelRequestRateLimitGroup config: %w", err)
|
||||
common.SysError("error marshalling model ratio: " + err.Error())
|
||||
}
|
||||
return string(jsonBytes)
|
||||
}
|
||||
|
||||
ModelRequestRateLimitGroup = newConfig
|
||||
return nil
|
||||
func UpdateModelRequestRateLimitGroupByJSONString(jsonStr string) error {
|
||||
ModelRequestRateLimitMutex.RLock()
|
||||
defer ModelRequestRateLimitMutex.RUnlock()
|
||||
|
||||
ModelRequestRateLimitGroup = make(map[string][2]int)
|
||||
return json.Unmarshal([]byte(jsonStr), &ModelRequestRateLimitGroup)
|
||||
}
|
||||
|
||||
func GetGroupRateLimit(group string) (totalCount, successCount int, found bool) {
|
||||
ModelRequestRateLimitGroupMutex.RLock()
|
||||
defer ModelRequestRateLimitGroupMutex.RUnlock()
|
||||
ModelRequestRateLimitMutex.RLock()
|
||||
defer ModelRequestRateLimitMutex.RUnlock()
|
||||
|
||||
if ModelRequestRateLimitGroup == nil {
|
||||
return 0, 0, false
|
||||
|
||||
@@ -24,7 +24,6 @@ const RateLimitSetting = () => {
|
||||
if (success) {
|
||||
let newInputs = {};
|
||||
data.forEach((item) => {
|
||||
// 检查 key 是否在初始 inputs 中定义
|
||||
if (Object.prototype.hasOwnProperty.call(inputs, item.key)) {
|
||||
if (item.key.endsWith('Enabled')) {
|
||||
newInputs[item.key] = item.value === 'true';
|
||||
@@ -33,6 +32,7 @@ const RateLimitSetting = () => {
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
setInputs(newInputs);
|
||||
} else {
|
||||
showError(message);
|
||||
@@ -42,6 +42,7 @@ const RateLimitSetting = () => {
|
||||
try {
|
||||
setLoading(true);
|
||||
await getOptions();
|
||||
// showSuccess('刷新成功');
|
||||
} catch (error) {
|
||||
showError('刷新失败');
|
||||
} finally {
|
||||
@@ -56,6 +57,7 @@ const RateLimitSetting = () => {
|
||||
return (
|
||||
<>
|
||||
<Spin spinning={loading} size='large'>
|
||||
{/* AI请求速率限制 */}
|
||||
<Card style={{ marginTop: '10px' }}>
|
||||
<RequestRateLimit options={inputs} refresh={onRefresh} />
|
||||
</Card>
|
||||
@@ -64,4 +66,4 @@ const RateLimitSetting = () => {
|
||||
);
|
||||
};
|
||||
|
||||
export default RateLimitSetting;
|
||||
export default RateLimitSetting;
|
||||
Reference in New Issue
Block a user