fix: normalize repository email lookups
This commit is contained in:
@@ -12,6 +12,7 @@ import (
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||
dbgroup "github.com/Wei-Shaw/sub2api/ent/group"
|
||||
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
|
||||
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
||||
@@ -104,10 +105,20 @@ func (r *userRepository) GetByID(ctx context.Context, id int64) (*service.User,
|
||||
}
|
||||
|
||||
func (r *userRepository) GetByEmail(ctx context.Context, email string) (*service.User, error) {
|
||||
m, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Only(ctx)
|
||||
matches, err := r.client.User.Query().
|
||||
Where(userEmailLookupPredicate(email)).
|
||||
Order(dbent.Asc(dbuser.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||
return nil, err
|
||||
}
|
||||
if len(matches) == 0 {
|
||||
return nil, service.ErrUserNotFound
|
||||
}
|
||||
if len(matches) > 1 {
|
||||
return nil, fmt.Errorf("normalized email lookup matched multiple users for %q", strings.TrimSpace(email))
|
||||
}
|
||||
m := matches[0]
|
||||
|
||||
out := userEntityToService(m)
|
||||
groups, err := r.loadAllowedGroups(ctx, []int64{m.ID})
|
||||
@@ -469,7 +480,20 @@ func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount
|
||||
}
|
||||
|
||||
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
|
||||
return r.client.User.Query().Where(dbuser.EmailEQ(email)).Exist(ctx)
|
||||
return r.client.User.Query().Where(userEmailLookupPredicate(email)).Exist(ctx)
|
||||
}
|
||||
|
||||
func userEmailLookupPredicate(email string) predicate.User {
|
||||
normalized := strings.TrimSpace(email)
|
||||
if normalized == "" {
|
||||
return dbuser.EmailEQ(email)
|
||||
}
|
||||
return predicate.User(func(s *entsql.Selector) {
|
||||
s.Where(entsql.ExprP(
|
||||
fmt.Sprintf("LOWER(TRIM(%s)) = LOWER(TRIM(?))", s.C(dbuser.FieldEmail)),
|
||||
normalized,
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
func (r *userRepository) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
|
||||
|
||||
Reference in New Issue
Block a user