diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go index 6d6564e8..21ed2bc6 100644 --- a/backend/internal/handler/auth_oauth_pending_flow.go +++ b/backend/internal/handler/auth_oauth_pending_flow.go @@ -10,6 +10,7 @@ import ( dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" dbuser "github.com/Wei-Shaw/sub2api/ent/user" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" @@ -309,6 +310,14 @@ func cloneOAuthMetadata(values map[string]any) map[string]any { return cloned } +func mergeOAuthMetadata(base map[string]any, overlay map[string]any) map[string]any { + merged := cloneOAuthMetadata(base) + for key, value := range overlay { + merged[key] = value + } + return merged +} + func normalizeAdoptedOAuthDisplayName(value string) string { value = strings.TrimSpace(value) if len([]rune(value)) > 100 { @@ -558,6 +567,10 @@ func oauthIdentityIssuer(session *dbent.PendingAuthSession) *string { } func ensurePendingOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, session *dbent.PendingAuthSession, userID int64) (*dbent.AuthIdentity, error) { + if session != nil && strings.EqualFold(strings.TrimSpace(session.ProviderType), "wechat") { + return ensurePendingWeChatOAuthIdentityForUser(ctx, tx, session, userID) + } + client := tx.Client() identity, err := client.AuthIdentity.Query(). Where( @@ -588,14 +601,149 @@ func ensurePendingOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, sessio return create.Save(ctx) } +func ensurePendingWeChatOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, session *dbent.PendingAuthSession, userID int64) (*dbent.AuthIdentity, error) { + client := tx.Client() + providerType := strings.TrimSpace(session.ProviderType) + providerKey := strings.TrimSpace(session.ProviderKey) + providerSubject := strings.TrimSpace(session.ProviderSubject) + channel := strings.TrimSpace(pendingSessionStringValue(session.UpstreamIdentityClaims, "channel")) + channelAppID := strings.TrimSpace(pendingSessionStringValue(session.UpstreamIdentityClaims, "channel_app_id")) + channelSubject := strings.TrimSpace(pendingSessionStringValue(session.UpstreamIdentityClaims, "channel_subject")) + metadata := cloneOAuthMetadata(session.UpstreamIdentityClaims) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(providerType), + authidentity.ProviderKeyEQ(providerKey), + authidentity.ProviderSubjectEQ(providerSubject), + ). + Only(ctx) + if err != nil && !dbent.IsNotFound(err) { + return nil, err + } + if identity != nil && identity.UserID != userID { + return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user") + } + + var legacyOpenIDIdentity *dbent.AuthIdentity + if channelSubject != "" && channelSubject != providerSubject { + legacyOpenIDIdentity, err = client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(providerType), + authidentity.ProviderKeyEQ(providerKey), + authidentity.ProviderSubjectEQ(channelSubject), + ). + Only(ctx) + if err != nil && !dbent.IsNotFound(err) { + return nil, err + } + if legacyOpenIDIdentity != nil && legacyOpenIDIdentity.UserID != userID { + return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user") + } + } + + switch { + case identity != nil: + update := client.AuthIdentity.UpdateOneID(identity.ID). + SetMetadata(mergeOAuthMetadata(identity.Metadata, metadata)) + if issuer := oauthIdentityIssuer(session); issuer != nil { + update = update.SetIssuer(strings.TrimSpace(*issuer)) + } + identity, err = update.Save(ctx) + if err != nil { + return nil, err + } + case legacyOpenIDIdentity != nil: + update := client.AuthIdentity.UpdateOneID(legacyOpenIDIdentity.ID). + SetProviderSubject(providerSubject). + SetMetadata(mergeOAuthMetadata(legacyOpenIDIdentity.Metadata, metadata)) + if issuer := oauthIdentityIssuer(session); issuer != nil { + update = update.SetIssuer(strings.TrimSpace(*issuer)) + } + identity, err = update.Save(ctx) + if err != nil { + return nil, err + } + default: + create := client.AuthIdentity.Create(). + SetUserID(userID). + SetProviderType(providerType). + SetProviderKey(providerKey). + SetProviderSubject(providerSubject). + SetMetadata(metadata) + if issuer := oauthIdentityIssuer(session); issuer != nil { + create = create.SetIssuer(strings.TrimSpace(*issuer)) + } + identity, err = create.Save(ctx) + if err != nil { + return nil, err + } + } + + if channel == "" || channelAppID == "" || channelSubject == "" { + return identity, nil + } + + channelRecord, err := client.AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ(providerType), + authidentitychannel.ProviderKeyEQ(providerKey), + authidentitychannel.ChannelEQ(channel), + authidentitychannel.ChannelAppIDEQ(channelAppID), + authidentitychannel.ChannelSubjectEQ(channelSubject), + ). + WithIdentity(). + Only(ctx) + if err != nil && !dbent.IsNotFound(err) { + return nil, err + } + if channelRecord != nil && channelRecord.Edges.Identity != nil && channelRecord.Edges.Identity.UserID != userID { + return nil, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user") + } + + channelMetadata := mergeOAuthMetadata(channelRecordMetadata(channelRecord), metadata) + if channelRecord == nil { + if _, err := client.AuthIdentityChannel.Create(). + SetIdentityID(identity.ID). + SetProviderType(providerType). + SetProviderKey(providerKey). + SetChannel(channel). + SetChannelAppID(channelAppID). + SetChannelSubject(channelSubject). + SetMetadata(channelMetadata). + Save(ctx); err != nil { + return nil, err + } + return identity, nil + } + + _, err = client.AuthIdentityChannel.UpdateOneID(channelRecord.ID). + SetIdentityID(identity.ID). + SetMetadata(channelMetadata). + Save(ctx) + if err != nil { + return nil, err + } + return identity, nil +} + +func channelRecordMetadata(channel *dbent.AuthIdentityChannel) map[string]any { + if channel == nil { + return map[string]any{} + } + return cloneOAuthMetadata(channel.Metadata) +} + func shouldBindPendingOAuthIdentity(session *dbent.PendingAuthSession, decision *dbent.IdentityAdoptionDecision) bool { if session == nil || decision == nil { return false } - if strings.EqualFold(strings.TrimSpace(session.Intent), "bind_current_user") { + switch strings.ToLower(strings.TrimSpace(session.Intent)) { + case "bind_current_user", "login", "adopt_existing_user_by_email": return true + default: + return decision.AdoptDisplayName || decision.AdoptAvatar } - return decision.AdoptDisplayName || decision.AdoptAvatar } func applyPendingOAuthBinding( diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go index ae506e52..89accd60 100644 --- a/backend/internal/handler/auth_oauth_pending_flow_test.go +++ b/backend/internal/handler/auth_oauth_pending_flow_test.go @@ -372,7 +372,7 @@ func TestExchangePendingOAuthCompletionBindCurrentUserOwnershipConflict(t *testi require.Nil(t, storedSession.ConsumedAt) } -func TestExchangePendingOAuthCompletionLoginFalseFalseDoesNotBindIdentity(t *testing.T) { +func TestExchangePendingOAuthCompletionLoginFalseFalseBindsIdentityWithoutAdoption(t *testing.T) { handler, client := newOAuthPendingFlowTestHandler(t, false) ctx := context.Background() @@ -420,21 +420,22 @@ func TestExchangePendingOAuthCompletionLoginFalseFalseDoesNotBindIdentity(t *tes require.Equal(t, http.StatusOK, recorder.Code) - identityCount, err := client.AuthIdentity.Query(). + identity, err := client.AuthIdentity.Query(). Where( authidentity.ProviderTypeEQ("linuxdo"), authidentity.ProviderKeyEQ("linuxdo"), authidentity.ProviderSubjectEQ("login-false-123"), ). - Count(ctx) + Only(ctx) require.NoError(t, err) - require.Zero(t, identityCount) + require.Equal(t, userEntity.ID, identity.UserID) decision, err := client.IdentityAdoptionDecision.Query(). Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). Only(ctx) require.NoError(t, err) - require.Nil(t, decision.IdentityID) + require.NotNil(t, decision.IdentityID) + require.Equal(t, identity.ID, *decision.IdentityID) require.False(t, decision.AdoptDisplayName) require.False(t, decision.AdoptAvatar) diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go index f0755f1f..b6d47670 100644 --- a/backend/internal/handler/auth_wechat_oauth.go +++ b/backend/internal/handler/auth_wechat_oauth.go @@ -242,7 +242,18 @@ func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) { redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err)) return } + if existingIdentityUser == nil { + existingIdentityUser, err = h.findWeChatUserByLegacyOpenID(c.Request.Context(), identityRef, cfg, openid) + if err != nil { + redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + } if existingIdentityUser != nil { + if err := h.ensureWeChatRuntimeIdentityBinding(c.Request.Context(), existingIdentityUser.ID, identityRef, upstreamClaims); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err)) + return + } tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), existingIdentityUser.Email, username, "") if err != nil { redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err)) @@ -511,6 +522,91 @@ func (h *AuthHandler) ensureWeChatBindOwnership( return nil } +func (h *AuthHandler) findWeChatUserByLegacyOpenID( + ctx context.Context, + identity service.PendingAuthIdentityKey, + cfg wechatOAuthConfig, + openid string, +) (*dbent.User, error) { + client := h.entClient() + if client == nil { + return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + + openid = strings.TrimSpace(openid) + channel := strings.TrimSpace(cfg.mode) + channelAppID := strings.TrimSpace(cfg.appID) + if openid != "" && channel != "" && channelAppID != "" { + record, err := client.AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ(strings.TrimSpace(identity.ProviderType)), + authidentitychannel.ProviderKeyEQ(strings.TrimSpace(identity.ProviderKey)), + authidentitychannel.ChannelEQ(channel), + authidentitychannel.ChannelAppIDEQ(channelAppID), + authidentitychannel.ChannelSubjectEQ(openid), + ). + WithIdentity(func(q *dbent.AuthIdentityQuery) { + q.WithUser() + }). + Only(ctx) + if err != nil && !dbent.IsNotFound(err) { + return nil, infraerrors.InternalServer("AUTH_IDENTITY_CHANNEL_LOOKUP_FAILED", "failed to inspect auth identity channel ownership").WithCause(err) + } + if record != nil && record.Edges.Identity != nil && record.Edges.Identity.Edges.User != nil { + return record.Edges.Identity.Edges.User, nil + } + } + + if openid == "" { + return nil, nil + } + + record, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(strings.TrimSpace(identity.ProviderType)), + authidentity.ProviderKeyEQ(strings.TrimSpace(identity.ProviderKey)), + authidentity.ProviderSubjectEQ(openid), + ). + WithUser(). + Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return nil, nil + } + return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err) + } + return record.Edges.User, nil +} + +func (h *AuthHandler) ensureWeChatRuntimeIdentityBinding( + ctx context.Context, + userID int64, + identity service.PendingAuthIdentityKey, + upstreamClaims map[string]any, +) error { + client := h.entClient() + if client == nil { + return infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + + tx, err := client.Tx(ctx) + if err != nil { + return infraerrors.InternalServer("AUTH_IDENTITY_BIND_FAILED", "failed to begin wechat identity repair transaction").WithCause(err) + } + defer func() { _ = tx.Rollback() }() + + _, err = ensurePendingOAuthIdentityForUser(dbent.NewTxContext(ctx, tx), tx, &dbent.PendingAuthSession{ + ProviderType: strings.TrimSpace(identity.ProviderType), + ProviderKey: strings.TrimSpace(identity.ProviderKey), + ProviderSubject: strings.TrimSpace(identity.ProviderSubject), + UpstreamIdentityClaims: cloneOAuthMetadata(upstreamClaims), + }, userID) + if err != nil { + return err + } + return tx.Commit() +} + func (h *AuthHandler) getWeChatOAuthConfig(ctx context.Context, rawMode string, c *gin.Context) (wechatOAuthConfig, error) { mode, err := resolveWeChatOAuthMode(rawMode, c) if err != nil { diff --git a/backend/internal/handler/auth_wechat_oauth_test.go b/backend/internal/handler/auth_wechat_oauth_test.go index 1ff80e1b..dd022fb9 100644 --- a/backend/internal/handler/auth_wechat_oauth_test.go +++ b/backend/internal/handler/auth_wechat_oauth_test.go @@ -15,6 +15,7 @@ import ( dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" "github.com/Wei-Shaw/sub2api/ent/enttest" "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" @@ -563,6 +564,19 @@ func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing require.Equal(t, "WeChat Display", identity.Metadata["display_name"]) require.Equal(t, "https://cdn.example/wechat.png", identity.Metadata["avatar_url"]) + channel, err := client.AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ("wechat"), + authidentitychannel.ProviderKeyEQ("wechat-main"), + authidentitychannel.ChannelEQ("open"), + authidentitychannel.ChannelAppIDEQ("wx-open-app"), + authidentitychannel.ChannelSubjectEQ("openid-123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, identity.ID, channel.IdentityID) + require.Equal(t, "union-456", channel.Metadata["unionid"]) + decision, err := client.IdentityAdoptionDecision.Query(). Where(identityadoptiondecision.PendingAuthSessionIDEQ(pendingSession.ID)). Only(ctx) @@ -579,6 +593,116 @@ func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing require.NotNil(t, consumed.ConsumedAt) } +func TestWeChatOAuthCallbackRepairsLegacyOpenIDOnlyIdentity(t *testing.T) { + t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") + t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") + t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback") + + originalAccessTokenURL := wechatOAuthAccessTokenURL + originalUserInfoURL := wechatOAuthUserInfoURL + t.Cleanup(func() { + wechatOAuthAccessTokenURL = originalAccessTokenURL + wechatOAuthUserInfoURL = originalUserInfoURL + }) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`)) + case strings.Contains(r.URL.Path, "/sns/userinfo"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Legacy WeChat","headimgurl":"https://cdn.example/legacy.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" + wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" + + handler, client := newWeChatOAuthTestHandler(t, false) + defer client.Close() + + ctx := context.Background() + legacyUser, err := client.User.Create(). + SetEmail("legacy@example.com"). + SetUsername("legacy-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + legacyIdentity, err := client.AuthIdentity.Create(). + SetUserID(legacyUser.ID). + SetProviderType("wechat"). + SetProviderKey(wechatOAuthProviderKey). + SetProviderSubject("openid-123"). + SetMetadata(map[string]any{"openid": "openid-123"}). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil) + req.Host = "api.example.com" + req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123")) + req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) + req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open")) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) + c.Request = req + + handler.WeChatOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "/auth/wechat/callback", recorder.Header().Get("Location")) + + sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + + session, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, session.TargetUserID) + require.Equal(t, legacyUser.ID, *session.TargetUserID) + require.Equal(t, legacyUser.Email, session.ResolvedEmail) + + repairedIdentity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("wechat"), + authidentity.ProviderKeyEQ(wechatOAuthProviderKey), + authidentity.ProviderSubjectEQ("union-456"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, legacyIdentity.ID, repairedIdentity.ID) + require.Equal(t, legacyUser.ID, repairedIdentity.UserID) + + openIDIdentityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("wechat"), + authidentity.ProviderKeyEQ(wechatOAuthProviderKey), + authidentity.ProviderSubjectEQ("openid-123"), + ). + Count(ctx) + require.NoError(t, err) + require.Zero(t, openIDIdentityCount) + + channel, err := client.AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ("wechat"), + authidentitychannel.ProviderKeyEQ(wechatOAuthProviderKey), + authidentitychannel.ChannelEQ("open"), + authidentitychannel.ChannelAppIDEQ("wx-open-app"), + authidentitychannel.ChannelSubjectEQ("openid-123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, repairedIdentity.ID, channel.IdentityID) +} + func newWeChatOAuthTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) { t.Helper()