diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 110c9008..4ae66613 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -879,6 +879,8 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6 if providerKey == "" || providerSubject == "" { return nil, infraerrors.BadRequest("INVALID_INPUT", "provider_type, provider_key, and provider_subject are required") } + canonicalProviderKey := canonicalAdminAuthIdentityProviderKey(providerType, "", providerKey) + compatibleProviderKeys := compatibleAdminAuthIdentityProviderKeys(providerType, providerKey) var issuer *string if input.Issuer != nil { @@ -900,25 +902,26 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6 } defer func() { _ = tx.Rollback() }() - identity, err := tx.AuthIdentity.Query(). + identityRecords, err := tx.AuthIdentity.Query(). Where( authidentity.ProviderTypeEQ(providerType), - authidentity.ProviderKeyEQ(providerKey), + authidentity.ProviderKeyIn(compatibleProviderKeys...), authidentity.ProviderSubjectEQ(providerSubject), ). - Only(ctx) - if err != nil && !dbent.IsNotFound(err) { + All(ctx) + if err != nil { return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err) } - if identity != nil && identity.UserID != userID { + if hasAdminAuthIdentityOwnershipConflict(identityRecords, userID) { return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user") } + identity := selectOwnedAdminAuthIdentity(identityRecords, userID) if identity == nil { create := tx.AuthIdentity.Create(). SetUserID(userID). SetProviderType(providerType). - SetProviderKey(providerKey). + SetProviderKey(canonicalProviderKey). SetProviderSubject(providerSubject). SetVerifiedAt(verifiedAt) if issuer != nil { @@ -932,7 +935,9 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6 return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_SAVE_FAILED", "failed to save auth identity").WithCause(err) } } else { - update := tx.AuthIdentity.UpdateOneID(identity.ID).SetVerifiedAt(verifiedAt) + update := tx.AuthIdentity.UpdateOneID(identity.ID). + SetVerifiedAt(verifiedAt). + SetProviderKey(canonicalProviderKey) if issuer != nil { update = update.SetIssuer(*issuer) } @@ -947,27 +952,28 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6 var channel *dbent.AuthIdentityChannel if channelInput != nil { - channel, err = tx.AuthIdentityChannel.Query(). + channelRecords, err := tx.AuthIdentityChannel.Query(). Where( authidentitychannel.ProviderTypeEQ(providerType), - authidentitychannel.ProviderKeyEQ(providerKey), + authidentitychannel.ProviderKeyIn(compatibleProviderKeys...), authidentitychannel.ChannelEQ(channelInput.Channel), authidentitychannel.ChannelAppIDEQ(channelInput.ChannelAppID), authidentitychannel.ChannelSubjectEQ(channelInput.ChannelSubject), ). WithIdentity(). - Only(ctx) - if err != nil && !dbent.IsNotFound(err) { + All(ctx) + if err != nil { return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_LOOKUP_FAILED", "failed to inspect auth identity channel ownership").WithCause(err) } - if channel != nil && channel.Edges.Identity != nil && channel.Edges.Identity.UserID != userID { + if hasAdminAuthIdentityChannelOwnershipConflict(channelRecords, userID) { return nil, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user") } + channel = selectOwnedAdminAuthIdentityChannel(channelRecords, userID) if channel == nil { create := tx.AuthIdentityChannel.Create(). SetIdentityID(identity.ID). SetProviderType(providerType). - SetProviderKey(providerKey). + SetProviderKey(canonicalProviderKey). SetChannel(channelInput.Channel). SetChannelAppID(channelInput.ChannelAppID). SetChannelSubject(channelInput.ChannelSubject) @@ -979,7 +985,9 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6 return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_SAVE_FAILED", "failed to save auth identity channel").WithCause(err) } } else { - update := tx.AuthIdentityChannel.UpdateOneID(channel.ID).SetIdentityID(identity.ID) + update := tx.AuthIdentityChannel.UpdateOneID(channel.ID). + SetIdentityID(identity.ID). + SetProviderKey(canonicalProviderKey) if channelInput.Metadata != nil { update = update.SetMetadata(cloneAdminAuthIdentityMetadata(channelInput.Metadata)) } @@ -996,6 +1004,105 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6 return buildAdminBoundAuthIdentity(identity, channel), nil } +func compatibleAdminAuthIdentityProviderKeys(providerType, providerKey string) []string { + providerType = strings.TrimSpace(strings.ToLower(providerType)) + providerKey = strings.TrimSpace(providerKey) + if providerKey == "" { + return []string{providerKey} + } + if providerType != "wechat" { + return []string{providerKey} + } + + keys := []string{providerKey} + if !strings.EqualFold(providerKey, "wechat-main") { + keys = append(keys, "wechat-main") + } + if !strings.EqualFold(providerKey, "wechat") { + keys = append(keys, "wechat") + } + return keys +} + +func canonicalAdminAuthIdentityProviderKey(providerType, existingKey, requestedKey string) string { + providerType = strings.TrimSpace(strings.ToLower(providerType)) + existingKey = strings.TrimSpace(existingKey) + requestedKey = strings.TrimSpace(requestedKey) + if providerType != "wechat" { + if requestedKey != "" { + return requestedKey + } + return existingKey + } + if strings.EqualFold(existingKey, "wechat") || strings.EqualFold(existingKey, "wechat-main") || strings.EqualFold(requestedKey, "wechat-main") { + return "wechat-main" + } + if requestedKey != "" { + return requestedKey + } + return existingKey +} + +func adminAuthIdentityProviderKeyRank(providerType, providerKey string) int { + providerType = strings.TrimSpace(strings.ToLower(providerType)) + providerKey = strings.TrimSpace(providerKey) + if providerType != "wechat" { + return 0 + } + switch { + case strings.EqualFold(providerKey, "wechat-main"): + return 0 + case strings.EqualFold(providerKey, "wechat"): + return 2 + default: + return 1 + } +} + +func selectOwnedAdminAuthIdentity(records []*dbent.AuthIdentity, userID int64) *dbent.AuthIdentity { + var selected *dbent.AuthIdentity + for _, record := range records { + if record.UserID != userID { + continue + } + if selected == nil || adminAuthIdentityProviderKeyRank(record.ProviderType, record.ProviderKey) < adminAuthIdentityProviderKeyRank(selected.ProviderType, selected.ProviderKey) { + selected = record + } + } + return selected +} + +func hasAdminAuthIdentityOwnershipConflict(records []*dbent.AuthIdentity, userID int64) bool { + for _, record := range records { + if record.UserID != userID { + return true + } + } + return false +} + +func selectOwnedAdminAuthIdentityChannel(records []*dbent.AuthIdentityChannel, userID int64) *dbent.AuthIdentityChannel { + var selected *dbent.AuthIdentityChannel + for _, record := range records { + if record.Edges.Identity == nil || record.Edges.Identity.UserID != userID { + continue + } + if selected == nil || adminAuthIdentityProviderKeyRank(record.ProviderType, record.ProviderKey) < adminAuthIdentityProviderKeyRank(selected.ProviderType, selected.ProviderKey) { + selected = record + } + } + return selected +} + +func hasAdminAuthIdentityChannelOwnershipConflict(records []*dbent.AuthIdentityChannel, userID int64) bool { + for _, record := range records { + if record.Edges.Identity != nil && record.Edges.Identity.UserID != userID { + return true + } + } + return false +} + func normalizeAdminBindChannelInput(input *AdminBindAuthIdentityChannelInput) *AdminBindAuthIdentityChannelInput { if input == nil { return nil diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index aac60b08..93b7def1 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -818,14 +818,14 @@ func oidcUsePKCECompatibilityDefault(base config.OIDCConnectConfig) bool { if base.UsePKCEExplicit { return base.UsePKCE } - return false + return true } func oidcValidateIDTokenCompatibilityDefault(base config.OIDCConnectConfig) bool { if base.ValidateIDTokenExplicit { return base.ValidateIDToken } - return false + return true } // UpdateSettings 更新系统设置 diff --git a/backend/internal/service/setting_service_oidc_config_test.go b/backend/internal/service/setting_service_oidc_config_test.go index 1ece6405..61324204 100644 --- a/backend/internal/service/setting_service_oidc_config_test.go +++ b/backend/internal/service/setting_service_oidc_config_test.go @@ -133,7 +133,7 @@ func TestSettingService_ParseSettings_DefaultsOIDCSecurityFlagsToSafeConfigValue require.True(t, got.OIDCConnectValidateIDToken) } -func TestSettingService_ParseSettings_UsesLegacyOIDCCompatibilityFlagsWhenSettingsMissing(t *testing.T) { +func TestSettingService_ParseSettings_DefaultsOIDCCompatibilityFlagsToSafeDefaultsWhenSettingsMissing(t *testing.T) { svc := NewSettingService(&settingOIDCRepoStub{values: map[string]string{}}, &config.Config{ OIDC: config.OIDCConnectConfig{ UsePKCE: true, @@ -145,8 +145,8 @@ func TestSettingService_ParseSettings_UsesLegacyOIDCCompatibilityFlagsWhenSettin SettingKeyOIDCConnectEnabled: "true", }) - require.False(t, got.OIDCConnectUsePKCE) - require.False(t, got.OIDCConnectValidateIDToken) + require.True(t, got.OIDCConnectUsePKCE) + require.True(t, got.OIDCConnectValidateIDToken) } func TestGetOIDCConnectOAuthConfig_AllowsCompatibilityFlagsToDisablePKCEAndIDTokenValidation(t *testing.T) { @@ -216,7 +216,7 @@ func TestGetOIDCConnectOAuthConfig_DefaultsToSecureFlagsWhenSettingsMissing(t *t require.True(t, got.ValidateIDToken) } -func TestGetOIDCConnectOAuthConfig_UsesLegacyOIDCCompatibilityFlagsWhenSettingsMissing(t *testing.T) { +func TestGetOIDCConnectOAuthConfig_DefaultsCompatibilityFlagsToSafeValuesWhenSettingsMissing(t *testing.T) { cfg := &config.Config{ OIDC: config.OIDCConnectConfig{ Enabled: true, @@ -246,6 +246,6 @@ func TestGetOIDCConnectOAuthConfig_UsesLegacyOIDCCompatibilityFlagsWhenSettingsM got, err := svc.GetOIDCConnectOAuthConfig(context.Background()) require.NoError(t, err) - require.False(t, got.UsePKCE) - require.False(t, got.ValidateIDToken) + require.True(t, got.UsePKCE) + require.True(t, got.ValidateIDToken) }