//go:build unit package middleware import ( "context" "errors" "net/http" "net/http/httptest" "testing" "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) func TestSimpleModeBypassesQuotaCheck(t *testing.T) { gin.SetMode(gin.TestMode) limit := 1.0 group := &service.Group{ ID: 42, Name: "sub", Status: service.StatusActive, Hydrated: true, SubscriptionType: service.SubscriptionTypeSubscription, DailyLimitUSD: &limit, } user := &service.User{ ID: 7, Role: service.RoleUser, Status: service.StatusActive, Balance: 10, Concurrency: 3, } apiKey := &service.APIKey{ ID: 100, UserID: user.ID, Key: "test-key", Status: service.StatusActive, User: user, Group: group, } apiKey.GroupID = &group.ID apiKeyRepo := &stubApiKeyRepo{ getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { if key != apiKey.Key { return nil, service.ErrAPIKeyNotFound } clone := *apiKey return &clone, nil }, } t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) { cfg := &config.Config{RunMode: config.RunModeSimple} apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg) subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil) router := newAuthTestRouter(apiKeyService, subscriptionService, cfg) w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/t", nil) req.Header.Set("x-api-key", apiKey.Key) router.ServeHTTP(w, req) require.Equal(t, http.StatusOK, w.Code) }) t.Run("standard_mode_enforces_quota_check", func(t *testing.T) { cfg := &config.Config{RunMode: config.RunModeStandard} apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg) now := time.Now() sub := &service.UserSubscription{ ID: 55, UserID: user.ID, GroupID: group.ID, Status: service.SubscriptionStatusActive, ExpiresAt: now.Add(24 * time.Hour), DailyWindowStart: &now, DailyUsageUSD: 10, } subscriptionRepo := &stubUserSubscriptionRepo{ getActive: func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) { if userID != sub.UserID || groupID != sub.GroupID { return nil, service.ErrSubscriptionNotFound } clone := *sub return &clone, nil }, updateStatus: func(ctx context.Context, subscriptionID int64, status string) error { return nil }, activateWindow: func(ctx context.Context, id int64, start time.Time) error { return nil }, resetDaily: func(ctx context.Context, id int64, start time.Time) error { return nil }, resetWeekly: func(ctx context.Context, id int64, start time.Time) error { return nil }, resetMonthly: func(ctx context.Context, id int64, start time.Time) error { return nil }, } subscriptionService := service.NewSubscriptionService(nil, subscriptionRepo, nil) router := newAuthTestRouter(apiKeyService, subscriptionService, cfg) w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/t", nil) req.Header.Set("x-api-key", apiKey.Key) router.ServeHTTP(w, req) require.Equal(t, http.StatusTooManyRequests, w.Code) require.Contains(t, w.Body.String(), "USAGE_LIMIT_EXCEEDED") }) } func TestAPIKeyAuthSetsGroupContext(t *testing.T) { gin.SetMode(gin.TestMode) group := &service.Group{ ID: 101, Name: "g1", Status: service.StatusActive, Platform: service.PlatformAnthropic, Hydrated: true, } user := &service.User{ ID: 7, Role: service.RoleUser, Status: service.StatusActive, Balance: 10, Concurrency: 3, } apiKey := &service.APIKey{ ID: 100, UserID: user.ID, Key: "test-key", Status: service.StatusActive, User: user, Group: group, } apiKey.GroupID = &group.ID apiKeyRepo := &stubApiKeyRepo{ getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { if key != apiKey.Key { return nil, service.ErrAPIKeyNotFound } clone := *apiKey return &clone, nil }, } cfg := &config.Config{RunMode: config.RunModeSimple} apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg) router := gin.New() router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg))) router.GET("/t", func(c *gin.Context) { groupFromCtx, ok := c.Request.Context().Value(ctxkey.Group).(*service.Group) if !ok || groupFromCtx == nil || groupFromCtx.ID != group.ID { c.JSON(http.StatusInternalServerError, gin.H{"ok": false}) return } c.JSON(http.StatusOK, gin.H{"ok": true}) }) w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/t", nil) req.Header.Set("x-api-key", apiKey.Key) router.ServeHTTP(w, req) require.Equal(t, http.StatusOK, w.Code) } func TestAPIKeyAuthOverwritesInvalidContextGroup(t *testing.T) { gin.SetMode(gin.TestMode) group := &service.Group{ ID: 101, Name: "g1", Status: service.StatusActive, Platform: service.PlatformAnthropic, Hydrated: true, } user := &service.User{ ID: 7, Role: service.RoleUser, Status: service.StatusActive, Balance: 10, Concurrency: 3, } apiKey := &service.APIKey{ ID: 100, UserID: user.ID, Key: "test-key", Status: service.StatusActive, User: user, Group: group, } apiKey.GroupID = &group.ID apiKeyRepo := &stubApiKeyRepo{ getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { if key != apiKey.Key { return nil, service.ErrAPIKeyNotFound } clone := *apiKey return &clone, nil }, } cfg := &config.Config{RunMode: config.RunModeSimple} apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg) router := gin.New() router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg))) invalidGroup := &service.Group{ ID: group.ID, Platform: group.Platform, Status: group.Status, } router.GET("/t", func(c *gin.Context) { groupFromCtx, ok := c.Request.Context().Value(ctxkey.Group).(*service.Group) if !ok || groupFromCtx == nil || groupFromCtx.ID != group.ID || !groupFromCtx.Hydrated || groupFromCtx == invalidGroup { c.JSON(http.StatusInternalServerError, gin.H{"ok": false}) return } c.JSON(http.StatusOK, gin.H{"ok": true}) }) w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/t", nil) req.Header.Set("x-api-key", apiKey.Key) req = req.WithContext(context.WithValue(req.Context(), ctxkey.Group, invalidGroup)) router.ServeHTTP(w, req) require.Equal(t, http.StatusOK, w.Code) } func newAuthTestRouter(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine { router := gin.New() router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, cfg))) router.GET("/t", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"ok": true}) }) return router } type stubApiKeyRepo struct { getByKey func(ctx context.Context, key string) (*service.APIKey, error) } func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.APIKey) error { return errors.New("not implemented") } func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey, error) { return nil, errors.New("not implemented") } func (r *stubApiKeyRepo) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) { return "", 0, errors.New("not implemented") } func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) { if r.getByKey != nil { return r.getByKey(ctx, key) } return nil, errors.New("not implemented") } func (r *stubApiKeyRepo) GetByKeyForAuth(ctx context.Context, key string) (*service.APIKey, error) { return r.GetByKey(ctx, key) } func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.APIKey) error { return errors.New("not implemented") } func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error { return errors.New("not implemented") } func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) { return nil, nil, errors.New("not implemented") } func (r *stubApiKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) { return nil, errors.New("not implemented") } func (r *stubApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) { return 0, errors.New("not implemented") } func (r *stubApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) { return false, errors.New("not implemented") } func (r *stubApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) { return nil, nil, errors.New("not implemented") } func (r *stubApiKeyRepo) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) { return nil, errors.New("not implemented") } func (r *stubApiKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) { return 0, errors.New("not implemented") } func (r *stubApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { return 0, errors.New("not implemented") } func (r *stubApiKeyRepo) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) { return nil, errors.New("not implemented") } func (r *stubApiKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) { return nil, errors.New("not implemented") } type stubUserSubscriptionRepo struct { getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) updateStatus func(ctx context.Context, subscriptionID int64, status string) error activateWindow func(ctx context.Context, id int64, start time.Time) error resetDaily func(ctx context.Context, id int64, start time.Time) error resetWeekly func(ctx context.Context, id int64, start time.Time) error resetMonthly func(ctx context.Context, id int64, start time.Time) error } func (r *stubUserSubscriptionRepo) Create(ctx context.Context, sub *service.UserSubscription) error { return errors.New("not implemented") } func (r *stubUserSubscriptionRepo) GetByID(ctx context.Context, id int64) (*service.UserSubscription, error) { return nil, errors.New("not implemented") } func (r *stubUserSubscriptionRepo) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) { return nil, errors.New("not implemented") } func (r *stubUserSubscriptionRepo) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) { if r.getActive != nil { return r.getActive(ctx, userID, groupID) } return nil, errors.New("not implemented") } func (r *stubUserSubscriptionRepo) Update(ctx context.Context, sub *service.UserSubscription) error { return errors.New("not implemented") } func (r *stubUserSubscriptionRepo) Delete(ctx context.Context, id int64) error { return errors.New("not implemented") } func (r *stubUserSubscriptionRepo) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) { return nil, errors.New("not implemented") } func (r *stubUserSubscriptionRepo) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) { return nil, errors.New("not implemented") } func (r *stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) { return nil, nil, errors.New("not implemented") } func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) { return nil, nil, errors.New("not implemented") } func (r *stubUserSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) { return false, errors.New("not implemented") } func (r *stubUserSubscriptionRepo) ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error { return errors.New("not implemented") } func (r *stubUserSubscriptionRepo) UpdateStatus(ctx context.Context, subscriptionID int64, status string) error { if r.updateStatus != nil { return r.updateStatus(ctx, subscriptionID, status) } return errors.New("not implemented") } func (r *stubUserSubscriptionRepo) UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error { return errors.New("not implemented") } func (r *stubUserSubscriptionRepo) ActivateWindows(ctx context.Context, id int64, start time.Time) error { if r.activateWindow != nil { return r.activateWindow(ctx, id, start) } return errors.New("not implemented") } func (r *stubUserSubscriptionRepo) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { if r.resetDaily != nil { return r.resetDaily(ctx, id, newWindowStart) } return errors.New("not implemented") } func (r *stubUserSubscriptionRepo) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { if r.resetWeekly != nil { return r.resetWeekly(ctx, id, newWindowStart) } return errors.New("not implemented") } func (r *stubUserSubscriptionRepo) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { if r.resetMonthly != nil { return r.resetMonthly(ctx, id, newWindowStart) } return errors.New("not implemented") } func (r *stubUserSubscriptionRepo) IncrementUsage(ctx context.Context, id int64, costUSD float64) error { return errors.New("not implemented") } func (r *stubUserSubscriptionRepo) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) { return 0, errors.New("not implemented") }