Fix user profile writes on postgres conflicts

This commit is contained in:
IanShaw027
2026-04-21 10:13:28 -07:00
parent 0f4a8d7be8
commit 525a320424
7 changed files with 104 additions and 9 deletions

View File

@@ -19,13 +19,17 @@ func NewAnnouncementReadRepository(client *dbent.Client) service.AnnouncementRea
func (r *announcementReadRepository) MarkRead(ctx context.Context, announcementID, userID int64, readAt time.Time) error {
client := clientFromContext(ctx, r.client)
return client.AnnouncementRead.Create().
err := client.AnnouncementRead.Create().
SetAnnouncementID(announcementID).
SetUserID(userID).
SetReadAt(readAt).
OnConflictColumns(announcementread.FieldAnnouncementID, announcementread.FieldUserID).
DoNothing().
Exec(ctx)
if isSQLNoRowsError(err) {
return nil
}
return err
}
func (r *announcementReadRepository) GetReadMapByUser(ctx context.Context, userID int64, announcementIDs []int64) (map[int64]time.Time, error) {

View File

@@ -392,6 +392,31 @@ func (s *UserProfileIdentityRepoSuite) TestUpsertIdentityAdoptionDecision_Reassi
s.Require().Nil(reloadedFirst.IdentityID)
}
func (s *UserProfileIdentityRepoSuite) TestWithUserProfileIdentityTx_AllowsAvatarOnlyProfileUpdate() {
user := s.mustCreateUser("avatar-only-update")
model, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().NotNil(model)
err = s.repo.WithUserProfileIdentityTx(s.ctx, func(txCtx context.Context) error {
_, err := s.repo.UpsertUserAvatar(txCtx, user.ID, service.UpsertUserAvatarInput{
StorageProvider: "remote_url",
URL: "https://cdn.example.com/avatar.png",
})
if err != nil {
return err
}
return s.repo.Update(txCtx, model)
})
s.Require().NoError(err)
avatar, err := s.repo.GetUserAvatar(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().NotNil(avatar)
s.Require().Equal("https://cdn.example.com/avatar.png", avatar.URL)
}
func (s *UserProfileIdentityRepoSuite) TestUserAvatarCRUDAndUserLookup() {
user := s.mustCreateUser("avatar")

View File

@@ -56,9 +56,13 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
defer func() { _ = tx.Rollback() }()
txClient = tx.Client()
} else {
// 已处于外部事务中ErrTxStarted复用当前 client 并由调用方负责提交/回滚。
// 已处于外部事务中ErrTxStarted复用当前事务 client 并由调用方负责提交/回滚。
if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
txClient = existingTx.Client()
} else {
txClient = r.client
}
}
created, err := txClient.User.Create().
SetEmail(userIn.Email).
@@ -154,9 +158,13 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
defer func() { _ = tx.Rollback() }()
txClient = tx.Client()
} else {
// 已处于外部事务中ErrTxStarted复用当前 client 并由调用方负责提交/回滚。
// 已处于外部事务中ErrTxStarted复用当前事务 client 并由调用方负责提交/回滚。
if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
txClient = existingTx.Client()
} else {
txClient = r.client
}
}
existing, err := clientFromContext(ctx, txClient).User.Get(ctx, userIn.ID)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
@@ -236,8 +244,10 @@ func ensureEmailAuthIdentityWithClient(ctx context.Context, client *dbent.Client
).
DoNothing().
Exec(ctx); err != nil {
if !isSQLNoRowsError(err) {
return err
}
}
identity, err := client.AuthIdentity.Query().
Where(
@@ -303,9 +313,13 @@ func (r *userRepository) Delete(ctx context.Context, id int64) error {
if err == nil {
defer func() { _ = tx.Rollback() }()
txClient = tx.Client()
} else {
if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
txClient = existingTx.Client()
} else {
txClient = r.client
}
}
identityIDs, err := txClient.AuthIdentity.Query().
Where(authidentity.UserIDEQ(id)).
@@ -707,12 +721,16 @@ func userEmailLookupPredicate(email string) predicate.User {
func (r *userRepository) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
client := clientFromContext(ctx, r.client)
return client.UserAllowedGroup.Create().
err := client.UserAllowedGroup.Create().
SetUserID(userID).
SetGroupID(groupID).
OnConflictColumns(userallowedgroup.FieldUserID, userallowedgroup.FieldGroupID).
DoNothing().
Exec(ctx)
if isSQLNoRowsError(err) {
return nil
}
return err
}
func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
@@ -812,6 +830,9 @@ func (r *userRepository) syncUserAllowedGroupsWithClient(ctx context.Context, cl
OnConflictColumns(userallowedgroup.FieldUserID, userallowedgroup.FieldGroupID).
DoNothing().
Exec(ctx); err != nil {
if isSQLNoRowsError(err) {
return nil
}
return err
}
}

View File

@@ -160,6 +160,30 @@ func (s *UserRepoSuite) TestUpdate() {
s.Require().Equal("updated", updated.Username)
}
func (s *UserRepoSuite) TestUpdateIgnoresNoRowsFromConflictingEmailIdentityUpsert() {
user := s.mustCreateUser(&service.User{Email: "update-existing-identity@test.com", Username: "original"})
identityCount, err := s.client.AuthIdentity.Query().
Where(
authidentity.UserIDEQ(user.ID),
authidentity.ProviderTypeEQ("email"),
authidentity.ProviderKeyEQ("email"),
authidentity.ProviderSubjectEQ("update-existing-identity@test.com"),
).
Count(s.ctx)
s.Require().NoError(err)
s.Require().Equal(1, identityCount)
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
got.Username = "updated"
s.Require().NoError(s.repo.Update(s.ctx, got), "Update should tolerate ON CONFLICT DO NOTHING returning no rows")
updated, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().Equal("updated", updated.Username)
}
func (s *UserRepoSuite) TestDelete() {
user := s.mustCreateUser(&service.User{Email: "delete@test.com"})

View File

@@ -277,8 +277,10 @@ func ensureBoundEmailAuthIdentityWithClient(
).
DoNothing().
Exec(ctx); err != nil {
if !isSQLNoRowsError(err) {
return err
}
}
identity, err := client.AuthIdentity.Query().
Where(

View File

@@ -916,6 +916,11 @@ func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User, s
).
DoNothing().
Exec(ctx); err != nil {
if isSQLNoRowsError(err) {
err = nil
}
}
if err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to ensure email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
return nil, false
}

View File

@@ -0,0 +1,14 @@
package service
import (
"database/sql"
"errors"
"strings"
)
func isSQLNoRowsError(err error) bool {
if err == nil {
return false
}
return errors.Is(err, sql.ErrNoRows) || strings.Contains(err.Error(), "no rows in result set")
}