fix(unit): restore secure oidc defaults and wechat alias reuse

This commit is contained in:
IanShaw027
2026-04-22 16:51:23 +08:00
parent 66680a3056
commit a94d89efa7
3 changed files with 129 additions and 22 deletions

View File

@@ -879,6 +879,8 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6
if providerKey == "" || providerSubject == "" {
return nil, infraerrors.BadRequest("INVALID_INPUT", "provider_type, provider_key, and provider_subject are required")
}
canonicalProviderKey := canonicalAdminAuthIdentityProviderKey(providerType, "", providerKey)
compatibleProviderKeys := compatibleAdminAuthIdentityProviderKeys(providerType, providerKey)
var issuer *string
if input.Issuer != nil {
@@ -900,25 +902,26 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6
}
defer func() { _ = tx.Rollback() }()
identity, err := tx.AuthIdentity.Query().
identityRecords, err := tx.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ(providerType),
authidentity.ProviderKeyEQ(providerKey),
authidentity.ProviderKeyIn(compatibleProviderKeys...),
authidentity.ProviderSubjectEQ(providerSubject),
).
Only(ctx)
if err != nil && !dbent.IsNotFound(err) {
All(ctx)
if err != nil {
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
}
if identity != nil && identity.UserID != userID {
if hasAdminAuthIdentityOwnershipConflict(identityRecords, userID) {
return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
}
identity := selectOwnedAdminAuthIdentity(identityRecords, userID)
if identity == nil {
create := tx.AuthIdentity.Create().
SetUserID(userID).
SetProviderType(providerType).
SetProviderKey(providerKey).
SetProviderKey(canonicalProviderKey).
SetProviderSubject(providerSubject).
SetVerifiedAt(verifiedAt)
if issuer != nil {
@@ -932,7 +935,9 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_SAVE_FAILED", "failed to save auth identity").WithCause(err)
}
} else {
update := tx.AuthIdentity.UpdateOneID(identity.ID).SetVerifiedAt(verifiedAt)
update := tx.AuthIdentity.UpdateOneID(identity.ID).
SetVerifiedAt(verifiedAt).
SetProviderKey(canonicalProviderKey)
if issuer != nil {
update = update.SetIssuer(*issuer)
}
@@ -947,27 +952,28 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6
var channel *dbent.AuthIdentityChannel
if channelInput != nil {
channel, err = tx.AuthIdentityChannel.Query().
channelRecords, err := tx.AuthIdentityChannel.Query().
Where(
authidentitychannel.ProviderTypeEQ(providerType),
authidentitychannel.ProviderKeyEQ(providerKey),
authidentitychannel.ProviderKeyIn(compatibleProviderKeys...),
authidentitychannel.ChannelEQ(channelInput.Channel),
authidentitychannel.ChannelAppIDEQ(channelInput.ChannelAppID),
authidentitychannel.ChannelSubjectEQ(channelInput.ChannelSubject),
).
WithIdentity().
Only(ctx)
if err != nil && !dbent.IsNotFound(err) {
All(ctx)
if err != nil {
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_LOOKUP_FAILED", "failed to inspect auth identity channel ownership").WithCause(err)
}
if channel != nil && channel.Edges.Identity != nil && channel.Edges.Identity.UserID != userID {
if hasAdminAuthIdentityChannelOwnershipConflict(channelRecords, userID) {
return nil, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user")
}
channel = selectOwnedAdminAuthIdentityChannel(channelRecords, userID)
if channel == nil {
create := tx.AuthIdentityChannel.Create().
SetIdentityID(identity.ID).
SetProviderType(providerType).
SetProviderKey(providerKey).
SetProviderKey(canonicalProviderKey).
SetChannel(channelInput.Channel).
SetChannelAppID(channelInput.ChannelAppID).
SetChannelSubject(channelInput.ChannelSubject)
@@ -979,7 +985,9 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_SAVE_FAILED", "failed to save auth identity channel").WithCause(err)
}
} else {
update := tx.AuthIdentityChannel.UpdateOneID(channel.ID).SetIdentityID(identity.ID)
update := tx.AuthIdentityChannel.UpdateOneID(channel.ID).
SetIdentityID(identity.ID).
SetProviderKey(canonicalProviderKey)
if channelInput.Metadata != nil {
update = update.SetMetadata(cloneAdminAuthIdentityMetadata(channelInput.Metadata))
}
@@ -996,6 +1004,105 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6
return buildAdminBoundAuthIdentity(identity, channel), nil
}
func compatibleAdminAuthIdentityProviderKeys(providerType, providerKey string) []string {
providerType = strings.TrimSpace(strings.ToLower(providerType))
providerKey = strings.TrimSpace(providerKey)
if providerKey == "" {
return []string{providerKey}
}
if providerType != "wechat" {
return []string{providerKey}
}
keys := []string{providerKey}
if !strings.EqualFold(providerKey, "wechat-main") {
keys = append(keys, "wechat-main")
}
if !strings.EqualFold(providerKey, "wechat") {
keys = append(keys, "wechat")
}
return keys
}
func canonicalAdminAuthIdentityProviderKey(providerType, existingKey, requestedKey string) string {
providerType = strings.TrimSpace(strings.ToLower(providerType))
existingKey = strings.TrimSpace(existingKey)
requestedKey = strings.TrimSpace(requestedKey)
if providerType != "wechat" {
if requestedKey != "" {
return requestedKey
}
return existingKey
}
if strings.EqualFold(existingKey, "wechat") || strings.EqualFold(existingKey, "wechat-main") || strings.EqualFold(requestedKey, "wechat-main") {
return "wechat-main"
}
if requestedKey != "" {
return requestedKey
}
return existingKey
}
func adminAuthIdentityProviderKeyRank(providerType, providerKey string) int {
providerType = strings.TrimSpace(strings.ToLower(providerType))
providerKey = strings.TrimSpace(providerKey)
if providerType != "wechat" {
return 0
}
switch {
case strings.EqualFold(providerKey, "wechat-main"):
return 0
case strings.EqualFold(providerKey, "wechat"):
return 2
default:
return 1
}
}
func selectOwnedAdminAuthIdentity(records []*dbent.AuthIdentity, userID int64) *dbent.AuthIdentity {
var selected *dbent.AuthIdentity
for _, record := range records {
if record.UserID != userID {
continue
}
if selected == nil || adminAuthIdentityProviderKeyRank(record.ProviderType, record.ProviderKey) < adminAuthIdentityProviderKeyRank(selected.ProviderType, selected.ProviderKey) {
selected = record
}
}
return selected
}
func hasAdminAuthIdentityOwnershipConflict(records []*dbent.AuthIdentity, userID int64) bool {
for _, record := range records {
if record.UserID != userID {
return true
}
}
return false
}
func selectOwnedAdminAuthIdentityChannel(records []*dbent.AuthIdentityChannel, userID int64) *dbent.AuthIdentityChannel {
var selected *dbent.AuthIdentityChannel
for _, record := range records {
if record.Edges.Identity == nil || record.Edges.Identity.UserID != userID {
continue
}
if selected == nil || adminAuthIdentityProviderKeyRank(record.ProviderType, record.ProviderKey) < adminAuthIdentityProviderKeyRank(selected.ProviderType, selected.ProviderKey) {
selected = record
}
}
return selected
}
func hasAdminAuthIdentityChannelOwnershipConflict(records []*dbent.AuthIdentityChannel, userID int64) bool {
for _, record := range records {
if record.Edges.Identity != nil && record.Edges.Identity.UserID != userID {
return true
}
}
return false
}
func normalizeAdminBindChannelInput(input *AdminBindAuthIdentityChannelInput) *AdminBindAuthIdentityChannelInput {
if input == nil {
return nil

View File

@@ -818,14 +818,14 @@ func oidcUsePKCECompatibilityDefault(base config.OIDCConnectConfig) bool {
if base.UsePKCEExplicit {
return base.UsePKCE
}
return false
return true
}
func oidcValidateIDTokenCompatibilityDefault(base config.OIDCConnectConfig) bool {
if base.ValidateIDTokenExplicit {
return base.ValidateIDToken
}
return false
return true
}
// UpdateSettings 更新系统设置

View File

@@ -133,7 +133,7 @@ func TestSettingService_ParseSettings_DefaultsOIDCSecurityFlagsToSafeConfigValue
require.True(t, got.OIDCConnectValidateIDToken)
}
func TestSettingService_ParseSettings_UsesLegacyOIDCCompatibilityFlagsWhenSettingsMissing(t *testing.T) {
func TestSettingService_ParseSettings_DefaultsOIDCCompatibilityFlagsToSafeDefaultsWhenSettingsMissing(t *testing.T) {
svc := NewSettingService(&settingOIDCRepoStub{values: map[string]string{}}, &config.Config{
OIDC: config.OIDCConnectConfig{
UsePKCE: true,
@@ -145,8 +145,8 @@ func TestSettingService_ParseSettings_UsesLegacyOIDCCompatibilityFlagsWhenSettin
SettingKeyOIDCConnectEnabled: "true",
})
require.False(t, got.OIDCConnectUsePKCE)
require.False(t, got.OIDCConnectValidateIDToken)
require.True(t, got.OIDCConnectUsePKCE)
require.True(t, got.OIDCConnectValidateIDToken)
}
func TestGetOIDCConnectOAuthConfig_AllowsCompatibilityFlagsToDisablePKCEAndIDTokenValidation(t *testing.T) {
@@ -216,7 +216,7 @@ func TestGetOIDCConnectOAuthConfig_DefaultsToSecureFlagsWhenSettingsMissing(t *t
require.True(t, got.ValidateIDToken)
}
func TestGetOIDCConnectOAuthConfig_UsesLegacyOIDCCompatibilityFlagsWhenSettingsMissing(t *testing.T) {
func TestGetOIDCConnectOAuthConfig_DefaultsCompatibilityFlagsToSafeValuesWhenSettingsMissing(t *testing.T) {
cfg := &config.Config{
OIDC: config.OIDCConnectConfig{
Enabled: true,
@@ -246,6 +246,6 @@ func TestGetOIDCConnectOAuthConfig_UsesLegacyOIDCCompatibilityFlagsWhenSettingsM
got, err := svc.GetOIDCConnectOAuthConfig(context.Background())
require.NoError(t, err)
require.False(t, got.UsePKCE)
require.False(t, got.ValidateIDToken)
require.True(t, got.UsePKCE)
require.True(t, got.ValidateIDToken)
}