diff --git a/src/careamics/models/lvae/noise_models.py b/src/careamics/models/lvae/noise_models.py index 33e6a36ae..e89814de3 100644 --- a/src/careamics/models/lvae/noise_models.py +++ b/src/careamics/models/lvae/noise_models.py @@ -265,7 +265,7 @@ class GaussianMixtureNoiseModel(nn.Module): # TODO training a NM relies on getting a clean data(N2V e.g,) def __init__(self, config: GaussianMixtureNMConfig) -> None: super().__init__() - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.device = torch.device("cpu") if config.path is not None: params = np.load(config.path) @@ -319,10 +319,8 @@ def _set_model_mode(self, mode: str) -> None: """Move parameters to the device and set weights' requires_grad depending on the mode""" if mode == "train": self.weight.requires_grad = True - self.to_device(self.device) else: self.weight.requires_grad = False - self.to_device(torch.device("cpu")) def polynomial_regressor( self, weight_params: torch.Tensor, signals: torch.Tensor @@ -548,6 +546,8 @@ def fit( Upper percentile for clipping. Default is 100. """ self._set_model_mode(mode="train") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.to_device(device) optimizer = torch.optim.Adam([self.weight], lr=learning_rate) sig_obs_pairs = self.get_signal_observation_pairs( @@ -589,6 +589,7 @@ def fit( counter += 1 self._set_model_mode(mode="prediction") + self.to_device(torch.device("cpu")) print("===================\n") return train_losses