package middleware import ( "context" "net/http" "net/http/httptest" "testing" "time" "github.com/gin-gonic/gin" "github.com/redis/go-redis/v9" "github.com/stretchr/testify/require" ) func TestWindowTTLMillis(t *testing.T) { require.Equal(t, int64(1), windowTTLMillis(500*time.Microsecond)) require.Equal(t, int64(1), windowTTLMillis(1500*time.Microsecond)) require.Equal(t, int64(2), windowTTLMillis(2500*time.Microsecond)) } func TestRateLimiterFailureModes(t *testing.T) { gin.SetMode(gin.TestMode) rdb := redis.NewClient(&redis.Options{ Addr: "127.0.0.1:1", DialTimeout: 50 * time.Millisecond, ReadTimeout: 50 * time.Millisecond, WriteTimeout: 50 * time.Millisecond, }) t.Cleanup(func() { _ = rdb.Close() }) limiter := NewRateLimiter(rdb) failOpenRouter := gin.New() failOpenRouter.Use(limiter.Limit("test", 1, time.Second)) failOpenRouter.GET("/test", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"ok": true}) }) req := httptest.NewRequest(http.MethodGet, "/test", nil) req.RemoteAddr = "127.0.0.1:1234" recorder := httptest.NewRecorder() failOpenRouter.ServeHTTP(recorder, req) require.Equal(t, http.StatusOK, recorder.Code) failCloseRouter := gin.New() failCloseRouter.Use(limiter.LimitWithOptions("test", 1, time.Second, RateLimitOptions{ FailureMode: RateLimitFailClose, })) failCloseRouter.GET("/test", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"ok": true}) }) req = httptest.NewRequest(http.MethodGet, "/test", nil) req.RemoteAddr = "127.0.0.1:1234" recorder = httptest.NewRecorder() failCloseRouter.ServeHTTP(recorder, req) require.Equal(t, http.StatusTooManyRequests, recorder.Code) } func TestRateLimiterSuccessAndLimit(t *testing.T) { gin.SetMode(gin.TestMode) originalRun := rateLimitRun counts := []int64{1, 2} callIndex := 0 rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, bool, error) { if callIndex >= len(counts) { return counts[len(counts)-1], false, nil } value := counts[callIndex] callIndex++ return value, false, nil } t.Cleanup(func() { rateLimitRun = originalRun }) limiter := NewRateLimiter(redis.NewClient(&redis.Options{Addr: "127.0.0.1:1"})) router := gin.New() router.Use(limiter.Limit("test", 1, time.Second)) router.GET("/test", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"ok": true}) }) req := httptest.NewRequest(http.MethodGet, "/test", nil) req.RemoteAddr = "127.0.0.1:1234" recorder := httptest.NewRecorder() router.ServeHTTP(recorder, req) require.Equal(t, http.StatusOK, recorder.Code) req = httptest.NewRequest(http.MethodGet, "/test", nil) req.RemoteAddr = "127.0.0.1:1234" recorder = httptest.NewRecorder() router.ServeHTTP(recorder, req) require.Equal(t, http.StatusTooManyRequests, recorder.Code) }