Skip to content

Commit

Permalink
Merge branch 'main' into jd/chore/rename_mean_std
Browse files Browse the repository at this point in the history
  • Loading branch information
jdeschamps authored Jun 19, 2024
2 parents 5dc1e32 + 2a63226 commit 3436af5
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions src/careamics/model_io/model_io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pathlib import Path
from typing import Tuple, Union

from torch import load
import torch

from careamics.config import Configuration
from careamics.lightning_module import CAREamicsModule
Expand Down Expand Up @@ -64,7 +64,10 @@ def _load_checkpoint(path: Union[Path, str]) -> Tuple[CAREamicsModule, Configura
If the checkpoint file does not contain hyper parameters (configuration).
"""
# load checkpoint
checkpoint: dict = load(path)
# here we might run into issues between devices
# see https://pytorch.org/tutorials/recipes/recipes/save_load_across_devices.html
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
checkpoint: dict = torch.load(path, map_location=device)

# attempt to load configuration
try:
Expand Down

0 comments on commit 3436af5

Please sign in to comment.