feat: complete email binding and pending oauth verification flows
This commit is contained in:
@@ -5,6 +5,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
@@ -15,6 +16,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/ent/enttest"
|
||||
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
|
||||
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
|
||||
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
|
||||
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
@@ -61,6 +63,18 @@ func TestApplySuggestedProfileToCompletionResponseKeepsExistingPayloadValues(t *
|
||||
require.Equal(t, true, payload["adoption_required"])
|
||||
}
|
||||
|
||||
func TestSetOAuthPendingSessionCookieUsesProviderCompletionPathPrefix(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(recorder)
|
||||
ginCtx.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback", nil)
|
||||
|
||||
setOAuthPendingSessionCookie(ginCtx, "pending-session-token", false)
|
||||
|
||||
cookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
|
||||
require.NotNil(t, cookie)
|
||||
require.Equal(t, "/api/v1/auth/oauth", cookie.Path)
|
||||
}
|
||||
|
||||
func TestExchangePendingOAuthCompletionPreviewThenFinalizeAppliesAdoptionDecision(t *testing.T) {
|
||||
handler, client := newOAuthPendingFlowTestHandler(t, false)
|
||||
ctx := context.Background()
|
||||
@@ -943,6 +957,81 @@ func TestCreateOIDCOAuthAccountBlocksBackendModeBeforeCreatingUser(t *testing.T)
|
||||
require.Nil(t, storedSession.ConsumedAt)
|
||||
}
|
||||
|
||||
func TestCreateOIDCOAuthAccountRollsBackCreatedUserWhenBindingFails(t *testing.T) {
|
||||
handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, true, "fresh@example.com", "246810")
|
||||
ctx := context.Background()
|
||||
|
||||
conflictOwner, err := client.User.Create().
|
||||
SetEmail("owner@example.com").
|
||||
SetUsername("owner-user").
|
||||
SetPasswordHash("hash").
|
||||
SetRole(service.RoleUser).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.AuthIdentity.Create().
|
||||
SetUserID(conflictOwner.ID).
|
||||
SetProviderType("oidc").
|
||||
SetProviderKey("https://issuer.example").
|
||||
SetProviderSubject("oidc-conflict-123").
|
||||
SetMetadata(map[string]any{
|
||||
"username": "owner-user",
|
||||
}).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
invitation, err := client.RedeemCode.Create().
|
||||
SetCode("INVITE123").
|
||||
SetType(service.RedeemTypeInvitation).
|
||||
SetStatus(service.StatusUnused).
|
||||
SetValue(0).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
session, err := client.PendingAuthSession.Create().
|
||||
SetSessionToken("create-account-conflict-session-token").
|
||||
SetIntent("login").
|
||||
SetProviderType("oidc").
|
||||
SetProviderKey("https://issuer.example").
|
||||
SetProviderSubject("oidc-conflict-123").
|
||||
SetBrowserSessionKey("create-account-conflict-browser-session-key").
|
||||
SetUpstreamIdentityClaims(map[string]any{
|
||||
"username": "oidc_user",
|
||||
}).
|
||||
SetRedirectTo("/profile").
|
||||
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
body := bytes.NewBufferString(`{"email":"fresh@example.com","verify_code":"246810","password":"secret-123","invitation_code":"INVITE123"}`)
|
||||
recorder := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(recorder)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
|
||||
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("create-account-conflict-browser-session-key")})
|
||||
ginCtx.Request = req
|
||||
|
||||
handler.CreateOIDCOAuthAccount(ginCtx)
|
||||
|
||||
require.Equal(t, http.StatusInternalServerError, recorder.Code)
|
||||
|
||||
userCount, err := client.User.Query().Where(dbuser.EmailEQ("fresh@example.com")).Count(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Zero(t, userCount)
|
||||
|
||||
storedInvitation, err := client.RedeemCode.Get(ctx, invitation.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, service.StatusUnused, storedInvitation.Status)
|
||||
require.Nil(t, storedInvitation.UsedBy)
|
||||
require.Nil(t, storedInvitation.UsedAt)
|
||||
|
||||
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, storedSession.ConsumedAt)
|
||||
}
|
||||
|
||||
func TestBindOIDCOAuthLoginBindsExistingUserAndConsumesSession(t *testing.T) {
|
||||
handler, client := newOAuthPendingFlowTestHandler(t, false)
|
||||
ctx := context.Background()
|
||||
@@ -1529,6 +1618,8 @@ type oauthPendingFlowTestHandlerOptions struct {
|
||||
defaultSubAssigner service.DefaultSubscriptionAssigner
|
||||
totpCache service.TotpCache
|
||||
totpEncryptor service.SecretEncryptor
|
||||
redeemRepoFactory func(client *dbent.Client) service.RedeemCodeRepository
|
||||
userRepoOptions oauthPendingFlowUserRepoOptions
|
||||
}
|
||||
|
||||
func newOAuthPendingFlowTestHandlerWithDependencies(
|
||||
@@ -1590,7 +1681,17 @@ CREATE TABLE IF NOT EXISTS user_avatars (
|
||||
settingValues[key] = value
|
||||
}
|
||||
settingSvc := service.NewSettingService(&oauthPendingFlowSettingRepoStub{values: settingValues}, cfg)
|
||||
userRepo := &oauthPendingFlowUserRepo{client: client}
|
||||
userRepo := &oauthPendingFlowUserRepo{
|
||||
client: client,
|
||||
options: options.userRepoOptions,
|
||||
}
|
||||
redeemRepo := service.RedeemCodeRepository(nil)
|
||||
if options.redeemRepoFactory != nil {
|
||||
redeemRepo = options.redeemRepoFactory(client)
|
||||
}
|
||||
if redeemRepo == nil {
|
||||
redeemRepo = &oauthPendingFlowRedeemCodeRepo{client: client}
|
||||
}
|
||||
var emailService *service.EmailService
|
||||
if options.emailCache != nil {
|
||||
emailService = service.NewEmailService(&oauthPendingFlowSettingRepoStub{
|
||||
@@ -1602,7 +1703,7 @@ CREATE TABLE IF NOT EXISTS user_avatars (
|
||||
authSvc := service.NewAuthService(
|
||||
client,
|
||||
userRepo,
|
||||
nil,
|
||||
redeemRepo,
|
||||
&oauthPendingFlowRefreshTokenCacheStub{},
|
||||
cfg,
|
||||
settingSvc,
|
||||
@@ -1797,6 +1898,127 @@ func (s *oauthPendingFlowRefreshTokenCacheStub) IsTokenInFamily(context.Context,
|
||||
return false, nil
|
||||
}
|
||||
|
||||
type oauthPendingFlowRedeemCodeRepo struct {
|
||||
client *dbent.Client
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowRedeemCodeRepo) Create(context.Context, *service.RedeemCode) error {
|
||||
panic("unexpected Create call")
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowRedeemCodeRepo) CreateBatch(context.Context, []service.RedeemCode) error {
|
||||
panic("unexpected CreateBatch call")
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowRedeemCodeRepo) GetByID(context.Context, int64) (*service.RedeemCode, error) {
|
||||
panic("unexpected GetByID call")
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowRedeemCodeRepo) GetByCode(ctx context.Context, code string) (*service.RedeemCode, error) {
|
||||
entity, err := r.client.RedeemCode.Query().Where(redeemcode.CodeEQ(code)).Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrRedeemCodeNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
notes := ""
|
||||
if entity.Notes != nil {
|
||||
notes = *entity.Notes
|
||||
}
|
||||
return &service.RedeemCode{
|
||||
ID: entity.ID,
|
||||
Code: entity.Code,
|
||||
Type: entity.Type,
|
||||
Value: entity.Value,
|
||||
Status: entity.Status,
|
||||
UsedBy: entity.UsedBy,
|
||||
UsedAt: entity.UsedAt,
|
||||
Notes: notes,
|
||||
CreatedAt: entity.CreatedAt,
|
||||
GroupID: entity.GroupID,
|
||||
ValidityDays: entity.ValidityDays,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowRedeemCodeRepo) Update(ctx context.Context, code *service.RedeemCode) error {
|
||||
if code == nil {
|
||||
return nil
|
||||
}
|
||||
update := r.client.RedeemCode.UpdateOneID(code.ID).
|
||||
SetCode(code.Code).
|
||||
SetType(code.Type).
|
||||
SetValue(code.Value).
|
||||
SetStatus(code.Status).
|
||||
SetNotes(code.Notes).
|
||||
SetValidityDays(code.ValidityDays)
|
||||
if code.UsedBy != nil {
|
||||
update = update.SetUsedBy(*code.UsedBy)
|
||||
} else {
|
||||
update = update.ClearUsedBy()
|
||||
}
|
||||
if code.UsedAt != nil {
|
||||
update = update.SetUsedAt(*code.UsedAt)
|
||||
} else {
|
||||
update = update.ClearUsedAt()
|
||||
}
|
||||
if code.GroupID != nil {
|
||||
update = update.SetGroupID(*code.GroupID)
|
||||
} else {
|
||||
update = update.ClearGroupID()
|
||||
}
|
||||
_, err := update.Save(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowRedeemCodeRepo) Delete(context.Context, int64) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowRedeemCodeRepo) Use(ctx context.Context, id, userID int64) error {
|
||||
affected, err := r.client.RedeemCode.Update().
|
||||
Where(redeemcode.IDEQ(id), redeemcode.StatusEQ(service.StatusUnused)).
|
||||
SetStatus(service.StatusUsed).
|
||||
SetUsedBy(userID).
|
||||
SetUsedAt(time.Now().UTC()).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected == 0 {
|
||||
return service.ErrRedeemCodeUsed
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowRedeemCodeRepo) List(context.Context, pagination.PaginationParams) ([]service.RedeemCode, *pagination.PaginationResult, error) {
|
||||
panic("unexpected List call")
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowRedeemCodeRepo) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string) ([]service.RedeemCode, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListWithFilters call")
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowRedeemCodeRepo) ListByUser(context.Context, int64, int) ([]service.RedeemCode, error) {
|
||||
panic("unexpected ListByUser call")
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowRedeemCodeRepo) ListByUserPaginated(context.Context, int64, pagination.PaginationParams, string) ([]service.RedeemCode, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListByUserPaginated call")
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowRedeemCodeRepo) SumPositiveBalanceByUser(context.Context, int64) (float64, error) {
|
||||
panic("unexpected SumPositiveBalanceByUser call")
|
||||
}
|
||||
|
||||
type oauthPendingFlowFailingUseRedeemRepo struct {
|
||||
*oauthPendingFlowRedeemCodeRepo
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowFailingUseRedeemRepo) Use(context.Context, int64, int64) error {
|
||||
return errors.New("forced invitation use failure")
|
||||
}
|
||||
|
||||
func decodeJSONResponseData(t *testing.T, recorder *httptest.ResponseRecorder) map[string]any {
|
||||
t.Helper()
|
||||
|
||||
@@ -1872,6 +2094,11 @@ func countProviderGrantRecords(
|
||||
|
||||
type oauthPendingFlowUserRepo struct {
|
||||
client *dbent.Client
|
||||
options oauthPendingFlowUserRepoOptions
|
||||
}
|
||||
|
||||
type oauthPendingFlowUserRepoOptions struct {
|
||||
rejectDeleteWhileAuthIdentityExists bool
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowUserRepo) Create(ctx context.Context, user *service.User) error {
|
||||
@@ -1953,6 +2180,15 @@ func (r *oauthPendingFlowUserRepo) Update(ctx context.Context, user *service.Use
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowUserRepo) Delete(ctx context.Context, id int64) error {
|
||||
if r.options.rejectDeleteWhileAuthIdentityExists {
|
||||
count, err := r.client.AuthIdentity.Query().Where(authidentity.UserIDEQ(id)).Count(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count > 0 {
|
||||
return errors.New("cannot delete user while auth identities still exist")
|
||||
}
|
||||
}
|
||||
return r.client.User.DeleteOneID(id).Exec(ctx)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user