fix(限流): 原子化 Redis 限流并支持故障策略
使用 Lua 脚本原子设置计数与过期,修复 TTL 缺失\n支持 fail-open/fail-close 并对优惠码验证启用 fail-close\n新增单元与集成测试覆盖关键分支\n\n测试:go test ./...
This commit is contained in:
100
backend/internal/middleware/rate_limiter_test.go
Normal file
100
backend/internal/middleware/rate_limiter_test.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestWindowTTLMillis(t *testing.T) {
|
||||
require.Equal(t, int64(1), windowTTLMillis(500*time.Microsecond))
|
||||
require.Equal(t, int64(1), windowTTLMillis(1500*time.Microsecond))
|
||||
require.Equal(t, int64(2), windowTTLMillis(2500*time.Microsecond))
|
||||
}
|
||||
|
||||
func TestRateLimiterFailureModes(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: "127.0.0.1:1",
|
||||
DialTimeout: 50 * time.Millisecond,
|
||||
ReadTimeout: 50 * time.Millisecond,
|
||||
WriteTimeout: 50 * time.Millisecond,
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
_ = rdb.Close()
|
||||
})
|
||||
|
||||
limiter := NewRateLimiter(rdb)
|
||||
|
||||
failOpenRouter := gin.New()
|
||||
failOpenRouter.Use(limiter.Limit("test", 1, time.Second))
|
||||
failOpenRouter.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.RemoteAddr = "127.0.0.1:1234"
|
||||
recorder := httptest.NewRecorder()
|
||||
failOpenRouter.ServeHTTP(recorder, req)
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
|
||||
failCloseRouter := gin.New()
|
||||
failCloseRouter.Use(limiter.LimitWithOptions("test", 1, time.Second, RateLimitOptions{
|
||||
FailureMode: RateLimitFailClose,
|
||||
}))
|
||||
failCloseRouter.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.RemoteAddr = "127.0.0.1:1234"
|
||||
recorder = httptest.NewRecorder()
|
||||
failCloseRouter.ServeHTTP(recorder, req)
|
||||
require.Equal(t, http.StatusTooManyRequests, recorder.Code)
|
||||
}
|
||||
|
||||
func TestRateLimiterSuccessAndLimit(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
originalRun := rateLimitRun
|
||||
counts := []int64{1, 2}
|
||||
callIndex := 0
|
||||
rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, error) {
|
||||
if callIndex >= len(counts) {
|
||||
return counts[len(counts)-1], nil
|
||||
}
|
||||
value := counts[callIndex]
|
||||
callIndex++
|
||||
return value, nil
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
rateLimitRun = originalRun
|
||||
})
|
||||
|
||||
limiter := NewRateLimiter(redis.NewClient(&redis.Options{Addr: "127.0.0.1:1"}))
|
||||
|
||||
router := gin.New()
|
||||
router.Use(limiter.Limit("test", 1, time.Second))
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.RemoteAddr = "127.0.0.1:1234"
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.RemoteAddr = "127.0.0.1:1234"
|
||||
recorder = httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
require.Equal(t, http.StatusTooManyRequests, recorder.Code)
|
||||
}
|
||||
Reference in New Issue
Block a user