fix: normalize pending oauth email lookups
This commit is contained in:
@@ -3,6 +3,7 @@ package handler
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -12,12 +13,14 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/ent/authidentity"
|
||||
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
|
||||
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
|
||||
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
entsql "entgo.io/ent/dialect/sql"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -531,11 +534,9 @@ func resolvePendingOAuthTargetUserID(ctx context.Context, client *dbent.Client,
|
||||
return 0, infraerrors.BadRequest("PENDING_AUTH_TARGET_USER_MISSING", "pending auth target user is missing")
|
||||
}
|
||||
|
||||
userEntity, err := client.User.Query().
|
||||
Where(dbuser.EmailEQ(email)).
|
||||
Only(ctx)
|
||||
userEntity, err := findUserByNormalizedEmail(ctx, client, email)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
if errors.Is(err, service.ErrUserNotFound) {
|
||||
return 0, infraerrors.InternalServer("PENDING_AUTH_TARGET_USER_NOT_FOUND", "pending auth target user was not found")
|
||||
}
|
||||
return 0, err
|
||||
@@ -543,6 +544,40 @@ func resolvePendingOAuthTargetUserID(ctx context.Context, client *dbent.Client,
|
||||
return userEntity.ID, nil
|
||||
}
|
||||
|
||||
func userNormalizedEmailPredicate(email string) predicate.User {
|
||||
normalized := strings.TrimSpace(email)
|
||||
if normalized == "" {
|
||||
return dbuser.EmailEQ(email)
|
||||
}
|
||||
return predicate.User(func(s *entsql.Selector) {
|
||||
s.Where(entsql.ExprP(
|
||||
fmt.Sprintf("LOWER(TRIM(%s)) = LOWER(TRIM(?))", s.C(dbuser.FieldEmail)),
|
||||
normalized,
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
func findUserByNormalizedEmail(ctx context.Context, client *dbent.Client, email string) (*dbent.User, error) {
|
||||
if client == nil {
|
||||
return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
|
||||
}
|
||||
|
||||
matches, err := client.User.Query().
|
||||
Where(userNormalizedEmailPredicate(email)).
|
||||
Order(dbent.Asc(dbuser.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(matches) == 0 {
|
||||
return nil, service.ErrUserNotFound
|
||||
}
|
||||
if len(matches) > 1 {
|
||||
return nil, infraerrors.Conflict("USER_EMAIL_CONFLICT", "normalized email matched multiple users")
|
||||
}
|
||||
return matches[0], nil
|
||||
}
|
||||
|
||||
func oauthIdentityIssuer(session *dbent.PendingAuthSession) *string {
|
||||
if session == nil {
|
||||
return nil
|
||||
@@ -1102,8 +1137,8 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
|
||||
}
|
||||
|
||||
email := strings.TrimSpace(strings.ToLower(req.Email))
|
||||
existingUser, err := client.User.Query().Where(dbuser.EmailEQ(email)).Only(c.Request.Context())
|
||||
if err != nil && !dbent.IsNotFound(err) {
|
||||
existingUser, err := findUserByNormalizedEmail(c.Request.Context(), client, email)
|
||||
if err != nil && !errors.Is(err, service.ErrUserNotFound) {
|
||||
response.ErrorFrom(c, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable"))
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user