feat: apply oauth first-bind defaults and pending bind 2fa
This commit is contained in:
@@ -6,6 +6,7 @@ import (
|
|||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||||
|
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||||
@@ -269,6 +270,62 @@ func (h *AuthHandler) Login2FA(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if session.PendingOAuthBind != nil {
|
||||||
|
pendingSvc, err := h.pendingIdentityService()
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
pendingSession, err := pendingSvc.GetBrowserSession(
|
||||||
|
c.Request.Context(),
|
||||||
|
session.PendingOAuthBind.PendingSessionToken,
|
||||||
|
session.PendingOAuthBind.BrowserSessionKey,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
decision, err := h.ensurePendingOAuthAdoptionDecision(c, pendingSession.ID, oauthAdoptionDecisionRequest{})
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := applyPendingOAuthBinding(
|
||||||
|
c.Request.Context(),
|
||||||
|
h.entClient(),
|
||||||
|
h.authService,
|
||||||
|
pendingSession,
|
||||||
|
decision,
|
||||||
|
&user.ID,
|
||||||
|
true,
|
||||||
|
true,
|
||||||
|
); err != nil {
|
||||||
|
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, err := pendingSvc.ConsumeBrowserSession(
|
||||||
|
c.Request.Context(),
|
||||||
|
pendingSession.SessionToken,
|
||||||
|
pendingSession.BrowserSessionKey,
|
||||||
|
); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
secureCookie := isRequestHTTPS(c)
|
||||||
|
clearOAuthPendingSessionCookie(c, secureCookie)
|
||||||
|
clearOAuthPendingBrowserCookie(c, secureCookie)
|
||||||
|
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
|
||||||
|
|
||||||
|
user, err = h.userService.GetByID(c.Request.Context(), session.UserID)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Delete the login session (only after all checks pass)
|
// Delete the login session (only after all checks pass)
|
||||||
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
|
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
|
||||||
|
|
||||||
|
|||||||
@@ -436,7 +436,7 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
|
|||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), session, decision, &user.ID); err != nil {
|
if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, session, decision, &user.ID); err != nil {
|
||||||
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
|
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -601,10 +601,12 @@ func shouldBindPendingOAuthIdentity(session *dbent.PendingAuthSession, decision
|
|||||||
func applyPendingOAuthBinding(
|
func applyPendingOAuthBinding(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
client *dbent.Client,
|
client *dbent.Client,
|
||||||
|
authService *service.AuthService,
|
||||||
session *dbent.PendingAuthSession,
|
session *dbent.PendingAuthSession,
|
||||||
decision *dbent.IdentityAdoptionDecision,
|
decision *dbent.IdentityAdoptionDecision,
|
||||||
overrideUserID *int64,
|
overrideUserID *int64,
|
||||||
forceBind bool,
|
forceBind bool,
|
||||||
|
applyFirstBindDefaults bool,
|
||||||
) error {
|
) error {
|
||||||
if client == nil || session == nil {
|
if client == nil || session == nil {
|
||||||
return nil
|
return nil
|
||||||
@@ -638,16 +640,17 @@ func applyPendingOAuthBinding(
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer func() { _ = tx.Rollback() }()
|
defer func() { _ = tx.Rollback() }()
|
||||||
|
txCtx := dbent.NewTxContext(ctx, tx)
|
||||||
|
|
||||||
if decision != nil && decision.AdoptDisplayName && adoptedDisplayName != "" {
|
if decision != nil && decision.AdoptDisplayName && adoptedDisplayName != "" {
|
||||||
if err := tx.Client().User.UpdateOneID(targetUserID).
|
if err := tx.Client().User.UpdateOneID(targetUserID).
|
||||||
SetUsername(adoptedDisplayName).
|
SetUsername(adoptedDisplayName).
|
||||||
Exec(ctx); err != nil {
|
Exec(txCtx); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
identity, err := ensurePendingOAuthIdentityForUser(ctx, tx, session, targetUserID)
|
identity, err := ensurePendingOAuthIdentityForUser(txCtx, tx, session, targetUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -667,14 +670,20 @@ func applyPendingOAuthBinding(
|
|||||||
if issuer := oauthIdentityIssuer(session); issuer != nil {
|
if issuer := oauthIdentityIssuer(session); issuer != nil {
|
||||||
updateIdentity = updateIdentity.SetIssuer(strings.TrimSpace(*issuer))
|
updateIdentity = updateIdentity.SetIssuer(strings.TrimSpace(*issuer))
|
||||||
}
|
}
|
||||||
if _, err := updateIdentity.Save(ctx); err != nil {
|
if _, err := updateIdentity.Save(txCtx); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if decision != nil && (decision.IdentityID == nil || *decision.IdentityID != identity.ID) {
|
if decision != nil && (decision.IdentityID == nil || *decision.IdentityID != identity.ID) {
|
||||||
if _, err := tx.Client().IdentityAdoptionDecision.UpdateOneID(decision.ID).
|
if _, err := tx.Client().IdentityAdoptionDecision.UpdateOneID(decision.ID).
|
||||||
SetIdentityID(identity.ID).
|
SetIdentityID(identity.ID).
|
||||||
Save(ctx); err != nil {
|
Save(txCtx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if applyFirstBindDefaults && authService != nil {
|
||||||
|
if err := authService.ApplyProviderDefaultSettingsOnFirstBind(txCtx, targetUserID, session.ProviderType); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -685,11 +694,21 @@ func applyPendingOAuthBinding(
|
|||||||
func applyPendingOAuthAdoption(
|
func applyPendingOAuthAdoption(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
client *dbent.Client,
|
client *dbent.Client,
|
||||||
|
authService *service.AuthService,
|
||||||
session *dbent.PendingAuthSession,
|
session *dbent.PendingAuthSession,
|
||||||
decision *dbent.IdentityAdoptionDecision,
|
decision *dbent.IdentityAdoptionDecision,
|
||||||
overrideUserID *int64,
|
overrideUserID *int64,
|
||||||
) error {
|
) error {
|
||||||
return applyPendingOAuthBinding(ctx, client, session, decision, overrideUserID, false)
|
return applyPendingOAuthBinding(
|
||||||
|
ctx,
|
||||||
|
client,
|
||||||
|
authService,
|
||||||
|
session,
|
||||||
|
decision,
|
||||||
|
overrideUserID,
|
||||||
|
false,
|
||||||
|
strings.EqualFold(strings.TrimSpace(session.Intent), "bind_current_user"),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func applySuggestedProfileToCompletionResponse(payload map[string]any, upstream map[string]any) {
|
func applySuggestedProfileToCompletionResponse(payload map[string]any, upstream map[string]any) {
|
||||||
@@ -804,7 +823,26 @@ func (h *AuthHandler) bindPendingOAuthLogin(c *gin.Context, provider string) {
|
|||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := applyPendingOAuthBinding(c.Request.Context(), h.entClient(), session, decision, &user.ID, true); err != nil {
|
if h.totpService != nil && h.settingSvc.IsTotpEnabled(c.Request.Context()) && user.TotpEnabled {
|
||||||
|
tempToken, err := h.totpService.CreatePendingOAuthBindLoginSession(
|
||||||
|
c.Request.Context(),
|
||||||
|
user.ID,
|
||||||
|
user.Email,
|
||||||
|
session.SessionToken,
|
||||||
|
session.BrowserSessionKey,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalError(c, "Failed to create 2FA session")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, TotpLoginResponse{
|
||||||
|
Requires2FA: true,
|
||||||
|
TempToken: tempToken,
|
||||||
|
UserEmailMasked: service.MaskEmail(user.Email),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := applyPendingOAuthBinding(c.Request.Context(), h.entClient(), h.authService, session, decision, &user.ID, true, true); err != nil {
|
||||||
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
|
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -900,7 +938,7 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
|
|||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := applyPendingOAuthBinding(c.Request.Context(), client, session, decision, &user.ID, true); err != nil {
|
if err := applyPendingOAuthBinding(c.Request.Context(), client, h.authService, session, decision, &user.ID, true, false); err != nil {
|
||||||
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
|
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -990,7 +1028,7 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
|
|||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), session, decision, session.TargetUserID); err != nil {
|
if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, session, decision, session.TargetUserID); err != nil {
|
||||||
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
|
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/pquerna/otp/totp"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"entgo.io/ent/dialect"
|
"entgo.io/ent/dialect"
|
||||||
@@ -773,6 +774,316 @@ func TestBindOIDCOAuthLoginRejectsInvalidPasswordWithoutConsumingSession(t *test
|
|||||||
require.Nil(t, storedSession.ConsumedAt)
|
require.Nil(t, storedSession.ConsumedAt)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBindOIDCOAuthLoginAppliesFirstBindGrantOnce(t *testing.T) {
|
||||||
|
defaultSubAssigner := &oauthPendingFlowDefaultSubAssignerStub{}
|
||||||
|
handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
|
||||||
|
settingValues: map[string]string{
|
||||||
|
service.SettingKeyAuthSourceDefaultOIDCBalance: "12.5",
|
||||||
|
service.SettingKeyAuthSourceDefaultOIDCConcurrency: "3",
|
||||||
|
service.SettingKeyAuthSourceDefaultOIDCSubscriptions: `[{"group_id":101,"validity_days":30}]`,
|
||||||
|
service.SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind: "true",
|
||||||
|
},
|
||||||
|
defaultSubAssigner: defaultSubAssigner,
|
||||||
|
})
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
passwordHash, err := handler.authService.HashPassword("secret-123")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
existingUser, err := client.User.Create().
|
||||||
|
SetEmail("owner@example.com").
|
||||||
|
SetUsername("owner-user").
|
||||||
|
SetPasswordHash(passwordHash).
|
||||||
|
SetBalance(5).
|
||||||
|
SetConcurrency(2).
|
||||||
|
SetRole(service.RoleUser).
|
||||||
|
SetStatus(service.StatusActive).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
firstSession, err := client.PendingAuthSession.Create().
|
||||||
|
SetSessionToken("first-bind-session-token").
|
||||||
|
SetIntent("adopt_existing_user_by_email").
|
||||||
|
SetProviderType("oidc").
|
||||||
|
SetProviderKey("https://issuer.example").
|
||||||
|
SetProviderSubject("oidc-bind-first-123").
|
||||||
|
SetTargetUserID(existingUser.ID).
|
||||||
|
SetResolvedEmail(existingUser.Email).
|
||||||
|
SetBrowserSessionKey("first-bind-browser-session-key").
|
||||||
|
SetUpstreamIdentityClaims(map[string]any{
|
||||||
|
"suggested_display_name": "Bound OIDC User",
|
||||||
|
"suggested_avatar_url": "https://cdn.example/bound.png",
|
||||||
|
}).
|
||||||
|
SetRedirectTo("/profile").
|
||||||
|
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
firstBody := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`)
|
||||||
|
firstRecorder := httptest.NewRecorder()
|
||||||
|
firstGinCtx, _ := gin.CreateTestContext(firstRecorder)
|
||||||
|
firstReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", firstBody)
|
||||||
|
firstReq.Header.Set("Content-Type", "application/json")
|
||||||
|
firstReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(firstSession.SessionToken)})
|
||||||
|
firstReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("first-bind-browser-session-key")})
|
||||||
|
firstGinCtx.Request = firstReq
|
||||||
|
|
||||||
|
handler.BindOIDCOAuthLogin(firstGinCtx)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, firstRecorder.Code)
|
||||||
|
|
||||||
|
storedUser, err := client.User.Get(ctx, existingUser.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 17.5, storedUser.Balance)
|
||||||
|
require.Equal(t, 5, storedUser.Concurrency)
|
||||||
|
require.Zero(t, storedUser.TotalRecharged)
|
||||||
|
require.Len(t, defaultSubAssigner.calls, 1)
|
||||||
|
require.Equal(t, int64(existingUser.ID), defaultSubAssigner.calls[0].UserID)
|
||||||
|
require.Equal(t, int64(101), defaultSubAssigner.calls[0].GroupID)
|
||||||
|
require.Equal(t, 30, defaultSubAssigner.calls[0].ValidityDays)
|
||||||
|
require.Equal(t, 1, countProviderGrantRecords(t, client, existingUser.ID, "oidc", "first_bind"))
|
||||||
|
|
||||||
|
secondSession, err := client.PendingAuthSession.Create().
|
||||||
|
SetSessionToken("second-bind-session-token").
|
||||||
|
SetIntent("adopt_existing_user_by_email").
|
||||||
|
SetProviderType("oidc").
|
||||||
|
SetProviderKey("https://issuer.example").
|
||||||
|
SetProviderSubject("oidc-bind-second-456").
|
||||||
|
SetTargetUserID(existingUser.ID).
|
||||||
|
SetResolvedEmail(existingUser.Email).
|
||||||
|
SetBrowserSessionKey("second-bind-browser-session-key").
|
||||||
|
SetUpstreamIdentityClaims(map[string]any{
|
||||||
|
"suggested_display_name": "Second OIDC User",
|
||||||
|
"suggested_avatar_url": "https://cdn.example/second.png",
|
||||||
|
}).
|
||||||
|
SetRedirectTo("/profile").
|
||||||
|
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
secondBody := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`)
|
||||||
|
secondRecorder := httptest.NewRecorder()
|
||||||
|
secondGinCtx, _ := gin.CreateTestContext(secondRecorder)
|
||||||
|
secondReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", secondBody)
|
||||||
|
secondReq.Header.Set("Content-Type", "application/json")
|
||||||
|
secondReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(secondSession.SessionToken)})
|
||||||
|
secondReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("second-bind-browser-session-key")})
|
||||||
|
secondGinCtx.Request = secondReq
|
||||||
|
|
||||||
|
handler.BindOIDCOAuthLogin(secondGinCtx)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, secondRecorder.Code)
|
||||||
|
|
||||||
|
storedUser, err = client.User.Get(ctx, existingUser.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 17.5, storedUser.Balance)
|
||||||
|
require.Equal(t, 5, storedUser.Concurrency)
|
||||||
|
require.Zero(t, storedUser.TotalRecharged)
|
||||||
|
require.Len(t, defaultSubAssigner.calls, 1)
|
||||||
|
require.Equal(t, 1, countProviderGrantRecords(t, client, existingUser.ID, "oidc", "first_bind"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBindOIDCOAuthLoginReturns2FAChallengeWhenUserHasTotp(t *testing.T) {
|
||||||
|
totpCache := &oauthPendingFlowTotpCacheStub{}
|
||||||
|
handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
|
||||||
|
settingValues: map[string]string{
|
||||||
|
service.SettingKeyTotpEnabled: "true",
|
||||||
|
},
|
||||||
|
totpCache: totpCache,
|
||||||
|
totpEncryptor: oauthPendingFlowTotpEncryptorStub{},
|
||||||
|
})
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
passwordHash, err := handler.authService.HashPassword("secret-123")
|
||||||
|
require.NoError(t, err)
|
||||||
|
totpEnabledAt := time.Now().UTC().Add(-time.Hour)
|
||||||
|
secret := "JBSWY3DPEHPK3PXP"
|
||||||
|
|
||||||
|
existingUser, err := client.User.Create().
|
||||||
|
SetEmail("owner@example.com").
|
||||||
|
SetUsername("owner-user").
|
||||||
|
SetPasswordHash(passwordHash).
|
||||||
|
SetRole(service.RoleUser).
|
||||||
|
SetStatus(service.StatusActive).
|
||||||
|
SetTotpEnabled(true).
|
||||||
|
SetTotpSecretEncrypted(secret).
|
||||||
|
SetTotpEnabledAt(totpEnabledAt).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
session, err := client.PendingAuthSession.Create().
|
||||||
|
SetSessionToken("bind-login-2fa-session-token").
|
||||||
|
SetIntent("adopt_existing_user_by_email").
|
||||||
|
SetProviderType("oidc").
|
||||||
|
SetProviderKey("https://issuer.example").
|
||||||
|
SetProviderSubject("oidc-bind-2fa-123").
|
||||||
|
SetTargetUserID(existingUser.ID).
|
||||||
|
SetResolvedEmail(existingUser.Email).
|
||||||
|
SetBrowserSessionKey("bind-login-2fa-browser-session-key").
|
||||||
|
SetUpstreamIdentityClaims(map[string]any{
|
||||||
|
"suggested_display_name": "Bound OIDC User",
|
||||||
|
"suggested_avatar_url": "https://cdn.example/bound.png",
|
||||||
|
}).
|
||||||
|
SetRedirectTo("/profile").
|
||||||
|
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
body := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(recorder)
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", body)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
|
||||||
|
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-login-2fa-browser-session-key")})
|
||||||
|
ginCtx.Request = req
|
||||||
|
|
||||||
|
handler.BindOIDCOAuthLogin(ginCtx)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, recorder.Code)
|
||||||
|
data := decodeJSONResponseData(t, recorder)
|
||||||
|
require.Equal(t, true, data["requires_2fa"])
|
||||||
|
require.Equal(t, "o***r@example.com", data["user_email_masked"])
|
||||||
|
tempToken, ok := data["temp_token"].(string)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.NotEmpty(t, tempToken)
|
||||||
|
|
||||||
|
loginSession, err := totpCache.GetLoginSession(ctx, tempToken)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, loginSession)
|
||||||
|
require.NotNil(t, loginSession.PendingOAuthBind)
|
||||||
|
require.Equal(t, session.SessionToken, loginSession.PendingOAuthBind.PendingSessionToken)
|
||||||
|
require.Equal(t, session.BrowserSessionKey, loginSession.PendingOAuthBind.BrowserSessionKey)
|
||||||
|
|
||||||
|
identityCount, err := client.AuthIdentity.Query().
|
||||||
|
Where(
|
||||||
|
authidentity.ProviderTypeEQ("oidc"),
|
||||||
|
authidentity.ProviderKeyEQ("https://issuer.example"),
|
||||||
|
authidentity.ProviderSubjectEQ("oidc-bind-2fa-123"),
|
||||||
|
).
|
||||||
|
Count(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Zero(t, identityCount)
|
||||||
|
|
||||||
|
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Nil(t, storedSession.ConsumedAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogin2FACompletesPendingOAuthBindAndConsumesSession(t *testing.T) {
|
||||||
|
totpCache := &oauthPendingFlowTotpCacheStub{}
|
||||||
|
defaultSubAssigner := &oauthPendingFlowDefaultSubAssignerStub{}
|
||||||
|
handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
|
||||||
|
settingValues: map[string]string{
|
||||||
|
service.SettingKeyTotpEnabled: "true",
|
||||||
|
service.SettingKeyAuthSourceDefaultOIDCBalance: "8",
|
||||||
|
service.SettingKeyAuthSourceDefaultOIDCConcurrency: "2",
|
||||||
|
service.SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind: "true",
|
||||||
|
},
|
||||||
|
defaultSubAssigner: defaultSubAssigner,
|
||||||
|
totpCache: totpCache,
|
||||||
|
totpEncryptor: oauthPendingFlowTotpEncryptorStub{},
|
||||||
|
})
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
passwordHash, err := handler.authService.HashPassword("secret-123")
|
||||||
|
require.NoError(t, err)
|
||||||
|
totpEnabledAt := time.Now().UTC().Add(-time.Hour)
|
||||||
|
secret := "JBSWY3DPEHPK3PXP"
|
||||||
|
|
||||||
|
existingUser, err := client.User.Create().
|
||||||
|
SetEmail("owner@example.com").
|
||||||
|
SetUsername("owner-user").
|
||||||
|
SetPasswordHash(passwordHash).
|
||||||
|
SetBalance(1.5).
|
||||||
|
SetConcurrency(4).
|
||||||
|
SetRole(service.RoleUser).
|
||||||
|
SetStatus(service.StatusActive).
|
||||||
|
SetTotpEnabled(true).
|
||||||
|
SetTotpSecretEncrypted(secret).
|
||||||
|
SetTotpEnabledAt(totpEnabledAt).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
session, err := client.PendingAuthSession.Create().
|
||||||
|
SetSessionToken("login-2fa-pending-session-token").
|
||||||
|
SetIntent("adopt_existing_user_by_email").
|
||||||
|
SetProviderType("oidc").
|
||||||
|
SetProviderKey("https://issuer.example").
|
||||||
|
SetProviderSubject("oidc-login-2fa-123").
|
||||||
|
SetTargetUserID(existingUser.ID).
|
||||||
|
SetResolvedEmail(existingUser.Email).
|
||||||
|
SetBrowserSessionKey("login-2fa-browser-session-key").
|
||||||
|
SetUpstreamIdentityClaims(map[string]any{
|
||||||
|
"suggested_display_name": "Bound OIDC User",
|
||||||
|
"suggested_avatar_url": "https://cdn.example/bound.png",
|
||||||
|
}).
|
||||||
|
SetRedirectTo("/profile").
|
||||||
|
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = client.IdentityAdoptionDecision.Create().
|
||||||
|
SetPendingAuthSessionID(session.ID).
|
||||||
|
SetAdoptDisplayName(false).
|
||||||
|
SetAdoptAvatar(false).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
tempToken, err := handler.totpService.CreatePendingOAuthBindLoginSession(
|
||||||
|
ctx,
|
||||||
|
existingUser.ID,
|
||||||
|
existingUser.Email,
|
||||||
|
session.SessionToken,
|
||||||
|
session.BrowserSessionKey,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
code, err := totp.GenerateCode(secret, time.Now().UTC())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
body := bytes.NewBufferString(`{"temp_token":"` + tempToken + `","totp_code":"` + code + `"}`)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(recorder)
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login/2fa", body)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
|
||||||
|
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue(session.BrowserSessionKey)})
|
||||||
|
ginCtx.Request = req
|
||||||
|
|
||||||
|
handler.Login2FA(ginCtx)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, recorder.Code)
|
||||||
|
payload := decodeJSONResponseData(t, recorder)
|
||||||
|
require.NotEmpty(t, payload["access_token"])
|
||||||
|
require.NotEmpty(t, payload["refresh_token"])
|
||||||
|
|
||||||
|
identity, err := client.AuthIdentity.Query().
|
||||||
|
Where(
|
||||||
|
authidentity.ProviderTypeEQ("oidc"),
|
||||||
|
authidentity.ProviderKeyEQ("https://issuer.example"),
|
||||||
|
authidentity.ProviderSubjectEQ("oidc-login-2fa-123"),
|
||||||
|
).
|
||||||
|
Only(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, existingUser.ID, identity.UserID)
|
||||||
|
|
||||||
|
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, storedSession.ConsumedAt)
|
||||||
|
|
||||||
|
loginSession, err := totpCache.GetLoginSession(ctx, tempToken)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Nil(t, loginSession)
|
||||||
|
|
||||||
|
storedUser, err := client.User.Get(ctx, existingUser.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 9.5, storedUser.Balance)
|
||||||
|
require.Equal(t, 6, storedUser.Concurrency)
|
||||||
|
require.Equal(t, 1, countProviderGrantRecords(t, client, existingUser.ID, "oidc", "first_bind"))
|
||||||
|
require.Empty(t, defaultSubAssigner.calls)
|
||||||
|
}
|
||||||
|
|
||||||
func newOAuthPendingFlowTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) {
|
func newOAuthPendingFlowTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
@@ -805,6 +1116,27 @@ func newOAuthPendingFlowTestHandlerWithOptions(
|
|||||||
invitationEnabled bool,
|
invitationEnabled bool,
|
||||||
emailVerifyEnabled bool,
|
emailVerifyEnabled bool,
|
||||||
emailCache service.EmailCache,
|
emailCache service.EmailCache,
|
||||||
|
) (*AuthHandler, *dbent.Client) {
|
||||||
|
return newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
|
||||||
|
invitationEnabled: invitationEnabled,
|
||||||
|
emailVerifyEnabled: emailVerifyEnabled,
|
||||||
|
emailCache: emailCache,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
type oauthPendingFlowTestHandlerOptions struct {
|
||||||
|
invitationEnabled bool
|
||||||
|
emailVerifyEnabled bool
|
||||||
|
emailCache service.EmailCache
|
||||||
|
settingValues map[string]string
|
||||||
|
defaultSubAssigner service.DefaultSubscriptionAssigner
|
||||||
|
totpCache service.TotpCache
|
||||||
|
totpEncryptor service.SecretEncryptor
|
||||||
|
}
|
||||||
|
|
||||||
|
func newOAuthPendingFlowTestHandlerWithDependencies(
|
||||||
|
t *testing.T,
|
||||||
|
options oauthPendingFlowTestHandlerOptions,
|
||||||
) (*AuthHandler, *dbent.Client) {
|
) (*AuthHandler, *dbent.Client) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
@@ -814,6 +1146,16 @@ func newOAuthPendingFlowTestHandlerWithOptions(
|
|||||||
|
|
||||||
_, err = db.Exec("PRAGMA foreign_keys = ON")
|
_, err = db.Exec("PRAGMA foreign_keys = ON")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
_, err = db.Exec(`
|
||||||
|
CREATE TABLE IF NOT EXISTS user_provider_default_grants (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
user_id INTEGER NOT NULL,
|
||||||
|
provider_type TEXT NOT NULL,
|
||||||
|
grant_reason TEXT NOT NULL DEFAULT 'first_bind',
|
||||||
|
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
UNIQUE(user_id, provider_type, grant_reason)
|
||||||
|
)`)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
drv := entsql.OpenDB(dialect.SQLite, db)
|
drv := entsql.OpenDB(dialect.SQLite, db)
|
||||||
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
|
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
|
||||||
@@ -830,21 +1172,23 @@ func newOAuthPendingFlowTestHandlerWithOptions(
|
|||||||
UserConcurrency: 1,
|
UserConcurrency: 1,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
settingSvc := service.NewSettingService(&oauthPendingFlowSettingRepoStub{
|
settingValues := map[string]string{
|
||||||
values: map[string]string{
|
service.SettingKeyRegistrationEnabled: "true",
|
||||||
service.SettingKeyRegistrationEnabled: "true",
|
service.SettingKeyInvitationCodeEnabled: boolSettingValue(options.invitationEnabled),
|
||||||
service.SettingKeyInvitationCodeEnabled: boolSettingValue(invitationEnabled),
|
service.SettingKeyEmailVerifyEnabled: boolSettingValue(options.emailVerifyEnabled),
|
||||||
service.SettingKeyEmailVerifyEnabled: boolSettingValue(emailVerifyEnabled),
|
}
|
||||||
},
|
for key, value := range options.settingValues {
|
||||||
}, cfg)
|
settingValues[key] = value
|
||||||
|
}
|
||||||
|
settingSvc := service.NewSettingService(&oauthPendingFlowSettingRepoStub{values: settingValues}, cfg)
|
||||||
userRepo := &oauthPendingFlowUserRepo{client: client}
|
userRepo := &oauthPendingFlowUserRepo{client: client}
|
||||||
var emailService *service.EmailService
|
var emailService *service.EmailService
|
||||||
if emailCache != nil {
|
if options.emailCache != nil {
|
||||||
emailService = service.NewEmailService(&oauthPendingFlowSettingRepoStub{
|
emailService = service.NewEmailService(&oauthPendingFlowSettingRepoStub{
|
||||||
values: map[string]string{
|
values: map[string]string{
|
||||||
service.SettingKeyEmailVerifyEnabled: boolSettingValue(emailVerifyEnabled),
|
service.SettingKeyEmailVerifyEnabled: boolSettingValue(options.emailVerifyEnabled),
|
||||||
},
|
},
|
||||||
}, emailCache)
|
}, options.emailCache)
|
||||||
}
|
}
|
||||||
authSvc := service.NewAuthService(
|
authSvc := service.NewAuthService(
|
||||||
client,
|
client,
|
||||||
@@ -857,14 +1201,27 @@ func newOAuthPendingFlowTestHandlerWithOptions(
|
|||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
nil,
|
options.defaultSubAssigner,
|
||||||
)
|
)
|
||||||
userSvc := service.NewUserService(userRepo, nil, nil, nil)
|
userSvc := service.NewUserService(userRepo, nil, nil, nil)
|
||||||
|
var totpSvc *service.TotpService
|
||||||
|
if options.totpCache != nil || options.totpEncryptor != nil {
|
||||||
|
totpCache := options.totpCache
|
||||||
|
if totpCache == nil {
|
||||||
|
totpCache = &oauthPendingFlowTotpCacheStub{}
|
||||||
|
}
|
||||||
|
totpEncryptor := options.totpEncryptor
|
||||||
|
if totpEncryptor == nil {
|
||||||
|
totpEncryptor = oauthPendingFlowTotpEncryptorStub{}
|
||||||
|
}
|
||||||
|
totpSvc = service.NewTotpService(userRepo, totpEncryptor, totpCache, settingSvc, nil, nil)
|
||||||
|
}
|
||||||
|
|
||||||
return &AuthHandler{
|
return &AuthHandler{
|
||||||
authService: authSvc,
|
authService: authSvc,
|
||||||
userService: userSvc,
|
userService: userSvc,
|
||||||
settingSvc: settingSvc,
|
settingSvc: settingSvc,
|
||||||
|
totpService: totpSvc,
|
||||||
}, client
|
}, client
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1049,6 +1406,32 @@ func decodeJSONBody(t *testing.T, recorder *httptest.ResponseRecorder) map[strin
|
|||||||
return payload
|
return payload
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func countProviderGrantRecords(
|
||||||
|
t *testing.T,
|
||||||
|
client *dbent.Client,
|
||||||
|
userID int64,
|
||||||
|
providerType string,
|
||||||
|
grantReason string,
|
||||||
|
) int {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
var rows entsql.Rows
|
||||||
|
err := client.Driver().Query(
|
||||||
|
context.Background(),
|
||||||
|
`SELECT COUNT(*) FROM user_provider_default_grants WHERE user_id = ? AND provider_type = ? AND grant_reason = ?`,
|
||||||
|
[]any{userID, providerType, grantReason},
|
||||||
|
&rows,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
require.True(t, rows.Next())
|
||||||
|
var count int
|
||||||
|
require.NoError(t, rows.Scan(&count))
|
||||||
|
require.False(t, rows.Next())
|
||||||
|
return count
|
||||||
|
}
|
||||||
|
|
||||||
type oauthPendingFlowUserRepo struct {
|
type oauthPendingFlowUserRepo struct {
|
||||||
client *dbent.Client
|
client *dbent.Client
|
||||||
}
|
}
|
||||||
@@ -1063,6 +1446,10 @@ func (r *oauthPendingFlowUserRepo) Create(ctx context.Context, user *service.Use
|
|||||||
SetBalance(user.Balance).
|
SetBalance(user.Balance).
|
||||||
SetConcurrency(user.Concurrency).
|
SetConcurrency(user.Concurrency).
|
||||||
SetStatus(user.Status).
|
SetStatus(user.Status).
|
||||||
|
SetNillableTotpSecretEncrypted(user.TotpSecretEncrypted).
|
||||||
|
SetTotpEnabled(user.TotpEnabled).
|
||||||
|
SetNillableTotpEnabledAt(user.TotpEnabledAt).
|
||||||
|
SetTotalRecharged(user.TotalRecharged).
|
||||||
SetSignupSource(user.SignupSource).
|
SetSignupSource(user.SignupSource).
|
||||||
SetNillableLastLoginAt(user.LastLoginAt).
|
SetNillableLastLoginAt(user.LastLoginAt).
|
||||||
SetNillableLastActiveAt(user.LastActiveAt).
|
SetNillableLastActiveAt(user.LastActiveAt).
|
||||||
@@ -1112,6 +1499,10 @@ func (r *oauthPendingFlowUserRepo) Update(ctx context.Context, user *service.Use
|
|||||||
SetBalance(user.Balance).
|
SetBalance(user.Balance).
|
||||||
SetConcurrency(user.Concurrency).
|
SetConcurrency(user.Concurrency).
|
||||||
SetStatus(user.Status).
|
SetStatus(user.Status).
|
||||||
|
SetNillableTotpSecretEncrypted(user.TotpSecretEncrypted).
|
||||||
|
SetTotpEnabled(user.TotpEnabled).
|
||||||
|
SetNillableTotpEnabledAt(user.TotpEnabledAt).
|
||||||
|
SetTotalRecharged(user.TotalRecharged).
|
||||||
SetSignupSource(user.SignupSource).
|
SetSignupSource(user.SignupSource).
|
||||||
SetNillableLastLoginAt(user.LastLoginAt).
|
SetNillableLastLoginAt(user.LastLoginAt).
|
||||||
SetNillableLastActiveAt(user.LastActiveAt).
|
SetNillableLastActiveAt(user.LastActiveAt).
|
||||||
@@ -1203,16 +1594,29 @@ func (r *oauthPendingFlowUserRepo) ListUserAuthIdentities(ctx context.Context, u
|
|||||||
return records, nil
|
return records, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *oauthPendingFlowUserRepo) UpdateTotpSecret(context.Context, int64, *string) error {
|
func (r *oauthPendingFlowUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
|
||||||
panic("unexpected UpdateTotpSecret call")
|
update := r.client.User.UpdateOneID(userID)
|
||||||
|
if encryptedSecret == nil {
|
||||||
|
update = update.ClearTotpSecretEncrypted()
|
||||||
|
} else {
|
||||||
|
update = update.SetTotpSecretEncrypted(*encryptedSecret)
|
||||||
|
}
|
||||||
|
return update.Exec(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *oauthPendingFlowUserRepo) EnableTotp(context.Context, int64) error {
|
func (r *oauthPendingFlowUserRepo) EnableTotp(ctx context.Context, userID int64) error {
|
||||||
panic("unexpected EnableTotp call")
|
return r.client.User.UpdateOneID(userID).
|
||||||
|
SetTotpEnabled(true).
|
||||||
|
SetTotpEnabledAt(time.Now().UTC()).
|
||||||
|
Exec(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *oauthPendingFlowUserRepo) DisableTotp(context.Context, int64) error {
|
func (r *oauthPendingFlowUserRepo) DisableTotp(ctx context.Context, userID int64) error {
|
||||||
panic("unexpected DisableTotp call")
|
return r.client.User.UpdateOneID(userID).
|
||||||
|
SetTotpEnabled(false).
|
||||||
|
ClearTotpSecretEncrypted().
|
||||||
|
ClearTotpEnabledAt().
|
||||||
|
Exec(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func oauthPendingFlowServiceUser(entity *dbent.User) *service.User {
|
func oauthPendingFlowServiceUser(entity *dbent.User) *service.User {
|
||||||
@@ -1220,19 +1624,113 @@ func oauthPendingFlowServiceUser(entity *dbent.User) *service.User {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return &service.User{
|
return &service.User{
|
||||||
ID: entity.ID,
|
ID: entity.ID,
|
||||||
Email: entity.Email,
|
Email: entity.Email,
|
||||||
Username: entity.Username,
|
Username: entity.Username,
|
||||||
Notes: entity.Notes,
|
Notes: entity.Notes,
|
||||||
PasswordHash: entity.PasswordHash,
|
PasswordHash: entity.PasswordHash,
|
||||||
Role: entity.Role,
|
Role: entity.Role,
|
||||||
Balance: entity.Balance,
|
Balance: entity.Balance,
|
||||||
Concurrency: entity.Concurrency,
|
Concurrency: entity.Concurrency,
|
||||||
Status: entity.Status,
|
Status: entity.Status,
|
||||||
SignupSource: entity.SignupSource,
|
SignupSource: entity.SignupSource,
|
||||||
LastLoginAt: entity.LastLoginAt,
|
LastLoginAt: entity.LastLoginAt,
|
||||||
LastActiveAt: entity.LastActiveAt,
|
LastActiveAt: entity.LastActiveAt,
|
||||||
CreatedAt: entity.CreatedAt,
|
TotpSecretEncrypted: entity.TotpSecretEncrypted,
|
||||||
UpdatedAt: entity.UpdatedAt,
|
TotpEnabled: entity.TotpEnabled,
|
||||||
|
TotpEnabledAt: entity.TotpEnabledAt,
|
||||||
|
TotalRecharged: entity.TotalRecharged,
|
||||||
|
CreatedAt: entity.CreatedAt,
|
||||||
|
UpdatedAt: entity.UpdatedAt,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type oauthPendingFlowDefaultSubAssignerStub struct {
|
||||||
|
calls []service.AssignSubscriptionInput
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *oauthPendingFlowDefaultSubAssignerStub) AssignOrExtendSubscription(
|
||||||
|
_ context.Context,
|
||||||
|
input *service.AssignSubscriptionInput,
|
||||||
|
) (*service.UserSubscription, bool, error) {
|
||||||
|
if input != nil {
|
||||||
|
s.calls = append(s.calls, *input)
|
||||||
|
}
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type oauthPendingFlowTotpCacheStub struct {
|
||||||
|
setupSessions map[int64]*service.TotpSetupSession
|
||||||
|
loginSessions map[string]*service.TotpLoginSession
|
||||||
|
verifyAttempts map[int64]int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *oauthPendingFlowTotpCacheStub) GetSetupSession(_ context.Context, userID int64) (*service.TotpSetupSession, error) {
|
||||||
|
if s == nil || s.setupSessions == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return s.setupSessions[userID], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *oauthPendingFlowTotpCacheStub) SetSetupSession(_ context.Context, userID int64, session *service.TotpSetupSession, _ time.Duration) error {
|
||||||
|
if s.setupSessions == nil {
|
||||||
|
s.setupSessions = map[int64]*service.TotpSetupSession{}
|
||||||
|
}
|
||||||
|
s.setupSessions[userID] = session
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *oauthPendingFlowTotpCacheStub) DeleteSetupSession(_ context.Context, userID int64) error {
|
||||||
|
delete(s.setupSessions, userID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *oauthPendingFlowTotpCacheStub) GetLoginSession(_ context.Context, tempToken string) (*service.TotpLoginSession, error) {
|
||||||
|
if s == nil || s.loginSessions == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return s.loginSessions[tempToken], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *oauthPendingFlowTotpCacheStub) SetLoginSession(_ context.Context, tempToken string, session *service.TotpLoginSession, _ time.Duration) error {
|
||||||
|
if s.loginSessions == nil {
|
||||||
|
s.loginSessions = map[string]*service.TotpLoginSession{}
|
||||||
|
}
|
||||||
|
s.loginSessions[tempToken] = session
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *oauthPendingFlowTotpCacheStub) DeleteLoginSession(_ context.Context, tempToken string) error {
|
||||||
|
delete(s.loginSessions, tempToken)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *oauthPendingFlowTotpCacheStub) IncrementVerifyAttempts(_ context.Context, userID int64) (int, error) {
|
||||||
|
if s.verifyAttempts == nil {
|
||||||
|
s.verifyAttempts = map[int64]int{}
|
||||||
|
}
|
||||||
|
s.verifyAttempts[userID]++
|
||||||
|
return s.verifyAttempts[userID], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *oauthPendingFlowTotpCacheStub) GetVerifyAttempts(_ context.Context, userID int64) (int, error) {
|
||||||
|
if s == nil || s.verifyAttempts == nil {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
return s.verifyAttempts[userID], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *oauthPendingFlowTotpCacheStub) ClearVerifyAttempts(_ context.Context, userID int64) error {
|
||||||
|
delete(s.verifyAttempts, userID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type oauthPendingFlowTotpEncryptorStub struct{}
|
||||||
|
|
||||||
|
func (oauthPendingFlowTotpEncryptorStub) Encrypt(plaintext string) (string, error) {
|
||||||
|
return plaintext, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (oauthPendingFlowTotpEncryptorStub) Decrypt(ciphertext string) (string, error) {
|
||||||
|
return ciphertext, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -537,7 +537,7 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
|
|||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), session, decision, &user.ID); err != nil {
|
if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, session, decision, &user.ID); err != nil {
|
||||||
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
|
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -346,7 +346,7 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) {
|
|||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), session, decision, &user.ID); err != nil {
|
if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, session, decision, &user.ID); err != nil {
|
||||||
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
|
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
106
backend/internal/service/auth_oauth_first_bind.go
Normal file
106
backend/internal/service/auth_oauth_first_bind.go
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
|
|
||||||
|
entsql "entgo.io/ent/dialect/sql"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ApplyProviderDefaultSettingsOnFirstBind applies provider-specific bootstrap
|
||||||
|
// settings the first time a user binds a third-party identity. The grant is
|
||||||
|
// idempotent per user/provider pair.
|
||||||
|
func (s *AuthService) ApplyProviderDefaultSettingsOnFirstBind(
|
||||||
|
ctx context.Context,
|
||||||
|
userID int64,
|
||||||
|
providerType string,
|
||||||
|
) error {
|
||||||
|
if s == nil || s.entClient == nil || s.settingService == nil || userID <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if dbent.TxFromContext(ctx) != nil {
|
||||||
|
return s.applyProviderDefaultSettingsOnFirstBind(ctx, userID, providerType)
|
||||||
|
}
|
||||||
|
|
||||||
|
tx, err := s.entClient.Tx(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("begin first bind defaults transaction: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = tx.Rollback() }()
|
||||||
|
|
||||||
|
txCtx := dbent.NewTxContext(ctx, tx)
|
||||||
|
if err := s.applyProviderDefaultSettingsOnFirstBind(txCtx, userID, providerType); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return tx.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AuthService) applyProviderDefaultSettingsOnFirstBind(
|
||||||
|
ctx context.Context,
|
||||||
|
userID int64,
|
||||||
|
providerType string,
|
||||||
|
) error {
|
||||||
|
defaults, err := s.settingService.GetAuthSourceDefaultSettings(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("load auth source defaults: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
providerDefaults, ok := authSourceSignupSettings(defaults, providerType)
|
||||||
|
if !ok || !providerDefaults.GrantOnFirstBind {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
client := s.entClient
|
||||||
|
if tx := dbent.TxFromContext(ctx); tx != nil {
|
||||||
|
client = tx.Client()
|
||||||
|
}
|
||||||
|
|
||||||
|
var result entsql.Result
|
||||||
|
if err := client.Driver().Exec(
|
||||||
|
ctx,
|
||||||
|
`INSERT INTO user_provider_default_grants (user_id, provider_type, grant_reason)
|
||||||
|
VALUES (?, ?, ?)
|
||||||
|
ON CONFLICT (user_id, provider_type, grant_reason) DO NOTHING`,
|
||||||
|
[]any{userID, strings.TrimSpace(providerType), "first_bind"},
|
||||||
|
&result,
|
||||||
|
); err != nil {
|
||||||
|
return fmt.Errorf("record first bind provider grant: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
affected, err := result.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read first bind provider grant result: %w", err)
|
||||||
|
}
|
||||||
|
if affected == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if providerDefaults.Balance != 0 {
|
||||||
|
if err := client.User.UpdateOneID(userID).AddBalance(providerDefaults.Balance).Exec(ctx); err != nil {
|
||||||
|
return fmt.Errorf("apply first bind balance default: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if providerDefaults.Concurrency != 0 {
|
||||||
|
if err := client.User.UpdateOneID(userID).AddConcurrency(providerDefaults.Concurrency).Exec(ctx); err != nil {
|
||||||
|
return fmt.Errorf("apply first bind concurrency default: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if s.defaultSubAssigner != nil {
|
||||||
|
for _, item := range providerDefaults.Subscriptions {
|
||||||
|
if _, _, err := s.defaultSubAssigner.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{
|
||||||
|
UserID: userID,
|
||||||
|
GroupID: item.GroupID,
|
||||||
|
ValidityDays: item.ValidityDays,
|
||||||
|
Notes: "auto assigned by first bind defaults",
|
||||||
|
}); err != nil {
|
||||||
|
return fmt.Errorf("apply first bind subscription default: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -58,9 +58,15 @@ type TotpSetupSession struct {
|
|||||||
|
|
||||||
// TotpLoginSession represents a pending 2FA login session
|
// TotpLoginSession represents a pending 2FA login session
|
||||||
type TotpLoginSession struct {
|
type TotpLoginSession struct {
|
||||||
UserID int64
|
UserID int64
|
||||||
Email string
|
Email string
|
||||||
TokenExpiry time.Time
|
TokenExpiry time.Time
|
||||||
|
PendingOAuthBind *PendingOAuthBindLoginSession `json:"pending_oauth_bind,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type PendingOAuthBindLoginSession struct {
|
||||||
|
PendingSessionToken string `json:"pending_session_token,omitempty"`
|
||||||
|
BrowserSessionKey string `json:"browser_session_key,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// TotpStatus represents the TOTP status for a user
|
// TotpStatus represents the TOTP status for a user
|
||||||
@@ -397,6 +403,30 @@ func (s *TotpService) VerifyCode(ctx context.Context, userID int64, code string)
|
|||||||
|
|
||||||
// CreateLoginSession creates a temporary login session for 2FA
|
// CreateLoginSession creates a temporary login session for 2FA
|
||||||
func (s *TotpService) CreateLoginSession(ctx context.Context, userID int64, email string) (string, error) {
|
func (s *TotpService) CreateLoginSession(ctx context.Context, userID int64, email string) (string, error) {
|
||||||
|
return s.createLoginSession(ctx, userID, email, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreatePendingOAuthBindLoginSession creates a temporary 2FA session that will
|
||||||
|
// finalize a pending OAuth bind after the TOTP code is verified.
|
||||||
|
func (s *TotpService) CreatePendingOAuthBindLoginSession(
|
||||||
|
ctx context.Context,
|
||||||
|
userID int64,
|
||||||
|
email string,
|
||||||
|
pendingSessionToken string,
|
||||||
|
browserSessionKey string,
|
||||||
|
) (string, error) {
|
||||||
|
return s.createLoginSession(ctx, userID, email, &PendingOAuthBindLoginSession{
|
||||||
|
PendingSessionToken: pendingSessionToken,
|
||||||
|
BrowserSessionKey: browserSessionKey,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TotpService) createLoginSession(
|
||||||
|
ctx context.Context,
|
||||||
|
userID int64,
|
||||||
|
email string,
|
||||||
|
pendingOAuthBind *PendingOAuthBindLoginSession,
|
||||||
|
) (string, error) {
|
||||||
// Generate a random temp token
|
// Generate a random temp token
|
||||||
tempToken, err := generateRandomToken(32)
|
tempToken, err := generateRandomToken(32)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -404,9 +434,10 @@ func (s *TotpService) CreateLoginSession(ctx context.Context, userID int64, emai
|
|||||||
}
|
}
|
||||||
|
|
||||||
session := &TotpLoginSession{
|
session := &TotpLoginSession{
|
||||||
UserID: userID,
|
UserID: userID,
|
||||||
Email: email,
|
Email: email,
|
||||||
TokenExpiry: time.Now().Add(totpLoginTTL),
|
TokenExpiry: time.Now().Add(totpLoginTTL),
|
||||||
|
PendingOAuthBind: pendingOAuthBind,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.cache.SetLoginSession(ctx, tempToken, session, totpLoginTTL); err != nil {
|
if err := s.cache.SetLoginSession(ctx, tempToken, session, totpLoginTTL); err != nil {
|
||||||
|
|||||||
Reference in New Issue
Block a user