127 lines
3.2 KiB
Go
127 lines
3.2 KiB
Go
//go:build unit
|
|
|
|
package middleware
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestClientRequestID_GeneratesWhenMissing(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
r := gin.New()
|
|
r.Use(ClientRequestID())
|
|
r.GET("/t", func(c *gin.Context) {
|
|
v := c.Request.Context().Value(ctxkey.ClientRequestID)
|
|
require.NotNil(t, v)
|
|
id, ok := v.(string)
|
|
require.True(t, ok)
|
|
require.NotEmpty(t, id)
|
|
c.Status(http.StatusOK)
|
|
})
|
|
|
|
w := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
|
r.ServeHTTP(w, req)
|
|
require.Equal(t, http.StatusOK, w.Code)
|
|
}
|
|
|
|
func TestClientRequestID_PreservesExisting(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
r := gin.New()
|
|
r.Use(ClientRequestID())
|
|
r.GET("/t", func(c *gin.Context) {
|
|
id, ok := c.Request.Context().Value(ctxkey.ClientRequestID).(string)
|
|
require.True(t, ok)
|
|
require.Equal(t, "keep", id)
|
|
c.Status(http.StatusOK)
|
|
})
|
|
|
|
w := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
|
req = req.WithContext(context.WithValue(req.Context(), ctxkey.ClientRequestID, "keep"))
|
|
r.ServeHTTP(w, req)
|
|
require.Equal(t, http.StatusOK, w.Code)
|
|
}
|
|
|
|
func TestRequestBodyLimit_LimitsBody(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
r := gin.New()
|
|
r.Use(RequestBodyLimit(4))
|
|
r.POST("/t", func(c *gin.Context) {
|
|
_, err := io.ReadAll(c.Request.Body)
|
|
require.Error(t, err)
|
|
c.Status(http.StatusOK)
|
|
})
|
|
|
|
w := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodPost, "/t", bytes.NewBufferString("12345"))
|
|
r.ServeHTTP(w, req)
|
|
require.Equal(t, http.StatusOK, w.Code)
|
|
}
|
|
|
|
func TestForcePlatform_SetsContextAndGinValue(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
r := gin.New()
|
|
r.Use(ForcePlatform("anthropic"))
|
|
r.GET("/t", func(c *gin.Context) {
|
|
require.True(t, HasForcePlatform(c))
|
|
v, ok := GetForcePlatformFromContext(c)
|
|
require.True(t, ok)
|
|
require.Equal(t, "anthropic", v)
|
|
|
|
ctxV := c.Request.Context().Value(ctxkey.ForcePlatform)
|
|
require.Equal(t, "anthropic", ctxV)
|
|
c.Status(http.StatusOK)
|
|
})
|
|
|
|
w := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
|
r.ServeHTTP(w, req)
|
|
require.Equal(t, http.StatusOK, w.Code)
|
|
}
|
|
|
|
func TestAuthSubjectHelpers_RoundTrip(t *testing.T) {
|
|
c := &gin.Context{}
|
|
c.Set(string(ContextKeyUser), AuthSubject{UserID: 1, Concurrency: 2})
|
|
c.Set(string(ContextKeyUserRole), "admin")
|
|
|
|
sub, ok := GetAuthSubjectFromContext(c)
|
|
require.True(t, ok)
|
|
require.Equal(t, int64(1), sub.UserID)
|
|
require.Equal(t, 2, sub.Concurrency)
|
|
|
|
role, ok := GetUserRoleFromContext(c)
|
|
require.True(t, ok)
|
|
require.Equal(t, "admin", role)
|
|
}
|
|
|
|
func TestAPIKeyAndSubscriptionFromContext(t *testing.T) {
|
|
c := &gin.Context{}
|
|
|
|
key := &service.APIKey{ID: 1}
|
|
c.Set(string(ContextKeyAPIKey), key)
|
|
gotKey, ok := GetAPIKeyFromContext(c)
|
|
require.True(t, ok)
|
|
require.Equal(t, int64(1), gotKey.ID)
|
|
|
|
sub := &service.UserSubscription{ID: 2}
|
|
c.Set(string(ContextKeySubscription), sub)
|
|
gotSub, ok := GetSubscriptionFromContext(c)
|
|
require.True(t, ok)
|
|
require.Equal(t, int64(2), gotSub.ID)
|
|
}
|