Skip to content

Commit

Permalink
Fix: Load checkpoint on the correct device (#146)
Browse files Browse the repository at this point in the history
### Description

Although CAREamics should be run on the GPU, users might want to use CPU
to perform some tasks, such as prediction using trained network. This
currently leads to issues when the model was trained on the GPU:
#145.

This PR simply fixes that issue by detecting the current device.

- **What**: Add device parameter in `torch.load` function call.
- **Why**: Fixing #145 so
that a CPU-only user can use a pre-trained network.

### Changes Made

- **Modified**: `model_io_utils.py`.

### Related Issues

- Fixes #145.

---

**Please ensure your PR meets the following requirements:**

- [x] Code builds and passes tests locally, including doctests
- [x] Pre-commit passes
  • Loading branch information
jdeschamps authored Jun 19, 2024
1 parent 4227a19 commit 2a63226
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions src/careamics/model_io/model_io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pathlib import Path
from typing import Tuple, Union

from torch import load
import torch

from careamics.config import Configuration
from careamics.lightning_module import CAREamicsModule
Expand Down Expand Up @@ -64,7 +64,10 @@ def _load_checkpoint(path: Union[Path, str]) -> Tuple[CAREamicsModule, Configura
If the checkpoint file does not contain hyper parameters (configuration).
"""
# load checkpoint
checkpoint: dict = load(path)
# here we might run into issues between devices
# see https://pytorch.org/tutorials/recipes/recipes/save_load_across_devices.html
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
checkpoint: dict = torch.load(path, map_location=device)

# attempt to load configuration
try:
Expand Down

0 comments on commit 2a63226

Please sign in to comment.