fix: normalize repository email lookups
This commit is contained in:
@@ -12,6 +12,7 @@ import (
|
|||||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||||
dbgroup "github.com/Wei-Shaw/sub2api/ent/group"
|
dbgroup "github.com/Wei-Shaw/sub2api/ent/group"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||||
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
|
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
||||||
@@ -104,10 +105,20 @@ func (r *userRepository) GetByID(ctx context.Context, id int64) (*service.User,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *userRepository) GetByEmail(ctx context.Context, email string) (*service.User, error) {
|
func (r *userRepository) GetByEmail(ctx context.Context, email string) (*service.User, error) {
|
||||||
m, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Only(ctx)
|
matches, err := r.client.User.Query().
|
||||||
|
Where(userEmailLookupPredicate(email)).
|
||||||
|
Order(dbent.Asc(dbuser.FieldID)).
|
||||||
|
All(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if len(matches) == 0 {
|
||||||
|
return nil, service.ErrUserNotFound
|
||||||
|
}
|
||||||
|
if len(matches) > 1 {
|
||||||
|
return nil, fmt.Errorf("normalized email lookup matched multiple users for %q", strings.TrimSpace(email))
|
||||||
|
}
|
||||||
|
m := matches[0]
|
||||||
|
|
||||||
out := userEntityToService(m)
|
out := userEntityToService(m)
|
||||||
groups, err := r.loadAllowedGroups(ctx, []int64{m.ID})
|
groups, err := r.loadAllowedGroups(ctx, []int64{m.ID})
|
||||||
@@ -469,7 +480,20 @@ func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
|
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
|
||||||
return r.client.User.Query().Where(dbuser.EmailEQ(email)).Exist(ctx)
|
return r.client.User.Query().Where(userEmailLookupPredicate(email)).Exist(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func userEmailLookupPredicate(email string) predicate.User {
|
||||||
|
normalized := strings.TrimSpace(email)
|
||||||
|
if normalized == "" {
|
||||||
|
return dbuser.EmailEQ(email)
|
||||||
|
}
|
||||||
|
return predicate.User(func(s *entsql.Selector) {
|
||||||
|
s.Where(entsql.ExprP(
|
||||||
|
fmt.Sprintf("LOWER(TRIM(%s)) = LOWER(TRIM(?))", s.C(dbuser.FieldEmail)),
|
||||||
|
normalized,
|
||||||
|
))
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *userRepository) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
|
func (r *userRepository) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
|
||||||
|
|||||||
@@ -0,0 +1,69 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/enttest"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"entgo.io/ent/dialect"
|
||||||
|
entsql "entgo.io/ent/dialect/sql"
|
||||||
|
_ "modernc.org/sqlite"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newUserEntRepo(t *testing.T) (*userRepository, *dbent.Client) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
db, err := sql.Open("sqlite", "file:user_repo_email_lookup?mode=memory&cache=shared")
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() { _ = db.Close() })
|
||||||
|
|
||||||
|
_, err = db.Exec("PRAGMA foreign_keys = ON")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
drv := entsql.OpenDB(dialect.SQLite, db)
|
||||||
|
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
|
||||||
|
t.Cleanup(func() { _ = client.Close() })
|
||||||
|
|
||||||
|
return newUserRepositoryWithSQL(client, db), client
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUserRepositoryGetByEmailNormalizesLegacySpacingAndCase(t *testing.T) {
|
||||||
|
repo, _ := newUserEntRepo(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
err := repo.Create(ctx, &service.User{
|
||||||
|
Email: " Legacy@Example.com ",
|
||||||
|
Username: "legacy-user",
|
||||||
|
PasswordHash: "hash",
|
||||||
|
Role: service.RoleUser,
|
||||||
|
Status: service.StatusActive,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
got, err := repo.GetByEmail(ctx, "legacy@example.com")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, " Legacy@Example.com ", got.Email)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUserRepositoryExistsByEmailNormalizesLegacySpacingAndCase(t *testing.T) {
|
||||||
|
repo, _ := newUserEntRepo(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
err := repo.Create(ctx, &service.User{
|
||||||
|
Email: " Legacy@Example.com ",
|
||||||
|
Username: "legacy-user",
|
||||||
|
PasswordHash: "hash",
|
||||||
|
Role: service.RoleUser,
|
||||||
|
Status: service.StatusActive,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
exists, err := repo.ExistsByEmail(ctx, " LEGACY@example.com ")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, exists)
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user