162 lines
3.7 KiB
Go
162 lines
3.7 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log"
|
|
"net/http"
|
|
"strconv"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/redis/go-redis/v9"
|
|
)
|
|
|
|
// RateLimitFailureMode Redis 故障策略
|
|
type RateLimitFailureMode int
|
|
|
|
const (
|
|
RateLimitFailOpen RateLimitFailureMode = iota
|
|
RateLimitFailClose
|
|
)
|
|
|
|
// RateLimitOptions 限流可选配置
|
|
type RateLimitOptions struct {
|
|
FailureMode RateLimitFailureMode
|
|
}
|
|
|
|
var rateLimitScript = redis.NewScript(`
|
|
local current = redis.call('INCR', KEYS[1])
|
|
local ttl = redis.call('PTTL', KEYS[1])
|
|
local repaired = 0
|
|
if current == 1 then
|
|
redis.call('PEXPIRE', KEYS[1], ARGV[1])
|
|
elseif ttl == -1 then
|
|
redis.call('PEXPIRE', KEYS[1], ARGV[1])
|
|
repaired = 1
|
|
end
|
|
return {current, repaired}
|
|
`)
|
|
|
|
// rateLimitRun 允许测试覆写脚本执行逻辑
|
|
var rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, bool, error) {
|
|
values, err := rateLimitScript.Run(ctx, client, []string{key}, windowMillis).Slice()
|
|
if err != nil {
|
|
return 0, false, err
|
|
}
|
|
if len(values) < 2 {
|
|
return 0, false, fmt.Errorf("rate limit script returned %d values", len(values))
|
|
}
|
|
count, err := parseInt64(values[0])
|
|
if err != nil {
|
|
return 0, false, err
|
|
}
|
|
repaired, err := parseInt64(values[1])
|
|
if err != nil {
|
|
return 0, false, err
|
|
}
|
|
return count, repaired == 1, nil
|
|
}
|
|
|
|
// RateLimiter Redis 速率限制器
|
|
type RateLimiter struct {
|
|
redis *redis.Client
|
|
prefix string
|
|
}
|
|
|
|
// NewRateLimiter 创建速率限制器实例
|
|
func NewRateLimiter(redisClient *redis.Client) *RateLimiter {
|
|
return &RateLimiter{
|
|
redis: redisClient,
|
|
prefix: "rate_limit:",
|
|
}
|
|
}
|
|
|
|
// Limit 返回速率限制中间件
|
|
// key: 限制类型标识
|
|
// limit: 时间窗口内最大请求数
|
|
// window: 时间窗口
|
|
func (r *RateLimiter) Limit(key string, limit int, window time.Duration) gin.HandlerFunc {
|
|
return r.LimitWithOptions(key, limit, window, RateLimitOptions{})
|
|
}
|
|
|
|
// LimitWithOptions 返回速率限制中间件(带可选配置)
|
|
func (r *RateLimiter) LimitWithOptions(key string, limit int, window time.Duration, opts RateLimitOptions) gin.HandlerFunc {
|
|
failureMode := opts.FailureMode
|
|
if failureMode != RateLimitFailClose {
|
|
failureMode = RateLimitFailOpen
|
|
}
|
|
|
|
return func(c *gin.Context) {
|
|
ip := c.ClientIP()
|
|
redisKey := r.prefix + key + ":" + ip
|
|
|
|
ctx := c.Request.Context()
|
|
|
|
windowMillis := windowTTLMillis(window)
|
|
|
|
// 使用 Lua 脚本原子操作增加计数并设置过期
|
|
count, repaired, err := rateLimitRun(ctx, r.redis, redisKey, windowMillis)
|
|
if err != nil {
|
|
log.Printf("[RateLimit] redis error: key=%s mode=%s err=%v", redisKey, failureModeLabel(failureMode), err)
|
|
if failureMode == RateLimitFailClose {
|
|
abortRateLimit(c)
|
|
return
|
|
}
|
|
// Redis 错误时放行,避免影响正常服务
|
|
c.Next()
|
|
return
|
|
}
|
|
if repaired {
|
|
log.Printf("[RateLimit] ttl repaired: key=%s window_ms=%d", redisKey, windowMillis)
|
|
}
|
|
|
|
// 超过限制
|
|
if count > int64(limit) {
|
|
abortRateLimit(c)
|
|
return
|
|
}
|
|
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
func windowTTLMillis(window time.Duration) int64 {
|
|
ttl := window.Milliseconds()
|
|
if ttl < 1 {
|
|
return 1
|
|
}
|
|
return ttl
|
|
}
|
|
|
|
func abortRateLimit(c *gin.Context) {
|
|
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
|
|
"error": "rate limit exceeded",
|
|
"message": "Too many requests, please try again later",
|
|
})
|
|
}
|
|
|
|
func failureModeLabel(mode RateLimitFailureMode) string {
|
|
if mode == RateLimitFailClose {
|
|
return "fail-close"
|
|
}
|
|
return "fail-open"
|
|
}
|
|
|
|
func parseInt64(value any) (int64, error) {
|
|
switch v := value.(type) {
|
|
case int64:
|
|
return v, nil
|
|
case int:
|
|
return int64(v), nil
|
|
case string:
|
|
parsed, err := strconv.ParseInt(v, 10, 64)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return parsed, nil
|
|
default:
|
|
return 0, fmt.Errorf("unexpected value type %T", value)
|
|
}
|
|
}
|