diff --git a/pytorch_pfn_extras/training/extensions/_snapshot.py b/pytorch_pfn_extras/training/extensions/_snapshot.py index fdc218b2..c4c2e7a7 100644 --- a/pytorch_pfn_extras/training/extensions/_snapshot.py +++ b/pytorch_pfn_extras/training/extensions/_snapshot.py @@ -513,11 +513,10 @@ def _add_cleanup_hook(self, writer: writing.Writer) -> None: if self._rank == self._saver_rank: super()._add_cleanup_hook(writer) - def __call__(self, manager: ExtensionsManagerProtocol) -> None: - if self.condition(): - # on distributed environments only the designed rank - # saves the snapshot - if self._rank == self._saver_rank: - self._make_snapshot(manager) - if self._size > 1: - torch.distributed.barrier() # type: ignore[no-untyped-call] + def _make_snapshot(self, manager: ExtensionsManagerProtocol) -> None: + # on distributed environments only the designed rank + # saves the snapshot + if self._rank == self._saver_rank: + super()._make_snapshot(manager) + if self._size > 1: + torch.distributed.barrier() # type: ignore