package middleware import ( "context" "fmt" "net/http" "one-api/common" "one-api/common/limiter" "one-api/setting" "strconv" "time" "github.com/gin-gonic/gin" "github.com/go-redis/redis/v8" ) const ( ModelRequestRateLimitCountMark = "MRRL" ModelRequestRateLimitSuccessCountMark = "MRRLS" ) func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, maxCount int, duration int64) (bool, error) { if maxCount == 0 { return true, nil } length, err := rdb.LLen(ctx, key).Result() if err != nil { return false, err } if length < int64(maxCount) { return true, nil } oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result() oldTime, err := time.Parse(timeFormat, oldTimeStr) if err != nil { return false, err } nowTimeStr := time.Now().Format(timeFormat) nowTime, err := time.Parse(timeFormat, nowTimeStr) if err != nil { return false, err } subTime := nowTime.Sub(oldTime).Seconds() if int64(subTime) < duration { rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute) return false, nil } return true, nil } func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxCount int) { if maxCount == 0 { return } now := time.Now().Format(timeFormat) rdb.LPush(ctx, key, now) rdb.LTrim(ctx, key, 0, int64(maxCount-1)) rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute) } func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc { return func(c *gin.Context) { userId := strconv.Itoa(c.GetInt("id")) ctx := context.Background() rdb := common.RDB successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId) allowed, err := checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration) if err != nil { fmt.Println("检查成功请求数限制失败:", err.Error()) abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed") return } if !allowed { abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到请求数限制:%d分钟内最多请求%d次", setting.ModelRequestRateLimitDurationMinutes, successMaxCount)) return } totalKey := fmt.Sprintf("rateLimit:%s", userId) tb := limiter.New(ctx, rdb) allowed, err = tb.Allow( ctx, totalKey, limiter.WithCapacity(int64(totalMaxCount)*duration), limiter.WithRate(int64(totalMaxCount)), limiter.WithRequested(duration), ) if err != nil { fmt.Println("检查总请求数限制失败:", err.Error()) abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed") return } if !allowed { abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount)) } c.Next() if c.Writer.Status() < 400 { recordRedisRequest(ctx, rdb, successKey, successMaxCount) } } } func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc { inMemoryRateLimiter.Init(time.Duration(setting.ModelRequestRateLimitDurationMinutes) * time.Minute) return func(c *gin.Context) { userId := strconv.Itoa(c.GetInt("id")) totalKey := ModelRequestRateLimitCountMark + userId successKey := ModelRequestRateLimitSuccessCountMark + userId if totalMaxCount > 0 && !inMemoryRateLimiter.Request(totalKey, totalMaxCount, duration) { c.Status(http.StatusTooManyRequests) c.Abort() return } checkKey := successKey + "_check" if !inMemoryRateLimiter.Request(checkKey, successMaxCount, duration) { c.Status(http.StatusTooManyRequests) c.Abort() return } c.Next() if c.Writer.Status() < 400 { inMemoryRateLimiter.Request(successKey, successMaxCount, duration) } } } func ModelRequestRateLimit() func(c *gin.Context) { return func(c *gin.Context) { if !setting.ModelRequestRateLimitEnabled { c.Next() return } duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60) group := c.GetString("token_group") if group == "" { group = c.GetString("group") } if group == "" { group = "default" } finalTotalCount := setting.ModelRequestRateLimitCount finalSuccessCount := setting.ModelRequestRateLimitSuccessCount foundGroupLimit := false groupTotalCount, groupSuccessCount, found := setting.GetGroupRateLimit(group) if found { finalTotalCount = groupTotalCount finalSuccessCount = groupSuccessCount foundGroupLimit = true common.LogWarn(c.Request.Context(), fmt.Sprintf("Using rate limit for group '%s': total=%d, success=%d", group, finalTotalCount, finalSuccessCount)) } if !foundGroupLimit { common.LogInfo(c.Request.Context(), fmt.Sprintf("No specific rate limit found for group '%s', using global limits: total=%d, success=%d", group, finalTotalCount, finalSuccessCount)) } if common.RedisEnabled { redisRateLimitHandler(duration, finalTotalCount, finalSuccessCount)(c) } else { memoryRateLimitHandler(duration, finalTotalCount, finalSuccessCount)(c) } } }