fix(限流): 原子化 Redis 限流并支持故障策略
使用 Lua 脚本原子设置计数与过期,修复 TTL 缺失\n支持 fail-open/fail-close 并对优惠码验证启用 fail-close\n新增单元与集成测试覆盖关键分支\n\n测试:go test ./...
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
@@ -8,6 +9,33 @@ import (
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// RateLimitFailureMode Redis 故障策略
|
||||
type RateLimitFailureMode int
|
||||
|
||||
const (
|
||||
RateLimitFailOpen RateLimitFailureMode = iota
|
||||
RateLimitFailClose
|
||||
)
|
||||
|
||||
// RateLimitOptions 限流可选配置
|
||||
type RateLimitOptions struct {
|
||||
FailureMode RateLimitFailureMode
|
||||
}
|
||||
|
||||
var rateLimitScript = redis.NewScript(`
|
||||
local current = redis.call('INCR', KEYS[1])
|
||||
local ttl = redis.call('PTTL', KEYS[1])
|
||||
if current == 1 or ttl == -1 then
|
||||
redis.call('PEXPIRE', KEYS[1], ARGV[1])
|
||||
end
|
||||
return current
|
||||
`)
|
||||
|
||||
// rateLimitRun 允许测试覆写脚本执行逻辑
|
||||
var rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, error) {
|
||||
return rateLimitScript.Run(ctx, client, []string{key}, windowMillis).Int64()
|
||||
}
|
||||
|
||||
// RateLimiter Redis 速率限制器
|
||||
type RateLimiter struct {
|
||||
redis *redis.Client
|
||||
@@ -27,34 +55,57 @@ func NewRateLimiter(redisClient *redis.Client) *RateLimiter {
|
||||
// limit: 时间窗口内最大请求数
|
||||
// window: 时间窗口
|
||||
func (r *RateLimiter) Limit(key string, limit int, window time.Duration) gin.HandlerFunc {
|
||||
return r.LimitWithOptions(key, limit, window, RateLimitOptions{})
|
||||
}
|
||||
|
||||
// LimitWithOptions 返回速率限制中间件(带可选配置)
|
||||
func (r *RateLimiter) LimitWithOptions(key string, limit int, window time.Duration, opts RateLimitOptions) gin.HandlerFunc {
|
||||
failureMode := opts.FailureMode
|
||||
if failureMode != RateLimitFailClose {
|
||||
failureMode = RateLimitFailOpen
|
||||
}
|
||||
|
||||
return func(c *gin.Context) {
|
||||
ip := c.ClientIP()
|
||||
redisKey := r.prefix + key + ":" + ip
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// 使用 INCR 原子操作增加计数
|
||||
count, err := r.redis.Incr(ctx, redisKey).Result()
|
||||
windowMillis := windowTTLMillis(window)
|
||||
|
||||
// 使用 Lua 脚本原子操作增加计数并设置过期
|
||||
count, err := rateLimitRun(ctx, r.redis, redisKey, windowMillis)
|
||||
if err != nil {
|
||||
if failureMode == RateLimitFailClose {
|
||||
abortRateLimit(c)
|
||||
return
|
||||
}
|
||||
// Redis 错误时放行,避免影响正常服务
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 首次访问时设置过期时间
|
||||
if count == 1 {
|
||||
r.redis.Expire(ctx, redisKey, window)
|
||||
}
|
||||
|
||||
// 超过限制
|
||||
if count > int64(limit) {
|
||||
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
|
||||
"error": "rate limit exceeded",
|
||||
"message": "Too many requests, please try again later",
|
||||
})
|
||||
abortRateLimit(c)
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func windowTTLMillis(window time.Duration) int64 {
|
||||
ttl := window.Milliseconds()
|
||||
if ttl < 1 {
|
||||
return 1
|
||||
}
|
||||
return ttl
|
||||
}
|
||||
|
||||
func abortRateLimit(c *gin.Context) {
|
||||
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
|
||||
"error": "rate limit exceeded",
|
||||
"message": "Too many requests, please try again later",
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user