From 31d0183d45b3c5d2a6692daebec58625a393fe38 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 21:51:57 +0800 Subject: [PATCH] fix: normalize repository email lookups --- backend/internal/repository/user_repo.go | 30 +++++++- .../user_repo_email_lookup_unit_test.go | 69 +++++++++++++++++++ 2 files changed, 96 insertions(+), 3 deletions(-) create mode 100644 backend/internal/repository/user_repo_email_lookup_unit_test.go diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go index 0c607ecc..b5efd19d 100644 --- a/backend/internal/repository/user_repo.go +++ b/backend/internal/repository/user_repo.go @@ -12,6 +12,7 @@ import ( dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/apikey" dbgroup "github.com/Wei-Shaw/sub2api/ent/group" + "github.com/Wei-Shaw/sub2api/ent/predicate" dbuser "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/userallowedgroup" "github.com/Wei-Shaw/sub2api/ent/usersubscription" @@ -104,10 +105,20 @@ func (r *userRepository) GetByID(ctx context.Context, id int64) (*service.User, } func (r *userRepository) GetByEmail(ctx context.Context, email string) (*service.User, error) { - m, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Only(ctx) + matches, err := r.client.User.Query(). + Where(userEmailLookupPredicate(email)). + Order(dbent.Asc(dbuser.FieldID)). + All(ctx) if err != nil { - return nil, translatePersistenceError(err, service.ErrUserNotFound, nil) + return nil, err } + if len(matches) == 0 { + return nil, service.ErrUserNotFound + } + if len(matches) > 1 { + return nil, fmt.Errorf("normalized email lookup matched multiple users for %q", strings.TrimSpace(email)) + } + m := matches[0] out := userEntityToService(m) groups, err := r.loadAllowedGroups(ctx, []int64{m.ID}) @@ -469,7 +480,20 @@ func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount } func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) { - return r.client.User.Query().Where(dbuser.EmailEQ(email)).Exist(ctx) + return r.client.User.Query().Where(userEmailLookupPredicate(email)).Exist(ctx) +} + +func userEmailLookupPredicate(email string) predicate.User { + normalized := strings.TrimSpace(email) + if normalized == "" { + return dbuser.EmailEQ(email) + } + return predicate.User(func(s *entsql.Selector) { + s.Where(entsql.ExprP( + fmt.Sprintf("LOWER(TRIM(%s)) = LOWER(TRIM(?))", s.C(dbuser.FieldEmail)), + normalized, + )) + }) } func (r *userRepository) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error { diff --git a/backend/internal/repository/user_repo_email_lookup_unit_test.go b/backend/internal/repository/user_repo_email_lookup_unit_test.go new file mode 100644 index 00000000..d42ce9ac --- /dev/null +++ b/backend/internal/repository/user_repo_email_lookup_unit_test.go @@ -0,0 +1,69 @@ +package repository + +import ( + "context" + "database/sql" + "testing" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/enttest" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" +) + +func newUserEntRepo(t *testing.T) (*userRepository, *dbent.Client) { + t.Helper() + + db, err := sql.Open("sqlite", "file:user_repo_email_lookup?mode=memory&cache=shared") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + t.Cleanup(func() { _ = client.Close() }) + + return newUserRepositoryWithSQL(client, db), client +} + +func TestUserRepositoryGetByEmailNormalizesLegacySpacingAndCase(t *testing.T) { + repo, _ := newUserEntRepo(t) + ctx := context.Background() + + err := repo.Create(ctx, &service.User{ + Email: " Legacy@Example.com ", + Username: "legacy-user", + PasswordHash: "hash", + Role: service.RoleUser, + Status: service.StatusActive, + }) + require.NoError(t, err) + + got, err := repo.GetByEmail(ctx, "legacy@example.com") + require.NoError(t, err) + require.Equal(t, " Legacy@Example.com ", got.Email) +} + +func TestUserRepositoryExistsByEmailNormalizesLegacySpacingAndCase(t *testing.T) { + repo, _ := newUserEntRepo(t) + ctx := context.Background() + + err := repo.Create(ctx, &service.User{ + Email: " Legacy@Example.com ", + Username: "legacy-user", + PasswordHash: "hash", + Role: service.RoleUser, + Status: service.StatusActive, + }) + require.NoError(t, err) + + exists, err := repo.ExistsByEmail(ctx, " LEGACY@example.com ") + require.NoError(t, err) + require.True(t, exists) +}