diff --git a/src/careamics/model_io/model_io_utils.py b/src/careamics/model_io/model_io_utils.py index 2beb25ec2..987b2fb0b 100644 --- a/src/careamics/model_io/model_io_utils.py +++ b/src/careamics/model_io/model_io_utils.py @@ -3,7 +3,7 @@ from pathlib import Path from typing import Tuple, Union -from torch import load +import torch from careamics.config import Configuration from careamics.lightning_module import CAREamicsModule @@ -64,7 +64,10 @@ def _load_checkpoint(path: Union[Path, str]) -> Tuple[CAREamicsModule, Configura If the checkpoint file does not contain hyper parameters (configuration). """ # load checkpoint - checkpoint: dict = load(path) + # here we might run into issues between devices + # see https://pytorch.org/tutorials/recipes/recipes/save_load_across_devices.html + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + checkpoint: dict = torch.load(path, map_location=device) # attempt to load configuration try: