feat: add pending oauth email onboarding flow

This commit is contained in:
IanShaw027
2026-04-20 19:30:09 +08:00
parent d47580a144
commit 6a75bd77e3
13 changed files with 1273 additions and 119 deletions

View File

@@ -46,6 +46,36 @@ type oauthAdoptionDecisionRequest struct {
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
}
type bindPendingOAuthLoginRequest struct {
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required"`
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
}
type createPendingOAuthAccountRequest struct {
Email string `json:"email" binding:"required,email"`
VerifyCode string `json:"verify_code,omitempty"`
Password string `json:"password" binding:"required,min=6"`
InvitationCode string `json:"invitation_code,omitempty"`
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
}
func (r bindPendingOAuthLoginRequest) adoptionDecision() oauthAdoptionDecisionRequest {
return oauthAdoptionDecisionRequest{
AdoptDisplayName: r.AdoptDisplayName,
AdoptAvatar: r.AdoptAvatar,
}
}
func (r createPendingOAuthAccountRequest) adoptionDecision() oauthAdoptionDecisionRequest {
return oauthAdoptionDecisionRequest{
AdoptDisplayName: r.AdoptDisplayName,
AdoptAvatar: r.AdoptAvatar,
}
}
func (h *AuthHandler) pendingIdentityService() (*service.AuthPendingIdentityService, error) {
if h == nil || h.authService == nil || h.authService.EntClient() == nil {
return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
@@ -170,6 +200,36 @@ func readCompletionResponse(session map[string]any) (map[string]any, bool) {
return result, true
}
func clonePendingMap(values map[string]any) map[string]any {
if len(values) == 0 {
return map[string]any{}
}
cloned := make(map[string]any, len(values))
for key, value := range values {
cloned[key] = value
}
return cloned
}
func mergePendingCompletionResponse(session *dbent.PendingAuthSession, overrides map[string]any) map[string]any {
payload, _ := readCompletionResponse(session.LocalFlowState)
merged := clonePendingMap(payload)
if strings.TrimSpace(session.RedirectTo) != "" {
if _, exists := merged["redirect"]; !exists {
merged["redirect"] = session.RedirectTo
}
}
for key, value := range overrides {
if value == nil {
delete(merged, key)
continue
}
merged[key] = value
}
applySuggestedProfileToCompletionResponse(merged, session.UpstreamIdentityClaims)
return merged
}
func pendingSessionStringValue(values map[string]any, key string) string {
if len(values) == 0 {
return ""
@@ -264,6 +324,89 @@ func (h *AuthHandler) entClient() *dbent.Client {
return h.authService.EntClient()
}
func (h *AuthHandler) isForceEmailOnThirdPartySignup(ctx context.Context) bool {
if h == nil || h.settingSvc == nil {
return false
}
defaults, err := h.settingSvc.GetAuthSourceDefaultSettings(ctx)
if err != nil || defaults == nil {
return false
}
return defaults.ForceEmailOnThirdPartySignup
}
func (h *AuthHandler) findOAuthIdentityUser(ctx context.Context, identity service.PendingAuthIdentityKey) (*dbent.User, error) {
client := h.entClient()
if client == nil {
return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
}
record, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ(strings.TrimSpace(identity.ProviderType)),
authidentity.ProviderKeyEQ(strings.TrimSpace(identity.ProviderKey)),
authidentity.ProviderSubjectEQ(strings.TrimSpace(identity.ProviderSubject)),
).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, nil
}
return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
}
userEntity, err := client.User.Get(ctx, record.UserID)
if err != nil {
if dbent.IsNotFound(err) {
return nil, nil
}
return nil, infraerrors.InternalServer("AUTH_IDENTITY_USER_LOOKUP_FAILED", "failed to load auth identity user").WithCause(err)
}
return userEntity, nil
}
func (h *AuthHandler) createOAuthEmailRequiredPendingSession(
c *gin.Context,
identity service.PendingAuthIdentityKey,
redirectTo string,
browserSessionKey string,
upstreamClaims map[string]any,
) error {
return h.createOAuthPendingSession(c, oauthPendingSessionPayload{
Intent: oauthIntentLogin,
Identity: identity,
RedirectTo: redirectTo,
BrowserSessionKey: browserSessionKey,
UpstreamIdentityClaims: upstreamClaims,
CompletionResponse: map[string]any{
"redirect": redirectTo,
"step": "email_required",
"force_email_on_signup": true,
"email_binding_required": true,
"existing_account_bindable": true,
},
})
}
func (h *AuthHandler) BindLinuxDoOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "linuxdo") }
func (h *AuthHandler) BindOIDCOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "oidc") }
func (h *AuthHandler) BindWeChatOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "wechat") }
func (h *AuthHandler) BindPendingOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "") }
func (h *AuthHandler) CreateLinuxDoOAuthAccount(c *gin.Context) {
h.createPendingOAuthAccount(c, "linuxdo")
}
func (h *AuthHandler) CreateOIDCOAuthAccount(c *gin.Context) { h.createPendingOAuthAccount(c, "oidc") }
func (h *AuthHandler) CreateWeChatOAuthAccount(c *gin.Context) {
h.createPendingOAuthAccount(c, "wechat")
}
func (h *AuthHandler) CreatePendingOAuthAccount(c *gin.Context) {
h.createPendingOAuthAccount(c, "")
}
func (h *AuthHandler) upsertPendingOAuthAdoptionDecision(
c *gin.Context,
sessionID int64,
@@ -313,6 +456,60 @@ func (h *AuthHandler) upsertPendingOAuthAdoptionDecision(
return decision, nil
}
func (h *AuthHandler) ensurePendingOAuthAdoptionDecision(
c *gin.Context,
sessionID int64,
req oauthAdoptionDecisionRequest,
) (*dbent.IdentityAdoptionDecision, error) {
decision, err := h.upsertPendingOAuthAdoptionDecision(c, sessionID, req)
if err != nil {
return nil, err
}
if decision != nil {
return decision, nil
}
svc, err := h.pendingIdentityService()
if err != nil {
return nil, err
}
decision, err = svc.UpsertAdoptionDecision(c.Request.Context(), service.PendingIdentityAdoptionDecisionInput{
PendingAuthSessionID: sessionID,
})
if err != nil {
return nil, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_SAVE_FAILED", "failed to save oauth profile adoption decision").WithCause(err)
}
return decision, nil
}
func updatePendingOAuthSessionProgress(
ctx context.Context,
client *dbent.Client,
session *dbent.PendingAuthSession,
intent string,
resolvedEmail string,
targetUserID *int64,
completionResponse map[string]any,
) (*dbent.PendingAuthSession, error) {
if client == nil || session == nil {
return nil, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth session is invalid")
}
localFlowState := clonePendingMap(session.LocalFlowState)
localFlowState[oauthCompletionResponseKey] = clonePendingMap(completionResponse)
update := client.PendingAuthSession.UpdateOneID(session.ID).
SetIntent(strings.TrimSpace(intent)).
SetResolvedEmail(strings.TrimSpace(resolvedEmail)).
SetLocalFlowState(localFlowState)
if targetUserID != nil && *targetUserID > 0 {
update = update.SetTargetUserID(*targetUserID)
} else {
update = update.ClearTargetUserID()
}
return update.Save(ctx)
}
func resolvePendingOAuthTargetUserID(ctx context.Context, client *dbent.Client, session *dbent.PendingAuthSession) (int64, error) {
if session == nil {
return 0, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth session is invalid")
@@ -401,17 +598,18 @@ func shouldBindPendingOAuthIdentity(session *dbent.PendingAuthSession, decision
return decision.AdoptDisplayName || decision.AdoptAvatar
}
func applyPendingOAuthAdoption(
func applyPendingOAuthBinding(
ctx context.Context,
client *dbent.Client,
session *dbent.PendingAuthSession,
decision *dbent.IdentityAdoptionDecision,
overrideUserID *int64,
forceBind bool,
) error {
if client == nil || session == nil || decision == nil {
if client == nil || session == nil {
return nil
}
if !shouldBindPendingOAuthIdentity(session, decision) {
if !forceBind && !shouldBindPendingOAuthIdentity(session, decision) {
return nil
}
@@ -427,11 +625,11 @@ func applyPendingOAuthAdoption(
}
adoptedDisplayName := ""
if decision.AdoptDisplayName {
if decision != nil && decision.AdoptDisplayName {
adoptedDisplayName = normalizeAdoptedOAuthDisplayName(pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_display_name"))
}
adoptedAvatarURL := ""
if decision.AdoptAvatar {
if decision != nil && decision.AdoptAvatar {
adoptedAvatarURL = pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_avatar_url")
}
@@ -441,7 +639,7 @@ func applyPendingOAuthAdoption(
}
defer func() { _ = tx.Rollback() }()
if decision.AdoptDisplayName && adoptedDisplayName != "" {
if decision != nil && decision.AdoptDisplayName && adoptedDisplayName != "" {
if err := tx.Client().User.UpdateOneID(targetUserID).
SetUsername(adoptedDisplayName).
Exec(ctx); err != nil {
@@ -458,10 +656,10 @@ func applyPendingOAuthAdoption(
for key, value := range session.UpstreamIdentityClaims {
metadata[key] = value
}
if decision.AdoptDisplayName && adoptedDisplayName != "" {
if decision != nil && decision.AdoptDisplayName && adoptedDisplayName != "" {
metadata["display_name"] = adoptedDisplayName
}
if decision.AdoptAvatar && adoptedAvatarURL != "" {
if decision != nil && decision.AdoptAvatar && adoptedAvatarURL != "" {
metadata["avatar_url"] = adoptedAvatarURL
}
@@ -473,7 +671,7 @@ func applyPendingOAuthAdoption(
return err
}
if decision.IdentityID == nil || *decision.IdentityID != identity.ID {
if decision != nil && (decision.IdentityID == nil || *decision.IdentityID != identity.ID) {
if _, err := tx.Client().IdentityAdoptionDecision.UpdateOneID(decision.ID).
SetIdentityID(identity.ID).
Save(ctx); err != nil {
@@ -484,6 +682,16 @@ func applyPendingOAuthAdoption(
return tx.Commit()
}
func applyPendingOAuthAdoption(
ctx context.Context,
client *dbent.Client,
session *dbent.PendingAuthSession,
decision *dbent.IdentityAdoptionDecision,
overrideUserID *int64,
) error {
return applyPendingOAuthBinding(ctx, client, session, decision, overrideUserID, false)
}
func applySuggestedProfileToCompletionResponse(payload map[string]any, upstream map[string]any) {
if len(payload) == 0 || len(upstream) == 0 {
return
@@ -507,6 +715,206 @@ func applySuggestedProfileToCompletionResponse(payload map[string]any, upstream
}
}
func readPendingOAuthBrowserSession(c *gin.Context, h *AuthHandler) (*service.AuthPendingIdentityService, *dbent.PendingAuthSession, func(), error) {
secureCookie := isRequestHTTPS(c)
clearCookies := func() {
clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie)
}
sessionToken, err := readOAuthPendingSessionCookie(c)
if err != nil || strings.TrimSpace(sessionToken) == "" {
clearCookies()
return nil, nil, clearCookies, service.ErrPendingAuthSessionNotFound
}
browserSessionKey, err := readOAuthPendingBrowserCookie(c)
if err != nil || strings.TrimSpace(browserSessionKey) == "" {
clearCookies()
return nil, nil, clearCookies, service.ErrPendingAuthBrowserMismatch
}
svc, err := h.pendingIdentityService()
if err != nil {
clearCookies()
return nil, nil, clearCookies, err
}
session, err := svc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey)
if err != nil {
clearCookies()
return nil, nil, clearCookies, err
}
return svc, session, clearCookies, nil
}
func buildPendingOAuthSessionStatusPayload(session *dbent.PendingAuthSession) gin.H {
payload := gin.H{
"auth_result": "pending_session",
"provider": strings.TrimSpace(session.ProviderType),
"intent": strings.TrimSpace(session.Intent),
}
for key, value := range mergePendingCompletionResponse(session, nil) {
payload[key] = value
}
if email := strings.TrimSpace(session.ResolvedEmail); email != "" {
payload["email"] = email
}
return payload
}
func writeOAuthTokenPairResponse(c *gin.Context, tokenPair *service.TokenPair) {
c.JSON(http.StatusOK, gin.H{
"access_token": tokenPair.AccessToken,
"refresh_token": tokenPair.RefreshToken,
"expires_in": tokenPair.ExpiresIn,
"token_type": "Bearer",
})
}
func (h *AuthHandler) bindPendingOAuthLogin(c *gin.Context, provider string) {
var req bindPendingOAuthLoginRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
pendingSvc, session, clearCookies, err := readPendingOAuthBrowserSession(c, h)
if err != nil {
response.ErrorFrom(c, err)
return
}
if strings.TrimSpace(provider) != "" && !strings.EqualFold(strings.TrimSpace(session.ProviderType), provider) {
response.BadRequest(c, "Pending oauth session provider mismatch")
return
}
user, err := h.authService.ValidatePasswordCredentials(c.Request.Context(), strings.TrimSpace(req.Email), req.Password)
if err != nil {
response.ErrorFrom(c, err)
return
}
if session.TargetUserID != nil && *session.TargetUserID > 0 && user.ID != *session.TargetUserID {
response.ErrorFrom(c, infraerrors.Conflict("PENDING_AUTH_TARGET_USER_MISMATCH", "pending oauth session must be completed by the targeted user"))
return
}
decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, req.adoptionDecision())
if err != nil {
response.ErrorFrom(c, err)
return
}
if err := applyPendingOAuthBinding(c.Request.Context(), h.entClient(), session, decision, &user.ID, true); err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
return
}
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
tokenPair, err := h.authService.GenerateTokenPair(c.Request.Context(), user, "")
if err != nil {
response.InternalError(c, "Failed to generate token pair")
return
}
if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), session.SessionToken, session.BrowserSessionKey); err != nil {
clearCookies()
response.ErrorFrom(c, err)
return
}
clearCookies()
writeOAuthTokenPairResponse(c, tokenPair)
}
func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string) {
var req createPendingOAuthAccountRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
pendingSvc, session, clearCookies, err := readPendingOAuthBrowserSession(c, h)
if err != nil {
response.ErrorFrom(c, err)
return
}
if strings.TrimSpace(provider) != "" && !strings.EqualFold(strings.TrimSpace(session.ProviderType), provider) {
response.BadRequest(c, "Pending oauth session provider mismatch")
return
}
client := h.entClient()
if client == nil {
response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready"))
return
}
email := strings.TrimSpace(strings.ToLower(req.Email))
existingUser, err := client.User.Query().Where(dbuser.EmailEQ(email)).Only(c.Request.Context())
if err != nil && !dbent.IsNotFound(err) {
response.ErrorFrom(c, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable"))
return
}
if existingUser != nil {
completionResponse := mergePendingCompletionResponse(session, map[string]any{
"step": "bind_login_required",
"email": email,
})
session, err = updatePendingOAuthSessionProgress(
c.Request.Context(),
client,
session,
"adopt_existing_user_by_email",
email,
&existingUser.ID,
completionResponse,
)
if err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_SESSION_UPDATE_FAILED", "failed to update pending oauth session").WithCause(err))
return
}
if _, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, req.adoptionDecision()); err != nil {
response.ErrorFrom(c, err)
return
}
c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(session))
return
}
tokenPair, user, err := h.authService.RegisterOAuthEmailAccount(
c.Request.Context(),
email,
req.Password,
strings.TrimSpace(req.VerifyCode),
strings.TrimSpace(req.InvitationCode),
strings.TrimSpace(session.ProviderType),
)
if err != nil {
response.ErrorFrom(c, err)
return
}
decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, req.adoptionDecision())
if err != nil {
response.ErrorFrom(c, err)
return
}
if err := applyPendingOAuthBinding(c.Request.Context(), client, session, decision, &user.ID, true); err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
return
}
if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), session.SessionToken, session.BrowserSessionKey); err != nil {
clearCookies()
response.ErrorFrom(c, err)
return
}
clearCookies()
writeOAuthTokenPairResponse(c, tokenPair)
}
// ExchangePendingOAuthCompletion redeems a pending OAuth browser session into a frontend-safe payload.
// POST /api/v1/auth/oauth/pending/exchange
func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {