package middleware import ( "context" "encoding/json" "errors" "net/http" "net/http/httptest" "testing" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) type fakeApiKeyRepo struct { getByKey func(ctx context.Context, key string) (*service.ApiKey, error) } func (f fakeApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error { return errors.New("not implemented") } func (f fakeApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) { return nil, errors.New("not implemented") } func (f fakeApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) { return 0, errors.New("not implemented") } func (f fakeApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) { if f.getByKey == nil { return nil, errors.New("unexpected call") } return f.getByKey(ctx, key) } func (f fakeApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error { return errors.New("not implemented") } func (f fakeApiKeyRepo) Delete(ctx context.Context, id int64) error { return errors.New("not implemented") } func (f fakeApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) { return nil, nil, errors.New("not implemented") } func (f fakeApiKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) { return nil, errors.New("not implemented") } func (f fakeApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) { return 0, errors.New("not implemented") } func (f fakeApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) { return false, errors.New("not implemented") } func (f fakeApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) { return nil, nil, errors.New("not implemented") } func (f fakeApiKeyRepo) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) { return nil, errors.New("not implemented") } func (f fakeApiKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) { return 0, errors.New("not implemented") } func (f fakeApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { return 0, errors.New("not implemented") } type googleErrorResponse struct { Error struct { Code int `json:"code"` Message string `json:"message"` Status string `json:"status"` } `json:"error"` } func newTestApiKeyService(repo service.ApiKeyRepository) *service.ApiKeyService { return service.NewApiKeyService( repo, nil, // userRepo (unused in GetByKey) nil, // groupRepo nil, // userSubRepo nil, // cache &config.Config{}, ) } func TestApiKeyAuthWithSubscriptionGoogle_MissingKey(t *testing.T) { gin.SetMode(gin.TestMode) r := gin.New() apiKeyService := newTestApiKeyService(fakeApiKeyRepo{ getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) { return nil, errors.New("should not be called") }, }) r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{})) r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil) rec := httptest.NewRecorder() r.ServeHTTP(rec, req) require.Equal(t, http.StatusUnauthorized, rec.Code) var resp googleErrorResponse require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) require.Equal(t, http.StatusUnauthorized, resp.Error.Code) require.Equal(t, "API key is required", resp.Error.Message) require.Equal(t, "UNAUTHENTICATED", resp.Error.Status) } func TestApiKeyAuthWithSubscriptionGoogle_InvalidKey(t *testing.T) { gin.SetMode(gin.TestMode) r := gin.New() apiKeyService := newTestApiKeyService(fakeApiKeyRepo{ getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) { return nil, service.ErrApiKeyNotFound }, }) r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{})) r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil) req.Header.Set("Authorization", "Bearer invalid") rec := httptest.NewRecorder() r.ServeHTTP(rec, req) require.Equal(t, http.StatusUnauthorized, rec.Code) var resp googleErrorResponse require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) require.Equal(t, http.StatusUnauthorized, resp.Error.Code) require.Equal(t, "Invalid API key", resp.Error.Message) require.Equal(t, "UNAUTHENTICATED", resp.Error.Status) } func TestApiKeyAuthWithSubscriptionGoogle_RepoError(t *testing.T) { gin.SetMode(gin.TestMode) r := gin.New() apiKeyService := newTestApiKeyService(fakeApiKeyRepo{ getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) { return nil, errors.New("db down") }, }) r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{})) r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil) req.Header.Set("Authorization", "Bearer any") rec := httptest.NewRecorder() r.ServeHTTP(rec, req) require.Equal(t, http.StatusInternalServerError, rec.Code) var resp googleErrorResponse require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) require.Equal(t, http.StatusInternalServerError, resp.Error.Code) require.Equal(t, "Failed to validate API key", resp.Error.Message) require.Equal(t, "INTERNAL", resp.Error.Status) } func TestApiKeyAuthWithSubscriptionGoogle_DisabledKey(t *testing.T) { gin.SetMode(gin.TestMode) r := gin.New() apiKeyService := newTestApiKeyService(fakeApiKeyRepo{ getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) { return &service.ApiKey{ ID: 1, Key: key, Status: service.StatusDisabled, User: &service.User{ ID: 123, Status: service.StatusActive, }, }, nil }, }) r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{})) r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil) req.Header.Set("Authorization", "Bearer disabled") rec := httptest.NewRecorder() r.ServeHTTP(rec, req) require.Equal(t, http.StatusUnauthorized, rec.Code) var resp googleErrorResponse require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) require.Equal(t, http.StatusUnauthorized, resp.Error.Code) require.Equal(t, "API key is disabled", resp.Error.Message) require.Equal(t, "UNAUTHENTICATED", resp.Error.Status) } func TestApiKeyAuthWithSubscriptionGoogle_InsufficientBalance(t *testing.T) { gin.SetMode(gin.TestMode) r := gin.New() apiKeyService := newTestApiKeyService(fakeApiKeyRepo{ getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) { return &service.ApiKey{ ID: 1, Key: key, Status: service.StatusActive, User: &service.User{ ID: 123, Status: service.StatusActive, Balance: 0, }, }, nil }, }) r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{})) r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil) req.Header.Set("Authorization", "Bearer ok") rec := httptest.NewRecorder() r.ServeHTTP(rec, req) require.Equal(t, http.StatusForbidden, rec.Code) var resp googleErrorResponse require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) require.Equal(t, http.StatusForbidden, resp.Error.Code) require.Equal(t, "Insufficient account balance", resp.Error.Message) require.Equal(t, "PERMISSION_DENIED", resp.Error.Status) }