feat: add pending oauth email onboarding flow

This commit is contained in:
IanShaw027
2026-04-20 19:30:09 +08:00
parent d47580a144
commit 6a75bd77e3
13 changed files with 1273 additions and 119 deletions

View File

@@ -243,6 +243,18 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
if subject != "" {
email = linuxDoSyntheticEmail(subject)
}
identityKey := service.PendingAuthIdentityKey{
ProviderType: "linuxdo",
ProviderKey: "linuxdo",
ProviderSubject: subject,
}
upstreamClaims := map[string]any{
"email": email,
"username": username,
"subject": subject,
"suggested_display_name": displayName,
"suggested_avatar_url": avatarURL,
}
if intent == oauthIntentBindCurrentUser {
targetUserID, err := h.readOAuthBindUserIDFromCookie(c, linuxDoOAuthBindUserCookieName)
if err != nil {
@@ -250,23 +262,13 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
return
}
if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
Intent: oauthIntentBindCurrentUser,
Identity: service.PendingAuthIdentityKey{
ProviderType: "linuxdo",
ProviderKey: "linuxdo",
ProviderSubject: subject,
},
TargetUserID: &targetUserID,
ResolvedEmail: email,
RedirectTo: redirectTo,
BrowserSessionKey: browserSessionKey,
UpstreamIdentityClaims: map[string]any{
"email": email,
"username": username,
"subject": subject,
"suggested_display_name": displayName,
"suggested_avatar_url": avatarURL,
},
Intent: oauthIntentBindCurrentUser,
Identity: identityKey,
TargetUserID: &targetUserID,
ResolvedEmail: email,
RedirectTo: redirectTo,
BrowserSessionKey: browserSessionKey,
UpstreamIdentityClaims: upstreamClaims,
CompletionResponse: map[string]any{
"redirect": redirectTo,
},
@@ -278,27 +280,60 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
return
}
existingIdentityUser, err := h.findOAuthIdentityUser(c.Request.Context(), identityKey)
if err != nil {
redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
return
}
if existingIdentityUser != nil {
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), existingIdentityUser.Email, username, "")
if err != nil {
redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
return
}
if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
Intent: oauthIntentLogin,
Identity: identityKey,
TargetUserID: &user.ID,
ResolvedEmail: existingIdentityUser.Email,
RedirectTo: redirectTo,
BrowserSessionKey: browserSessionKey,
UpstreamIdentityClaims: upstreamClaims,
CompletionResponse: map[string]any{
"access_token": tokenPair.AccessToken,
"refresh_token": tokenPair.RefreshToken,
"expires_in": tokenPair.ExpiresIn,
"token_type": "Bearer",
"redirect": redirectTo,
},
}); err != nil {
redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
return
}
redirectToFrontendCallback(c, frontendCallback)
return
}
if h.isForceEmailOnThirdPartySignup(c.Request.Context()) {
if err := h.createOAuthEmailRequiredPendingSession(c, identityKey, redirectTo, browserSessionKey, upstreamClaims); err != nil {
redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
return
}
redirectToFrontendCallback(c, frontendCallback)
return
}
// 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
if err != nil {
if errors.Is(err, service.ErrOAuthInvitationRequired) {
if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
Intent: "login",
Identity: service.PendingAuthIdentityKey{
ProviderType: "linuxdo",
ProviderKey: "linuxdo",
ProviderSubject: subject,
},
ResolvedEmail: email,
RedirectTo: redirectTo,
BrowserSessionKey: browserSessionKey,
UpstreamIdentityClaims: map[string]any{
"email": email,
"username": username,
"subject": subject,
"suggested_display_name": displayName,
"suggested_avatar_url": avatarURL,
},
Intent: "login",
Identity: identityKey,
ResolvedEmail: email,
RedirectTo: redirectTo,
BrowserSessionKey: browserSessionKey,
UpstreamIdentityClaims: upstreamClaims,
CompletionResponse: map[string]any{
"error": "invitation_required",
"redirect": redirectTo,
@@ -316,23 +351,13 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
}
if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
Intent: "login",
Identity: service.PendingAuthIdentityKey{
ProviderType: "linuxdo",
ProviderKey: "linuxdo",
ProviderSubject: subject,
},
TargetUserID: &user.ID,
ResolvedEmail: email,
RedirectTo: redirectTo,
BrowserSessionKey: browserSessionKey,
UpstreamIdentityClaims: map[string]any{
"email": email,
"username": username,
"subject": subject,
"suggested_display_name": displayName,
"suggested_avatar_url": avatarURL,
},
Intent: "login",
Identity: identityKey,
TargetUserID: &user.ID,
ResolvedEmail: email,
RedirectTo: redirectTo,
BrowserSessionKey: browserSessionKey,
UpstreamIdentityClaims: upstreamClaims,
CompletionResponse: map[string]any{
"access_token": tokenPair.AccessToken,
"refresh_token": tokenPair.RefreshToken,

View File

@@ -46,6 +46,36 @@ type oauthAdoptionDecisionRequest struct {
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
}
type bindPendingOAuthLoginRequest struct {
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required"`
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
}
type createPendingOAuthAccountRequest struct {
Email string `json:"email" binding:"required,email"`
VerifyCode string `json:"verify_code,omitempty"`
Password string `json:"password" binding:"required,min=6"`
InvitationCode string `json:"invitation_code,omitempty"`
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
}
func (r bindPendingOAuthLoginRequest) adoptionDecision() oauthAdoptionDecisionRequest {
return oauthAdoptionDecisionRequest{
AdoptDisplayName: r.AdoptDisplayName,
AdoptAvatar: r.AdoptAvatar,
}
}
func (r createPendingOAuthAccountRequest) adoptionDecision() oauthAdoptionDecisionRequest {
return oauthAdoptionDecisionRequest{
AdoptDisplayName: r.AdoptDisplayName,
AdoptAvatar: r.AdoptAvatar,
}
}
func (h *AuthHandler) pendingIdentityService() (*service.AuthPendingIdentityService, error) {
if h == nil || h.authService == nil || h.authService.EntClient() == nil {
return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
@@ -170,6 +200,36 @@ func readCompletionResponse(session map[string]any) (map[string]any, bool) {
return result, true
}
func clonePendingMap(values map[string]any) map[string]any {
if len(values) == 0 {
return map[string]any{}
}
cloned := make(map[string]any, len(values))
for key, value := range values {
cloned[key] = value
}
return cloned
}
func mergePendingCompletionResponse(session *dbent.PendingAuthSession, overrides map[string]any) map[string]any {
payload, _ := readCompletionResponse(session.LocalFlowState)
merged := clonePendingMap(payload)
if strings.TrimSpace(session.RedirectTo) != "" {
if _, exists := merged["redirect"]; !exists {
merged["redirect"] = session.RedirectTo
}
}
for key, value := range overrides {
if value == nil {
delete(merged, key)
continue
}
merged[key] = value
}
applySuggestedProfileToCompletionResponse(merged, session.UpstreamIdentityClaims)
return merged
}
func pendingSessionStringValue(values map[string]any, key string) string {
if len(values) == 0 {
return ""
@@ -264,6 +324,89 @@ func (h *AuthHandler) entClient() *dbent.Client {
return h.authService.EntClient()
}
func (h *AuthHandler) isForceEmailOnThirdPartySignup(ctx context.Context) bool {
if h == nil || h.settingSvc == nil {
return false
}
defaults, err := h.settingSvc.GetAuthSourceDefaultSettings(ctx)
if err != nil || defaults == nil {
return false
}
return defaults.ForceEmailOnThirdPartySignup
}
func (h *AuthHandler) findOAuthIdentityUser(ctx context.Context, identity service.PendingAuthIdentityKey) (*dbent.User, error) {
client := h.entClient()
if client == nil {
return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
}
record, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ(strings.TrimSpace(identity.ProviderType)),
authidentity.ProviderKeyEQ(strings.TrimSpace(identity.ProviderKey)),
authidentity.ProviderSubjectEQ(strings.TrimSpace(identity.ProviderSubject)),
).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, nil
}
return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
}
userEntity, err := client.User.Get(ctx, record.UserID)
if err != nil {
if dbent.IsNotFound(err) {
return nil, nil
}
return nil, infraerrors.InternalServer("AUTH_IDENTITY_USER_LOOKUP_FAILED", "failed to load auth identity user").WithCause(err)
}
return userEntity, nil
}
func (h *AuthHandler) createOAuthEmailRequiredPendingSession(
c *gin.Context,
identity service.PendingAuthIdentityKey,
redirectTo string,
browserSessionKey string,
upstreamClaims map[string]any,
) error {
return h.createOAuthPendingSession(c, oauthPendingSessionPayload{
Intent: oauthIntentLogin,
Identity: identity,
RedirectTo: redirectTo,
BrowserSessionKey: browserSessionKey,
UpstreamIdentityClaims: upstreamClaims,
CompletionResponse: map[string]any{
"redirect": redirectTo,
"step": "email_required",
"force_email_on_signup": true,
"email_binding_required": true,
"existing_account_bindable": true,
},
})
}
func (h *AuthHandler) BindLinuxDoOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "linuxdo") }
func (h *AuthHandler) BindOIDCOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "oidc") }
func (h *AuthHandler) BindWeChatOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "wechat") }
func (h *AuthHandler) BindPendingOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "") }
func (h *AuthHandler) CreateLinuxDoOAuthAccount(c *gin.Context) {
h.createPendingOAuthAccount(c, "linuxdo")
}
func (h *AuthHandler) CreateOIDCOAuthAccount(c *gin.Context) { h.createPendingOAuthAccount(c, "oidc") }
func (h *AuthHandler) CreateWeChatOAuthAccount(c *gin.Context) {
h.createPendingOAuthAccount(c, "wechat")
}
func (h *AuthHandler) CreatePendingOAuthAccount(c *gin.Context) {
h.createPendingOAuthAccount(c, "")
}
func (h *AuthHandler) upsertPendingOAuthAdoptionDecision(
c *gin.Context,
sessionID int64,
@@ -313,6 +456,60 @@ func (h *AuthHandler) upsertPendingOAuthAdoptionDecision(
return decision, nil
}
func (h *AuthHandler) ensurePendingOAuthAdoptionDecision(
c *gin.Context,
sessionID int64,
req oauthAdoptionDecisionRequest,
) (*dbent.IdentityAdoptionDecision, error) {
decision, err := h.upsertPendingOAuthAdoptionDecision(c, sessionID, req)
if err != nil {
return nil, err
}
if decision != nil {
return decision, nil
}
svc, err := h.pendingIdentityService()
if err != nil {
return nil, err
}
decision, err = svc.UpsertAdoptionDecision(c.Request.Context(), service.PendingIdentityAdoptionDecisionInput{
PendingAuthSessionID: sessionID,
})
if err != nil {
return nil, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_SAVE_FAILED", "failed to save oauth profile adoption decision").WithCause(err)
}
return decision, nil
}
func updatePendingOAuthSessionProgress(
ctx context.Context,
client *dbent.Client,
session *dbent.PendingAuthSession,
intent string,
resolvedEmail string,
targetUserID *int64,
completionResponse map[string]any,
) (*dbent.PendingAuthSession, error) {
if client == nil || session == nil {
return nil, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth session is invalid")
}
localFlowState := clonePendingMap(session.LocalFlowState)
localFlowState[oauthCompletionResponseKey] = clonePendingMap(completionResponse)
update := client.PendingAuthSession.UpdateOneID(session.ID).
SetIntent(strings.TrimSpace(intent)).
SetResolvedEmail(strings.TrimSpace(resolvedEmail)).
SetLocalFlowState(localFlowState)
if targetUserID != nil && *targetUserID > 0 {
update = update.SetTargetUserID(*targetUserID)
} else {
update = update.ClearTargetUserID()
}
return update.Save(ctx)
}
func resolvePendingOAuthTargetUserID(ctx context.Context, client *dbent.Client, session *dbent.PendingAuthSession) (int64, error) {
if session == nil {
return 0, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth session is invalid")
@@ -401,17 +598,18 @@ func shouldBindPendingOAuthIdentity(session *dbent.PendingAuthSession, decision
return decision.AdoptDisplayName || decision.AdoptAvatar
}
func applyPendingOAuthAdoption(
func applyPendingOAuthBinding(
ctx context.Context,
client *dbent.Client,
session *dbent.PendingAuthSession,
decision *dbent.IdentityAdoptionDecision,
overrideUserID *int64,
forceBind bool,
) error {
if client == nil || session == nil || decision == nil {
if client == nil || session == nil {
return nil
}
if !shouldBindPendingOAuthIdentity(session, decision) {
if !forceBind && !shouldBindPendingOAuthIdentity(session, decision) {
return nil
}
@@ -427,11 +625,11 @@ func applyPendingOAuthAdoption(
}
adoptedDisplayName := ""
if decision.AdoptDisplayName {
if decision != nil && decision.AdoptDisplayName {
adoptedDisplayName = normalizeAdoptedOAuthDisplayName(pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_display_name"))
}
adoptedAvatarURL := ""
if decision.AdoptAvatar {
if decision != nil && decision.AdoptAvatar {
adoptedAvatarURL = pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_avatar_url")
}
@@ -441,7 +639,7 @@ func applyPendingOAuthAdoption(
}
defer func() { _ = tx.Rollback() }()
if decision.AdoptDisplayName && adoptedDisplayName != "" {
if decision != nil && decision.AdoptDisplayName && adoptedDisplayName != "" {
if err := tx.Client().User.UpdateOneID(targetUserID).
SetUsername(adoptedDisplayName).
Exec(ctx); err != nil {
@@ -458,10 +656,10 @@ func applyPendingOAuthAdoption(
for key, value := range session.UpstreamIdentityClaims {
metadata[key] = value
}
if decision.AdoptDisplayName && adoptedDisplayName != "" {
if decision != nil && decision.AdoptDisplayName && adoptedDisplayName != "" {
metadata["display_name"] = adoptedDisplayName
}
if decision.AdoptAvatar && adoptedAvatarURL != "" {
if decision != nil && decision.AdoptAvatar && adoptedAvatarURL != "" {
metadata["avatar_url"] = adoptedAvatarURL
}
@@ -473,7 +671,7 @@ func applyPendingOAuthAdoption(
return err
}
if decision.IdentityID == nil || *decision.IdentityID != identity.ID {
if decision != nil && (decision.IdentityID == nil || *decision.IdentityID != identity.ID) {
if _, err := tx.Client().IdentityAdoptionDecision.UpdateOneID(decision.ID).
SetIdentityID(identity.ID).
Save(ctx); err != nil {
@@ -484,6 +682,16 @@ func applyPendingOAuthAdoption(
return tx.Commit()
}
func applyPendingOAuthAdoption(
ctx context.Context,
client *dbent.Client,
session *dbent.PendingAuthSession,
decision *dbent.IdentityAdoptionDecision,
overrideUserID *int64,
) error {
return applyPendingOAuthBinding(ctx, client, session, decision, overrideUserID, false)
}
func applySuggestedProfileToCompletionResponse(payload map[string]any, upstream map[string]any) {
if len(payload) == 0 || len(upstream) == 0 {
return
@@ -507,6 +715,206 @@ func applySuggestedProfileToCompletionResponse(payload map[string]any, upstream
}
}
func readPendingOAuthBrowserSession(c *gin.Context, h *AuthHandler) (*service.AuthPendingIdentityService, *dbent.PendingAuthSession, func(), error) {
secureCookie := isRequestHTTPS(c)
clearCookies := func() {
clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie)
}
sessionToken, err := readOAuthPendingSessionCookie(c)
if err != nil || strings.TrimSpace(sessionToken) == "" {
clearCookies()
return nil, nil, clearCookies, service.ErrPendingAuthSessionNotFound
}
browserSessionKey, err := readOAuthPendingBrowserCookie(c)
if err != nil || strings.TrimSpace(browserSessionKey) == "" {
clearCookies()
return nil, nil, clearCookies, service.ErrPendingAuthBrowserMismatch
}
svc, err := h.pendingIdentityService()
if err != nil {
clearCookies()
return nil, nil, clearCookies, err
}
session, err := svc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey)
if err != nil {
clearCookies()
return nil, nil, clearCookies, err
}
return svc, session, clearCookies, nil
}
func buildPendingOAuthSessionStatusPayload(session *dbent.PendingAuthSession) gin.H {
payload := gin.H{
"auth_result": "pending_session",
"provider": strings.TrimSpace(session.ProviderType),
"intent": strings.TrimSpace(session.Intent),
}
for key, value := range mergePendingCompletionResponse(session, nil) {
payload[key] = value
}
if email := strings.TrimSpace(session.ResolvedEmail); email != "" {
payload["email"] = email
}
return payload
}
func writeOAuthTokenPairResponse(c *gin.Context, tokenPair *service.TokenPair) {
c.JSON(http.StatusOK, gin.H{
"access_token": tokenPair.AccessToken,
"refresh_token": tokenPair.RefreshToken,
"expires_in": tokenPair.ExpiresIn,
"token_type": "Bearer",
})
}
func (h *AuthHandler) bindPendingOAuthLogin(c *gin.Context, provider string) {
var req bindPendingOAuthLoginRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
pendingSvc, session, clearCookies, err := readPendingOAuthBrowserSession(c, h)
if err != nil {
response.ErrorFrom(c, err)
return
}
if strings.TrimSpace(provider) != "" && !strings.EqualFold(strings.TrimSpace(session.ProviderType), provider) {
response.BadRequest(c, "Pending oauth session provider mismatch")
return
}
user, err := h.authService.ValidatePasswordCredentials(c.Request.Context(), strings.TrimSpace(req.Email), req.Password)
if err != nil {
response.ErrorFrom(c, err)
return
}
if session.TargetUserID != nil && *session.TargetUserID > 0 && user.ID != *session.TargetUserID {
response.ErrorFrom(c, infraerrors.Conflict("PENDING_AUTH_TARGET_USER_MISMATCH", "pending oauth session must be completed by the targeted user"))
return
}
decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, req.adoptionDecision())
if err != nil {
response.ErrorFrom(c, err)
return
}
if err := applyPendingOAuthBinding(c.Request.Context(), h.entClient(), session, decision, &user.ID, true); err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
return
}
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
tokenPair, err := h.authService.GenerateTokenPair(c.Request.Context(), user, "")
if err != nil {
response.InternalError(c, "Failed to generate token pair")
return
}
if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), session.SessionToken, session.BrowserSessionKey); err != nil {
clearCookies()
response.ErrorFrom(c, err)
return
}
clearCookies()
writeOAuthTokenPairResponse(c, tokenPair)
}
func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string) {
var req createPendingOAuthAccountRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
pendingSvc, session, clearCookies, err := readPendingOAuthBrowserSession(c, h)
if err != nil {
response.ErrorFrom(c, err)
return
}
if strings.TrimSpace(provider) != "" && !strings.EqualFold(strings.TrimSpace(session.ProviderType), provider) {
response.BadRequest(c, "Pending oauth session provider mismatch")
return
}
client := h.entClient()
if client == nil {
response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready"))
return
}
email := strings.TrimSpace(strings.ToLower(req.Email))
existingUser, err := client.User.Query().Where(dbuser.EmailEQ(email)).Only(c.Request.Context())
if err != nil && !dbent.IsNotFound(err) {
response.ErrorFrom(c, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable"))
return
}
if existingUser != nil {
completionResponse := mergePendingCompletionResponse(session, map[string]any{
"step": "bind_login_required",
"email": email,
})
session, err = updatePendingOAuthSessionProgress(
c.Request.Context(),
client,
session,
"adopt_existing_user_by_email",
email,
&existingUser.ID,
completionResponse,
)
if err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_SESSION_UPDATE_FAILED", "failed to update pending oauth session").WithCause(err))
return
}
if _, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, req.adoptionDecision()); err != nil {
response.ErrorFrom(c, err)
return
}
c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(session))
return
}
tokenPair, user, err := h.authService.RegisterOAuthEmailAccount(
c.Request.Context(),
email,
req.Password,
strings.TrimSpace(req.VerifyCode),
strings.TrimSpace(req.InvitationCode),
strings.TrimSpace(session.ProviderType),
)
if err != nil {
response.ErrorFrom(c, err)
return
}
decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, req.adoptionDecision())
if err != nil {
response.ErrorFrom(c, err)
return
}
if err := applyPendingOAuthBinding(c.Request.Context(), client, session, decision, &user.ID, true); err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
return
}
if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), session.SessionToken, session.BrowserSessionKey); err != nil {
clearCookies()
response.ErrorFrom(c, err)
return
}
clearCookies()
writeOAuthTokenPairResponse(c, tokenPair)
}
// ExchangePendingOAuthCompletion redeems a pending OAuth browser session into a frontend-safe payload.
// POST /api/v1/auth/oauth/pending/exchange
func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {

View File

@@ -509,9 +509,305 @@ func TestExchangePendingOAuthCompletionInvitationRequiredFalseFalsePersistsDecis
require.Nil(t, storedSession.ConsumedAt)
}
func TestCreateOIDCOAuthAccountCreatesUserBindsIdentityAndConsumesSession(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "fresh@example.com", "246810")
ctx := context.Background()
session, err := client.PendingAuthSession.Create().
SetSessionToken("create-account-session-token").
SetIntent("login").
SetProviderType("oidc").
SetProviderKey("https://issuer.example").
SetProviderSubject("oidc-create-123").
SetBrowserSessionKey("create-account-browser-session-key").
SetUpstreamIdentityClaims(map[string]any{
"username": "oidc_user",
"suggested_display_name": "Fresh OIDC User",
"suggested_avatar_url": "https://cdn.example/fresh.png",
}).
SetRedirectTo("/profile").
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
body := bytes.NewBufferString(`{"email":"fresh@example.com","verify_code":"246810","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`)
recorder := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("create-account-browser-session-key")})
ginCtx.Request = req
handler.CreateOIDCOAuthAccount(ginCtx)
require.Equal(t, http.StatusOK, recorder.Code)
var payload map[string]any
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload))
require.NotEmpty(t, payload["access_token"])
require.NotEmpty(t, payload["refresh_token"])
require.Equal(t, "Bearer", payload["token_type"])
createdUser, err := client.User.Query().Where(dbuser.EmailEQ("fresh@example.com")).Only(ctx)
require.NoError(t, err)
require.Equal(t, service.StatusActive, createdUser.Status)
identity, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("oidc"),
authidentity.ProviderKeyEQ("https://issuer.example"),
authidentity.ProviderSubjectEQ("oidc-create-123"),
).
Only(ctx)
require.NoError(t, err)
require.Equal(t, createdUser.ID, identity.UserID)
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
require.NotNil(t, storedSession.ConsumedAt)
}
func TestCreateOIDCOAuthAccountExistingEmailReturnsAdoptExistingUserByEmailState(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790")
ctx := context.Background()
existingUser, err := client.User.Create().
SetEmail("owner@example.com").
SetUsername("owner-user").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
session, err := client.PendingAuthSession.Create().
SetSessionToken("existing-email-session-token").
SetIntent("login").
SetProviderType("oidc").
SetProviderKey("https://issuer.example").
SetProviderSubject("oidc-existing-123").
SetBrowserSessionKey("existing-email-browser-session-key").
SetUpstreamIdentityClaims(map[string]any{
"username": "oidc_user",
"suggested_display_name": "Existing OIDC User",
"suggested_avatar_url": "https://cdn.example/existing.png",
}).
SetRedirectTo("/dashboard").
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
body := bytes.NewBufferString(`{"email":"owner@example.com","verify_code":"135790","password":"secret-123"}`)
recorder := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("existing-email-browser-session-key")})
ginCtx.Request = req
handler.CreateOIDCOAuthAccount(ginCtx)
require.Equal(t, http.StatusOK, recorder.Code)
var payload map[string]any
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload))
require.Equal(t, "pending_session", payload["auth_result"])
require.Equal(t, "adopt_existing_user_by_email", payload["intent"])
require.Equal(t, "oidc", payload["provider"])
require.Equal(t, "/dashboard", payload["redirect"])
require.Equal(t, true, payload["adoption_required"])
require.Equal(t, "Existing OIDC User", payload["suggested_display_name"])
require.Equal(t, "https://cdn.example/existing.png", payload["suggested_avatar_url"])
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
require.Equal(t, "adopt_existing_user_by_email", storedSession.Intent)
require.NotNil(t, storedSession.TargetUserID)
require.Equal(t, existingUser.ID, *storedSession.TargetUserID)
require.Equal(t, "owner@example.com", storedSession.ResolvedEmail)
require.Nil(t, storedSession.ConsumedAt)
identityCount, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("oidc"),
authidentity.ProviderKeyEQ("https://issuer.example"),
authidentity.ProviderSubjectEQ("oidc-existing-123"),
).
Count(ctx)
require.NoError(t, err)
require.Zero(t, identityCount)
}
func TestBindOIDCOAuthLoginBindsExistingUserAndConsumesSession(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
passwordHash, err := handler.authService.HashPassword("secret-123")
require.NoError(t, err)
existingUser, err := client.User.Create().
SetEmail("owner@example.com").
SetUsername("owner-user").
SetPasswordHash(passwordHash).
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
session, err := client.PendingAuthSession.Create().
SetSessionToken("bind-login-session-token").
SetIntent("adopt_existing_user_by_email").
SetProviderType("oidc").
SetProviderKey("https://issuer.example").
SetProviderSubject("oidc-bind-123").
SetTargetUserID(existingUser.ID).
SetResolvedEmail(existingUser.Email).
SetBrowserSessionKey("bind-login-browser-session-key").
SetUpstreamIdentityClaims(map[string]any{
"username": "oidc_user",
"suggested_display_name": "Bound OIDC User",
"suggested_avatar_url": "https://cdn.example/bound.png",
}).
SetRedirectTo("/profile").
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
body := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`)
recorder := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", body)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-login-browser-session-key")})
ginCtx.Request = req
handler.BindOIDCOAuthLogin(ginCtx)
require.Equal(t, http.StatusOK, recorder.Code)
var payload map[string]any
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload))
require.NotEmpty(t, payload["access_token"])
require.NotEmpty(t, payload["refresh_token"])
require.Equal(t, "Bearer", payload["token_type"])
identity, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("oidc"),
authidentity.ProviderKeyEQ("https://issuer.example"),
authidentity.ProviderSubjectEQ("oidc-bind-123"),
).
Only(ctx)
require.NoError(t, err)
require.Equal(t, existingUser.ID, identity.UserID)
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
require.NotNil(t, storedSession.ConsumedAt)
}
func TestBindOIDCOAuthLoginRejectsInvalidPasswordWithoutConsumingSession(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
passwordHash, err := handler.authService.HashPassword("secret-123")
require.NoError(t, err)
existingUser, err := client.User.Create().
SetEmail("owner@example.com").
SetUsername("owner-user").
SetPasswordHash(passwordHash).
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
session, err := client.PendingAuthSession.Create().
SetSessionToken("bind-login-invalid-password-session-token").
SetIntent("adopt_existing_user_by_email").
SetProviderType("oidc").
SetProviderKey("https://issuer.example").
SetProviderSubject("oidc-bind-invalid-123").
SetTargetUserID(existingUser.ID).
SetResolvedEmail(existingUser.Email).
SetBrowserSessionKey("bind-login-invalid-password-browser-session-key").
SetUpstreamIdentityClaims(map[string]any{
"username": "oidc_user",
"suggested_display_name": "Bound OIDC User",
"suggested_avatar_url": "https://cdn.example/bound.png",
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
body := bytes.NewBufferString(`{"email":"owner@example.com","password":"wrong-password"}`)
recorder := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", body)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-login-invalid-password-browser-session-key")})
ginCtx.Request = req
handler.BindOIDCOAuthLogin(ginCtx)
require.Equal(t, http.StatusUnauthorized, recorder.Code)
payload := decodeJSONBody(t, recorder)
require.Equal(t, "INVALID_CREDENTIALS", payload["reason"])
identityCount, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("oidc"),
authidentity.ProviderKeyEQ("https://issuer.example"),
authidentity.ProviderSubjectEQ("oidc-bind-invalid-123"),
).
Count(ctx)
require.NoError(t, err)
require.Zero(t, identityCount)
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
require.Nil(t, storedSession.ConsumedAt)
}
func newOAuthPendingFlowTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) {
t.Helper()
return newOAuthPendingFlowTestHandlerWithOptions(t, invitationEnabled, false, nil)
}
func newOAuthPendingFlowTestHandlerWithEmailVerification(
t *testing.T,
invitationEnabled bool,
email string,
code string,
) (*AuthHandler, *dbent.Client) {
t.Helper()
cache := &oauthPendingFlowEmailCacheStub{
verificationCodes: map[string]*service.VerificationCodeData{
email: {
Code: code,
Attempts: 0,
CreatedAt: time.Now().UTC(),
ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
},
},
}
return newOAuthPendingFlowTestHandlerWithOptions(t, invitationEnabled, true, cache)
}
func newOAuthPendingFlowTestHandlerWithOptions(
t *testing.T,
invitationEnabled bool,
emailVerifyEnabled bool,
emailCache service.EmailCache,
) (*AuthHandler, *dbent.Client) {
t.Helper()
db, err := sql.Open("sqlite", "file:auth_oauth_pending_flow_handler?mode=memory&cache=shared")
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })
@@ -538,9 +834,18 @@ func newOAuthPendingFlowTestHandler(t *testing.T, invitationEnabled bool) (*Auth
values: map[string]string{
service.SettingKeyRegistrationEnabled: "true",
service.SettingKeyInvitationCodeEnabled: boolSettingValue(invitationEnabled),
service.SettingKeyEmailVerifyEnabled: boolSettingValue(emailVerifyEnabled),
},
}, cfg)
userRepo := &oauthPendingFlowUserRepo{client: client}
var emailService *service.EmailService
if emailCache != nil {
emailService = service.NewEmailService(&oauthPendingFlowSettingRepoStub{
values: map[string]string{
service.SettingKeyEmailVerifyEnabled: boolSettingValue(emailVerifyEnabled),
},
}, emailCache)
}
authSvc := service.NewAuthService(
client,
userRepo,
@@ -548,7 +853,7 @@ func newOAuthPendingFlowTestHandler(t *testing.T, invitationEnabled bool) (*Auth
&oauthPendingFlowRefreshTokenCacheStub{},
cfg,
settingSvc,
nil,
emailService,
nil,
nil,
nil,
@@ -622,6 +927,70 @@ func (s *oauthPendingFlowSettingRepoStub) Delete(context.Context, string) error
type oauthPendingFlowRefreshTokenCacheStub struct{}
type oauthPendingFlowEmailCacheStub struct {
verificationCodes map[string]*service.VerificationCodeData
}
func (s *oauthPendingFlowEmailCacheStub) GetVerificationCode(_ context.Context, email string) (*service.VerificationCodeData, error) {
if s == nil || s.verificationCodes == nil {
return nil, nil
}
return s.verificationCodes[email], nil
}
func (s *oauthPendingFlowEmailCacheStub) SetVerificationCode(_ context.Context, email string, data *service.VerificationCodeData, _ time.Duration) error {
if s.verificationCodes == nil {
s.verificationCodes = map[string]*service.VerificationCodeData{}
}
s.verificationCodes[email] = data
return nil
}
func (s *oauthPendingFlowEmailCacheStub) DeleteVerificationCode(_ context.Context, email string) error {
delete(s.verificationCodes, email)
return nil
}
func (s *oauthPendingFlowEmailCacheStub) GetNotifyVerifyCode(context.Context, string) (*service.VerificationCodeData, error) {
return nil, nil
}
func (s *oauthPendingFlowEmailCacheStub) SetNotifyVerifyCode(context.Context, string, *service.VerificationCodeData, time.Duration) error {
return nil
}
func (s *oauthPendingFlowEmailCacheStub) DeleteNotifyVerifyCode(context.Context, string) error {
return nil
}
func (s *oauthPendingFlowEmailCacheStub) GetPasswordResetToken(context.Context, string) (*service.PasswordResetTokenData, error) {
return nil, nil
}
func (s *oauthPendingFlowEmailCacheStub) SetPasswordResetToken(context.Context, string, *service.PasswordResetTokenData, time.Duration) error {
return nil
}
func (s *oauthPendingFlowEmailCacheStub) DeletePasswordResetToken(context.Context, string) error {
return nil
}
func (s *oauthPendingFlowEmailCacheStub) IsPasswordResetEmailInCooldown(context.Context, string) bool {
return false
}
func (s *oauthPendingFlowEmailCacheStub) SetPasswordResetEmailCooldown(context.Context, string, time.Duration) error {
return nil
}
func (s *oauthPendingFlowEmailCacheStub) IncrNotifyCodeUserRate(context.Context, int64, time.Duration) (int64, error) {
return 0, nil
}
func (s *oauthPendingFlowEmailCacheStub) GetNotifyCodeUserRate(context.Context, int64) (int64, error) {
return 0, nil
}
func (s *oauthPendingFlowRefreshTokenCacheStub) StoreRefreshToken(context.Context, string, *service.RefreshTokenData, time.Duration) error {
return nil
}

View File

@@ -342,6 +342,21 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
idClaims.Name,
oidcFallbackUsername(subject),
)
identityRef := service.PendingAuthIdentityKey{
ProviderType: "oidc",
ProviderKey: issuer,
ProviderSubject: subject,
}
upstreamClaims := map[string]any{
"email": email,
"username": username,
"subject": subject,
"issuer": issuer,
"email_verified": emailVerified != nil && *emailVerified,
"provider_fallback": strings.TrimSpace(cfg.ProviderName),
"suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, idClaims.Name, username),
"suggested_avatar_url": userInfoClaims.AvatarURL,
}
if intent == oauthIntentBindCurrentUser {
targetUserID, err := h.readOAuthBindUserIDFromCookie(c, oidcOAuthBindUserCookieName)
if err != nil {
@@ -349,26 +364,13 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
return
}
if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
Intent: oauthIntentBindCurrentUser,
Identity: service.PendingAuthIdentityKey{
ProviderType: "oidc",
ProviderKey: issuer,
ProviderSubject: subject,
},
TargetUserID: &targetUserID,
ResolvedEmail: email,
RedirectTo: redirectTo,
BrowserSessionKey: browserSessionKey,
UpstreamIdentityClaims: map[string]any{
"email": email,
"username": username,
"subject": subject,
"issuer": issuer,
"email_verified": emailVerified != nil && *emailVerified,
"provider_fallback": strings.TrimSpace(cfg.ProviderName),
"suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, idClaims.Name, username),
"suggested_avatar_url": userInfoClaims.AvatarURL,
},
Intent: oauthIntentBindCurrentUser,
Identity: identityRef,
TargetUserID: &targetUserID,
ResolvedEmail: email,
RedirectTo: redirectTo,
BrowserSessionKey: browserSessionKey,
UpstreamIdentityClaims: upstreamClaims,
CompletionResponse: map[string]any{
"redirect": redirectTo,
},
@@ -380,30 +382,60 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
return
}
existingIdentityUser, err := h.findOAuthIdentityUser(c.Request.Context(), identityRef)
if err != nil {
redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
return
}
if existingIdentityUser != nil {
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), existingIdentityUser.Email, username, "")
if err != nil {
redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
return
}
if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
Intent: oauthIntentLogin,
Identity: identityRef,
TargetUserID: &user.ID,
ResolvedEmail: existingIdentityUser.Email,
RedirectTo: redirectTo,
BrowserSessionKey: browserSessionKey,
UpstreamIdentityClaims: upstreamClaims,
CompletionResponse: map[string]any{
"access_token": tokenPair.AccessToken,
"refresh_token": tokenPair.RefreshToken,
"expires_in": tokenPair.ExpiresIn,
"token_type": "Bearer",
"redirect": redirectTo,
},
}); err != nil {
redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
return
}
redirectToFrontendCallback(c, frontendCallback)
return
}
if h.isForceEmailOnThirdPartySignup(c.Request.Context()) {
if err := h.createOAuthEmailRequiredPendingSession(c, identityRef, redirectTo, browserSessionKey, upstreamClaims); err != nil {
redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
return
}
redirectToFrontendCallback(c, frontendCallback)
return
}
// 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
if err != nil {
if errors.Is(err, service.ErrOAuthInvitationRequired) {
if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
Intent: "login",
Identity: service.PendingAuthIdentityKey{
ProviderType: "oidc",
ProviderKey: issuer,
ProviderSubject: subject,
},
ResolvedEmail: email,
RedirectTo: redirectTo,
BrowserSessionKey: browserSessionKey,
UpstreamIdentityClaims: map[string]any{
"email": email,
"username": username,
"subject": subject,
"issuer": issuer,
"email_verified": emailVerified != nil && *emailVerified,
"provider_fallback": strings.TrimSpace(cfg.ProviderName),
"suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, idClaims.Name, username),
"suggested_avatar_url": userInfoClaims.AvatarURL,
},
Intent: "login",
Identity: identityRef,
ResolvedEmail: email,
RedirectTo: redirectTo,
BrowserSessionKey: browserSessionKey,
UpstreamIdentityClaims: upstreamClaims,
CompletionResponse: map[string]any{
"error": "invitation_required",
"redirect": redirectTo,
@@ -420,26 +452,13 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
}
if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
Intent: "login",
Identity: service.PendingAuthIdentityKey{
ProviderType: "oidc",
ProviderKey: issuer,
ProviderSubject: subject,
},
TargetUserID: &user.ID,
ResolvedEmail: email,
RedirectTo: redirectTo,
BrowserSessionKey: browserSessionKey,
UpstreamIdentityClaims: map[string]any{
"email": email,
"username": username,
"subject": subject,
"issuer": issuer,
"email_verified": emailVerified != nil && *emailVerified,
"provider_fallback": strings.TrimSpace(cfg.ProviderName),
"suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, idClaims.Name, username),
"suggested_avatar_url": userInfoClaims.AvatarURL,
},
Intent: "login",
Identity: identityRef,
TargetUserID: &user.ID,
ResolvedEmail: email,
RedirectTo: redirectTo,
BrowserSessionKey: browserSessionKey,
UpstreamIdentityClaims: upstreamClaims,
CompletionResponse: map[string]any{
"access_token": tokenPair.AccessToken,
"refresh_token": tokenPair.RefreshToken,

View File

@@ -214,6 +214,11 @@ func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) {
"suggested_display_name": strings.TrimSpace(userInfo.Nickname),
"suggested_avatar_url": strings.TrimSpace(userInfo.HeadImgURL),
}
identityRef := service.PendingAuthIdentityKey{
ProviderType: "wechat",
ProviderKey: wechatOAuthProviderKey,
ProviderSubject: providerSubject,
}
normalizedIntent := normalizeWeChatOAuthIntent(intent)
if normalizedIntent == wechatOAuthIntentBind {
@@ -232,6 +237,34 @@ func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) {
return
}
existingIdentityUser, err := h.findOAuthIdentityUser(c.Request.Context(), identityRef)
if err != nil {
redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
return
}
if existingIdentityUser != nil {
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), existingIdentityUser.Email, username, "")
if err != nil {
redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
return
}
if err := h.createWeChatPendingSession(c, normalizedIntent, providerSubject, existingIdentityUser.Email, redirectTo, browserSessionKey, upstreamClaims, tokenPair, nil, &user.ID); err != nil {
redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
return
}
redirectToFrontendCallback(c, frontendCallback)
return
}
if h.isForceEmailOnThirdPartySignup(c.Request.Context()) {
if err := h.createOAuthEmailRequiredPendingSession(c, identityRef, redirectTo, browserSessionKey, upstreamClaims); err != nil {
redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
return
}
redirectToFrontendCallback(c, frontendCallback)
return
}
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
if err != nil {
if err := h.createWeChatPendingSession(c, normalizedIntent, providerSubject, email, redirectTo, browserSessionKey, upstreamClaims, tokenPair, err, nil); err != nil {

View File

@@ -167,6 +167,7 @@ type DefaultSubscriptionSetting struct {
type PublicSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
ForceEmailOnThirdPartySignup bool `json:"force_email_on_third_party_signup"`
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`

View File

@@ -34,6 +34,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
response.Success(c, dto.PublicSettings{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
ForceEmailOnThirdPartySignup: settings.ForceEmailOnThirdPartySignup,
RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
PromoCodeEnabled: settings.PromoCodeEnabled,
PasswordResetEnabled: settings.PasswordResetEnabled,

View File

@@ -0,0 +1,83 @@
//go:build unit
package handler
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type settingHandlerPublicRepoStub struct {
values map[string]string
}
func (s *settingHandlerPublicRepoStub) Get(ctx context.Context, key string) (*service.Setting, error) {
panic("unexpected Get call")
}
func (s *settingHandlerPublicRepoStub) GetValue(ctx context.Context, key string) (string, error) {
panic("unexpected GetValue call")
}
func (s *settingHandlerPublicRepoStub) Set(ctx context.Context, key, value string) error {
panic("unexpected Set call")
}
func (s *settingHandlerPublicRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
out := make(map[string]string, len(keys))
for _, key := range keys {
if value, ok := s.values[key]; ok {
out[key] = value
}
}
return out, nil
}
func (s *settingHandlerPublicRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
panic("unexpected SetMultiple call")
}
func (s *settingHandlerPublicRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
panic("unexpected GetAll call")
}
func (s *settingHandlerPublicRepoStub) Delete(ctx context.Context, key string) error {
panic("unexpected Delete call")
}
func TestSettingHandler_GetPublicSettings_ExposesForceEmailOnThirdPartySignup(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := &settingHandlerPublicRepoStub{
values: map[string]string{
service.SettingKeyForceEmailOnThirdPartySignup: "true",
},
}
h := NewSettingHandler(service.NewSettingService(repo, &config.Config{}), "test-version")
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/settings/public", nil)
h.GetPublicSettings(c)
require.Equal(t, http.StatusOK, recorder.Code)
var resp struct {
Code int `json:"code"`
Data struct {
ForceEmailOnThirdPartySignup bool `json:"force_email_on_third_party_signup"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.True(t, resp.Data.ForceEmailOnThirdPartySignup)
}

View File

@@ -72,18 +72,54 @@ func RegisterAuthRoutes(
}),
h.Auth.ExchangePendingOAuthCompletion,
)
auth.POST("/oauth/pending/create-account",
rateLimiter.LimitWithOptions("oauth-pending-create-account", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}),
h.Auth.CreatePendingOAuthAccount,
)
auth.POST("/oauth/pending/bind-login",
rateLimiter.LimitWithOptions("oauth-pending-bind-login", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}),
h.Auth.BindPendingOAuthLogin,
)
auth.POST("/oauth/linuxdo/complete-registration",
rateLimiter.LimitWithOptions("oauth-linuxdo-complete", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}),
h.Auth.CompleteLinuxDoOAuthRegistration,
)
auth.POST("/oauth/linuxdo/bind-login",
rateLimiter.LimitWithOptions("oauth-linuxdo-bind-login", 20, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}),
h.Auth.BindLinuxDoOAuthLogin,
)
auth.POST("/oauth/linuxdo/create-account",
rateLimiter.LimitWithOptions("oauth-linuxdo-create-account", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}),
h.Auth.CreateLinuxDoOAuthAccount,
)
auth.POST("/oauth/wechat/complete-registration",
rateLimiter.LimitWithOptions("oauth-wechat-complete", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}),
h.Auth.CompleteWeChatOAuthRegistration,
)
auth.POST("/oauth/wechat/bind-login",
rateLimiter.LimitWithOptions("oauth-wechat-bind-login", 20, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}),
h.Auth.BindWeChatOAuthLogin,
)
auth.POST("/oauth/wechat/create-account",
rateLimiter.LimitWithOptions("oauth-wechat-create-account", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}),
h.Auth.CreateWeChatOAuthAccount,
)
auth.GET("/oauth/oidc/start", h.Auth.OIDCOAuthStart)
auth.GET("/oauth/oidc/callback", h.Auth.OIDCOAuthCallback)
auth.POST("/oauth/oidc/complete-registration",
@@ -92,6 +128,18 @@ func RegisterAuthRoutes(
}),
h.Auth.CompleteOIDCOAuthRegistration,
)
auth.POST("/oauth/oidc/bind-login",
rateLimiter.LimitWithOptions("oauth-oidc-bind-login", 20, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}),
h.Auth.BindOIDCOAuthLogin,
)
auth.POST("/oauth/oidc/create-account",
rateLimiter.LimitWithOptions("oauth-oidc-create-account", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}),
h.Auth.CreateOIDCOAuthAccount,
)
}
// 公开设置(无需认证)

View File

@@ -0,0 +1,151 @@
package service
import (
"context"
"errors"
"fmt"
"strings"
)
// VerifyOAuthEmailCode verifies the locally entered email verification code for
// third-party signup and binding flows. This is intentionally independent from
// the global registration email verification toggle.
func (s *AuthService) VerifyOAuthEmailCode(ctx context.Context, email, verifyCode string) error {
email = strings.TrimSpace(strings.ToLower(email))
verifyCode = strings.TrimSpace(verifyCode)
if email == "" {
return ErrEmailVerifyRequired
}
if verifyCode == "" {
return ErrEmailVerifyRequired
}
if s == nil || s.emailService == nil {
return ErrServiceUnavailable
}
return s.emailService.VerifyCode(ctx, email, verifyCode)
}
// RegisterOAuthEmailAccount creates a local account from a third-party first
// login after the user has verified a local email address.
func (s *AuthService) RegisterOAuthEmailAccount(
ctx context.Context,
email string,
password string,
verifyCode string,
invitationCode string,
signupSource string,
) (*TokenPair, *User, error) {
if s == nil {
return nil, nil, ErrServiceUnavailable
}
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
return nil, nil, ErrRegDisabled
}
email = strings.TrimSpace(strings.ToLower(email))
if isReservedEmail(email) {
return nil, nil, ErrEmailReserved
}
if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil {
return nil, nil, err
}
if err := s.VerifyOAuthEmailCode(ctx, email, verifyCode); err != nil {
return nil, nil, err
}
var invitationRedeemCode *RedeemCode
if s.settingService.IsInvitationCodeEnabled(ctx) {
if invitationCode == "" {
return nil, nil, ErrInvitationCodeRequired
}
redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode)
if err != nil {
return nil, nil, ErrInvitationCodeInvalid
}
if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused {
return nil, nil, ErrInvitationCodeInvalid
}
invitationRedeemCode = redeemCode
}
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
if err != nil {
return nil, nil, ErrServiceUnavailable
}
if existsEmail {
return nil, nil, ErrEmailExists
}
hashedPassword, err := s.HashPassword(password)
if err != nil {
return nil, nil, fmt.Errorf("hash password: %w", err)
}
signupSource = strings.TrimSpace(strings.ToLower(signupSource))
if signupSource == "" {
signupSource = "email"
}
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
user := &User{
Email: email,
PasswordHash: hashedPassword,
Role: RoleUser,
Balance: grantPlan.Balance,
Concurrency: grantPlan.Concurrency,
Status: StatusActive,
}
if err := s.userRepo.Create(ctx, user); err != nil {
if errors.Is(err, ErrEmailExists) {
return nil, nil, ErrEmailExists
}
return nil, nil, ErrServiceUnavailable
}
s.postAuthUserBootstrap(ctx, user, signupSource, true)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
if invitationRedeemCode != nil {
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
return nil, nil, ErrInvitationCodeInvalid
}
}
tokenPair, err := s.GenerateTokenPair(ctx, user, "")
if err != nil {
return nil, nil, fmt.Errorf("generate token pair: %w", err)
}
return tokenPair, user, nil
}
// ValidatePasswordCredentials checks the local password without completing the
// login flow. This is used by pending third-party account adoption flows before
// the external identity has been bound.
func (s *AuthService) ValidatePasswordCredentials(ctx context.Context, email, password string) (*User, error) {
if s == nil {
return nil, ErrServiceUnavailable
}
user, err := s.userRepo.GetByEmail(ctx, strings.TrimSpace(strings.ToLower(email)))
if err != nil {
if errors.Is(err, ErrUserNotFound) {
return nil, ErrInvalidCredentials
}
return nil, ErrServiceUnavailable
}
if !user.IsActive() {
return nil, ErrUserNotActive
}
if !s.CheckPassword(password, user.PasswordHash) {
return nil, ErrInvalidCredentials
}
return user, nil
}
// RecordSuccessfulLogin updates last-login activity after a non-standard login
// flow finishes with a real session.
func (s *AuthService) RecordSuccessfulLogin(ctx context.Context, userID int64) {
s.touchUserLogin(ctx, userID)
}

View File

@@ -217,6 +217,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
keys := []string{
SettingKeyRegistrationEnabled,
SettingKeyEmailVerifyEnabled,
SettingKeyForceEmailOnThirdPartySignup,
SettingKeyRegistrationEmailSuffixWhitelist,
SettingKeyPromoCodeEnabled,
SettingKeyPasswordResetEnabled,
@@ -294,6 +295,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
return &PublicSettings{
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: emailVerifyEnabled,
ForceEmailOnThirdPartySignup: settings[SettingKeyForceEmailOnThirdPartySignup] == "true",
RegistrationEmailSuffixWhitelist: registrationEmailSuffixWhitelist,
PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用
PasswordResetEnabled: passwordResetEnabled,

View File

@@ -77,3 +77,16 @@ func TestSettingService_GetPublicSettings_ExposesTablePreferences(t *testing.T)
require.Equal(t, 50, settings.TableDefaultPageSize)
require.Equal(t, []int{20, 50, 100}, settings.TablePageSizeOptions)
}
func TestSettingService_GetPublicSettings_ExposesForceEmailOnThirdPartySignup(t *testing.T) {
repo := &settingPublicRepoStub{
values: map[string]string{
SettingKeyForceEmailOnThirdPartySignup: "true",
},
}
svc := NewSettingService(repo, &config.Config{})
settings, err := svc.GetPublicSettings(context.Background())
require.NoError(t, err)
require.True(t, settings.ForceEmailOnThirdPartySignup)
}

View File

@@ -128,6 +128,7 @@ type DefaultSubscriptionSetting struct {
type PublicSettings struct {
RegistrationEnabled bool
EmailVerifyEnabled bool
ForceEmailOnThirdPartySignup bool
RegistrationEmailSuffixWhitelist []string
PromoCodeEnabled bool
PasswordResetEnabled bool