feat: add pending oauth email onboarding flow
This commit is contained in:
@@ -243,6 +243,18 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
|
||||
if subject != "" {
|
||||
email = linuxDoSyntheticEmail(subject)
|
||||
}
|
||||
identityKey := service.PendingAuthIdentityKey{
|
||||
ProviderType: "linuxdo",
|
||||
ProviderKey: "linuxdo",
|
||||
ProviderSubject: subject,
|
||||
}
|
||||
upstreamClaims := map[string]any{
|
||||
"email": email,
|
||||
"username": username,
|
||||
"subject": subject,
|
||||
"suggested_display_name": displayName,
|
||||
"suggested_avatar_url": avatarURL,
|
||||
}
|
||||
if intent == oauthIntentBindCurrentUser {
|
||||
targetUserID, err := h.readOAuthBindUserIDFromCookie(c, linuxDoOAuthBindUserCookieName)
|
||||
if err != nil {
|
||||
@@ -250,23 +262,13 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
|
||||
Intent: oauthIntentBindCurrentUser,
|
||||
Identity: service.PendingAuthIdentityKey{
|
||||
ProviderType: "linuxdo",
|
||||
ProviderKey: "linuxdo",
|
||||
ProviderSubject: subject,
|
||||
},
|
||||
TargetUserID: &targetUserID,
|
||||
ResolvedEmail: email,
|
||||
RedirectTo: redirectTo,
|
||||
BrowserSessionKey: browserSessionKey,
|
||||
UpstreamIdentityClaims: map[string]any{
|
||||
"email": email,
|
||||
"username": username,
|
||||
"subject": subject,
|
||||
"suggested_display_name": displayName,
|
||||
"suggested_avatar_url": avatarURL,
|
||||
},
|
||||
Intent: oauthIntentBindCurrentUser,
|
||||
Identity: identityKey,
|
||||
TargetUserID: &targetUserID,
|
||||
ResolvedEmail: email,
|
||||
RedirectTo: redirectTo,
|
||||
BrowserSessionKey: browserSessionKey,
|
||||
UpstreamIdentityClaims: upstreamClaims,
|
||||
CompletionResponse: map[string]any{
|
||||
"redirect": redirectTo,
|
||||
},
|
||||
@@ -278,27 +280,60 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
existingIdentityUser, err := h.findOAuthIdentityUser(c.Request.Context(), identityKey)
|
||||
if err != nil {
|
||||
redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
|
||||
return
|
||||
}
|
||||
if existingIdentityUser != nil {
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), existingIdentityUser.Email, username, "")
|
||||
if err != nil {
|
||||
redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
|
||||
return
|
||||
}
|
||||
if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
|
||||
Intent: oauthIntentLogin,
|
||||
Identity: identityKey,
|
||||
TargetUserID: &user.ID,
|
||||
ResolvedEmail: existingIdentityUser.Email,
|
||||
RedirectTo: redirectTo,
|
||||
BrowserSessionKey: browserSessionKey,
|
||||
UpstreamIdentityClaims: upstreamClaims,
|
||||
CompletionResponse: map[string]any{
|
||||
"access_token": tokenPair.AccessToken,
|
||||
"refresh_token": tokenPair.RefreshToken,
|
||||
"expires_in": tokenPair.ExpiresIn,
|
||||
"token_type": "Bearer",
|
||||
"redirect": redirectTo,
|
||||
},
|
||||
}); err != nil {
|
||||
redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
|
||||
return
|
||||
}
|
||||
redirectToFrontendCallback(c, frontendCallback)
|
||||
return
|
||||
}
|
||||
|
||||
if h.isForceEmailOnThirdPartySignup(c.Request.Context()) {
|
||||
if err := h.createOAuthEmailRequiredPendingSession(c, identityKey, redirectTo, browserSessionKey, upstreamClaims); err != nil {
|
||||
redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
|
||||
return
|
||||
}
|
||||
redirectToFrontendCallback(c, frontendCallback)
|
||||
return
|
||||
}
|
||||
|
||||
// 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrOAuthInvitationRequired) {
|
||||
if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
|
||||
Intent: "login",
|
||||
Identity: service.PendingAuthIdentityKey{
|
||||
ProviderType: "linuxdo",
|
||||
ProviderKey: "linuxdo",
|
||||
ProviderSubject: subject,
|
||||
},
|
||||
ResolvedEmail: email,
|
||||
RedirectTo: redirectTo,
|
||||
BrowserSessionKey: browserSessionKey,
|
||||
UpstreamIdentityClaims: map[string]any{
|
||||
"email": email,
|
||||
"username": username,
|
||||
"subject": subject,
|
||||
"suggested_display_name": displayName,
|
||||
"suggested_avatar_url": avatarURL,
|
||||
},
|
||||
Intent: "login",
|
||||
Identity: identityKey,
|
||||
ResolvedEmail: email,
|
||||
RedirectTo: redirectTo,
|
||||
BrowserSessionKey: browserSessionKey,
|
||||
UpstreamIdentityClaims: upstreamClaims,
|
||||
CompletionResponse: map[string]any{
|
||||
"error": "invitation_required",
|
||||
"redirect": redirectTo,
|
||||
@@ -316,23 +351,13 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
|
||||
}
|
||||
|
||||
if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
|
||||
Intent: "login",
|
||||
Identity: service.PendingAuthIdentityKey{
|
||||
ProviderType: "linuxdo",
|
||||
ProviderKey: "linuxdo",
|
||||
ProviderSubject: subject,
|
||||
},
|
||||
TargetUserID: &user.ID,
|
||||
ResolvedEmail: email,
|
||||
RedirectTo: redirectTo,
|
||||
BrowserSessionKey: browserSessionKey,
|
||||
UpstreamIdentityClaims: map[string]any{
|
||||
"email": email,
|
||||
"username": username,
|
||||
"subject": subject,
|
||||
"suggested_display_name": displayName,
|
||||
"suggested_avatar_url": avatarURL,
|
||||
},
|
||||
Intent: "login",
|
||||
Identity: identityKey,
|
||||
TargetUserID: &user.ID,
|
||||
ResolvedEmail: email,
|
||||
RedirectTo: redirectTo,
|
||||
BrowserSessionKey: browserSessionKey,
|
||||
UpstreamIdentityClaims: upstreamClaims,
|
||||
CompletionResponse: map[string]any{
|
||||
"access_token": tokenPair.AccessToken,
|
||||
"refresh_token": tokenPair.RefreshToken,
|
||||
|
||||
@@ -46,6 +46,36 @@ type oauthAdoptionDecisionRequest struct {
|
||||
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
|
||||
}
|
||||
|
||||
type bindPendingOAuthLoginRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
|
||||
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
|
||||
}
|
||||
|
||||
type createPendingOAuthAccountRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
VerifyCode string `json:"verify_code,omitempty"`
|
||||
Password string `json:"password" binding:"required,min=6"`
|
||||
InvitationCode string `json:"invitation_code,omitempty"`
|
||||
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
|
||||
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
|
||||
}
|
||||
|
||||
func (r bindPendingOAuthLoginRequest) adoptionDecision() oauthAdoptionDecisionRequest {
|
||||
return oauthAdoptionDecisionRequest{
|
||||
AdoptDisplayName: r.AdoptDisplayName,
|
||||
AdoptAvatar: r.AdoptAvatar,
|
||||
}
|
||||
}
|
||||
|
||||
func (r createPendingOAuthAccountRequest) adoptionDecision() oauthAdoptionDecisionRequest {
|
||||
return oauthAdoptionDecisionRequest{
|
||||
AdoptDisplayName: r.AdoptDisplayName,
|
||||
AdoptAvatar: r.AdoptAvatar,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *AuthHandler) pendingIdentityService() (*service.AuthPendingIdentityService, error) {
|
||||
if h == nil || h.authService == nil || h.authService.EntClient() == nil {
|
||||
return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
|
||||
@@ -170,6 +200,36 @@ func readCompletionResponse(session map[string]any) (map[string]any, bool) {
|
||||
return result, true
|
||||
}
|
||||
|
||||
func clonePendingMap(values map[string]any) map[string]any {
|
||||
if len(values) == 0 {
|
||||
return map[string]any{}
|
||||
}
|
||||
cloned := make(map[string]any, len(values))
|
||||
for key, value := range values {
|
||||
cloned[key] = value
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func mergePendingCompletionResponse(session *dbent.PendingAuthSession, overrides map[string]any) map[string]any {
|
||||
payload, _ := readCompletionResponse(session.LocalFlowState)
|
||||
merged := clonePendingMap(payload)
|
||||
if strings.TrimSpace(session.RedirectTo) != "" {
|
||||
if _, exists := merged["redirect"]; !exists {
|
||||
merged["redirect"] = session.RedirectTo
|
||||
}
|
||||
}
|
||||
for key, value := range overrides {
|
||||
if value == nil {
|
||||
delete(merged, key)
|
||||
continue
|
||||
}
|
||||
merged[key] = value
|
||||
}
|
||||
applySuggestedProfileToCompletionResponse(merged, session.UpstreamIdentityClaims)
|
||||
return merged
|
||||
}
|
||||
|
||||
func pendingSessionStringValue(values map[string]any, key string) string {
|
||||
if len(values) == 0 {
|
||||
return ""
|
||||
@@ -264,6 +324,89 @@ func (h *AuthHandler) entClient() *dbent.Client {
|
||||
return h.authService.EntClient()
|
||||
}
|
||||
|
||||
func (h *AuthHandler) isForceEmailOnThirdPartySignup(ctx context.Context) bool {
|
||||
if h == nil || h.settingSvc == nil {
|
||||
return false
|
||||
}
|
||||
defaults, err := h.settingSvc.GetAuthSourceDefaultSettings(ctx)
|
||||
if err != nil || defaults == nil {
|
||||
return false
|
||||
}
|
||||
return defaults.ForceEmailOnThirdPartySignup
|
||||
}
|
||||
|
||||
func (h *AuthHandler) findOAuthIdentityUser(ctx context.Context, identity service.PendingAuthIdentityKey) (*dbent.User, error) {
|
||||
client := h.entClient()
|
||||
if client == nil {
|
||||
return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
|
||||
}
|
||||
|
||||
record, err := client.AuthIdentity.Query().
|
||||
Where(
|
||||
authidentity.ProviderTypeEQ(strings.TrimSpace(identity.ProviderType)),
|
||||
authidentity.ProviderKeyEQ(strings.TrimSpace(identity.ProviderKey)),
|
||||
authidentity.ProviderSubjectEQ(strings.TrimSpace(identity.ProviderSubject)),
|
||||
).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
|
||||
}
|
||||
|
||||
userEntity, err := client.User.Get(ctx, record.UserID)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, infraerrors.InternalServer("AUTH_IDENTITY_USER_LOOKUP_FAILED", "failed to load auth identity user").WithCause(err)
|
||||
}
|
||||
return userEntity, nil
|
||||
}
|
||||
|
||||
func (h *AuthHandler) createOAuthEmailRequiredPendingSession(
|
||||
c *gin.Context,
|
||||
identity service.PendingAuthIdentityKey,
|
||||
redirectTo string,
|
||||
browserSessionKey string,
|
||||
upstreamClaims map[string]any,
|
||||
) error {
|
||||
return h.createOAuthPendingSession(c, oauthPendingSessionPayload{
|
||||
Intent: oauthIntentLogin,
|
||||
Identity: identity,
|
||||
RedirectTo: redirectTo,
|
||||
BrowserSessionKey: browserSessionKey,
|
||||
UpstreamIdentityClaims: upstreamClaims,
|
||||
CompletionResponse: map[string]any{
|
||||
"redirect": redirectTo,
|
||||
"step": "email_required",
|
||||
"force_email_on_signup": true,
|
||||
"email_binding_required": true,
|
||||
"existing_account_bindable": true,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (h *AuthHandler) BindLinuxDoOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "linuxdo") }
|
||||
func (h *AuthHandler) BindOIDCOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "oidc") }
|
||||
func (h *AuthHandler) BindWeChatOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "wechat") }
|
||||
func (h *AuthHandler) BindPendingOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "") }
|
||||
|
||||
func (h *AuthHandler) CreateLinuxDoOAuthAccount(c *gin.Context) {
|
||||
h.createPendingOAuthAccount(c, "linuxdo")
|
||||
}
|
||||
|
||||
func (h *AuthHandler) CreateOIDCOAuthAccount(c *gin.Context) { h.createPendingOAuthAccount(c, "oidc") }
|
||||
|
||||
func (h *AuthHandler) CreateWeChatOAuthAccount(c *gin.Context) {
|
||||
h.createPendingOAuthAccount(c, "wechat")
|
||||
}
|
||||
|
||||
func (h *AuthHandler) CreatePendingOAuthAccount(c *gin.Context) {
|
||||
h.createPendingOAuthAccount(c, "")
|
||||
}
|
||||
|
||||
func (h *AuthHandler) upsertPendingOAuthAdoptionDecision(
|
||||
c *gin.Context,
|
||||
sessionID int64,
|
||||
@@ -313,6 +456,60 @@ func (h *AuthHandler) upsertPendingOAuthAdoptionDecision(
|
||||
return decision, nil
|
||||
}
|
||||
|
||||
func (h *AuthHandler) ensurePendingOAuthAdoptionDecision(
|
||||
c *gin.Context,
|
||||
sessionID int64,
|
||||
req oauthAdoptionDecisionRequest,
|
||||
) (*dbent.IdentityAdoptionDecision, error) {
|
||||
decision, err := h.upsertPendingOAuthAdoptionDecision(c, sessionID, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if decision != nil {
|
||||
return decision, nil
|
||||
}
|
||||
|
||||
svc, err := h.pendingIdentityService()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
decision, err = svc.UpsertAdoptionDecision(c.Request.Context(), service.PendingIdentityAdoptionDecisionInput{
|
||||
PendingAuthSessionID: sessionID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_SAVE_FAILED", "failed to save oauth profile adoption decision").WithCause(err)
|
||||
}
|
||||
return decision, nil
|
||||
}
|
||||
|
||||
func updatePendingOAuthSessionProgress(
|
||||
ctx context.Context,
|
||||
client *dbent.Client,
|
||||
session *dbent.PendingAuthSession,
|
||||
intent string,
|
||||
resolvedEmail string,
|
||||
targetUserID *int64,
|
||||
completionResponse map[string]any,
|
||||
) (*dbent.PendingAuthSession, error) {
|
||||
if client == nil || session == nil {
|
||||
return nil, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth session is invalid")
|
||||
}
|
||||
|
||||
localFlowState := clonePendingMap(session.LocalFlowState)
|
||||
localFlowState[oauthCompletionResponseKey] = clonePendingMap(completionResponse)
|
||||
|
||||
update := client.PendingAuthSession.UpdateOneID(session.ID).
|
||||
SetIntent(strings.TrimSpace(intent)).
|
||||
SetResolvedEmail(strings.TrimSpace(resolvedEmail)).
|
||||
SetLocalFlowState(localFlowState)
|
||||
if targetUserID != nil && *targetUserID > 0 {
|
||||
update = update.SetTargetUserID(*targetUserID)
|
||||
} else {
|
||||
update = update.ClearTargetUserID()
|
||||
}
|
||||
return update.Save(ctx)
|
||||
}
|
||||
|
||||
func resolvePendingOAuthTargetUserID(ctx context.Context, client *dbent.Client, session *dbent.PendingAuthSession) (int64, error) {
|
||||
if session == nil {
|
||||
return 0, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth session is invalid")
|
||||
@@ -401,17 +598,18 @@ func shouldBindPendingOAuthIdentity(session *dbent.PendingAuthSession, decision
|
||||
return decision.AdoptDisplayName || decision.AdoptAvatar
|
||||
}
|
||||
|
||||
func applyPendingOAuthAdoption(
|
||||
func applyPendingOAuthBinding(
|
||||
ctx context.Context,
|
||||
client *dbent.Client,
|
||||
session *dbent.PendingAuthSession,
|
||||
decision *dbent.IdentityAdoptionDecision,
|
||||
overrideUserID *int64,
|
||||
forceBind bool,
|
||||
) error {
|
||||
if client == nil || session == nil || decision == nil {
|
||||
if client == nil || session == nil {
|
||||
return nil
|
||||
}
|
||||
if !shouldBindPendingOAuthIdentity(session, decision) {
|
||||
if !forceBind && !shouldBindPendingOAuthIdentity(session, decision) {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -427,11 +625,11 @@ func applyPendingOAuthAdoption(
|
||||
}
|
||||
|
||||
adoptedDisplayName := ""
|
||||
if decision.AdoptDisplayName {
|
||||
if decision != nil && decision.AdoptDisplayName {
|
||||
adoptedDisplayName = normalizeAdoptedOAuthDisplayName(pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_display_name"))
|
||||
}
|
||||
adoptedAvatarURL := ""
|
||||
if decision.AdoptAvatar {
|
||||
if decision != nil && decision.AdoptAvatar {
|
||||
adoptedAvatarURL = pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_avatar_url")
|
||||
}
|
||||
|
||||
@@ -441,7 +639,7 @@ func applyPendingOAuthAdoption(
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
if decision.AdoptDisplayName && adoptedDisplayName != "" {
|
||||
if decision != nil && decision.AdoptDisplayName && adoptedDisplayName != "" {
|
||||
if err := tx.Client().User.UpdateOneID(targetUserID).
|
||||
SetUsername(adoptedDisplayName).
|
||||
Exec(ctx); err != nil {
|
||||
@@ -458,10 +656,10 @@ func applyPendingOAuthAdoption(
|
||||
for key, value := range session.UpstreamIdentityClaims {
|
||||
metadata[key] = value
|
||||
}
|
||||
if decision.AdoptDisplayName && adoptedDisplayName != "" {
|
||||
if decision != nil && decision.AdoptDisplayName && adoptedDisplayName != "" {
|
||||
metadata["display_name"] = adoptedDisplayName
|
||||
}
|
||||
if decision.AdoptAvatar && adoptedAvatarURL != "" {
|
||||
if decision != nil && decision.AdoptAvatar && adoptedAvatarURL != "" {
|
||||
metadata["avatar_url"] = adoptedAvatarURL
|
||||
}
|
||||
|
||||
@@ -473,7 +671,7 @@ func applyPendingOAuthAdoption(
|
||||
return err
|
||||
}
|
||||
|
||||
if decision.IdentityID == nil || *decision.IdentityID != identity.ID {
|
||||
if decision != nil && (decision.IdentityID == nil || *decision.IdentityID != identity.ID) {
|
||||
if _, err := tx.Client().IdentityAdoptionDecision.UpdateOneID(decision.ID).
|
||||
SetIdentityID(identity.ID).
|
||||
Save(ctx); err != nil {
|
||||
@@ -484,6 +682,16 @@ func applyPendingOAuthAdoption(
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func applyPendingOAuthAdoption(
|
||||
ctx context.Context,
|
||||
client *dbent.Client,
|
||||
session *dbent.PendingAuthSession,
|
||||
decision *dbent.IdentityAdoptionDecision,
|
||||
overrideUserID *int64,
|
||||
) error {
|
||||
return applyPendingOAuthBinding(ctx, client, session, decision, overrideUserID, false)
|
||||
}
|
||||
|
||||
func applySuggestedProfileToCompletionResponse(payload map[string]any, upstream map[string]any) {
|
||||
if len(payload) == 0 || len(upstream) == 0 {
|
||||
return
|
||||
@@ -507,6 +715,206 @@ func applySuggestedProfileToCompletionResponse(payload map[string]any, upstream
|
||||
}
|
||||
}
|
||||
|
||||
func readPendingOAuthBrowserSession(c *gin.Context, h *AuthHandler) (*service.AuthPendingIdentityService, *dbent.PendingAuthSession, func(), error) {
|
||||
secureCookie := isRequestHTTPS(c)
|
||||
clearCookies := func() {
|
||||
clearOAuthPendingSessionCookie(c, secureCookie)
|
||||
clearOAuthPendingBrowserCookie(c, secureCookie)
|
||||
}
|
||||
|
||||
sessionToken, err := readOAuthPendingSessionCookie(c)
|
||||
if err != nil || strings.TrimSpace(sessionToken) == "" {
|
||||
clearCookies()
|
||||
return nil, nil, clearCookies, service.ErrPendingAuthSessionNotFound
|
||||
}
|
||||
browserSessionKey, err := readOAuthPendingBrowserCookie(c)
|
||||
if err != nil || strings.TrimSpace(browserSessionKey) == "" {
|
||||
clearCookies()
|
||||
return nil, nil, clearCookies, service.ErrPendingAuthBrowserMismatch
|
||||
}
|
||||
|
||||
svc, err := h.pendingIdentityService()
|
||||
if err != nil {
|
||||
clearCookies()
|
||||
return nil, nil, clearCookies, err
|
||||
}
|
||||
|
||||
session, err := svc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey)
|
||||
if err != nil {
|
||||
clearCookies()
|
||||
return nil, nil, clearCookies, err
|
||||
}
|
||||
|
||||
return svc, session, clearCookies, nil
|
||||
}
|
||||
|
||||
func buildPendingOAuthSessionStatusPayload(session *dbent.PendingAuthSession) gin.H {
|
||||
payload := gin.H{
|
||||
"auth_result": "pending_session",
|
||||
"provider": strings.TrimSpace(session.ProviderType),
|
||||
"intent": strings.TrimSpace(session.Intent),
|
||||
}
|
||||
for key, value := range mergePendingCompletionResponse(session, nil) {
|
||||
payload[key] = value
|
||||
}
|
||||
if email := strings.TrimSpace(session.ResolvedEmail); email != "" {
|
||||
payload["email"] = email
|
||||
}
|
||||
return payload
|
||||
}
|
||||
|
||||
func writeOAuthTokenPairResponse(c *gin.Context, tokenPair *service.TokenPair) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"access_token": tokenPair.AccessToken,
|
||||
"refresh_token": tokenPair.RefreshToken,
|
||||
"expires_in": tokenPair.ExpiresIn,
|
||||
"token_type": "Bearer",
|
||||
})
|
||||
}
|
||||
|
||||
func (h *AuthHandler) bindPendingOAuthLogin(c *gin.Context, provider string) {
|
||||
var req bindPendingOAuthLoginRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
pendingSvc, session, clearCookies, err := readPendingOAuthBrowserSession(c, h)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(provider) != "" && !strings.EqualFold(strings.TrimSpace(session.ProviderType), provider) {
|
||||
response.BadRequest(c, "Pending oauth session provider mismatch")
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.authService.ValidatePasswordCredentials(c.Request.Context(), strings.TrimSpace(req.Email), req.Password)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if session.TargetUserID != nil && *session.TargetUserID > 0 && user.ID != *session.TargetUserID {
|
||||
response.ErrorFrom(c, infraerrors.Conflict("PENDING_AUTH_TARGET_USER_MISMATCH", "pending oauth session must be completed by the targeted user"))
|
||||
return
|
||||
}
|
||||
|
||||
decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, req.adoptionDecision())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if err := applyPendingOAuthBinding(c.Request.Context(), h.entClient(), session, decision, &user.ID, true); err != nil {
|
||||
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
|
||||
return
|
||||
}
|
||||
|
||||
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
|
||||
tokenPair, err := h.authService.GenerateTokenPair(c.Request.Context(), user, "")
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to generate token pair")
|
||||
return
|
||||
}
|
||||
if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), session.SessionToken, session.BrowserSessionKey); err != nil {
|
||||
clearCookies()
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
clearCookies()
|
||||
writeOAuthTokenPairResponse(c, tokenPair)
|
||||
}
|
||||
|
||||
func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string) {
|
||||
var req createPendingOAuthAccountRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
pendingSvc, session, clearCookies, err := readPendingOAuthBrowserSession(c, h)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(provider) != "" && !strings.EqualFold(strings.TrimSpace(session.ProviderType), provider) {
|
||||
response.BadRequest(c, "Pending oauth session provider mismatch")
|
||||
return
|
||||
}
|
||||
|
||||
client := h.entClient()
|
||||
if client == nil {
|
||||
response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready"))
|
||||
return
|
||||
}
|
||||
|
||||
email := strings.TrimSpace(strings.ToLower(req.Email))
|
||||
existingUser, err := client.User.Query().Where(dbuser.EmailEQ(email)).Only(c.Request.Context())
|
||||
if err != nil && !dbent.IsNotFound(err) {
|
||||
response.ErrorFrom(c, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable"))
|
||||
return
|
||||
}
|
||||
if existingUser != nil {
|
||||
completionResponse := mergePendingCompletionResponse(session, map[string]any{
|
||||
"step": "bind_login_required",
|
||||
"email": email,
|
||||
})
|
||||
session, err = updatePendingOAuthSessionProgress(
|
||||
c.Request.Context(),
|
||||
client,
|
||||
session,
|
||||
"adopt_existing_user_by_email",
|
||||
email,
|
||||
&existingUser.ID,
|
||||
completionResponse,
|
||||
)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_SESSION_UPDATE_FAILED", "failed to update pending oauth session").WithCause(err))
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, req.adoptionDecision()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(session))
|
||||
return
|
||||
}
|
||||
|
||||
tokenPair, user, err := h.authService.RegisterOAuthEmailAccount(
|
||||
c.Request.Context(),
|
||||
email,
|
||||
req.Password,
|
||||
strings.TrimSpace(req.VerifyCode),
|
||||
strings.TrimSpace(req.InvitationCode),
|
||||
strings.TrimSpace(session.ProviderType),
|
||||
)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, req.adoptionDecision())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if err := applyPendingOAuthBinding(c.Request.Context(), client, session, decision, &user.ID, true); err != nil {
|
||||
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), session.SessionToken, session.BrowserSessionKey); err != nil {
|
||||
clearCookies()
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
clearCookies()
|
||||
writeOAuthTokenPairResponse(c, tokenPair)
|
||||
}
|
||||
|
||||
// ExchangePendingOAuthCompletion redeems a pending OAuth browser session into a frontend-safe payload.
|
||||
// POST /api/v1/auth/oauth/pending/exchange
|
||||
func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
|
||||
|
||||
@@ -509,9 +509,305 @@ func TestExchangePendingOAuthCompletionInvitationRequiredFalseFalsePersistsDecis
|
||||
require.Nil(t, storedSession.ConsumedAt)
|
||||
}
|
||||
|
||||
func TestCreateOIDCOAuthAccountCreatesUserBindsIdentityAndConsumesSession(t *testing.T) {
|
||||
handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "fresh@example.com", "246810")
|
||||
ctx := context.Background()
|
||||
|
||||
session, err := client.PendingAuthSession.Create().
|
||||
SetSessionToken("create-account-session-token").
|
||||
SetIntent("login").
|
||||
SetProviderType("oidc").
|
||||
SetProviderKey("https://issuer.example").
|
||||
SetProviderSubject("oidc-create-123").
|
||||
SetBrowserSessionKey("create-account-browser-session-key").
|
||||
SetUpstreamIdentityClaims(map[string]any{
|
||||
"username": "oidc_user",
|
||||
"suggested_display_name": "Fresh OIDC User",
|
||||
"suggested_avatar_url": "https://cdn.example/fresh.png",
|
||||
}).
|
||||
SetRedirectTo("/profile").
|
||||
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
body := bytes.NewBufferString(`{"email":"fresh@example.com","verify_code":"246810","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`)
|
||||
recorder := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(recorder)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
|
||||
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("create-account-browser-session-key")})
|
||||
ginCtx.Request = req
|
||||
|
||||
handler.CreateOIDCOAuthAccount(ginCtx)
|
||||
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
|
||||
var payload map[string]any
|
||||
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload))
|
||||
require.NotEmpty(t, payload["access_token"])
|
||||
require.NotEmpty(t, payload["refresh_token"])
|
||||
require.Equal(t, "Bearer", payload["token_type"])
|
||||
|
||||
createdUser, err := client.User.Query().Where(dbuser.EmailEQ("fresh@example.com")).Only(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, service.StatusActive, createdUser.Status)
|
||||
|
||||
identity, err := client.AuthIdentity.Query().
|
||||
Where(
|
||||
authidentity.ProviderTypeEQ("oidc"),
|
||||
authidentity.ProviderKeyEQ("https://issuer.example"),
|
||||
authidentity.ProviderSubjectEQ("oidc-create-123"),
|
||||
).
|
||||
Only(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, createdUser.ID, identity.UserID)
|
||||
|
||||
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, storedSession.ConsumedAt)
|
||||
}
|
||||
|
||||
func TestCreateOIDCOAuthAccountExistingEmailReturnsAdoptExistingUserByEmailState(t *testing.T) {
|
||||
handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790")
|
||||
ctx := context.Background()
|
||||
|
||||
existingUser, err := client.User.Create().
|
||||
SetEmail("owner@example.com").
|
||||
SetUsername("owner-user").
|
||||
SetPasswordHash("hash").
|
||||
SetRole(service.RoleUser).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
session, err := client.PendingAuthSession.Create().
|
||||
SetSessionToken("existing-email-session-token").
|
||||
SetIntent("login").
|
||||
SetProviderType("oidc").
|
||||
SetProviderKey("https://issuer.example").
|
||||
SetProviderSubject("oidc-existing-123").
|
||||
SetBrowserSessionKey("existing-email-browser-session-key").
|
||||
SetUpstreamIdentityClaims(map[string]any{
|
||||
"username": "oidc_user",
|
||||
"suggested_display_name": "Existing OIDC User",
|
||||
"suggested_avatar_url": "https://cdn.example/existing.png",
|
||||
}).
|
||||
SetRedirectTo("/dashboard").
|
||||
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
body := bytes.NewBufferString(`{"email":"owner@example.com","verify_code":"135790","password":"secret-123"}`)
|
||||
recorder := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(recorder)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
|
||||
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("existing-email-browser-session-key")})
|
||||
ginCtx.Request = req
|
||||
|
||||
handler.CreateOIDCOAuthAccount(ginCtx)
|
||||
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
|
||||
var payload map[string]any
|
||||
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload))
|
||||
require.Equal(t, "pending_session", payload["auth_result"])
|
||||
require.Equal(t, "adopt_existing_user_by_email", payload["intent"])
|
||||
require.Equal(t, "oidc", payload["provider"])
|
||||
require.Equal(t, "/dashboard", payload["redirect"])
|
||||
require.Equal(t, true, payload["adoption_required"])
|
||||
require.Equal(t, "Existing OIDC User", payload["suggested_display_name"])
|
||||
require.Equal(t, "https://cdn.example/existing.png", payload["suggested_avatar_url"])
|
||||
|
||||
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "adopt_existing_user_by_email", storedSession.Intent)
|
||||
require.NotNil(t, storedSession.TargetUserID)
|
||||
require.Equal(t, existingUser.ID, *storedSession.TargetUserID)
|
||||
require.Equal(t, "owner@example.com", storedSession.ResolvedEmail)
|
||||
require.Nil(t, storedSession.ConsumedAt)
|
||||
|
||||
identityCount, err := client.AuthIdentity.Query().
|
||||
Where(
|
||||
authidentity.ProviderTypeEQ("oidc"),
|
||||
authidentity.ProviderKeyEQ("https://issuer.example"),
|
||||
authidentity.ProviderSubjectEQ("oidc-existing-123"),
|
||||
).
|
||||
Count(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Zero(t, identityCount)
|
||||
}
|
||||
|
||||
func TestBindOIDCOAuthLoginBindsExistingUserAndConsumesSession(t *testing.T) {
|
||||
handler, client := newOAuthPendingFlowTestHandler(t, false)
|
||||
ctx := context.Background()
|
||||
|
||||
passwordHash, err := handler.authService.HashPassword("secret-123")
|
||||
require.NoError(t, err)
|
||||
|
||||
existingUser, err := client.User.Create().
|
||||
SetEmail("owner@example.com").
|
||||
SetUsername("owner-user").
|
||||
SetPasswordHash(passwordHash).
|
||||
SetRole(service.RoleUser).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
session, err := client.PendingAuthSession.Create().
|
||||
SetSessionToken("bind-login-session-token").
|
||||
SetIntent("adopt_existing_user_by_email").
|
||||
SetProviderType("oidc").
|
||||
SetProviderKey("https://issuer.example").
|
||||
SetProviderSubject("oidc-bind-123").
|
||||
SetTargetUserID(existingUser.ID).
|
||||
SetResolvedEmail(existingUser.Email).
|
||||
SetBrowserSessionKey("bind-login-browser-session-key").
|
||||
SetUpstreamIdentityClaims(map[string]any{
|
||||
"username": "oidc_user",
|
||||
"suggested_display_name": "Bound OIDC User",
|
||||
"suggested_avatar_url": "https://cdn.example/bound.png",
|
||||
}).
|
||||
SetRedirectTo("/profile").
|
||||
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
body := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`)
|
||||
recorder := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(recorder)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
|
||||
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-login-browser-session-key")})
|
||||
ginCtx.Request = req
|
||||
|
||||
handler.BindOIDCOAuthLogin(ginCtx)
|
||||
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
|
||||
var payload map[string]any
|
||||
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload))
|
||||
require.NotEmpty(t, payload["access_token"])
|
||||
require.NotEmpty(t, payload["refresh_token"])
|
||||
require.Equal(t, "Bearer", payload["token_type"])
|
||||
|
||||
identity, err := client.AuthIdentity.Query().
|
||||
Where(
|
||||
authidentity.ProviderTypeEQ("oidc"),
|
||||
authidentity.ProviderKeyEQ("https://issuer.example"),
|
||||
authidentity.ProviderSubjectEQ("oidc-bind-123"),
|
||||
).
|
||||
Only(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, existingUser.ID, identity.UserID)
|
||||
|
||||
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, storedSession.ConsumedAt)
|
||||
}
|
||||
|
||||
func TestBindOIDCOAuthLoginRejectsInvalidPasswordWithoutConsumingSession(t *testing.T) {
|
||||
handler, client := newOAuthPendingFlowTestHandler(t, false)
|
||||
ctx := context.Background()
|
||||
|
||||
passwordHash, err := handler.authService.HashPassword("secret-123")
|
||||
require.NoError(t, err)
|
||||
|
||||
existingUser, err := client.User.Create().
|
||||
SetEmail("owner@example.com").
|
||||
SetUsername("owner-user").
|
||||
SetPasswordHash(passwordHash).
|
||||
SetRole(service.RoleUser).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
session, err := client.PendingAuthSession.Create().
|
||||
SetSessionToken("bind-login-invalid-password-session-token").
|
||||
SetIntent("adopt_existing_user_by_email").
|
||||
SetProviderType("oidc").
|
||||
SetProviderKey("https://issuer.example").
|
||||
SetProviderSubject("oidc-bind-invalid-123").
|
||||
SetTargetUserID(existingUser.ID).
|
||||
SetResolvedEmail(existingUser.Email).
|
||||
SetBrowserSessionKey("bind-login-invalid-password-browser-session-key").
|
||||
SetUpstreamIdentityClaims(map[string]any{
|
||||
"username": "oidc_user",
|
||||
"suggested_display_name": "Bound OIDC User",
|
||||
"suggested_avatar_url": "https://cdn.example/bound.png",
|
||||
}).
|
||||
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
body := bytes.NewBufferString(`{"email":"owner@example.com","password":"wrong-password"}`)
|
||||
recorder := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(recorder)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
|
||||
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-login-invalid-password-browser-session-key")})
|
||||
ginCtx.Request = req
|
||||
|
||||
handler.BindOIDCOAuthLogin(ginCtx)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, recorder.Code)
|
||||
payload := decodeJSONBody(t, recorder)
|
||||
require.Equal(t, "INVALID_CREDENTIALS", payload["reason"])
|
||||
|
||||
identityCount, err := client.AuthIdentity.Query().
|
||||
Where(
|
||||
authidentity.ProviderTypeEQ("oidc"),
|
||||
authidentity.ProviderKeyEQ("https://issuer.example"),
|
||||
authidentity.ProviderSubjectEQ("oidc-bind-invalid-123"),
|
||||
).
|
||||
Count(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Zero(t, identityCount)
|
||||
|
||||
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, storedSession.ConsumedAt)
|
||||
}
|
||||
|
||||
func newOAuthPendingFlowTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) {
|
||||
t.Helper()
|
||||
|
||||
return newOAuthPendingFlowTestHandlerWithOptions(t, invitationEnabled, false, nil)
|
||||
}
|
||||
|
||||
func newOAuthPendingFlowTestHandlerWithEmailVerification(
|
||||
t *testing.T,
|
||||
invitationEnabled bool,
|
||||
email string,
|
||||
code string,
|
||||
) (*AuthHandler, *dbent.Client) {
|
||||
t.Helper()
|
||||
|
||||
cache := &oauthPendingFlowEmailCacheStub{
|
||||
verificationCodes: map[string]*service.VerificationCodeData{
|
||||
email: {
|
||||
Code: code,
|
||||
Attempts: 0,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
|
||||
},
|
||||
},
|
||||
}
|
||||
return newOAuthPendingFlowTestHandlerWithOptions(t, invitationEnabled, true, cache)
|
||||
}
|
||||
|
||||
func newOAuthPendingFlowTestHandlerWithOptions(
|
||||
t *testing.T,
|
||||
invitationEnabled bool,
|
||||
emailVerifyEnabled bool,
|
||||
emailCache service.EmailCache,
|
||||
) (*AuthHandler, *dbent.Client) {
|
||||
t.Helper()
|
||||
|
||||
db, err := sql.Open("sqlite", "file:auth_oauth_pending_flow_handler?mode=memory&cache=shared")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { _ = db.Close() })
|
||||
@@ -538,9 +834,18 @@ func newOAuthPendingFlowTestHandler(t *testing.T, invitationEnabled bool) (*Auth
|
||||
values: map[string]string{
|
||||
service.SettingKeyRegistrationEnabled: "true",
|
||||
service.SettingKeyInvitationCodeEnabled: boolSettingValue(invitationEnabled),
|
||||
service.SettingKeyEmailVerifyEnabled: boolSettingValue(emailVerifyEnabled),
|
||||
},
|
||||
}, cfg)
|
||||
userRepo := &oauthPendingFlowUserRepo{client: client}
|
||||
var emailService *service.EmailService
|
||||
if emailCache != nil {
|
||||
emailService = service.NewEmailService(&oauthPendingFlowSettingRepoStub{
|
||||
values: map[string]string{
|
||||
service.SettingKeyEmailVerifyEnabled: boolSettingValue(emailVerifyEnabled),
|
||||
},
|
||||
}, emailCache)
|
||||
}
|
||||
authSvc := service.NewAuthService(
|
||||
client,
|
||||
userRepo,
|
||||
@@ -548,7 +853,7 @@ func newOAuthPendingFlowTestHandler(t *testing.T, invitationEnabled bool) (*Auth
|
||||
&oauthPendingFlowRefreshTokenCacheStub{},
|
||||
cfg,
|
||||
settingSvc,
|
||||
nil,
|
||||
emailService,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
@@ -622,6 +927,70 @@ func (s *oauthPendingFlowSettingRepoStub) Delete(context.Context, string) error
|
||||
|
||||
type oauthPendingFlowRefreshTokenCacheStub struct{}
|
||||
|
||||
type oauthPendingFlowEmailCacheStub struct {
|
||||
verificationCodes map[string]*service.VerificationCodeData
|
||||
}
|
||||
|
||||
func (s *oauthPendingFlowEmailCacheStub) GetVerificationCode(_ context.Context, email string) (*service.VerificationCodeData, error) {
|
||||
if s == nil || s.verificationCodes == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return s.verificationCodes[email], nil
|
||||
}
|
||||
|
||||
func (s *oauthPendingFlowEmailCacheStub) SetVerificationCode(_ context.Context, email string, data *service.VerificationCodeData, _ time.Duration) error {
|
||||
if s.verificationCodes == nil {
|
||||
s.verificationCodes = map[string]*service.VerificationCodeData{}
|
||||
}
|
||||
s.verificationCodes[email] = data
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *oauthPendingFlowEmailCacheStub) DeleteVerificationCode(_ context.Context, email string) error {
|
||||
delete(s.verificationCodes, email)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *oauthPendingFlowEmailCacheStub) GetNotifyVerifyCode(context.Context, string) (*service.VerificationCodeData, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *oauthPendingFlowEmailCacheStub) SetNotifyVerifyCode(context.Context, string, *service.VerificationCodeData, time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *oauthPendingFlowEmailCacheStub) DeleteNotifyVerifyCode(context.Context, string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *oauthPendingFlowEmailCacheStub) GetPasswordResetToken(context.Context, string) (*service.PasswordResetTokenData, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *oauthPendingFlowEmailCacheStub) SetPasswordResetToken(context.Context, string, *service.PasswordResetTokenData, time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *oauthPendingFlowEmailCacheStub) DeletePasswordResetToken(context.Context, string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *oauthPendingFlowEmailCacheStub) IsPasswordResetEmailInCooldown(context.Context, string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *oauthPendingFlowEmailCacheStub) SetPasswordResetEmailCooldown(context.Context, string, time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *oauthPendingFlowEmailCacheStub) IncrNotifyCodeUserRate(context.Context, int64, time.Duration) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *oauthPendingFlowEmailCacheStub) GetNotifyCodeUserRate(context.Context, int64) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *oauthPendingFlowRefreshTokenCacheStub) StoreRefreshToken(context.Context, string, *service.RefreshTokenData, time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -342,6 +342,21 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
|
||||
idClaims.Name,
|
||||
oidcFallbackUsername(subject),
|
||||
)
|
||||
identityRef := service.PendingAuthIdentityKey{
|
||||
ProviderType: "oidc",
|
||||
ProviderKey: issuer,
|
||||
ProviderSubject: subject,
|
||||
}
|
||||
upstreamClaims := map[string]any{
|
||||
"email": email,
|
||||
"username": username,
|
||||
"subject": subject,
|
||||
"issuer": issuer,
|
||||
"email_verified": emailVerified != nil && *emailVerified,
|
||||
"provider_fallback": strings.TrimSpace(cfg.ProviderName),
|
||||
"suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, idClaims.Name, username),
|
||||
"suggested_avatar_url": userInfoClaims.AvatarURL,
|
||||
}
|
||||
if intent == oauthIntentBindCurrentUser {
|
||||
targetUserID, err := h.readOAuthBindUserIDFromCookie(c, oidcOAuthBindUserCookieName)
|
||||
if err != nil {
|
||||
@@ -349,26 +364,13 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
|
||||
Intent: oauthIntentBindCurrentUser,
|
||||
Identity: service.PendingAuthIdentityKey{
|
||||
ProviderType: "oidc",
|
||||
ProviderKey: issuer,
|
||||
ProviderSubject: subject,
|
||||
},
|
||||
TargetUserID: &targetUserID,
|
||||
ResolvedEmail: email,
|
||||
RedirectTo: redirectTo,
|
||||
BrowserSessionKey: browserSessionKey,
|
||||
UpstreamIdentityClaims: map[string]any{
|
||||
"email": email,
|
||||
"username": username,
|
||||
"subject": subject,
|
||||
"issuer": issuer,
|
||||
"email_verified": emailVerified != nil && *emailVerified,
|
||||
"provider_fallback": strings.TrimSpace(cfg.ProviderName),
|
||||
"suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, idClaims.Name, username),
|
||||
"suggested_avatar_url": userInfoClaims.AvatarURL,
|
||||
},
|
||||
Intent: oauthIntentBindCurrentUser,
|
||||
Identity: identityRef,
|
||||
TargetUserID: &targetUserID,
|
||||
ResolvedEmail: email,
|
||||
RedirectTo: redirectTo,
|
||||
BrowserSessionKey: browserSessionKey,
|
||||
UpstreamIdentityClaims: upstreamClaims,
|
||||
CompletionResponse: map[string]any{
|
||||
"redirect": redirectTo,
|
||||
},
|
||||
@@ -380,30 +382,60 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
existingIdentityUser, err := h.findOAuthIdentityUser(c.Request.Context(), identityRef)
|
||||
if err != nil {
|
||||
redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
|
||||
return
|
||||
}
|
||||
if existingIdentityUser != nil {
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), existingIdentityUser.Email, username, "")
|
||||
if err != nil {
|
||||
redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
|
||||
return
|
||||
}
|
||||
if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
|
||||
Intent: oauthIntentLogin,
|
||||
Identity: identityRef,
|
||||
TargetUserID: &user.ID,
|
||||
ResolvedEmail: existingIdentityUser.Email,
|
||||
RedirectTo: redirectTo,
|
||||
BrowserSessionKey: browserSessionKey,
|
||||
UpstreamIdentityClaims: upstreamClaims,
|
||||
CompletionResponse: map[string]any{
|
||||
"access_token": tokenPair.AccessToken,
|
||||
"refresh_token": tokenPair.RefreshToken,
|
||||
"expires_in": tokenPair.ExpiresIn,
|
||||
"token_type": "Bearer",
|
||||
"redirect": redirectTo,
|
||||
},
|
||||
}); err != nil {
|
||||
redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
|
||||
return
|
||||
}
|
||||
redirectToFrontendCallback(c, frontendCallback)
|
||||
return
|
||||
}
|
||||
|
||||
if h.isForceEmailOnThirdPartySignup(c.Request.Context()) {
|
||||
if err := h.createOAuthEmailRequiredPendingSession(c, identityRef, redirectTo, browserSessionKey, upstreamClaims); err != nil {
|
||||
redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
|
||||
return
|
||||
}
|
||||
redirectToFrontendCallback(c, frontendCallback)
|
||||
return
|
||||
}
|
||||
|
||||
// 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrOAuthInvitationRequired) {
|
||||
if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
|
||||
Intent: "login",
|
||||
Identity: service.PendingAuthIdentityKey{
|
||||
ProviderType: "oidc",
|
||||
ProviderKey: issuer,
|
||||
ProviderSubject: subject,
|
||||
},
|
||||
ResolvedEmail: email,
|
||||
RedirectTo: redirectTo,
|
||||
BrowserSessionKey: browserSessionKey,
|
||||
UpstreamIdentityClaims: map[string]any{
|
||||
"email": email,
|
||||
"username": username,
|
||||
"subject": subject,
|
||||
"issuer": issuer,
|
||||
"email_verified": emailVerified != nil && *emailVerified,
|
||||
"provider_fallback": strings.TrimSpace(cfg.ProviderName),
|
||||
"suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, idClaims.Name, username),
|
||||
"suggested_avatar_url": userInfoClaims.AvatarURL,
|
||||
},
|
||||
Intent: "login",
|
||||
Identity: identityRef,
|
||||
ResolvedEmail: email,
|
||||
RedirectTo: redirectTo,
|
||||
BrowserSessionKey: browserSessionKey,
|
||||
UpstreamIdentityClaims: upstreamClaims,
|
||||
CompletionResponse: map[string]any{
|
||||
"error": "invitation_required",
|
||||
"redirect": redirectTo,
|
||||
@@ -420,26 +452,13 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
|
||||
}
|
||||
|
||||
if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
|
||||
Intent: "login",
|
||||
Identity: service.PendingAuthIdentityKey{
|
||||
ProviderType: "oidc",
|
||||
ProviderKey: issuer,
|
||||
ProviderSubject: subject,
|
||||
},
|
||||
TargetUserID: &user.ID,
|
||||
ResolvedEmail: email,
|
||||
RedirectTo: redirectTo,
|
||||
BrowserSessionKey: browserSessionKey,
|
||||
UpstreamIdentityClaims: map[string]any{
|
||||
"email": email,
|
||||
"username": username,
|
||||
"subject": subject,
|
||||
"issuer": issuer,
|
||||
"email_verified": emailVerified != nil && *emailVerified,
|
||||
"provider_fallback": strings.TrimSpace(cfg.ProviderName),
|
||||
"suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, idClaims.Name, username),
|
||||
"suggested_avatar_url": userInfoClaims.AvatarURL,
|
||||
},
|
||||
Intent: "login",
|
||||
Identity: identityRef,
|
||||
TargetUserID: &user.ID,
|
||||
ResolvedEmail: email,
|
||||
RedirectTo: redirectTo,
|
||||
BrowserSessionKey: browserSessionKey,
|
||||
UpstreamIdentityClaims: upstreamClaims,
|
||||
CompletionResponse: map[string]any{
|
||||
"access_token": tokenPair.AccessToken,
|
||||
"refresh_token": tokenPair.RefreshToken,
|
||||
|
||||
@@ -214,6 +214,11 @@ func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) {
|
||||
"suggested_display_name": strings.TrimSpace(userInfo.Nickname),
|
||||
"suggested_avatar_url": strings.TrimSpace(userInfo.HeadImgURL),
|
||||
}
|
||||
identityRef := service.PendingAuthIdentityKey{
|
||||
ProviderType: "wechat",
|
||||
ProviderKey: wechatOAuthProviderKey,
|
||||
ProviderSubject: providerSubject,
|
||||
}
|
||||
|
||||
normalizedIntent := normalizeWeChatOAuthIntent(intent)
|
||||
if normalizedIntent == wechatOAuthIntentBind {
|
||||
@@ -232,6 +237,34 @@ func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
existingIdentityUser, err := h.findOAuthIdentityUser(c.Request.Context(), identityRef)
|
||||
if err != nil {
|
||||
redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
|
||||
return
|
||||
}
|
||||
if existingIdentityUser != nil {
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), existingIdentityUser.Email, username, "")
|
||||
if err != nil {
|
||||
redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
|
||||
return
|
||||
}
|
||||
if err := h.createWeChatPendingSession(c, normalizedIntent, providerSubject, existingIdentityUser.Email, redirectTo, browserSessionKey, upstreamClaims, tokenPair, nil, &user.ID); err != nil {
|
||||
redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
|
||||
return
|
||||
}
|
||||
redirectToFrontendCallback(c, frontendCallback)
|
||||
return
|
||||
}
|
||||
|
||||
if h.isForceEmailOnThirdPartySignup(c.Request.Context()) {
|
||||
if err := h.createOAuthEmailRequiredPendingSession(c, identityRef, redirectTo, browserSessionKey, upstreamClaims); err != nil {
|
||||
redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
|
||||
return
|
||||
}
|
||||
redirectToFrontendCallback(c, frontendCallback)
|
||||
return
|
||||
}
|
||||
|
||||
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
|
||||
if err != nil {
|
||||
if err := h.createWeChatPendingSession(c, normalizedIntent, providerSubject, email, redirectTo, browserSessionKey, upstreamClaims, tokenPair, err, nil); err != nil {
|
||||
|
||||
@@ -167,6 +167,7 @@ type DefaultSubscriptionSetting struct {
|
||||
type PublicSettings struct {
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
ForceEmailOnThirdPartySignup bool `json:"force_email_on_third_party_signup"`
|
||||
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
|
||||
@@ -34,6 +34,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
||||
response.Success(c, dto.PublicSettings{
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
ForceEmailOnThirdPartySignup: settings.ForceEmailOnThirdPartySignup,
|
||||
RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
|
||||
PromoCodeEnabled: settings.PromoCodeEnabled,
|
||||
PasswordResetEnabled: settings.PasswordResetEnabled,
|
||||
|
||||
83
backend/internal/handler/setting_handler_public_test.go
Normal file
83
backend/internal/handler/setting_handler_public_test.go
Normal file
@@ -0,0 +1,83 @@
|
||||
//go:build unit
|
||||
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type settingHandlerPublicRepoStub struct {
|
||||
values map[string]string
|
||||
}
|
||||
|
||||
func (s *settingHandlerPublicRepoStub) Get(ctx context.Context, key string) (*service.Setting, error) {
|
||||
panic("unexpected Get call")
|
||||
}
|
||||
|
||||
func (s *settingHandlerPublicRepoStub) GetValue(ctx context.Context, key string) (string, error) {
|
||||
panic("unexpected GetValue call")
|
||||
}
|
||||
|
||||
func (s *settingHandlerPublicRepoStub) Set(ctx context.Context, key, value string) error {
|
||||
panic("unexpected Set call")
|
||||
}
|
||||
|
||||
func (s *settingHandlerPublicRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
|
||||
out := make(map[string]string, len(keys))
|
||||
for _, key := range keys {
|
||||
if value, ok := s.values[key]; ok {
|
||||
out[key] = value
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *settingHandlerPublicRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
|
||||
panic("unexpected SetMultiple call")
|
||||
}
|
||||
|
||||
func (s *settingHandlerPublicRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
|
||||
panic("unexpected GetAll call")
|
||||
}
|
||||
|
||||
func (s *settingHandlerPublicRepoStub) Delete(ctx context.Context, key string) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
func TestSettingHandler_GetPublicSettings_ExposesForceEmailOnThirdPartySignup(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
repo := &settingHandlerPublicRepoStub{
|
||||
values: map[string]string{
|
||||
service.SettingKeyForceEmailOnThirdPartySignup: "true",
|
||||
},
|
||||
}
|
||||
h := NewSettingHandler(service.NewSettingService(repo, &config.Config{}), "test-version")
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/settings/public", nil)
|
||||
|
||||
h.GetPublicSettings(c)
|
||||
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
|
||||
var resp struct {
|
||||
Code int `json:"code"`
|
||||
Data struct {
|
||||
ForceEmailOnThirdPartySignup bool `json:"force_email_on_third_party_signup"`
|
||||
} `json:"data"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
|
||||
require.Equal(t, 0, resp.Code)
|
||||
require.True(t, resp.Data.ForceEmailOnThirdPartySignup)
|
||||
}
|
||||
@@ -72,18 +72,54 @@ func RegisterAuthRoutes(
|
||||
}),
|
||||
h.Auth.ExchangePendingOAuthCompletion,
|
||||
)
|
||||
auth.POST("/oauth/pending/create-account",
|
||||
rateLimiter.LimitWithOptions("oauth-pending-create-account", 10, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}),
|
||||
h.Auth.CreatePendingOAuthAccount,
|
||||
)
|
||||
auth.POST("/oauth/pending/bind-login",
|
||||
rateLimiter.LimitWithOptions("oauth-pending-bind-login", 10, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}),
|
||||
h.Auth.BindPendingOAuthLogin,
|
||||
)
|
||||
auth.POST("/oauth/linuxdo/complete-registration",
|
||||
rateLimiter.LimitWithOptions("oauth-linuxdo-complete", 10, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}),
|
||||
h.Auth.CompleteLinuxDoOAuthRegistration,
|
||||
)
|
||||
auth.POST("/oauth/linuxdo/bind-login",
|
||||
rateLimiter.LimitWithOptions("oauth-linuxdo-bind-login", 20, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}),
|
||||
h.Auth.BindLinuxDoOAuthLogin,
|
||||
)
|
||||
auth.POST("/oauth/linuxdo/create-account",
|
||||
rateLimiter.LimitWithOptions("oauth-linuxdo-create-account", 10, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}),
|
||||
h.Auth.CreateLinuxDoOAuthAccount,
|
||||
)
|
||||
auth.POST("/oauth/wechat/complete-registration",
|
||||
rateLimiter.LimitWithOptions("oauth-wechat-complete", 10, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}),
|
||||
h.Auth.CompleteWeChatOAuthRegistration,
|
||||
)
|
||||
auth.POST("/oauth/wechat/bind-login",
|
||||
rateLimiter.LimitWithOptions("oauth-wechat-bind-login", 20, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}),
|
||||
h.Auth.BindWeChatOAuthLogin,
|
||||
)
|
||||
auth.POST("/oauth/wechat/create-account",
|
||||
rateLimiter.LimitWithOptions("oauth-wechat-create-account", 10, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}),
|
||||
h.Auth.CreateWeChatOAuthAccount,
|
||||
)
|
||||
auth.GET("/oauth/oidc/start", h.Auth.OIDCOAuthStart)
|
||||
auth.GET("/oauth/oidc/callback", h.Auth.OIDCOAuthCallback)
|
||||
auth.POST("/oauth/oidc/complete-registration",
|
||||
@@ -92,6 +128,18 @@ func RegisterAuthRoutes(
|
||||
}),
|
||||
h.Auth.CompleteOIDCOAuthRegistration,
|
||||
)
|
||||
auth.POST("/oauth/oidc/bind-login",
|
||||
rateLimiter.LimitWithOptions("oauth-oidc-bind-login", 20, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}),
|
||||
h.Auth.BindOIDCOAuthLogin,
|
||||
)
|
||||
auth.POST("/oauth/oidc/create-account",
|
||||
rateLimiter.LimitWithOptions("oauth-oidc-create-account", 10, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}),
|
||||
h.Auth.CreateOIDCOAuthAccount,
|
||||
)
|
||||
}
|
||||
|
||||
// 公开设置(无需认证)
|
||||
|
||||
151
backend/internal/service/auth_oauth_email_flow.go
Normal file
151
backend/internal/service/auth_oauth_email_flow.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// VerifyOAuthEmailCode verifies the locally entered email verification code for
|
||||
// third-party signup and binding flows. This is intentionally independent from
|
||||
// the global registration email verification toggle.
|
||||
func (s *AuthService) VerifyOAuthEmailCode(ctx context.Context, email, verifyCode string) error {
|
||||
email = strings.TrimSpace(strings.ToLower(email))
|
||||
verifyCode = strings.TrimSpace(verifyCode)
|
||||
|
||||
if email == "" {
|
||||
return ErrEmailVerifyRequired
|
||||
}
|
||||
if verifyCode == "" {
|
||||
return ErrEmailVerifyRequired
|
||||
}
|
||||
if s == nil || s.emailService == nil {
|
||||
return ErrServiceUnavailable
|
||||
}
|
||||
return s.emailService.VerifyCode(ctx, email, verifyCode)
|
||||
}
|
||||
|
||||
// RegisterOAuthEmailAccount creates a local account from a third-party first
|
||||
// login after the user has verified a local email address.
|
||||
func (s *AuthService) RegisterOAuthEmailAccount(
|
||||
ctx context.Context,
|
||||
email string,
|
||||
password string,
|
||||
verifyCode string,
|
||||
invitationCode string,
|
||||
signupSource string,
|
||||
) (*TokenPair, *User, error) {
|
||||
if s == nil {
|
||||
return nil, nil, ErrServiceUnavailable
|
||||
}
|
||||
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
|
||||
return nil, nil, ErrRegDisabled
|
||||
}
|
||||
|
||||
email = strings.TrimSpace(strings.ToLower(email))
|
||||
if isReservedEmail(email) {
|
||||
return nil, nil, ErrEmailReserved
|
||||
}
|
||||
if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if err := s.VerifyOAuthEmailCode(ctx, email, verifyCode); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
var invitationRedeemCode *RedeemCode
|
||||
if s.settingService.IsInvitationCodeEnabled(ctx) {
|
||||
if invitationCode == "" {
|
||||
return nil, nil, ErrInvitationCodeRequired
|
||||
}
|
||||
redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode)
|
||||
if err != nil {
|
||||
return nil, nil, ErrInvitationCodeInvalid
|
||||
}
|
||||
if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused {
|
||||
return nil, nil, ErrInvitationCodeInvalid
|
||||
}
|
||||
invitationRedeemCode = redeemCode
|
||||
}
|
||||
|
||||
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
|
||||
if err != nil {
|
||||
return nil, nil, ErrServiceUnavailable
|
||||
}
|
||||
if existsEmail {
|
||||
return nil, nil, ErrEmailExists
|
||||
}
|
||||
|
||||
hashedPassword, err := s.HashPassword(password)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("hash password: %w", err)
|
||||
}
|
||||
|
||||
signupSource = strings.TrimSpace(strings.ToLower(signupSource))
|
||||
if signupSource == "" {
|
||||
signupSource = "email"
|
||||
}
|
||||
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
|
||||
|
||||
user := &User{
|
||||
Email: email,
|
||||
PasswordHash: hashedPassword,
|
||||
Role: RoleUser,
|
||||
Balance: grantPlan.Balance,
|
||||
Concurrency: grantPlan.Concurrency,
|
||||
Status: StatusActive,
|
||||
}
|
||||
|
||||
if err := s.userRepo.Create(ctx, user); err != nil {
|
||||
if errors.Is(err, ErrEmailExists) {
|
||||
return nil, nil, ErrEmailExists
|
||||
}
|
||||
return nil, nil, ErrServiceUnavailable
|
||||
}
|
||||
|
||||
s.postAuthUserBootstrap(ctx, user, signupSource, true)
|
||||
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
|
||||
|
||||
if invitationRedeemCode != nil {
|
||||
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
|
||||
return nil, nil, ErrInvitationCodeInvalid
|
||||
}
|
||||
}
|
||||
|
||||
tokenPair, err := s.GenerateTokenPair(ctx, user, "")
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("generate token pair: %w", err)
|
||||
}
|
||||
return tokenPair, user, nil
|
||||
}
|
||||
|
||||
// ValidatePasswordCredentials checks the local password without completing the
|
||||
// login flow. This is used by pending third-party account adoption flows before
|
||||
// the external identity has been bound.
|
||||
func (s *AuthService) ValidatePasswordCredentials(ctx context.Context, email, password string) (*User, error) {
|
||||
if s == nil {
|
||||
return nil, ErrServiceUnavailable
|
||||
}
|
||||
|
||||
user, err := s.userRepo.GetByEmail(ctx, strings.TrimSpace(strings.ToLower(email)))
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUserNotFound) {
|
||||
return nil, ErrInvalidCredentials
|
||||
}
|
||||
return nil, ErrServiceUnavailable
|
||||
}
|
||||
if !user.IsActive() {
|
||||
return nil, ErrUserNotActive
|
||||
}
|
||||
if !s.CheckPassword(password, user.PasswordHash) {
|
||||
return nil, ErrInvalidCredentials
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// RecordSuccessfulLogin updates last-login activity after a non-standard login
|
||||
// flow finishes with a real session.
|
||||
func (s *AuthService) RecordSuccessfulLogin(ctx context.Context, userID int64) {
|
||||
s.touchUserLogin(ctx, userID)
|
||||
}
|
||||
@@ -217,6 +217,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
keys := []string{
|
||||
SettingKeyRegistrationEnabled,
|
||||
SettingKeyEmailVerifyEnabled,
|
||||
SettingKeyForceEmailOnThirdPartySignup,
|
||||
SettingKeyRegistrationEmailSuffixWhitelist,
|
||||
SettingKeyPromoCodeEnabled,
|
||||
SettingKeyPasswordResetEnabled,
|
||||
@@ -294,6 +295,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
return &PublicSettings{
|
||||
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
|
||||
EmailVerifyEnabled: emailVerifyEnabled,
|
||||
ForceEmailOnThirdPartySignup: settings[SettingKeyForceEmailOnThirdPartySignup] == "true",
|
||||
RegistrationEmailSuffixWhitelist: registrationEmailSuffixWhitelist,
|
||||
PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用
|
||||
PasswordResetEnabled: passwordResetEnabled,
|
||||
|
||||
@@ -77,3 +77,16 @@ func TestSettingService_GetPublicSettings_ExposesTablePreferences(t *testing.T)
|
||||
require.Equal(t, 50, settings.TableDefaultPageSize)
|
||||
require.Equal(t, []int{20, 50, 100}, settings.TablePageSizeOptions)
|
||||
}
|
||||
|
||||
func TestSettingService_GetPublicSettings_ExposesForceEmailOnThirdPartySignup(t *testing.T) {
|
||||
repo := &settingPublicRepoStub{
|
||||
values: map[string]string{
|
||||
SettingKeyForceEmailOnThirdPartySignup: "true",
|
||||
},
|
||||
}
|
||||
svc := NewSettingService(repo, &config.Config{})
|
||||
|
||||
settings, err := svc.GetPublicSettings(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.True(t, settings.ForceEmailOnThirdPartySignup)
|
||||
}
|
||||
|
||||
@@ -128,6 +128,7 @@ type DefaultSubscriptionSetting struct {
|
||||
type PublicSettings struct {
|
||||
RegistrationEnabled bool
|
||||
EmailVerifyEnabled bool
|
||||
ForceEmailOnThirdPartySignup bool
|
||||
RegistrationEmailSuffixWhitelist []string
|
||||
PromoCodeEnabled bool
|
||||
PasswordResetEnabled bool
|
||||
|
||||
Reference in New Issue
Block a user