package handler import ( "bytes" "context" "database/sql" "encoding/json" "net/http" "net/http/httptest" "testing" "time" dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/authidentity" "github.com/Wei-Shaw/sub2api/ent/enttest" "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" dbuser "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" "github.com/pquerna/otp/totp" "github.com/stretchr/testify/require" "entgo.io/ent/dialect" entsql "entgo.io/ent/dialect/sql" _ "modernc.org/sqlite" ) func TestApplySuggestedProfileToCompletionResponse(t *testing.T) { payload := map[string]any{ "access_token": "token", } upstream := map[string]any{ "suggested_display_name": "Alice", "suggested_avatar_url": "https://cdn.example/avatar.png", } applySuggestedProfileToCompletionResponse(payload, upstream) require.Equal(t, "Alice", payload["suggested_display_name"]) require.Equal(t, "https://cdn.example/avatar.png", payload["suggested_avatar_url"]) require.Equal(t, true, payload["adoption_required"]) } func TestApplySuggestedProfileToCompletionResponseKeepsExistingPayloadValues(t *testing.T) { payload := map[string]any{ "suggested_display_name": "Existing", "adoption_required": false, } upstream := map[string]any{ "suggested_display_name": "Alice", "suggested_avatar_url": "https://cdn.example/avatar.png", } applySuggestedProfileToCompletionResponse(payload, upstream) require.Equal(t, "Existing", payload["suggested_display_name"]) require.Equal(t, "https://cdn.example/avatar.png", payload["suggested_avatar_url"]) require.Equal(t, true, payload["adoption_required"]) } func TestExchangePendingOAuthCompletionPreviewThenFinalizeAppliesAdoptionDecision(t *testing.T) { handler, client := newOAuthPendingFlowTestHandler(t, false) ctx := context.Background() userEntity, err := client.User.Create(). SetEmail("linuxdo-123@linuxdo-connect.invalid"). SetUsername("legacy-name"). SetPasswordHash("hash"). SetRole(service.RoleUser). SetStatus(service.StatusActive). Save(ctx) require.NoError(t, err) session, err := client.PendingAuthSession.Create(). SetSessionToken("pending-session-token"). SetIntent("login"). SetProviderType("linuxdo"). SetProviderKey("linuxdo"). SetProviderSubject("123"). SetTargetUserID(userEntity.ID). SetResolvedEmail(userEntity.Email). SetBrowserSessionKey("browser-session-key"). SetUpstreamIdentityClaims(map[string]any{ "username": "linuxdo_user", "suggested_display_name": "Alice Example", "suggested_avatar_url": "https://cdn.example/alice.png", }). SetLocalFlowState(map[string]any{ oauthCompletionResponseKey: map[string]any{ "access_token": "access-token", "redirect": "/dashboard", }, }). SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). Save(ctx) require.NoError(t, err) previewRecorder := httptest.NewRecorder() previewCtx, _ := gin.CreateTestContext(previewRecorder) previewReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil) previewReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) previewReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-session-key")}) previewCtx.Request = previewReq handler.ExchangePendingOAuthCompletion(previewCtx) require.Equal(t, http.StatusOK, previewRecorder.Code) previewData := decodeJSONResponseData(t, previewRecorder) require.Equal(t, "Alice Example", previewData["suggested_display_name"]) require.Equal(t, "https://cdn.example/alice.png", previewData["suggested_avatar_url"]) require.Equal(t, true, previewData["adoption_required"]) storedUser, err := client.User.Get(ctx, userEntity.ID) require.NoError(t, err) require.Equal(t, "legacy-name", storedUser.Username) previewSession, err := client.PendingAuthSession.Query(). Where(pendingauthsession.IDEQ(session.ID)). Only(ctx) require.NoError(t, err) require.Nil(t, previewSession.ConsumedAt) body := bytes.NewBufferString(`{"adopt_display_name":true,"adopt_avatar":true}`) finalizeRecorder := httptest.NewRecorder() finalizeCtx, _ := gin.CreateTestContext(finalizeRecorder) finalizeReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body) finalizeReq.Header.Set("Content-Type", "application/json") finalizeReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) finalizeReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-session-key")}) finalizeCtx.Request = finalizeReq handler.ExchangePendingOAuthCompletion(finalizeCtx) require.Equal(t, http.StatusOK, finalizeRecorder.Code) storedUser, err = client.User.Get(ctx, userEntity.ID) require.NoError(t, err) require.Equal(t, "Alice Example", storedUser.Username) identity, err := client.AuthIdentity.Query(). Where( authidentity.ProviderTypeEQ("linuxdo"), authidentity.ProviderKeyEQ("linuxdo"), authidentity.ProviderSubjectEQ("123"), ). Only(ctx) require.NoError(t, err) require.Equal(t, userEntity.ID, identity.UserID) require.Equal(t, "Alice Example", identity.Metadata["display_name"]) require.Equal(t, "https://cdn.example/alice.png", identity.Metadata["avatar_url"]) decision, err := client.IdentityAdoptionDecision.Query(). Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). Only(ctx) require.NoError(t, err) require.NotNil(t, decision.IdentityID) require.Equal(t, identity.ID, *decision.IdentityID) require.True(t, decision.AdoptDisplayName) require.True(t, decision.AdoptAvatar) consumed, err := client.PendingAuthSession.Query(). Where(pendingauthsession.IDEQ(session.ID)). Only(ctx) require.NoError(t, err) require.NotNil(t, consumed.ConsumedAt) } func TestExchangePendingOAuthCompletionBindCurrentUserPreviewThenFinalizeBindsIdentityWithoutAdoption(t *testing.T) { handler, client := newOAuthPendingFlowTestHandler(t, false) ctx := context.Background() userEntity, err := client.User.Create(). SetEmail("bind-target@example.com"). SetUsername("legacy-name"). SetPasswordHash("hash"). SetRole(service.RoleUser). SetStatus(service.StatusActive). Save(ctx) require.NoError(t, err) session, err := client.PendingAuthSession.Create(). SetSessionToken("bind-pending-session-token"). SetIntent("bind_current_user"). SetProviderType("linuxdo"). SetProviderKey("linuxdo"). SetProviderSubject("bind-123"). SetTargetUserID(userEntity.ID). SetResolvedEmail(userEntity.Email). SetBrowserSessionKey("bind-browser-session-key"). SetUpstreamIdentityClaims(map[string]any{ "username": "linuxdo_user", "suggested_display_name": "Bound Example", "suggested_avatar_url": "https://cdn.example/bound.png", }). SetLocalFlowState(map[string]any{ oauthCompletionResponseKey: map[string]any{ "access_token": "access-token", "redirect": "/settings/profile", }, }). SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). Save(ctx) require.NoError(t, err) previewRecorder := httptest.NewRecorder() previewCtx, _ := gin.CreateTestContext(previewRecorder) previewReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil) previewReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) previewReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-browser-session-key")}) previewCtx.Request = previewReq handler.ExchangePendingOAuthCompletion(previewCtx) require.Equal(t, http.StatusOK, previewRecorder.Code) previewData := decodeJSONResponseData(t, previewRecorder) require.Equal(t, "Bound Example", previewData["suggested_display_name"]) require.Equal(t, "https://cdn.example/bound.png", previewData["suggested_avatar_url"]) require.Equal(t, true, previewData["adoption_required"]) identityCount, err := client.AuthIdentity.Query(). Where( authidentity.ProviderTypeEQ("linuxdo"), authidentity.ProviderKeyEQ("linuxdo"), authidentity.ProviderSubjectEQ("bind-123"), ). Count(ctx) require.NoError(t, err) require.Zero(t, identityCount) previewSession, err := client.PendingAuthSession.Query(). Where(pendingauthsession.IDEQ(session.ID)). Only(ctx) require.NoError(t, err) require.Nil(t, previewSession.ConsumedAt) body := bytes.NewBufferString(`{"adopt_display_name":false,"adopt_avatar":false}`) finalizeRecorder := httptest.NewRecorder() finalizeCtx, _ := gin.CreateTestContext(finalizeRecorder) finalizeReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body) finalizeReq.Header.Set("Content-Type", "application/json") finalizeReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) finalizeReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-browser-session-key")}) finalizeCtx.Request = finalizeReq handler.ExchangePendingOAuthCompletion(finalizeCtx) require.Equal(t, http.StatusOK, finalizeRecorder.Code) storedUser, err := client.User.Get(ctx, userEntity.ID) require.NoError(t, err) require.Equal(t, "legacy-name", storedUser.Username) identity, err := client.AuthIdentity.Query(). Where( authidentity.ProviderTypeEQ("linuxdo"), authidentity.ProviderKeyEQ("linuxdo"), authidentity.ProviderSubjectEQ("bind-123"), ). Only(ctx) require.NoError(t, err) require.Equal(t, userEntity.ID, identity.UserID) require.Equal(t, "Bound Example", identity.Metadata["suggested_display_name"]) require.Equal(t, "https://cdn.example/bound.png", identity.Metadata["suggested_avatar_url"]) _, hasDisplayName := identity.Metadata["display_name"] require.False(t, hasDisplayName) _, hasAvatarURL := identity.Metadata["avatar_url"] require.False(t, hasAvatarURL) decision, err := client.IdentityAdoptionDecision.Query(). Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). Only(ctx) require.NoError(t, err) require.NotNil(t, decision.IdentityID) require.Equal(t, identity.ID, *decision.IdentityID) require.False(t, decision.AdoptDisplayName) require.False(t, decision.AdoptAvatar) consumed, err := client.PendingAuthSession.Query(). Where(pendingauthsession.IDEQ(session.ID)). Only(ctx) require.NoError(t, err) require.NotNil(t, consumed.ConsumedAt) } func TestExchangePendingOAuthCompletionBindCurrentUserOwnershipConflict(t *testing.T) { handler, client := newOAuthPendingFlowTestHandler(t, false) ctx := context.Background() targetUser, err := client.User.Create(). SetEmail("bind-conflict-target@example.com"). SetUsername("target-user"). SetPasswordHash("hash"). SetRole(service.RoleUser). SetStatus(service.StatusActive). Save(ctx) require.NoError(t, err) ownerUser, err := client.User.Create(). SetEmail("bind-conflict-owner@example.com"). SetUsername("owner-user"). SetPasswordHash("hash"). SetRole(service.RoleUser). SetStatus(service.StatusActive). Save(ctx) require.NoError(t, err) existingIdentity, err := client.AuthIdentity.Create(). SetUserID(ownerUser.ID). SetProviderType("linuxdo"). SetProviderKey("linuxdo"). SetProviderSubject("conflict-123"). SetMetadata(map[string]any{"username": "owner-user"}). Save(ctx) require.NoError(t, err) session, err := client.PendingAuthSession.Create(). SetSessionToken("bind-conflict-session-token"). SetIntent("bind_current_user"). SetProviderType("linuxdo"). SetProviderKey("linuxdo"). SetProviderSubject("conflict-123"). SetTargetUserID(targetUser.ID). SetResolvedEmail(targetUser.Email). SetBrowserSessionKey("bind-conflict-browser-session-key"). SetUpstreamIdentityClaims(map[string]any{ "suggested_display_name": "Conflict Example", "suggested_avatar_url": "https://cdn.example/conflict.png", }). SetLocalFlowState(map[string]any{ oauthCompletionResponseKey: map[string]any{ "access_token": "access-token", }, }). SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). Save(ctx) require.NoError(t, err) body := bytes.NewBufferString(`{"adopt_display_name":false,"adopt_avatar":false}`) recorder := httptest.NewRecorder() ginCtx, _ := gin.CreateTestContext(recorder) req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body) req.Header.Set("Content-Type", "application/json") req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-conflict-browser-session-key")}) ginCtx.Request = req handler.ExchangePendingOAuthCompletion(ginCtx) require.Equal(t, http.StatusInternalServerError, recorder.Code) payload := decodeJSONBody(t, recorder) require.Equal(t, "PENDING_AUTH_ADOPTION_APPLY_FAILED", payload["reason"]) identity, err := client.AuthIdentity.Get(ctx, existingIdentity.ID) require.NoError(t, err) require.Equal(t, ownerUser.ID, identity.UserID) decision, err := client.IdentityAdoptionDecision.Query(). Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). Only(ctx) require.NoError(t, err) require.Nil(t, decision.IdentityID) require.False(t, decision.AdoptDisplayName) require.False(t, decision.AdoptAvatar) storedSession, err := client.PendingAuthSession.Query(). Where(pendingauthsession.IDEQ(session.ID)). Only(ctx) require.NoError(t, err) require.Nil(t, storedSession.ConsumedAt) } func TestExchangePendingOAuthCompletionLoginFalseFalseDoesNotBindIdentity(t *testing.T) { handler, client := newOAuthPendingFlowTestHandler(t, false) ctx := context.Background() userEntity, err := client.User.Create(). SetEmail("login-false@example.com"). SetUsername("legacy-name"). SetPasswordHash("hash"). SetRole(service.RoleUser). SetStatus(service.StatusActive). Save(ctx) require.NoError(t, err) session, err := client.PendingAuthSession.Create(). SetSessionToken("login-false-session-token"). SetIntent("login"). SetProviderType("linuxdo"). SetProviderKey("linuxdo"). SetProviderSubject("login-false-123"). SetTargetUserID(userEntity.ID). SetResolvedEmail(userEntity.Email). SetBrowserSessionKey("login-false-browser-session-key"). SetUpstreamIdentityClaims(map[string]any{ "suggested_display_name": "Login Example", "suggested_avatar_url": "https://cdn.example/login.png", }). SetLocalFlowState(map[string]any{ oauthCompletionResponseKey: map[string]any{ "access_token": "access-token", }, }). SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). Save(ctx) require.NoError(t, err) body := bytes.NewBufferString(`{"adopt_display_name":false,"adopt_avatar":false}`) recorder := httptest.NewRecorder() ginCtx, _ := gin.CreateTestContext(recorder) req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body) req.Header.Set("Content-Type", "application/json") req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("login-false-browser-session-key")}) ginCtx.Request = req handler.ExchangePendingOAuthCompletion(ginCtx) require.Equal(t, http.StatusOK, recorder.Code) identityCount, err := client.AuthIdentity.Query(). Where( authidentity.ProviderTypeEQ("linuxdo"), authidentity.ProviderKeyEQ("linuxdo"), authidentity.ProviderSubjectEQ("login-false-123"), ). Count(ctx) require.NoError(t, err) require.Zero(t, identityCount) decision, err := client.IdentityAdoptionDecision.Query(). Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). Only(ctx) require.NoError(t, err) require.Nil(t, decision.IdentityID) require.False(t, decision.AdoptDisplayName) require.False(t, decision.AdoptAvatar) storedSession, err := client.PendingAuthSession.Query(). Where(pendingauthsession.IDEQ(session.ID)). Only(ctx) require.NoError(t, err) require.NotNil(t, storedSession.ConsumedAt) } func TestExchangePendingOAuthCompletionInvitationRequiredFalseFalsePersistsDecisionWithoutBinding(t *testing.T) { handler, client := newOAuthPendingFlowTestHandler(t, true) ctx := context.Background() session, err := client.PendingAuthSession.Create(). SetSessionToken("invitation-required-session-token"). SetIntent("login"). SetProviderType("linuxdo"). SetProviderKey("linuxdo"). SetProviderSubject("invitation-123"). SetBrowserSessionKey("invitation-required-browser-session-key"). SetUpstreamIdentityClaims(map[string]any{ "suggested_display_name": "Invite Example", "suggested_avatar_url": "https://cdn.example/invite.png", }). SetLocalFlowState(map[string]any{ oauthCompletionResponseKey: map[string]any{ "error": "invitation_required", }, }). SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). Save(ctx) require.NoError(t, err) body := bytes.NewBufferString(`{"adopt_display_name":false,"adopt_avatar":false}`) recorder := httptest.NewRecorder() ginCtx, _ := gin.CreateTestContext(recorder) req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body) req.Header.Set("Content-Type", "application/json") req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("invitation-required-browser-session-key")}) ginCtx.Request = req handler.ExchangePendingOAuthCompletion(ginCtx) require.Equal(t, http.StatusOK, recorder.Code) data := decodeJSONResponseData(t, recorder) require.Equal(t, "invitation_required", data["error"]) require.Equal(t, true, data["adoption_required"]) identityCount, err := client.AuthIdentity.Query(). Where( authidentity.ProviderTypeEQ("linuxdo"), authidentity.ProviderKeyEQ("linuxdo"), authidentity.ProviderSubjectEQ("invitation-123"), ). Count(ctx) require.NoError(t, err) require.Zero(t, identityCount) decision, err := client.IdentityAdoptionDecision.Query(). Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). Only(ctx) require.NoError(t, err) require.Nil(t, decision.IdentityID) require.False(t, decision.AdoptDisplayName) require.False(t, decision.AdoptAvatar) storedSession, err := client.PendingAuthSession.Query(). Where(pendingauthsession.IDEQ(session.ID)). Only(ctx) require.NoError(t, err) require.Nil(t, storedSession.ConsumedAt) } func TestCreateOIDCOAuthAccountCreatesUserBindsIdentityAndConsumesSession(t *testing.T) { handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "fresh@example.com", "246810") ctx := context.Background() session, err := client.PendingAuthSession.Create(). SetSessionToken("create-account-session-token"). SetIntent("login"). SetProviderType("oidc"). SetProviderKey("https://issuer.example"). SetProviderSubject("oidc-create-123"). SetBrowserSessionKey("create-account-browser-session-key"). SetUpstreamIdentityClaims(map[string]any{ "username": "oidc_user", "suggested_display_name": "Fresh OIDC User", "suggested_avatar_url": "https://cdn.example/fresh.png", }). SetRedirectTo("/profile"). SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). Save(ctx) require.NoError(t, err) body := bytes.NewBufferString(`{"email":"fresh@example.com","verify_code":"246810","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`) recorder := httptest.NewRecorder() ginCtx, _ := gin.CreateTestContext(recorder) req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body) req.Header.Set("Content-Type", "application/json") req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("create-account-browser-session-key")}) ginCtx.Request = req handler.CreateOIDCOAuthAccount(ginCtx) require.Equal(t, http.StatusOK, recorder.Code) var payload map[string]any require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload)) require.NotEmpty(t, payload["access_token"]) require.NotEmpty(t, payload["refresh_token"]) require.Equal(t, "Bearer", payload["token_type"]) createdUser, err := client.User.Query().Where(dbuser.EmailEQ("fresh@example.com")).Only(ctx) require.NoError(t, err) require.Equal(t, service.StatusActive, createdUser.Status) identity, err := client.AuthIdentity.Query(). Where( authidentity.ProviderTypeEQ("oidc"), authidentity.ProviderKeyEQ("https://issuer.example"), authidentity.ProviderSubjectEQ("oidc-create-123"), ). Only(ctx) require.NoError(t, err) require.Equal(t, createdUser.ID, identity.UserID) storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) require.NoError(t, err) require.NotNil(t, storedSession.ConsumedAt) } func TestCreateOIDCOAuthAccountExistingEmailReturnsAdoptExistingUserByEmailState(t *testing.T) { handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790") ctx := context.Background() existingUser, err := client.User.Create(). SetEmail("owner@example.com"). SetUsername("owner-user"). SetPasswordHash("hash"). SetRole(service.RoleUser). SetStatus(service.StatusActive). Save(ctx) require.NoError(t, err) session, err := client.PendingAuthSession.Create(). SetSessionToken("existing-email-session-token"). SetIntent("login"). SetProviderType("oidc"). SetProviderKey("https://issuer.example"). SetProviderSubject("oidc-existing-123"). SetBrowserSessionKey("existing-email-browser-session-key"). SetUpstreamIdentityClaims(map[string]any{ "username": "oidc_user", "suggested_display_name": "Existing OIDC User", "suggested_avatar_url": "https://cdn.example/existing.png", }). SetRedirectTo("/dashboard"). SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). Save(ctx) require.NoError(t, err) body := bytes.NewBufferString(`{"email":"owner@example.com","verify_code":"135790","password":"secret-123"}`) recorder := httptest.NewRecorder() ginCtx, _ := gin.CreateTestContext(recorder) req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body) req.Header.Set("Content-Type", "application/json") req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("existing-email-browser-session-key")}) ginCtx.Request = req handler.CreateOIDCOAuthAccount(ginCtx) require.Equal(t, http.StatusOK, recorder.Code) var payload map[string]any require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload)) require.Equal(t, "pending_session", payload["auth_result"]) require.Equal(t, "adopt_existing_user_by_email", payload["intent"]) require.Equal(t, "oidc", payload["provider"]) require.Equal(t, "/dashboard", payload["redirect"]) require.Equal(t, true, payload["adoption_required"]) require.Equal(t, "Existing OIDC User", payload["suggested_display_name"]) require.Equal(t, "https://cdn.example/existing.png", payload["suggested_avatar_url"]) storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) require.NoError(t, err) require.Equal(t, "adopt_existing_user_by_email", storedSession.Intent) require.NotNil(t, storedSession.TargetUserID) require.Equal(t, existingUser.ID, *storedSession.TargetUserID) require.Equal(t, "owner@example.com", storedSession.ResolvedEmail) require.Nil(t, storedSession.ConsumedAt) identityCount, err := client.AuthIdentity.Query(). Where( authidentity.ProviderTypeEQ("oidc"), authidentity.ProviderKeyEQ("https://issuer.example"), authidentity.ProviderSubjectEQ("oidc-existing-123"), ). Count(ctx) require.NoError(t, err) require.Zero(t, identityCount) } func TestBindOIDCOAuthLoginBindsExistingUserAndConsumesSession(t *testing.T) { handler, client := newOAuthPendingFlowTestHandler(t, false) ctx := context.Background() passwordHash, err := handler.authService.HashPassword("secret-123") require.NoError(t, err) existingUser, err := client.User.Create(). SetEmail("owner@example.com"). SetUsername("owner-user"). SetPasswordHash(passwordHash). SetRole(service.RoleUser). SetStatus(service.StatusActive). Save(ctx) require.NoError(t, err) session, err := client.PendingAuthSession.Create(). SetSessionToken("bind-login-session-token"). SetIntent("adopt_existing_user_by_email"). SetProviderType("oidc"). SetProviderKey("https://issuer.example"). SetProviderSubject("oidc-bind-123"). SetTargetUserID(existingUser.ID). SetResolvedEmail(existingUser.Email). SetBrowserSessionKey("bind-login-browser-session-key"). SetUpstreamIdentityClaims(map[string]any{ "username": "oidc_user", "suggested_display_name": "Bound OIDC User", "suggested_avatar_url": "https://cdn.example/bound.png", }). SetRedirectTo("/profile"). SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). Save(ctx) require.NoError(t, err) body := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`) recorder := httptest.NewRecorder() ginCtx, _ := gin.CreateTestContext(recorder) req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", body) req.Header.Set("Content-Type", "application/json") req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-login-browser-session-key")}) ginCtx.Request = req handler.BindOIDCOAuthLogin(ginCtx) require.Equal(t, http.StatusOK, recorder.Code) var payload map[string]any require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload)) require.NotEmpty(t, payload["access_token"]) require.NotEmpty(t, payload["refresh_token"]) require.Equal(t, "Bearer", payload["token_type"]) identity, err := client.AuthIdentity.Query(). Where( authidentity.ProviderTypeEQ("oidc"), authidentity.ProviderKeyEQ("https://issuer.example"), authidentity.ProviderSubjectEQ("oidc-bind-123"), ). Only(ctx) require.NoError(t, err) require.Equal(t, existingUser.ID, identity.UserID) storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) require.NoError(t, err) require.NotNil(t, storedSession.ConsumedAt) } func TestBindOIDCOAuthLoginRejectsInvalidPasswordWithoutConsumingSession(t *testing.T) { handler, client := newOAuthPendingFlowTestHandler(t, false) ctx := context.Background() passwordHash, err := handler.authService.HashPassword("secret-123") require.NoError(t, err) existingUser, err := client.User.Create(). SetEmail("owner@example.com"). SetUsername("owner-user"). SetPasswordHash(passwordHash). SetRole(service.RoleUser). SetStatus(service.StatusActive). Save(ctx) require.NoError(t, err) session, err := client.PendingAuthSession.Create(). SetSessionToken("bind-login-invalid-password-session-token"). SetIntent("adopt_existing_user_by_email"). SetProviderType("oidc"). SetProviderKey("https://issuer.example"). SetProviderSubject("oidc-bind-invalid-123"). SetTargetUserID(existingUser.ID). SetResolvedEmail(existingUser.Email). SetBrowserSessionKey("bind-login-invalid-password-browser-session-key"). SetUpstreamIdentityClaims(map[string]any{ "username": "oidc_user", "suggested_display_name": "Bound OIDC User", "suggested_avatar_url": "https://cdn.example/bound.png", }). SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). Save(ctx) require.NoError(t, err) body := bytes.NewBufferString(`{"email":"owner@example.com","password":"wrong-password"}`) recorder := httptest.NewRecorder() ginCtx, _ := gin.CreateTestContext(recorder) req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", body) req.Header.Set("Content-Type", "application/json") req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-login-invalid-password-browser-session-key")}) ginCtx.Request = req handler.BindOIDCOAuthLogin(ginCtx) require.Equal(t, http.StatusUnauthorized, recorder.Code) payload := decodeJSONBody(t, recorder) require.Equal(t, "INVALID_CREDENTIALS", payload["reason"]) identityCount, err := client.AuthIdentity.Query(). Where( authidentity.ProviderTypeEQ("oidc"), authidentity.ProviderKeyEQ("https://issuer.example"), authidentity.ProviderSubjectEQ("oidc-bind-invalid-123"), ). Count(ctx) require.NoError(t, err) require.Zero(t, identityCount) storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) require.NoError(t, err) require.Nil(t, storedSession.ConsumedAt) } func TestBindOIDCOAuthLoginAppliesFirstBindGrantOnce(t *testing.T) { defaultSubAssigner := &oauthPendingFlowDefaultSubAssignerStub{} handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{ settingValues: map[string]string{ service.SettingKeyAuthSourceDefaultOIDCBalance: "12.5", service.SettingKeyAuthSourceDefaultOIDCConcurrency: "3", service.SettingKeyAuthSourceDefaultOIDCSubscriptions: `[{"group_id":101,"validity_days":30}]`, service.SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind: "true", }, defaultSubAssigner: defaultSubAssigner, }) ctx := context.Background() passwordHash, err := handler.authService.HashPassword("secret-123") require.NoError(t, err) existingUser, err := client.User.Create(). SetEmail("owner@example.com"). SetUsername("owner-user"). SetPasswordHash(passwordHash). SetBalance(5). SetConcurrency(2). SetRole(service.RoleUser). SetStatus(service.StatusActive). Save(ctx) require.NoError(t, err) firstSession, err := client.PendingAuthSession.Create(). SetSessionToken("first-bind-session-token"). SetIntent("adopt_existing_user_by_email"). SetProviderType("oidc"). SetProviderKey("https://issuer.example"). SetProviderSubject("oidc-bind-first-123"). SetTargetUserID(existingUser.ID). SetResolvedEmail(existingUser.Email). SetBrowserSessionKey("first-bind-browser-session-key"). SetUpstreamIdentityClaims(map[string]any{ "suggested_display_name": "Bound OIDC User", "suggested_avatar_url": "https://cdn.example/bound.png", }). SetRedirectTo("/profile"). SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). Save(ctx) require.NoError(t, err) firstBody := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`) firstRecorder := httptest.NewRecorder() firstGinCtx, _ := gin.CreateTestContext(firstRecorder) firstReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", firstBody) firstReq.Header.Set("Content-Type", "application/json") firstReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(firstSession.SessionToken)}) firstReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("first-bind-browser-session-key")}) firstGinCtx.Request = firstReq handler.BindOIDCOAuthLogin(firstGinCtx) require.Equal(t, http.StatusOK, firstRecorder.Code) storedUser, err := client.User.Get(ctx, existingUser.ID) require.NoError(t, err) require.Equal(t, 17.5, storedUser.Balance) require.Equal(t, 5, storedUser.Concurrency) require.Zero(t, storedUser.TotalRecharged) require.Len(t, defaultSubAssigner.calls, 1) require.Equal(t, int64(existingUser.ID), defaultSubAssigner.calls[0].UserID) require.Equal(t, int64(101), defaultSubAssigner.calls[0].GroupID) require.Equal(t, 30, defaultSubAssigner.calls[0].ValidityDays) require.Equal(t, 1, countProviderGrantRecords(t, client, existingUser.ID, "oidc", "first_bind")) secondSession, err := client.PendingAuthSession.Create(). SetSessionToken("second-bind-session-token"). SetIntent("adopt_existing_user_by_email"). SetProviderType("oidc"). SetProviderKey("https://issuer.example"). SetProviderSubject("oidc-bind-second-456"). SetTargetUserID(existingUser.ID). SetResolvedEmail(existingUser.Email). SetBrowserSessionKey("second-bind-browser-session-key"). SetUpstreamIdentityClaims(map[string]any{ "suggested_display_name": "Second OIDC User", "suggested_avatar_url": "https://cdn.example/second.png", }). SetRedirectTo("/profile"). SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). Save(ctx) require.NoError(t, err) secondBody := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`) secondRecorder := httptest.NewRecorder() secondGinCtx, _ := gin.CreateTestContext(secondRecorder) secondReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", secondBody) secondReq.Header.Set("Content-Type", "application/json") secondReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(secondSession.SessionToken)}) secondReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("second-bind-browser-session-key")}) secondGinCtx.Request = secondReq handler.BindOIDCOAuthLogin(secondGinCtx) require.Equal(t, http.StatusOK, secondRecorder.Code) storedUser, err = client.User.Get(ctx, existingUser.ID) require.NoError(t, err) require.Equal(t, 17.5, storedUser.Balance) require.Equal(t, 5, storedUser.Concurrency) require.Zero(t, storedUser.TotalRecharged) require.Len(t, defaultSubAssigner.calls, 1) require.Equal(t, 1, countProviderGrantRecords(t, client, existingUser.ID, "oidc", "first_bind")) } func TestBindOIDCOAuthLoginReturns2FAChallengeWhenUserHasTotp(t *testing.T) { totpCache := &oauthPendingFlowTotpCacheStub{} handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{ settingValues: map[string]string{ service.SettingKeyTotpEnabled: "true", }, totpCache: totpCache, totpEncryptor: oauthPendingFlowTotpEncryptorStub{}, }) ctx := context.Background() passwordHash, err := handler.authService.HashPassword("secret-123") require.NoError(t, err) totpEnabledAt := time.Now().UTC().Add(-time.Hour) secret := "JBSWY3DPEHPK3PXP" existingUser, err := client.User.Create(). SetEmail("owner@example.com"). SetUsername("owner-user"). SetPasswordHash(passwordHash). SetRole(service.RoleUser). SetStatus(service.StatusActive). SetTotpEnabled(true). SetTotpSecretEncrypted(secret). SetTotpEnabledAt(totpEnabledAt). Save(ctx) require.NoError(t, err) session, err := client.PendingAuthSession.Create(). SetSessionToken("bind-login-2fa-session-token"). SetIntent("adopt_existing_user_by_email"). SetProviderType("oidc"). SetProviderKey("https://issuer.example"). SetProviderSubject("oidc-bind-2fa-123"). SetTargetUserID(existingUser.ID). SetResolvedEmail(existingUser.Email). SetBrowserSessionKey("bind-login-2fa-browser-session-key"). SetUpstreamIdentityClaims(map[string]any{ "suggested_display_name": "Bound OIDC User", "suggested_avatar_url": "https://cdn.example/bound.png", }). SetRedirectTo("/profile"). SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). Save(ctx) require.NoError(t, err) body := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`) recorder := httptest.NewRecorder() ginCtx, _ := gin.CreateTestContext(recorder) req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", body) req.Header.Set("Content-Type", "application/json") req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-login-2fa-browser-session-key")}) ginCtx.Request = req handler.BindOIDCOAuthLogin(ginCtx) require.Equal(t, http.StatusOK, recorder.Code) data := decodeJSONResponseData(t, recorder) require.Equal(t, true, data["requires_2fa"]) require.Equal(t, "o***r@example.com", data["user_email_masked"]) tempToken, ok := data["temp_token"].(string) require.True(t, ok) require.NotEmpty(t, tempToken) loginSession, err := totpCache.GetLoginSession(ctx, tempToken) require.NoError(t, err) require.NotNil(t, loginSession) require.NotNil(t, loginSession.PendingOAuthBind) require.Equal(t, session.SessionToken, loginSession.PendingOAuthBind.PendingSessionToken) require.Equal(t, session.BrowserSessionKey, loginSession.PendingOAuthBind.BrowserSessionKey) identityCount, err := client.AuthIdentity.Query(). Where( authidentity.ProviderTypeEQ("oidc"), authidentity.ProviderKeyEQ("https://issuer.example"), authidentity.ProviderSubjectEQ("oidc-bind-2fa-123"), ). Count(ctx) require.NoError(t, err) require.Zero(t, identityCount) storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) require.NoError(t, err) require.Nil(t, storedSession.ConsumedAt) } func TestLogin2FACompletesPendingOAuthBindAndConsumesSession(t *testing.T) { totpCache := &oauthPendingFlowTotpCacheStub{} defaultSubAssigner := &oauthPendingFlowDefaultSubAssignerStub{} handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{ settingValues: map[string]string{ service.SettingKeyTotpEnabled: "true", service.SettingKeyAuthSourceDefaultOIDCBalance: "8", service.SettingKeyAuthSourceDefaultOIDCConcurrency: "2", service.SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind: "true", }, defaultSubAssigner: defaultSubAssigner, totpCache: totpCache, totpEncryptor: oauthPendingFlowTotpEncryptorStub{}, }) ctx := context.Background() passwordHash, err := handler.authService.HashPassword("secret-123") require.NoError(t, err) totpEnabledAt := time.Now().UTC().Add(-time.Hour) secret := "JBSWY3DPEHPK3PXP" existingUser, err := client.User.Create(). SetEmail("owner@example.com"). SetUsername("owner-user"). SetPasswordHash(passwordHash). SetBalance(1.5). SetConcurrency(4). SetRole(service.RoleUser). SetStatus(service.StatusActive). SetTotpEnabled(true). SetTotpSecretEncrypted(secret). SetTotpEnabledAt(totpEnabledAt). Save(ctx) require.NoError(t, err) session, err := client.PendingAuthSession.Create(). SetSessionToken("login-2fa-pending-session-token"). SetIntent("adopt_existing_user_by_email"). SetProviderType("oidc"). SetProviderKey("https://issuer.example"). SetProviderSubject("oidc-login-2fa-123"). SetTargetUserID(existingUser.ID). SetResolvedEmail(existingUser.Email). SetBrowserSessionKey("login-2fa-browser-session-key"). SetUpstreamIdentityClaims(map[string]any{ "suggested_display_name": "Bound OIDC User", "suggested_avatar_url": "https://cdn.example/bound.png", }). SetRedirectTo("/profile"). SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). Save(ctx) require.NoError(t, err) _, err = client.IdentityAdoptionDecision.Create(). SetPendingAuthSessionID(session.ID). SetAdoptDisplayName(false). SetAdoptAvatar(false). Save(ctx) require.NoError(t, err) tempToken, err := handler.totpService.CreatePendingOAuthBindLoginSession( ctx, existingUser.ID, existingUser.Email, session.SessionToken, session.BrowserSessionKey, ) require.NoError(t, err) code, err := totp.GenerateCode(secret, time.Now().UTC()) require.NoError(t, err) body := bytes.NewBufferString(`{"temp_token":"` + tempToken + `","totp_code":"` + code + `"}`) recorder := httptest.NewRecorder() ginCtx, _ := gin.CreateTestContext(recorder) req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login/2fa", body) req.Header.Set("Content-Type", "application/json") req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue(session.BrowserSessionKey)}) ginCtx.Request = req handler.Login2FA(ginCtx) require.Equal(t, http.StatusOK, recorder.Code) payload := decodeJSONResponseData(t, recorder) require.NotEmpty(t, payload["access_token"]) require.NotEmpty(t, payload["refresh_token"]) identity, err := client.AuthIdentity.Query(). Where( authidentity.ProviderTypeEQ("oidc"), authidentity.ProviderKeyEQ("https://issuer.example"), authidentity.ProviderSubjectEQ("oidc-login-2fa-123"), ). Only(ctx) require.NoError(t, err) require.Equal(t, existingUser.ID, identity.UserID) storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) require.NoError(t, err) require.NotNil(t, storedSession.ConsumedAt) loginSession, err := totpCache.GetLoginSession(ctx, tempToken) require.NoError(t, err) require.Nil(t, loginSession) storedUser, err := client.User.Get(ctx, existingUser.ID) require.NoError(t, err) require.Equal(t, 9.5, storedUser.Balance) require.Equal(t, 6, storedUser.Concurrency) require.Equal(t, 1, countProviderGrantRecords(t, client, existingUser.ID, "oidc", "first_bind")) require.Empty(t, defaultSubAssigner.calls) } func newOAuthPendingFlowTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) { t.Helper() return newOAuthPendingFlowTestHandlerWithOptions(t, invitationEnabled, false, nil) } func newOAuthPendingFlowTestHandlerWithEmailVerification( t *testing.T, invitationEnabled bool, email string, code string, ) (*AuthHandler, *dbent.Client) { t.Helper() cache := &oauthPendingFlowEmailCacheStub{ verificationCodes: map[string]*service.VerificationCodeData{ email: { Code: code, Attempts: 0, CreatedAt: time.Now().UTC(), ExpiresAt: time.Now().UTC().Add(15 * time.Minute), }, }, } return newOAuthPendingFlowTestHandlerWithOptions(t, invitationEnabled, true, cache) } func newOAuthPendingFlowTestHandlerWithOptions( t *testing.T, invitationEnabled bool, emailVerifyEnabled bool, emailCache service.EmailCache, ) (*AuthHandler, *dbent.Client) { return newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{ invitationEnabled: invitationEnabled, emailVerifyEnabled: emailVerifyEnabled, emailCache: emailCache, }) } type oauthPendingFlowTestHandlerOptions struct { invitationEnabled bool emailVerifyEnabled bool emailCache service.EmailCache settingValues map[string]string defaultSubAssigner service.DefaultSubscriptionAssigner totpCache service.TotpCache totpEncryptor service.SecretEncryptor } func newOAuthPendingFlowTestHandlerWithDependencies( t *testing.T, options oauthPendingFlowTestHandlerOptions, ) (*AuthHandler, *dbent.Client) { t.Helper() db, err := sql.Open("sqlite", "file:auth_oauth_pending_flow_handler?mode=memory&cache=shared") require.NoError(t, err) t.Cleanup(func() { _ = db.Close() }) _, err = db.Exec("PRAGMA foreign_keys = ON") require.NoError(t, err) _, err = db.Exec(` CREATE TABLE IF NOT EXISTS user_provider_default_grants ( id INTEGER PRIMARY KEY AUTOINCREMENT, user_id INTEGER NOT NULL, provider_type TEXT NOT NULL, grant_reason TEXT NOT NULL DEFAULT 'first_bind', created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, UNIQUE(user_id, provider_type, grant_reason) )`) require.NoError(t, err) drv := entsql.OpenDB(dialect.SQLite, db) client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) cfg := &config.Config{ JWT: config.JWTConfig{ Secret: "test-secret", ExpireHour: 1, AccessTokenExpireMinutes: 60, RefreshTokenExpireDays: 7, }, Default: config.DefaultConfig{ UserBalance: 0, UserConcurrency: 1, }, } settingValues := map[string]string{ service.SettingKeyRegistrationEnabled: "true", service.SettingKeyInvitationCodeEnabled: boolSettingValue(options.invitationEnabled), service.SettingKeyEmailVerifyEnabled: boolSettingValue(options.emailVerifyEnabled), } for key, value := range options.settingValues { settingValues[key] = value } settingSvc := service.NewSettingService(&oauthPendingFlowSettingRepoStub{values: settingValues}, cfg) userRepo := &oauthPendingFlowUserRepo{client: client} var emailService *service.EmailService if options.emailCache != nil { emailService = service.NewEmailService(&oauthPendingFlowSettingRepoStub{ values: map[string]string{ service.SettingKeyEmailVerifyEnabled: boolSettingValue(options.emailVerifyEnabled), }, }, options.emailCache) } authSvc := service.NewAuthService( client, userRepo, nil, &oauthPendingFlowRefreshTokenCacheStub{}, cfg, settingSvc, emailService, nil, nil, nil, options.defaultSubAssigner, ) userSvc := service.NewUserService(userRepo, nil, nil, nil) var totpSvc *service.TotpService if options.totpCache != nil || options.totpEncryptor != nil { totpCache := options.totpCache if totpCache == nil { totpCache = &oauthPendingFlowTotpCacheStub{} } totpEncryptor := options.totpEncryptor if totpEncryptor == nil { totpEncryptor = oauthPendingFlowTotpEncryptorStub{} } totpSvc = service.NewTotpService(userRepo, totpEncryptor, totpCache, settingSvc, nil, nil) } return &AuthHandler{ authService: authSvc, userService: userSvc, settingSvc: settingSvc, totpService: totpSvc, }, client } func boolSettingValue(v bool) string { if v { return "true" } return "false" } func boolPtr(v bool) *bool { return &v } type oauthPendingFlowSettingRepoStub struct { values map[string]string } func (s *oauthPendingFlowSettingRepoStub) Get(context.Context, string) (*service.Setting, error) { return nil, service.ErrSettingNotFound } func (s *oauthPendingFlowSettingRepoStub) GetValue(_ context.Context, key string) (string, error) { value, ok := s.values[key] if !ok { return "", service.ErrSettingNotFound } return value, nil } func (s *oauthPendingFlowSettingRepoStub) Set(context.Context, string, string) error { return nil } func (s *oauthPendingFlowSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { result := make(map[string]string, len(keys)) for _, key := range keys { if value, ok := s.values[key]; ok { result[key] = value } } return result, nil } func (s *oauthPendingFlowSettingRepoStub) SetMultiple(context.Context, map[string]string) error { return nil } func (s *oauthPendingFlowSettingRepoStub) GetAll(context.Context) (map[string]string, error) { result := make(map[string]string, len(s.values)) for key, value := range s.values { result[key] = value } return result, nil } func (s *oauthPendingFlowSettingRepoStub) Delete(context.Context, string) error { return nil } type oauthPendingFlowRefreshTokenCacheStub struct{} type oauthPendingFlowEmailCacheStub struct { verificationCodes map[string]*service.VerificationCodeData } func (s *oauthPendingFlowEmailCacheStub) GetVerificationCode(_ context.Context, email string) (*service.VerificationCodeData, error) { if s == nil || s.verificationCodes == nil { return nil, nil } return s.verificationCodes[email], nil } func (s *oauthPendingFlowEmailCacheStub) SetVerificationCode(_ context.Context, email string, data *service.VerificationCodeData, _ time.Duration) error { if s.verificationCodes == nil { s.verificationCodes = map[string]*service.VerificationCodeData{} } s.verificationCodes[email] = data return nil } func (s *oauthPendingFlowEmailCacheStub) DeleteVerificationCode(_ context.Context, email string) error { delete(s.verificationCodes, email) return nil } func (s *oauthPendingFlowEmailCacheStub) GetNotifyVerifyCode(context.Context, string) (*service.VerificationCodeData, error) { return nil, nil } func (s *oauthPendingFlowEmailCacheStub) SetNotifyVerifyCode(context.Context, string, *service.VerificationCodeData, time.Duration) error { return nil } func (s *oauthPendingFlowEmailCacheStub) DeleteNotifyVerifyCode(context.Context, string) error { return nil } func (s *oauthPendingFlowEmailCacheStub) GetPasswordResetToken(context.Context, string) (*service.PasswordResetTokenData, error) { return nil, nil } func (s *oauthPendingFlowEmailCacheStub) SetPasswordResetToken(context.Context, string, *service.PasswordResetTokenData, time.Duration) error { return nil } func (s *oauthPendingFlowEmailCacheStub) DeletePasswordResetToken(context.Context, string) error { return nil } func (s *oauthPendingFlowEmailCacheStub) IsPasswordResetEmailInCooldown(context.Context, string) bool { return false } func (s *oauthPendingFlowEmailCacheStub) SetPasswordResetEmailCooldown(context.Context, string, time.Duration) error { return nil } func (s *oauthPendingFlowEmailCacheStub) IncrNotifyCodeUserRate(context.Context, int64, time.Duration) (int64, error) { return 0, nil } func (s *oauthPendingFlowEmailCacheStub) GetNotifyCodeUserRate(context.Context, int64) (int64, error) { return 0, nil } func (s *oauthPendingFlowRefreshTokenCacheStub) StoreRefreshToken(context.Context, string, *service.RefreshTokenData, time.Duration) error { return nil } func (s *oauthPendingFlowRefreshTokenCacheStub) GetRefreshToken(context.Context, string) (*service.RefreshTokenData, error) { return nil, service.ErrRefreshTokenNotFound } func (s *oauthPendingFlowRefreshTokenCacheStub) DeleteRefreshToken(context.Context, string) error { return nil } func (s *oauthPendingFlowRefreshTokenCacheStub) DeleteUserRefreshTokens(context.Context, int64) error { return nil } func (s *oauthPendingFlowRefreshTokenCacheStub) DeleteTokenFamily(context.Context, string) error { return nil } func (s *oauthPendingFlowRefreshTokenCacheStub) AddToUserTokenSet(context.Context, int64, string, time.Duration) error { return nil } func (s *oauthPendingFlowRefreshTokenCacheStub) AddToFamilyTokenSet(context.Context, string, string, time.Duration) error { return nil } func (s *oauthPendingFlowRefreshTokenCacheStub) GetUserTokenHashes(context.Context, int64) ([]string, error) { return nil, nil } func (s *oauthPendingFlowRefreshTokenCacheStub) GetFamilyTokenHashes(context.Context, string) ([]string, error) { return nil, nil } func (s *oauthPendingFlowRefreshTokenCacheStub) IsTokenInFamily(context.Context, string, string) (bool, error) { return false, nil } func decodeJSONResponseData(t *testing.T, recorder *httptest.ResponseRecorder) map[string]any { t.Helper() var envelope struct { Data map[string]any `json:"data"` } require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &envelope)) return envelope.Data } func decodeJSONBody(t *testing.T, recorder *httptest.ResponseRecorder) map[string]any { t.Helper() var payload map[string]any require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload)) return payload } func countProviderGrantRecords( t *testing.T, client *dbent.Client, userID int64, providerType string, grantReason string, ) int { t.Helper() var rows entsql.Rows err := client.Driver().Query( context.Background(), `SELECT COUNT(*) FROM user_provider_default_grants WHERE user_id = ? AND provider_type = ? AND grant_reason = ?`, []any{userID, providerType, grantReason}, &rows, ) require.NoError(t, err) defer rows.Close() require.True(t, rows.Next()) var count int require.NoError(t, rows.Scan(&count)) require.False(t, rows.Next()) return count } type oauthPendingFlowUserRepo struct { client *dbent.Client } func (r *oauthPendingFlowUserRepo) Create(ctx context.Context, user *service.User) error { entity, err := r.client.User.Create(). SetEmail(user.Email). SetUsername(user.Username). SetNotes(user.Notes). SetPasswordHash(user.PasswordHash). SetRole(user.Role). SetBalance(user.Balance). SetConcurrency(user.Concurrency). SetStatus(user.Status). SetNillableTotpSecretEncrypted(user.TotpSecretEncrypted). SetTotpEnabled(user.TotpEnabled). SetNillableTotpEnabledAt(user.TotpEnabledAt). SetTotalRecharged(user.TotalRecharged). SetSignupSource(user.SignupSource). SetNillableLastLoginAt(user.LastLoginAt). SetNillableLastActiveAt(user.LastActiveAt). Save(ctx) if err != nil { return err } user.ID = entity.ID user.CreatedAt = entity.CreatedAt user.UpdatedAt = entity.UpdatedAt return nil } func (r *oauthPendingFlowUserRepo) GetByID(ctx context.Context, id int64) (*service.User, error) { entity, err := r.client.User.Get(ctx, id) if err != nil { if dbent.IsNotFound(err) { return nil, service.ErrUserNotFound } return nil, err } return oauthPendingFlowServiceUser(entity), nil } func (r *oauthPendingFlowUserRepo) GetByEmail(ctx context.Context, email string) (*service.User, error) { entity, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Only(ctx) if err != nil { if dbent.IsNotFound(err) { return nil, service.ErrUserNotFound } return nil, err } return oauthPendingFlowServiceUser(entity), nil } func (r *oauthPendingFlowUserRepo) GetFirstAdmin(context.Context) (*service.User, error) { panic("unexpected GetFirstAdmin call") } func (r *oauthPendingFlowUserRepo) Update(ctx context.Context, user *service.User) error { entity, err := r.client.User.UpdateOneID(user.ID). SetEmail(user.Email). SetUsername(user.Username). SetNotes(user.Notes). SetPasswordHash(user.PasswordHash). SetRole(user.Role). SetBalance(user.Balance). SetConcurrency(user.Concurrency). SetStatus(user.Status). SetNillableTotpSecretEncrypted(user.TotpSecretEncrypted). SetTotpEnabled(user.TotpEnabled). SetNillableTotpEnabledAt(user.TotpEnabledAt). SetTotalRecharged(user.TotalRecharged). SetSignupSource(user.SignupSource). SetNillableLastLoginAt(user.LastLoginAt). SetNillableLastActiveAt(user.LastActiveAt). Save(ctx) if err != nil { return err } user.UpdatedAt = entity.UpdatedAt return nil } func (r *oauthPendingFlowUserRepo) Delete(ctx context.Context, id int64) error { return r.client.User.DeleteOneID(id).Exec(ctx) } func (r *oauthPendingFlowUserRepo) GetUserAvatar(context.Context, int64) (*service.UserAvatar, error) { return nil, nil } func (r *oauthPendingFlowUserRepo) UpsertUserAvatar(context.Context, int64, service.UpsertUserAvatarInput) (*service.UserAvatar, error) { panic("unexpected UpsertUserAvatar call") } func (r *oauthPendingFlowUserRepo) DeleteUserAvatar(context.Context, int64) error { return nil } func (r *oauthPendingFlowUserRepo) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) { panic("unexpected List call") } func (r *oauthPendingFlowUserRepo) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) { panic("unexpected ListWithFilters call") } func (r *oauthPendingFlowUserRepo) UpdateBalance(context.Context, int64, float64) error { panic("unexpected UpdateBalance call") } func (r *oauthPendingFlowUserRepo) DeductBalance(context.Context, int64, float64) error { panic("unexpected DeductBalance call") } func (r *oauthPendingFlowUserRepo) UpdateConcurrency(context.Context, int64, int) error { panic("unexpected UpdateConcurrency call") } func (r *oauthPendingFlowUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) { count, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Count(ctx) return count > 0, err } func (r *oauthPendingFlowUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { panic("unexpected RemoveGroupFromAllowedGroups call") } func (r *oauthPendingFlowUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error { panic("unexpected AddGroupToAllowedGroups call") } func (r *oauthPendingFlowUserRepo) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error { panic("unexpected RemoveGroupFromUserAllowedGroups call") } func (r *oauthPendingFlowUserRepo) ListUserAuthIdentities(ctx context.Context, userID int64) ([]service.UserAuthIdentityRecord, error) { identities, err := r.client.AuthIdentity.Query(). Where(authidentity.UserIDEQ(userID)). All(ctx) if err != nil { return nil, err } records := make([]service.UserAuthIdentityRecord, 0, len(identities)) for _, identity := range identities { if identity == nil { continue } records = append(records, service.UserAuthIdentityRecord{ ProviderType: identity.ProviderType, ProviderKey: identity.ProviderKey, ProviderSubject: identity.ProviderSubject, VerifiedAt: identity.VerifiedAt, Issuer: identity.Issuer, Metadata: identity.Metadata, CreatedAt: identity.CreatedAt, UpdatedAt: identity.UpdatedAt, }) } return records, nil } func (r *oauthPendingFlowUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error { update := r.client.User.UpdateOneID(userID) if encryptedSecret == nil { update = update.ClearTotpSecretEncrypted() } else { update = update.SetTotpSecretEncrypted(*encryptedSecret) } return update.Exec(ctx) } func (r *oauthPendingFlowUserRepo) EnableTotp(ctx context.Context, userID int64) error { return r.client.User.UpdateOneID(userID). SetTotpEnabled(true). SetTotpEnabledAt(time.Now().UTC()). Exec(ctx) } func (r *oauthPendingFlowUserRepo) DisableTotp(ctx context.Context, userID int64) error { return r.client.User.UpdateOneID(userID). SetTotpEnabled(false). ClearTotpSecretEncrypted(). ClearTotpEnabledAt(). Exec(ctx) } func oauthPendingFlowServiceUser(entity *dbent.User) *service.User { if entity == nil { return nil } return &service.User{ ID: entity.ID, Email: entity.Email, Username: entity.Username, Notes: entity.Notes, PasswordHash: entity.PasswordHash, Role: entity.Role, Balance: entity.Balance, Concurrency: entity.Concurrency, Status: entity.Status, SignupSource: entity.SignupSource, LastLoginAt: entity.LastLoginAt, LastActiveAt: entity.LastActiveAt, TotpSecretEncrypted: entity.TotpSecretEncrypted, TotpEnabled: entity.TotpEnabled, TotpEnabledAt: entity.TotpEnabledAt, TotalRecharged: entity.TotalRecharged, CreatedAt: entity.CreatedAt, UpdatedAt: entity.UpdatedAt, } } type oauthPendingFlowDefaultSubAssignerStub struct { calls []service.AssignSubscriptionInput } func (s *oauthPendingFlowDefaultSubAssignerStub) AssignOrExtendSubscription( _ context.Context, input *service.AssignSubscriptionInput, ) (*service.UserSubscription, bool, error) { if input != nil { s.calls = append(s.calls, *input) } return nil, false, nil } type oauthPendingFlowTotpCacheStub struct { setupSessions map[int64]*service.TotpSetupSession loginSessions map[string]*service.TotpLoginSession verifyAttempts map[int64]int } func (s *oauthPendingFlowTotpCacheStub) GetSetupSession(_ context.Context, userID int64) (*service.TotpSetupSession, error) { if s == nil || s.setupSessions == nil { return nil, nil } return s.setupSessions[userID], nil } func (s *oauthPendingFlowTotpCacheStub) SetSetupSession(_ context.Context, userID int64, session *service.TotpSetupSession, _ time.Duration) error { if s.setupSessions == nil { s.setupSessions = map[int64]*service.TotpSetupSession{} } s.setupSessions[userID] = session return nil } func (s *oauthPendingFlowTotpCacheStub) DeleteSetupSession(_ context.Context, userID int64) error { delete(s.setupSessions, userID) return nil } func (s *oauthPendingFlowTotpCacheStub) GetLoginSession(_ context.Context, tempToken string) (*service.TotpLoginSession, error) { if s == nil || s.loginSessions == nil { return nil, nil } return s.loginSessions[tempToken], nil } func (s *oauthPendingFlowTotpCacheStub) SetLoginSession(_ context.Context, tempToken string, session *service.TotpLoginSession, _ time.Duration) error { if s.loginSessions == nil { s.loginSessions = map[string]*service.TotpLoginSession{} } s.loginSessions[tempToken] = session return nil } func (s *oauthPendingFlowTotpCacheStub) DeleteLoginSession(_ context.Context, tempToken string) error { delete(s.loginSessions, tempToken) return nil } func (s *oauthPendingFlowTotpCacheStub) IncrementVerifyAttempts(_ context.Context, userID int64) (int, error) { if s.verifyAttempts == nil { s.verifyAttempts = map[int64]int{} } s.verifyAttempts[userID]++ return s.verifyAttempts[userID], nil } func (s *oauthPendingFlowTotpCacheStub) GetVerifyAttempts(_ context.Context, userID int64) (int, error) { if s == nil || s.verifyAttempts == nil { return 0, nil } return s.verifyAttempts[userID], nil } func (s *oauthPendingFlowTotpCacheStub) ClearVerifyAttempts(_ context.Context, userID int64) error { delete(s.verifyAttempts, userID) return nil } type oauthPendingFlowTotpEncryptorStub struct{} func (oauthPendingFlowTotpEncryptorStub) Encrypt(plaintext string) (string, error) { return plaintext, nil } func (oauthPendingFlowTotpEncryptorStub) Decrypt(ciphertext string) (string, error) { return ciphertext, nil }