From a3e829ec74dbd5747cd0a1e6c0f2134f26b3fe2d Mon Sep 17 00:00:00 2001 From: veegalinova Date: Fri, 17 Jan 2025 15:06:52 +0100 Subject: [PATCH] fix: prevent noise model from moving to CPU during mode changes --- src/careamics/models/lvae/noise_models.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/careamics/models/lvae/noise_models.py b/src/careamics/models/lvae/noise_models.py index 33e6a36ae..e89814de3 100644 --- a/src/careamics/models/lvae/noise_models.py +++ b/src/careamics/models/lvae/noise_models.py @@ -265,7 +265,7 @@ class GaussianMixtureNoiseModel(nn.Module): # TODO training a NM relies on getting a clean data(N2V e.g,) def __init__(self, config: GaussianMixtureNMConfig) -> None: super().__init__() - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.device = torch.device("cpu") if config.path is not None: params = np.load(config.path) @@ -319,10 +319,8 @@ def _set_model_mode(self, mode: str) -> None: """Move parameters to the device and set weights' requires_grad depending on the mode""" if mode == "train": self.weight.requires_grad = True - self.to_device(self.device) else: self.weight.requires_grad = False - self.to_device(torch.device("cpu")) def polynomial_regressor( self, weight_params: torch.Tensor, signals: torch.Tensor @@ -548,6 +546,8 @@ def fit( Upper percentile for clipping. Default is 100. """ self._set_model_mode(mode="train") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.to_device(device) optimizer = torch.optim.Adam([self.weight], lr=learning_rate) sig_obs_pairs = self.get_signal_observation_pairs( @@ -589,6 +589,7 @@ def fit( counter += 1 self._set_model_mode(mode="prediction") + self.to_device(torch.device("cpu")) print("===================\n") return train_losses