Files
new-api/middleware/model-rate-limit.go

182 lines
5.1 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package middleware
import (
"context"
"fmt"
"net/http"
"one-api/common"
"one-api/common/limiter"
"one-api/setting"
"strconv"
"time"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
)
const (
ModelRequestRateLimitCountMark = "MRRL"
ModelRequestRateLimitSuccessCountMark = "MRRLS"
)
func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, maxCount int, duration int64) (bool, error) {
if maxCount == 0 {
return true, nil
}
length, err := rdb.LLen(ctx, key).Result()
if err != nil {
return false, err
}
if length < int64(maxCount) {
return true, nil
}
oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result()
oldTime, err := time.Parse(timeFormat, oldTimeStr)
if err != nil {
return false, err
}
nowTimeStr := time.Now().Format(timeFormat)
nowTime, err := time.Parse(timeFormat, nowTimeStr)
if err != nil {
return false, err
}
subTime := nowTime.Sub(oldTime).Seconds()
if int64(subTime) < duration {
rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute)
return false, nil
}
return true, nil
}
func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxCount int) {
if maxCount == 0 {
return
}
now := time.Now().Format(timeFormat)
rdb.LPush(ctx, key, now)
rdb.LTrim(ctx, key, 0, int64(maxCount-1))
rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute)
}
func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc {
return func(c *gin.Context) {
userId := strconv.Itoa(c.GetInt("id"))
ctx := context.Background()
rdb := common.RDB
successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId)
allowed, err := checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration)
if err != nil {
fmt.Println("检查成功请求数限制失败:", err.Error())
abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
return
}
if !allowed {
abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到请求数限制:%d分钟内最多请求%d次", setting.ModelRequestRateLimitDurationMinutes, successMaxCount))
return
}
totalKey := fmt.Sprintf("rateLimit:%s", userId)
tb := limiter.New(ctx, rdb)
allowed, err = tb.Allow(
ctx,
totalKey,
limiter.WithCapacity(int64(totalMaxCount)*duration),
limiter.WithRate(int64(totalMaxCount)),
limiter.WithRequested(duration),
)
if err != nil {
fmt.Println("检查总请求数限制失败:", err.Error())
abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
return
}
if !allowed {
abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次包括失败次数请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount))
}
c.Next()
if c.Writer.Status() < 400 {
recordRedisRequest(ctx, rdb, successKey, successMaxCount)
}
}
}
func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc {
inMemoryRateLimiter.Init(time.Duration(setting.ModelRequestRateLimitDurationMinutes) * time.Minute)
return func(c *gin.Context) {
userId := strconv.Itoa(c.GetInt("id"))
totalKey := ModelRequestRateLimitCountMark + userId
successKey := ModelRequestRateLimitSuccessCountMark + userId
if totalMaxCount > 0 && !inMemoryRateLimiter.Request(totalKey, totalMaxCount, duration) {
c.Status(http.StatusTooManyRequests)
c.Abort()
return
}
checkKey := successKey + "_check"
if !inMemoryRateLimiter.Request(checkKey, successMaxCount, duration) {
c.Status(http.StatusTooManyRequests)
c.Abort()
return
}
c.Next()
if c.Writer.Status() < 400 {
inMemoryRateLimiter.Request(successKey, successMaxCount, duration)
}
}
}
func ModelRequestRateLimit() func(c *gin.Context) {
return func(c *gin.Context) {
if !setting.ModelRequestRateLimitEnabled {
c.Next()
return
}
duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60)
group := c.GetString("token_group")
if group == "" {
group = c.GetString("group")
}
if group == "" {
group = "default"
}
finalTotalCount := setting.ModelRequestRateLimitCount
finalSuccessCount := setting.ModelRequestRateLimitSuccessCount
foundGroupLimit := false
groupTotalCount, groupSuccessCount, found := setting.GetGroupRateLimit(group)
if found {
finalTotalCount = groupTotalCount
finalSuccessCount = groupSuccessCount
foundGroupLimit = true
common.LogWarn(c.Request.Context(), fmt.Sprintf("Using rate limit for group '%s': total=%d, success=%d", group, finalTotalCount, finalSuccessCount))
}
if !foundGroupLimit {
common.LogInfo(c.Request.Context(), fmt.Sprintf("No specific rate limit found for group '%s', using global limits: total=%d, success=%d", group, finalTotalCount, finalSuccessCount))
}
if common.RedisEnabled {
redisRateLimitHandler(duration, finalTotalCount, finalSuccessCount)(c)
} else {
memoryRateLimitHandler(duration, finalTotalCount, finalSuccessCount)(c)
}
}
}