fix: finalize oauth identity bindings
This commit is contained in:
@@ -10,6 +10,7 @@ import (
|
|||||||
|
|
||||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/authidentity"
|
"github.com/Wei-Shaw/sub2api/ent/authidentity"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
|
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
|
||||||
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
||||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
@@ -309,6 +310,14 @@ func cloneOAuthMetadata(values map[string]any) map[string]any {
|
|||||||
return cloned
|
return cloned
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func mergeOAuthMetadata(base map[string]any, overlay map[string]any) map[string]any {
|
||||||
|
merged := cloneOAuthMetadata(base)
|
||||||
|
for key, value := range overlay {
|
||||||
|
merged[key] = value
|
||||||
|
}
|
||||||
|
return merged
|
||||||
|
}
|
||||||
|
|
||||||
func normalizeAdoptedOAuthDisplayName(value string) string {
|
func normalizeAdoptedOAuthDisplayName(value string) string {
|
||||||
value = strings.TrimSpace(value)
|
value = strings.TrimSpace(value)
|
||||||
if len([]rune(value)) > 100 {
|
if len([]rune(value)) > 100 {
|
||||||
@@ -558,6 +567,10 @@ func oauthIdentityIssuer(session *dbent.PendingAuthSession) *string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func ensurePendingOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, session *dbent.PendingAuthSession, userID int64) (*dbent.AuthIdentity, error) {
|
func ensurePendingOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, session *dbent.PendingAuthSession, userID int64) (*dbent.AuthIdentity, error) {
|
||||||
|
if session != nil && strings.EqualFold(strings.TrimSpace(session.ProviderType), "wechat") {
|
||||||
|
return ensurePendingWeChatOAuthIdentityForUser(ctx, tx, session, userID)
|
||||||
|
}
|
||||||
|
|
||||||
client := tx.Client()
|
client := tx.Client()
|
||||||
identity, err := client.AuthIdentity.Query().
|
identity, err := client.AuthIdentity.Query().
|
||||||
Where(
|
Where(
|
||||||
@@ -588,14 +601,149 @@ func ensurePendingOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, sessio
|
|||||||
return create.Save(ctx)
|
return create.Save(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ensurePendingWeChatOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, session *dbent.PendingAuthSession, userID int64) (*dbent.AuthIdentity, error) {
|
||||||
|
client := tx.Client()
|
||||||
|
providerType := strings.TrimSpace(session.ProviderType)
|
||||||
|
providerKey := strings.TrimSpace(session.ProviderKey)
|
||||||
|
providerSubject := strings.TrimSpace(session.ProviderSubject)
|
||||||
|
channel := strings.TrimSpace(pendingSessionStringValue(session.UpstreamIdentityClaims, "channel"))
|
||||||
|
channelAppID := strings.TrimSpace(pendingSessionStringValue(session.UpstreamIdentityClaims, "channel_app_id"))
|
||||||
|
channelSubject := strings.TrimSpace(pendingSessionStringValue(session.UpstreamIdentityClaims, "channel_subject"))
|
||||||
|
metadata := cloneOAuthMetadata(session.UpstreamIdentityClaims)
|
||||||
|
|
||||||
|
identity, err := client.AuthIdentity.Query().
|
||||||
|
Where(
|
||||||
|
authidentity.ProviderTypeEQ(providerType),
|
||||||
|
authidentity.ProviderKeyEQ(providerKey),
|
||||||
|
authidentity.ProviderSubjectEQ(providerSubject),
|
||||||
|
).
|
||||||
|
Only(ctx)
|
||||||
|
if err != nil && !dbent.IsNotFound(err) {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if identity != nil && identity.UserID != userID {
|
||||||
|
return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
|
||||||
|
}
|
||||||
|
|
||||||
|
var legacyOpenIDIdentity *dbent.AuthIdentity
|
||||||
|
if channelSubject != "" && channelSubject != providerSubject {
|
||||||
|
legacyOpenIDIdentity, err = client.AuthIdentity.Query().
|
||||||
|
Where(
|
||||||
|
authidentity.ProviderTypeEQ(providerType),
|
||||||
|
authidentity.ProviderKeyEQ(providerKey),
|
||||||
|
authidentity.ProviderSubjectEQ(channelSubject),
|
||||||
|
).
|
||||||
|
Only(ctx)
|
||||||
|
if err != nil && !dbent.IsNotFound(err) {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if legacyOpenIDIdentity != nil && legacyOpenIDIdentity.UserID != userID {
|
||||||
|
return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case identity != nil:
|
||||||
|
update := client.AuthIdentity.UpdateOneID(identity.ID).
|
||||||
|
SetMetadata(mergeOAuthMetadata(identity.Metadata, metadata))
|
||||||
|
if issuer := oauthIdentityIssuer(session); issuer != nil {
|
||||||
|
update = update.SetIssuer(strings.TrimSpace(*issuer))
|
||||||
|
}
|
||||||
|
identity, err = update.Save(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
case legacyOpenIDIdentity != nil:
|
||||||
|
update := client.AuthIdentity.UpdateOneID(legacyOpenIDIdentity.ID).
|
||||||
|
SetProviderSubject(providerSubject).
|
||||||
|
SetMetadata(mergeOAuthMetadata(legacyOpenIDIdentity.Metadata, metadata))
|
||||||
|
if issuer := oauthIdentityIssuer(session); issuer != nil {
|
||||||
|
update = update.SetIssuer(strings.TrimSpace(*issuer))
|
||||||
|
}
|
||||||
|
identity, err = update.Save(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
create := client.AuthIdentity.Create().
|
||||||
|
SetUserID(userID).
|
||||||
|
SetProviderType(providerType).
|
||||||
|
SetProviderKey(providerKey).
|
||||||
|
SetProviderSubject(providerSubject).
|
||||||
|
SetMetadata(metadata)
|
||||||
|
if issuer := oauthIdentityIssuer(session); issuer != nil {
|
||||||
|
create = create.SetIssuer(strings.TrimSpace(*issuer))
|
||||||
|
}
|
||||||
|
identity, err = create.Save(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if channel == "" || channelAppID == "" || channelSubject == "" {
|
||||||
|
return identity, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
channelRecord, err := client.AuthIdentityChannel.Query().
|
||||||
|
Where(
|
||||||
|
authidentitychannel.ProviderTypeEQ(providerType),
|
||||||
|
authidentitychannel.ProviderKeyEQ(providerKey),
|
||||||
|
authidentitychannel.ChannelEQ(channel),
|
||||||
|
authidentitychannel.ChannelAppIDEQ(channelAppID),
|
||||||
|
authidentitychannel.ChannelSubjectEQ(channelSubject),
|
||||||
|
).
|
||||||
|
WithIdentity().
|
||||||
|
Only(ctx)
|
||||||
|
if err != nil && !dbent.IsNotFound(err) {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if channelRecord != nil && channelRecord.Edges.Identity != nil && channelRecord.Edges.Identity.UserID != userID {
|
||||||
|
return nil, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user")
|
||||||
|
}
|
||||||
|
|
||||||
|
channelMetadata := mergeOAuthMetadata(channelRecordMetadata(channelRecord), metadata)
|
||||||
|
if channelRecord == nil {
|
||||||
|
if _, err := client.AuthIdentityChannel.Create().
|
||||||
|
SetIdentityID(identity.ID).
|
||||||
|
SetProviderType(providerType).
|
||||||
|
SetProviderKey(providerKey).
|
||||||
|
SetChannel(channel).
|
||||||
|
SetChannelAppID(channelAppID).
|
||||||
|
SetChannelSubject(channelSubject).
|
||||||
|
SetMetadata(channelMetadata).
|
||||||
|
Save(ctx); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return identity, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = client.AuthIdentityChannel.UpdateOneID(channelRecord.ID).
|
||||||
|
SetIdentityID(identity.ID).
|
||||||
|
SetMetadata(channelMetadata).
|
||||||
|
Save(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return identity, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func channelRecordMetadata(channel *dbent.AuthIdentityChannel) map[string]any {
|
||||||
|
if channel == nil {
|
||||||
|
return map[string]any{}
|
||||||
|
}
|
||||||
|
return cloneOAuthMetadata(channel.Metadata)
|
||||||
|
}
|
||||||
|
|
||||||
func shouldBindPendingOAuthIdentity(session *dbent.PendingAuthSession, decision *dbent.IdentityAdoptionDecision) bool {
|
func shouldBindPendingOAuthIdentity(session *dbent.PendingAuthSession, decision *dbent.IdentityAdoptionDecision) bool {
|
||||||
if session == nil || decision == nil {
|
if session == nil || decision == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if strings.EqualFold(strings.TrimSpace(session.Intent), "bind_current_user") {
|
switch strings.ToLower(strings.TrimSpace(session.Intent)) {
|
||||||
|
case "bind_current_user", "login", "adopt_existing_user_by_email":
|
||||||
return true
|
return true
|
||||||
|
default:
|
||||||
|
return decision.AdoptDisplayName || decision.AdoptAvatar
|
||||||
}
|
}
|
||||||
return decision.AdoptDisplayName || decision.AdoptAvatar
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func applyPendingOAuthBinding(
|
func applyPendingOAuthBinding(
|
||||||
|
|||||||
@@ -372,7 +372,7 @@ func TestExchangePendingOAuthCompletionBindCurrentUserOwnershipConflict(t *testi
|
|||||||
require.Nil(t, storedSession.ConsumedAt)
|
require.Nil(t, storedSession.ConsumedAt)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExchangePendingOAuthCompletionLoginFalseFalseDoesNotBindIdentity(t *testing.T) {
|
func TestExchangePendingOAuthCompletionLoginFalseFalseBindsIdentityWithoutAdoption(t *testing.T) {
|
||||||
handler, client := newOAuthPendingFlowTestHandler(t, false)
|
handler, client := newOAuthPendingFlowTestHandler(t, false)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
@@ -420,21 +420,22 @@ func TestExchangePendingOAuthCompletionLoginFalseFalseDoesNotBindIdentity(t *tes
|
|||||||
|
|
||||||
require.Equal(t, http.StatusOK, recorder.Code)
|
require.Equal(t, http.StatusOK, recorder.Code)
|
||||||
|
|
||||||
identityCount, err := client.AuthIdentity.Query().
|
identity, err := client.AuthIdentity.Query().
|
||||||
Where(
|
Where(
|
||||||
authidentity.ProviderTypeEQ("linuxdo"),
|
authidentity.ProviderTypeEQ("linuxdo"),
|
||||||
authidentity.ProviderKeyEQ("linuxdo"),
|
authidentity.ProviderKeyEQ("linuxdo"),
|
||||||
authidentity.ProviderSubjectEQ("login-false-123"),
|
authidentity.ProviderSubjectEQ("login-false-123"),
|
||||||
).
|
).
|
||||||
Count(ctx)
|
Only(ctx)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Zero(t, identityCount)
|
require.Equal(t, userEntity.ID, identity.UserID)
|
||||||
|
|
||||||
decision, err := client.IdentityAdoptionDecision.Query().
|
decision, err := client.IdentityAdoptionDecision.Query().
|
||||||
Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
|
Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
|
||||||
Only(ctx)
|
Only(ctx)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Nil(t, decision.IdentityID)
|
require.NotNil(t, decision.IdentityID)
|
||||||
|
require.Equal(t, identity.ID, *decision.IdentityID)
|
||||||
require.False(t, decision.AdoptDisplayName)
|
require.False(t, decision.AdoptDisplayName)
|
||||||
require.False(t, decision.AdoptAvatar)
|
require.False(t, decision.AdoptAvatar)
|
||||||
|
|
||||||
|
|||||||
@@ -242,7 +242,18 @@ func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) {
|
|||||||
redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
|
redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if existingIdentityUser == nil {
|
||||||
|
existingIdentityUser, err = h.findWeChatUserByLegacyOpenID(c.Request.Context(), identityRef, cfg, openid)
|
||||||
|
if err != nil {
|
||||||
|
redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
if existingIdentityUser != nil {
|
if existingIdentityUser != nil {
|
||||||
|
if err := h.ensureWeChatRuntimeIdentityBinding(c.Request.Context(), existingIdentityUser.ID, identityRef, upstreamClaims); err != nil {
|
||||||
|
redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), existingIdentityUser.Email, username, "")
|
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), existingIdentityUser.Email, username, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
|
redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
|
||||||
@@ -511,6 +522,91 @@ func (h *AuthHandler) ensureWeChatBindOwnership(
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *AuthHandler) findWeChatUserByLegacyOpenID(
|
||||||
|
ctx context.Context,
|
||||||
|
identity service.PendingAuthIdentityKey,
|
||||||
|
cfg wechatOAuthConfig,
|
||||||
|
openid string,
|
||||||
|
) (*dbent.User, error) {
|
||||||
|
client := h.entClient()
|
||||||
|
if client == nil {
|
||||||
|
return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
|
||||||
|
}
|
||||||
|
|
||||||
|
openid = strings.TrimSpace(openid)
|
||||||
|
channel := strings.TrimSpace(cfg.mode)
|
||||||
|
channelAppID := strings.TrimSpace(cfg.appID)
|
||||||
|
if openid != "" && channel != "" && channelAppID != "" {
|
||||||
|
record, err := client.AuthIdentityChannel.Query().
|
||||||
|
Where(
|
||||||
|
authidentitychannel.ProviderTypeEQ(strings.TrimSpace(identity.ProviderType)),
|
||||||
|
authidentitychannel.ProviderKeyEQ(strings.TrimSpace(identity.ProviderKey)),
|
||||||
|
authidentitychannel.ChannelEQ(channel),
|
||||||
|
authidentitychannel.ChannelAppIDEQ(channelAppID),
|
||||||
|
authidentitychannel.ChannelSubjectEQ(openid),
|
||||||
|
).
|
||||||
|
WithIdentity(func(q *dbent.AuthIdentityQuery) {
|
||||||
|
q.WithUser()
|
||||||
|
}).
|
||||||
|
Only(ctx)
|
||||||
|
if err != nil && !dbent.IsNotFound(err) {
|
||||||
|
return nil, infraerrors.InternalServer("AUTH_IDENTITY_CHANNEL_LOOKUP_FAILED", "failed to inspect auth identity channel ownership").WithCause(err)
|
||||||
|
}
|
||||||
|
if record != nil && record.Edges.Identity != nil && record.Edges.Identity.Edges.User != nil {
|
||||||
|
return record.Edges.Identity.Edges.User, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if openid == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
record, err := client.AuthIdentity.Query().
|
||||||
|
Where(
|
||||||
|
authidentity.ProviderTypeEQ(strings.TrimSpace(identity.ProviderType)),
|
||||||
|
authidentity.ProviderKeyEQ(strings.TrimSpace(identity.ProviderKey)),
|
||||||
|
authidentity.ProviderSubjectEQ(openid),
|
||||||
|
).
|
||||||
|
WithUser().
|
||||||
|
Only(ctx)
|
||||||
|
if err != nil {
|
||||||
|
if dbent.IsNotFound(err) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
|
||||||
|
}
|
||||||
|
return record.Edges.User, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *AuthHandler) ensureWeChatRuntimeIdentityBinding(
|
||||||
|
ctx context.Context,
|
||||||
|
userID int64,
|
||||||
|
identity service.PendingAuthIdentityKey,
|
||||||
|
upstreamClaims map[string]any,
|
||||||
|
) error {
|
||||||
|
client := h.entClient()
|
||||||
|
if client == nil {
|
||||||
|
return infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
|
||||||
|
}
|
||||||
|
|
||||||
|
tx, err := client.Tx(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return infraerrors.InternalServer("AUTH_IDENTITY_BIND_FAILED", "failed to begin wechat identity repair transaction").WithCause(err)
|
||||||
|
}
|
||||||
|
defer func() { _ = tx.Rollback() }()
|
||||||
|
|
||||||
|
_, err = ensurePendingOAuthIdentityForUser(dbent.NewTxContext(ctx, tx), tx, &dbent.PendingAuthSession{
|
||||||
|
ProviderType: strings.TrimSpace(identity.ProviderType),
|
||||||
|
ProviderKey: strings.TrimSpace(identity.ProviderKey),
|
||||||
|
ProviderSubject: strings.TrimSpace(identity.ProviderSubject),
|
||||||
|
UpstreamIdentityClaims: cloneOAuthMetadata(upstreamClaims),
|
||||||
|
}, userID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return tx.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
func (h *AuthHandler) getWeChatOAuthConfig(ctx context.Context, rawMode string, c *gin.Context) (wechatOAuthConfig, error) {
|
func (h *AuthHandler) getWeChatOAuthConfig(ctx context.Context, rawMode string, c *gin.Context) (wechatOAuthConfig, error) {
|
||||||
mode, err := resolveWeChatOAuthMode(rawMode, c)
|
mode, err := resolveWeChatOAuthMode(rawMode, c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
|
|
||||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/authidentity"
|
"github.com/Wei-Shaw/sub2api/ent/authidentity"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/enttest"
|
"github.com/Wei-Shaw/sub2api/ent/enttest"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
|
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
|
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
|
||||||
@@ -563,6 +564,19 @@ func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing
|
|||||||
require.Equal(t, "WeChat Display", identity.Metadata["display_name"])
|
require.Equal(t, "WeChat Display", identity.Metadata["display_name"])
|
||||||
require.Equal(t, "https://cdn.example/wechat.png", identity.Metadata["avatar_url"])
|
require.Equal(t, "https://cdn.example/wechat.png", identity.Metadata["avatar_url"])
|
||||||
|
|
||||||
|
channel, err := client.AuthIdentityChannel.Query().
|
||||||
|
Where(
|
||||||
|
authidentitychannel.ProviderTypeEQ("wechat"),
|
||||||
|
authidentitychannel.ProviderKeyEQ("wechat-main"),
|
||||||
|
authidentitychannel.ChannelEQ("open"),
|
||||||
|
authidentitychannel.ChannelAppIDEQ("wx-open-app"),
|
||||||
|
authidentitychannel.ChannelSubjectEQ("openid-123"),
|
||||||
|
).
|
||||||
|
Only(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, identity.ID, channel.IdentityID)
|
||||||
|
require.Equal(t, "union-456", channel.Metadata["unionid"])
|
||||||
|
|
||||||
decision, err := client.IdentityAdoptionDecision.Query().
|
decision, err := client.IdentityAdoptionDecision.Query().
|
||||||
Where(identityadoptiondecision.PendingAuthSessionIDEQ(pendingSession.ID)).
|
Where(identityadoptiondecision.PendingAuthSessionIDEQ(pendingSession.ID)).
|
||||||
Only(ctx)
|
Only(ctx)
|
||||||
@@ -579,6 +593,116 @@ func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing
|
|||||||
require.NotNil(t, consumed.ConsumedAt)
|
require.NotNil(t, consumed.ConsumedAt)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWeChatOAuthCallbackRepairsLegacyOpenIDOnlyIdentity(t *testing.T) {
|
||||||
|
t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app")
|
||||||
|
t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret")
|
||||||
|
t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback")
|
||||||
|
|
||||||
|
originalAccessTokenURL := wechatOAuthAccessTokenURL
|
||||||
|
originalUserInfoURL := wechatOAuthUserInfoURL
|
||||||
|
t.Cleanup(func() {
|
||||||
|
wechatOAuthAccessTokenURL = originalAccessTokenURL
|
||||||
|
wechatOAuthUserInfoURL = originalUserInfoURL
|
||||||
|
})
|
||||||
|
|
||||||
|
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch {
|
||||||
|
case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
|
||||||
|
case strings.Contains(r.URL.Path, "/sns/userinfo"):
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Legacy WeChat","headimgurl":"https://cdn.example/legacy.png"}`))
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer upstream.Close()
|
||||||
|
wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
|
||||||
|
wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
|
||||||
|
|
||||||
|
handler, client := newWeChatOAuthTestHandler(t, false)
|
||||||
|
defer client.Close()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
legacyUser, err := client.User.Create().
|
||||||
|
SetEmail("legacy@example.com").
|
||||||
|
SetUsername("legacy-user").
|
||||||
|
SetPasswordHash("hash").
|
||||||
|
SetRole(service.RoleUser).
|
||||||
|
SetStatus(service.StatusActive).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
legacyIdentity, err := client.AuthIdentity.Create().
|
||||||
|
SetUserID(legacyUser.ID).
|
||||||
|
SetProviderType("wechat").
|
||||||
|
SetProviderKey(wechatOAuthProviderKey).
|
||||||
|
SetProviderSubject("openid-123").
|
||||||
|
SetMetadata(map[string]any{"openid": "openid-123"}).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
|
||||||
|
req.Host = "api.example.com"
|
||||||
|
req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
|
||||||
|
req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
|
||||||
|
req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
|
||||||
|
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
handler.WeChatOAuthCallback(c)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusFound, recorder.Code)
|
||||||
|
require.Equal(t, "/auth/wechat/callback", recorder.Header().Get("Location"))
|
||||||
|
|
||||||
|
sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
|
||||||
|
require.NotNil(t, sessionCookie)
|
||||||
|
|
||||||
|
session, err := client.PendingAuthSession.Query().
|
||||||
|
Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
|
||||||
|
Only(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, session.TargetUserID)
|
||||||
|
require.Equal(t, legacyUser.ID, *session.TargetUserID)
|
||||||
|
require.Equal(t, legacyUser.Email, session.ResolvedEmail)
|
||||||
|
|
||||||
|
repairedIdentity, err := client.AuthIdentity.Query().
|
||||||
|
Where(
|
||||||
|
authidentity.ProviderTypeEQ("wechat"),
|
||||||
|
authidentity.ProviderKeyEQ(wechatOAuthProviderKey),
|
||||||
|
authidentity.ProviderSubjectEQ("union-456"),
|
||||||
|
).
|
||||||
|
Only(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, legacyIdentity.ID, repairedIdentity.ID)
|
||||||
|
require.Equal(t, legacyUser.ID, repairedIdentity.UserID)
|
||||||
|
|
||||||
|
openIDIdentityCount, err := client.AuthIdentity.Query().
|
||||||
|
Where(
|
||||||
|
authidentity.ProviderTypeEQ("wechat"),
|
||||||
|
authidentity.ProviderKeyEQ(wechatOAuthProviderKey),
|
||||||
|
authidentity.ProviderSubjectEQ("openid-123"),
|
||||||
|
).
|
||||||
|
Count(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Zero(t, openIDIdentityCount)
|
||||||
|
|
||||||
|
channel, err := client.AuthIdentityChannel.Query().
|
||||||
|
Where(
|
||||||
|
authidentitychannel.ProviderTypeEQ("wechat"),
|
||||||
|
authidentitychannel.ProviderKeyEQ(wechatOAuthProviderKey),
|
||||||
|
authidentitychannel.ChannelEQ("open"),
|
||||||
|
authidentitychannel.ChannelAppIDEQ("wx-open-app"),
|
||||||
|
authidentitychannel.ChannelSubjectEQ("openid-123"),
|
||||||
|
).
|
||||||
|
Only(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, repairedIdentity.ID, channel.IdentityID)
|
||||||
|
}
|
||||||
|
|
||||||
func newWeChatOAuthTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) {
|
func newWeChatOAuthTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user