diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go index a5b9cd7f..7d2d3b70 100644 --- a/backend/internal/service/account_service.go +++ b/backend/internal/service/account_service.go @@ -208,20 +208,23 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount account.Status = *req.Status } - if err := s.accountRepo.Update(ctx, account); err != nil { - return nil, fmt.Errorf("update account: %w", err) - } - - // 更新分组绑定 + // 先验证分组是否存在(在任何写操作之前) if req.GroupIDs != nil { - // 验证分组是否存在 for _, groupID := range *req.GroupIDs { _, err := s.groupRepo.GetByID(ctx, groupID) if err != nil { return nil, fmt.Errorf("get group: %w", err) } } + } + // 执行更新 + if err := s.accountRepo.Update(ctx, account); err != nil { + return nil, fmt.Errorf("update account: %w", err) + } + + // 绑定分组 + if req.GroupIDs != nil { if err := s.accountRepo.BindGroups(ctx, account.ID, *req.GroupIDs); err != nil { return nil, fmt.Errorf("bind groups: %w", err) } diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index f1eb0fc6..db207ce5 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -652,11 +652,20 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U account.Status = input.Status } + // 先验证分组是否存在(在任何写操作之前) + if input.GroupIDs != nil { + for _, groupID := range *input.GroupIDs { + if _, err := s.groupRepo.GetByID(ctx, groupID); err != nil { + return nil, fmt.Errorf("get group: %w", err) + } + } + } + if err := s.accountRepo.Update(ctx, account); err != nil { return nil, err } - // 更新分组绑定 + // 绑定分组 if input.GroupIDs != nil { if err := s.accountRepo.BindGroups(ctx, account.ID, *input.GroupIDs); err != nil { return nil, err