fix: 完善邮箱快捷登录注册流程

This commit is contained in:
lyen1688
2026-05-06 20:50:41 +08:00
committed by lyen1688
parent 81edaa8986
commit e69256a706
5 changed files with 417 additions and 63 deletions

View File

@@ -9,6 +9,7 @@ import (
"net/url"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
@@ -168,10 +169,22 @@ func (h *AuthHandler) emailOAuthCallbackWithProfile(
UpstreamMetadata: profile.Metadata,
}
affiliateCode := h.emailOAuthAffiliateCode(c)
if shouldCreate, err := h.emailOAuthShouldCreatePendingRegistration(c.Request.Context(), input); err != nil {
redirectOAuthError(c, frontendCallback, infraerrors.Reason(err), infraerrors.Message(err), "")
return
} else if shouldCreate {
if pendingErr := h.createEmailOAuthRegistrationPendingSession(c, provider, frontendCallback, redirectTo, profile); pendingErr != nil {
redirectOAuthError(c, frontendCallback, infraerrors.Reason(pendingErr), infraerrors.Message(pendingErr), "")
return
}
redirectToFrontendCallback(c, frontendCallback)
return
}
tokenPair, user, err := h.authService.LoginOrRegisterVerifiedEmailOAuthWithInvitation(c.Request.Context(), input, "", affiliateCode)
if err != nil {
if errors.Is(err, service.ErrOAuthInvitationRequired) {
if pendingErr := h.createEmailOAuthInvitationPendingSession(c, provider, frontendCallback, redirectTo, profile); pendingErr != nil {
if pendingErr := h.createEmailOAuthRegistrationPendingSession(c, provider, frontendCallback, redirectTo, profile); pendingErr != nil {
redirectOAuthError(c, frontendCallback, infraerrors.Reason(pendingErr), infraerrors.Message(pendingErr), "")
return
}
@@ -195,6 +208,35 @@ func (h *AuthHandler) emailOAuthCallbackWithProfile(
redirectWithFragment(c, frontendCallback, fragment)
}
func (h *AuthHandler) emailOAuthShouldCreatePendingRegistration(ctx context.Context, input service.EmailOAuthIdentityInput) (bool, error) {
client := h.entClient()
if client == nil {
return false, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
}
identityUser, err := h.findOAuthIdentityUser(ctx, service.PendingAuthIdentityKey{
ProviderType: strings.TrimSpace(input.ProviderType),
ProviderKey: strings.TrimSpace(input.ProviderKey),
ProviderSubject: strings.TrimSpace(input.ProviderSubject),
})
if err != nil {
return false, err
}
email := strings.TrimSpace(strings.ToLower(input.Email))
if identityUser != nil {
if !strings.EqualFold(strings.TrimSpace(identityUser.Email), email) {
return false, infraerrors.Conflict("AUTH_IDENTITY_EMAIL_MISMATCH", "oauth identity belongs to a different email")
}
return false, nil
}
if _, err := findUserByNormalizedEmail(ctx, client, email); err != nil {
if errors.Is(err, service.ErrUserNotFound) {
return true, nil
}
return false, err
}
return false, nil
}
func (h *AuthHandler) emailOAuthAffiliateCode(c *gin.Context) string {
if c == nil {
return ""
@@ -205,7 +247,7 @@ func (h *AuthHandler) emailOAuthAffiliateCode(c *gin.Context) string {
return ""
}
func (h *AuthHandler) createEmailOAuthInvitationPendingSession(
func (h *AuthHandler) createEmailOAuthRegistrationPendingSession(
c *gin.Context,
provider string,
frontendCallback string,
@@ -247,14 +289,22 @@ func (h *AuthHandler) createEmailOAuthInvitationPendingSession(
}
}
invitationRequired := h != nil && h.settingSvc != nil && h.settingSvc.IsInvitationCodeEnabled(c.Request.Context())
pendingError := "registration_completion_required"
choiceReason := "registration_completion_required"
if invitationRequired {
pendingError = "invitation_required"
choiceReason = "invitation_required"
}
completionResponse := map[string]any{
"step": oauthPendingChoiceStep,
"error": "invitation_required",
"choice_reason": "invitation_required",
"error": pendingError,
"choice_reason": choiceReason,
"adoption_required": false,
"create_account_allowed": true,
"existing_account_bindable": false,
"force_email_on_signup": true,
"invitation_required": invitationRequired,
"email": email,
"resolved_email": email,
"provider": provider,
@@ -276,7 +326,8 @@ func (h *AuthHandler) createEmailOAuthInvitationPendingSession(
}
type completeEmailOAuthRequest struct {
InvitationCode string `json:"invitation_code" binding:"required"`
Password string `json:"password" binding:"required,min=6"`
InvitationCode string `json:"invitation_code,omitempty"`
AffCode string `json:"aff_code,omitempty"`
}
@@ -310,21 +361,12 @@ func (h *AuthHandler) completeEmailOAuthRegistration(c *gin.Context, provider st
affiliateCode = pendingSessionStringValue(session.UpstreamIdentityClaims, "aff_code")
}
tokenPair, user, err := h.authService.LoginOrRegisterVerifiedEmailOAuthWithInvitation(
tokenPair, user, err := h.authService.RegisterVerifiedOAuthEmailAccount(
c.Request.Context(),
service.EmailOAuthIdentityInput{
ProviderType: strings.TrimSpace(session.ProviderType),
ProviderKey: strings.TrimSpace(session.ProviderKey),
ProviderSubject: strings.TrimSpace(session.ProviderSubject),
Email: strings.TrimSpace(session.ResolvedEmail),
EmailVerified: true,
Username: pendingSessionStringValue(session.UpstreamIdentityClaims, "username"),
DisplayName: pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_display_name"),
AvatarURL: pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_avatar_url"),
UpstreamMetadata: clonePendingMap(session.UpstreamIdentityClaims),
},
strings.TrimSpace(session.ResolvedEmail),
req.Password,
strings.TrimSpace(req.InvitationCode),
affiliateCode,
strings.TrimSpace(session.ProviderType),
)
if err != nil {
response.ErrorFrom(c, err)
@@ -342,13 +384,46 @@ func (h *AuthHandler) completeEmailOAuthRegistration(c *gin.Context, provider st
return
}
defer func() { _ = tx.Rollback() }()
txCtx := dbent.NewTxContext(c.Request.Context(), tx)
sessionForBinding := *session
sessionForBinding.UpstreamIdentityClaims = clonePendingMap(session.UpstreamIdentityClaims)
if strings.TrimSpace(req.InvitationCode) != "" {
sessionForBinding.UpstreamIdentityClaims["invitation_code"] = strings.TrimSpace(req.InvitationCode)
}
decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{})
if err != nil {
_ = tx.Rollback()
_ = h.authService.RollbackOAuthEmailAccountCreation(c.Request.Context(), user.ID, strings.TrimSpace(req.InvitationCode))
response.ErrorFrom(c, err)
return
}
if err := applyPendingOAuthBinding(txCtx, client, h.authService, h.userService, &sessionForBinding, decision, &user.ID, true, false); err != nil {
_ = tx.Rollback()
_ = h.authService.RollbackOAuthEmailAccountCreation(c.Request.Context(), user.ID, strings.TrimSpace(req.InvitationCode))
respondPendingOAuthBindingApplyError(c, err)
return
}
if err := h.authService.FinalizeOAuthEmailAccount(
txCtx,
user,
strings.TrimSpace(req.InvitationCode),
strings.TrimSpace(session.ProviderType),
affiliateCode,
); err != nil {
_ = tx.Rollback()
_ = h.authService.RollbackOAuthEmailAccountCreation(c.Request.Context(), user.ID, strings.TrimSpace(req.InvitationCode))
response.ErrorFrom(c, err)
return
}
if err := consumePendingOAuthBrowserSessionTx(c.Request.Context(), tx, session); err != nil {
_ = tx.Rollback()
_ = h.authService.RollbackOAuthEmailAccountCreation(c.Request.Context(), user.ID, strings.TrimSpace(req.InvitationCode))
clearCookies()
response.ErrorFrom(c, err)
return
}
if err := tx.Commit(); err != nil {
_ = h.authService.RollbackOAuthEmailAccountCreation(c.Request.Context(), user.ID, strings.TrimSpace(req.InvitationCode))
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to consume pending oauth session").WithCause(err))
return
}
@@ -438,17 +513,17 @@ func parseGitHubOAuthProfile(ctx context.Context, cfg config.EmailOAuthProviderC
if subject == "" {
return nil, errors.New("github user id is missing")
}
email := strings.TrimSpace(gjson.Get(body, "email").String())
emailVerified := email != ""
if strings.TrimSpace(cfg.EmailsURL) != "" {
if verifiedEmail, err := fetchGitHubPrimaryVerifiedEmail(ctx, cfg.EmailsURL, token.AccessToken); err == nil && verifiedEmail != "" {
email = verifiedEmail
emailVerified = true
} else if email == "" && err != nil {
return nil, err
}
email := ""
emailsURL := strings.TrimSpace(cfg.EmailsURL)
if emailsURL == "" {
return nil, errors.New("github verified email is missing")
}
if email == "" || !emailVerified {
verifiedEmail, err := fetchGitHubPrimaryVerifiedEmail(ctx, emailsURL, token.AccessToken)
if err != nil {
return nil, err
}
email = verifiedEmail
if email == "" {
return nil, errors.New("github verified email is missing")
}
login := strings.TrimSpace(gjson.Get(body, "login").String())

View File

@@ -73,6 +73,7 @@ func TestEmailOAuthCallbackRequiresPendingRegistrationWhenInvitationEnabled(t *t
require.True(t, ok)
require.Equal(t, oauthPendingChoiceStep, completion["step"])
require.Equal(t, "invitation_required", completion["error"])
require.Equal(t, true, completion["invitation_required"])
require.Equal(t, "fresh@example.com", completion["email"])
require.Equal(t, "fresh@example.com", completion["resolved_email"])
require.Equal(t, true, completion["create_account_allowed"])
@@ -129,7 +130,7 @@ func TestEmailOAuthCallbackExistingEmailLogsInWhenInvitationEnabled(t *testing.T
_ = user
}
func TestEmailOAuthCallbackAutoRegistrationAppliesAffiliateCode(t *testing.T) {
func TestEmailOAuthCallbackCreatesPasswordRegistrationSessionForNewEmail(t *testing.T) {
affiliateRepo := newOAuthEmailAffiliateRepoStub(map[string]int64{"AFF123": 1001})
handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
settingValues: map[string]string{
@@ -161,11 +162,26 @@ func TestEmailOAuthCallbackAutoRegistrationAppliesAffiliateCode(t *testing.T) {
})
require.Equal(t, http.StatusFound, recorder.Code)
require.Contains(t, recorder.Header().Get("Location"), "access_token=")
user, err := client.User.Query().Where(dbuser.EmailEQ("aff-user@example.com")).Only(ctx)
require.NotContains(t, recorder.Header().Get("Location"), "access_token=")
userCount, err := client.User.Query().Where(dbuser.EmailEQ("aff-user@example.com")).Count(ctx)
require.NoError(t, err)
require.Equal(t, []int64{user.ID, user.ID}, affiliateRepo.ensureUserIDs)
require.Equal(t, []oauthEmailAffiliateBindCall{{userID: user.ID, inviterID: 1001}}, affiliateRepo.bindCalls)
require.Zero(t, userCount)
require.Empty(t, affiliateRepo.ensureUserIDs)
require.Empty(t, affiliateRepo.bindCalls)
session, err := client.PendingAuthSession.Query().Only(ctx)
require.NoError(t, err)
require.Equal(t, "aff-user@example.com", session.ResolvedEmail)
require.Equal(t, "AFF123", pendingSessionStringValue(session.UpstreamIdentityClaims, "aff_code"))
completion, ok := readCompletionResponse(session.LocalFlowState)
require.True(t, ok)
require.Equal(t, oauthPendingChoiceStep, completion["step"])
require.Equal(t, "registration_completion_required", completion["error"])
require.Equal(t, false, completion["invitation_required"])
require.Equal(t, true, completion["create_account_allowed"])
require.Equal(t, true, completion["force_email_on_signup"])
require.Equal(t, "aff-user@example.com", completion["resolved_email"])
}
func TestCompleteEmailOAuthRegistrationUsesAffiliateCodeFromPendingSession(t *testing.T) {
@@ -216,7 +232,7 @@ func TestCompleteEmailOAuthRegistrationUsesAffiliateCodeFromPendingSession(t *te
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/google/complete-registration", strings.NewReader(`{"invitation_code":"INVITE456"}`))
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/google/complete-registration", strings.NewReader(`{"password":"secret-123","invitation_code":"INVITE456","email":"tampered@example.com"}`))
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-aff-key")})
@@ -227,6 +243,11 @@ func TestCompleteEmailOAuthRegistrationUsesAffiliateCodeFromPendingSession(t *te
require.Equal(t, http.StatusOK, recorder.Code)
user, err := client.User.Query().Where(dbuser.EmailEQ("pending-aff@example.com")).Only(ctx)
require.NoError(t, err)
require.NotEmpty(t, user.PasswordHash)
require.NotEqual(t, "secret-123", user.PasswordHash)
tamperedCount, err := client.User.Query().Where(dbuser.EmailEQ("tampered@example.com")).Count(ctx)
require.NoError(t, err)
require.Zero(t, tamperedCount)
require.Equal(t, []oauthEmailAffiliateBindCall{{userID: user.ID, inviterID: 2002}}, affiliateRepo.bindCalls)
storedInvitation, err := client.RedeemCode.Query().Where(redeemcode.IDEQ(invitation.ID)).Only(ctx)
require.NoError(t, err)
@@ -234,6 +255,66 @@ func TestCompleteEmailOAuthRegistrationUsesAffiliateCodeFromPendingSession(t *te
require.Equal(t, user.ID, *storedInvitation.UsedBy)
}
func TestCompleteEmailOAuthRegistrationRequiresPassword(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
session, err := client.PendingAuthSession.Create().
SetSessionToken("email-oauth-password-session-token").
SetIntent(oauthIntentLogin).
SetProviderType("github").
SetProviderKey("github").
SetProviderSubject("github-password-user").
SetResolvedEmail("password-required@example.com").
SetRedirectTo("/dashboard").
SetBrowserSessionKey("browser-password-key").
SetUpstreamIdentityClaims(map[string]any{
"email": "password-required@example.com",
"email_verified": true,
"username": "password-required",
"provider": "github",
"provider_key": "github",
"provider_subject": "github-password-user",
}).
SetLocalFlowState(map[string]any{
"step": oauthPendingChoiceStep,
"error": "registration_completion_required",
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/github/complete-registration", strings.NewReader(`{}`))
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-password-key")})
c.Request = req
handler.completeEmailOAuthRegistration(c, "github")
require.Equal(t, http.StatusBadRequest, recorder.Code)
userCount, err := client.User.Query().Where(dbuser.EmailEQ("password-required@example.com")).Count(ctx)
require.NoError(t, err)
require.Zero(t, userCount)
}
func TestParseGitHubOAuthProfileRejectsPublicEmailWhenEmailsEndpointFails(t *testing.T) {
emailServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
http.Error(w, "missing scope", http.StatusForbidden)
}))
t.Cleanup(emailServer.Close)
profile, err := parseGitHubOAuthProfile(context.Background(), config.EmailOAuthProviderConfig{
EmailsURL: emailServer.URL,
}, &emailOAuthTokenResponse{AccessToken: "token"}, `{"id":123,"login":"octo","email":"public@example.com"}`)
require.Error(t, err)
require.Nil(t, profile)
require.Contains(t, err.Error(), "github emails endpoint status 403")
}
type oauthEmailAffiliateBindCall struct {
userID int64
inviterID int64

View File

@@ -10,6 +10,7 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
func normalizeOAuthSignupSource(signupSource string) string {
@@ -168,6 +169,87 @@ func (s *AuthService) RegisterOAuthEmailAccount(
return tokenPair, user, nil
}
// RegisterVerifiedOAuthEmailAccount creates a local account from an OAuth
// provider that has already returned a verified email address.
func (s *AuthService) RegisterVerifiedOAuthEmailAccount(
ctx context.Context,
email string,
password string,
invitationCode string,
signupSource string,
) (*TokenPair, *User, error) {
if s == nil {
return nil, nil, ErrServiceUnavailable
}
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
return nil, nil, ErrRegDisabled
}
email = strings.TrimSpace(strings.ToLower(email))
if email == "" || len(email) > 255 {
return nil, nil, ErrEmailVerifyRequired
}
if _, err := mail.ParseAddress(email); err != nil {
return nil, nil, ErrEmailVerifyRequired
}
if isReservedEmail(email) {
return nil, nil, ErrEmailReserved
}
if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil {
return nil, nil, err
}
if strings.TrimSpace(password) == "" {
return nil, nil, infraerrors.BadRequest("PASSWORD_REQUIRED", "password is required")
}
if _, err := s.validateOAuthRegistrationInvitation(ctx, invitationCode); err != nil {
return nil, nil, err
}
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
if err != nil {
return nil, nil, ErrServiceUnavailable
}
if existsEmail {
return nil, nil, ErrEmailExists
}
hashedPassword, err := s.HashPassword(password)
if err != nil {
return nil, nil, fmt.Errorf("hash password: %w", err)
}
signupSource = normalizeOAuthSignupSource(signupSource)
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
var defaultRPMLimit int
if s.settingService != nil {
defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx)
}
user := &User{
Email: email,
PasswordHash: hashedPassword,
Role: RoleUser,
Balance: grantPlan.Balance,
Concurrency: grantPlan.Concurrency,
RPMLimit: defaultRPMLimit,
Status: StatusActive,
SignupSource: signupSource,
}
if err := s.userRepo.Create(ctx, user); err != nil {
if errors.Is(err, ErrEmailExists) {
return nil, nil, ErrEmailExists
}
return nil, nil, ErrServiceUnavailable
}
tokenPair, err := s.GenerateTokenPair(ctx, user, "")
if err != nil {
_ = s.RollbackOAuthEmailAccountCreation(ctx, user.ID, "")
return nil, nil, fmt.Errorf("generate token pair: %w", err)
}
return tokenPair, user, nil
}
// FinalizeOAuthEmailAccount applies invitation usage and normal signup bootstrap
// only after the pending OAuth flow has fully reached its last reversible step.
func (s *AuthService) FinalizeOAuthEmailAccount(