feat: complete email binding and pending oauth verification flows
This commit is contained in:
@@ -79,7 +79,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
totpCache := repository.NewTotpCache(redisClient)
|
||||
totpService := service.NewTotpService(userRepository, secretEncryptor, totpCache, settingService, emailService, emailQueueService)
|
||||
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, redeemService, totpService)
|
||||
userHandler := handler.NewUserHandler(userService, emailService, emailCache)
|
||||
userHandler := handler.NewUserHandler(userService, authService, emailService, emailCache)
|
||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||
usageLogRepository := repository.NewUsageLogRepository(client, db)
|
||||
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
@@ -27,7 +28,7 @@ import (
|
||||
const (
|
||||
oauthPendingBrowserCookiePath = "/api/v1/auth/oauth"
|
||||
oauthPendingBrowserCookieName = "oauth_pending_browser_session"
|
||||
oauthPendingSessionCookiePath = "/api/v1/auth/oauth/pending"
|
||||
oauthPendingSessionCookiePath = "/api/v1/auth/oauth"
|
||||
oauthPendingSessionCookieName = "oauth_pending_session"
|
||||
oauthPendingCookieMaxAgeSec = 10 * 60
|
||||
|
||||
@@ -66,6 +67,13 @@ type createPendingOAuthAccountRequest struct {
|
||||
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
|
||||
}
|
||||
|
||||
type sendPendingOAuthVerifyCodeRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
TurnstileToken string `json:"turnstile_token,omitempty"`
|
||||
PendingAuthToken string `json:"pending_auth_token,omitempty"`
|
||||
PendingOAuthToken string `json:"pending_oauth_token,omitempty"`
|
||||
}
|
||||
|
||||
func (r bindPendingOAuthLoginRequest) adoptionDecision() oauthAdoptionDecisionRequest {
|
||||
return oauthAdoptionDecisionRequest{
|
||||
AdoptDisplayName: r.AdoptDisplayName,
|
||||
@@ -448,6 +456,43 @@ func (h *AuthHandler) CreatePendingOAuthAccount(c *gin.Context) {
|
||||
h.createPendingOAuthAccount(c, "")
|
||||
}
|
||||
|
||||
// SendPendingOAuthVerifyCode sends a verification code for a browser-bound
|
||||
// pending OAuth account-creation flow.
|
||||
// POST /api/v1/auth/oauth/pending/send-verify-code
|
||||
func (h *AuthHandler) SendPendingOAuthVerifyCode(c *gin.Context) {
|
||||
var req sendPendingOAuthVerifyCodeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
_, session, _, err := readPendingOAuthBrowserSession(c, h)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.authService.SendPendingOAuthVerifyCode(c.Request.Context(), req.Email)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, SendVerifyCodeResponse{
|
||||
Message: "Verification code sent successfully",
|
||||
Countdown: result.Countdown,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *AuthHandler) upsertPendingOAuthAdoptionDecision(
|
||||
c *gin.Context,
|
||||
sessionID int64,
|
||||
@@ -1084,6 +1129,41 @@ func buildPendingOAuthSessionStatusPayload(session *dbent.PendingAuthSession) gi
|
||||
return payload
|
||||
}
|
||||
|
||||
func (h *AuthHandler) transitionPendingOAuthAccountToBindLogin(
|
||||
c *gin.Context,
|
||||
client *dbent.Client,
|
||||
session *dbent.PendingAuthSession,
|
||||
email string,
|
||||
decision oauthAdoptionDecisionRequest,
|
||||
) (*dbent.PendingAuthSession, error) {
|
||||
existingUser, err := findUserByNormalizedEmail(c.Request.Context(), client, email)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
completionResponse := mergePendingCompletionResponse(session, map[string]any{
|
||||
"step": "bind_login_required",
|
||||
"email": email,
|
||||
})
|
||||
session, err = updatePendingOAuthSessionProgress(
|
||||
c.Request.Context(),
|
||||
client,
|
||||
session,
|
||||
"adopt_existing_user_by_email",
|
||||
email,
|
||||
&existingUser.ID,
|
||||
completionResponse,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, infraerrors.InternalServer("PENDING_AUTH_SESSION_UPDATE_FAILED", "failed to update pending oauth session").WithCause(err)
|
||||
}
|
||||
|
||||
if _, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, decision); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func writeOAuthTokenPairResponse(c *gin.Context, tokenPair *service.TokenPair) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"access_token": tokenPair.AccessToken,
|
||||
@@ -1199,29 +1279,11 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
|
||||
return
|
||||
}
|
||||
if existingUser != nil {
|
||||
completionResponse := mergePendingCompletionResponse(session, map[string]any{
|
||||
"step": "bind_login_required",
|
||||
"email": email,
|
||||
})
|
||||
session, err = updatePendingOAuthSessionProgress(
|
||||
c.Request.Context(),
|
||||
client,
|
||||
session,
|
||||
"adopt_existing_user_by_email",
|
||||
email,
|
||||
&existingUser.ID,
|
||||
completionResponse,
|
||||
)
|
||||
session, err = h.transitionPendingOAuthAccountToBindLogin(c, client, session, email, req.adoptionDecision())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_SESSION_UPDATE_FAILED", "failed to update pending oauth session").WithCause(err))
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, req.adoptionDecision()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(session))
|
||||
return
|
||||
}
|
||||
@@ -1239,27 +1301,77 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
|
||||
strings.TrimSpace(session.ProviderType),
|
||||
)
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrEmailExists) {
|
||||
session, err = h.transitionPendingOAuthAccountToBindLogin(c, client, session, email, req.adoptionDecision())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(session))
|
||||
return
|
||||
}
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
rollbackCreatedUser := func(originalErr error) bool {
|
||||
if user == nil || user.ID <= 0 {
|
||||
return false
|
||||
}
|
||||
if rollbackErr := h.authService.RollbackOAuthEmailAccountCreation(
|
||||
c.Request.Context(),
|
||||
user.ID,
|
||||
strings.TrimSpace(req.InvitationCode),
|
||||
); rollbackErr != nil {
|
||||
response.ErrorFrom(c, infraerrors.InternalServer(
|
||||
"PENDING_AUTH_ACCOUNT_ROLLBACK_FAILED",
|
||||
"failed to rollback pending oauth account creation",
|
||||
).WithCause(fmt.Errorf("original error: %w; rollback error: %v", originalErr, rollbackErr)))
|
||||
return true
|
||||
}
|
||||
user = nil
|
||||
return false
|
||||
}
|
||||
|
||||
decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, req.adoptionDecision())
|
||||
if err != nil {
|
||||
if rollbackCreatedUser(err) {
|
||||
return
|
||||
}
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if err := applyPendingOAuthBinding(c.Request.Context(), client, h.authService, h.userService, session, decision, &user.ID, true, false); err != nil {
|
||||
if rollbackCreatedUser(err) {
|
||||
return
|
||||
}
|
||||
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
|
||||
return
|
||||
}
|
||||
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
|
||||
|
||||
if err := h.authService.FinalizeOAuthEmailAccount(
|
||||
c.Request.Context(),
|
||||
user,
|
||||
strings.TrimSpace(req.InvitationCode),
|
||||
strings.TrimSpace(session.ProviderType),
|
||||
); err != nil {
|
||||
if rollbackCreatedUser(err) {
|
||||
return
|
||||
}
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), session.SessionToken, session.BrowserSessionKey); err != nil {
|
||||
if rollbackCreatedUser(err) {
|
||||
return
|
||||
}
|
||||
clearCookies()
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
|
||||
clearCookies()
|
||||
writeOAuthTokenPairResponse(c, tokenPair)
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
@@ -15,6 +16,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/ent/enttest"
|
||||
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
|
||||
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
|
||||
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
|
||||
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
@@ -61,6 +63,18 @@ func TestApplySuggestedProfileToCompletionResponseKeepsExistingPayloadValues(t *
|
||||
require.Equal(t, true, payload["adoption_required"])
|
||||
}
|
||||
|
||||
func TestSetOAuthPendingSessionCookieUsesProviderCompletionPathPrefix(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(recorder)
|
||||
ginCtx.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback", nil)
|
||||
|
||||
setOAuthPendingSessionCookie(ginCtx, "pending-session-token", false)
|
||||
|
||||
cookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
|
||||
require.NotNil(t, cookie)
|
||||
require.Equal(t, "/api/v1/auth/oauth", cookie.Path)
|
||||
}
|
||||
|
||||
func TestExchangePendingOAuthCompletionPreviewThenFinalizeAppliesAdoptionDecision(t *testing.T) {
|
||||
handler, client := newOAuthPendingFlowTestHandler(t, false)
|
||||
ctx := context.Background()
|
||||
@@ -943,6 +957,81 @@ func TestCreateOIDCOAuthAccountBlocksBackendModeBeforeCreatingUser(t *testing.T)
|
||||
require.Nil(t, storedSession.ConsumedAt)
|
||||
}
|
||||
|
||||
func TestCreateOIDCOAuthAccountRollsBackCreatedUserWhenBindingFails(t *testing.T) {
|
||||
handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, true, "fresh@example.com", "246810")
|
||||
ctx := context.Background()
|
||||
|
||||
conflictOwner, err := client.User.Create().
|
||||
SetEmail("owner@example.com").
|
||||
SetUsername("owner-user").
|
||||
SetPasswordHash("hash").
|
||||
SetRole(service.RoleUser).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.AuthIdentity.Create().
|
||||
SetUserID(conflictOwner.ID).
|
||||
SetProviderType("oidc").
|
||||
SetProviderKey("https://issuer.example").
|
||||
SetProviderSubject("oidc-conflict-123").
|
||||
SetMetadata(map[string]any{
|
||||
"username": "owner-user",
|
||||
}).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
invitation, err := client.RedeemCode.Create().
|
||||
SetCode("INVITE123").
|
||||
SetType(service.RedeemTypeInvitation).
|
||||
SetStatus(service.StatusUnused).
|
||||
SetValue(0).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
session, err := client.PendingAuthSession.Create().
|
||||
SetSessionToken("create-account-conflict-session-token").
|
||||
SetIntent("login").
|
||||
SetProviderType("oidc").
|
||||
SetProviderKey("https://issuer.example").
|
||||
SetProviderSubject("oidc-conflict-123").
|
||||
SetBrowserSessionKey("create-account-conflict-browser-session-key").
|
||||
SetUpstreamIdentityClaims(map[string]any{
|
||||
"username": "oidc_user",
|
||||
}).
|
||||
SetRedirectTo("/profile").
|
||||
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
body := bytes.NewBufferString(`{"email":"fresh@example.com","verify_code":"246810","password":"secret-123","invitation_code":"INVITE123"}`)
|
||||
recorder := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(recorder)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
|
||||
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("create-account-conflict-browser-session-key")})
|
||||
ginCtx.Request = req
|
||||
|
||||
handler.CreateOIDCOAuthAccount(ginCtx)
|
||||
|
||||
require.Equal(t, http.StatusInternalServerError, recorder.Code)
|
||||
|
||||
userCount, err := client.User.Query().Where(dbuser.EmailEQ("fresh@example.com")).Count(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Zero(t, userCount)
|
||||
|
||||
storedInvitation, err := client.RedeemCode.Get(ctx, invitation.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, service.StatusUnused, storedInvitation.Status)
|
||||
require.Nil(t, storedInvitation.UsedBy)
|
||||
require.Nil(t, storedInvitation.UsedAt)
|
||||
|
||||
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, storedSession.ConsumedAt)
|
||||
}
|
||||
|
||||
func TestBindOIDCOAuthLoginBindsExistingUserAndConsumesSession(t *testing.T) {
|
||||
handler, client := newOAuthPendingFlowTestHandler(t, false)
|
||||
ctx := context.Background()
|
||||
@@ -1529,6 +1618,8 @@ type oauthPendingFlowTestHandlerOptions struct {
|
||||
defaultSubAssigner service.DefaultSubscriptionAssigner
|
||||
totpCache service.TotpCache
|
||||
totpEncryptor service.SecretEncryptor
|
||||
redeemRepoFactory func(client *dbent.Client) service.RedeemCodeRepository
|
||||
userRepoOptions oauthPendingFlowUserRepoOptions
|
||||
}
|
||||
|
||||
func newOAuthPendingFlowTestHandlerWithDependencies(
|
||||
@@ -1590,7 +1681,17 @@ CREATE TABLE IF NOT EXISTS user_avatars (
|
||||
settingValues[key] = value
|
||||
}
|
||||
settingSvc := service.NewSettingService(&oauthPendingFlowSettingRepoStub{values: settingValues}, cfg)
|
||||
userRepo := &oauthPendingFlowUserRepo{client: client}
|
||||
userRepo := &oauthPendingFlowUserRepo{
|
||||
client: client,
|
||||
options: options.userRepoOptions,
|
||||
}
|
||||
redeemRepo := service.RedeemCodeRepository(nil)
|
||||
if options.redeemRepoFactory != nil {
|
||||
redeemRepo = options.redeemRepoFactory(client)
|
||||
}
|
||||
if redeemRepo == nil {
|
||||
redeemRepo = &oauthPendingFlowRedeemCodeRepo{client: client}
|
||||
}
|
||||
var emailService *service.EmailService
|
||||
if options.emailCache != nil {
|
||||
emailService = service.NewEmailService(&oauthPendingFlowSettingRepoStub{
|
||||
@@ -1602,7 +1703,7 @@ CREATE TABLE IF NOT EXISTS user_avatars (
|
||||
authSvc := service.NewAuthService(
|
||||
client,
|
||||
userRepo,
|
||||
nil,
|
||||
redeemRepo,
|
||||
&oauthPendingFlowRefreshTokenCacheStub{},
|
||||
cfg,
|
||||
settingSvc,
|
||||
@@ -1797,6 +1898,127 @@ func (s *oauthPendingFlowRefreshTokenCacheStub) IsTokenInFamily(context.Context,
|
||||
return false, nil
|
||||
}
|
||||
|
||||
type oauthPendingFlowRedeemCodeRepo struct {
|
||||
client *dbent.Client
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowRedeemCodeRepo) Create(context.Context, *service.RedeemCode) error {
|
||||
panic("unexpected Create call")
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowRedeemCodeRepo) CreateBatch(context.Context, []service.RedeemCode) error {
|
||||
panic("unexpected CreateBatch call")
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowRedeemCodeRepo) GetByID(context.Context, int64) (*service.RedeemCode, error) {
|
||||
panic("unexpected GetByID call")
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowRedeemCodeRepo) GetByCode(ctx context.Context, code string) (*service.RedeemCode, error) {
|
||||
entity, err := r.client.RedeemCode.Query().Where(redeemcode.CodeEQ(code)).Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrRedeemCodeNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
notes := ""
|
||||
if entity.Notes != nil {
|
||||
notes = *entity.Notes
|
||||
}
|
||||
return &service.RedeemCode{
|
||||
ID: entity.ID,
|
||||
Code: entity.Code,
|
||||
Type: entity.Type,
|
||||
Value: entity.Value,
|
||||
Status: entity.Status,
|
||||
UsedBy: entity.UsedBy,
|
||||
UsedAt: entity.UsedAt,
|
||||
Notes: notes,
|
||||
CreatedAt: entity.CreatedAt,
|
||||
GroupID: entity.GroupID,
|
||||
ValidityDays: entity.ValidityDays,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowRedeemCodeRepo) Update(ctx context.Context, code *service.RedeemCode) error {
|
||||
if code == nil {
|
||||
return nil
|
||||
}
|
||||
update := r.client.RedeemCode.UpdateOneID(code.ID).
|
||||
SetCode(code.Code).
|
||||
SetType(code.Type).
|
||||
SetValue(code.Value).
|
||||
SetStatus(code.Status).
|
||||
SetNotes(code.Notes).
|
||||
SetValidityDays(code.ValidityDays)
|
||||
if code.UsedBy != nil {
|
||||
update = update.SetUsedBy(*code.UsedBy)
|
||||
} else {
|
||||
update = update.ClearUsedBy()
|
||||
}
|
||||
if code.UsedAt != nil {
|
||||
update = update.SetUsedAt(*code.UsedAt)
|
||||
} else {
|
||||
update = update.ClearUsedAt()
|
||||
}
|
||||
if code.GroupID != nil {
|
||||
update = update.SetGroupID(*code.GroupID)
|
||||
} else {
|
||||
update = update.ClearGroupID()
|
||||
}
|
||||
_, err := update.Save(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowRedeemCodeRepo) Delete(context.Context, int64) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowRedeemCodeRepo) Use(ctx context.Context, id, userID int64) error {
|
||||
affected, err := r.client.RedeemCode.Update().
|
||||
Where(redeemcode.IDEQ(id), redeemcode.StatusEQ(service.StatusUnused)).
|
||||
SetStatus(service.StatusUsed).
|
||||
SetUsedBy(userID).
|
||||
SetUsedAt(time.Now().UTC()).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected == 0 {
|
||||
return service.ErrRedeemCodeUsed
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowRedeemCodeRepo) List(context.Context, pagination.PaginationParams) ([]service.RedeemCode, *pagination.PaginationResult, error) {
|
||||
panic("unexpected List call")
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowRedeemCodeRepo) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string) ([]service.RedeemCode, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListWithFilters call")
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowRedeemCodeRepo) ListByUser(context.Context, int64, int) ([]service.RedeemCode, error) {
|
||||
panic("unexpected ListByUser call")
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowRedeemCodeRepo) ListByUserPaginated(context.Context, int64, pagination.PaginationParams, string) ([]service.RedeemCode, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListByUserPaginated call")
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowRedeemCodeRepo) SumPositiveBalanceByUser(context.Context, int64) (float64, error) {
|
||||
panic("unexpected SumPositiveBalanceByUser call")
|
||||
}
|
||||
|
||||
type oauthPendingFlowFailingUseRedeemRepo struct {
|
||||
*oauthPendingFlowRedeemCodeRepo
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowFailingUseRedeemRepo) Use(context.Context, int64, int64) error {
|
||||
return errors.New("forced invitation use failure")
|
||||
}
|
||||
|
||||
func decodeJSONResponseData(t *testing.T, recorder *httptest.ResponseRecorder) map[string]any {
|
||||
t.Helper()
|
||||
|
||||
@@ -1872,6 +2094,11 @@ func countProviderGrantRecords(
|
||||
|
||||
type oauthPendingFlowUserRepo struct {
|
||||
client *dbent.Client
|
||||
options oauthPendingFlowUserRepoOptions
|
||||
}
|
||||
|
||||
type oauthPendingFlowUserRepoOptions struct {
|
||||
rejectDeleteWhileAuthIdentityExists bool
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowUserRepo) Create(ctx context.Context, user *service.User) error {
|
||||
@@ -1953,6 +2180,15 @@ func (r *oauthPendingFlowUserRepo) Update(ctx context.Context, user *service.Use
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowUserRepo) Delete(ctx context.Context, id int64) error {
|
||||
if r.options.rejectDeleteWhileAuthIdentityExists {
|
||||
count, err := r.client.AuthIdentity.Query().Where(authidentity.UserIDEQ(id)).Count(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count > 0 {
|
||||
return errors.New("cannot delete user while auth identities still exist")
|
||||
}
|
||||
}
|
||||
return r.client.User.DeleteOneID(id).Exec(ctx)
|
||||
}
|
||||
|
||||
|
||||
@@ -15,14 +15,21 @@ import (
|
||||
// UserHandler handles user-related requests
|
||||
type UserHandler struct {
|
||||
userService *service.UserService
|
||||
authService *service.AuthService
|
||||
emailService *service.EmailService
|
||||
emailCache service.EmailCache
|
||||
}
|
||||
|
||||
// NewUserHandler creates a new UserHandler
|
||||
func NewUserHandler(userService *service.UserService, emailService *service.EmailService, emailCache service.EmailCache) *UserHandler {
|
||||
func NewUserHandler(
|
||||
userService *service.UserService,
|
||||
authService *service.AuthService,
|
||||
emailService *service.EmailService,
|
||||
emailCache service.EmailCache,
|
||||
) *UserHandler {
|
||||
return &UserHandler{
|
||||
userService: userService,
|
||||
authService: authService,
|
||||
emailService: emailService,
|
||||
emailCache: emailCache,
|
||||
}
|
||||
@@ -157,6 +164,16 @@ type StartIdentityBindingRequest struct {
|
||||
RedirectTo string `json:"redirect_to"`
|
||||
}
|
||||
|
||||
type BindEmailIdentityRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
VerifyCode string `json:"verify_code" binding:"required"`
|
||||
Password string `json:"password" binding:"required,min=6"`
|
||||
}
|
||||
|
||||
type SendEmailBindingCodeRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
}
|
||||
|
||||
// StartIdentityBinding returns the backend authorize URL for starting a third-party identity bind flow.
|
||||
// POST /api/v1/user/auth-identities/bind/start
|
||||
func (h *UserHandler) StartIdentityBinding(c *gin.Context) {
|
||||
@@ -183,6 +200,73 @@ func (h *UserHandler) StartIdentityBinding(c *gin.Context) {
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// BindEmailIdentity verifies and binds a local email identity for the current user.
|
||||
// POST /api/v1/user/account-bindings/email
|
||||
func (h *UserHandler) BindEmailIdentity(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
if h.authService == nil {
|
||||
response.InternalError(c, "Auth service not configured")
|
||||
return
|
||||
}
|
||||
|
||||
var req BindEmailIdentityRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
updatedUser, err := h.authService.BindEmailIdentity(
|
||||
c.Request.Context(),
|
||||
subject.UserID,
|
||||
req.Email,
|
||||
req.VerifyCode,
|
||||
req.Password,
|
||||
)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, updatedUser)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, profileResp)
|
||||
}
|
||||
|
||||
// SendEmailBindingCode sends a verification code for the current user's email binding flow.
|
||||
// POST /api/v1/user/account-bindings/email/send-code
|
||||
func (h *UserHandler) SendEmailBindingCode(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
if h.authService == nil {
|
||||
response.InternalError(c, "Auth service not configured")
|
||||
return
|
||||
}
|
||||
|
||||
var req SendEmailBindingCodeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.authService.SendEmailIdentityBindCode(c.Request.Context(), subject.UserID, req.Email); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Verification code sent successfully"})
|
||||
}
|
||||
|
||||
// SendNotifyEmailCodeRequest represents the request to send notify email verification code
|
||||
type SendNotifyEmailCodeRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
@@ -122,7 +123,7 @@ func TestUserHandlerUpdateProfileReturnsAvatarURL(t *testing.T) {
|
||||
Status: service.StatusActive,
|
||||
},
|
||||
}
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
|
||||
|
||||
body := []byte(`{"avatar_url":"https://cdn.example.com/avatar.png"}`)
|
||||
recorder := httptest.NewRecorder()
|
||||
@@ -180,7 +181,7 @@ func TestUserHandlerGetProfileReturnsIdentitySummaries(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
@@ -262,7 +263,7 @@ func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
@@ -311,6 +312,116 @@ func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) {
|
||||
require.Equal(t, "linuxdo", usernameSource["source"])
|
||||
}
|
||||
|
||||
type userHandlerEmailCacheStub struct {
|
||||
data *service.VerificationCodeData
|
||||
}
|
||||
|
||||
func (s *userHandlerEmailCacheStub) GetVerificationCode(context.Context, string) (*service.VerificationCodeData, error) {
|
||||
return s.data, nil
|
||||
}
|
||||
|
||||
func (s *userHandlerEmailCacheStub) SetVerificationCode(context.Context, string, *service.VerificationCodeData, time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *userHandlerEmailCacheStub) DeleteVerificationCode(context.Context, string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *userHandlerEmailCacheStub) GetNotifyVerifyCode(context.Context, string) (*service.VerificationCodeData, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *userHandlerEmailCacheStub) SetNotifyVerifyCode(context.Context, string, *service.VerificationCodeData, time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *userHandlerEmailCacheStub) DeleteNotifyVerifyCode(context.Context, string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *userHandlerEmailCacheStub) GetPasswordResetToken(context.Context, string) (*service.PasswordResetTokenData, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *userHandlerEmailCacheStub) SetPasswordResetToken(context.Context, string, *service.PasswordResetTokenData, time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *userHandlerEmailCacheStub) DeletePasswordResetToken(context.Context, string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *userHandlerEmailCacheStub) IsPasswordResetEmailInCooldown(context.Context, string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *userHandlerEmailCacheStub) SetPasswordResetEmailCooldown(context.Context, string, time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *userHandlerEmailCacheStub) GetNotifyCodeUserRate(context.Context, int64) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *userHandlerEmailCacheStub) IncrNotifyCodeUserRate(context.Context, int64, time.Duration) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func TestUserHandlerBindEmailIdentityReturnsProfileResponse(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
repo := &userHandlerRepoStub{
|
||||
user: &service.User{
|
||||
ID: 11,
|
||||
Email: "legacy-user" + service.LinuxDoConnectSyntheticEmailDomain,
|
||||
Username: "legacy-user",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
},
|
||||
}
|
||||
emailCache := &userHandlerEmailCacheStub{
|
||||
data: &service.VerificationCodeData{
|
||||
Code: "123456",
|
||||
CreatedAt: time.Now().UTC(),
|
||||
ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{
|
||||
JWT: config.JWTConfig{
|
||||
Secret: "test-secret",
|
||||
ExpireHour: 1,
|
||||
},
|
||||
}
|
||||
emailService := service.NewEmailService(nil, emailCache)
|
||||
authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil)
|
||||
|
||||
body := []byte(`{"email":"new@example.com","verify_code":"123456","password":"new-password"}`)
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/user/account-bindings/email", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
c.Params = gin.Params{{Key: "provider", Value: "email"}}
|
||||
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 11})
|
||||
|
||||
handler.BindEmailIdentity(c)
|
||||
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
|
||||
var resp struct {
|
||||
Code int `json:"code"`
|
||||
Data struct {
|
||||
Email string `json:"email"`
|
||||
EmailBound bool `json:"email_bound"`
|
||||
} `json:"data"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
|
||||
require.Equal(t, 0, resp.Code)
|
||||
require.Equal(t, "new@example.com", resp.Data.Email)
|
||||
require.True(t, resp.Data.EmailBound)
|
||||
}
|
||||
|
||||
func TestUserHandlerStartIdentityBindingReturnsAuthorizeURL(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
@@ -323,7 +434,7 @@ func TestUserHandlerStartIdentityBindingReturnsAuthorizeURL(t *testing.T) {
|
||||
Status: service.StatusActive,
|
||||
},
|
||||
}
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
|
||||
|
||||
body := []byte(`{"provider":"wechat","redirect_to":"/settings/profile"}`)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
@@ -74,6 +74,12 @@ func RegisterAuthRoutes(
|
||||
}),
|
||||
h.Auth.ExchangePendingOAuthCompletion,
|
||||
)
|
||||
auth.POST("/oauth/pending/send-verify-code",
|
||||
rateLimiter.LimitWithOptions("oauth-pending-send-verify-code", 5, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}),
|
||||
h.Auth.SendPendingOAuthVerifyCode,
|
||||
)
|
||||
auth.POST("/oauth/pending/create-account",
|
||||
rateLimiter.LimitWithOptions("oauth-pending-create-account", 10, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
|
||||
@@ -52,6 +52,7 @@ func TestAuthRoutesRateLimitFailCloseWhenRedisUnavailable(t *testing.T) {
|
||||
"/api/v1/auth/login",
|
||||
"/api/v1/auth/login/2fa",
|
||||
"/api/v1/auth/send-verify-code",
|
||||
"/api/v1/auth/oauth/pending/send-verify-code",
|
||||
}
|
||||
|
||||
for _, path := range paths {
|
||||
|
||||
@@ -25,6 +25,8 @@ func RegisterUserRoutes(
|
||||
user.GET("/profile", h.User.GetProfile)
|
||||
user.PUT("/password", h.User.ChangePassword)
|
||||
user.PUT("", h.User.UpdateProfile)
|
||||
user.POST("/account-bindings/email/send-code", h.User.SendEmailBindingCode)
|
||||
user.POST("/account-bindings/email", h.User.BindEmailIdentity)
|
||||
user.POST("/auth-identities/bind/start", h.User.StartIdentityBinding)
|
||||
|
||||
// 通知邮箱管理
|
||||
|
||||
128
backend/internal/service/auth_email_binding.go
Normal file
128
backend/internal/service/auth_email_binding.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/mail"
|
||||
"strings"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
)
|
||||
|
||||
// BindEmailIdentity verifies and binds a local email/password identity to the current user.
|
||||
func (s *AuthService) BindEmailIdentity(
|
||||
ctx context.Context,
|
||||
userID int64,
|
||||
email string,
|
||||
verifyCode string,
|
||||
password string,
|
||||
) (*User, error) {
|
||||
if s == nil {
|
||||
return nil, ErrServiceUnavailable
|
||||
}
|
||||
|
||||
normalizedEmail, err := normalizeEmailForIdentityBinding(email)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if isReservedEmail(normalizedEmail) {
|
||||
return nil, ErrEmailReserved
|
||||
}
|
||||
if strings.TrimSpace(password) == "" {
|
||||
return nil, ErrPasswordRequired
|
||||
}
|
||||
if err := s.VerifyOAuthEmailCode(ctx, normalizedEmail, verifyCode); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
currentUser, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
existingUser, err := s.userRepo.GetByEmail(ctx, normalizedEmail)
|
||||
switch {
|
||||
case err == nil && existingUser != nil && existingUser.ID != userID:
|
||||
return nil, ErrEmailExists
|
||||
case err != nil && !errors.Is(err, ErrUserNotFound):
|
||||
return nil, ErrServiceUnavailable
|
||||
}
|
||||
|
||||
hashedPassword, err := s.HashPassword(password)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("hash password: %w", err)
|
||||
}
|
||||
|
||||
firstRealEmailBind := !hasBindableEmailIdentitySubject(currentUser.Email)
|
||||
currentUser.Email = normalizedEmail
|
||||
currentUser.PasswordHash = hashedPassword
|
||||
if err := s.userRepo.Update(ctx, currentUser); err != nil {
|
||||
if errors.Is(err, ErrEmailExists) {
|
||||
return nil, ErrEmailExists
|
||||
}
|
||||
return nil, ErrServiceUnavailable
|
||||
}
|
||||
|
||||
if firstRealEmailBind {
|
||||
if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, userID, "email"); err != nil {
|
||||
return nil, fmt.Errorf("apply email first bind defaults: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return currentUser, nil
|
||||
}
|
||||
|
||||
// SendEmailIdentityBindCode sends a verification code for authenticated email binding flows.
|
||||
func (s *AuthService) SendEmailIdentityBindCode(ctx context.Context, userID int64, email string) error {
|
||||
if s == nil {
|
||||
return ErrServiceUnavailable
|
||||
}
|
||||
|
||||
normalizedEmail, err := normalizeEmailForIdentityBinding(email)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if isReservedEmail(normalizedEmail) {
|
||||
return ErrEmailReserved
|
||||
}
|
||||
if s.emailService == nil {
|
||||
return ErrServiceUnavailable
|
||||
}
|
||||
if _, err := s.userRepo.GetByID(ctx, userID); err != nil {
|
||||
if errors.Is(err, ErrUserNotFound) {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
return ErrServiceUnavailable
|
||||
}
|
||||
|
||||
existingUser, err := s.userRepo.GetByEmail(ctx, normalizedEmail)
|
||||
switch {
|
||||
case err == nil && existingUser != nil && existingUser.ID != userID:
|
||||
return ErrEmailExists
|
||||
case err != nil && !errors.Is(err, ErrUserNotFound):
|
||||
return ErrServiceUnavailable
|
||||
}
|
||||
|
||||
siteName := "Sub2API"
|
||||
if s.settingService != nil {
|
||||
siteName = s.settingService.GetSiteName(ctx)
|
||||
}
|
||||
return s.emailService.SendVerifyCode(ctx, normalizedEmail, siteName)
|
||||
}
|
||||
|
||||
func normalizeEmailForIdentityBinding(email string) (string, error) {
|
||||
normalized := strings.ToLower(strings.TrimSpace(email))
|
||||
if normalized == "" || len(normalized) > 255 {
|
||||
return "", infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
|
||||
}
|
||||
if _, err := mail.ParseAddress(normalized); err != nil {
|
||||
return "", infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
|
||||
}
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
func hasBindableEmailIdentitySubject(email string) bool {
|
||||
normalized := strings.ToLower(strings.TrimSpace(email))
|
||||
return normalized != "" && !isReservedEmail(normalized)
|
||||
}
|
||||
@@ -4,9 +4,71 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/mail"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func normalizeOAuthSignupSource(signupSource string) string {
|
||||
signupSource = strings.TrimSpace(strings.ToLower(signupSource))
|
||||
if signupSource == "" {
|
||||
return "email"
|
||||
}
|
||||
return signupSource
|
||||
}
|
||||
|
||||
// SendPendingOAuthVerifyCode sends a local verification code for pending OAuth
|
||||
// account-creation flows without relying on the public registration gate.
|
||||
func (s *AuthService) SendPendingOAuthVerifyCode(ctx context.Context, email string) (*SendVerifyCodeResult, error) {
|
||||
email = strings.TrimSpace(strings.ToLower(email))
|
||||
if email == "" {
|
||||
return nil, ErrEmailVerifyRequired
|
||||
}
|
||||
if _, err := mail.ParseAddress(email); err != nil {
|
||||
return nil, ErrEmailVerifyRequired
|
||||
}
|
||||
if isReservedEmail(email) {
|
||||
return nil, ErrEmailReserved
|
||||
}
|
||||
if s == nil || s.emailService == nil {
|
||||
return nil, ErrServiceUnavailable
|
||||
}
|
||||
|
||||
siteName := "Sub2API"
|
||||
if s.settingService != nil {
|
||||
siteName = s.settingService.GetSiteName(ctx)
|
||||
}
|
||||
if err := s.emailService.SendVerifyCode(ctx, email, siteName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &SendVerifyCodeResult{
|
||||
Countdown: int(verifyCodeCooldown / time.Second),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *AuthService) validateOAuthRegistrationInvitation(ctx context.Context, invitationCode string) (*RedeemCode, error) {
|
||||
if s == nil || s.settingService == nil || !s.settingService.IsInvitationCodeEnabled(ctx) {
|
||||
return nil, nil
|
||||
}
|
||||
if s.redeemRepo == nil {
|
||||
return nil, ErrServiceUnavailable
|
||||
}
|
||||
|
||||
invitationCode = strings.TrimSpace(invitationCode)
|
||||
if invitationCode == "" {
|
||||
return nil, ErrInvitationCodeRequired
|
||||
}
|
||||
|
||||
redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode)
|
||||
if err != nil {
|
||||
return nil, ErrInvitationCodeInvalid
|
||||
}
|
||||
if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused {
|
||||
return nil, ErrInvitationCodeInvalid
|
||||
}
|
||||
return redeemCode, nil
|
||||
}
|
||||
|
||||
// VerifyOAuthEmailCode verifies the locally entered email verification code for
|
||||
// third-party signup and binding flows. This is intentionally independent from
|
||||
// the global registration email verification toggle.
|
||||
@@ -54,19 +116,8 @@ func (s *AuthService) RegisterOAuthEmailAccount(
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
var invitationRedeemCode *RedeemCode
|
||||
if s.settingService.IsInvitationCodeEnabled(ctx) {
|
||||
if invitationCode == "" {
|
||||
return nil, nil, ErrInvitationCodeRequired
|
||||
}
|
||||
redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode)
|
||||
if err != nil {
|
||||
return nil, nil, ErrInvitationCodeInvalid
|
||||
}
|
||||
if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused {
|
||||
return nil, nil, ErrInvitationCodeInvalid
|
||||
}
|
||||
invitationRedeemCode = redeemCode
|
||||
if _, err := s.validateOAuthRegistrationInvitation(ctx, invitationCode); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
|
||||
@@ -104,22 +155,91 @@ func (s *AuthService) RegisterOAuthEmailAccount(
|
||||
return nil, nil, ErrServiceUnavailable
|
||||
}
|
||||
|
||||
s.postAuthUserBootstrap(ctx, user, signupSource, false)
|
||||
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
|
||||
|
||||
if invitationRedeemCode != nil {
|
||||
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
|
||||
return nil, nil, ErrInvitationCodeInvalid
|
||||
}
|
||||
}
|
||||
|
||||
tokenPair, err := s.GenerateTokenPair(ctx, user, "")
|
||||
if err != nil {
|
||||
_ = s.RollbackOAuthEmailAccountCreation(ctx, user.ID, "")
|
||||
return nil, nil, fmt.Errorf("generate token pair: %w", err)
|
||||
}
|
||||
return tokenPair, user, nil
|
||||
}
|
||||
|
||||
// FinalizeOAuthEmailAccount applies invitation usage and normal signup bootstrap
|
||||
// only after the pending OAuth flow has fully reached its last reversible step.
|
||||
func (s *AuthService) FinalizeOAuthEmailAccount(
|
||||
ctx context.Context,
|
||||
user *User,
|
||||
invitationCode string,
|
||||
signupSource string,
|
||||
) error {
|
||||
if s == nil || user == nil || user.ID <= 0 {
|
||||
return ErrServiceUnavailable
|
||||
}
|
||||
|
||||
signupSource = normalizeOAuthSignupSource(signupSource)
|
||||
invitationRedeemCode, err := s.validateOAuthRegistrationInvitation(ctx, invitationCode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if invitationRedeemCode != nil {
|
||||
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
|
||||
return ErrInvitationCodeInvalid
|
||||
}
|
||||
}
|
||||
|
||||
s.postAuthUserBootstrap(ctx, user, signupSource, false)
|
||||
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
|
||||
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
|
||||
return nil
|
||||
}
|
||||
|
||||
// RollbackOAuthEmailAccountCreation removes a partially-created local account
|
||||
// and restores any invitation code already consumed by that account.
|
||||
func (s *AuthService) RollbackOAuthEmailAccountCreation(ctx context.Context, userID int64, invitationCode string) error {
|
||||
if s == nil || s.userRepo == nil || userID <= 0 {
|
||||
return ErrServiceUnavailable
|
||||
}
|
||||
if err := s.restoreOAuthRegistrationInvitation(ctx, invitationCode, userID); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.userRepo.Delete(ctx, userID); err != nil {
|
||||
return fmt.Errorf("delete created oauth user: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AuthService) restoreOAuthRegistrationInvitation(ctx context.Context, invitationCode string, userID int64) error {
|
||||
if s == nil || s.settingService == nil || !s.settingService.IsInvitationCodeEnabled(ctx) {
|
||||
return nil
|
||||
}
|
||||
if s.redeemRepo == nil {
|
||||
return ErrServiceUnavailable
|
||||
}
|
||||
|
||||
invitationCode = strings.TrimSpace(invitationCode)
|
||||
if invitationCode == "" || userID <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrRedeemCodeNotFound) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("load invitation code: %w", err)
|
||||
}
|
||||
if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUsed || redeemCode.UsedBy == nil || *redeemCode.UsedBy != userID {
|
||||
return nil
|
||||
}
|
||||
|
||||
redeemCode.Status = StatusUnused
|
||||
redeemCode.UsedBy = nil
|
||||
redeemCode.UsedAt = nil
|
||||
if err := s.redeemRepo.Update(ctx, redeemCode); err != nil {
|
||||
return fmt.Errorf("restore invitation code: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidatePasswordCredentials checks the local password without completing the
|
||||
// login flow. This is used by pending third-party account adoption flows before
|
||||
// the external identity has been bound.
|
||||
|
||||
251
backend/internal/service/auth_oauth_email_flow_test.go
Normal file
251
backend/internal/service/auth_oauth_email_flow_test.go
Normal file
@@ -0,0 +1,251 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type redeemCodeRepoStub struct {
|
||||
codesByCode map[string]*RedeemCode
|
||||
useCalls []struct {
|
||||
id int64
|
||||
userID int64
|
||||
}
|
||||
updateCalls []*RedeemCode
|
||||
}
|
||||
|
||||
func (s *redeemCodeRepoStub) Create(context.Context, *RedeemCode) error {
|
||||
panic("unexpected Create call")
|
||||
}
|
||||
|
||||
func (s *redeemCodeRepoStub) CreateBatch(context.Context, []RedeemCode) error {
|
||||
panic("unexpected CreateBatch call")
|
||||
}
|
||||
|
||||
func (s *redeemCodeRepoStub) GetByID(context.Context, int64) (*RedeemCode, error) {
|
||||
panic("unexpected GetByID call")
|
||||
}
|
||||
|
||||
func (s *redeemCodeRepoStub) GetByCode(_ context.Context, code string) (*RedeemCode, error) {
|
||||
if s.codesByCode == nil {
|
||||
return nil, ErrRedeemCodeNotFound
|
||||
}
|
||||
redeemCode, ok := s.codesByCode[code]
|
||||
if !ok {
|
||||
return nil, ErrRedeemCodeNotFound
|
||||
}
|
||||
cloned := *redeemCode
|
||||
return &cloned, nil
|
||||
}
|
||||
|
||||
func (s *redeemCodeRepoStub) Update(_ context.Context, code *RedeemCode) error {
|
||||
if code == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *code
|
||||
s.updateCalls = append(s.updateCalls, &cloned)
|
||||
if s.codesByCode == nil {
|
||||
s.codesByCode = make(map[string]*RedeemCode)
|
||||
}
|
||||
s.codesByCode[cloned.Code] = &cloned
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *redeemCodeRepoStub) Delete(context.Context, int64) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
func (s *redeemCodeRepoStub) Use(_ context.Context, id, userID int64) error {
|
||||
for code, redeemCode := range s.codesByCode {
|
||||
if redeemCode.ID != id {
|
||||
continue
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
redeemCode.Status = StatusUsed
|
||||
redeemCode.UsedBy = &userID
|
||||
redeemCode.UsedAt = &now
|
||||
s.codesByCode[code] = redeemCode
|
||||
s.useCalls = append(s.useCalls, struct {
|
||||
id int64
|
||||
userID int64
|
||||
}{id: id, userID: userID})
|
||||
return nil
|
||||
}
|
||||
return ErrRedeemCodeNotFound
|
||||
}
|
||||
|
||||
func (s *redeemCodeRepoStub) List(context.Context, pagination.PaginationParams) ([]RedeemCode, *pagination.PaginationResult, error) {
|
||||
panic("unexpected List call")
|
||||
}
|
||||
|
||||
func (s *redeemCodeRepoStub) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string) ([]RedeemCode, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListWithFilters call")
|
||||
}
|
||||
|
||||
func (s *redeemCodeRepoStub) ListByUser(context.Context, int64, int) ([]RedeemCode, error) {
|
||||
panic("unexpected ListByUser call")
|
||||
}
|
||||
|
||||
func (s *redeemCodeRepoStub) ListByUserPaginated(context.Context, int64, pagination.PaginationParams, string) ([]RedeemCode, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListByUserPaginated call")
|
||||
}
|
||||
|
||||
func (s *redeemCodeRepoStub) SumPositiveBalanceByUser(context.Context, int64) (float64, error) {
|
||||
panic("unexpected SumPositiveBalanceByUser call")
|
||||
}
|
||||
|
||||
func newOAuthEmailFlowAuthService(
|
||||
userRepo UserRepository,
|
||||
redeemRepo RedeemCodeRepository,
|
||||
refreshTokenCache RefreshTokenCache,
|
||||
settings map[string]string,
|
||||
emailCache EmailCache,
|
||||
) *AuthService {
|
||||
cfg := &config.Config{
|
||||
JWT: config.JWTConfig{
|
||||
Secret: "test-secret",
|
||||
ExpireHour: 1,
|
||||
AccessTokenExpireMinutes: 60,
|
||||
RefreshTokenExpireDays: 7,
|
||||
},
|
||||
Default: config.DefaultConfig{
|
||||
UserBalance: 3.5,
|
||||
UserConcurrency: 2,
|
||||
},
|
||||
}
|
||||
|
||||
settingService := NewSettingService(&settingRepoStub{values: settings}, cfg)
|
||||
emailService := NewEmailService(&settingRepoStub{values: settings}, emailCache)
|
||||
|
||||
return NewAuthService(
|
||||
nil,
|
||||
userRepo,
|
||||
redeemRepo,
|
||||
refreshTokenCache,
|
||||
cfg,
|
||||
settingService,
|
||||
emailService,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
func TestRegisterOAuthEmailAccountRollsBackCreatedUserWhenTokenPairGenerationFails(t *testing.T) {
|
||||
userRepo := &userRepoStub{nextID: 42}
|
||||
redeemRepo := &redeemCodeRepoStub{
|
||||
codesByCode: map[string]*RedeemCode{
|
||||
"INVITE123": {
|
||||
ID: 7,
|
||||
Code: "INVITE123",
|
||||
Type: RedeemTypeInvitation,
|
||||
Status: StatusUnused,
|
||||
},
|
||||
},
|
||||
}
|
||||
emailCache := &emailCacheStub{
|
||||
data: &VerificationCodeData{
|
||||
Code: "246810",
|
||||
Attempts: 0,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
|
||||
},
|
||||
}
|
||||
authService := newOAuthEmailFlowAuthService(
|
||||
userRepo,
|
||||
redeemRepo,
|
||||
nil,
|
||||
map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyInvitationCodeEnabled: "true",
|
||||
SettingKeyEmailVerifyEnabled: "true",
|
||||
},
|
||||
emailCache,
|
||||
)
|
||||
|
||||
tokenPair, user, err := authService.RegisterOAuthEmailAccount(
|
||||
context.Background(),
|
||||
"fresh@example.com",
|
||||
"secret-123",
|
||||
"246810",
|
||||
"INVITE123",
|
||||
"oidc",
|
||||
)
|
||||
|
||||
require.Nil(t, tokenPair)
|
||||
require.Nil(t, user)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "generate token pair")
|
||||
require.Equal(t, []int64{42}, userRepo.deletedIDs)
|
||||
require.Len(t, userRepo.created, 1)
|
||||
require.Empty(t, redeemRepo.useCalls)
|
||||
require.Empty(t, redeemRepo.updateCalls)
|
||||
}
|
||||
|
||||
func TestRollbackOAuthEmailAccountCreationRestoresInvitationUsage(t *testing.T) {
|
||||
userRepo := &userRepoStub{}
|
||||
redeemRepo := &redeemCodeRepoStub{
|
||||
codesByCode: map[string]*RedeemCode{
|
||||
"INVITE123": {
|
||||
ID: 7,
|
||||
Code: "INVITE123",
|
||||
Type: RedeemTypeInvitation,
|
||||
Status: StatusUsed,
|
||||
UsedBy: func() *int64 {
|
||||
v := int64(42)
|
||||
return &v
|
||||
}(),
|
||||
UsedAt: func() *time.Time {
|
||||
v := time.Now().UTC()
|
||||
return &v
|
||||
}(),
|
||||
},
|
||||
},
|
||||
}
|
||||
authService := newOAuthEmailFlowAuthService(
|
||||
userRepo,
|
||||
redeemRepo,
|
||||
&refreshTokenCacheStub{},
|
||||
map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyInvitationCodeEnabled: "true",
|
||||
},
|
||||
&emailCacheStub{},
|
||||
)
|
||||
|
||||
err := authService.RollbackOAuthEmailAccountCreation(context.Background(), 42, "INVITE123")
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []int64{42}, userRepo.deletedIDs)
|
||||
require.Len(t, redeemRepo.updateCalls, 1)
|
||||
require.Equal(t, StatusUnused, redeemRepo.updateCalls[0].Status)
|
||||
require.Nil(t, redeemRepo.updateCalls[0].UsedBy)
|
||||
require.Nil(t, redeemRepo.updateCalls[0].UsedAt)
|
||||
}
|
||||
|
||||
func TestRollbackOAuthEmailAccountCreationPropagatesDeleteError(t *testing.T) {
|
||||
userRepo := &userRepoStub{deleteErr: errors.New("delete failed")}
|
||||
authService := newOAuthEmailFlowAuthService(
|
||||
userRepo,
|
||||
&redeemCodeRepoStub{},
|
||||
&refreshTokenCacheStub{},
|
||||
map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
},
|
||||
&emailCacheStub{},
|
||||
)
|
||||
|
||||
err := authService.RollbackOAuthEmailAccountCreation(context.Background(), 42, "")
|
||||
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "delete created oauth user")
|
||||
}
|
||||
316
backend/internal/service/auth_service_email_bind_test.go
Normal file
316
backend/internal/service/auth_service_email_bind_test.go
Normal file
@@ -0,0 +1,316 @@
|
||||
//go:build unit
|
||||
|
||||
package service_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/authidentity"
|
||||
"github.com/Wei-Shaw/sub2api/ent/enttest"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/repository"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"entgo.io/ent/dialect"
|
||||
entsql "entgo.io/ent/dialect/sql"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
type emailBindDefaultSubAssignerStub struct {
|
||||
calls []*service.AssignSubscriptionInput
|
||||
}
|
||||
|
||||
func (s *emailBindDefaultSubAssignerStub) AssignOrExtendSubscription(
|
||||
_ context.Context,
|
||||
input *service.AssignSubscriptionInput,
|
||||
) (*service.UserSubscription, bool, error) {
|
||||
cloned := *input
|
||||
s.calls = append(s.calls, &cloned)
|
||||
return &service.UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, false, nil
|
||||
}
|
||||
|
||||
func newAuthServiceForEmailBind(
|
||||
t *testing.T,
|
||||
settings map[string]string,
|
||||
emailCache service.EmailCache,
|
||||
defaultSubAssigner service.DefaultSubscriptionAssigner,
|
||||
) (*service.AuthService, service.UserRepository, *dbent.Client) {
|
||||
t.Helper()
|
||||
|
||||
db, err := sql.Open("sqlite", "file:auth_service_email_bind?mode=memory&cache=shared")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { _ = db.Close() })
|
||||
|
||||
_, err = db.Exec("PRAGMA foreign_keys = ON")
|
||||
require.NoError(t, err)
|
||||
_, err = db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS user_provider_default_grants (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER NOT NULL,
|
||||
provider_type TEXT NOT NULL,
|
||||
grant_reason TEXT NOT NULL DEFAULT 'first_bind',
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE(user_id, provider_type, grant_reason)
|
||||
)`)
|
||||
require.NoError(t, err)
|
||||
|
||||
drv := entsql.OpenDB(dialect.SQLite, db)
|
||||
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
|
||||
t.Cleanup(func() { _ = client.Close() })
|
||||
|
||||
repo := repository.NewUserRepository(client, db)
|
||||
cfg := &config.Config{
|
||||
JWT: config.JWTConfig{
|
||||
Secret: "test-bind-email-secret",
|
||||
ExpireHour: 1,
|
||||
},
|
||||
Default: config.DefaultConfig{
|
||||
UserBalance: 3.5,
|
||||
UserConcurrency: 2,
|
||||
},
|
||||
}
|
||||
|
||||
settingRepo := &emailBindSettingRepoStub{values: settings}
|
||||
settingSvc := service.NewSettingService(settingRepo, cfg)
|
||||
|
||||
var emailSvc *service.EmailService
|
||||
if emailCache != nil {
|
||||
emailSvc = service.NewEmailService(settingRepo, emailCache)
|
||||
}
|
||||
|
||||
svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, emailSvc, nil, nil, nil, defaultSubAssigner)
|
||||
return svc, repo, client
|
||||
}
|
||||
|
||||
func TestAuthServiceBindEmailIdentity_UpdatesEmailAndAppliesFirstBindDefaults(t *testing.T) {
|
||||
assigner := &emailBindDefaultSubAssignerStub{}
|
||||
cache := &emailBindCacheStub{
|
||||
data: &service.VerificationCodeData{
|
||||
Code: "123456",
|
||||
CreatedAt: time.Now().UTC(),
|
||||
ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
|
||||
},
|
||||
}
|
||||
svc, _, client := newAuthServiceForEmailBind(t, map[string]string{
|
||||
service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
|
||||
service.SettingKeyAuthSourceDefaultEmailConcurrency: "4",
|
||||
service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
|
||||
service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
|
||||
}, cache, assigner)
|
||||
|
||||
ctx := context.Background()
|
||||
user, err := client.User.Create().
|
||||
SetEmail("legacy-user" + service.LinuxDoConnectSyntheticEmailDomain).
|
||||
SetUsername("legacy-user").
|
||||
SetPasswordHash("old-hash").
|
||||
SetBalance(2.5).
|
||||
SetConcurrency(1).
|
||||
SetRole(service.RoleUser).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
updatedUser, err := svc.BindEmailIdentity(ctx, user.ID, " NewEmail@Example.com ", "123456", "new-password")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, updatedUser)
|
||||
require.Equal(t, "newemail@example.com", updatedUser.Email)
|
||||
|
||||
storedUser, err := client.User.Get(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "newemail@example.com", storedUser.Email)
|
||||
require.Equal(t, 11.0, storedUser.Balance)
|
||||
require.Equal(t, 5, storedUser.Concurrency)
|
||||
require.True(t, svc.CheckPassword("new-password", storedUser.PasswordHash))
|
||||
|
||||
identityCount, err := client.AuthIdentity.Query().
|
||||
Where(
|
||||
authidentity.UserIDEQ(user.ID),
|
||||
authidentity.ProviderTypeEQ("email"),
|
||||
authidentity.ProviderKeyEQ("email"),
|
||||
authidentity.ProviderSubjectEQ("newemail@example.com"),
|
||||
).
|
||||
Count(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, identityCount)
|
||||
|
||||
require.Len(t, assigner.calls, 1)
|
||||
require.Equal(t, user.ID, assigner.calls[0].UserID)
|
||||
require.Equal(t, int64(11), assigner.calls[0].GroupID)
|
||||
require.Equal(t, 30, assigner.calls[0].ValidityDays)
|
||||
require.Equal(t, 1, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
|
||||
}
|
||||
|
||||
func TestAuthServiceBindEmailIdentity_RejectsExistingEmailOnAnotherUser(t *testing.T) {
|
||||
cache := &emailBindCacheStub{
|
||||
data: &service.VerificationCodeData{
|
||||
Code: "123456",
|
||||
CreatedAt: time.Now().UTC(),
|
||||
ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
|
||||
},
|
||||
}
|
||||
svc, _, client := newAuthServiceForEmailBind(t, nil, cache, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
sourceUser, err := client.User.Create().
|
||||
SetEmail("source-user" + service.OIDCConnectSyntheticEmailDomain).
|
||||
SetUsername("source-user").
|
||||
SetPasswordHash("old-hash").
|
||||
SetBalance(1).
|
||||
SetConcurrency(1).
|
||||
SetRole(service.RoleUser).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
_, err = client.User.Create().
|
||||
SetEmail("taken@example.com").
|
||||
SetUsername("taken-user").
|
||||
SetPasswordHash("hash").
|
||||
SetBalance(1).
|
||||
SetConcurrency(1).
|
||||
SetRole(service.RoleUser).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
updatedUser, err := svc.BindEmailIdentity(ctx, sourceUser.ID, "taken@example.com", "123456", "new-password")
|
||||
require.ErrorIs(t, err, service.ErrEmailExists)
|
||||
require.Nil(t, updatedUser)
|
||||
|
||||
storedUser, err := client.User.Get(ctx, sourceUser.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "source-user"+service.OIDCConnectSyntheticEmailDomain, storedUser.Email)
|
||||
require.Equal(t, 0, countProviderGrantRecords(t, client, sourceUser.ID, "email", "first_bind"))
|
||||
}
|
||||
|
||||
func TestAuthServiceBindEmailIdentity_RejectsReservedEmail(t *testing.T) {
|
||||
cache := &emailBindCacheStub{
|
||||
data: &service.VerificationCodeData{
|
||||
Code: "123456",
|
||||
CreatedAt: time.Now().UTC(),
|
||||
ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
|
||||
},
|
||||
}
|
||||
svc, _, client := newAuthServiceForEmailBind(t, nil, cache, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
user, err := client.User.Create().
|
||||
SetEmail("source-user@example.com").
|
||||
SetUsername("source-user").
|
||||
SetPasswordHash("old-hash").
|
||||
SetBalance(1).
|
||||
SetConcurrency(1).
|
||||
SetRole(service.RoleUser).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
updatedUser, err := svc.BindEmailIdentity(ctx, user.ID, "reserved"+service.LinuxDoConnectSyntheticEmailDomain, "123456", "new-password")
|
||||
require.ErrorIs(t, err, service.ErrEmailReserved)
|
||||
require.Nil(t, updatedUser)
|
||||
}
|
||||
|
||||
type emailBindSettingRepoStub struct {
|
||||
values map[string]string
|
||||
}
|
||||
|
||||
func (s *emailBindSettingRepoStub) Get(context.Context, string) (*service.Setting, error) {
|
||||
panic("unexpected Get call")
|
||||
}
|
||||
|
||||
func (s *emailBindSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
|
||||
if v, ok := s.values[key]; ok {
|
||||
return v, nil
|
||||
}
|
||||
return "", service.ErrSettingNotFound
|
||||
}
|
||||
|
||||
func (s *emailBindSettingRepoStub) Set(context.Context, string, string) error {
|
||||
panic("unexpected Set call")
|
||||
}
|
||||
|
||||
func (s *emailBindSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
|
||||
out := make(map[string]string, len(keys))
|
||||
for _, key := range keys {
|
||||
if v, ok := s.values[key]; ok {
|
||||
out[key] = v
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *emailBindSettingRepoStub) SetMultiple(context.Context, map[string]string) error {
|
||||
panic("unexpected SetMultiple call")
|
||||
}
|
||||
|
||||
func (s *emailBindSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
|
||||
panic("unexpected GetAll call")
|
||||
}
|
||||
|
||||
func (s *emailBindSettingRepoStub) Delete(context.Context, string) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
type emailBindCacheStub struct {
|
||||
data *service.VerificationCodeData
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *emailBindCacheStub) GetVerificationCode(context.Context, string) (*service.VerificationCodeData, error) {
|
||||
if s.err != nil {
|
||||
return nil, s.err
|
||||
}
|
||||
return s.data, nil
|
||||
}
|
||||
|
||||
func (s *emailBindCacheStub) SetVerificationCode(context.Context, string, *service.VerificationCodeData, time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *emailBindCacheStub) DeleteVerificationCode(context.Context, string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *emailBindCacheStub) GetNotifyVerifyCode(context.Context, string) (*service.VerificationCodeData, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *emailBindCacheStub) SetNotifyVerifyCode(context.Context, string, *service.VerificationCodeData, time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *emailBindCacheStub) DeleteNotifyVerifyCode(context.Context, string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *emailBindCacheStub) GetPasswordResetToken(context.Context, string) (*service.PasswordResetTokenData, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *emailBindCacheStub) SetPasswordResetToken(context.Context, string, *service.PasswordResetTokenData, time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *emailBindCacheStub) DeletePasswordResetToken(context.Context, string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *emailBindCacheStub) IsPasswordResetEmailInCooldown(context.Context, string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *emailBindCacheStub) SetPasswordResetEmailCooldown(context.Context, string, time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *emailBindCacheStub) GetNotifyCodeUserRate(context.Context, int64) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *emailBindCacheStub) IncrNotifyCodeUserRate(context.Context, int64, time.Duration) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
Reference in New Issue
Block a user