fix(unit): restore secure oidc defaults and wechat alias reuse

This commit is contained in:
IanShaw027
2026-04-22 16:51:23 +08:00
parent 66680a3056
commit a94d89efa7
3 changed files with 129 additions and 22 deletions

View File

@@ -879,6 +879,8 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6
if providerKey == "" || providerSubject == "" { if providerKey == "" || providerSubject == "" {
return nil, infraerrors.BadRequest("INVALID_INPUT", "provider_type, provider_key, and provider_subject are required") return nil, infraerrors.BadRequest("INVALID_INPUT", "provider_type, provider_key, and provider_subject are required")
} }
canonicalProviderKey := canonicalAdminAuthIdentityProviderKey(providerType, "", providerKey)
compatibleProviderKeys := compatibleAdminAuthIdentityProviderKeys(providerType, providerKey)
var issuer *string var issuer *string
if input.Issuer != nil { if input.Issuer != nil {
@@ -900,25 +902,26 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6
} }
defer func() { _ = tx.Rollback() }() defer func() { _ = tx.Rollback() }()
identity, err := tx.AuthIdentity.Query(). identityRecords, err := tx.AuthIdentity.Query().
Where( Where(
authidentity.ProviderTypeEQ(providerType), authidentity.ProviderTypeEQ(providerType),
authidentity.ProviderKeyEQ(providerKey), authidentity.ProviderKeyIn(compatibleProviderKeys...),
authidentity.ProviderSubjectEQ(providerSubject), authidentity.ProviderSubjectEQ(providerSubject),
). ).
Only(ctx) All(ctx)
if err != nil && !dbent.IsNotFound(err) { if err != nil {
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err) return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
} }
if identity != nil && identity.UserID != userID { if hasAdminAuthIdentityOwnershipConflict(identityRecords, userID) {
return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user") return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
} }
identity := selectOwnedAdminAuthIdentity(identityRecords, userID)
if identity == nil { if identity == nil {
create := tx.AuthIdentity.Create(). create := tx.AuthIdentity.Create().
SetUserID(userID). SetUserID(userID).
SetProviderType(providerType). SetProviderType(providerType).
SetProviderKey(providerKey). SetProviderKey(canonicalProviderKey).
SetProviderSubject(providerSubject). SetProviderSubject(providerSubject).
SetVerifiedAt(verifiedAt) SetVerifiedAt(verifiedAt)
if issuer != nil { if issuer != nil {
@@ -932,7 +935,9 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_SAVE_FAILED", "failed to save auth identity").WithCause(err) return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_SAVE_FAILED", "failed to save auth identity").WithCause(err)
} }
} else { } else {
update := tx.AuthIdentity.UpdateOneID(identity.ID).SetVerifiedAt(verifiedAt) update := tx.AuthIdentity.UpdateOneID(identity.ID).
SetVerifiedAt(verifiedAt).
SetProviderKey(canonicalProviderKey)
if issuer != nil { if issuer != nil {
update = update.SetIssuer(*issuer) update = update.SetIssuer(*issuer)
} }
@@ -947,27 +952,28 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6
var channel *dbent.AuthIdentityChannel var channel *dbent.AuthIdentityChannel
if channelInput != nil { if channelInput != nil {
channel, err = tx.AuthIdentityChannel.Query(). channelRecords, err := tx.AuthIdentityChannel.Query().
Where( Where(
authidentitychannel.ProviderTypeEQ(providerType), authidentitychannel.ProviderTypeEQ(providerType),
authidentitychannel.ProviderKeyEQ(providerKey), authidentitychannel.ProviderKeyIn(compatibleProviderKeys...),
authidentitychannel.ChannelEQ(channelInput.Channel), authidentitychannel.ChannelEQ(channelInput.Channel),
authidentitychannel.ChannelAppIDEQ(channelInput.ChannelAppID), authidentitychannel.ChannelAppIDEQ(channelInput.ChannelAppID),
authidentitychannel.ChannelSubjectEQ(channelInput.ChannelSubject), authidentitychannel.ChannelSubjectEQ(channelInput.ChannelSubject),
). ).
WithIdentity(). WithIdentity().
Only(ctx) All(ctx)
if err != nil && !dbent.IsNotFound(err) { if err != nil {
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_LOOKUP_FAILED", "failed to inspect auth identity channel ownership").WithCause(err) return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_LOOKUP_FAILED", "failed to inspect auth identity channel ownership").WithCause(err)
} }
if channel != nil && channel.Edges.Identity != nil && channel.Edges.Identity.UserID != userID { if hasAdminAuthIdentityChannelOwnershipConflict(channelRecords, userID) {
return nil, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user") return nil, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user")
} }
channel = selectOwnedAdminAuthIdentityChannel(channelRecords, userID)
if channel == nil { if channel == nil {
create := tx.AuthIdentityChannel.Create(). create := tx.AuthIdentityChannel.Create().
SetIdentityID(identity.ID). SetIdentityID(identity.ID).
SetProviderType(providerType). SetProviderType(providerType).
SetProviderKey(providerKey). SetProviderKey(canonicalProviderKey).
SetChannel(channelInput.Channel). SetChannel(channelInput.Channel).
SetChannelAppID(channelInput.ChannelAppID). SetChannelAppID(channelInput.ChannelAppID).
SetChannelSubject(channelInput.ChannelSubject) SetChannelSubject(channelInput.ChannelSubject)
@@ -979,7 +985,9 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_SAVE_FAILED", "failed to save auth identity channel").WithCause(err) return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_SAVE_FAILED", "failed to save auth identity channel").WithCause(err)
} }
} else { } else {
update := tx.AuthIdentityChannel.UpdateOneID(channel.ID).SetIdentityID(identity.ID) update := tx.AuthIdentityChannel.UpdateOneID(channel.ID).
SetIdentityID(identity.ID).
SetProviderKey(canonicalProviderKey)
if channelInput.Metadata != nil { if channelInput.Metadata != nil {
update = update.SetMetadata(cloneAdminAuthIdentityMetadata(channelInput.Metadata)) update = update.SetMetadata(cloneAdminAuthIdentityMetadata(channelInput.Metadata))
} }
@@ -996,6 +1004,105 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6
return buildAdminBoundAuthIdentity(identity, channel), nil return buildAdminBoundAuthIdentity(identity, channel), nil
} }
func compatibleAdminAuthIdentityProviderKeys(providerType, providerKey string) []string {
providerType = strings.TrimSpace(strings.ToLower(providerType))
providerKey = strings.TrimSpace(providerKey)
if providerKey == "" {
return []string{providerKey}
}
if providerType != "wechat" {
return []string{providerKey}
}
keys := []string{providerKey}
if !strings.EqualFold(providerKey, "wechat-main") {
keys = append(keys, "wechat-main")
}
if !strings.EqualFold(providerKey, "wechat") {
keys = append(keys, "wechat")
}
return keys
}
func canonicalAdminAuthIdentityProviderKey(providerType, existingKey, requestedKey string) string {
providerType = strings.TrimSpace(strings.ToLower(providerType))
existingKey = strings.TrimSpace(existingKey)
requestedKey = strings.TrimSpace(requestedKey)
if providerType != "wechat" {
if requestedKey != "" {
return requestedKey
}
return existingKey
}
if strings.EqualFold(existingKey, "wechat") || strings.EqualFold(existingKey, "wechat-main") || strings.EqualFold(requestedKey, "wechat-main") {
return "wechat-main"
}
if requestedKey != "" {
return requestedKey
}
return existingKey
}
func adminAuthIdentityProviderKeyRank(providerType, providerKey string) int {
providerType = strings.TrimSpace(strings.ToLower(providerType))
providerKey = strings.TrimSpace(providerKey)
if providerType != "wechat" {
return 0
}
switch {
case strings.EqualFold(providerKey, "wechat-main"):
return 0
case strings.EqualFold(providerKey, "wechat"):
return 2
default:
return 1
}
}
func selectOwnedAdminAuthIdentity(records []*dbent.AuthIdentity, userID int64) *dbent.AuthIdentity {
var selected *dbent.AuthIdentity
for _, record := range records {
if record.UserID != userID {
continue
}
if selected == nil || adminAuthIdentityProviderKeyRank(record.ProviderType, record.ProviderKey) < adminAuthIdentityProviderKeyRank(selected.ProviderType, selected.ProviderKey) {
selected = record
}
}
return selected
}
func hasAdminAuthIdentityOwnershipConflict(records []*dbent.AuthIdentity, userID int64) bool {
for _, record := range records {
if record.UserID != userID {
return true
}
}
return false
}
func selectOwnedAdminAuthIdentityChannel(records []*dbent.AuthIdentityChannel, userID int64) *dbent.AuthIdentityChannel {
var selected *dbent.AuthIdentityChannel
for _, record := range records {
if record.Edges.Identity == nil || record.Edges.Identity.UserID != userID {
continue
}
if selected == nil || adminAuthIdentityProviderKeyRank(record.ProviderType, record.ProviderKey) < adminAuthIdentityProviderKeyRank(selected.ProviderType, selected.ProviderKey) {
selected = record
}
}
return selected
}
func hasAdminAuthIdentityChannelOwnershipConflict(records []*dbent.AuthIdentityChannel, userID int64) bool {
for _, record := range records {
if record.Edges.Identity != nil && record.Edges.Identity.UserID != userID {
return true
}
}
return false
}
func normalizeAdminBindChannelInput(input *AdminBindAuthIdentityChannelInput) *AdminBindAuthIdentityChannelInput { func normalizeAdminBindChannelInput(input *AdminBindAuthIdentityChannelInput) *AdminBindAuthIdentityChannelInput {
if input == nil { if input == nil {
return nil return nil

View File

@@ -818,14 +818,14 @@ func oidcUsePKCECompatibilityDefault(base config.OIDCConnectConfig) bool {
if base.UsePKCEExplicit { if base.UsePKCEExplicit {
return base.UsePKCE return base.UsePKCE
} }
return false return true
} }
func oidcValidateIDTokenCompatibilityDefault(base config.OIDCConnectConfig) bool { func oidcValidateIDTokenCompatibilityDefault(base config.OIDCConnectConfig) bool {
if base.ValidateIDTokenExplicit { if base.ValidateIDTokenExplicit {
return base.ValidateIDToken return base.ValidateIDToken
} }
return false return true
} }
// UpdateSettings 更新系统设置 // UpdateSettings 更新系统设置

View File

@@ -133,7 +133,7 @@ func TestSettingService_ParseSettings_DefaultsOIDCSecurityFlagsToSafeConfigValue
require.True(t, got.OIDCConnectValidateIDToken) require.True(t, got.OIDCConnectValidateIDToken)
} }
func TestSettingService_ParseSettings_UsesLegacyOIDCCompatibilityFlagsWhenSettingsMissing(t *testing.T) { func TestSettingService_ParseSettings_DefaultsOIDCCompatibilityFlagsToSafeDefaultsWhenSettingsMissing(t *testing.T) {
svc := NewSettingService(&settingOIDCRepoStub{values: map[string]string{}}, &config.Config{ svc := NewSettingService(&settingOIDCRepoStub{values: map[string]string{}}, &config.Config{
OIDC: config.OIDCConnectConfig{ OIDC: config.OIDCConnectConfig{
UsePKCE: true, UsePKCE: true,
@@ -145,8 +145,8 @@ func TestSettingService_ParseSettings_UsesLegacyOIDCCompatibilityFlagsWhenSettin
SettingKeyOIDCConnectEnabled: "true", SettingKeyOIDCConnectEnabled: "true",
}) })
require.False(t, got.OIDCConnectUsePKCE) require.True(t, got.OIDCConnectUsePKCE)
require.False(t, got.OIDCConnectValidateIDToken) require.True(t, got.OIDCConnectValidateIDToken)
} }
func TestGetOIDCConnectOAuthConfig_AllowsCompatibilityFlagsToDisablePKCEAndIDTokenValidation(t *testing.T) { func TestGetOIDCConnectOAuthConfig_AllowsCompatibilityFlagsToDisablePKCEAndIDTokenValidation(t *testing.T) {
@@ -216,7 +216,7 @@ func TestGetOIDCConnectOAuthConfig_DefaultsToSecureFlagsWhenSettingsMissing(t *t
require.True(t, got.ValidateIDToken) require.True(t, got.ValidateIDToken)
} }
func TestGetOIDCConnectOAuthConfig_UsesLegacyOIDCCompatibilityFlagsWhenSettingsMissing(t *testing.T) { func TestGetOIDCConnectOAuthConfig_DefaultsCompatibilityFlagsToSafeValuesWhenSettingsMissing(t *testing.T) {
cfg := &config.Config{ cfg := &config.Config{
OIDC: config.OIDCConnectConfig{ OIDC: config.OIDCConnectConfig{
Enabled: true, Enabled: true,
@@ -246,6 +246,6 @@ func TestGetOIDCConnectOAuthConfig_UsesLegacyOIDCCompatibilityFlagsWhenSettingsM
got, err := svc.GetOIDCConnectOAuthConfig(context.Background()) got, err := svc.GetOIDCConnectOAuthConfig(context.Background())
require.NoError(t, err) require.NoError(t, err)
require.False(t, got.UsePKCE) require.True(t, got.UsePKCE)
require.False(t, got.ValidateIDToken) require.True(t, got.ValidateIDToken)
} }