Fix user profile writes on postgres conflicts
This commit is contained in:
@@ -19,13 +19,17 @@ func NewAnnouncementReadRepository(client *dbent.Client) service.AnnouncementRea
|
||||
|
||||
func (r *announcementReadRepository) MarkRead(ctx context.Context, announcementID, userID int64, readAt time.Time) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
return client.AnnouncementRead.Create().
|
||||
err := client.AnnouncementRead.Create().
|
||||
SetAnnouncementID(announcementID).
|
||||
SetUserID(userID).
|
||||
SetReadAt(readAt).
|
||||
OnConflictColumns(announcementread.FieldAnnouncementID, announcementread.FieldUserID).
|
||||
DoNothing().
|
||||
Exec(ctx)
|
||||
if isSQLNoRowsError(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *announcementReadRepository) GetReadMapByUser(ctx context.Context, userID int64, announcementIDs []int64) (map[int64]time.Time, error) {
|
||||
|
||||
@@ -392,6 +392,31 @@ func (s *UserProfileIdentityRepoSuite) TestUpsertIdentityAdoptionDecision_Reassi
|
||||
s.Require().Nil(reloadedFirst.IdentityID)
|
||||
}
|
||||
|
||||
func (s *UserProfileIdentityRepoSuite) TestWithUserProfileIdentityTx_AllowsAvatarOnlyProfileUpdate() {
|
||||
user := s.mustCreateUser("avatar-only-update")
|
||||
|
||||
model, err := s.repo.GetByID(s.ctx, user.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(model)
|
||||
|
||||
err = s.repo.WithUserProfileIdentityTx(s.ctx, func(txCtx context.Context) error {
|
||||
_, err := s.repo.UpsertUserAvatar(txCtx, user.ID, service.UpsertUserAvatarInput{
|
||||
StorageProvider: "remote_url",
|
||||
URL: "https://cdn.example.com/avatar.png",
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.repo.Update(txCtx, model)
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
avatar, err := s.repo.GetUserAvatar(s.ctx, user.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(avatar)
|
||||
s.Require().Equal("https://cdn.example.com/avatar.png", avatar.URL)
|
||||
}
|
||||
|
||||
func (s *UserProfileIdentityRepoSuite) TestUserAvatarCRUDAndUserLookup() {
|
||||
user := s.mustCreateUser("avatar")
|
||||
|
||||
|
||||
@@ -56,8 +56,12 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
txClient = tx.Client()
|
||||
} else {
|
||||
// 已处于外部事务中(ErrTxStarted),复用当前 client 并由调用方负责提交/回滚。
|
||||
txClient = r.client
|
||||
// 已处于外部事务中(ErrTxStarted),复用当前事务 client 并由调用方负责提交/回滚。
|
||||
if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
|
||||
txClient = existingTx.Client()
|
||||
} else {
|
||||
txClient = r.client
|
||||
}
|
||||
}
|
||||
|
||||
created, err := txClient.User.Create().
|
||||
@@ -154,8 +158,12 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
txClient = tx.Client()
|
||||
} else {
|
||||
// 已处于外部事务中(ErrTxStarted),复用当前 client 并由调用方负责提交/回滚。
|
||||
txClient = r.client
|
||||
// 已处于外部事务中(ErrTxStarted),复用当前事务 client 并由调用方负责提交/回滚。
|
||||
if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
|
||||
txClient = existingTx.Client()
|
||||
} else {
|
||||
txClient = r.client
|
||||
}
|
||||
}
|
||||
existing, err := clientFromContext(ctx, txClient).User.Get(ctx, userIn.ID)
|
||||
if err != nil {
|
||||
@@ -236,7 +244,9 @@ func ensureEmailAuthIdentityWithClient(ctx context.Context, client *dbent.Client
|
||||
).
|
||||
DoNothing().
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
if !isSQLNoRowsError(err) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
identity, err := client.AuthIdentity.Query().
|
||||
@@ -304,7 +314,11 @@ func (r *userRepository) Delete(ctx context.Context, id int64) error {
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
txClient = tx.Client()
|
||||
} else {
|
||||
txClient = r.client
|
||||
if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
|
||||
txClient = existingTx.Client()
|
||||
} else {
|
||||
txClient = r.client
|
||||
}
|
||||
}
|
||||
|
||||
identityIDs, err := txClient.AuthIdentity.Query().
|
||||
@@ -707,12 +721,16 @@ func userEmailLookupPredicate(email string) predicate.User {
|
||||
|
||||
func (r *userRepository) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
return client.UserAllowedGroup.Create().
|
||||
err := client.UserAllowedGroup.Create().
|
||||
SetUserID(userID).
|
||||
SetGroupID(groupID).
|
||||
OnConflictColumns(userallowedgroup.FieldUserID, userallowedgroup.FieldGroupID).
|
||||
DoNothing().
|
||||
Exec(ctx)
|
||||
if isSQLNoRowsError(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
|
||||
@@ -812,6 +830,9 @@ func (r *userRepository) syncUserAllowedGroupsWithClient(ctx context.Context, cl
|
||||
OnConflictColumns(userallowedgroup.FieldUserID, userallowedgroup.FieldGroupID).
|
||||
DoNothing().
|
||||
Exec(ctx); err != nil {
|
||||
if isSQLNoRowsError(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -160,6 +160,30 @@ func (s *UserRepoSuite) TestUpdate() {
|
||||
s.Require().Equal("updated", updated.Username)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestUpdateIgnoresNoRowsFromConflictingEmailIdentityUpsert() {
|
||||
user := s.mustCreateUser(&service.User{Email: "update-existing-identity@test.com", Username: "original"})
|
||||
|
||||
identityCount, err := s.client.AuthIdentity.Query().
|
||||
Where(
|
||||
authidentity.UserIDEQ(user.ID),
|
||||
authidentity.ProviderTypeEQ("email"),
|
||||
authidentity.ProviderKeyEQ("email"),
|
||||
authidentity.ProviderSubjectEQ("update-existing-identity@test.com"),
|
||||
).
|
||||
Count(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(1, identityCount)
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, user.ID)
|
||||
s.Require().NoError(err)
|
||||
got.Username = "updated"
|
||||
s.Require().NoError(s.repo.Update(s.ctx, got), "Update should tolerate ON CONFLICT DO NOTHING returning no rows")
|
||||
|
||||
updated, err := s.repo.GetByID(s.ctx, user.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal("updated", updated.Username)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestDelete() {
|
||||
user := s.mustCreateUser(&service.User{Email: "delete@test.com"})
|
||||
|
||||
|
||||
@@ -277,7 +277,9 @@ func ensureBoundEmailAuthIdentityWithClient(
|
||||
).
|
||||
DoNothing().
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
if !isSQLNoRowsError(err) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
identity, err := client.AuthIdentity.Query().
|
||||
|
||||
@@ -916,6 +916,11 @@ func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User, s
|
||||
).
|
||||
DoNothing().
|
||||
Exec(ctx); err != nil {
|
||||
if isSQLNoRowsError(err) {
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to ensure email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
14
backend/internal/service/sql_errors.go
Normal file
14
backend/internal/service/sql_errors.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func isSQLNoRowsError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
return errors.Is(err, sql.ErrNoRows) || strings.Contains(err.Error(), "no rows in result set")
|
||||
}
|
||||
Reference in New Issue
Block a user