Fix user profile writes on postgres conflicts
This commit is contained in:
@@ -19,13 +19,17 @@ func NewAnnouncementReadRepository(client *dbent.Client) service.AnnouncementRea
|
|||||||
|
|
||||||
func (r *announcementReadRepository) MarkRead(ctx context.Context, announcementID, userID int64, readAt time.Time) error {
|
func (r *announcementReadRepository) MarkRead(ctx context.Context, announcementID, userID int64, readAt time.Time) error {
|
||||||
client := clientFromContext(ctx, r.client)
|
client := clientFromContext(ctx, r.client)
|
||||||
return client.AnnouncementRead.Create().
|
err := client.AnnouncementRead.Create().
|
||||||
SetAnnouncementID(announcementID).
|
SetAnnouncementID(announcementID).
|
||||||
SetUserID(userID).
|
SetUserID(userID).
|
||||||
SetReadAt(readAt).
|
SetReadAt(readAt).
|
||||||
OnConflictColumns(announcementread.FieldAnnouncementID, announcementread.FieldUserID).
|
OnConflictColumns(announcementread.FieldAnnouncementID, announcementread.FieldUserID).
|
||||||
DoNothing().
|
DoNothing().
|
||||||
Exec(ctx)
|
Exec(ctx)
|
||||||
|
if isSQLNoRowsError(err) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *announcementReadRepository) GetReadMapByUser(ctx context.Context, userID int64, announcementIDs []int64) (map[int64]time.Time, error) {
|
func (r *announcementReadRepository) GetReadMapByUser(ctx context.Context, userID int64, announcementIDs []int64) (map[int64]time.Time, error) {
|
||||||
|
|||||||
@@ -392,6 +392,31 @@ func (s *UserProfileIdentityRepoSuite) TestUpsertIdentityAdoptionDecision_Reassi
|
|||||||
s.Require().Nil(reloadedFirst.IdentityID)
|
s.Require().Nil(reloadedFirst.IdentityID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *UserProfileIdentityRepoSuite) TestWithUserProfileIdentityTx_AllowsAvatarOnlyProfileUpdate() {
|
||||||
|
user := s.mustCreateUser("avatar-only-update")
|
||||||
|
|
||||||
|
model, err := s.repo.GetByID(s.ctx, user.ID)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().NotNil(model)
|
||||||
|
|
||||||
|
err = s.repo.WithUserProfileIdentityTx(s.ctx, func(txCtx context.Context) error {
|
||||||
|
_, err := s.repo.UpsertUserAvatar(txCtx, user.ID, service.UpsertUserAvatarInput{
|
||||||
|
StorageProvider: "remote_url",
|
||||||
|
URL: "https://cdn.example.com/avatar.png",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return s.repo.Update(txCtx, model)
|
||||||
|
})
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
avatar, err := s.repo.GetUserAvatar(s.ctx, user.ID)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().NotNil(avatar)
|
||||||
|
s.Require().Equal("https://cdn.example.com/avatar.png", avatar.URL)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *UserProfileIdentityRepoSuite) TestUserAvatarCRUDAndUserLookup() {
|
func (s *UserProfileIdentityRepoSuite) TestUserAvatarCRUDAndUserLookup() {
|
||||||
user := s.mustCreateUser("avatar")
|
user := s.mustCreateUser("avatar")
|
||||||
|
|
||||||
|
|||||||
@@ -56,8 +56,12 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
|
|||||||
defer func() { _ = tx.Rollback() }()
|
defer func() { _ = tx.Rollback() }()
|
||||||
txClient = tx.Client()
|
txClient = tx.Client()
|
||||||
} else {
|
} else {
|
||||||
// 已处于外部事务中(ErrTxStarted),复用当前 client 并由调用方负责提交/回滚。
|
// 已处于外部事务中(ErrTxStarted),复用当前事务 client 并由调用方负责提交/回滚。
|
||||||
txClient = r.client
|
if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
|
||||||
|
txClient = existingTx.Client()
|
||||||
|
} else {
|
||||||
|
txClient = r.client
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
created, err := txClient.User.Create().
|
created, err := txClient.User.Create().
|
||||||
@@ -154,8 +158,12 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
|
|||||||
defer func() { _ = tx.Rollback() }()
|
defer func() { _ = tx.Rollback() }()
|
||||||
txClient = tx.Client()
|
txClient = tx.Client()
|
||||||
} else {
|
} else {
|
||||||
// 已处于外部事务中(ErrTxStarted),复用当前 client 并由调用方负责提交/回滚。
|
// 已处于外部事务中(ErrTxStarted),复用当前事务 client 并由调用方负责提交/回滚。
|
||||||
txClient = r.client
|
if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
|
||||||
|
txClient = existingTx.Client()
|
||||||
|
} else {
|
||||||
|
txClient = r.client
|
||||||
|
}
|
||||||
}
|
}
|
||||||
existing, err := clientFromContext(ctx, txClient).User.Get(ctx, userIn.ID)
|
existing, err := clientFromContext(ctx, txClient).User.Get(ctx, userIn.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -236,7 +244,9 @@ func ensureEmailAuthIdentityWithClient(ctx context.Context, client *dbent.Client
|
|||||||
).
|
).
|
||||||
DoNothing().
|
DoNothing().
|
||||||
Exec(ctx); err != nil {
|
Exec(ctx); err != nil {
|
||||||
return err
|
if !isSQLNoRowsError(err) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
identity, err := client.AuthIdentity.Query().
|
identity, err := client.AuthIdentity.Query().
|
||||||
@@ -304,7 +314,11 @@ func (r *userRepository) Delete(ctx context.Context, id int64) error {
|
|||||||
defer func() { _ = tx.Rollback() }()
|
defer func() { _ = tx.Rollback() }()
|
||||||
txClient = tx.Client()
|
txClient = tx.Client()
|
||||||
} else {
|
} else {
|
||||||
txClient = r.client
|
if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
|
||||||
|
txClient = existingTx.Client()
|
||||||
|
} else {
|
||||||
|
txClient = r.client
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
identityIDs, err := txClient.AuthIdentity.Query().
|
identityIDs, err := txClient.AuthIdentity.Query().
|
||||||
@@ -707,12 +721,16 @@ func userEmailLookupPredicate(email string) predicate.User {
|
|||||||
|
|
||||||
func (r *userRepository) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
|
func (r *userRepository) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
|
||||||
client := clientFromContext(ctx, r.client)
|
client := clientFromContext(ctx, r.client)
|
||||||
return client.UserAllowedGroup.Create().
|
err := client.UserAllowedGroup.Create().
|
||||||
SetUserID(userID).
|
SetUserID(userID).
|
||||||
SetGroupID(groupID).
|
SetGroupID(groupID).
|
||||||
OnConflictColumns(userallowedgroup.FieldUserID, userallowedgroup.FieldGroupID).
|
OnConflictColumns(userallowedgroup.FieldUserID, userallowedgroup.FieldGroupID).
|
||||||
DoNothing().
|
DoNothing().
|
||||||
Exec(ctx)
|
Exec(ctx)
|
||||||
|
if isSQLNoRowsError(err) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
|
func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
|
||||||
@@ -812,6 +830,9 @@ func (r *userRepository) syncUserAllowedGroupsWithClient(ctx context.Context, cl
|
|||||||
OnConflictColumns(userallowedgroup.FieldUserID, userallowedgroup.FieldGroupID).
|
OnConflictColumns(userallowedgroup.FieldUserID, userallowedgroup.FieldGroupID).
|
||||||
DoNothing().
|
DoNothing().
|
||||||
Exec(ctx); err != nil {
|
Exec(ctx); err != nil {
|
||||||
|
if isSQLNoRowsError(err) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -160,6 +160,30 @@ func (s *UserRepoSuite) TestUpdate() {
|
|||||||
s.Require().Equal("updated", updated.Username)
|
s.Require().Equal("updated", updated.Username)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *UserRepoSuite) TestUpdateIgnoresNoRowsFromConflictingEmailIdentityUpsert() {
|
||||||
|
user := s.mustCreateUser(&service.User{Email: "update-existing-identity@test.com", Username: "original"})
|
||||||
|
|
||||||
|
identityCount, err := s.client.AuthIdentity.Query().
|
||||||
|
Where(
|
||||||
|
authidentity.UserIDEQ(user.ID),
|
||||||
|
authidentity.ProviderTypeEQ("email"),
|
||||||
|
authidentity.ProviderKeyEQ("email"),
|
||||||
|
authidentity.ProviderSubjectEQ("update-existing-identity@test.com"),
|
||||||
|
).
|
||||||
|
Count(s.ctx)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Equal(1, identityCount)
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, user.ID)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
got.Username = "updated"
|
||||||
|
s.Require().NoError(s.repo.Update(s.ctx, got), "Update should tolerate ON CONFLICT DO NOTHING returning no rows")
|
||||||
|
|
||||||
|
updated, err := s.repo.GetByID(s.ctx, user.ID)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Equal("updated", updated.Username)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *UserRepoSuite) TestDelete() {
|
func (s *UserRepoSuite) TestDelete() {
|
||||||
user := s.mustCreateUser(&service.User{Email: "delete@test.com"})
|
user := s.mustCreateUser(&service.User{Email: "delete@test.com"})
|
||||||
|
|
||||||
|
|||||||
@@ -277,7 +277,9 @@ func ensureBoundEmailAuthIdentityWithClient(
|
|||||||
).
|
).
|
||||||
DoNothing().
|
DoNothing().
|
||||||
Exec(ctx); err != nil {
|
Exec(ctx); err != nil {
|
||||||
return err
|
if !isSQLNoRowsError(err) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
identity, err := client.AuthIdentity.Query().
|
identity, err := client.AuthIdentity.Query().
|
||||||
|
|||||||
@@ -916,6 +916,11 @@ func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User, s
|
|||||||
).
|
).
|
||||||
DoNothing().
|
DoNothing().
|
||||||
Exec(ctx); err != nil {
|
Exec(ctx); err != nil {
|
||||||
|
if isSQLNoRowsError(err) {
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to ensure email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
|
logger.LegacyPrintf("service.auth", "[Auth] Failed to ensure email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|||||||
14
backend/internal/service/sql_errors.go
Normal file
14
backend/internal/service/sql_errors.go
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func isSQLNoRowsError(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return errors.Is(err, sql.ErrNoRows) || strings.Contains(err.Error(), "no rows in result set")
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user