fix(unit): restore secure oidc defaults and wechat alias reuse
This commit is contained in:
@@ -879,6 +879,8 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6
|
|||||||
if providerKey == "" || providerSubject == "" {
|
if providerKey == "" || providerSubject == "" {
|
||||||
return nil, infraerrors.BadRequest("INVALID_INPUT", "provider_type, provider_key, and provider_subject are required")
|
return nil, infraerrors.BadRequest("INVALID_INPUT", "provider_type, provider_key, and provider_subject are required")
|
||||||
}
|
}
|
||||||
|
canonicalProviderKey := canonicalAdminAuthIdentityProviderKey(providerType, "", providerKey)
|
||||||
|
compatibleProviderKeys := compatibleAdminAuthIdentityProviderKeys(providerType, providerKey)
|
||||||
|
|
||||||
var issuer *string
|
var issuer *string
|
||||||
if input.Issuer != nil {
|
if input.Issuer != nil {
|
||||||
@@ -900,25 +902,26 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6
|
|||||||
}
|
}
|
||||||
defer func() { _ = tx.Rollback() }()
|
defer func() { _ = tx.Rollback() }()
|
||||||
|
|
||||||
identity, err := tx.AuthIdentity.Query().
|
identityRecords, err := tx.AuthIdentity.Query().
|
||||||
Where(
|
Where(
|
||||||
authidentity.ProviderTypeEQ(providerType),
|
authidentity.ProviderTypeEQ(providerType),
|
||||||
authidentity.ProviderKeyEQ(providerKey),
|
authidentity.ProviderKeyIn(compatibleProviderKeys...),
|
||||||
authidentity.ProviderSubjectEQ(providerSubject),
|
authidentity.ProviderSubjectEQ(providerSubject),
|
||||||
).
|
).
|
||||||
Only(ctx)
|
All(ctx)
|
||||||
if err != nil && !dbent.IsNotFound(err) {
|
if err != nil {
|
||||||
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
|
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
|
||||||
}
|
}
|
||||||
if identity != nil && identity.UserID != userID {
|
if hasAdminAuthIdentityOwnershipConflict(identityRecords, userID) {
|
||||||
return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
|
return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
|
||||||
}
|
}
|
||||||
|
identity := selectOwnedAdminAuthIdentity(identityRecords, userID)
|
||||||
|
|
||||||
if identity == nil {
|
if identity == nil {
|
||||||
create := tx.AuthIdentity.Create().
|
create := tx.AuthIdentity.Create().
|
||||||
SetUserID(userID).
|
SetUserID(userID).
|
||||||
SetProviderType(providerType).
|
SetProviderType(providerType).
|
||||||
SetProviderKey(providerKey).
|
SetProviderKey(canonicalProviderKey).
|
||||||
SetProviderSubject(providerSubject).
|
SetProviderSubject(providerSubject).
|
||||||
SetVerifiedAt(verifiedAt)
|
SetVerifiedAt(verifiedAt)
|
||||||
if issuer != nil {
|
if issuer != nil {
|
||||||
@@ -932,7 +935,9 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6
|
|||||||
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_SAVE_FAILED", "failed to save auth identity").WithCause(err)
|
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_SAVE_FAILED", "failed to save auth identity").WithCause(err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
update := tx.AuthIdentity.UpdateOneID(identity.ID).SetVerifiedAt(verifiedAt)
|
update := tx.AuthIdentity.UpdateOneID(identity.ID).
|
||||||
|
SetVerifiedAt(verifiedAt).
|
||||||
|
SetProviderKey(canonicalProviderKey)
|
||||||
if issuer != nil {
|
if issuer != nil {
|
||||||
update = update.SetIssuer(*issuer)
|
update = update.SetIssuer(*issuer)
|
||||||
}
|
}
|
||||||
@@ -947,27 +952,28 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6
|
|||||||
|
|
||||||
var channel *dbent.AuthIdentityChannel
|
var channel *dbent.AuthIdentityChannel
|
||||||
if channelInput != nil {
|
if channelInput != nil {
|
||||||
channel, err = tx.AuthIdentityChannel.Query().
|
channelRecords, err := tx.AuthIdentityChannel.Query().
|
||||||
Where(
|
Where(
|
||||||
authidentitychannel.ProviderTypeEQ(providerType),
|
authidentitychannel.ProviderTypeEQ(providerType),
|
||||||
authidentitychannel.ProviderKeyEQ(providerKey),
|
authidentitychannel.ProviderKeyIn(compatibleProviderKeys...),
|
||||||
authidentitychannel.ChannelEQ(channelInput.Channel),
|
authidentitychannel.ChannelEQ(channelInput.Channel),
|
||||||
authidentitychannel.ChannelAppIDEQ(channelInput.ChannelAppID),
|
authidentitychannel.ChannelAppIDEQ(channelInput.ChannelAppID),
|
||||||
authidentitychannel.ChannelSubjectEQ(channelInput.ChannelSubject),
|
authidentitychannel.ChannelSubjectEQ(channelInput.ChannelSubject),
|
||||||
).
|
).
|
||||||
WithIdentity().
|
WithIdentity().
|
||||||
Only(ctx)
|
All(ctx)
|
||||||
if err != nil && !dbent.IsNotFound(err) {
|
if err != nil {
|
||||||
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_LOOKUP_FAILED", "failed to inspect auth identity channel ownership").WithCause(err)
|
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_LOOKUP_FAILED", "failed to inspect auth identity channel ownership").WithCause(err)
|
||||||
}
|
}
|
||||||
if channel != nil && channel.Edges.Identity != nil && channel.Edges.Identity.UserID != userID {
|
if hasAdminAuthIdentityChannelOwnershipConflict(channelRecords, userID) {
|
||||||
return nil, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user")
|
return nil, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user")
|
||||||
}
|
}
|
||||||
|
channel = selectOwnedAdminAuthIdentityChannel(channelRecords, userID)
|
||||||
if channel == nil {
|
if channel == nil {
|
||||||
create := tx.AuthIdentityChannel.Create().
|
create := tx.AuthIdentityChannel.Create().
|
||||||
SetIdentityID(identity.ID).
|
SetIdentityID(identity.ID).
|
||||||
SetProviderType(providerType).
|
SetProviderType(providerType).
|
||||||
SetProviderKey(providerKey).
|
SetProviderKey(canonicalProviderKey).
|
||||||
SetChannel(channelInput.Channel).
|
SetChannel(channelInput.Channel).
|
||||||
SetChannelAppID(channelInput.ChannelAppID).
|
SetChannelAppID(channelInput.ChannelAppID).
|
||||||
SetChannelSubject(channelInput.ChannelSubject)
|
SetChannelSubject(channelInput.ChannelSubject)
|
||||||
@@ -979,7 +985,9 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6
|
|||||||
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_SAVE_FAILED", "failed to save auth identity channel").WithCause(err)
|
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_SAVE_FAILED", "failed to save auth identity channel").WithCause(err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
update := tx.AuthIdentityChannel.UpdateOneID(channel.ID).SetIdentityID(identity.ID)
|
update := tx.AuthIdentityChannel.UpdateOneID(channel.ID).
|
||||||
|
SetIdentityID(identity.ID).
|
||||||
|
SetProviderKey(canonicalProviderKey)
|
||||||
if channelInput.Metadata != nil {
|
if channelInput.Metadata != nil {
|
||||||
update = update.SetMetadata(cloneAdminAuthIdentityMetadata(channelInput.Metadata))
|
update = update.SetMetadata(cloneAdminAuthIdentityMetadata(channelInput.Metadata))
|
||||||
}
|
}
|
||||||
@@ -996,6 +1004,105 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6
|
|||||||
return buildAdminBoundAuthIdentity(identity, channel), nil
|
return buildAdminBoundAuthIdentity(identity, channel), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func compatibleAdminAuthIdentityProviderKeys(providerType, providerKey string) []string {
|
||||||
|
providerType = strings.TrimSpace(strings.ToLower(providerType))
|
||||||
|
providerKey = strings.TrimSpace(providerKey)
|
||||||
|
if providerKey == "" {
|
||||||
|
return []string{providerKey}
|
||||||
|
}
|
||||||
|
if providerType != "wechat" {
|
||||||
|
return []string{providerKey}
|
||||||
|
}
|
||||||
|
|
||||||
|
keys := []string{providerKey}
|
||||||
|
if !strings.EqualFold(providerKey, "wechat-main") {
|
||||||
|
keys = append(keys, "wechat-main")
|
||||||
|
}
|
||||||
|
if !strings.EqualFold(providerKey, "wechat") {
|
||||||
|
keys = append(keys, "wechat")
|
||||||
|
}
|
||||||
|
return keys
|
||||||
|
}
|
||||||
|
|
||||||
|
func canonicalAdminAuthIdentityProviderKey(providerType, existingKey, requestedKey string) string {
|
||||||
|
providerType = strings.TrimSpace(strings.ToLower(providerType))
|
||||||
|
existingKey = strings.TrimSpace(existingKey)
|
||||||
|
requestedKey = strings.TrimSpace(requestedKey)
|
||||||
|
if providerType != "wechat" {
|
||||||
|
if requestedKey != "" {
|
||||||
|
return requestedKey
|
||||||
|
}
|
||||||
|
return existingKey
|
||||||
|
}
|
||||||
|
if strings.EqualFold(existingKey, "wechat") || strings.EqualFold(existingKey, "wechat-main") || strings.EqualFold(requestedKey, "wechat-main") {
|
||||||
|
return "wechat-main"
|
||||||
|
}
|
||||||
|
if requestedKey != "" {
|
||||||
|
return requestedKey
|
||||||
|
}
|
||||||
|
return existingKey
|
||||||
|
}
|
||||||
|
|
||||||
|
func adminAuthIdentityProviderKeyRank(providerType, providerKey string) int {
|
||||||
|
providerType = strings.TrimSpace(strings.ToLower(providerType))
|
||||||
|
providerKey = strings.TrimSpace(providerKey)
|
||||||
|
if providerType != "wechat" {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
switch {
|
||||||
|
case strings.EqualFold(providerKey, "wechat-main"):
|
||||||
|
return 0
|
||||||
|
case strings.EqualFold(providerKey, "wechat"):
|
||||||
|
return 2
|
||||||
|
default:
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func selectOwnedAdminAuthIdentity(records []*dbent.AuthIdentity, userID int64) *dbent.AuthIdentity {
|
||||||
|
var selected *dbent.AuthIdentity
|
||||||
|
for _, record := range records {
|
||||||
|
if record.UserID != userID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if selected == nil || adminAuthIdentityProviderKeyRank(record.ProviderType, record.ProviderKey) < adminAuthIdentityProviderKeyRank(selected.ProviderType, selected.ProviderKey) {
|
||||||
|
selected = record
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return selected
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasAdminAuthIdentityOwnershipConflict(records []*dbent.AuthIdentity, userID int64) bool {
|
||||||
|
for _, record := range records {
|
||||||
|
if record.UserID != userID {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func selectOwnedAdminAuthIdentityChannel(records []*dbent.AuthIdentityChannel, userID int64) *dbent.AuthIdentityChannel {
|
||||||
|
var selected *dbent.AuthIdentityChannel
|
||||||
|
for _, record := range records {
|
||||||
|
if record.Edges.Identity == nil || record.Edges.Identity.UserID != userID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if selected == nil || adminAuthIdentityProviderKeyRank(record.ProviderType, record.ProviderKey) < adminAuthIdentityProviderKeyRank(selected.ProviderType, selected.ProviderKey) {
|
||||||
|
selected = record
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return selected
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasAdminAuthIdentityChannelOwnershipConflict(records []*dbent.AuthIdentityChannel, userID int64) bool {
|
||||||
|
for _, record := range records {
|
||||||
|
if record.Edges.Identity != nil && record.Edges.Identity.UserID != userID {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func normalizeAdminBindChannelInput(input *AdminBindAuthIdentityChannelInput) *AdminBindAuthIdentityChannelInput {
|
func normalizeAdminBindChannelInput(input *AdminBindAuthIdentityChannelInput) *AdminBindAuthIdentityChannelInput {
|
||||||
if input == nil {
|
if input == nil {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -818,14 +818,14 @@ func oidcUsePKCECompatibilityDefault(base config.OIDCConnectConfig) bool {
|
|||||||
if base.UsePKCEExplicit {
|
if base.UsePKCEExplicit {
|
||||||
return base.UsePKCE
|
return base.UsePKCE
|
||||||
}
|
}
|
||||||
return false
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func oidcValidateIDTokenCompatibilityDefault(base config.OIDCConnectConfig) bool {
|
func oidcValidateIDTokenCompatibilityDefault(base config.OIDCConnectConfig) bool {
|
||||||
if base.ValidateIDTokenExplicit {
|
if base.ValidateIDTokenExplicit {
|
||||||
return base.ValidateIDToken
|
return base.ValidateIDToken
|
||||||
}
|
}
|
||||||
return false
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateSettings 更新系统设置
|
// UpdateSettings 更新系统设置
|
||||||
|
|||||||
@@ -133,7 +133,7 @@ func TestSettingService_ParseSettings_DefaultsOIDCSecurityFlagsToSafeConfigValue
|
|||||||
require.True(t, got.OIDCConnectValidateIDToken)
|
require.True(t, got.OIDCConnectValidateIDToken)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSettingService_ParseSettings_UsesLegacyOIDCCompatibilityFlagsWhenSettingsMissing(t *testing.T) {
|
func TestSettingService_ParseSettings_DefaultsOIDCCompatibilityFlagsToSafeDefaultsWhenSettingsMissing(t *testing.T) {
|
||||||
svc := NewSettingService(&settingOIDCRepoStub{values: map[string]string{}}, &config.Config{
|
svc := NewSettingService(&settingOIDCRepoStub{values: map[string]string{}}, &config.Config{
|
||||||
OIDC: config.OIDCConnectConfig{
|
OIDC: config.OIDCConnectConfig{
|
||||||
UsePKCE: true,
|
UsePKCE: true,
|
||||||
@@ -145,8 +145,8 @@ func TestSettingService_ParseSettings_UsesLegacyOIDCCompatibilityFlagsWhenSettin
|
|||||||
SettingKeyOIDCConnectEnabled: "true",
|
SettingKeyOIDCConnectEnabled: "true",
|
||||||
})
|
})
|
||||||
|
|
||||||
require.False(t, got.OIDCConnectUsePKCE)
|
require.True(t, got.OIDCConnectUsePKCE)
|
||||||
require.False(t, got.OIDCConnectValidateIDToken)
|
require.True(t, got.OIDCConnectValidateIDToken)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetOIDCConnectOAuthConfig_AllowsCompatibilityFlagsToDisablePKCEAndIDTokenValidation(t *testing.T) {
|
func TestGetOIDCConnectOAuthConfig_AllowsCompatibilityFlagsToDisablePKCEAndIDTokenValidation(t *testing.T) {
|
||||||
@@ -216,7 +216,7 @@ func TestGetOIDCConnectOAuthConfig_DefaultsToSecureFlagsWhenSettingsMissing(t *t
|
|||||||
require.True(t, got.ValidateIDToken)
|
require.True(t, got.ValidateIDToken)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetOIDCConnectOAuthConfig_UsesLegacyOIDCCompatibilityFlagsWhenSettingsMissing(t *testing.T) {
|
func TestGetOIDCConnectOAuthConfig_DefaultsCompatibilityFlagsToSafeValuesWhenSettingsMissing(t *testing.T) {
|
||||||
cfg := &config.Config{
|
cfg := &config.Config{
|
||||||
OIDC: config.OIDCConnectConfig{
|
OIDC: config.OIDCConnectConfig{
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
@@ -246,6 +246,6 @@ func TestGetOIDCConnectOAuthConfig_UsesLegacyOIDCCompatibilityFlagsWhenSettingsM
|
|||||||
|
|
||||||
got, err := svc.GetOIDCConnectOAuthConfig(context.Background())
|
got, err := svc.GetOIDCConnectOAuthConfig(context.Background())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.False(t, got.UsePKCE)
|
require.True(t, got.UsePKCE)
|
||||||
require.False(t, got.ValidateIDToken)
|
require.True(t, got.ValidateIDToken)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user