fix: finalize oauth identity bindings
This commit is contained in:
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/authidentity"
|
||||
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
|
||||
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
|
||||
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
@@ -309,6 +310,14 @@ func cloneOAuthMetadata(values map[string]any) map[string]any {
|
||||
return cloned
|
||||
}
|
||||
|
||||
func mergeOAuthMetadata(base map[string]any, overlay map[string]any) map[string]any {
|
||||
merged := cloneOAuthMetadata(base)
|
||||
for key, value := range overlay {
|
||||
merged[key] = value
|
||||
}
|
||||
return merged
|
||||
}
|
||||
|
||||
func normalizeAdoptedOAuthDisplayName(value string) string {
|
||||
value = strings.TrimSpace(value)
|
||||
if len([]rune(value)) > 100 {
|
||||
@@ -558,6 +567,10 @@ func oauthIdentityIssuer(session *dbent.PendingAuthSession) *string {
|
||||
}
|
||||
|
||||
func ensurePendingOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, session *dbent.PendingAuthSession, userID int64) (*dbent.AuthIdentity, error) {
|
||||
if session != nil && strings.EqualFold(strings.TrimSpace(session.ProviderType), "wechat") {
|
||||
return ensurePendingWeChatOAuthIdentityForUser(ctx, tx, session, userID)
|
||||
}
|
||||
|
||||
client := tx.Client()
|
||||
identity, err := client.AuthIdentity.Query().
|
||||
Where(
|
||||
@@ -588,14 +601,149 @@ func ensurePendingOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, sessio
|
||||
return create.Save(ctx)
|
||||
}
|
||||
|
||||
func ensurePendingWeChatOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, session *dbent.PendingAuthSession, userID int64) (*dbent.AuthIdentity, error) {
|
||||
client := tx.Client()
|
||||
providerType := strings.TrimSpace(session.ProviderType)
|
||||
providerKey := strings.TrimSpace(session.ProviderKey)
|
||||
providerSubject := strings.TrimSpace(session.ProviderSubject)
|
||||
channel := strings.TrimSpace(pendingSessionStringValue(session.UpstreamIdentityClaims, "channel"))
|
||||
channelAppID := strings.TrimSpace(pendingSessionStringValue(session.UpstreamIdentityClaims, "channel_app_id"))
|
||||
channelSubject := strings.TrimSpace(pendingSessionStringValue(session.UpstreamIdentityClaims, "channel_subject"))
|
||||
metadata := cloneOAuthMetadata(session.UpstreamIdentityClaims)
|
||||
|
||||
identity, err := client.AuthIdentity.Query().
|
||||
Where(
|
||||
authidentity.ProviderTypeEQ(providerType),
|
||||
authidentity.ProviderKeyEQ(providerKey),
|
||||
authidentity.ProviderSubjectEQ(providerSubject),
|
||||
).
|
||||
Only(ctx)
|
||||
if err != nil && !dbent.IsNotFound(err) {
|
||||
return nil, err
|
||||
}
|
||||
if identity != nil && identity.UserID != userID {
|
||||
return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
|
||||
}
|
||||
|
||||
var legacyOpenIDIdentity *dbent.AuthIdentity
|
||||
if channelSubject != "" && channelSubject != providerSubject {
|
||||
legacyOpenIDIdentity, err = client.AuthIdentity.Query().
|
||||
Where(
|
||||
authidentity.ProviderTypeEQ(providerType),
|
||||
authidentity.ProviderKeyEQ(providerKey),
|
||||
authidentity.ProviderSubjectEQ(channelSubject),
|
||||
).
|
||||
Only(ctx)
|
||||
if err != nil && !dbent.IsNotFound(err) {
|
||||
return nil, err
|
||||
}
|
||||
if legacyOpenIDIdentity != nil && legacyOpenIDIdentity.UserID != userID {
|
||||
return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case identity != nil:
|
||||
update := client.AuthIdentity.UpdateOneID(identity.ID).
|
||||
SetMetadata(mergeOAuthMetadata(identity.Metadata, metadata))
|
||||
if issuer := oauthIdentityIssuer(session); issuer != nil {
|
||||
update = update.SetIssuer(strings.TrimSpace(*issuer))
|
||||
}
|
||||
identity, err = update.Save(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case legacyOpenIDIdentity != nil:
|
||||
update := client.AuthIdentity.UpdateOneID(legacyOpenIDIdentity.ID).
|
||||
SetProviderSubject(providerSubject).
|
||||
SetMetadata(mergeOAuthMetadata(legacyOpenIDIdentity.Metadata, metadata))
|
||||
if issuer := oauthIdentityIssuer(session); issuer != nil {
|
||||
update = update.SetIssuer(strings.TrimSpace(*issuer))
|
||||
}
|
||||
identity, err = update.Save(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
default:
|
||||
create := client.AuthIdentity.Create().
|
||||
SetUserID(userID).
|
||||
SetProviderType(providerType).
|
||||
SetProviderKey(providerKey).
|
||||
SetProviderSubject(providerSubject).
|
||||
SetMetadata(metadata)
|
||||
if issuer := oauthIdentityIssuer(session); issuer != nil {
|
||||
create = create.SetIssuer(strings.TrimSpace(*issuer))
|
||||
}
|
||||
identity, err = create.Save(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if channel == "" || channelAppID == "" || channelSubject == "" {
|
||||
return identity, nil
|
||||
}
|
||||
|
||||
channelRecord, err := client.AuthIdentityChannel.Query().
|
||||
Where(
|
||||
authidentitychannel.ProviderTypeEQ(providerType),
|
||||
authidentitychannel.ProviderKeyEQ(providerKey),
|
||||
authidentitychannel.ChannelEQ(channel),
|
||||
authidentitychannel.ChannelAppIDEQ(channelAppID),
|
||||
authidentitychannel.ChannelSubjectEQ(channelSubject),
|
||||
).
|
||||
WithIdentity().
|
||||
Only(ctx)
|
||||
if err != nil && !dbent.IsNotFound(err) {
|
||||
return nil, err
|
||||
}
|
||||
if channelRecord != nil && channelRecord.Edges.Identity != nil && channelRecord.Edges.Identity.UserID != userID {
|
||||
return nil, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user")
|
||||
}
|
||||
|
||||
channelMetadata := mergeOAuthMetadata(channelRecordMetadata(channelRecord), metadata)
|
||||
if channelRecord == nil {
|
||||
if _, err := client.AuthIdentityChannel.Create().
|
||||
SetIdentityID(identity.ID).
|
||||
SetProviderType(providerType).
|
||||
SetProviderKey(providerKey).
|
||||
SetChannel(channel).
|
||||
SetChannelAppID(channelAppID).
|
||||
SetChannelSubject(channelSubject).
|
||||
SetMetadata(channelMetadata).
|
||||
Save(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return identity, nil
|
||||
}
|
||||
|
||||
_, err = client.AuthIdentityChannel.UpdateOneID(channelRecord.ID).
|
||||
SetIdentityID(identity.ID).
|
||||
SetMetadata(channelMetadata).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return identity, nil
|
||||
}
|
||||
|
||||
func channelRecordMetadata(channel *dbent.AuthIdentityChannel) map[string]any {
|
||||
if channel == nil {
|
||||
return map[string]any{}
|
||||
}
|
||||
return cloneOAuthMetadata(channel.Metadata)
|
||||
}
|
||||
|
||||
func shouldBindPendingOAuthIdentity(session *dbent.PendingAuthSession, decision *dbent.IdentityAdoptionDecision) bool {
|
||||
if session == nil || decision == nil {
|
||||
return false
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(session.Intent), "bind_current_user") {
|
||||
switch strings.ToLower(strings.TrimSpace(session.Intent)) {
|
||||
case "bind_current_user", "login", "adopt_existing_user_by_email":
|
||||
return true
|
||||
default:
|
||||
return decision.AdoptDisplayName || decision.AdoptAvatar
|
||||
}
|
||||
return decision.AdoptDisplayName || decision.AdoptAvatar
|
||||
}
|
||||
|
||||
func applyPendingOAuthBinding(
|
||||
|
||||
Reference in New Issue
Block a user