Reason: The original steps 1 and 3 in the redisRateLimitHandler method were not atomic, leading to poor precision under high concurrent requests. For example, with a rate limit set to 60, sending 200 concurrent requests would result in none being blocked, whereas theoretically around 140 should be intercepted. Solution: I chose not to merge steps 1 and 3 into a single Lua script because a single atomic operation involving read, write, and delete operations could suffer from performance issues under high concurrency. Instead, I implemented a token bucket algorithm to optimize this, reducing the atomic operation to just read and write steps while significantly decreasing the memory footprint.
184 lines
5.4 KiB
Go
184 lines
5.4 KiB
Go
package middleware
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"net/http"
|
||
"one-api/common"
|
||
"one-api/common/limiter"
|
||
"one-api/setting"
|
||
"strconv"
|
||
"time"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/go-redis/redis/v8"
|
||
)
|
||
|
||
const (
|
||
ModelRequestRateLimitCountMark = "MRRL"
|
||
ModelRequestRateLimitSuccessCountMark = "MRRLS"
|
||
)
|
||
|
||
// 检查Redis中的请求限制
|
||
func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, maxCount int, duration int64) (bool, error) {
|
||
// 如果maxCount为0,表示不限制
|
||
if maxCount == 0 {
|
||
return true, nil
|
||
}
|
||
|
||
// 获取当前计数
|
||
length, err := rdb.LLen(ctx, key).Result()
|
||
if err != nil {
|
||
return false, err
|
||
}
|
||
|
||
// 如果未达到限制,允许请求
|
||
if length < int64(maxCount) {
|
||
return true, nil
|
||
}
|
||
|
||
// 检查时间窗口
|
||
oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result()
|
||
oldTime, err := time.Parse(timeFormat, oldTimeStr)
|
||
if err != nil {
|
||
return false, err
|
||
}
|
||
|
||
nowTimeStr := time.Now().Format(timeFormat)
|
||
nowTime, err := time.Parse(timeFormat, nowTimeStr)
|
||
if err != nil {
|
||
return false, err
|
||
}
|
||
// 如果在时间窗口内已达到限制,拒绝请求
|
||
subTime := nowTime.Sub(oldTime).Seconds()
|
||
if int64(subTime) < duration {
|
||
rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute)
|
||
return false, nil
|
||
}
|
||
|
||
return true, nil
|
||
}
|
||
|
||
// 记录Redis请求
|
||
func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxCount int) {
|
||
// 如果maxCount为0,不记录请求
|
||
if maxCount == 0 {
|
||
return
|
||
}
|
||
|
||
now := time.Now().Format(timeFormat)
|
||
rdb.LPush(ctx, key, now)
|
||
rdb.LTrim(ctx, key, 0, int64(maxCount-1))
|
||
rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute)
|
||
}
|
||
|
||
// Redis限流处理器
|
||
func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
userId := strconv.Itoa(c.GetInt("id"))
|
||
ctx := context.Background()
|
||
rdb := common.RDB
|
||
|
||
// 1. 检查成功请求数限制
|
||
successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId)
|
||
allowed, err := checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration)
|
||
if err != nil {
|
||
fmt.Println("检查成功请求数限制失败:", err.Error())
|
||
abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
|
||
return
|
||
}
|
||
if !allowed {
|
||
abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到请求数限制:%d分钟内最多请求%d次", setting.ModelRequestRateLimitDurationMinutes, successMaxCount))
|
||
return
|
||
}
|
||
//检查总请求数限制并记录总请求(当totalMaxCount为0时会自动跳过,使用令牌桶限流器
|
||
totalKey := fmt.Sprintf("rateLimit:%s", userId)
|
||
//allowed, err = checkRedisRateLimit(ctx, rdb, totalKey, totalMaxCount, duration)
|
||
// 初始化
|
||
tb := limiter.New(ctx, rdb)
|
||
allowed, err = tb.Allow(
|
||
ctx,
|
||
totalKey,
|
||
limiter.WithCapacity(int64(totalMaxCount)*duration),
|
||
limiter.WithRate(int64(totalMaxCount)),
|
||
limiter.WithRequested(duration),
|
||
)
|
||
|
||
if err != nil {
|
||
fmt.Println("检查总请求数限制失败:", err.Error())
|
||
abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
|
||
return
|
||
}
|
||
|
||
if !allowed {
|
||
abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount))
|
||
}
|
||
|
||
// 4. 处理请求
|
||
c.Next()
|
||
|
||
// 5. 如果请求成功,记录成功请求
|
||
if c.Writer.Status() < 400 {
|
||
recordRedisRequest(ctx, rdb, successKey, successMaxCount)
|
||
}
|
||
}
|
||
}
|
||
|
||
// 内存限流处理器
|
||
func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc {
|
||
inMemoryRateLimiter.Init(time.Duration(setting.ModelRequestRateLimitDurationMinutes) * time.Minute)
|
||
|
||
return func(c *gin.Context) {
|
||
userId := strconv.Itoa(c.GetInt("id"))
|
||
totalKey := ModelRequestRateLimitCountMark + userId
|
||
successKey := ModelRequestRateLimitSuccessCountMark + userId
|
||
|
||
// 1. 检查总请求数限制(当totalMaxCount为0时跳过)
|
||
if totalMaxCount > 0 && !inMemoryRateLimiter.Request(totalKey, totalMaxCount, duration) {
|
||
c.Status(http.StatusTooManyRequests)
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
// 2. 检查成功请求数限制
|
||
// 使用一个临时key来检查限制,这样可以避免实际记录
|
||
checkKey := successKey + "_check"
|
||
if !inMemoryRateLimiter.Request(checkKey, successMaxCount, duration) {
|
||
c.Status(http.StatusTooManyRequests)
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
// 3. 处理请求
|
||
c.Next()
|
||
|
||
// 4. 如果请求成功,记录到实际的成功请求计数中
|
||
if c.Writer.Status() < 400 {
|
||
inMemoryRateLimiter.Request(successKey, successMaxCount, duration)
|
||
}
|
||
}
|
||
}
|
||
|
||
// ModelRequestRateLimit 模型请求限流中间件
|
||
func ModelRequestRateLimit() func(c *gin.Context) {
|
||
return func(c *gin.Context) {
|
||
// 在每个请求时检查是否启用限流
|
||
if !setting.ModelRequestRateLimitEnabled {
|
||
c.Next()
|
||
return
|
||
}
|
||
|
||
// 计算限流参数
|
||
duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60)
|
||
totalMaxCount := setting.ModelRequestRateLimitCount
|
||
successMaxCount := setting.ModelRequestRateLimitSuccessCount
|
||
|
||
// 根据存储类型选择并执行限流处理器
|
||
if common.RedisEnabled {
|
||
redisRateLimitHandler(duration, totalMaxCount, successMaxCount)(c)
|
||
} else {
|
||
memoryRateLimitHandler(duration, totalMaxCount, successMaxCount)(c)
|
||
}
|
||
}
|
||
}
|