package middleware import ( "context" "encoding/json" "errors" "net/http" "net/http/httptest" "testing" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) type fakeAPIKeyRepo struct { getByKey func(ctx context.Context, key string) (*service.APIKey, error) } func (f fakeAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error { return errors.New("not implemented") } func (f fakeAPIKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey, error) { return nil, errors.New("not implemented") } func (f fakeAPIKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) { return 0, errors.New("not implemented") } func (f fakeAPIKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) { if f.getByKey == nil { return nil, errors.New("unexpected call") } return f.getByKey(ctx, key) } func (f fakeAPIKeyRepo) Update(ctx context.Context, key *service.APIKey) error { return errors.New("not implemented") } func (f fakeAPIKeyRepo) Delete(ctx context.Context, id int64) error { return errors.New("not implemented") } func (f fakeAPIKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) { return nil, nil, errors.New("not implemented") } func (f fakeAPIKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) { return nil, errors.New("not implemented") } func (f fakeAPIKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) { return 0, errors.New("not implemented") } func (f fakeAPIKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) { return false, errors.New("not implemented") } func (f fakeAPIKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) { return nil, nil, errors.New("not implemented") } func (f fakeAPIKeyRepo) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) { return nil, errors.New("not implemented") } func (f fakeAPIKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) { return 0, errors.New("not implemented") } func (f fakeAPIKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { return 0, errors.New("not implemented") } type googleErrorResponse struct { Error struct { Code int `json:"code"` Message string `json:"message"` Status string `json:"status"` } `json:"error"` } func newTestAPIKeyService(repo service.APIKeyRepository) *service.APIKeyService { return service.NewAPIKeyService( repo, nil, // userRepo (unused in GetByKey) nil, // groupRepo nil, // userSubRepo nil, // cache &config.Config{}, ) } func TestApiKeyAuthWithSubscriptionGoogle_MissingKey(t *testing.T) { gin.SetMode(gin.TestMode) r := gin.New() apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{ getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { return nil, errors.New("should not be called") }, }) r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{})) r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil) rec := httptest.NewRecorder() r.ServeHTTP(rec, req) require.Equal(t, http.StatusUnauthorized, rec.Code) var resp googleErrorResponse require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) require.Equal(t, http.StatusUnauthorized, resp.Error.Code) require.Equal(t, "API key is required", resp.Error.Message) require.Equal(t, "UNAUTHENTICATED", resp.Error.Status) } func TestApiKeyAuthWithSubscriptionGoogle_QueryApiKeyRejected(t *testing.T) { gin.SetMode(gin.TestMode) r := gin.New() apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{ getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { return nil, errors.New("should not be called") }, }) r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{})) r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) req := httptest.NewRequest(http.MethodGet, "/v1beta/test?api_key=legacy", nil) rec := httptest.NewRecorder() r.ServeHTTP(rec, req) require.Equal(t, http.StatusBadRequest, rec.Code) var resp googleErrorResponse require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) require.Equal(t, http.StatusBadRequest, resp.Error.Code) require.Equal(t, "Query parameter api_key is deprecated. Use Authorization header or key instead.", resp.Error.Message) require.Equal(t, "INVALID_ARGUMENT", resp.Error.Status) } func TestApiKeyAuthWithSubscriptionGoogleSetsGroupContext(t *testing.T) { gin.SetMode(gin.TestMode) group := &service.Group{ ID: 99, Name: "g1", Status: service.StatusActive, Platform: service.PlatformGemini, Hydrated: true, } user := &service.User{ ID: 7, Role: service.RoleUser, Status: service.StatusActive, Balance: 10, Concurrency: 3, } apiKey := &service.APIKey{ ID: 100, UserID: user.ID, Key: "test-key", Status: service.StatusActive, User: user, Group: group, } apiKey.GroupID = &group.ID apiKeyService := service.NewAPIKeyService( fakeAPIKeyRepo{ getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { if key != apiKey.Key { return nil, service.ErrAPIKeyNotFound } clone := *apiKey return &clone, nil }, }, nil, nil, nil, nil, &config.Config{RunMode: config.RunModeSimple}, ) cfg := &config.Config{RunMode: config.RunModeSimple} r := gin.New() r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg)) r.GET("/v1beta/test", func(c *gin.Context) { groupFromCtx, ok := c.Request.Context().Value(ctxkey.Group).(*service.Group) if !ok || groupFromCtx == nil || groupFromCtx.ID != group.ID { c.JSON(http.StatusInternalServerError, gin.H{"ok": false}) return } c.JSON(http.StatusOK, gin.H{"ok": true}) }) req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil) req.Header.Set("x-api-key", apiKey.Key) rec := httptest.NewRecorder() r.ServeHTTP(rec, req) require.Equal(t, http.StatusOK, rec.Code) } func TestApiKeyAuthWithSubscriptionGoogle_QueryKeyAllowedOnV1Beta(t *testing.T) { gin.SetMode(gin.TestMode) r := gin.New() apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{ getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { return &service.APIKey{ ID: 1, Key: key, Status: service.StatusActive, User: &service.User{ ID: 123, Status: service.StatusActive, }, }, nil }, }) cfg := &config.Config{RunMode: config.RunModeSimple} r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg)) r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) req := httptest.NewRequest(http.MethodGet, "/v1beta/test?key=valid", nil) rec := httptest.NewRecorder() r.ServeHTTP(rec, req) require.Equal(t, http.StatusOK, rec.Code) } func TestApiKeyAuthWithSubscriptionGoogle_InvalidKey(t *testing.T) { gin.SetMode(gin.TestMode) r := gin.New() apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{ getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { return nil, service.ErrAPIKeyNotFound }, }) r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{})) r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil) req.Header.Set("Authorization", "Bearer invalid") rec := httptest.NewRecorder() r.ServeHTTP(rec, req) require.Equal(t, http.StatusUnauthorized, rec.Code) var resp googleErrorResponse require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) require.Equal(t, http.StatusUnauthorized, resp.Error.Code) require.Equal(t, "Invalid API key", resp.Error.Message) require.Equal(t, "UNAUTHENTICATED", resp.Error.Status) } func TestApiKeyAuthWithSubscriptionGoogle_RepoError(t *testing.T) { gin.SetMode(gin.TestMode) r := gin.New() apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{ getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { return nil, errors.New("db down") }, }) r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{})) r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil) req.Header.Set("Authorization", "Bearer any") rec := httptest.NewRecorder() r.ServeHTTP(rec, req) require.Equal(t, http.StatusInternalServerError, rec.Code) var resp googleErrorResponse require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) require.Equal(t, http.StatusInternalServerError, resp.Error.Code) require.Equal(t, "Failed to validate API key", resp.Error.Message) require.Equal(t, "INTERNAL", resp.Error.Status) } func TestApiKeyAuthWithSubscriptionGoogle_DisabledKey(t *testing.T) { gin.SetMode(gin.TestMode) r := gin.New() apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{ getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { return &service.APIKey{ ID: 1, Key: key, Status: service.StatusDisabled, User: &service.User{ ID: 123, Status: service.StatusActive, }, }, nil }, }) r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{})) r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil) req.Header.Set("Authorization", "Bearer disabled") rec := httptest.NewRecorder() r.ServeHTTP(rec, req) require.Equal(t, http.StatusUnauthorized, rec.Code) var resp googleErrorResponse require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) require.Equal(t, http.StatusUnauthorized, resp.Error.Code) require.Equal(t, "API key is disabled", resp.Error.Message) require.Equal(t, "UNAUTHENTICATED", resp.Error.Status) } func TestApiKeyAuthWithSubscriptionGoogle_InsufficientBalance(t *testing.T) { gin.SetMode(gin.TestMode) r := gin.New() apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{ getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { return &service.APIKey{ ID: 1, Key: key, Status: service.StatusActive, User: &service.User{ ID: 123, Status: service.StatusActive, Balance: 0, }, }, nil }, }) r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{})) r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil) req.Header.Set("Authorization", "Bearer ok") rec := httptest.NewRecorder() r.ServeHTTP(rec, req) require.Equal(t, http.StatusForbidden, rec.Code) var resp googleErrorResponse require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) require.Equal(t, http.StatusForbidden, resp.Error.Code) require.Equal(t, "Insufficient account balance", resp.Error.Message) require.Equal(t, "PERMISSION_DENIED", resp.Error.Status) }