diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go index 835e5fd8..c4ecb8fa 100644 --- a/backend/internal/handler/auth_linuxdo_oauth.go +++ b/backend/internal/handler/auth_linuxdo_oauth.go @@ -243,6 +243,18 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { if subject != "" { email = linuxDoSyntheticEmail(subject) } + identityKey := service.PendingAuthIdentityKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo", + ProviderSubject: subject, + } + upstreamClaims := map[string]any{ + "email": email, + "username": username, + "subject": subject, + "suggested_display_name": displayName, + "suggested_avatar_url": avatarURL, + } if intent == oauthIntentBindCurrentUser { targetUserID, err := h.readOAuthBindUserIDFromCookie(c, linuxDoOAuthBindUserCookieName) if err != nil { @@ -250,23 +262,13 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { return } if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ - Intent: oauthIntentBindCurrentUser, - Identity: service.PendingAuthIdentityKey{ - ProviderType: "linuxdo", - ProviderKey: "linuxdo", - ProviderSubject: subject, - }, - TargetUserID: &targetUserID, - ResolvedEmail: email, - RedirectTo: redirectTo, - BrowserSessionKey: browserSessionKey, - UpstreamIdentityClaims: map[string]any{ - "email": email, - "username": username, - "subject": subject, - "suggested_display_name": displayName, - "suggested_avatar_url": avatarURL, - }, + Intent: oauthIntentBindCurrentUser, + Identity: identityKey, + TargetUserID: &targetUserID, + ResolvedEmail: email, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, CompletionResponse: map[string]any{ "redirect": redirectTo, }, @@ -278,27 +280,60 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { return } + existingIdentityUser, err := h.findOAuthIdentityUser(c.Request.Context(), identityKey) + if err != nil { + redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + if existingIdentityUser != nil { + tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), existingIdentityUser.Email, username, "") + if err != nil { + redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ + Intent: oauthIntentLogin, + Identity: identityKey, + TargetUserID: &user.ID, + ResolvedEmail: existingIdentityUser.Email, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, + CompletionResponse: map[string]any{ + "access_token": tokenPair.AccessToken, + "refresh_token": tokenPair.RefreshToken, + "expires_in": tokenPair.ExpiresIn, + "token_type": "Bearer", + "redirect": redirectTo, + }, + }); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") + return + } + redirectToFrontendCallback(c, frontendCallback) + return + } + + if h.isForceEmailOnThirdPartySignup(c.Request.Context()) { + if err := h.createOAuthEmailRequiredPendingSession(c, identityKey, redirectTo, browserSessionKey, upstreamClaims); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") + return + } + redirectToFrontendCallback(c, frontendCallback) + return + } + // 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "") if err != nil { if errors.Is(err, service.ErrOAuthInvitationRequired) { if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ - Intent: "login", - Identity: service.PendingAuthIdentityKey{ - ProviderType: "linuxdo", - ProviderKey: "linuxdo", - ProviderSubject: subject, - }, - ResolvedEmail: email, - RedirectTo: redirectTo, - BrowserSessionKey: browserSessionKey, - UpstreamIdentityClaims: map[string]any{ - "email": email, - "username": username, - "subject": subject, - "suggested_display_name": displayName, - "suggested_avatar_url": avatarURL, - }, + Intent: "login", + Identity: identityKey, + ResolvedEmail: email, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, CompletionResponse: map[string]any{ "error": "invitation_required", "redirect": redirectTo, @@ -316,23 +351,13 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { } if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ - Intent: "login", - Identity: service.PendingAuthIdentityKey{ - ProviderType: "linuxdo", - ProviderKey: "linuxdo", - ProviderSubject: subject, - }, - TargetUserID: &user.ID, - ResolvedEmail: email, - RedirectTo: redirectTo, - BrowserSessionKey: browserSessionKey, - UpstreamIdentityClaims: map[string]any{ - "email": email, - "username": username, - "subject": subject, - "suggested_display_name": displayName, - "suggested_avatar_url": avatarURL, - }, + Intent: "login", + Identity: identityKey, + TargetUserID: &user.ID, + ResolvedEmail: email, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, CompletionResponse: map[string]any{ "access_token": tokenPair.AccessToken, "refresh_token": tokenPair.RefreshToken, diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go index 2d6c3714..99b9b406 100644 --- a/backend/internal/handler/auth_oauth_pending_flow.go +++ b/backend/internal/handler/auth_oauth_pending_flow.go @@ -46,6 +46,36 @@ type oauthAdoptionDecisionRequest struct { AdoptAvatar *bool `json:"adopt_avatar,omitempty"` } +type bindPendingOAuthLoginRequest struct { + Email string `json:"email" binding:"required,email"` + Password string `json:"password" binding:"required"` + AdoptDisplayName *bool `json:"adopt_display_name,omitempty"` + AdoptAvatar *bool `json:"adopt_avatar,omitempty"` +} + +type createPendingOAuthAccountRequest struct { + Email string `json:"email" binding:"required,email"` + VerifyCode string `json:"verify_code,omitempty"` + Password string `json:"password" binding:"required,min=6"` + InvitationCode string `json:"invitation_code,omitempty"` + AdoptDisplayName *bool `json:"adopt_display_name,omitempty"` + AdoptAvatar *bool `json:"adopt_avatar,omitempty"` +} + +func (r bindPendingOAuthLoginRequest) adoptionDecision() oauthAdoptionDecisionRequest { + return oauthAdoptionDecisionRequest{ + AdoptDisplayName: r.AdoptDisplayName, + AdoptAvatar: r.AdoptAvatar, + } +} + +func (r createPendingOAuthAccountRequest) adoptionDecision() oauthAdoptionDecisionRequest { + return oauthAdoptionDecisionRequest{ + AdoptDisplayName: r.AdoptDisplayName, + AdoptAvatar: r.AdoptAvatar, + } +} + func (h *AuthHandler) pendingIdentityService() (*service.AuthPendingIdentityService, error) { if h == nil || h.authService == nil || h.authService.EntClient() == nil { return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") @@ -170,6 +200,36 @@ func readCompletionResponse(session map[string]any) (map[string]any, bool) { return result, true } +func clonePendingMap(values map[string]any) map[string]any { + if len(values) == 0 { + return map[string]any{} + } + cloned := make(map[string]any, len(values)) + for key, value := range values { + cloned[key] = value + } + return cloned +} + +func mergePendingCompletionResponse(session *dbent.PendingAuthSession, overrides map[string]any) map[string]any { + payload, _ := readCompletionResponse(session.LocalFlowState) + merged := clonePendingMap(payload) + if strings.TrimSpace(session.RedirectTo) != "" { + if _, exists := merged["redirect"]; !exists { + merged["redirect"] = session.RedirectTo + } + } + for key, value := range overrides { + if value == nil { + delete(merged, key) + continue + } + merged[key] = value + } + applySuggestedProfileToCompletionResponse(merged, session.UpstreamIdentityClaims) + return merged +} + func pendingSessionStringValue(values map[string]any, key string) string { if len(values) == 0 { return "" @@ -264,6 +324,89 @@ func (h *AuthHandler) entClient() *dbent.Client { return h.authService.EntClient() } +func (h *AuthHandler) isForceEmailOnThirdPartySignup(ctx context.Context) bool { + if h == nil || h.settingSvc == nil { + return false + } + defaults, err := h.settingSvc.GetAuthSourceDefaultSettings(ctx) + if err != nil || defaults == nil { + return false + } + return defaults.ForceEmailOnThirdPartySignup +} + +func (h *AuthHandler) findOAuthIdentityUser(ctx context.Context, identity service.PendingAuthIdentityKey) (*dbent.User, error) { + client := h.entClient() + if client == nil { + return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + + record, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(strings.TrimSpace(identity.ProviderType)), + authidentity.ProviderKeyEQ(strings.TrimSpace(identity.ProviderKey)), + authidentity.ProviderSubjectEQ(strings.TrimSpace(identity.ProviderSubject)), + ). + Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return nil, nil + } + return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err) + } + + userEntity, err := client.User.Get(ctx, record.UserID) + if err != nil { + if dbent.IsNotFound(err) { + return nil, nil + } + return nil, infraerrors.InternalServer("AUTH_IDENTITY_USER_LOOKUP_FAILED", "failed to load auth identity user").WithCause(err) + } + return userEntity, nil +} + +func (h *AuthHandler) createOAuthEmailRequiredPendingSession( + c *gin.Context, + identity service.PendingAuthIdentityKey, + redirectTo string, + browserSessionKey string, + upstreamClaims map[string]any, +) error { + return h.createOAuthPendingSession(c, oauthPendingSessionPayload{ + Intent: oauthIntentLogin, + Identity: identity, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, + CompletionResponse: map[string]any{ + "redirect": redirectTo, + "step": "email_required", + "force_email_on_signup": true, + "email_binding_required": true, + "existing_account_bindable": true, + }, + }) +} + +func (h *AuthHandler) BindLinuxDoOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "linuxdo") } +func (h *AuthHandler) BindOIDCOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "oidc") } +func (h *AuthHandler) BindWeChatOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "wechat") } +func (h *AuthHandler) BindPendingOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "") } + +func (h *AuthHandler) CreateLinuxDoOAuthAccount(c *gin.Context) { + h.createPendingOAuthAccount(c, "linuxdo") +} + +func (h *AuthHandler) CreateOIDCOAuthAccount(c *gin.Context) { h.createPendingOAuthAccount(c, "oidc") } + +func (h *AuthHandler) CreateWeChatOAuthAccount(c *gin.Context) { + h.createPendingOAuthAccount(c, "wechat") +} + +func (h *AuthHandler) CreatePendingOAuthAccount(c *gin.Context) { + h.createPendingOAuthAccount(c, "") +} + func (h *AuthHandler) upsertPendingOAuthAdoptionDecision( c *gin.Context, sessionID int64, @@ -313,6 +456,60 @@ func (h *AuthHandler) upsertPendingOAuthAdoptionDecision( return decision, nil } +func (h *AuthHandler) ensurePendingOAuthAdoptionDecision( + c *gin.Context, + sessionID int64, + req oauthAdoptionDecisionRequest, +) (*dbent.IdentityAdoptionDecision, error) { + decision, err := h.upsertPendingOAuthAdoptionDecision(c, sessionID, req) + if err != nil { + return nil, err + } + if decision != nil { + return decision, nil + } + + svc, err := h.pendingIdentityService() + if err != nil { + return nil, err + } + decision, err = svc.UpsertAdoptionDecision(c.Request.Context(), service.PendingIdentityAdoptionDecisionInput{ + PendingAuthSessionID: sessionID, + }) + if err != nil { + return nil, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_SAVE_FAILED", "failed to save oauth profile adoption decision").WithCause(err) + } + return decision, nil +} + +func updatePendingOAuthSessionProgress( + ctx context.Context, + client *dbent.Client, + session *dbent.PendingAuthSession, + intent string, + resolvedEmail string, + targetUserID *int64, + completionResponse map[string]any, +) (*dbent.PendingAuthSession, error) { + if client == nil || session == nil { + return nil, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth session is invalid") + } + + localFlowState := clonePendingMap(session.LocalFlowState) + localFlowState[oauthCompletionResponseKey] = clonePendingMap(completionResponse) + + update := client.PendingAuthSession.UpdateOneID(session.ID). + SetIntent(strings.TrimSpace(intent)). + SetResolvedEmail(strings.TrimSpace(resolvedEmail)). + SetLocalFlowState(localFlowState) + if targetUserID != nil && *targetUserID > 0 { + update = update.SetTargetUserID(*targetUserID) + } else { + update = update.ClearTargetUserID() + } + return update.Save(ctx) +} + func resolvePendingOAuthTargetUserID(ctx context.Context, client *dbent.Client, session *dbent.PendingAuthSession) (int64, error) { if session == nil { return 0, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth session is invalid") @@ -401,17 +598,18 @@ func shouldBindPendingOAuthIdentity(session *dbent.PendingAuthSession, decision return decision.AdoptDisplayName || decision.AdoptAvatar } -func applyPendingOAuthAdoption( +func applyPendingOAuthBinding( ctx context.Context, client *dbent.Client, session *dbent.PendingAuthSession, decision *dbent.IdentityAdoptionDecision, overrideUserID *int64, + forceBind bool, ) error { - if client == nil || session == nil || decision == nil { + if client == nil || session == nil { return nil } - if !shouldBindPendingOAuthIdentity(session, decision) { + if !forceBind && !shouldBindPendingOAuthIdentity(session, decision) { return nil } @@ -427,11 +625,11 @@ func applyPendingOAuthAdoption( } adoptedDisplayName := "" - if decision.AdoptDisplayName { + if decision != nil && decision.AdoptDisplayName { adoptedDisplayName = normalizeAdoptedOAuthDisplayName(pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_display_name")) } adoptedAvatarURL := "" - if decision.AdoptAvatar { + if decision != nil && decision.AdoptAvatar { adoptedAvatarURL = pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_avatar_url") } @@ -441,7 +639,7 @@ func applyPendingOAuthAdoption( } defer func() { _ = tx.Rollback() }() - if decision.AdoptDisplayName && adoptedDisplayName != "" { + if decision != nil && decision.AdoptDisplayName && adoptedDisplayName != "" { if err := tx.Client().User.UpdateOneID(targetUserID). SetUsername(adoptedDisplayName). Exec(ctx); err != nil { @@ -458,10 +656,10 @@ func applyPendingOAuthAdoption( for key, value := range session.UpstreamIdentityClaims { metadata[key] = value } - if decision.AdoptDisplayName && adoptedDisplayName != "" { + if decision != nil && decision.AdoptDisplayName && adoptedDisplayName != "" { metadata["display_name"] = adoptedDisplayName } - if decision.AdoptAvatar && adoptedAvatarURL != "" { + if decision != nil && decision.AdoptAvatar && adoptedAvatarURL != "" { metadata["avatar_url"] = adoptedAvatarURL } @@ -473,7 +671,7 @@ func applyPendingOAuthAdoption( return err } - if decision.IdentityID == nil || *decision.IdentityID != identity.ID { + if decision != nil && (decision.IdentityID == nil || *decision.IdentityID != identity.ID) { if _, err := tx.Client().IdentityAdoptionDecision.UpdateOneID(decision.ID). SetIdentityID(identity.ID). Save(ctx); err != nil { @@ -484,6 +682,16 @@ func applyPendingOAuthAdoption( return tx.Commit() } +func applyPendingOAuthAdoption( + ctx context.Context, + client *dbent.Client, + session *dbent.PendingAuthSession, + decision *dbent.IdentityAdoptionDecision, + overrideUserID *int64, +) error { + return applyPendingOAuthBinding(ctx, client, session, decision, overrideUserID, false) +} + func applySuggestedProfileToCompletionResponse(payload map[string]any, upstream map[string]any) { if len(payload) == 0 || len(upstream) == 0 { return @@ -507,6 +715,206 @@ func applySuggestedProfileToCompletionResponse(payload map[string]any, upstream } } +func readPendingOAuthBrowserSession(c *gin.Context, h *AuthHandler) (*service.AuthPendingIdentityService, *dbent.PendingAuthSession, func(), error) { + secureCookie := isRequestHTTPS(c) + clearCookies := func() { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + } + + sessionToken, err := readOAuthPendingSessionCookie(c) + if err != nil || strings.TrimSpace(sessionToken) == "" { + clearCookies() + return nil, nil, clearCookies, service.ErrPendingAuthSessionNotFound + } + browserSessionKey, err := readOAuthPendingBrowserCookie(c) + if err != nil || strings.TrimSpace(browserSessionKey) == "" { + clearCookies() + return nil, nil, clearCookies, service.ErrPendingAuthBrowserMismatch + } + + svc, err := h.pendingIdentityService() + if err != nil { + clearCookies() + return nil, nil, clearCookies, err + } + + session, err := svc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey) + if err != nil { + clearCookies() + return nil, nil, clearCookies, err + } + + return svc, session, clearCookies, nil +} + +func buildPendingOAuthSessionStatusPayload(session *dbent.PendingAuthSession) gin.H { + payload := gin.H{ + "auth_result": "pending_session", + "provider": strings.TrimSpace(session.ProviderType), + "intent": strings.TrimSpace(session.Intent), + } + for key, value := range mergePendingCompletionResponse(session, nil) { + payload[key] = value + } + if email := strings.TrimSpace(session.ResolvedEmail); email != "" { + payload["email"] = email + } + return payload +} + +func writeOAuthTokenPairResponse(c *gin.Context, tokenPair *service.TokenPair) { + c.JSON(http.StatusOK, gin.H{ + "access_token": tokenPair.AccessToken, + "refresh_token": tokenPair.RefreshToken, + "expires_in": tokenPair.ExpiresIn, + "token_type": "Bearer", + }) +} + +func (h *AuthHandler) bindPendingOAuthLogin(c *gin.Context, provider string) { + var req bindPendingOAuthLoginRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + pendingSvc, session, clearCookies, err := readPendingOAuthBrowserSession(c, h) + if err != nil { + response.ErrorFrom(c, err) + return + } + if strings.TrimSpace(provider) != "" && !strings.EqualFold(strings.TrimSpace(session.ProviderType), provider) { + response.BadRequest(c, "Pending oauth session provider mismatch") + return + } + + user, err := h.authService.ValidatePasswordCredentials(c.Request.Context(), strings.TrimSpace(req.Email), req.Password) + if err != nil { + response.ErrorFrom(c, err) + return + } + if session.TargetUserID != nil && *session.TargetUserID > 0 && user.ID != *session.TargetUserID { + response.ErrorFrom(c, infraerrors.Conflict("PENDING_AUTH_TARGET_USER_MISMATCH", "pending oauth session must be completed by the targeted user")) + return + } + + decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, req.adoptionDecision()) + if err != nil { + response.ErrorFrom(c, err) + return + } + if err := applyPendingOAuthBinding(c.Request.Context(), h.entClient(), session, decision, &user.ID, true); err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err)) + return + } + + h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID) + tokenPair, err := h.authService.GenerateTokenPair(c.Request.Context(), user, "") + if err != nil { + response.InternalError(c, "Failed to generate token pair") + return + } + if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), session.SessionToken, session.BrowserSessionKey); err != nil { + clearCookies() + response.ErrorFrom(c, err) + return + } + + clearCookies() + writeOAuthTokenPairResponse(c, tokenPair) +} + +func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string) { + var req createPendingOAuthAccountRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + pendingSvc, session, clearCookies, err := readPendingOAuthBrowserSession(c, h) + if err != nil { + response.ErrorFrom(c, err) + return + } + if strings.TrimSpace(provider) != "" && !strings.EqualFold(strings.TrimSpace(session.ProviderType), provider) { + response.BadRequest(c, "Pending oauth session provider mismatch") + return + } + + client := h.entClient() + if client == nil { + response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")) + return + } + + email := strings.TrimSpace(strings.ToLower(req.Email)) + existingUser, err := client.User.Query().Where(dbuser.EmailEQ(email)).Only(c.Request.Context()) + if err != nil && !dbent.IsNotFound(err) { + response.ErrorFrom(c, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")) + return + } + if existingUser != nil { + completionResponse := mergePendingCompletionResponse(session, map[string]any{ + "step": "bind_login_required", + "email": email, + }) + session, err = updatePendingOAuthSessionProgress( + c.Request.Context(), + client, + session, + "adopt_existing_user_by_email", + email, + &existingUser.ID, + completionResponse, + ) + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_SESSION_UPDATE_FAILED", "failed to update pending oauth session").WithCause(err)) + return + } + + if _, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, req.adoptionDecision()); err != nil { + response.ErrorFrom(c, err) + return + } + + c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(session)) + return + } + + tokenPair, user, err := h.authService.RegisterOAuthEmailAccount( + c.Request.Context(), + email, + req.Password, + strings.TrimSpace(req.VerifyCode), + strings.TrimSpace(req.InvitationCode), + strings.TrimSpace(session.ProviderType), + ) + if err != nil { + response.ErrorFrom(c, err) + return + } + + decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, req.adoptionDecision()) + if err != nil { + response.ErrorFrom(c, err) + return + } + if err := applyPendingOAuthBinding(c.Request.Context(), client, session, decision, &user.ID, true); err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err)) + return + } + + if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), session.SessionToken, session.BrowserSessionKey); err != nil { + clearCookies() + response.ErrorFrom(c, err) + return + } + + clearCookies() + writeOAuthTokenPairResponse(c, tokenPair) +} + // ExchangePendingOAuthCompletion redeems a pending OAuth browser session into a frontend-safe payload. // POST /api/v1/auth/oauth/pending/exchange func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) { diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go index 3afb4fb7..80338b8a 100644 --- a/backend/internal/handler/auth_oauth_pending_flow_test.go +++ b/backend/internal/handler/auth_oauth_pending_flow_test.go @@ -509,9 +509,305 @@ func TestExchangePendingOAuthCompletionInvitationRequiredFalseFalsePersistsDecis require.Nil(t, storedSession.ConsumedAt) } +func TestCreateOIDCOAuthAccountCreatesUserBindsIdentityAndConsumesSession(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "fresh@example.com", "246810") + ctx := context.Background() + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("create-account-session-token"). + SetIntent("login"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-create-123"). + SetBrowserSessionKey("create-account-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + "suggested_display_name": "Fresh OIDC User", + "suggested_avatar_url": "https://cdn.example/fresh.png", + }). + SetRedirectTo("/profile"). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"email":"fresh@example.com","verify_code":"246810","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("create-account-browser-session-key")}) + ginCtx.Request = req + + handler.CreateOIDCOAuthAccount(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + + var payload map[string]any + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload)) + require.NotEmpty(t, payload["access_token"]) + require.NotEmpty(t, payload["refresh_token"]) + require.Equal(t, "Bearer", payload["token_type"]) + + createdUser, err := client.User.Query().Where(dbuser.EmailEQ("fresh@example.com")).Only(ctx) + require.NoError(t, err) + require.Equal(t, service.StatusActive, createdUser.Status) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("oidc"), + authidentity.ProviderKeyEQ("https://issuer.example"), + authidentity.ProviderSubjectEQ("oidc-create-123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, createdUser.ID, identity.UserID) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.NotNil(t, storedSession.ConsumedAt) +} + +func TestCreateOIDCOAuthAccountExistingEmailReturnsAdoptExistingUserByEmailState(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790") + ctx := context.Background() + + existingUser, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("existing-email-session-token"). + SetIntent("login"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-existing-123"). + SetBrowserSessionKey("existing-email-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + "suggested_display_name": "Existing OIDC User", + "suggested_avatar_url": "https://cdn.example/existing.png", + }). + SetRedirectTo("/dashboard"). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"email":"owner@example.com","verify_code":"135790","password":"secret-123"}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("existing-email-browser-session-key")}) + ginCtx.Request = req + + handler.CreateOIDCOAuthAccount(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + + var payload map[string]any + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload)) + require.Equal(t, "pending_session", payload["auth_result"]) + require.Equal(t, "adopt_existing_user_by_email", payload["intent"]) + require.Equal(t, "oidc", payload["provider"]) + require.Equal(t, "/dashboard", payload["redirect"]) + require.Equal(t, true, payload["adoption_required"]) + require.Equal(t, "Existing OIDC User", payload["suggested_display_name"]) + require.Equal(t, "https://cdn.example/existing.png", payload["suggested_avatar_url"]) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Equal(t, "adopt_existing_user_by_email", storedSession.Intent) + require.NotNil(t, storedSession.TargetUserID) + require.Equal(t, existingUser.ID, *storedSession.TargetUserID) + require.Equal(t, "owner@example.com", storedSession.ResolvedEmail) + require.Nil(t, storedSession.ConsumedAt) + + identityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("oidc"), + authidentity.ProviderKeyEQ("https://issuer.example"), + authidentity.ProviderSubjectEQ("oidc-existing-123"), + ). + Count(ctx) + require.NoError(t, err) + require.Zero(t, identityCount) +} + +func TestBindOIDCOAuthLoginBindsExistingUserAndConsumesSession(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + passwordHash, err := handler.authService.HashPassword("secret-123") + require.NoError(t, err) + + existingUser, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash(passwordHash). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("bind-login-session-token"). + SetIntent("adopt_existing_user_by_email"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-bind-123"). + SetTargetUserID(existingUser.ID). + SetResolvedEmail(existingUser.Email). + SetBrowserSessionKey("bind-login-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + "suggested_display_name": "Bound OIDC User", + "suggested_avatar_url": "https://cdn.example/bound.png", + }). + SetRedirectTo("/profile"). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-login-browser-session-key")}) + ginCtx.Request = req + + handler.BindOIDCOAuthLogin(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + + var payload map[string]any + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload)) + require.NotEmpty(t, payload["access_token"]) + require.NotEmpty(t, payload["refresh_token"]) + require.Equal(t, "Bearer", payload["token_type"]) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("oidc"), + authidentity.ProviderKeyEQ("https://issuer.example"), + authidentity.ProviderSubjectEQ("oidc-bind-123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, existingUser.ID, identity.UserID) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.NotNil(t, storedSession.ConsumedAt) +} + +func TestBindOIDCOAuthLoginRejectsInvalidPasswordWithoutConsumingSession(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + passwordHash, err := handler.authService.HashPassword("secret-123") + require.NoError(t, err) + + existingUser, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash(passwordHash). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("bind-login-invalid-password-session-token"). + SetIntent("adopt_existing_user_by_email"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-bind-invalid-123"). + SetTargetUserID(existingUser.ID). + SetResolvedEmail(existingUser.Email). + SetBrowserSessionKey("bind-login-invalid-password-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + "suggested_display_name": "Bound OIDC User", + "suggested_avatar_url": "https://cdn.example/bound.png", + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"email":"owner@example.com","password":"wrong-password"}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-login-invalid-password-browser-session-key")}) + ginCtx.Request = req + + handler.BindOIDCOAuthLogin(ginCtx) + + require.Equal(t, http.StatusUnauthorized, recorder.Code) + payload := decodeJSONBody(t, recorder) + require.Equal(t, "INVALID_CREDENTIALS", payload["reason"]) + + identityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("oidc"), + authidentity.ProviderKeyEQ("https://issuer.example"), + authidentity.ProviderSubjectEQ("oidc-bind-invalid-123"), + ). + Count(ctx) + require.NoError(t, err) + require.Zero(t, identityCount) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + func newOAuthPendingFlowTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) { t.Helper() + return newOAuthPendingFlowTestHandlerWithOptions(t, invitationEnabled, false, nil) +} + +func newOAuthPendingFlowTestHandlerWithEmailVerification( + t *testing.T, + invitationEnabled bool, + email string, + code string, +) (*AuthHandler, *dbent.Client) { + t.Helper() + + cache := &oauthPendingFlowEmailCacheStub{ + verificationCodes: map[string]*service.VerificationCodeData{ + email: { + Code: code, + Attempts: 0, + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(15 * time.Minute), + }, + }, + } + return newOAuthPendingFlowTestHandlerWithOptions(t, invitationEnabled, true, cache) +} + +func newOAuthPendingFlowTestHandlerWithOptions( + t *testing.T, + invitationEnabled bool, + emailVerifyEnabled bool, + emailCache service.EmailCache, +) (*AuthHandler, *dbent.Client) { + t.Helper() + db, err := sql.Open("sqlite", "file:auth_oauth_pending_flow_handler?mode=memory&cache=shared") require.NoError(t, err) t.Cleanup(func() { _ = db.Close() }) @@ -538,9 +834,18 @@ func newOAuthPendingFlowTestHandler(t *testing.T, invitationEnabled bool) (*Auth values: map[string]string{ service.SettingKeyRegistrationEnabled: "true", service.SettingKeyInvitationCodeEnabled: boolSettingValue(invitationEnabled), + service.SettingKeyEmailVerifyEnabled: boolSettingValue(emailVerifyEnabled), }, }, cfg) userRepo := &oauthPendingFlowUserRepo{client: client} + var emailService *service.EmailService + if emailCache != nil { + emailService = service.NewEmailService(&oauthPendingFlowSettingRepoStub{ + values: map[string]string{ + service.SettingKeyEmailVerifyEnabled: boolSettingValue(emailVerifyEnabled), + }, + }, emailCache) + } authSvc := service.NewAuthService( client, userRepo, @@ -548,7 +853,7 @@ func newOAuthPendingFlowTestHandler(t *testing.T, invitationEnabled bool) (*Auth &oauthPendingFlowRefreshTokenCacheStub{}, cfg, settingSvc, - nil, + emailService, nil, nil, nil, @@ -622,6 +927,70 @@ func (s *oauthPendingFlowSettingRepoStub) Delete(context.Context, string) error type oauthPendingFlowRefreshTokenCacheStub struct{} +type oauthPendingFlowEmailCacheStub struct { + verificationCodes map[string]*service.VerificationCodeData +} + +func (s *oauthPendingFlowEmailCacheStub) GetVerificationCode(_ context.Context, email string) (*service.VerificationCodeData, error) { + if s == nil || s.verificationCodes == nil { + return nil, nil + } + return s.verificationCodes[email], nil +} + +func (s *oauthPendingFlowEmailCacheStub) SetVerificationCode(_ context.Context, email string, data *service.VerificationCodeData, _ time.Duration) error { + if s.verificationCodes == nil { + s.verificationCodes = map[string]*service.VerificationCodeData{} + } + s.verificationCodes[email] = data + return nil +} + +func (s *oauthPendingFlowEmailCacheStub) DeleteVerificationCode(_ context.Context, email string) error { + delete(s.verificationCodes, email) + return nil +} + +func (s *oauthPendingFlowEmailCacheStub) GetNotifyVerifyCode(context.Context, string) (*service.VerificationCodeData, error) { + return nil, nil +} + +func (s *oauthPendingFlowEmailCacheStub) SetNotifyVerifyCode(context.Context, string, *service.VerificationCodeData, time.Duration) error { + return nil +} + +func (s *oauthPendingFlowEmailCacheStub) DeleteNotifyVerifyCode(context.Context, string) error { + return nil +} + +func (s *oauthPendingFlowEmailCacheStub) GetPasswordResetToken(context.Context, string) (*service.PasswordResetTokenData, error) { + return nil, nil +} + +func (s *oauthPendingFlowEmailCacheStub) SetPasswordResetToken(context.Context, string, *service.PasswordResetTokenData, time.Duration) error { + return nil +} + +func (s *oauthPendingFlowEmailCacheStub) DeletePasswordResetToken(context.Context, string) error { + return nil +} + +func (s *oauthPendingFlowEmailCacheStub) IsPasswordResetEmailInCooldown(context.Context, string) bool { + return false +} + +func (s *oauthPendingFlowEmailCacheStub) SetPasswordResetEmailCooldown(context.Context, string, time.Duration) error { + return nil +} + +func (s *oauthPendingFlowEmailCacheStub) IncrNotifyCodeUserRate(context.Context, int64, time.Duration) (int64, error) { + return 0, nil +} + +func (s *oauthPendingFlowEmailCacheStub) GetNotifyCodeUserRate(context.Context, int64) (int64, error) { + return 0, nil +} + func (s *oauthPendingFlowRefreshTokenCacheStub) StoreRefreshToken(context.Context, string, *service.RefreshTokenData, time.Duration) error { return nil } diff --git a/backend/internal/handler/auth_oidc_oauth.go b/backend/internal/handler/auth_oidc_oauth.go index 0f79759e..909d6379 100644 --- a/backend/internal/handler/auth_oidc_oauth.go +++ b/backend/internal/handler/auth_oidc_oauth.go @@ -342,6 +342,21 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { idClaims.Name, oidcFallbackUsername(subject), ) + identityRef := service.PendingAuthIdentityKey{ + ProviderType: "oidc", + ProviderKey: issuer, + ProviderSubject: subject, + } + upstreamClaims := map[string]any{ + "email": email, + "username": username, + "subject": subject, + "issuer": issuer, + "email_verified": emailVerified != nil && *emailVerified, + "provider_fallback": strings.TrimSpace(cfg.ProviderName), + "suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, idClaims.Name, username), + "suggested_avatar_url": userInfoClaims.AvatarURL, + } if intent == oauthIntentBindCurrentUser { targetUserID, err := h.readOAuthBindUserIDFromCookie(c, oidcOAuthBindUserCookieName) if err != nil { @@ -349,26 +364,13 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { return } if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ - Intent: oauthIntentBindCurrentUser, - Identity: service.PendingAuthIdentityKey{ - ProviderType: "oidc", - ProviderKey: issuer, - ProviderSubject: subject, - }, - TargetUserID: &targetUserID, - ResolvedEmail: email, - RedirectTo: redirectTo, - BrowserSessionKey: browserSessionKey, - UpstreamIdentityClaims: map[string]any{ - "email": email, - "username": username, - "subject": subject, - "issuer": issuer, - "email_verified": emailVerified != nil && *emailVerified, - "provider_fallback": strings.TrimSpace(cfg.ProviderName), - "suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, idClaims.Name, username), - "suggested_avatar_url": userInfoClaims.AvatarURL, - }, + Intent: oauthIntentBindCurrentUser, + Identity: identityRef, + TargetUserID: &targetUserID, + ResolvedEmail: email, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, CompletionResponse: map[string]any{ "redirect": redirectTo, }, @@ -380,30 +382,60 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { return } + existingIdentityUser, err := h.findOAuthIdentityUser(c.Request.Context(), identityRef) + if err != nil { + redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + if existingIdentityUser != nil { + tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), existingIdentityUser.Email, username, "") + if err != nil { + redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ + Intent: oauthIntentLogin, + Identity: identityRef, + TargetUserID: &user.ID, + ResolvedEmail: existingIdentityUser.Email, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, + CompletionResponse: map[string]any{ + "access_token": tokenPair.AccessToken, + "refresh_token": tokenPair.RefreshToken, + "expires_in": tokenPair.ExpiresIn, + "token_type": "Bearer", + "redirect": redirectTo, + }, + }); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") + return + } + redirectToFrontendCallback(c, frontendCallback) + return + } + + if h.isForceEmailOnThirdPartySignup(c.Request.Context()) { + if err := h.createOAuthEmailRequiredPendingSession(c, identityRef, redirectTo, browserSessionKey, upstreamClaims); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") + return + } + redirectToFrontendCallback(c, frontendCallback) + return + } + // 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "") if err != nil { if errors.Is(err, service.ErrOAuthInvitationRequired) { if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ - Intent: "login", - Identity: service.PendingAuthIdentityKey{ - ProviderType: "oidc", - ProviderKey: issuer, - ProviderSubject: subject, - }, - ResolvedEmail: email, - RedirectTo: redirectTo, - BrowserSessionKey: browserSessionKey, - UpstreamIdentityClaims: map[string]any{ - "email": email, - "username": username, - "subject": subject, - "issuer": issuer, - "email_verified": emailVerified != nil && *emailVerified, - "provider_fallback": strings.TrimSpace(cfg.ProviderName), - "suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, idClaims.Name, username), - "suggested_avatar_url": userInfoClaims.AvatarURL, - }, + Intent: "login", + Identity: identityRef, + ResolvedEmail: email, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, CompletionResponse: map[string]any{ "error": "invitation_required", "redirect": redirectTo, @@ -420,26 +452,13 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { } if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ - Intent: "login", - Identity: service.PendingAuthIdentityKey{ - ProviderType: "oidc", - ProviderKey: issuer, - ProviderSubject: subject, - }, - TargetUserID: &user.ID, - ResolvedEmail: email, - RedirectTo: redirectTo, - BrowserSessionKey: browserSessionKey, - UpstreamIdentityClaims: map[string]any{ - "email": email, - "username": username, - "subject": subject, - "issuer": issuer, - "email_verified": emailVerified != nil && *emailVerified, - "provider_fallback": strings.TrimSpace(cfg.ProviderName), - "suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, idClaims.Name, username), - "suggested_avatar_url": userInfoClaims.AvatarURL, - }, + Intent: "login", + Identity: identityRef, + TargetUserID: &user.ID, + ResolvedEmail: email, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, CompletionResponse: map[string]any{ "access_token": tokenPair.AccessToken, "refresh_token": tokenPair.RefreshToken, diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go index 45ac6cad..6d37c799 100644 --- a/backend/internal/handler/auth_wechat_oauth.go +++ b/backend/internal/handler/auth_wechat_oauth.go @@ -214,6 +214,11 @@ func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) { "suggested_display_name": strings.TrimSpace(userInfo.Nickname), "suggested_avatar_url": strings.TrimSpace(userInfo.HeadImgURL), } + identityRef := service.PendingAuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: wechatOAuthProviderKey, + ProviderSubject: providerSubject, + } normalizedIntent := normalizeWeChatOAuthIntent(intent) if normalizedIntent == wechatOAuthIntentBind { @@ -232,6 +237,34 @@ func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) { return } + existingIdentityUser, err := h.findOAuthIdentityUser(c.Request.Context(), identityRef) + if err != nil { + redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + if existingIdentityUser != nil { + tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), existingIdentityUser.Email, username, "") + if err != nil { + redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + if err := h.createWeChatPendingSession(c, normalizedIntent, providerSubject, existingIdentityUser.Email, redirectTo, browserSessionKey, upstreamClaims, tokenPair, nil, &user.ID); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") + return + } + redirectToFrontendCallback(c, frontendCallback) + return + } + + if h.isForceEmailOnThirdPartySignup(c.Request.Context()) { + if err := h.createOAuthEmailRequiredPendingSession(c, identityRef, redirectTo, browserSessionKey, upstreamClaims); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") + return + } + redirectToFrontendCallback(c, frontendCallback) + return + } + tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "") if err != nil { if err := h.createWeChatPendingSession(c, normalizedIntent, providerSubject, email, redirectTo, browserSessionKey, upstreamClaims, tokenPair, err, nil); err != nil { diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index f44b3e3b..637e317b 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -167,6 +167,7 @@ type DefaultSubscriptionSetting struct { type PublicSettings struct { RegistrationEnabled bool `json:"registration_enabled"` EmailVerifyEnabled bool `json:"email_verify_enabled"` + ForceEmailOnThirdPartySignup bool `json:"force_email_on_third_party_signup"` RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"` PromoCodeEnabled bool `json:"promo_code_enabled"` PasswordResetEnabled bool `json:"password_reset_enabled"` diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go index c7bc3e2a..9925f066 100644 --- a/backend/internal/handler/setting_handler.go +++ b/backend/internal/handler/setting_handler.go @@ -34,6 +34,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { response.Success(c, dto.PublicSettings{ RegistrationEnabled: settings.RegistrationEnabled, EmailVerifyEnabled: settings.EmailVerifyEnabled, + ForceEmailOnThirdPartySignup: settings.ForceEmailOnThirdPartySignup, RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist, PromoCodeEnabled: settings.PromoCodeEnabled, PasswordResetEnabled: settings.PasswordResetEnabled, diff --git a/backend/internal/handler/setting_handler_public_test.go b/backend/internal/handler/setting_handler_public_test.go new file mode 100644 index 00000000..114c7245 --- /dev/null +++ b/backend/internal/handler/setting_handler_public_test.go @@ -0,0 +1,83 @@ +//go:build unit + +package handler + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type settingHandlerPublicRepoStub struct { + values map[string]string +} + +func (s *settingHandlerPublicRepoStub) Get(ctx context.Context, key string) (*service.Setting, error) { + panic("unexpected Get call") +} + +func (s *settingHandlerPublicRepoStub) GetValue(ctx context.Context, key string) (string, error) { + panic("unexpected GetValue call") +} + +func (s *settingHandlerPublicRepoStub) Set(ctx context.Context, key, value string) error { + panic("unexpected Set call") +} + +func (s *settingHandlerPublicRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, key := range keys { + if value, ok := s.values[key]; ok { + out[key] = value + } + } + return out, nil +} + +func (s *settingHandlerPublicRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { + panic("unexpected SetMultiple call") +} + +func (s *settingHandlerPublicRepoStub) GetAll(ctx context.Context) (map[string]string, error) { + panic("unexpected GetAll call") +} + +func (s *settingHandlerPublicRepoStub) Delete(ctx context.Context, key string) error { + panic("unexpected Delete call") +} + +func TestSettingHandler_GetPublicSettings_ExposesForceEmailOnThirdPartySignup(t *testing.T) { + gin.SetMode(gin.TestMode) + + repo := &settingHandlerPublicRepoStub{ + values: map[string]string{ + service.SettingKeyForceEmailOnThirdPartySignup: "true", + }, + } + h := NewSettingHandler(service.NewSettingService(repo, &config.Config{}), "test-version") + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/settings/public", nil) + + h.GetPublicSettings(c) + + require.Equal(t, http.StatusOK, recorder.Code) + + var resp struct { + Code int `json:"code"` + Data struct { + ForceEmailOnThirdPartySignup bool `json:"force_email_on_third_party_signup"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + require.True(t, resp.Data.ForceEmailOnThirdPartySignup) +} diff --git a/backend/internal/server/routes/auth.go b/backend/internal/server/routes/auth.go index 7a34834d..1f28e9c3 100644 --- a/backend/internal/server/routes/auth.go +++ b/backend/internal/server/routes/auth.go @@ -72,18 +72,54 @@ func RegisterAuthRoutes( }), h.Auth.ExchangePendingOAuthCompletion, ) + auth.POST("/oauth/pending/create-account", + rateLimiter.LimitWithOptions("oauth-pending-create-account", 10, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.CreatePendingOAuthAccount, + ) + auth.POST("/oauth/pending/bind-login", + rateLimiter.LimitWithOptions("oauth-pending-bind-login", 10, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.BindPendingOAuthLogin, + ) auth.POST("/oauth/linuxdo/complete-registration", rateLimiter.LimitWithOptions("oauth-linuxdo-complete", 10, time.Minute, middleware.RateLimitOptions{ FailureMode: middleware.RateLimitFailClose, }), h.Auth.CompleteLinuxDoOAuthRegistration, ) + auth.POST("/oauth/linuxdo/bind-login", + rateLimiter.LimitWithOptions("oauth-linuxdo-bind-login", 20, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.BindLinuxDoOAuthLogin, + ) + auth.POST("/oauth/linuxdo/create-account", + rateLimiter.LimitWithOptions("oauth-linuxdo-create-account", 10, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.CreateLinuxDoOAuthAccount, + ) auth.POST("/oauth/wechat/complete-registration", rateLimiter.LimitWithOptions("oauth-wechat-complete", 10, time.Minute, middleware.RateLimitOptions{ FailureMode: middleware.RateLimitFailClose, }), h.Auth.CompleteWeChatOAuthRegistration, ) + auth.POST("/oauth/wechat/bind-login", + rateLimiter.LimitWithOptions("oauth-wechat-bind-login", 20, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.BindWeChatOAuthLogin, + ) + auth.POST("/oauth/wechat/create-account", + rateLimiter.LimitWithOptions("oauth-wechat-create-account", 10, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.CreateWeChatOAuthAccount, + ) auth.GET("/oauth/oidc/start", h.Auth.OIDCOAuthStart) auth.GET("/oauth/oidc/callback", h.Auth.OIDCOAuthCallback) auth.POST("/oauth/oidc/complete-registration", @@ -92,6 +128,18 @@ func RegisterAuthRoutes( }), h.Auth.CompleteOIDCOAuthRegistration, ) + auth.POST("/oauth/oidc/bind-login", + rateLimiter.LimitWithOptions("oauth-oidc-bind-login", 20, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.BindOIDCOAuthLogin, + ) + auth.POST("/oauth/oidc/create-account", + rateLimiter.LimitWithOptions("oauth-oidc-create-account", 10, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.CreateOIDCOAuthAccount, + ) } // 公开设置(无需认证) diff --git a/backend/internal/service/auth_oauth_email_flow.go b/backend/internal/service/auth_oauth_email_flow.go new file mode 100644 index 00000000..ca3403d4 --- /dev/null +++ b/backend/internal/service/auth_oauth_email_flow.go @@ -0,0 +1,151 @@ +package service + +import ( + "context" + "errors" + "fmt" + "strings" +) + +// VerifyOAuthEmailCode verifies the locally entered email verification code for +// third-party signup and binding flows. This is intentionally independent from +// the global registration email verification toggle. +func (s *AuthService) VerifyOAuthEmailCode(ctx context.Context, email, verifyCode string) error { + email = strings.TrimSpace(strings.ToLower(email)) + verifyCode = strings.TrimSpace(verifyCode) + + if email == "" { + return ErrEmailVerifyRequired + } + if verifyCode == "" { + return ErrEmailVerifyRequired + } + if s == nil || s.emailService == nil { + return ErrServiceUnavailable + } + return s.emailService.VerifyCode(ctx, email, verifyCode) +} + +// RegisterOAuthEmailAccount creates a local account from a third-party first +// login after the user has verified a local email address. +func (s *AuthService) RegisterOAuthEmailAccount( + ctx context.Context, + email string, + password string, + verifyCode string, + invitationCode string, + signupSource string, +) (*TokenPair, *User, error) { + if s == nil { + return nil, nil, ErrServiceUnavailable + } + if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) { + return nil, nil, ErrRegDisabled + } + + email = strings.TrimSpace(strings.ToLower(email)) + if isReservedEmail(email) { + return nil, nil, ErrEmailReserved + } + if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil { + return nil, nil, err + } + if err := s.VerifyOAuthEmailCode(ctx, email, verifyCode); err != nil { + return nil, nil, err + } + + var invitationRedeemCode *RedeemCode + if s.settingService.IsInvitationCodeEnabled(ctx) { + if invitationCode == "" { + return nil, nil, ErrInvitationCodeRequired + } + redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode) + if err != nil { + return nil, nil, ErrInvitationCodeInvalid + } + if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused { + return nil, nil, ErrInvitationCodeInvalid + } + invitationRedeemCode = redeemCode + } + + existsEmail, err := s.userRepo.ExistsByEmail(ctx, email) + if err != nil { + return nil, nil, ErrServiceUnavailable + } + if existsEmail { + return nil, nil, ErrEmailExists + } + + hashedPassword, err := s.HashPassword(password) + if err != nil { + return nil, nil, fmt.Errorf("hash password: %w", err) + } + + signupSource = strings.TrimSpace(strings.ToLower(signupSource)) + if signupSource == "" { + signupSource = "email" + } + grantPlan := s.resolveSignupGrantPlan(ctx, signupSource) + + user := &User{ + Email: email, + PasswordHash: hashedPassword, + Role: RoleUser, + Balance: grantPlan.Balance, + Concurrency: grantPlan.Concurrency, + Status: StatusActive, + } + + if err := s.userRepo.Create(ctx, user); err != nil { + if errors.Is(err, ErrEmailExists) { + return nil, nil, ErrEmailExists + } + return nil, nil, ErrServiceUnavailable + } + + s.postAuthUserBootstrap(ctx, user, signupSource, true) + s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") + + if invitationRedeemCode != nil { + if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil { + return nil, nil, ErrInvitationCodeInvalid + } + } + + tokenPair, err := s.GenerateTokenPair(ctx, user, "") + if err != nil { + return nil, nil, fmt.Errorf("generate token pair: %w", err) + } + return tokenPair, user, nil +} + +// ValidatePasswordCredentials checks the local password without completing the +// login flow. This is used by pending third-party account adoption flows before +// the external identity has been bound. +func (s *AuthService) ValidatePasswordCredentials(ctx context.Context, email, password string) (*User, error) { + if s == nil { + return nil, ErrServiceUnavailable + } + + user, err := s.userRepo.GetByEmail(ctx, strings.TrimSpace(strings.ToLower(email))) + if err != nil { + if errors.Is(err, ErrUserNotFound) { + return nil, ErrInvalidCredentials + } + return nil, ErrServiceUnavailable + } + if !user.IsActive() { + return nil, ErrUserNotActive + } + if !s.CheckPassword(password, user.PasswordHash) { + return nil, ErrInvalidCredentials + } + return user, nil +} + +// RecordSuccessfulLogin updates last-login activity after a non-standard login +// flow finishes with a real session. +func (s *AuthService) RecordSuccessfulLogin(ctx context.Context, userID int64) { + s.touchUserLogin(ctx, userID) +} diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index de555478..a2644fcd 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -217,6 +217,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings keys := []string{ SettingKeyRegistrationEnabled, SettingKeyEmailVerifyEnabled, + SettingKeyForceEmailOnThirdPartySignup, SettingKeyRegistrationEmailSuffixWhitelist, SettingKeyPromoCodeEnabled, SettingKeyPasswordResetEnabled, @@ -294,6 +295,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings return &PublicSettings{ RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true", EmailVerifyEnabled: emailVerifyEnabled, + ForceEmailOnThirdPartySignup: settings[SettingKeyForceEmailOnThirdPartySignup] == "true", RegistrationEmailSuffixWhitelist: registrationEmailSuffixWhitelist, PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用 PasswordResetEnabled: passwordResetEnabled, diff --git a/backend/internal/service/setting_service_public_test.go b/backend/internal/service/setting_service_public_test.go index 5cf1e860..bb97c2aa 100644 --- a/backend/internal/service/setting_service_public_test.go +++ b/backend/internal/service/setting_service_public_test.go @@ -77,3 +77,16 @@ func TestSettingService_GetPublicSettings_ExposesTablePreferences(t *testing.T) require.Equal(t, 50, settings.TableDefaultPageSize) require.Equal(t, []int{20, 50, 100}, settings.TablePageSizeOptions) } + +func TestSettingService_GetPublicSettings_ExposesForceEmailOnThirdPartySignup(t *testing.T) { + repo := &settingPublicRepoStub{ + values: map[string]string{ + SettingKeyForceEmailOnThirdPartySignup: "true", + }, + } + svc := NewSettingService(repo, &config.Config{}) + + settings, err := svc.GetPublicSettings(context.Background()) + require.NoError(t, err) + require.True(t, settings.ForceEmailOnThirdPartySignup) +} diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index e991ebef..72db4e31 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -128,6 +128,7 @@ type DefaultSubscriptionSetting struct { type PublicSettings struct { RegistrationEnabled bool EmailVerifyEnabled bool + ForceEmailOnThirdPartySignup bool RegistrationEmailSuffixWhitelist []string PromoCodeEnabled bool PasswordResetEnabled bool