fix: normalize repository email lookups

This commit is contained in:
IanShaw027
2026-04-20 21:51:57 +08:00
parent b309822199
commit 31d0183d45
2 changed files with 96 additions and 3 deletions

View File

@@ -12,6 +12,7 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/apikey"
dbgroup "github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/predicate"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
@@ -104,10 +105,20 @@ func (r *userRepository) GetByID(ctx context.Context, id int64) (*service.User,
}
func (r *userRepository) GetByEmail(ctx context.Context, email string) (*service.User, error) {
m, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Only(ctx)
matches, err := r.client.User.Query().
Where(userEmailLookupPredicate(email)).
Order(dbent.Asc(dbuser.FieldID)).
All(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
return nil, err
}
if len(matches) == 0 {
return nil, service.ErrUserNotFound
}
if len(matches) > 1 {
return nil, fmt.Errorf("normalized email lookup matched multiple users for %q", strings.TrimSpace(email))
}
m := matches[0]
out := userEntityToService(m)
groups, err := r.loadAllowedGroups(ctx, []int64{m.ID})
@@ -469,7 +480,20 @@ func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount
}
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
return r.client.User.Query().Where(dbuser.EmailEQ(email)).Exist(ctx)
return r.client.User.Query().Where(userEmailLookupPredicate(email)).Exist(ctx)
}
func userEmailLookupPredicate(email string) predicate.User {
normalized := strings.TrimSpace(email)
if normalized == "" {
return dbuser.EmailEQ(email)
}
return predicate.User(func(s *entsql.Selector) {
s.Where(entsql.ExprP(
fmt.Sprintf("LOWER(TRIM(%s)) = LOWER(TRIM(?))", s.C(dbuser.FieldEmail)),
normalized,
))
})
}
func (r *userRepository) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {