使用 Lua 脚本原子设置计数与过期,修复 TTL 缺失\n支持 fail-open/fail-close 并对优惠码验证启用 fail-close\n新增单元与集成测试覆盖关键分支\n\n测试:go test ./...
101 lines
2.8 KiB
Go
101 lines
2.8 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/redis/go-redis/v9"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestWindowTTLMillis(t *testing.T) {
|
|
require.Equal(t, int64(1), windowTTLMillis(500*time.Microsecond))
|
|
require.Equal(t, int64(1), windowTTLMillis(1500*time.Microsecond))
|
|
require.Equal(t, int64(2), windowTTLMillis(2500*time.Microsecond))
|
|
}
|
|
|
|
func TestRateLimiterFailureModes(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
rdb := redis.NewClient(&redis.Options{
|
|
Addr: "127.0.0.1:1",
|
|
DialTimeout: 50 * time.Millisecond,
|
|
ReadTimeout: 50 * time.Millisecond,
|
|
WriteTimeout: 50 * time.Millisecond,
|
|
})
|
|
t.Cleanup(func() {
|
|
_ = rdb.Close()
|
|
})
|
|
|
|
limiter := NewRateLimiter(rdb)
|
|
|
|
failOpenRouter := gin.New()
|
|
failOpenRouter.Use(limiter.Limit("test", 1, time.Second))
|
|
failOpenRouter.GET("/test", func(c *gin.Context) {
|
|
c.JSON(http.StatusOK, gin.H{"ok": true})
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
|
req.RemoteAddr = "127.0.0.1:1234"
|
|
recorder := httptest.NewRecorder()
|
|
failOpenRouter.ServeHTTP(recorder, req)
|
|
require.Equal(t, http.StatusOK, recorder.Code)
|
|
|
|
failCloseRouter := gin.New()
|
|
failCloseRouter.Use(limiter.LimitWithOptions("test", 1, time.Second, RateLimitOptions{
|
|
FailureMode: RateLimitFailClose,
|
|
}))
|
|
failCloseRouter.GET("/test", func(c *gin.Context) {
|
|
c.JSON(http.StatusOK, gin.H{"ok": true})
|
|
})
|
|
|
|
req = httptest.NewRequest(http.MethodGet, "/test", nil)
|
|
req.RemoteAddr = "127.0.0.1:1234"
|
|
recorder = httptest.NewRecorder()
|
|
failCloseRouter.ServeHTTP(recorder, req)
|
|
require.Equal(t, http.StatusTooManyRequests, recorder.Code)
|
|
}
|
|
|
|
func TestRateLimiterSuccessAndLimit(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
originalRun := rateLimitRun
|
|
counts := []int64{1, 2}
|
|
callIndex := 0
|
|
rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, error) {
|
|
if callIndex >= len(counts) {
|
|
return counts[len(counts)-1], nil
|
|
}
|
|
value := counts[callIndex]
|
|
callIndex++
|
|
return value, nil
|
|
}
|
|
t.Cleanup(func() {
|
|
rateLimitRun = originalRun
|
|
})
|
|
|
|
limiter := NewRateLimiter(redis.NewClient(&redis.Options{Addr: "127.0.0.1:1"}))
|
|
|
|
router := gin.New()
|
|
router.Use(limiter.Limit("test", 1, time.Second))
|
|
router.GET("/test", func(c *gin.Context) {
|
|
c.JSON(http.StatusOK, gin.H{"ok": true})
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
|
req.RemoteAddr = "127.0.0.1:1234"
|
|
recorder := httptest.NewRecorder()
|
|
router.ServeHTTP(recorder, req)
|
|
require.Equal(t, http.StatusOK, recorder.Code)
|
|
|
|
req = httptest.NewRequest(http.MethodGet, "/test", nil)
|
|
req.RemoteAddr = "127.0.0.1:1234"
|
|
recorder = httptest.NewRecorder()
|
|
router.ServeHTTP(recorder, req)
|
|
require.Equal(t, http.StatusTooManyRequests, recorder.Code)
|
|
}
|