diff --git a/backend/ent/schema/auth_identity_schema_test.go b/backend/ent/schema/auth_identity_schema_test.go index de55dd69..fbb93236 100644 --- a/backend/ent/schema/auth_identity_schema_test.go +++ b/backend/ent/schema/auth_identity_schema_test.go @@ -3,7 +3,9 @@ package schema import ( "testing" + "entgo.io/ent" "entgo.io/ent/entc/load" + "entgo.io/ent/schema/field" "github.com/stretchr/testify/require" ) @@ -74,6 +76,17 @@ func TestAuthIdentityFoundationSchemas(t *testing.T) { userSchema := requireSchema(t, schemas, "User") requireSchemaFields(t, userSchema, "signup_source", "last_login_at", "last_active_at") + signupSource := requireSchemaField(t, userSchema, "signup_source") + require.Equal(t, field.TypeString, signupSource.Info.Type) + require.True(t, signupSource.Default) + require.Equal(t, "email", signupSource.DefaultValue) + require.Equal(t, 1, signupSource.Validators) + + validator := requireStringFieldValidator(t, User{}.Fields(), "signup_source") + for _, value := range []string{"email", "linuxdo", "wechat", "oidc"} { + require.NoError(t, validator(value)) + } + require.Error(t, validator("github")) } func requireSchema(t *testing.T, schemas map[string]*load.Schema, name string) *load.Schema { @@ -98,6 +111,37 @@ func requireSchemaFields(t *testing.T, schema *load.Schema, names ...string) { } } +func requireSchemaField(t *testing.T, schema *load.Schema, name string) *load.Field { + t.Helper() + + for _, schemaField := range schema.Fields { + if schemaField.Name == name { + return schemaField + } + } + + require.Failf(t, "missing schema field", "schema %s should include field %s", schema.Name, name) + return nil +} + +func requireStringFieldValidator(t *testing.T, fields []ent.Field, name string) func(string) error { + t.Helper() + + for _, entField := range fields { + descriptor := entField.Descriptor() + if descriptor.Name != name { + continue + } + require.NotEmpty(t, descriptor.Validators, "field %s should include a validator", name) + validator, ok := descriptor.Validators[0].(func(string) error) + require.True(t, ok, "field %s validator should be func(string) error", name) + return validator + } + + require.Failf(t, "missing field validator", "schema should include field %s", name) + return nil +} + func requireHasUniqueIndex(t *testing.T, schema *load.Schema, fields ...string) { t.Helper() diff --git a/backend/ent/schema/user.go b/backend/ent/schema/user.go index f307bda8..c0f0bdc1 100644 --- a/backend/ent/schema/user.go +++ b/backend/ent/schema/user.go @@ -1,6 +1,8 @@ package schema import ( + "fmt" + "github.com/Wei-Shaw/sub2api/ent/schema/mixins" "github.com/Wei-Shaw/sub2api/internal/domain" @@ -73,7 +75,14 @@ func (User) Fields() []ent.Field { Optional(). Nillable(), field.String("signup_source"). - MaxLen(20). + Validate(func(value string) error { + switch value { + case "email", "linuxdo", "wechat", "oidc": + return nil + default: + return fmt.Errorf("must be one of email, linuxdo, wechat, oidc") + } + }). Default("email"), field.Time("last_login_at"). Optional(). diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index d47eadd4..87263db0 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -211,25 +211,27 @@ type WeChatConnectConfig struct { } type OIDCConnectConfig struct { - Enabled bool `mapstructure:"enabled"` - ProviderName string `mapstructure:"provider_name"` // 显示名: "Keycloak" 等 - ClientID string `mapstructure:"client_id"` - ClientSecret string `mapstructure:"client_secret"` - IssuerURL string `mapstructure:"issuer_url"` - DiscoveryURL string `mapstructure:"discovery_url"` - AuthorizeURL string `mapstructure:"authorize_url"` - TokenURL string `mapstructure:"token_url"` - UserInfoURL string `mapstructure:"userinfo_url"` - JWKSURL string `mapstructure:"jwks_url"` - Scopes string `mapstructure:"scopes"` // 默认 "openid email profile" - RedirectURL string `mapstructure:"redirect_url"` // 后端回调地址(需在提供方后台登记) - FrontendRedirectURL string `mapstructure:"frontend_redirect_url"` // 前端接收 token 的路由(默认:/auth/oidc/callback) - TokenAuthMethod string `mapstructure:"token_auth_method"` // client_secret_post / client_secret_basic / none - UsePKCE bool `mapstructure:"use_pkce"` - ValidateIDToken bool `mapstructure:"validate_id_token"` - AllowedSigningAlgs string `mapstructure:"allowed_signing_algs"` // 默认 "RS256,ES256,PS256" - ClockSkewSeconds int `mapstructure:"clock_skew_seconds"` // 默认 120 - RequireEmailVerified bool `mapstructure:"require_email_verified"` // 默认 false + Enabled bool `mapstructure:"enabled"` + ProviderName string `mapstructure:"provider_name"` // 显示名: "Keycloak" 等 + ClientID string `mapstructure:"client_id"` + ClientSecret string `mapstructure:"client_secret"` + IssuerURL string `mapstructure:"issuer_url"` + DiscoveryURL string `mapstructure:"discovery_url"` + AuthorizeURL string `mapstructure:"authorize_url"` + TokenURL string `mapstructure:"token_url"` + UserInfoURL string `mapstructure:"userinfo_url"` + JWKSURL string `mapstructure:"jwks_url"` + Scopes string `mapstructure:"scopes"` // 默认 "openid email profile" + RedirectURL string `mapstructure:"redirect_url"` // 后端回调地址(需在提供方后台登记) + FrontendRedirectURL string `mapstructure:"frontend_redirect_url"` // 前端接收 token 的路由(默认:/auth/oidc/callback) + TokenAuthMethod string `mapstructure:"token_auth_method"` // client_secret_post / client_secret_basic / none + UsePKCE bool `mapstructure:"use_pkce"` + ValidateIDToken bool `mapstructure:"validate_id_token"` + UsePKCEExplicit bool `mapstructure:"-" yaml:"-"` + ValidateIDTokenExplicit bool `mapstructure:"-" yaml:"-"` + AllowedSigningAlgs string `mapstructure:"allowed_signing_algs"` // 默认 "RS256,ES256,PS256" + ClockSkewSeconds int `mapstructure:"clock_skew_seconds"` // 默认 120 + RequireEmailVerified bool `mapstructure:"require_email_verified"` // 默认 false // 可选:用于从 userinfo JSON 中提取字段的 gjson 路径。 // 为空时,服务端会尝试一组常见字段名。 @@ -329,6 +331,14 @@ func shouldApplyLegacyWeChatEnv(configKey, envKey string) bool { return !hasNewEnv } +func hasExplicitConfigOrEnv(configKey, envKey string) bool { + if viper.InConfig(configKey) { + return true + } + _, ok := os.LookupEnv(envKey) + return ok +} + func applyLegacyWeChatConnectEnvCompatibility(cfg *WeChatConnectConfig) { if cfg == nil { return @@ -1262,6 +1272,8 @@ func load(allowMissingJWTSecret bool) (*Config, error) { cfg.OIDC.UserInfoEmailPath = strings.TrimSpace(cfg.OIDC.UserInfoEmailPath) cfg.OIDC.UserInfoIDPath = strings.TrimSpace(cfg.OIDC.UserInfoIDPath) cfg.OIDC.UserInfoUsernamePath = strings.TrimSpace(cfg.OIDC.UserInfoUsernamePath) + cfg.OIDC.UsePKCEExplicit = hasExplicitConfigOrEnv("oidc_connect.use_pkce", "OIDC_CONNECT_USE_PKCE") + cfg.OIDC.ValidateIDTokenExplicit = hasExplicitConfigOrEnv("oidc_connect.validate_id_token", "OIDC_CONNECT_VALIDATE_ID_TOKEN") cfg.Dashboard.KeyPrefix = strings.TrimSpace(cfg.Dashboard.KeyPrefix) cfg.CORS.AllowedOrigins = normalizeStringSlice(cfg.CORS.AllowedOrigins) cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed) diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index 8b59ef5f..6ba86aa1 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -254,6 +254,21 @@ func TestLoadDefaultOIDCSecurityDefaults(t *testing.T) { require.NoError(t, err) require.True(t, cfg.OIDC.UsePKCE) require.True(t, cfg.OIDC.ValidateIDToken) + require.False(t, cfg.OIDC.UsePKCEExplicit) + require.False(t, cfg.OIDC.ValidateIDTokenExplicit) +} + +func TestLoadExplicitOIDCSecurityDefaultsFromEnvMarksFlagsExplicit(t *testing.T) { + resetViperWithJWTSecret(t) + t.Setenv("OIDC_CONNECT_USE_PKCE", "false") + t.Setenv("OIDC_CONNECT_VALIDATE_ID_TOKEN", "false") + + cfg, err := Load() + require.NoError(t, err) + require.False(t, cfg.OIDC.UsePKCE) + require.False(t, cfg.OIDC.ValidateIDToken) + require.True(t, cfg.OIDC.UsePKCEExplicit) + require.True(t, cfg.OIDC.ValidateIDTokenExplicit) } func TestLoadForcedCodexInstructionsTemplate(t *testing.T) { diff --git a/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go b/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go index 8045d0c9..9a33a93a 100644 --- a/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go +++ b/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go @@ -335,6 +335,75 @@ func TestSettingHandler_UpdateSettings_PersistsExplicitFalseOIDCCompatibilityFla require.Equal(t, false, data["oidc_connect_validate_id_token"]) } +func TestSettingHandler_UpdateSettings_DoesNotSolidifyImplicitOIDCSecurityDefaultsOnLegacyUpgrade(t *testing.T) { + gin.SetMode(gin.TestMode) + repo := &settingHandlerRepoStub{ + values: map[string]string{ + service.SettingKeyPromoCodeEnabled: "true", + service.SettingKeyOIDCConnectEnabled: "true", + service.SettingKeyOIDCConnectProviderName: "OIDC", + service.SettingKeyOIDCConnectClientID: "oidc-client", + service.SettingKeyOIDCConnectClientSecret: "oidc-secret", + service.SettingKeyOIDCConnectIssuerURL: "https://issuer.example.com", + service.SettingKeyOIDCConnectAuthorizeURL: "https://issuer.example.com/auth", + service.SettingKeyOIDCConnectTokenURL: "https://issuer.example.com/token", + service.SettingKeyOIDCConnectUserInfoURL: "https://issuer.example.com/userinfo", + service.SettingKeyOIDCConnectJWKSURL: "https://issuer.example.com/jwks", + service.SettingKeyOIDCConnectScopes: "openid email profile", + service.SettingKeyOIDCConnectRedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback", + service.SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback", + service.SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post", + service.SettingKeyOIDCConnectAllowedSigningAlgs: "RS256", + service.SettingKeyOIDCConnectClockSkewSeconds: "120", + service.SettingKeyOIDCConnectRequireEmailVerified: "false", + service.SettingKeyOIDCConnectUserInfoEmailPath: "", + service.SettingKeyOIDCConnectUserInfoIDPath: "", + service.SettingKeyOIDCConnectUserInfoUsernamePath: "", + }, + } + svc := service.NewSettingService(repo, &config.Config{ + Default: config.DefaultConfig{UserConcurrency: 5}, + OIDC: config.OIDCConnectConfig{ + Enabled: true, + ProviderName: "OIDC", + ClientID: "oidc-client", + ClientSecret: "oidc-secret", + IssuerURL: "https://issuer.example.com", + AuthorizeURL: "https://issuer.example.com/auth", + TokenURL: "https://issuer.example.com/token", + UserInfoURL: "https://issuer.example.com/userinfo", + JWKSURL: "https://issuer.example.com/jwks", + Scopes: "openid email profile", + RedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback", + FrontendRedirectURL: "/auth/oidc/callback", + TokenAuthMethod: "client_secret_post", + UsePKCE: true, + ValidateIDToken: true, + AllowedSigningAlgs: "RS256", + ClockSkewSeconds: 120, + }, + }) + handler := NewSettingHandler(svc, nil, nil, nil, nil, nil) + + body := map[string]any{ + "promo_code_enabled": true, + "oidc_connect_enabled": true, + } + rawBody, err := json.Marshal(body) + require.NoError(t, err) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody)) + c.Request.Header.Set("Content-Type", "application/json") + + handler.UpdateSettings(c) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "false", repo.values[service.SettingKeyOIDCConnectUsePKCE]) + require.Equal(t, "false", repo.values[service.SettingKeyOIDCConnectValidateIDToken]) +} + func TestSettingHandler_UpdateSettings_RejectsInvalidPaymentVisibleMethodSource(t *testing.T) { gin.SetMode(gin.TestMode) repo := &settingHandlerRepoStub{ diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go index a7e77c09..2ef05963 100644 --- a/backend/internal/handler/auth_linuxdo_oauth.go +++ b/backend/internal/handler/auth_linuxdo_oauth.go @@ -355,15 +355,20 @@ func (h *AuthHandler) findLinuxDoCompatEmailUser(ctx context.Context, email stri } userEntity, err := client.User.Query(). - Where(dbuser.EmailEqualFold(email)). - Only(ctx) + Where(userNormalizedEmailPredicate(email)). + Order(dbent.Asc(dbuser.FieldID)). + All(ctx) if err != nil { - if dbent.IsNotFound(err) { - return nil, nil - } return nil, infraerrors.InternalServer("COMPAT_EMAIL_LOOKUP_FAILED", "failed to look up compat email user").WithCause(err) } - return userEntity, nil + switch len(userEntity) { + case 0: + return nil, nil + case 1: + return userEntity[0], nil + default: + return nil, infraerrors.Conflict("USER_EMAIL_CONFLICT", "normalized email matched multiple users") + } } func (h *AuthHandler) createLinuxDoOAuthChoicePendingSession( @@ -411,9 +416,15 @@ func (h *AuthHandler) createLinuxDoOAuthChoicePendingSession( completionResponse["choice_reason"] = "force_email_on_signup" } + var targetUserID *int64 + if compatEmailUser != nil && compatEmailUser.ID > 0 { + targetUserID = &compatEmailUser.ID + } + return h.createOAuthPendingSession(c, oauthPendingSessionPayload{ Intent: oauthIntentLogin, Identity: identity, + TargetUserID: targetUserID, ResolvedEmail: resolvedChoiceEmail, RedirectTo: redirectTo, BrowserSessionKey: browserSessionKey, @@ -490,9 +501,13 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) { return } - tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) - if err != nil { - response.ErrorFrom(c, err) + client := h.entClient() + if client == nil { + response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")) + return + } + if err := ensurePendingOAuthRegistrationIdentityAvailable(c.Request.Context(), client, session); err != nil { + respondPendingOAuthBindingApplyError(c, err) return } decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{ @@ -503,17 +518,16 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) { response.ErrorFrom(c, err) return } - if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, &user.ID); err != nil { - response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err)) - return - } - h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID) - if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil { - clearOAuthPendingSessionCookie(c, secureCookie) - clearOAuthPendingBrowserCookie(c, secureCookie) + tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) + if err != nil { response.ErrorFrom(c, err) return } + if err := applyPendingOAuthAdoptionAndConsumeSession(c.Request.Context(), client, h.authService, h.userService, session, decision, user.ID); err != nil { + respondPendingOAuthBindingApplyError(c, err) + return + } + h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID) clearOAuthPendingSessionCookie(c, secureCookie) clearOAuthPendingBrowserCookie(c, secureCookie) diff --git a/backend/internal/handler/auth_linuxdo_oauth_test.go b/backend/internal/handler/auth_linuxdo_oauth_test.go index d535c178..8b01ab41 100644 --- a/backend/internal/handler/auth_linuxdo_oauth_test.go +++ b/backend/internal/handler/auth_linuxdo_oauth_test.go @@ -508,7 +508,7 @@ func TestLinuxDoOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *test ctx := context.Background() existingUser, err := client.User.Create(). - SetEmail("legacy@example.com"). + SetEmail(" Legacy@Example.com "). SetUsername("legacy-user"). SetPasswordHash("hash"). SetRole(service.RoleUser). @@ -539,16 +539,17 @@ func TestLinuxDoOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *test Only(ctx) require.NoError(t, err) require.Equal(t, oauthIntentLogin, session.Intent) - require.Nil(t, session.TargetUserID) - require.Equal(t, existingUser.Email, session.ResolvedEmail) + require.NotNil(t, session.TargetUserID) + require.Equal(t, existingUser.ID, *session.TargetUserID) + require.Equal(t, strings.TrimSpace(existingUser.Email), session.ResolvedEmail) require.Equal(t, "legacy@example.com", session.UpstreamIdentityClaims["compat_email"]) completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) require.True(t, ok) require.Equal(t, "/dashboard", completion["redirect"]) require.Equal(t, oauthPendingChoiceStep, completion["step"]) - require.Equal(t, existingUser.Email, completion["email"]) - require.Equal(t, existingUser.Email, completion["existing_account_email"]) + require.Equal(t, strings.TrimSpace(existingUser.Email), completion["email"]) + require.Equal(t, strings.TrimSpace(existingUser.Email), completion["existing_account_email"]) require.Equal(t, true, completion["existing_account_bindable"]) require.Equal(t, "compat_email_match", completion["choice_reason"]) _, hasAccessToken := completion["access_token"] @@ -943,6 +944,68 @@ func TestCompleteLinuxDoOAuthRegistrationBindsIdentityWithoutAdoptionFlags(t *te require.False(t, decision.AdoptAvatar) } +func TestCompleteLinuxDoOAuthRegistrationRejectsIdentityOwnershipConflictBeforeUserCreation(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + existingOwner, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + _, err = client.AuthIdentity.Create(). + SetUserID(existingOwner.ID). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("linuxdo-conflict-subject"). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("linuxdo-complete-conflict-session"). + SetIntent("login"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("linuxdo-conflict-subject"). + SetResolvedEmail("linuxdo-conflict-subject@linuxdo-connect.invalid"). + SetBrowserSessionKey("linuxdo-conflict-browser"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "linuxdo_user", + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-conflict-browser")}) + c.Request = req + + handler.CompleteLinuxDoOAuthRegistration(c) + + require.Equal(t, http.StatusConflict, recorder.Code) + payload := decodeJSONBody(t, recorder) + require.Equal(t, "AUTH_IDENTITY_OWNERSHIP_CONFLICT", payload["reason"]) + + userCount, err := client.User.Query(). + Where(dbuser.EmailEQ("linuxdo-conflict-subject@linuxdo-connect.invalid")). + Count(ctx) + require.NoError(t, err) + require.Zero(t, userCount) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + func newLinuxDoOAuthTestHandler(t *testing.T, invitationEnabled bool, oauthCfg config.LinuxDoConnectConfig) *AuthHandler { t.Helper() handler, _ := newLinuxDoOAuthHandlerAndClient(t, invitationEnabled, oauthCfg) diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go index 7be01e74..ab854d24 100644 --- a/backend/internal/handler/auth_oauth_pending_flow.go +++ b/backend/internal/handler/auth_oauth_pending_flow.go @@ -519,7 +519,7 @@ func (h *AuthHandler) SendPendingOAuthVerifyCode(c *gin.Context) { email := strings.TrimSpace(strings.ToLower(req.Email)) if existingUser, err := findUserByNormalizedEmail(c.Request.Context(), client, email); err == nil && existingUser != nil { - session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, email) + session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, existingUser, email) if err != nil { response.ErrorFrom(c, err) return @@ -704,6 +704,38 @@ func findUserByNormalizedEmail(ctx context.Context, client *dbent.Client, email return matches[0], nil } +func ensurePendingOAuthRegistrationIdentityAvailable(ctx context.Context, client *dbent.Client, session *dbent.PendingAuthSession) error { + if client == nil || session == nil { + return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid") + } + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(strings.TrimSpace(session.ProviderType)), + authidentity.ProviderKeyEQ(strings.TrimSpace(session.ProviderKey)), + authidentity.ProviderSubjectEQ(strings.TrimSpace(session.ProviderSubject)), + ). + Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return nil + } + return err + } + if identity == nil || identity.UserID <= 0 { + return nil + } + + activeOwner, err := findActiveUserByID(ctx, client, identity.UserID) + if err != nil { + return err + } + if activeOwner != nil { + return infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user") + } + return nil +} + func oauthIdentityIssuer(session *dbent.PendingAuthSession) *string { if session == nil { return nil @@ -1206,6 +1238,38 @@ func consumePendingOAuthBrowserSessionTx( return nil } +func applyPendingOAuthAdoptionAndConsumeSession( + ctx context.Context, + client *dbent.Client, + authService *service.AuthService, + userService *service.UserService, + session *dbent.PendingAuthSession, + decision *dbent.IdentityAdoptionDecision, + userID int64, +) error { + if client == nil { + return infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + if session == nil || userID <= 0 { + return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid") + } + + tx, err := client.Tx(ctx) + if err != nil { + return err + } + defer func() { _ = tx.Rollback() }() + + txCtx := dbent.NewTxContext(ctx, tx) + if err := applyPendingOAuthAdoption(txCtx, client, authService, userService, session, decision, &userID); err != nil { + return err + } + if err := consumePendingOAuthBrowserSessionTx(txCtx, tx, session); err != nil { + return err + } + return tx.Commit() +} + func applyPendingOAuthAdoption( ctx context.Context, client *dbent.Client, @@ -1448,16 +1512,21 @@ func (h *AuthHandler) transitionPendingOAuthAccountToChoiceState( c *gin.Context, client *dbent.Client, session *dbent.PendingAuthSession, + targetUser *dbent.User, email string, ) (*dbent.PendingAuthSession, error) { completionResponse := pendingOAuthChoiceCompletionResponse(session, email) + var targetUserID *int64 + if targetUser != nil && targetUser.ID > 0 { + targetUserID = &targetUser.ID + } session, err := updatePendingOAuthSessionProgress( c.Request.Context(), client, session, strings.TrimSpace(session.Intent), email, - nil, + targetUserID, completionResponse, ) if err != nil { @@ -1601,7 +1670,7 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string) } } if existingUser != nil { - session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, email) + session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, existingUser, email) if err != nil { response.ErrorFrom(c, err) return @@ -1624,7 +1693,12 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string) ) if err != nil { if errors.Is(err, service.ErrEmailExists) { - session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, email) + existingUser, lookupErr := findUserByNormalizedEmail(c.Request.Context(), client, email) + if lookupErr != nil { + response.ErrorFrom(c, lookupErr) + return + } + session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, existingUser, email) if err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go index bc8fe7eb..9f9e497b 100644 --- a/backend/internal/handler/auth_oauth_pending_flow_test.go +++ b/backend/internal/handler/auth_oauth_pending_flow_test.go @@ -1045,7 +1045,7 @@ func TestCreateOIDCOAuthAccountExistingEmailReturnsChoicePendingSessionState(t * handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790") ctx := context.Background() - _, err := client.User.Create(). + existingUser, err := client.User.Create(). SetEmail("owner@example.com"). SetUsername("owner-user"). SetPasswordHash("hash"). @@ -1099,7 +1099,8 @@ func TestCreateOIDCOAuthAccountExistingEmailReturnsChoicePendingSessionState(t * storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) require.NoError(t, err) require.Equal(t, oauthIntentLogin, storedSession.Intent) - require.Nil(t, storedSession.TargetUserID) + require.NotNil(t, storedSession.TargetUserID) + require.Equal(t, existingUser.ID, *storedSession.TargetUserID) require.Equal(t, "owner@example.com", storedSession.ResolvedEmail) require.Nil(t, storedSession.ConsumedAt) @@ -1118,7 +1119,7 @@ func TestCreateOIDCOAuthAccountExistingEmailNormalizesLegacySpacingAndCase(t *te handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790") ctx := context.Background() - _, err := client.User.Create(). + existingUser, err := client.User.Create(). SetEmail(" Owner@Example.com "). SetUsername("owner-user"). SetPasswordHash("hash"). @@ -1164,7 +1165,8 @@ func TestCreateOIDCOAuthAccountExistingEmailNormalizesLegacySpacingAndCase(t *te storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) require.NoError(t, err) - require.Nil(t, storedSession.TargetUserID) + require.NotNil(t, storedSession.TargetUserID) + require.Equal(t, existingUser.ID, *storedSession.TargetUserID) require.Equal(t, "owner@example.com", storedSession.ResolvedEmail) } @@ -1172,7 +1174,7 @@ func TestSendPendingOAuthVerifyCodeExistingEmailReturnsBindLoginState(t *testing handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790") ctx := context.Background() - _, err := client.User.Create(). + existingUser, err := client.User.Create(). SetEmail("owner@example.com"). SetUsername("owner-user"). SetPasswordHash("hash"). @@ -1220,7 +1222,8 @@ func TestSendPendingOAuthVerifyCodeExistingEmailReturnsBindLoginState(t *testing storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) require.NoError(t, err) require.Equal(t, oauthIntentLogin, storedSession.Intent) - require.Nil(t, storedSession.TargetUserID) + require.NotNil(t, storedSession.TargetUserID) + require.Equal(t, existingUser.ID, *storedSession.TargetUserID) require.Equal(t, "owner@example.com", storedSession.ResolvedEmail) } diff --git a/backend/internal/handler/auth_oidc_oauth.go b/backend/internal/handler/auth_oidc_oauth.go index 3c67e421..0ac8871b 100644 --- a/backend/internal/handler/auth_oidc_oauth.go +++ b/backend/internal/handler/auth_oidc_oauth.go @@ -563,10 +563,15 @@ func (h *AuthHandler) createOIDCOAuthChoicePendingSession( if compatEmailUser != nil { resolvedChoiceEmail = strings.TrimSpace(compatEmailUser.Email) } + var targetUserID *int64 + if compatEmailUser != nil && compatEmailUser.ID > 0 { + targetUserID = &compatEmailUser.ID + } return h.createOAuthPendingSession(c, oauthPendingSessionPayload{ Intent: oauthIntentLogin, Identity: identity, + TargetUserID: targetUserID, ResolvedEmail: resolvedChoiceEmail, RedirectTo: redirectTo, BrowserSessionKey: browserSessionKey, @@ -643,9 +648,13 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) { return } - tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) - if err != nil { - response.ErrorFrom(c, err) + client := h.entClient() + if client == nil { + response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")) + return + } + if err := ensurePendingOAuthRegistrationIdentityAvailable(c.Request.Context(), client, session); err != nil { + respondPendingOAuthBindingApplyError(c, err) return } decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{ @@ -656,17 +665,16 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) { response.ErrorFrom(c, err) return } - if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, &user.ID); err != nil { - response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err)) - return - } - h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID) - if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil { - clearOAuthPendingSessionCookie(c, secureCookie) - clearOAuthPendingBrowserCookie(c, secureCookie) + tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) + if err != nil { response.ErrorFrom(c, err) return } + if err := applyPendingOAuthAdoptionAndConsumeSession(c.Request.Context(), client, h.authService, h.userService, session, decision, user.ID); err != nil { + respondPendingOAuthBindingApplyError(c, err) + return + } + h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID) clearOAuthPendingSessionCookie(c, secureCookie) clearOAuthPendingBrowserCookie(c, secureCookie) diff --git a/backend/internal/handler/auth_oidc_oauth_test.go b/backend/internal/handler/auth_oidc_oauth_test.go index c2855dc9..3216d51e 100644 --- a/backend/internal/handler/auth_oidc_oauth_test.go +++ b/backend/internal/handler/auth_oidc_oauth_test.go @@ -438,7 +438,8 @@ func TestOIDCOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *testing Only(ctx) require.NoError(t, err) require.Equal(t, oauthIntentLogin, session.Intent) - require.Nil(t, session.TargetUserID) + require.NotNil(t, session.TargetUserID) + require.Equal(t, existingUser.ID, *session.TargetUserID) require.Equal(t, existingUser.Email, session.ResolvedEmail) require.Equal(t, "legacy@example.com", session.UpstreamIdentityClaims["compat_email"]) @@ -862,6 +863,69 @@ func TestCompleteOIDCOAuthRegistrationBindsIdentityWithoutAdoptionFlags(t *testi require.False(t, decision.AdoptAvatar) } +func TestCompleteOIDCOAuthRegistrationRejectsIdentityOwnershipConflictBeforeUserCreation(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + existingOwner, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + _, err = client.AuthIdentity.Create(). + SetUserID(existingOwner.ID). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example.com"). + SetProviderSubject("oidc-conflict-subject"). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("oidc-complete-conflict-session"). + SetIntent("login"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example.com"). + SetProviderSubject("oidc-conflict-subject"). + SetResolvedEmail("f6f5f1f16f9248ccb11e0d633963b290@oidc-connect.invalid"). + SetBrowserSessionKey("oidc-conflict-browser"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + "issuer": "https://issuer.example.com", + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-conflict-browser")}) + c.Request = req + + handler.CompleteOIDCOAuthRegistration(c) + + require.Equal(t, http.StatusConflict, recorder.Code) + payload := decodeJSONBody(t, recorder) + require.Equal(t, "AUTH_IDENTITY_OWNERSHIP_CONFLICT", payload["reason"]) + + userCount, err := client.User.Query(). + Where(dbuser.EmailEQ("f6f5f1f16f9248ccb11e0d633963b290@oidc-connect.invalid")). + Count(ctx) + require.NoError(t, err) + require.Zero(t, userCount) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + type oidcProviderFixture struct { Subject string PreferredUsername string diff --git a/backend/internal/repository/auth_identity_legacy_migration_integration_test.go b/backend/internal/repository/auth_identity_legacy_migration_integration_test.go index 41b64de7..e64934c5 100644 --- a/backend/internal/repository/auth_identity_legacy_migration_integration_test.go +++ b/backend/internal/repository/auth_identity_legacy_migration_integration_test.go @@ -576,6 +576,258 @@ FROM auth_identity_migration_reports require.Equal(t, beforeCount, afterCount) } +func TestAuthIdentityLegacyExternalBackfillMigration_SkipsAmbiguousCanonicalSubjects(t *testing.T) { + tx := testTx(t) + ctx := context.Background() + + migrationPath := filepath.Join("..", "..", "migrations", "115_auth_identity_legacy_external_backfill.sql") + migrationSQL, err := os.ReadFile(migrationPath) + require.NoError(t, err) + + prepareLegacyExternalIdentitiesTable(t, tx, ctx) + truncateAuthIdentityLegacyFixtureTables(t, tx, ctx) + + var linuxDoFirstUserID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO users (email, password_hash, role, status, balance, concurrency) +VALUES ('legacy-linuxdo-ambiguous-a@example.com', 'hash', 'user', 'active', 0, 1) +RETURNING id`).Scan(&linuxDoFirstUserID)) + + var linuxDoSecondUserID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO users (email, password_hash, role, status, balance, concurrency) +VALUES ('legacy-linuxdo-ambiguous-b@example.com', 'hash', 'user', 'active', 0, 1) +RETURNING id`).Scan(&linuxDoSecondUserID)) + + var wechatFirstUserID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO users (email, password_hash, role, status, balance, concurrency) +VALUES ('legacy-wechat-ambiguous-a@example.com', 'hash', 'user', 'active', 0, 1) +RETURNING id`).Scan(&wechatFirstUserID)) + + var wechatSecondUserID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO users (email, password_hash, role, status, balance, concurrency) +VALUES ('legacy-wechat-ambiguous-b@example.com', 'hash', 'user', 'active', 0, 1) +RETURNING id`).Scan(&wechatSecondUserID)) + + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'linuxdo', 'linuxdo-ambiguous-subject', NULL, 'legacy-linuxdo-ambiguous-a', 'Legacy LinuxDo Ambiguous A', '{"source":"legacy"}') +RETURNING id +`, linuxDoFirstUserID).Scan(new(int64))) + + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'linuxdo', 'linuxdo-ambiguous-subject', NULL, 'legacy-linuxdo-ambiguous-b', 'Legacy LinuxDo Ambiguous B', '{"source":"legacy"}') +RETURNING id +`, linuxDoSecondUserID).Scan(new(int64))) + + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'wechat', 'openid-ambiguous-a', 'union-ambiguous-subject', 'legacy-wechat-ambiguous-a', 'Legacy WeChat Ambiguous A', '{"channel":"oa","appid":"wx-ambiguous-a"}') +RETURNING id +`, wechatFirstUserID).Scan(new(int64))) + + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'wechat', 'openid-ambiguous-b', 'union-ambiguous-subject', 'legacy-wechat-ambiguous-b', 'Legacy WeChat Ambiguous B', '{"channel":"oa","appid":"wx-ambiguous-b"}') +RETURNING id +`, wechatSecondUserID).Scan(new(int64))) + + _, err = tx.ExecContext(ctx, string(migrationSQL)) + require.NoError(t, err) + + var linuxDoIdentityCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identities +WHERE provider_type = 'linuxdo' + AND provider_key = 'linuxdo' + AND provider_subject = 'linuxdo-ambiguous-subject' +`).Scan(&linuxDoIdentityCount)) + require.Zero(t, linuxDoIdentityCount) + + var wechatIdentityCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identities +WHERE provider_type = 'wechat' + AND provider_key = 'wechat-main' + AND provider_subject = 'union-ambiguous-subject' +`).Scan(&wechatIdentityCount)) + require.Zero(t, wechatIdentityCount) + + var wechatChannelCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identity_channels +WHERE provider_type = 'wechat' + AND provider_key = 'wechat-main' + AND channel = 'oa' + AND channel_app_id IN ('wx-ambiguous-a', 'wx-ambiguous-b') +`).Scan(&wechatChannelCount)) + require.Zero(t, wechatChannelCount) +} + +func TestAuthIdentityLegacyExternalMigrations_ReportAmbiguousCanonicalSubjectsWithoutWinnerAttribution(t *testing.T) { + tx := testTx(t) + ctx := context.Background() + + migration115Path := filepath.Join("..", "..", "migrations", "115_auth_identity_legacy_external_backfill.sql") + migration115SQL, err := os.ReadFile(migration115Path) + require.NoError(t, err) + + migration116Path := filepath.Join("..", "..", "migrations", "116_auth_identity_legacy_external_safety_reports.sql") + migration116SQL, err := os.ReadFile(migration116Path) + require.NoError(t, err) + + prepareLegacyExternalIdentitiesTable(t, tx, ctx) + truncateAuthIdentityLegacyFixtureTables(t, tx, ctx) + + var linuxDoFirstUserID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO users (email, password_hash, role, status, balance, concurrency) +VALUES ('legacy-linuxdo-conflict-a@example.com', 'hash', 'user', 'active', 0, 1) +RETURNING id`).Scan(&linuxDoFirstUserID)) + + var linuxDoSecondUserID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO users (email, password_hash, role, status, balance, concurrency) +VALUES ('legacy-linuxdo-conflict-b@example.com', 'hash', 'user', 'active', 0, 1) +RETURNING id`).Scan(&linuxDoSecondUserID)) + + var wechatFirstUserID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO users (email, password_hash, role, status, balance, concurrency) +VALUES ('legacy-wechat-conflict-a@example.com', 'hash', 'user', 'active', 0, 1) +RETURNING id`).Scan(&wechatFirstUserID)) + + var wechatSecondUserID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO users (email, password_hash, role, status, balance, concurrency) +VALUES ('legacy-wechat-conflict-b@example.com', 'hash', 'user', 'active', 0, 1) +RETURNING id`).Scan(&wechatSecondUserID)) + + var linuxDoFirstLegacyID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'linuxdo', 'linuxdo-conflict-subject', NULL, 'legacy-linuxdo-conflict-a', 'Legacy LinuxDo Conflict A', '{"source":"legacy"}') +RETURNING id +`, linuxDoFirstUserID).Scan(&linuxDoFirstLegacyID)) + + var linuxDoSecondLegacyID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'linuxdo', 'linuxdo-conflict-subject', NULL, 'legacy-linuxdo-conflict-b', 'Legacy LinuxDo Conflict B', '{"source":"legacy"}') +RETURNING id +`, linuxDoSecondUserID).Scan(&linuxDoSecondLegacyID)) + + var wechatFirstLegacyID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'wechat', 'openid-conflict-a', 'union-conflict-subject', 'legacy-wechat-conflict-a', 'Legacy WeChat Conflict A', '{"channel":"oa","appid":"wx-conflict-a"}') +RETURNING id +`, wechatFirstUserID).Scan(&wechatFirstLegacyID)) + + var wechatSecondLegacyID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'wechat', 'openid-conflict-b', 'union-conflict-subject', 'legacy-wechat-conflict-b', 'Legacy WeChat Conflict B', '{"channel":"oa","appid":"wx-conflict-b"}') +RETURNING id +`, wechatSecondUserID).Scan(&wechatSecondLegacyID)) + + _, err = tx.ExecContext(ctx, string(migration115SQL)) + require.NoError(t, err) + + _, err = tx.ExecContext(ctx, string(migration116SQL)) + require.NoError(t, err) + + var identityCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identities +WHERE (provider_type = 'linuxdo' AND provider_key = 'linuxdo' AND provider_subject = 'linuxdo-conflict-subject') + OR (provider_type = 'wechat' AND provider_key = 'wechat-main' AND provider_subject = 'union-conflict-subject') +`).Scan(&identityCount)) + require.Zero(t, identityCount) + + var conflictReportCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identity_migration_reports +WHERE report_type = 'legacy_external_identity_conflict' + AND report_key IN ($1, $2, $3, $4) +`, "legacy_external_identity:"+strconv.FormatInt(linuxDoFirstLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(linuxDoSecondLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatFirstLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatSecondLegacyID, 10)).Scan(&conflictReportCount)) + require.Equal(t, 4, conflictReportCount) + + var winnerAttributedReportCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identity_migration_reports +WHERE report_type = 'legacy_external_identity_conflict' + AND report_key IN ($1, $2, $3, $4) + AND details ->> 'existing_identity_id' IS NOT NULL +`, "legacy_external_identity:"+strconv.FormatInt(linuxDoFirstLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(linuxDoSecondLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatFirstLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatSecondLegacyID, 10)).Scan(&winnerAttributedReportCount)) + require.Zero(t, winnerAttributedReportCount) +} + func TestAuthIdentityMigrationReportTypeWideningPreflightKeeps109And116SafeBefore121(t *testing.T) { tx := testTx(t) ctx := context.Background() diff --git a/backend/internal/repository/migrations_runner.go b/backend/internal/repository/migrations_runner.go index 662a3972..f5798486 100644 --- a/backend/internal/repository/migrations_runner.go +++ b/backend/internal/repository/migrations_runner.go @@ -51,6 +51,8 @@ CREATE TABLE IF NOT EXISTS atlas_schema_revisions ( const migrationsAdvisoryLockID int64 = 694208311321144027 const migrationsLockRetryInterval = 500 * time.Millisecond const nonTransactionalMigrationSuffix = "_notx.sql" +const paymentOrdersOutTradeNoUniqueMigration = "120_enforce_payment_orders_out_trade_no_unique_notx.sql" +const paymentOrdersOutTradeNoUniqueIndex = "paymentorder_out_trade_no_unique" type migrationChecksumCompatibilityRule struct { fileChecksum string @@ -65,9 +67,11 @@ var migrationChecksumCompatibilityRules = map[string]migrationChecksumCompatibil "054_drop_legacy_cache_columns.sql": newMigrationChecksumCompatibilityRule("82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d", "182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4"), "061_add_usage_log_request_type.sql": newMigrationChecksumCompatibilityRule("66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c", "08a248652cbab7cfde147fc6ef8cda464f2477674e20b718312faa252e0481c0", "222b4a09c797c22e5922b6b172327c824f5463aaa8760e4f621bc5c22e2be0f3"), "109_auth_identity_compat_backfill.sql": newMigrationChecksumCompatibilityRule("2b380305e73ff0c13aa8c811e45897f2b36ca4a438f7b3e8f98e19ecb6bae0b3", "551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee"), + "115_auth_identity_legacy_external_backfill.sql": newMigrationChecksumCompatibilityRule("022aadd97bb53e755f0cf7a3a957e0cb1a1353b0c39ec4de3234acd2871fd04f", "4cf39e508be9fd1a5aa41610cbbebeb80385c9adda45bf78a706de9db4f1385f"), + "116_auth_identity_legacy_external_safety_reports.sql": newMigrationChecksumCompatibilityRule("07edb09fa8d04ffb172b0621e3c22f4d1757d20a24ae267b3b36b087ab72d488", "f7757bd929ac67ffb08ce69fa4cf20fad39dbff9d5a5085fb2adabb7607e5877"), "118_wechat_dual_mode_and_auth_source_defaults.sql": newMigrationChecksumCompatibilityRule("b54194d7a3e4fbf710e0a3590d22a2fe7966804c487052a356e0b55f53ef96b0", "e0cdf835d6c688d64100f483d31bc02ac9ebad414bf1837af239a84bf75b8227"), "119_enforce_payment_orders_out_trade_no_unique.sql": newMigrationChecksumCompatibilityRule("0bbe809ae48a9d811dabda1ba1c74955bd71c4a9cc610f9128816818dfa6c11e", "ebd2c67cce0116393fb4f1b5d5116a67c6aceb73820dfb5133d1ff6f36d72d34"), - "120_enforce_payment_orders_out_trade_no_unique_notx.sql": newMigrationChecksumCompatibilityRule("707431450603e70a43ce9fbd61e0c12fa67da4875158ccefabacea069587ab22", "04b082b5a239c525154fe9185d324ee2b05ff90da9297e10dba19f9be79aa59a"), + "120_enforce_payment_orders_out_trade_no_unique_notx.sql": newMigrationChecksumCompatibilityRule("34aadc0db59a4e390f92a12b73bd74642d9724f33124f73638ae00089ea5e074", "e77921f79d539bc24575cb9c16cbe566d2b23ce816190343d0a7568f6a3fcf61", "707431450603e70a43ce9fbd61e0c12fa67da4875158ccefabacea069587ab22", "04b082b5a239c525154fe9185d324ee2b05ff90da9297e10dba19f9be79aa59a"), "123_fix_legacy_auth_source_grant_on_signup_defaults.sql": newMigrationChecksumCompatibilityRule("2ce43c2cd89e9f9e1febd34a407ed9e84d177386c5544b6f02c1f58a21129f57", "6cd33422f215dcd1f486ab6f35c0ea5805d9ca69bb25906d94bc649156657145"), } @@ -195,6 +199,10 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error { } if nonTx { + if err := prepareNonTransactionalMigration(ctx, db, name); err != nil { + return fmt.Errorf("prepare migration %s: %w", name, err) + } + // *_notx.sql:用于 CREATE/DROP INDEX CONCURRENTLY 场景,必须非事务执行。 // 逐条语句执行,避免将多条 CONCURRENTLY 语句放入同一个隐式事务块。 statements := splitSQLStatements(content) @@ -244,6 +252,88 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error { return nil } +func prepareNonTransactionalMigration(ctx context.Context, db *sql.DB, name string) error { + switch name { + case paymentOrdersOutTradeNoUniqueMigration: + return preparePaymentOrdersOutTradeNoUniqueMigration(ctx, db) + default: + return nil + } +} + +func preparePaymentOrdersOutTradeNoUniqueMigration(ctx context.Context, db *sql.DB) error { + duplicates, err := findDuplicatePaymentOrderOutTradeNos(ctx, db) + if err != nil { + return fmt.Errorf("precheck duplicate out_trade_no: %w", err) + } + if len(duplicates) > 0 { + return fmt.Errorf( + "duplicate out_trade_no values block %s; remediate duplicates before retrying: %s", + paymentOrdersOutTradeNoUniqueMigration, + strings.Join(duplicates, ", "), + ) + } + + invalid, err := indexIsInvalid(ctx, db, paymentOrdersOutTradeNoUniqueIndex) + if err != nil { + return fmt.Errorf("check invalid index %s: %w", paymentOrdersOutTradeNoUniqueIndex, err) + } + if !invalid { + return nil + } + + if _, err := db.ExecContext(ctx, fmt.Sprintf("DROP INDEX CONCURRENTLY IF EXISTS %s", paymentOrdersOutTradeNoUniqueIndex)); err != nil { + return fmt.Errorf("drop invalid index %s: %w", paymentOrdersOutTradeNoUniqueIndex, err) + } + return nil +} + +func findDuplicatePaymentOrderOutTradeNos(ctx context.Context, db *sql.DB) ([]string, error) { + rows, err := db.QueryContext(ctx, ` + SELECT out_trade_no, COUNT(*) AS duplicate_count + FROM payment_orders + WHERE out_trade_no <> '' + GROUP BY out_trade_no + HAVING COUNT(*) > 1 + ORDER BY duplicate_count DESC, out_trade_no + LIMIT 5 + `) + if err != nil { + return nil, err + } + defer rows.Close() + + duplicates := make([]string, 0, 5) + for rows.Next() { + var outTradeNo string + var duplicateCount int + if err := rows.Scan(&outTradeNo, &duplicateCount); err != nil { + return nil, err + } + duplicates = append(duplicates, fmt.Sprintf("%s (count=%d)", outTradeNo, duplicateCount)) + } + if err := rows.Err(); err != nil { + return nil, err + } + return duplicates, nil +} + +func indexIsInvalid(ctx context.Context, db *sql.DB, indexName string) (bool, error) { + var invalid bool + err := db.QueryRowContext(ctx, ` + SELECT EXISTS ( + SELECT 1 + FROM pg_class idx + JOIN pg_namespace ns ON ns.oid = idx.relnamespace + JOIN pg_index i ON i.indexrelid = idx.oid + WHERE ns.nspname = 'public' + AND idx.relname = $1 + AND NOT i.indisvalid + ) + `, indexName).Scan(&invalid) + return invalid, err +} + func ensureAtlasBaselineAligned(ctx context.Context, db *sql.DB, fsys fs.FS) error { hasLegacy, err := tableExists(ctx, db, "schema_migrations") if err != nil { diff --git a/backend/internal/repository/migrations_runner_checksum_test.go b/backend/internal/repository/migrations_runner_checksum_test.go index dc241a75..57647093 100644 --- a/backend/internal/repository/migrations_runner_checksum_test.go +++ b/backend/internal/repository/migrations_runner_checksum_test.go @@ -70,6 +70,24 @@ func TestIsMigrationChecksumCompatible(t *testing.T) { require.True(t, ok) }) + t.Run("115历史checksum可兼容修复后的legacy external backfill", func(t *testing.T) { + ok := isMigrationChecksumCompatible( + "115_auth_identity_legacy_external_backfill.sql", + "4cf39e508be9fd1a5aa41610cbbebeb80385c9adda45bf78a706de9db4f1385f", + "022aadd97bb53e755f0cf7a3a957e0cb1a1353b0c39ec4de3234acd2871fd04f", + ) + require.True(t, ok) + }) + + t.Run("116历史checksum可兼容修复后的legacy external safety reports", func(t *testing.T) { + ok := isMigrationChecksumCompatible( + "116_auth_identity_legacy_external_safety_reports.sql", + "f7757bd929ac67ffb08ce69fa4cf20fad39dbff9d5a5085fb2adabb7607e5877", + "07edb09fa8d04ffb172b0621e3c22f4d1757d20a24ae267b3b36b087ab72d488", + ) + require.True(t, ok) + }) + t.Run("119历史checksum可兼容占位文件", func(t *testing.T) { ok := isMigrationChecksumCompatible( "119_enforce_payment_orders_out_trade_no_unique.sql", @@ -79,6 +97,21 @@ func TestIsMigrationChecksumCompatible(t *testing.T) { require.True(t, ok) }) + t.Run("120多个历史checksum都可兼容新的notx修复版本", func(t *testing.T) { + for _, dbChecksum := range []string{ + "e77921f79d539bc24575cb9c16cbe566d2b23ce816190343d0a7568f6a3fcf61", + "707431450603e70a43ce9fbd61e0c12fa67da4875158ccefabacea069587ab22", + "04b082b5a239c525154fe9185d324ee2b05ff90da9297e10dba19f9be79aa59a", + } { + ok := isMigrationChecksumCompatible( + "120_enforce_payment_orders_out_trade_no_unique_notx.sql", + dbChecksum, + "34aadc0db59a4e390f92a12b73bd74642d9724f33124f73638ae00089ea5e074", + ) + require.True(t, ok) + } + }) + t.Run("119未知checksum不兼容", func(t *testing.T) { ok := isMigrationChecksumCompatible( "119_enforce_payment_orders_out_trade_no_unique.sql", diff --git a/backend/internal/repository/migrations_runner_extra_test.go b/backend/internal/repository/migrations_runner_extra_test.go index af1adc50..a8bc15bc 100644 --- a/backend/internal/repository/migrations_runner_extra_test.go +++ b/backend/internal/repository/migrations_runner_extra_test.go @@ -96,6 +96,8 @@ func TestIsMigrationChecksumCompatible_AdditionalCases(t *testing.T) { func TestMigrationChecksumCompatibilityRules_CoverEditedUpgradeCompatibilityMigrations(t *testing.T) { for _, name := range []string{ + "115_auth_identity_legacy_external_backfill.sql", + "116_auth_identity_legacy_external_safety_reports.sql", "118_wechat_dual_mode_and_auth_source_defaults.sql", "120_enforce_payment_orders_out_trade_no_unique_notx.sql", "123_fix_legacy_auth_source_grant_on_signup_defaults.sql", diff --git a/backend/internal/repository/migrations_runner_notx_test.go b/backend/internal/repository/migrations_runner_notx_test.go index db1183cd..b7cb396c 100644 --- a/backend/internal/repository/migrations_runner_notx_test.go +++ b/backend/internal/repository/migrations_runner_notx_test.go @@ -116,6 +116,84 @@ CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_b ON t(b); require.NoError(t, mock.ExpectationsWereMet()) } +func TestApplyMigrationsFS_PaymentOrdersOutTradeNoUniqueMigration_FailsFastOnDuplicatePrecheck(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + prepareMigrationsBootstrapExpectations(mock) + mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1"). + WithArgs("120_enforce_payment_orders_out_trade_no_unique_notx.sql"). + WillReturnError(sql.ErrNoRows) + mock.ExpectQuery("SELECT out_trade_no, COUNT\\(\\*\\) AS duplicate_count FROM payment_orders"). + WillReturnRows(sqlmock.NewRows([]string{"out_trade_no", "duplicate_count"}).AddRow("dup-out-trade-no", 2)) + mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnResult(sqlmock.NewResult(0, 1)) + + fsys := fstest.MapFS{ + "120_enforce_payment_orders_out_trade_no_unique_notx.sql": &fstest.MapFile{ + Data: []byte(` +CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique + ON payment_orders (out_trade_no) + WHERE out_trade_no <> ''; + +DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no; +`), + }, + } + + err = applyMigrationsFS(context.Background(), db, fsys) + require.Error(t, err) + require.Contains(t, err.Error(), "duplicate out_trade_no") + require.Contains(t, err.Error(), "dup-out-trade-no") + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestApplyMigrationsFS_PaymentOrdersOutTradeNoUniqueMigration_DropsInvalidIndexBeforeRetry(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + prepareMigrationsBootstrapExpectations(mock) + mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1"). + WithArgs("120_enforce_payment_orders_out_trade_no_unique_notx.sql"). + WillReturnError(sql.ErrNoRows) + mock.ExpectQuery("SELECT out_trade_no, COUNT\\(\\*\\) AS duplicate_count FROM payment_orders"). + WillReturnRows(sqlmock.NewRows([]string{"out_trade_no", "duplicate_count"})) + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("paymentorder_out_trade_no_unique"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true)) + mock.ExpectExec("DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no_unique"). + WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectExec("CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique"). + WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectExec("DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no"). + WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectExec("INSERT INTO schema_migrations \\(filename, checksum\\) VALUES \\(\\$1, \\$2\\)"). + WithArgs("120_enforce_payment_orders_out_trade_no_unique_notx.sql", sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnResult(sqlmock.NewResult(0, 1)) + + fsys := fstest.MapFS{ + "120_enforce_payment_orders_out_trade_no_unique_notx.sql": &fstest.MapFile{ + Data: []byte(` +CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique + ON payment_orders (out_trade_no) + WHERE out_trade_no <> ''; + +DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no; +`), + }, + } + + err = applyMigrationsFS(context.Background(), db, fsys) + require.NoError(t, err) + require.NoError(t, mock.ExpectationsWereMet()) +} + func TestApplyMigrationsFS_TransactionalMigration(t *testing.T) { db, mock, err := sqlmock.New() require.NoError(t, err) diff --git a/backend/internal/repository/migrations_schema_integration_test.go b/backend/internal/repository/migrations_schema_integration_test.go index ac4dea18..eeee5c23 100644 --- a/backend/internal/repository/migrations_schema_integration_test.go +++ b/backend/internal/repository/migrations_schema_integration_test.go @@ -93,6 +93,19 @@ func TestMigrationsRunner_AuthIdentityAndPaymentSchemaStayAligned(t *testing.T) tx := testTx(t) requireColumn(t, tx, "auth_identity_migration_reports", "report_type", "character varying", 80, false) + requireColumn(t, tx, "users", "signup_source", "character varying", 20, false) + requireColumnDefaultContains(t, tx, "users", "signup_source", "email") + requireConstraintDefinitionContains( + t, + tx, + "users", + "users_signup_source_check", + "signup_source", + "'email'", + "'linuxdo'", + "'wechat'", + "'oidc'", + ) requireForeignKeyOnDelete(t, tx, "auth_identities", "user_id", "users", "CASCADE") requireForeignKeyOnDelete(t, tx, "auth_identity_channels", "identity_id", "auth_identities", "CASCADE") @@ -195,6 +208,45 @@ LIMIT 1 require.Equal(t, expected, actual, "unexpected ON DELETE action for %s.%s -> %s", table, column, refTable) } +func requireConstraintDefinitionContains(t *testing.T, tx *sql.Tx, table, constraint string, fragments ...string) { + t.Helper() + + var def string + err := tx.QueryRowContext(context.Background(), ` +SELECT pg_get_constraintdef(c.oid) +FROM pg_constraint c +JOIN pg_class tbl ON tbl.oid = c.conrelid +JOIN pg_namespace ns ON ns.oid = tbl.relnamespace +WHERE ns.nspname = 'public' + AND tbl.relname = $1 + AND c.conname = $2 +`, table, constraint).Scan(&def) + require.NoError(t, err, "query constraint definition for %s.%s", table, constraint) + + for _, fragment := range fragments { + require.Contains(t, def, fragment, "expected constraint definition for %s.%s to contain %q", table, constraint, fragment) + } +} + +func requireColumnDefaultContains(t *testing.T, tx *sql.Tx, table, column string, fragments ...string) { + t.Helper() + + var columnDefault sql.NullString + err := tx.QueryRowContext(context.Background(), ` +SELECT column_default +FROM information_schema.columns +WHERE table_schema = 'public' + AND table_name = $1 + AND column_name = $2 +`, table, column).Scan(&columnDefault) + require.NoError(t, err, "query column_default for %s.%s", table, column) + require.True(t, columnDefault.Valid, "expected column_default for %s.%s", table, column) + + for _, fragment := range fragments { + require.Contains(t, columnDefault.String, fragment, "expected default for %s.%s to contain %q", table, column, fragment) + } +} + func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) { t.Helper() diff --git a/backend/internal/repository/user_profile_identity_repo.go b/backend/internal/repository/user_profile_identity_repo.go index 87094ad7..b2b03746 100644 --- a/backend/internal/repository/user_profile_identity_repo.go +++ b/backend/internal/repository/user_profile_identity_repo.go @@ -4,11 +4,15 @@ import ( "context" "database/sql" "fmt" + "hash/fnv" "reflect" + "sort" "strings" + "sync" "time" "unsafe" + "entgo.io/ent/dialect" entsql "entgo.io/ent/dialect/sql" dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/authidentity" @@ -120,6 +124,113 @@ type sqlQueryExecutor interface { QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) } +var repositoryScopedKeyLocks = newScopedKeyLockRegistry() + +type scopedKeyLockRegistry struct { + mu sync.Mutex + locks map[string]*scopedKeyLockEntry +} + +type scopedKeyLockEntry struct { + mu sync.Mutex + refs int +} + +func newScopedKeyLockRegistry() *scopedKeyLockRegistry { + return &scopedKeyLockRegistry{ + locks: make(map[string]*scopedKeyLockEntry), + } +} + +func (r *scopedKeyLockRegistry) lock(keys ...string) func() { + normalized := normalizeLockKeys(keys...) + if len(normalized) == 0 { + return func() {} + } + + entries := make([]*scopedKeyLockEntry, 0, len(normalized)) + r.mu.Lock() + for _, key := range normalized { + entry := r.locks[key] + if entry == nil { + entry = &scopedKeyLockEntry{} + r.locks[key] = entry + } + entry.refs++ + entries = append(entries, entry) + } + r.mu.Unlock() + + for _, entry := range entries { + entry.mu.Lock() + } + + return func() { + for i := len(entries) - 1; i >= 0; i-- { + entries[i].mu.Unlock() + } + + r.mu.Lock() + defer r.mu.Unlock() + for idx, key := range normalized { + entry := entries[idx] + entry.refs-- + if entry.refs == 0 { + delete(r.locks, key) + } + } + } +} + +func normalizeLockKeys(keys ...string) []string { + if len(keys) == 0 { + return nil + } + + deduped := make(map[string]struct{}, len(keys)) + for _, key := range keys { + trimmed := strings.TrimSpace(key) + if trimmed == "" { + continue + } + deduped[trimmed] = struct{}{} + } + if len(deduped) == 0 { + return nil + } + + normalized := make([]string, 0, len(deduped)) + for key := range deduped { + normalized = append(normalized, key) + } + sort.Strings(normalized) + return normalized +} + +func advisoryLockHash(key string) int64 { + hasher := fnv.New64a() + _, _ = hasher.Write([]byte(key)) + return int64(hasher.Sum64()) +} + +func lockRepositoryScopedKeys(ctx context.Context, client *dbent.Client, exec sqlQueryExecutor, keys ...string) (func(), error) { + release := repositoryScopedKeyLocks.lock(keys...) + normalized := normalizeLockKeys(keys...) + if len(normalized) == 0 || client == nil || exec == nil || client.Driver().Dialect() != dialect.Postgres { + return release, nil + } + + for _, key := range normalized { + rows, err := exec.QueryContext(ctx, "SELECT pg_advisory_xact_lock($1)", advisoryLockHash(key)) + if err != nil { + release() + return nil, err + } + _ = rows.Close() + } + return release, nil +} + func (r *userRepository) WithUserProfileIdentityTx(ctx context.Context, fn func(txCtx context.Context) error) error { if dbent.TxFromContext(ctx) != nil { return fn(ctx) @@ -329,7 +440,11 @@ func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindA return err } } else { + targetProviderKey := canonicalizeCompatibleIdentityProviderKey(canonical.ProviderType, identity.ProviderKey, canonical.ProviderKey) update := client.AuthIdentity.UpdateOneID(identity.ID) + if targetProviderKey != "" && !strings.EqualFold(targetProviderKey, identity.ProviderKey) { + update = update.SetProviderKey(targetProviderKey) + } if input.Metadata != nil { update = update.SetMetadata(copyMetadata(input.Metadata)) } @@ -378,8 +493,12 @@ func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindA return err } } else { + targetProviderKey := canonicalizeCompatibleIdentityProviderKey(input.Channel.ProviderType, channel.ProviderKey, input.Channel.ProviderKey) update := client.AuthIdentityChannel.UpdateOneID(channel.ID). SetIdentityID(identity.ID) + if targetProviderKey != "" && !strings.EqualFold(targetProviderKey, channel.ProviderKey) { + update = update.SetProviderKey(targetProviderKey) + } if input.ChannelMetadata != nil { update = update.SetMetadata(copyMetadata(input.ChannelMetadata)) } @@ -418,13 +537,52 @@ func compatibleIdentityProviderKeys(providerType, providerKey string) []string { return keys } +func canonicalizeCompatibleIdentityProviderKey(providerType, existingKey, requestedKey string) string { + providerType = strings.TrimSpace(strings.ToLower(providerType)) + existingKey = strings.TrimSpace(existingKey) + requestedKey = strings.TrimSpace(requestedKey) + if providerType != "wechat" { + if requestedKey != "" { + return requestedKey + } + return existingKey + } + if strings.EqualFold(existingKey, "wechat") || strings.EqualFold(existingKey, "wechat-main") || strings.EqualFold(requestedKey, "wechat-main") { + return "wechat-main" + } + if requestedKey != "" { + return requestedKey + } + return existingKey +} + +func compatibleIdentityProviderKeyRank(providerType, providerKey string) int { + providerType = strings.TrimSpace(strings.ToLower(providerType)) + providerKey = strings.TrimSpace(providerKey) + if providerType != "wechat" { + return 0 + } + switch { + case strings.EqualFold(providerKey, "wechat-main"): + return 0 + case strings.EqualFold(providerKey, "wechat"): + return 2 + default: + return 1 + } +} + func selectOwnedCompatibleIdentity(records []*dbent.AuthIdentity, userID int64) *dbent.AuthIdentity { + var selected *dbent.AuthIdentity for _, record := range records { - if record.UserID == userID { - return record + if record.UserID != userID { + continue + } + if selected == nil || compatibleIdentityProviderKeyRank(record.ProviderType, record.ProviderKey) < compatibleIdentityProviderKeyRank(selected.ProviderType, selected.ProviderKey) { + selected = record } } - return nil + return selected } func hasCompatibleIdentityConflict(records []*dbent.AuthIdentity, userID int64) bool { @@ -437,12 +595,16 @@ func hasCompatibleIdentityConflict(records []*dbent.AuthIdentity, userID int64) } func selectOwnedCompatibleChannel(records []*dbent.AuthIdentityChannel, userID int64) *dbent.AuthIdentityChannel { + var selected *dbent.AuthIdentityChannel for _, record := range records { - if record.Edges.Identity != nil && record.Edges.Identity.UserID == userID { - return record + if record.Edges.Identity == nil || record.Edges.Identity.UserID != userID { + continue + } + if selected == nil || compatibleIdentityProviderKeyRank(record.ProviderType, record.ProviderKey) < compatibleIdentityProviderKeyRank(selected.ProviderType, selected.ProviderKey) { + selected = record } } - return nil + return selected } func hasCompatibleChannelConflict(records []*dbent.AuthIdentityChannel, userID int64) bool { @@ -479,51 +641,70 @@ ON CONFLICT (user_id, provider_type, grant_reason) DO NOTHING`, } func (r *userRepository) UpsertIdentityAdoptionDecision(ctx context.Context, input IdentityAdoptionDecisionInput) (*dbent.IdentityAdoptionDecision, error) { - client := clientFromContext(ctx, r.client) - if input.IdentityID != nil && *input.IdentityID > 0 { - if _, err := client.IdentityAdoptionDecision.Update(). - Where( - identityadoptiondecision.IdentityIDEQ(*input.IdentityID), - dbpredicate.IdentityAdoptionDecision(func(s *entsql.Selector) { - col := s.C(identityadoptiondecision.FieldPendingAuthSessionID) - s.Where(entsql.Or( - entsql.IsNull(col), - entsql.NEQ(col, input.PendingAuthSessionID), - )) - }), - ). - ClearIdentityID(). - Save(ctx); err != nil { - return nil, err + var result *dbent.IdentityAdoptionDecision + err := r.WithUserProfileIdentityTx(ctx, func(txCtx context.Context) error { + client := clientFromContext(txCtx, r.client) + releaseLocks, err := lockRepositoryScopedKeys( + txCtx, + client, + txAwareSQLExecutor(txCtx, r.sql, r.client), + identityAdoptionDecisionLockKeys(input.PendingAuthSessionID, input.IdentityID)..., + ) + if err != nil { + return err + } + defer releaseLocks() + + if input.IdentityID != nil && *input.IdentityID > 0 { + if _, err := client.IdentityAdoptionDecision.Update(). + Where( + identityadoptiondecision.IdentityIDEQ(*input.IdentityID), + dbpredicate.IdentityAdoptionDecision(func(s *entsql.Selector) { + col := s.C(identityadoptiondecision.FieldPendingAuthSessionID) + s.Where(entsql.Or( + entsql.IsNull(col), + entsql.NEQ(col, input.PendingAuthSessionID), + )) + }), + ). + ClearIdentityID(). + Save(txCtx); err != nil { + return err + } } - } - current, err := client.IdentityAdoptionDecision.Query(). - Where(identityadoptiondecision.PendingAuthSessionIDEQ(input.PendingAuthSessionID)). - Only(ctx) - if err != nil && !dbent.IsNotFound(err) { - return nil, err - } - now := time.Now().UTC() - if current == nil { create := client.IdentityAdoptionDecision.Create(). SetPendingAuthSessionID(input.PendingAuthSessionID). SetAdoptDisplayName(input.AdoptDisplayName). SetAdoptAvatar(input.AdoptAvatar). - SetDecidedAt(now) - if input.IdentityID != nil { + SetDecidedAt(time.Now().UTC()) + if input.IdentityID != nil && *input.IdentityID > 0 { create = create.SetIdentityID(*input.IdentityID) } - return create.Save(ctx) - } - update := client.IdentityAdoptionDecision.UpdateOneID(current.ID). - SetAdoptDisplayName(input.AdoptDisplayName). - SetAdoptAvatar(input.AdoptAvatar) - if input.IdentityID != nil { - update = update.SetIdentityID(*input.IdentityID) + decisionID, err := create. + OnConflictColumns(identityadoptiondecision.FieldPendingAuthSessionID). + UpdateNewValues(). + ID(txCtx) + if err != nil { + return err + } + + result, err = client.IdentityAdoptionDecision.Get(txCtx, decisionID) + return err + }) + if err != nil { + return nil, err } - return update.Save(ctx) + return result, nil +} + +func identityAdoptionDecisionLockKeys(pendingAuthSessionID int64, identityID *int64) []string { + keys := []string{fmt.Sprintf("identity-adoption:pending:%d", pendingAuthSessionID)} + if identityID != nil && *identityID > 0 { + keys = append(keys, fmt.Sprintf("identity-adoption:identity:%d", *identityID)) + } + return keys } func (r *userRepository) GetIdentityAdoptionDecisionByPendingAuthSessionID(ctx context.Context, pendingAuthSessionID int64) (*dbent.IdentityAdoptionDecision, error) { diff --git a/backend/internal/repository/user_profile_identity_repo_unit_test.go b/backend/internal/repository/user_profile_identity_repo_unit_test.go new file mode 100644 index 00000000..689f32f9 --- /dev/null +++ b/backend/internal/repository/user_profile_identity_repo_unit_test.go @@ -0,0 +1,212 @@ +package repository + +import ( + "context" + "sync" + "testing" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +func TestUserRepositoryBindAuthIdentityToUserCanonicalizesLegacyWeChatAlias(t *testing.T) { + repo, client := newUserEntRepo(t) + ctx := context.Background() + + user := &service.User{ + Email: "wechat-legacy@example.com", + Username: "wechat-legacy", + PasswordHash: "hash", + Role: service.RoleUser, + Status: service.StatusActive, + } + require.NoError(t, repo.Create(ctx, user)) + + legacyIdentity, err := client.AuthIdentity.Create(). + SetUserID(user.ID). + SetProviderType("wechat"). + SetProviderKey("wechat"). + SetProviderSubject("union-legacy-123"). + SetMetadata(map[string]any{"source": "legacy-alias"}). + Save(ctx) + require.NoError(t, err) + + legacyChannel, err := client.AuthIdentityChannel.Create(). + SetIdentityID(legacyIdentity.ID). + SetProviderType("wechat"). + SetProviderKey("wechat"). + SetChannel("oa"). + SetChannelAppID("wx-app-legacy"). + SetChannelSubject("openid-legacy-123"). + SetMetadata(map[string]any{"scene": "legacy-alias"}). + Save(ctx) + require.NoError(t, err) + + bound, err := repo.BindAuthIdentityToUser(ctx, BindAuthIdentityInput{ + UserID: user.ID, + Canonical: AuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: "wechat-main", + ProviderSubject: "union-legacy-123", + }, + Channel: &AuthIdentityChannelKey{ + ProviderType: "wechat", + ProviderKey: "wechat-main", + Channel: "oa", + ChannelAppID: "wx-app-legacy", + ChannelSubject: "openid-legacy-123", + }, + Metadata: map[string]any{"source": "canonical-bind"}, + ChannelMetadata: map[string]any{"scene": "canonical-bind"}, + }) + require.NoError(t, err) + require.NotNil(t, bound) + require.NotNil(t, bound.Identity) + require.NotNil(t, bound.Channel) + require.Equal(t, legacyIdentity.ID, bound.Identity.ID) + require.Equal(t, legacyChannel.ID, bound.Channel.ID) + require.Equal(t, "wechat-main", bound.Identity.ProviderKey) + require.Equal(t, "wechat-main", bound.Channel.ProviderKey) + + reloadedIdentity, err := client.AuthIdentity.Get(ctx, legacyIdentity.ID) + require.NoError(t, err) + require.Equal(t, "wechat-main", reloadedIdentity.ProviderKey) + require.Equal(t, "canonical-bind", reloadedIdentity.Metadata["source"]) + + reloadedChannel, err := client.AuthIdentityChannel.Get(ctx, legacyChannel.ID) + require.NoError(t, err) + require.Equal(t, "wechat-main", reloadedChannel.ProviderKey) + require.Equal(t, "canonical-bind", reloadedChannel.Metadata["scene"]) + + identityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.UserIDEQ(user.ID), + authidentity.ProviderTypeEQ("wechat"), + authidentity.ProviderSubjectEQ("union-legacy-123"), + ). + Count(ctx) + require.NoError(t, err) + require.Equal(t, 1, identityCount) + + channelCount, err := client.AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ("wechat"), + authidentitychannel.ChannelEQ("oa"), + authidentitychannel.ChannelAppIDEQ("wx-app-legacy"), + authidentitychannel.ChannelSubjectEQ("openid-legacy-123"), + ). + Count(ctx) + require.NoError(t, err) + require.Equal(t, 1, channelCount) +} + +func TestUserRepositoryUpsertIdentityAdoptionDecisionIsIdempotentUnderConcurrency(t *testing.T) { + repo, client := newUserEntRepo(t) + ctx := context.Background() + + user := &service.User{ + Email: "repo-adoption@example.com", + Username: "repo-adoption", + PasswordHash: "hash", + Role: service.RoleUser, + Status: service.StatusActive, + } + require.NoError(t, repo.Create(ctx, user)) + + identity, err := client.AuthIdentity.Create(). + SetUserID(user.ID). + SetProviderType("wechat"). + SetProviderKey("wechat-main"). + SetProviderSubject("union-repo-adoption"). + SetMetadata(map[string]any{}). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("pending-repo-adoption"). + SetIntent("bind_current_user"). + SetProviderType("wechat"). + SetProviderKey("wechat-main"). + SetProviderSubject("union-repo-adoption"). + SetExpiresAt(time.Now().UTC().Add(15 * time.Minute)). + SetUpstreamIdentityClaims(map[string]any{"provider_subject": "union-repo-adoption"}). + SetLocalFlowState(map[string]any{"step": "pending"}). + Save(ctx) + require.NoError(t, err) + + firstCreateStarted := make(chan struct{}) + releaseFirstCreate := make(chan struct{}) + var firstCreate sync.Once + client.IdentityAdoptionDecision.Use(func(next dbent.Mutator) dbent.Mutator { + return dbent.MutateFunc(func(ctx context.Context, m dbent.Mutation) (dbent.Value, error) { + blocked := false + if m.Op().Is(dbent.OpCreate) { + firstCreate.Do(func() { + blocked = true + close(firstCreateStarted) + }) + } + if blocked { + <-releaseFirstCreate + } + return next.Mutate(ctx, m) + }) + }) + + type adoptionResult struct { + decision *dbent.IdentityAdoptionDecision + err error + } + + input := IdentityAdoptionDecisionInput{ + PendingAuthSessionID: session.ID, + IdentityID: &identity.ID, + AdoptDisplayName: true, + AdoptAvatar: true, + } + + results := make(chan adoptionResult, 2) + go func() { + decision, err := repo.UpsertIdentityAdoptionDecision(ctx, input) + results <- adoptionResult{decision: decision, err: err} + }() + + <-firstCreateStarted + + go func() { + decision, err := repo.UpsertIdentityAdoptionDecision(ctx, input) + results <- adoptionResult{decision: decision, err: err} + }() + + time.Sleep(100 * time.Millisecond) + close(releaseFirstCreate) + + first := <-results + second := <-results + + require.NoError(t, first.err) + require.NoError(t, second.err) + require.NotNil(t, first.decision) + require.NotNil(t, second.decision) + require.Equal(t, first.decision.ID, second.decision.ID) + + count, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Count(ctx) + require.NoError(t, err) + require.Equal(t, 1, count) + + loaded, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, loaded.IdentityID) + require.Equal(t, identity.ID, *loaded.IdentityID) + require.True(t, loaded.AdoptDisplayName) + require.True(t, loaded.AdoptAvatar) +} diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go index 68e51eeb..3d526e7b 100644 --- a/backend/internal/repository/user_repo.go +++ b/backend/internal/repository/user_repo.go @@ -43,9 +43,6 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error if userIn == nil { return nil } - if err := r.ensureNormalizedEmailAvailable(ctx, 0, userIn.Email); err != nil { - return err - } // 统一使用 ent 的事务:保证用户与允许分组的更新原子化, // 并避免基于 *sql.Tx 手动构造 ent client 导致的 ExecQuerier 断言错误。 @@ -55,9 +52,11 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error } var txClient *dbent.Client + txCtx := ctx if err == nil { defer func() { _ = tx.Rollback() }() txClient = tx.Client() + txCtx = dbent.NewTxContext(ctx, tx) } else { // 已处于外部事务中(ErrTxStarted),复用当前事务 client 并由调用方负责提交/回滚。 if existingTx := dbent.TxFromContext(ctx); existingTx != nil { @@ -67,6 +66,21 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error } } + releaseEmailLock, err := lockRepositoryScopedKeys( + txCtx, + txClient, + txAwareSQLExecutor(txCtx, r.sql, r.client), + normalizedEmailUniquenessLockKey(userIn.Email), + ) + if err != nil { + return err + } + defer releaseEmailLock() + + if err := ensureNormalizedEmailAvailableWithClient(txCtx, txClient, 0, userIn.Email); err != nil { + return err + } + created, err := txClient.User.Create(). SetEmail(userIn.Email). SetUsername(userIn.Username). @@ -79,15 +93,15 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error SetSignupSource(userSignupSourceOrDefault(userIn.SignupSource)). SetNillableLastLoginAt(userIn.LastLoginAt). SetNillableLastActiveAt(userIn.LastActiveAt). - Save(ctx) + Save(txCtx) if err != nil { return translatePersistenceError(err, nil, service.ErrEmailExists) } - if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, created.ID, userIn.AllowedGroups); err != nil { + if err := r.syncUserAllowedGroupsWithClient(txCtx, txClient, created.ID, userIn.AllowedGroups); err != nil { return err } - if err := ensureEmailAuthIdentityWithClient(ctx, txClient, created.ID, created.Email, "user_repo_create"); err != nil { + if err := ensureEmailAuthIdentityWithClient(txCtx, txClient, created.ID, created.Email, "user_repo_create"); err != nil { return err } @@ -149,9 +163,6 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error if userIn == nil { return nil } - if err := r.ensureNormalizedEmailAvailable(ctx, userIn.ID, userIn.Email); err != nil { - return err - } // 使用 ent 事务包裹用户更新与 allowed_groups 同步,避免跨层事务不一致。 tx, err := r.client.Tx(ctx) @@ -160,9 +171,11 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error } var txClient *dbent.Client + txCtx := ctx if err == nil { defer func() { _ = tx.Rollback() }() txClient = tx.Client() + txCtx = dbent.NewTxContext(ctx, tx) } else { // 已处于外部事务中(ErrTxStarted),复用当前事务 client 并由调用方负责提交/回滚。 if existingTx := dbent.TxFromContext(ctx); existingTx != nil { @@ -171,7 +184,23 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error txClient = r.client } } - existing, err := clientFromContext(ctx, txClient).User.Get(ctx, userIn.ID) + + releaseEmailLock, err := lockRepositoryScopedKeys( + txCtx, + txClient, + txAwareSQLExecutor(txCtx, r.sql, r.client), + normalizedEmailUniquenessLockKey(userIn.Email), + ) + if err != nil { + return err + } + defer releaseEmailLock() + + if err := ensureNormalizedEmailAvailableWithClient(txCtx, txClient, userIn.ID, userIn.Email); err != nil { + return err + } + + existing, err := clientFromContext(txCtx, txClient).User.Get(txCtx, userIn.ID) if err != nil { return translatePersistenceError(err, service.ErrUserNotFound, nil) } @@ -203,15 +232,15 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error if userIn.BalanceNotifyThreshold == nil { updateOp = updateOp.ClearBalanceNotifyThreshold() } - updated, err := updateOp.Save(ctx) + updated, err := updateOp.Save(txCtx) if err != nil { return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists) } - if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, updated.ID, userIn.AllowedGroups); err != nil { + if err := r.syncUserAllowedGroupsWithClient(txCtx, txClient, updated.ID, userIn.AllowedGroups); err != nil { return err } - if err := replaceEmailAuthIdentityWithClient(ctx, txClient, updated.ID, oldEmail, updated.Email, "user_repo_update"); err != nil { + if err := replaceEmailAuthIdentityWithClient(txCtx, txClient, updated.ID, oldEmail, updated.Email, "user_repo_update"); err != nil { return err } @@ -711,7 +740,16 @@ func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, } func (r *userRepository) ensureNormalizedEmailAvailable(ctx context.Context, userID int64, email string) error { - matches, err := r.client.User.Query(). + return ensureNormalizedEmailAvailableWithClient(ctx, clientFromContext(ctx, r.client), userID, email) +} + +func ensureNormalizedEmailAvailableWithClient(ctx context.Context, client *dbent.Client, userID int64, email string) error { + client = clientFromContext(ctx, client) + if client == nil { + return nil + } + + matches, err := client.User.Query(). Where(userEmailLookupPredicate(email)). All(ctx) if err != nil { @@ -726,7 +764,7 @@ func (r *userRepository) ensureNormalizedEmailAvailable(ctx context.Context, use } func userEmailLookupPredicate(email string) predicate.User { - normalized := strings.ToLower(strings.TrimSpace(email)) + normalized := normalizeEmailLookupValue(email) if normalized == "" { return dbuser.EmailEQ(email) } @@ -740,6 +778,18 @@ func userEmailLookupPredicate(email string) predicate.User { }) } +func normalizeEmailLookupValue(email string) string { + return strings.ToLower(strings.TrimSpace(email)) +} + +func normalizedEmailUniquenessLockKey(email string) string { + normalized := normalizeEmailLookupValue(email) + if normalized == "" { + return "" + } + return "users:normalized-email:" + normalized +} + func (r *userRepository) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error { client := clientFromContext(ctx, r.client) err := client.UserAllowedGroup.Create(). @@ -874,11 +924,14 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) { } func userSignupSourceOrDefault(signupSource string) string { - signupSource = strings.TrimSpace(signupSource) - if signupSource == "" { + switch strings.TrimSpace(strings.ToLower(signupSource)) { + case "", "email": + return "email" + case "linuxdo", "wechat", "oidc": + return strings.TrimSpace(strings.ToLower(signupSource)) + default: return "email" } - return signupSource } // marshalExtraEmails serializes notify email entries to JSON for storage. diff --git a/backend/internal/repository/user_repo_email_lookup_unit_test.go b/backend/internal/repository/user_repo_email_lookup_unit_test.go index b2b02ef5..2ef9d761 100644 --- a/backend/internal/repository/user_repo_email_lookup_unit_test.go +++ b/backend/internal/repository/user_repo_email_lookup_unit_test.go @@ -3,7 +3,10 @@ package repository import ( "context" "database/sql" + "fmt" + "sync" "testing" + "time" dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/enttest" @@ -18,9 +21,10 @@ import ( func newUserEntRepo(t *testing.T) (*userRepository, *dbent.Client) { t.Helper() - db, err := sql.Open("sqlite", "file:user_repo_email_lookup?mode=memory&cache=shared") + db, err := sql.Open("sqlite", fmt.Sprintf("file:%s?mode=memory&cache=shared&_fk=1", t.Name())) require.NoError(t, err) t.Cleanup(func() { _ = db.Close() }) + db.SetMaxOpenConns(10) _, err = db.Exec("PRAGMA foreign_keys = ON") require.NoError(t, err) @@ -144,3 +148,80 @@ func TestUserRepositoryGetByEmailReportsNormalizedEmailConflict(t *testing.T) { require.Error(t, err) require.ErrorContains(t, err, "normalized email lookup matched multiple users") } + +func TestUserRepositoryCreateSerializesNormalizedEmailConflictsUnderConcurrency(t *testing.T) { + repo, client := newUserEntRepo(t) + ctx := context.Background() + + firstCreateStarted := make(chan struct{}) + releaseFirstCreate := make(chan struct{}) + var firstCreate sync.Once + client.User.Use(func(next dbent.Mutator) dbent.Mutator { + return dbent.MutateFunc(func(ctx context.Context, m dbent.Mutation) (dbent.Value, error) { + blocked := false + if m.Op().Is(dbent.OpCreate) { + firstCreate.Do(func() { + blocked = true + close(firstCreateStarted) + }) + } + if blocked { + <-releaseFirstCreate + } + return next.Mutate(ctx, m) + }) + }) + + type createResult struct { + err error + } + + results := make(chan createResult, 2) + go func() { + results <- createResult{err: repo.Create(ctx, &service.User{ + Email: " Race@Example.com ", + Username: "race-user-1", + PasswordHash: "hash", + Role: service.RoleUser, + Status: service.StatusActive, + })} + }() + + <-firstCreateStarted + + go func() { + results <- createResult{err: repo.Create(ctx, &service.User{ + Email: "race@example.com", + Username: "race-user-2", + PasswordHash: "hash", + Role: service.RoleUser, + Status: service.StatusActive, + })} + }() + + time.Sleep(100 * time.Millisecond) + close(releaseFirstCreate) + + first := <-results + second := <-results + + errors := []error{first.err, second.err} + successes := 0 + conflicts := 0 + for _, err := range errors { + switch { + case err == nil: + successes++ + case err == service.ErrEmailExists: + conflicts++ + default: + t.Fatalf("unexpected create error: %v", err) + } + } + require.Equal(t, 1, successes) + require.Equal(t, 1, conflicts) + + count, err := client.User.Query().Where(userEmailLookupPredicate("race@example.com")).Count(ctx) + require.NoError(t, err) + require.Equal(t, 1, count) +} diff --git a/backend/internal/service/auth_oauth_email_flow.go b/backend/internal/service/auth_oauth_email_flow.go index ea558ae2..a18cf39c 100644 --- a/backend/internal/service/auth_oauth_email_flow.go +++ b/backend/internal/service/auth_oauth_email_flow.go @@ -14,10 +14,14 @@ import ( func normalizeOAuthSignupSource(signupSource string) string { signupSource = strings.TrimSpace(strings.ToLower(signupSource)) - if signupSource == "" { + switch signupSource { + case "", "email": + return "email" + case "linuxdo", "wechat", "oidc": + return signupSource + default: return "email" } - return signupSource } // SendPendingOAuthVerifyCode sends a local verification code for pending OAuth @@ -136,10 +140,7 @@ func (s *AuthService) RegisterOAuthEmailAccount( return nil, nil, fmt.Errorf("hash password: %w", err) } - signupSource = strings.TrimSpace(strings.ToLower(signupSource)) - if signupSource == "" { - signupSource = "email" - } + signupSource = normalizeOAuthSignupSource(signupSource) grantPlan := s.resolveSignupGrantPlan(ctx, signupSource) user := &User{ @@ -149,6 +150,7 @@ func (s *AuthService) RegisterOAuthEmailAccount( Balance: grantPlan.Balance, Concurrency: grantPlan.Concurrency, Status: StatusActive, + SignupSource: signupSource, } if err := s.userRepo.Create(ctx, user); err != nil { diff --git a/backend/internal/service/auth_oauth_email_flow_test.go b/backend/internal/service/auth_oauth_email_flow_test.go index a77dda72..e3fb2f85 100644 --- a/backend/internal/service/auth_oauth_email_flow_test.go +++ b/backend/internal/service/auth_oauth_email_flow_test.go @@ -191,6 +191,80 @@ func TestRegisterOAuthEmailAccountRollsBackCreatedUserWhenTokenPairGenerationFai require.Empty(t, redeemRepo.updateCalls) } +func TestRegisterOAuthEmailAccountSetsNormalizedSignupSourceOnCreatedUser(t *testing.T) { + userRepo := &userRepoStub{nextID: 42} + emailCache := &emailCacheStub{ + data: &VerificationCodeData{ + Code: "246810", + Attempts: 0, + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(15 * time.Minute), + }, + } + authService := newOAuthEmailFlowAuthService( + userRepo, + &redeemCodeRepoStub{}, + &refreshTokenCacheStub{}, + map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyEmailVerifyEnabled: "true", + }, + emailCache, + ) + + tokenPair, user, err := authService.RegisterOAuthEmailAccount( + context.Background(), + "fresh@example.com", + "secret-123", + "246810", + "", + " OIDC ", + ) + + require.NoError(t, err) + require.NotNil(t, tokenPair) + require.NotNil(t, user) + require.Len(t, userRepo.created, 1) + require.Equal(t, "oidc", userRepo.created[0].SignupSource) +} + +func TestRegisterOAuthEmailAccountFallsBackUnknownSignupSourceToEmail(t *testing.T) { + userRepo := &userRepoStub{nextID: 43} + emailCache := &emailCacheStub{ + data: &VerificationCodeData{ + Code: "246810", + Attempts: 0, + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(15 * time.Minute), + }, + } + authService := newOAuthEmailFlowAuthService( + userRepo, + &redeemCodeRepoStub{}, + &refreshTokenCacheStub{}, + map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyEmailVerifyEnabled: "true", + }, + emailCache, + ) + + tokenPair, user, err := authService.RegisterOAuthEmailAccount( + context.Background(), + "fallback@example.com", + "secret-123", + "246810", + "", + "github", + ) + + require.NoError(t, err) + require.NotNil(t, tokenPair) + require.NotNil(t, user) + require.Len(t, userRepo.created, 1) + require.Equal(t, "email", userRepo.created[0].SignupSource) +} + func TestRollbackOAuthEmailAccountCreationRestoresInvitationUsage(t *testing.T) { userRepo := &userRepoStub{} redeemRepo := &redeemCodeRepoStub{ diff --git a/backend/internal/service/auth_pending_identity_service.go b/backend/internal/service/auth_pending_identity_service.go index cc0522ab..6e69c121 100644 --- a/backend/internal/service/auth_pending_identity_service.go +++ b/backend/internal/service/auth_pending_identity_service.go @@ -5,10 +5,15 @@ import ( "crypto/rand" "crypto/sha256" "encoding/hex" + "errors" "fmt" + "hash/fnv" + "sort" "strings" + "sync" "time" + "entgo.io/ent/dialect" dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" @@ -75,6 +80,122 @@ type AuthPendingIdentityService struct { entClient *dbent.Client } +var authPendingIdentityScopedKeyLocks = newAuthPendingIdentityScopedKeyLockRegistry() + +type authPendingIdentityScopedKeyLockRegistry struct { + mu sync.Mutex + locks map[string]*authPendingIdentityScopedKeyLockEntry +} + +type authPendingIdentityScopedKeyLockEntry struct { + mu sync.Mutex + refs int +} + +func newAuthPendingIdentityScopedKeyLockRegistry() *authPendingIdentityScopedKeyLockRegistry { + return &authPendingIdentityScopedKeyLockRegistry{ + locks: make(map[string]*authPendingIdentityScopedKeyLockEntry), + } +} + +func (r *authPendingIdentityScopedKeyLockRegistry) lock(keys ...string) func() { + normalized := normalizeAuthPendingIdentityLockKeys(keys...) + if len(normalized) == 0 { + return func() {} + } + + entries := make([]*authPendingIdentityScopedKeyLockEntry, 0, len(normalized)) + r.mu.Lock() + for _, key := range normalized { + entry := r.locks[key] + if entry == nil { + entry = &authPendingIdentityScopedKeyLockEntry{} + r.locks[key] = entry + } + entry.refs++ + entries = append(entries, entry) + } + r.mu.Unlock() + + for _, entry := range entries { + entry.mu.Lock() + } + + return func() { + for i := len(entries) - 1; i >= 0; i-- { + entries[i].mu.Unlock() + } + + r.mu.Lock() + defer r.mu.Unlock() + for idx, key := range normalized { + entry := entries[idx] + entry.refs-- + if entry.refs == 0 { + delete(r.locks, key) + } + } + } +} + +func normalizeAuthPendingIdentityLockKeys(keys ...string) []string { + if len(keys) == 0 { + return nil + } + + deduped := make(map[string]struct{}, len(keys)) + for _, key := range keys { + trimmed := strings.TrimSpace(key) + if trimmed == "" { + continue + } + deduped[trimmed] = struct{}{} + } + if len(deduped) == 0 { + return nil + } + + normalized := make([]string, 0, len(deduped)) + for key := range deduped { + normalized = append(normalized, key) + } + sort.Strings(normalized) + return normalized +} + +func authPendingIdentityAdvisoryLockHash(key string) int64 { + hasher := fnv.New64a() + _, _ = hasher.Write([]byte(key)) + return int64(hasher.Sum64()) +} + +func lockAuthPendingIdentityKeys(ctx context.Context, client *dbent.Client, keys ...string) (func(), error) { + release := authPendingIdentityScopedKeyLocks.lock(keys...) + normalized := normalizeAuthPendingIdentityLockKeys(keys...) + if len(normalized) == 0 || client == nil || client.Driver().Dialect() != dialect.Postgres { + return release, nil + } + + for _, key := range normalized { + var rows entsql.Rows + if err := client.Driver().Query(ctx, "SELECT pg_advisory_xact_lock($1)", []any{authPendingIdentityAdvisoryLockHash(key)}, &rows); err != nil { + release() + return nil, err + } + _ = rows.Close() + } + + return release, nil +} + +func pendingIdentityAdoptionLockKeys(pendingAuthSessionID int64, identityID *int64) []string { + keys := []string{fmt.Sprintf("pending-auth-adoption:pending:%d", pendingAuthSessionID)} + if identityID != nil && *identityID > 0 { + keys = append(keys, fmt.Sprintf("pending-auth-adoption:identity:%d", *identityID)) + } + return keys +} + func NewAuthPendingIdentityService(entClient *dbent.Client) *AuthPendingIdentityService { return &AuthPendingIdentityService{entClient: entClient} } @@ -324,8 +445,29 @@ func (s *AuthPendingIdentityService) UpsertAdoptionDecision(ctx context.Context, return nil, fmt.Errorf("pending auth ent client is not configured") } + tx, err := s.entClient.Tx(ctx) + if err != nil && !errors.Is(err, dbent.ErrTxStarted) { + return nil, err + } + + client := s.entClient + txCtx := ctx + if err == nil { + defer func() { _ = tx.Rollback() }() + client = tx.Client() + txCtx = dbent.NewTxContext(ctx, tx) + } else if existingTx := dbent.TxFromContext(ctx); existingTx != nil { + client = existingTx.Client() + } + + releaseLocks, err := lockAuthPendingIdentityKeys(txCtx, client, pendingIdentityAdoptionLockKeys(input.PendingAuthSessionID, input.IdentityID)...) + if err != nil { + return nil, err + } + defer releaseLocks() + if input.IdentityID != nil && *input.IdentityID > 0 { - if _, err := s.entClient.IdentityAdoptionDecision.Update(). + if _, err := client.IdentityAdoptionDecision.Update(). Where( identityadoptiondecision.IdentityIDEQ(*input.IdentityID), dbpredicate.IdentityAdoptionDecision(func(s *entsql.Selector) { @@ -337,36 +479,40 @@ func (s *AuthPendingIdentityService) UpsertAdoptionDecision(ctx context.Context, }), ). ClearIdentityID(). - Save(ctx); err != nil { + Save(txCtx); err != nil { return nil, err } } - existing, err := s.entClient.IdentityAdoptionDecision.Query(). - Where(identityadoptiondecision.PendingAuthSessionIDEQ(input.PendingAuthSessionID)). - Only(ctx) - if err != nil && !dbent.IsNotFound(err) { - return nil, err - } - if existing == nil { - create := s.entClient.IdentityAdoptionDecision.Create(). - SetPendingAuthSessionID(input.PendingAuthSessionID). - SetAdoptDisplayName(input.AdoptDisplayName). - SetAdoptAvatar(input.AdoptAvatar). - SetDecidedAt(time.Now().UTC()) - if input.IdentityID != nil { - create = create.SetIdentityID(*input.IdentityID) - } - return create.Save(ctx) + create := client.IdentityAdoptionDecision.Create(). + SetPendingAuthSessionID(input.PendingAuthSessionID). + SetAdoptDisplayName(input.AdoptDisplayName). + SetAdoptAvatar(input.AdoptAvatar). + SetDecidedAt(time.Now().UTC()) + if input.IdentityID != nil && *input.IdentityID > 0 { + create = create.SetIdentityID(*input.IdentityID) } - update := s.entClient.IdentityAdoptionDecision.UpdateOneID(existing.ID). - SetAdoptDisplayName(input.AdoptDisplayName). - SetAdoptAvatar(input.AdoptAvatar) - if input.IdentityID != nil { - update = update.SetIdentityID(*input.IdentityID) + decisionID, err := create. + OnConflictColumns(identityadoptiondecision.FieldPendingAuthSessionID). + UpdateNewValues(). + ID(txCtx) + if err != nil { + return nil, err } - return update.Save(ctx) + + decision, err := client.IdentityAdoptionDecision.Get(txCtx, decisionID) + if err != nil { + return nil, err + } + + if tx != nil { + if err := tx.Commit(); err != nil { + return nil, err + } + } + + return decision, nil } func copyPendingMap(in map[string]any) map[string]any { diff --git a/backend/internal/service/auth_pending_identity_service_test.go b/backend/internal/service/auth_pending_identity_service_test.go index deeeeb06..555bb0e7 100644 --- a/backend/internal/service/auth_pending_identity_service_test.go +++ b/backend/internal/service/auth_pending_identity_service_test.go @@ -5,6 +5,7 @@ package service import ( "context" "database/sql" + "sync" "testing" "time" @@ -259,6 +260,107 @@ func TestAuthPendingIdentityService_UpsertAdoptionDecision_ReassignsExistingIden require.Nil(t, reloadedFirst.IdentityID) } +func TestAuthPendingIdentityService_UpsertAdoptionDecision_IsIdempotentUnderConcurrency(t *testing.T) { + svc, client := newAuthPendingIdentityServiceTestClient(t) + ctx := context.Background() + + user, err := client.User.Create(). + SetEmail("adoption-concurrent@example.com"). + SetPasswordHash("hash"). + SetRole(RoleUser). + SetStatus(StatusActive). + Save(ctx) + require.NoError(t, err) + + identity, err := client.AuthIdentity.Create(). + SetUserID(user.ID). + SetProviderType("wechat"). + SetProviderKey("wechat-main"). + SetProviderSubject("union-concurrent"). + SetMetadata(map[string]any{}). + Save(ctx) + require.NoError(t, err) + + session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{ + Intent: "bind_current_user", + Identity: PendingAuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: "wechat-main", + ProviderSubject: "union-concurrent", + }, + }) + require.NoError(t, err) + + firstCreateStarted := make(chan struct{}) + releaseFirstCreate := make(chan struct{}) + var firstCreate sync.Once + client.IdentityAdoptionDecision.Use(func(next dbent.Mutator) dbent.Mutator { + return dbent.MutateFunc(func(ctx context.Context, m dbent.Mutation) (dbent.Value, error) { + blocked := false + if m.Op().Is(dbent.OpCreate) { + firstCreate.Do(func() { + blocked = true + close(firstCreateStarted) + }) + } + if blocked { + <-releaseFirstCreate + } + return next.Mutate(ctx, m) + }) + }) + + type adoptionResult struct { + decision *dbent.IdentityAdoptionDecision + err error + } + + input := PendingIdentityAdoptionDecisionInput{ + PendingAuthSessionID: session.ID, + IdentityID: &identity.ID, + AdoptDisplayName: true, + AdoptAvatar: true, + } + + results := make(chan adoptionResult, 2) + go func() { + decision, err := svc.UpsertAdoptionDecision(ctx, input) + results <- adoptionResult{decision: decision, err: err} + }() + + <-firstCreateStarted + + go func() { + decision, err := svc.UpsertAdoptionDecision(ctx, input) + results <- adoptionResult{decision: decision, err: err} + }() + + time.Sleep(100 * time.Millisecond) + close(releaseFirstCreate) + + first := <-results + second := <-results + + require.NoError(t, first.err) + require.NoError(t, second.err) + require.NotNil(t, first.decision) + require.NotNil(t, second.decision) + require.Equal(t, first.decision.ID, second.decision.ID) + + count, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Count(ctx) + require.NoError(t, err) + require.Equal(t, 1, count) + + loaded, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, loaded.IdentityID) + require.Equal(t, identity.ID, *loaded.IdentityID) +} + func TestAuthPendingIdentityService_UpsertAdoptionDecision_ClearsLegacyNullSessionReference(t *testing.T) { t.Skip("legacy NULL pending_auth_session_id rows only exist in production PostgreSQL history; sqlite unit schema rejects NULL") diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index efe08644..59442d1f 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -4,6 +4,7 @@ import ( "context" "crypto/rand" "crypto/sha256" + "encoding/binary" "encoding/hex" "errors" "fmt" @@ -489,6 +490,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username Balance: grantPlan.Balance, Concurrency: grantPlan.Concurrency, Status: StatusActive, + SignupSource: signupSource, } if err := s.userRepo.Create(ctx, newUser); err != nil { @@ -599,6 +601,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema Balance: grantPlan.Balance, Concurrency: grantPlan.Concurrency, Status: StatusActive, + SignupSource: signupSource, } if s.entClient != nil && invitationRedeemCode != nil { @@ -1048,7 +1051,7 @@ func (s *AuthService) GenerateToken(user *User) (string, error) { UserID: user.ID, Email: user.Email, Role: user.Role, - TokenVersion: user.TokenVersion, + TokenVersion: resolvedTokenVersion(user), RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(expiresAt), IssuedAt: jwt.NewNumericDate(now), @@ -1114,7 +1117,7 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) ( // Security: Check TokenVersion to prevent refreshing revoked tokens // This ensures tokens issued before a password change cannot be refreshed - if claims.TokenVersion != user.TokenVersion { + if claims.TokenVersion != resolvedTokenVersion(user) { return "", ErrTokenRevoked } @@ -1342,7 +1345,7 @@ func (s *AuthService) generateRefreshToken(ctx context.Context, user *User, fami data := &RefreshTokenData{ UserID: user.ID, - TokenVersion: user.TokenVersion, + TokenVersion: resolvedTokenVersion(user), FamilyID: familyID, CreatedAt: now, ExpiresAt: now.Add(ttl), @@ -1422,7 +1425,7 @@ func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string) } // 检查TokenVersion(密码更改后所有Token失效) - if data.TokenVersion != user.TokenVersion { + if data.TokenVersion != resolvedTokenVersion(user) { // TokenVersion不匹配,撤销整个Token家族 _ = s.refreshTokenCache.DeleteTokenFamily(ctx, data.FamilyID) return nil, ErrTokenRevoked @@ -1492,3 +1495,14 @@ func hashToken(token string) string { hash := sha256.Sum256([]byte(token)) return hex.EncodeToString(hash[:]) } + +func resolvedTokenVersion(user *User) int64 { + if user == nil { + return 0 + } + + material := strings.ToLower(strings.TrimSpace(user.Email)) + "\n" + user.PasswordHash + sum := sha256.Sum256([]byte(material)) + fingerprint := int64(binary.BigEndian.Uint64(sum[:8]) & 0x7fffffffffffffff) + return user.TokenVersion ^ fingerprint +} diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index f08274c7..aac60b08 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -814,6 +814,20 @@ func parseCustomMenuItemURLs(raw string) []string { return urls } +func oidcUsePKCECompatibilityDefault(base config.OIDCConnectConfig) bool { + if base.UsePKCEExplicit { + return base.UsePKCE + } + return false +} + +func oidcValidateIDTokenCompatibilityDefault(base config.OIDCConnectConfig) bool { + if base.ValidateIDTokenExplicit { + return base.ValidateIDToken + } + return false +} + // UpdateSettings 更新系统设置 func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSettings) error { updates, err := s.buildSystemSettingsUpdates(ctx, settings) @@ -1479,6 +1493,17 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { return fmt.Errorf("check existing settings: %w", err) } + oidcUsePKCEDefault := true + oidcValidateIDTokenDefault := true + if s != nil && s.cfg != nil { + if s.cfg.OIDC.UsePKCEExplicit { + oidcUsePKCEDefault = s.cfg.OIDC.UsePKCE + } + if s.cfg.OIDC.ValidateIDTokenExplicit { + oidcValidateIDTokenDefault = s.cfg.OIDC.ValidateIDToken + } + } + // 初始化默认设置 defaults := map[string]string{ SettingKeyRegistrationEnabled: "true", @@ -1523,8 +1548,8 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { SettingKeyOIDCConnectRedirectURL: "", SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback", SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post", - SettingKeyOIDCConnectUsePKCE: "true", - SettingKeyOIDCConnectValidateIDToken: "true", + SettingKeyOIDCConnectUsePKCE: strconv.FormatBool(oidcUsePKCEDefault), + SettingKeyOIDCConnectValidateIDToken: strconv.FormatBool(oidcValidateIDTokenDefault), SettingKeyOIDCConnectAllowedSigningAlgs: "RS256,ES256,PS256", SettingKeyOIDCConnectClockSkewSeconds: "120", SettingKeyOIDCConnectRequireEmailVerified: "false", @@ -1767,12 +1792,12 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin if raw, ok := settings[SettingKeyOIDCConnectUsePKCE]; ok { result.OIDCConnectUsePKCE = raw == "true" } else { - result.OIDCConnectUsePKCE = oidcBase.UsePKCE + result.OIDCConnectUsePKCE = oidcUsePKCECompatibilityDefault(oidcBase) } if raw, ok := settings[SettingKeyOIDCConnectValidateIDToken]; ok { result.OIDCConnectValidateIDToken = raw == "true" } else { - result.OIDCConnectValidateIDToken = oidcBase.ValidateIDToken + result.OIDCConnectValidateIDToken = oidcValidateIDTokenCompatibilityDefault(oidcBase) } if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" { result.OIDCConnectAllowedSigningAlgs = strings.TrimSpace(v) @@ -2482,9 +2507,13 @@ func (s *SettingService) GetOIDCConnectOAuthConfig(ctx context.Context) (config. } if raw, ok := settings[SettingKeyOIDCConnectUsePKCE]; ok { effective.UsePKCE = raw == "true" + } else { + effective.UsePKCE = oidcUsePKCECompatibilityDefault(effective) } if raw, ok := settings[SettingKeyOIDCConnectValidateIDToken]; ok { effective.ValidateIDToken = raw == "true" + } else { + effective.ValidateIDToken = oidcValidateIDTokenCompatibilityDefault(effective) } if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" { effective.AllowedSigningAlgs = strings.TrimSpace(v) diff --git a/backend/internal/service/setting_service_oidc_config_test.go b/backend/internal/service/setting_service_oidc_config_test.go index eb312d2c..1ece6405 100644 --- a/backend/internal/service/setting_service_oidc_config_test.go +++ b/backend/internal/service/setting_service_oidc_config_test.go @@ -118,8 +118,10 @@ func TestSettingService_ParseSettings_PreservesOptionalOIDCCompatibilityFlags(t func TestSettingService_ParseSettings_DefaultsOIDCSecurityFlagsToSafeConfigValues(t *testing.T) { svc := NewSettingService(&settingOIDCRepoStub{values: map[string]string{}}, &config.Config{ OIDC: config.OIDCConnectConfig{ - UsePKCE: true, - ValidateIDToken: true, + UsePKCE: true, + UsePKCEExplicit: true, + ValidateIDToken: true, + ValidateIDTokenExplicit: true, }, }) @@ -131,6 +133,22 @@ func TestSettingService_ParseSettings_DefaultsOIDCSecurityFlagsToSafeConfigValue require.True(t, got.OIDCConnectValidateIDToken) } +func TestSettingService_ParseSettings_UsesLegacyOIDCCompatibilityFlagsWhenSettingsMissing(t *testing.T) { + svc := NewSettingService(&settingOIDCRepoStub{values: map[string]string{}}, &config.Config{ + OIDC: config.OIDCConnectConfig{ + UsePKCE: true, + ValidateIDToken: true, + }, + }) + + got := svc.parseSettings(map[string]string{ + SettingKeyOIDCConnectEnabled: "true", + }) + + require.False(t, got.OIDCConnectUsePKCE) + require.False(t, got.OIDCConnectValidateIDToken) +} + func TestGetOIDCConnectOAuthConfig_AllowsCompatibilityFlagsToDisablePKCEAndIDTokenValidation(t *testing.T) { cfg := &config.Config{ OIDC: config.OIDCConnectConfig{ @@ -163,6 +181,42 @@ func TestGetOIDCConnectOAuthConfig_AllowsCompatibilityFlagsToDisablePKCEAndIDTok } func TestGetOIDCConnectOAuthConfig_DefaultsToSecureFlagsWhenSettingsMissing(t *testing.T) { + cfg := &config.Config{ + OIDC: config.OIDCConnectConfig{ + Enabled: true, + ProviderName: "OIDC", + ClientID: "oidc-client", + ClientSecret: "oidc-secret", + IssuerURL: "https://issuer.example.com", + AuthorizeURL: "https://issuer.example.com/auth", + TokenURL: "https://issuer.example.com/token", + UserInfoURL: "https://issuer.example.com/userinfo", + JWKSURL: "https://issuer.example.com/jwks", + RedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback", + FrontendRedirectURL: "/auth/oidc/callback", + Scopes: "openid email profile", + TokenAuthMethod: "client_secret_post", + UsePKCE: true, + UsePKCEExplicit: true, + ValidateIDToken: true, + ValidateIDTokenExplicit: true, + AllowedSigningAlgs: "RS256", + ClockSkewSeconds: 120, + }, + } + + repo := &settingOIDCRepoStub{values: map[string]string{ + SettingKeyOIDCConnectEnabled: "true", + }} + svc := NewSettingService(repo, cfg) + + got, err := svc.GetOIDCConnectOAuthConfig(context.Background()) + require.NoError(t, err) + require.True(t, got.UsePKCE) + require.True(t, got.ValidateIDToken) +} + +func TestGetOIDCConnectOAuthConfig_UsesLegacyOIDCCompatibilityFlagsWhenSettingsMissing(t *testing.T) { cfg := &config.Config{ OIDC: config.OIDCConnectConfig{ Enabled: true, @@ -192,6 +246,6 @@ func TestGetOIDCConnectOAuthConfig_DefaultsToSecureFlagsWhenSettingsMissing(t *t got, err := svc.GetOIDCConnectOAuthConfig(context.Background()) require.NoError(t, err) - require.True(t, got.UsePKCE) - require.True(t, got.ValidateIDToken) + require.False(t, got.UsePKCE) + require.False(t, got.ValidateIDToken) } diff --git a/backend/migrations/110_pending_auth_and_provider_default_grants.sql b/backend/migrations/110_pending_auth_and_provider_default_grants.sql index fbaed62e..f59b2188 100644 --- a/backend/migrations/110_pending_auth_and_provider_default_grants.sql +++ b/backend/migrations/110_pending_auth_and_provider_default_grants.sql @@ -38,23 +38,22 @@ VALUES ('auth_source_default_email_balance', '0'), ('auth_source_default_email_concurrency', '5'), ('auth_source_default_email_subscriptions', '[]'), - ('auth_source_default_email_grant_on_signup', 'true'), + ('auth_source_default_email_grant_on_signup', 'false'), ('auth_source_default_email_grant_on_first_bind', 'false'), ('auth_source_default_linuxdo_balance', '0'), ('auth_source_default_linuxdo_concurrency', '5'), ('auth_source_default_linuxdo_subscriptions', '[]'), - ('auth_source_default_linuxdo_grant_on_signup', 'true'), + ('auth_source_default_linuxdo_grant_on_signup', 'false'), ('auth_source_default_linuxdo_grant_on_first_bind', 'false'), ('auth_source_default_oidc_balance', '0'), ('auth_source_default_oidc_concurrency', '5'), ('auth_source_default_oidc_subscriptions', '[]'), - ('auth_source_default_oidc_grant_on_signup', 'true'), + ('auth_source_default_oidc_grant_on_signup', 'false'), ('auth_source_default_oidc_grant_on_first_bind', 'false'), ('auth_source_default_wechat_balance', '0'), ('auth_source_default_wechat_concurrency', '5'), ('auth_source_default_wechat_subscriptions', '[]'), - ('auth_source_default_wechat_grant_on_signup', 'true'), + ('auth_source_default_wechat_grant_on_signup', 'false'), ('auth_source_default_wechat_grant_on_first_bind', 'false'), ('force_email_on_third_party_signup', 'false') ON CONFLICT (key) DO NOTHING; - diff --git a/backend/migrations/115_auth_identity_legacy_external_backfill.sql b/backend/migrations/115_auth_identity_legacy_external_backfill.sql index 7a20f8eb..264da3c9 100644 --- a/backend/migrations/115_auth_identity_legacy_external_backfill.sql +++ b/backend/migrations/115_auth_identity_legacy_external_backfill.sql @@ -31,6 +31,41 @@ BEGIN END IF; EXECUTE $sql$ +WITH legacy AS ( + SELECT + uei.id, + uei.user_id, + BTRIM(uei.provider_user_id) AS provider_user_id, + BTRIM(uei.provider_username) AS provider_username, + BTRIM(uei.display_name) AS display_name, + public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json, + uei.created_at, + uei.updated_at + FROM user_external_identities AS uei + JOIN users AS u ON u.id = uei.user_id + WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo' + AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '' +), +legacy_subjects AS ( + SELECT + provider_user_id AS provider_subject, + COUNT(DISTINCT user_id) AS distinct_user_count + FROM legacy + GROUP BY provider_user_id +), +canonical_legacy AS ( + SELECT + legacy.*, + ROW_NUMBER() OVER ( + PARTITION BY legacy.provider_user_id + ORDER BY COALESCE(legacy.updated_at, legacy.created_at, NOW()) DESC, legacy.id DESC + ) AS canonical_row_num + FROM legacy + JOIN legacy_subjects AS subjects + ON subjects.provider_subject = legacy.provider_user_id + AND subjects.distinct_user_count = 1 +) INSERT INTO auth_identities ( user_id, provider_type, @@ -52,11 +87,18 @@ SELECT 'display_name', legacy.display_name, 'migration', '115_auth_identity_legacy_external_backfill' ) -FROM ( +FROM canonical_legacy AS legacy +WHERE legacy.canonical_row_num = 1 +ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING; +$sql$; + + EXECUTE $sql$ +WITH legacy AS ( SELECT uei.id, uei.user_id, BTRIM(uei.provider_user_id) AS provider_user_id, + BTRIM(uei.provider_union_id) AS provider_union_id, BTRIM(uei.provider_username) AS provider_username, BTRIM(uei.display_name) AS display_name, public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json, @@ -65,13 +107,28 @@ FROM ( FROM user_external_identities AS uei JOIN users AS u ON u.id = uei.user_id WHERE u.deleted_at IS NULL - AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo' - AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '' -) AS legacy -ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING; -$sql$; - - EXECUTE $sql$ + AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' + AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '' +), +legacy_subjects AS ( + SELECT + provider_union_id AS provider_subject, + COUNT(DISTINCT user_id) AS distinct_user_count + FROM legacy + GROUP BY provider_union_id +), +canonical_legacy AS ( + SELECT + legacy.*, + ROW_NUMBER() OVER ( + PARTITION BY legacy.provider_union_id + ORDER BY COALESCE(legacy.updated_at, legacy.created_at, NOW()) DESC, legacy.id DESC + ) AS canonical_row_num + FROM legacy + JOIN legacy_subjects AS subjects + ON subjects.provider_subject = legacy.provider_union_id + AND subjects.distinct_user_count = 1 +) INSERT INTO auth_identities ( user_id, provider_type, @@ -96,27 +153,36 @@ SELECT 'display_name', legacy.display_name, 'migration', '115_auth_identity_legacy_external_backfill' ) -FROM ( - SELECT - uei.id, - uei.user_id, - BTRIM(uei.provider_user_id) AS provider_user_id, - BTRIM(uei.provider_union_id) AS provider_union_id, - BTRIM(uei.provider_username) AS provider_username, - BTRIM(uei.display_name) AS display_name, - public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json, - uei.created_at, - uei.updated_at - FROM user_external_identities AS uei - JOIN users AS u ON u.id = uei.user_id - WHERE u.deleted_at IS NULL - AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' - AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '' -) AS legacy +FROM canonical_legacy AS legacy +WHERE legacy.canonical_row_num = 1 ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING; $sql$; EXECUTE $sql$ +WITH legacy AS ( + SELECT + uei.user_id, + BTRIM(uei.provider_user_id) AS provider_user_id, + BTRIM(uei.provider_union_id) AS provider_union_id, + BTRIM(COALESCE(meta.metadata_json ->> 'channel', '')) AS channel, + BTRIM(COALESCE(meta.metadata_json ->> 'channel_app_id', meta.metadata_json ->> 'appid', meta.metadata_json ->> 'app_id', '')) AS channel_app_id, + meta.metadata_json + FROM user_external_identities AS uei + JOIN users AS u ON u.id = uei.user_id + CROSS JOIN LATERAL ( + SELECT public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json + ) AS meta + WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' + AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '' +), +legacy_subjects AS ( + SELECT + provider_union_id AS provider_subject, + COUNT(DISTINCT user_id) AS distinct_user_count + FROM legacy + GROUP BY provider_union_id +) INSERT INTO auth_identity_channels ( identity_id, provider_type, @@ -138,23 +204,10 @@ SELECT 'unionid', legacy.provider_union_id, 'migration', '115_auth_identity_legacy_external_backfill' ) -FROM ( - SELECT - uei.user_id, - BTRIM(uei.provider_user_id) AS provider_user_id, - BTRIM(uei.provider_union_id) AS provider_union_id, - BTRIM(COALESCE(meta.metadata_json ->> 'channel', '')) AS channel, - BTRIM(COALESCE(meta.metadata_json ->> 'channel_app_id', meta.metadata_json ->> 'appid', meta.metadata_json ->> 'app_id', '')) AS channel_app_id, - meta.metadata_json - FROM user_external_identities AS uei - JOIN users AS u ON u.id = uei.user_id - CROSS JOIN LATERAL ( - SELECT public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json - ) AS meta - WHERE u.deleted_at IS NULL - AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' - AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '' -) AS legacy +FROM legacy +JOIN legacy_subjects AS subjects + ON subjects.provider_subject = legacy.provider_union_id + AND subjects.distinct_user_count = 1 JOIN auth_identities AS ai ON ai.user_id = legacy.user_id AND ai.provider_type = 'wechat' diff --git a/backend/migrations/116_auth_identity_legacy_external_safety_reports.sql b/backend/migrations/116_auth_identity_legacy_external_safety_reports.sql index 3983bb1a..81eb133c 100644 --- a/backend/migrations/116_auth_identity_legacy_external_safety_reports.sql +++ b/backend/migrations/116_auth_identity_legacy_external_safety_reports.sql @@ -74,6 +74,82 @@ $sql$; EXECUTE $sql$ INSERT INTO auth_identity_migration_reports (report_type, report_key, details) +SELECT + 'legacy_external_identity_conflict', + 'legacy_external_identity:' || legacy.id::text, + legacy.metadata_json || jsonb_build_object( + 'legacy_identity_id', legacy.id, + 'legacy_user_id', legacy.user_id, + 'provider_type', legacy.provider_type, + 'provider_key', legacy.provider_key, + 'provider_subject', legacy.provider_subject, + 'conflicting_legacy_user_ids', ambiguous.conflicting_legacy_user_ids, + 'reason', 'legacy canonical identity subject belongs to multiple legacy users and cannot be auto-resolved', + 'migration', '116_auth_identity_legacy_external_safety_reports' + ) +FROM ( + SELECT + uei.id, + uei.user_id, + LOWER(BTRIM(COALESCE(uei.provider, ''))) AS provider_type, + CASE + WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN 'wechat-main' + ELSE 'linuxdo' + END AS provider_key, + CASE + WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN BTRIM(COALESCE(uei.provider_union_id, '')) + ELSE BTRIM(COALESCE(uei.provider_user_id, '')) + END AS provider_subject, + public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json + FROM user_external_identities AS uei + JOIN users AS u ON u.id = uei.user_id + WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(COALESCE(uei.provider, ''))) IN ('linuxdo', 'wechat') + AND ( + (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo' AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '') + OR + (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '') + ) +) AS legacy +JOIN ( + SELECT + provider_type, + provider_key, + provider_subject, + to_jsonb(array_agg(DISTINCT user_id ORDER BY user_id)) AS conflicting_legacy_user_ids + FROM ( + SELECT + uei.user_id, + LOWER(BTRIM(COALESCE(uei.provider, ''))) AS provider_type, + CASE + WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN 'wechat-main' + ELSE 'linuxdo' + END AS provider_key, + CASE + WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN BTRIM(COALESCE(uei.provider_union_id, '')) + ELSE BTRIM(COALESCE(uei.provider_user_id, '')) + END AS provider_subject + FROM user_external_identities AS uei + JOIN users AS u ON u.id = uei.user_id + WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(COALESCE(uei.provider, ''))) IN ('linuxdo', 'wechat') + AND ( + (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo' AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '') + OR + (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '') + ) + ) AS legacy_subjects + GROUP BY provider_type, provider_key, provider_subject + HAVING COUNT(DISTINCT user_id) > 1 +) AS ambiguous + ON ambiguous.provider_type = legacy.provider_type + AND ambiguous.provider_key = legacy.provider_key + AND ambiguous.provider_subject = legacy.provider_subject +ON CONFLICT (report_type, report_key) DO NOTHING; +$sql$; + + EXECUTE $sql$ +INSERT INTO auth_identity_migration_reports (report_type, report_key, details) SELECT 'legacy_external_identity_conflict', 'legacy_external_identity:' || legacy.id::text, @@ -116,6 +192,39 @@ FROM ( (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '') ) ) AS legacy +JOIN ( + SELECT + provider_type, + provider_key, + provider_subject + FROM ( + SELECT + uei.user_id, + LOWER(BTRIM(COALESCE(uei.provider, ''))) AS provider_type, + CASE + WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN 'wechat-main' + ELSE 'linuxdo' + END AS provider_key, + CASE + WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN BTRIM(COALESCE(uei.provider_union_id, '')) + ELSE BTRIM(COALESCE(uei.provider_user_id, '')) + END AS provider_subject + FROM user_external_identities AS uei + JOIN users AS u ON u.id = uei.user_id + WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(COALESCE(uei.provider, ''))) IN ('linuxdo', 'wechat') + AND ( + (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo' AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '') + OR + (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '') + ) + ) AS legacy_subjects + GROUP BY provider_type, provider_key, provider_subject + HAVING COUNT(DISTINCT user_id) = 1 +) AS clear_subjects + ON clear_subjects.provider_type = legacy.provider_type + AND clear_subjects.provider_key = legacy.provider_key + AND clear_subjects.provider_subject = legacy.provider_subject JOIN auth_identities AS ai ON ai.provider_type = legacy.provider_type AND ai.provider_key = legacy.provider_key @@ -125,29 +234,7 @@ ON CONFLICT (report_type, report_key) DO NOTHING; $sql$; EXECUTE $sql$ -INSERT INTO auth_identities ( - user_id, - provider_type, - provider_key, - provider_subject, - verified_at, - metadata -) -SELECT - legacy.user_id, - legacy.provider_type, - legacy.provider_key, - legacy.provider_subject, - legacy.verified_at, - legacy.metadata_json || jsonb_build_object( - 'legacy_identity_id', legacy.id, - 'provider_user_id', legacy.provider_user_id, - 'provider_union_id', NULLIF(legacy.provider_union_id, ''), - 'provider_username', legacy.provider_username, - 'display_name', legacy.display_name, - 'migration', '116_auth_identity_legacy_external_safety_reports' - ) -FROM ( +WITH legacy AS ( SELECT uei.id, uei.user_id, @@ -175,12 +262,58 @@ FROM ( OR (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '') ) -) AS legacy +), +clear_subjects AS ( + SELECT + provider_type, + provider_key, + provider_subject + FROM legacy + GROUP BY provider_type, provider_key, provider_subject + HAVING COUNT(DISTINCT user_id) = 1 +), +canonical_legacy AS ( + SELECT + legacy.*, + ROW_NUMBER() OVER ( + PARTITION BY legacy.provider_type, legacy.provider_key, legacy.provider_subject + ORDER BY legacy.verified_at DESC, legacy.id DESC + ) AS canonical_row_num + FROM legacy + JOIN clear_subjects + ON clear_subjects.provider_type = legacy.provider_type + AND clear_subjects.provider_key = legacy.provider_key + AND clear_subjects.provider_subject = legacy.provider_subject +) +INSERT INTO auth_identities ( + user_id, + provider_type, + provider_key, + provider_subject, + verified_at, + metadata +) +SELECT + legacy.user_id, + legacy.provider_type, + legacy.provider_key, + legacy.provider_subject, + legacy.verified_at, + legacy.metadata_json || jsonb_build_object( + 'legacy_identity_id', legacy.id, + 'provider_user_id', legacy.provider_user_id, + 'provider_union_id', NULLIF(legacy.provider_union_id, ''), + 'provider_username', legacy.provider_username, + 'display_name', legacy.display_name, + 'migration', '116_auth_identity_legacy_external_safety_reports' + ) +FROM canonical_legacy AS legacy LEFT JOIN auth_identities AS ai ON ai.provider_type = legacy.provider_type AND ai.provider_key = legacy.provider_key AND ai.provider_subject = legacy.provider_subject -WHERE ai.id IS NULL +WHERE legacy.canonical_row_num = 1 + AND ai.id IS NULL ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING; $sql$; @@ -225,6 +358,19 @@ FROM ( AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '' AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '' ) AS legacy +JOIN ( + SELECT + BTRIM(COALESCE(uei.provider_union_id, '')) AS provider_subject + FROM user_external_identities AS uei + JOIN users AS u ON u.id = uei.user_id + WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' + AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '' + AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '' + GROUP BY BTRIM(COALESCE(uei.provider_union_id, '')) + HAVING COUNT(DISTINCT uei.user_id) = 1 +) AS clear_subjects + ON clear_subjects.provider_subject = legacy.provider_union_id JOIN auth_identities AS legacy_ai ON legacy_ai.user_id = legacy.user_id AND legacy_ai.provider_type = 'wechat' @@ -245,6 +391,33 @@ ON CONFLICT (report_type, report_key) DO NOTHING; $sql$; EXECUTE $sql$ +WITH legacy AS ( + SELECT + uei.user_id, + BTRIM(COALESCE(uei.provider_user_id, '')) AS provider_user_id, + BTRIM(COALESCE(uei.provider_union_id, '')) AS provider_union_id, + public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json, + BTRIM(COALESCE(public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'channel', '')) AS channel, + BTRIM(COALESCE( + public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'channel_app_id', + public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'appid', + public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'app_id', + '' + )) AS channel_app_id + FROM user_external_identities AS uei + JOIN users AS u ON u.id = uei.user_id + WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' + AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '' + AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '' +), +clear_subjects AS ( + SELECT + provider_union_id AS provider_subject + FROM legacy + GROUP BY provider_union_id + HAVING COUNT(DISTINCT user_id) = 1 +) INSERT INTO auth_identity_channels ( identity_id, provider_type, @@ -266,26 +439,9 @@ SELECT 'unionid', legacy.provider_union_id, 'migration', '116_auth_identity_legacy_external_safety_reports' ) -FROM ( - SELECT - uei.user_id, - BTRIM(COALESCE(uei.provider_user_id, '')) AS provider_user_id, - BTRIM(COALESCE(uei.provider_union_id, '')) AS provider_union_id, - public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json, - BTRIM(COALESCE(public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'channel', '')) AS channel, - BTRIM(COALESCE( - public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'channel_app_id', - public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'appid', - public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'app_id', - '' - )) AS channel_app_id - FROM user_external_identities AS uei - JOIN users AS u ON u.id = uei.user_id - WHERE u.deleted_at IS NULL - AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' - AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '' - AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '' -) AS legacy +FROM legacy +JOIN clear_subjects + ON clear_subjects.provider_subject = legacy.provider_union_id JOIN auth_identities AS legacy_ai ON legacy_ai.user_id = legacy.user_id AND legacy_ai.provider_type = 'wechat' diff --git a/backend/migrations/123_fix_legacy_auth_source_grant_on_signup_defaults.sql b/backend/migrations/123_fix_legacy_auth_source_grant_on_signup_defaults.sql index 094b223c..4388285a 100644 --- a/backend/migrations/123_fix_legacy_auth_source_grant_on_signup_defaults.sql +++ b/backend/migrations/123_fix_legacy_auth_source_grant_on_signup_defaults.sql @@ -1,3 +1,68 @@ --- Intentionally left as a no-op. --- Legacy installs may have intentionally kept the original signup grant defaults, --- and we cannot distinguish those cases safely from untouched migration 110 rows. +-- Auto-backfill untouched migration 110 signup-grant defaults to the corrected false value. +-- Rows still matching the migration-110 default payload and timestamp window are treated as +-- untouched legacy defaults; any remaining legacy true values are reported for manual review. + +WITH migration_110 AS ( + SELECT applied_at + FROM schema_migrations + WHERE filename = '110_pending_auth_and_provider_default_grants.sql' +), +providers AS ( + SELECT provider_type + FROM ( + VALUES ('email'), ('linuxdo'), ('oidc'), ('wechat') + ) AS providers(provider_type) +), +legacy_provider_defaults AS ( + SELECT providers.provider_type + FROM providers + CROSS JOIN migration_110 + JOIN settings balance + ON balance.key = 'auth_source_default_' || providers.provider_type || '_balance' + JOIN settings concurrency + ON concurrency.key = 'auth_source_default_' || providers.provider_type || '_concurrency' + JOIN settings subscriptions + ON subscriptions.key = 'auth_source_default_' || providers.provider_type || '_subscriptions' + JOIN settings grant_on_signup + ON grant_on_signup.key = 'auth_source_default_' || providers.provider_type || '_grant_on_signup' + JOIN settings grant_on_first_bind + ON grant_on_first_bind.key = 'auth_source_default_' || providers.provider_type || '_grant_on_first_bind' + WHERE balance.value = '0' + AND concurrency.value = '5' + AND subscriptions.value = '[]' + AND grant_on_signup.value = 'true' + AND grant_on_first_bind.value = 'false' + AND balance.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute' + AND concurrency.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute' + AND subscriptions.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute' + AND grant_on_signup.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute' + AND grant_on_first_bind.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute' +), +updated_signup_grants AS ( + UPDATE settings + SET + value = 'false', + updated_at = NOW() + FROM legacy_provider_defaults + WHERE settings.key = 'auth_source_default_' || legacy_provider_defaults.provider_type || '_grant_on_signup' + AND settings.value = 'true' + RETURNING legacy_provider_defaults.provider_type +) +INSERT INTO auth_identity_migration_reports (report_type, report_key, details) +SELECT + 'legacy_auth_source_signup_grant_review', + providers.provider_type, + jsonb_build_object( + 'provider_type', providers.provider_type, + 'current_value', grant_on_signup.value, + 'auto_backfilled', FALSE, + 'reason', 'legacy_true_default_not_auto_backfilled' + ) +FROM providers +JOIN settings grant_on_signup + ON grant_on_signup.key = 'auth_source_default_' || providers.provider_type || '_grant_on_signup' +LEFT JOIN updated_signup_grants + ON updated_signup_grants.provider_type = providers.provider_type +WHERE grant_on_signup.value = 'true' + AND updated_signup_grants.provider_type IS NULL +ON CONFLICT (report_type, report_key) DO NOTHING;