fix(unit): restore secure oidc defaults and wechat alias reuse
This commit is contained in:
@@ -879,6 +879,8 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6
|
||||
if providerKey == "" || providerSubject == "" {
|
||||
return nil, infraerrors.BadRequest("INVALID_INPUT", "provider_type, provider_key, and provider_subject are required")
|
||||
}
|
||||
canonicalProviderKey := canonicalAdminAuthIdentityProviderKey(providerType, "", providerKey)
|
||||
compatibleProviderKeys := compatibleAdminAuthIdentityProviderKeys(providerType, providerKey)
|
||||
|
||||
var issuer *string
|
||||
if input.Issuer != nil {
|
||||
@@ -900,25 +902,26 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
identity, err := tx.AuthIdentity.Query().
|
||||
identityRecords, err := tx.AuthIdentity.Query().
|
||||
Where(
|
||||
authidentity.ProviderTypeEQ(providerType),
|
||||
authidentity.ProviderKeyEQ(providerKey),
|
||||
authidentity.ProviderKeyIn(compatibleProviderKeys...),
|
||||
authidentity.ProviderSubjectEQ(providerSubject),
|
||||
).
|
||||
Only(ctx)
|
||||
if err != nil && !dbent.IsNotFound(err) {
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
|
||||
}
|
||||
if identity != nil && identity.UserID != userID {
|
||||
if hasAdminAuthIdentityOwnershipConflict(identityRecords, userID) {
|
||||
return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
|
||||
}
|
||||
identity := selectOwnedAdminAuthIdentity(identityRecords, userID)
|
||||
|
||||
if identity == nil {
|
||||
create := tx.AuthIdentity.Create().
|
||||
SetUserID(userID).
|
||||
SetProviderType(providerType).
|
||||
SetProviderKey(providerKey).
|
||||
SetProviderKey(canonicalProviderKey).
|
||||
SetProviderSubject(providerSubject).
|
||||
SetVerifiedAt(verifiedAt)
|
||||
if issuer != nil {
|
||||
@@ -932,7 +935,9 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6
|
||||
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_SAVE_FAILED", "failed to save auth identity").WithCause(err)
|
||||
}
|
||||
} else {
|
||||
update := tx.AuthIdentity.UpdateOneID(identity.ID).SetVerifiedAt(verifiedAt)
|
||||
update := tx.AuthIdentity.UpdateOneID(identity.ID).
|
||||
SetVerifiedAt(verifiedAt).
|
||||
SetProviderKey(canonicalProviderKey)
|
||||
if issuer != nil {
|
||||
update = update.SetIssuer(*issuer)
|
||||
}
|
||||
@@ -947,27 +952,28 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6
|
||||
|
||||
var channel *dbent.AuthIdentityChannel
|
||||
if channelInput != nil {
|
||||
channel, err = tx.AuthIdentityChannel.Query().
|
||||
channelRecords, err := tx.AuthIdentityChannel.Query().
|
||||
Where(
|
||||
authidentitychannel.ProviderTypeEQ(providerType),
|
||||
authidentitychannel.ProviderKeyEQ(providerKey),
|
||||
authidentitychannel.ProviderKeyIn(compatibleProviderKeys...),
|
||||
authidentitychannel.ChannelEQ(channelInput.Channel),
|
||||
authidentitychannel.ChannelAppIDEQ(channelInput.ChannelAppID),
|
||||
authidentitychannel.ChannelSubjectEQ(channelInput.ChannelSubject),
|
||||
).
|
||||
WithIdentity().
|
||||
Only(ctx)
|
||||
if err != nil && !dbent.IsNotFound(err) {
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_LOOKUP_FAILED", "failed to inspect auth identity channel ownership").WithCause(err)
|
||||
}
|
||||
if channel != nil && channel.Edges.Identity != nil && channel.Edges.Identity.UserID != userID {
|
||||
if hasAdminAuthIdentityChannelOwnershipConflict(channelRecords, userID) {
|
||||
return nil, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user")
|
||||
}
|
||||
channel = selectOwnedAdminAuthIdentityChannel(channelRecords, userID)
|
||||
if channel == nil {
|
||||
create := tx.AuthIdentityChannel.Create().
|
||||
SetIdentityID(identity.ID).
|
||||
SetProviderType(providerType).
|
||||
SetProviderKey(providerKey).
|
||||
SetProviderKey(canonicalProviderKey).
|
||||
SetChannel(channelInput.Channel).
|
||||
SetChannelAppID(channelInput.ChannelAppID).
|
||||
SetChannelSubject(channelInput.ChannelSubject)
|
||||
@@ -979,7 +985,9 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6
|
||||
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_SAVE_FAILED", "failed to save auth identity channel").WithCause(err)
|
||||
}
|
||||
} else {
|
||||
update := tx.AuthIdentityChannel.UpdateOneID(channel.ID).SetIdentityID(identity.ID)
|
||||
update := tx.AuthIdentityChannel.UpdateOneID(channel.ID).
|
||||
SetIdentityID(identity.ID).
|
||||
SetProviderKey(canonicalProviderKey)
|
||||
if channelInput.Metadata != nil {
|
||||
update = update.SetMetadata(cloneAdminAuthIdentityMetadata(channelInput.Metadata))
|
||||
}
|
||||
@@ -996,6 +1004,105 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6
|
||||
return buildAdminBoundAuthIdentity(identity, channel), nil
|
||||
}
|
||||
|
||||
func compatibleAdminAuthIdentityProviderKeys(providerType, providerKey string) []string {
|
||||
providerType = strings.TrimSpace(strings.ToLower(providerType))
|
||||
providerKey = strings.TrimSpace(providerKey)
|
||||
if providerKey == "" {
|
||||
return []string{providerKey}
|
||||
}
|
||||
if providerType != "wechat" {
|
||||
return []string{providerKey}
|
||||
}
|
||||
|
||||
keys := []string{providerKey}
|
||||
if !strings.EqualFold(providerKey, "wechat-main") {
|
||||
keys = append(keys, "wechat-main")
|
||||
}
|
||||
if !strings.EqualFold(providerKey, "wechat") {
|
||||
keys = append(keys, "wechat")
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
func canonicalAdminAuthIdentityProviderKey(providerType, existingKey, requestedKey string) string {
|
||||
providerType = strings.TrimSpace(strings.ToLower(providerType))
|
||||
existingKey = strings.TrimSpace(existingKey)
|
||||
requestedKey = strings.TrimSpace(requestedKey)
|
||||
if providerType != "wechat" {
|
||||
if requestedKey != "" {
|
||||
return requestedKey
|
||||
}
|
||||
return existingKey
|
||||
}
|
||||
if strings.EqualFold(existingKey, "wechat") || strings.EqualFold(existingKey, "wechat-main") || strings.EqualFold(requestedKey, "wechat-main") {
|
||||
return "wechat-main"
|
||||
}
|
||||
if requestedKey != "" {
|
||||
return requestedKey
|
||||
}
|
||||
return existingKey
|
||||
}
|
||||
|
||||
func adminAuthIdentityProviderKeyRank(providerType, providerKey string) int {
|
||||
providerType = strings.TrimSpace(strings.ToLower(providerType))
|
||||
providerKey = strings.TrimSpace(providerKey)
|
||||
if providerType != "wechat" {
|
||||
return 0
|
||||
}
|
||||
switch {
|
||||
case strings.EqualFold(providerKey, "wechat-main"):
|
||||
return 0
|
||||
case strings.EqualFold(providerKey, "wechat"):
|
||||
return 2
|
||||
default:
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
func selectOwnedAdminAuthIdentity(records []*dbent.AuthIdentity, userID int64) *dbent.AuthIdentity {
|
||||
var selected *dbent.AuthIdentity
|
||||
for _, record := range records {
|
||||
if record.UserID != userID {
|
||||
continue
|
||||
}
|
||||
if selected == nil || adminAuthIdentityProviderKeyRank(record.ProviderType, record.ProviderKey) < adminAuthIdentityProviderKeyRank(selected.ProviderType, selected.ProviderKey) {
|
||||
selected = record
|
||||
}
|
||||
}
|
||||
return selected
|
||||
}
|
||||
|
||||
func hasAdminAuthIdentityOwnershipConflict(records []*dbent.AuthIdentity, userID int64) bool {
|
||||
for _, record := range records {
|
||||
if record.UserID != userID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func selectOwnedAdminAuthIdentityChannel(records []*dbent.AuthIdentityChannel, userID int64) *dbent.AuthIdentityChannel {
|
||||
var selected *dbent.AuthIdentityChannel
|
||||
for _, record := range records {
|
||||
if record.Edges.Identity == nil || record.Edges.Identity.UserID != userID {
|
||||
continue
|
||||
}
|
||||
if selected == nil || adminAuthIdentityProviderKeyRank(record.ProviderType, record.ProviderKey) < adminAuthIdentityProviderKeyRank(selected.ProviderType, selected.ProviderKey) {
|
||||
selected = record
|
||||
}
|
||||
}
|
||||
return selected
|
||||
}
|
||||
|
||||
func hasAdminAuthIdentityChannelOwnershipConflict(records []*dbent.AuthIdentityChannel, userID int64) bool {
|
||||
for _, record := range records {
|
||||
if record.Edges.Identity != nil && record.Edges.Identity.UserID != userID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func normalizeAdminBindChannelInput(input *AdminBindAuthIdentityChannelInput) *AdminBindAuthIdentityChannelInput {
|
||||
if input == nil {
|
||||
return nil
|
||||
|
||||
@@ -818,14 +818,14 @@ func oidcUsePKCECompatibilityDefault(base config.OIDCConnectConfig) bool {
|
||||
if base.UsePKCEExplicit {
|
||||
return base.UsePKCE
|
||||
}
|
||||
return false
|
||||
return true
|
||||
}
|
||||
|
||||
func oidcValidateIDTokenCompatibilityDefault(base config.OIDCConnectConfig) bool {
|
||||
if base.ValidateIDTokenExplicit {
|
||||
return base.ValidateIDToken
|
||||
}
|
||||
return false
|
||||
return true
|
||||
}
|
||||
|
||||
// UpdateSettings 更新系统设置
|
||||
|
||||
@@ -133,7 +133,7 @@ func TestSettingService_ParseSettings_DefaultsOIDCSecurityFlagsToSafeConfigValue
|
||||
require.True(t, got.OIDCConnectValidateIDToken)
|
||||
}
|
||||
|
||||
func TestSettingService_ParseSettings_UsesLegacyOIDCCompatibilityFlagsWhenSettingsMissing(t *testing.T) {
|
||||
func TestSettingService_ParseSettings_DefaultsOIDCCompatibilityFlagsToSafeDefaultsWhenSettingsMissing(t *testing.T) {
|
||||
svc := NewSettingService(&settingOIDCRepoStub{values: map[string]string{}}, &config.Config{
|
||||
OIDC: config.OIDCConnectConfig{
|
||||
UsePKCE: true,
|
||||
@@ -145,8 +145,8 @@ func TestSettingService_ParseSettings_UsesLegacyOIDCCompatibilityFlagsWhenSettin
|
||||
SettingKeyOIDCConnectEnabled: "true",
|
||||
})
|
||||
|
||||
require.False(t, got.OIDCConnectUsePKCE)
|
||||
require.False(t, got.OIDCConnectValidateIDToken)
|
||||
require.True(t, got.OIDCConnectUsePKCE)
|
||||
require.True(t, got.OIDCConnectValidateIDToken)
|
||||
}
|
||||
|
||||
func TestGetOIDCConnectOAuthConfig_AllowsCompatibilityFlagsToDisablePKCEAndIDTokenValidation(t *testing.T) {
|
||||
@@ -216,7 +216,7 @@ func TestGetOIDCConnectOAuthConfig_DefaultsToSecureFlagsWhenSettingsMissing(t *t
|
||||
require.True(t, got.ValidateIDToken)
|
||||
}
|
||||
|
||||
func TestGetOIDCConnectOAuthConfig_UsesLegacyOIDCCompatibilityFlagsWhenSettingsMissing(t *testing.T) {
|
||||
func TestGetOIDCConnectOAuthConfig_DefaultsCompatibilityFlagsToSafeValuesWhenSettingsMissing(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
OIDC: config.OIDCConnectConfig{
|
||||
Enabled: true,
|
||||
@@ -246,6 +246,6 @@ func TestGetOIDCConnectOAuthConfig_UsesLegacyOIDCCompatibilityFlagsWhenSettingsM
|
||||
|
||||
got, err := svc.GetOIDCConnectOAuthConfig(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.False(t, got.UsePKCE)
|
||||
require.False(t, got.ValidateIDToken)
|
||||
require.True(t, got.UsePKCE)
|
||||
require.True(t, got.ValidateIDToken)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user