From 8f03c885111f602252c8b1e969a0318792898c01 Mon Sep 17 00:00:00 2001 From: Vera Galinova <32124316+veegalinova@users.noreply.github.com> Date: Thu, 16 Jan 2025 14:03:50 +0100 Subject: [PATCH] Feature: Sampling from noise model, noise model refactoring (#340) ### Description - **What**: Added new method `sample_observation_from_signal` to the GaussianMixtureNoiseModel class, refactored the GaussianMixtureNoiseModel initialization and tensor device allocation, added noise model plotting function - **Why**: The `sample_observation_from_signal` function is necessary for the AI4Life project that I'm working on. The new tensor device allocation logic for GaussianMixtureNoiseModel was necessary because of the weights copying bug in the training. - **How**: 1. The `sample_observation_from_signal` function of GaussianMixtureNoiseModel class allows the creation of a noisy image based on an input clean "signal" image using the learned noise model statistics. It predicts means, standard deviations, and the probability of gaussian components for every pixel in "signal", then selects the gaussian component with the predicted probability and samples from the selected gaussian with the predicted mean and standard deviations. 2. The GaussianMixtureNoiseModel's parameters are now moved to the GPU before training and back to the CPU after training is finished. The `weights` parameter's `requires_grad` is set to `True` before training begins and it is detached after training is completed. 3. Added a slightly refactored version of `plot_probability_distribution` function into the utils folder, added matplotlib dependency ### Changes Made - **Added**: 1. `sample_observation_from_signal` function of GaussianMixtureNoiseModel 2. Added `_set_model_mode` functionality to move parameters between GPU and CPU for training 3. Added a new file `utils/plotting.py` with `plot_noise_model_probability_distribution` function for noise model visualization 4. Addded `matplotlib` dependency - **Modified**: Describe existing features or files modified. 1. Refactored GaussianMixtureNoiseModel class initialization 2. Added type annotations and missing docstrings, slightly changed namings to python standards 3. Refactored `create_histogram` function ### Breaking changes 1. Added the `matplotlib` dependency to the package. 2. `create_histogram` function's output is now slightly numerically different than before due to a logical error in the original code --- **Please ensure your PR meets the following requirements:** - [x] Code builds and passes tests locally, including doctests - [x] New tests have been added (for bug fixes/features) - [x] Pre-commit passes - [ ] PR to the documentation exists (for bug fixes / features) --- pyproject.toml | 1 + src/careamics/models/lvae/likelihoods.py | 2 + src/careamics/models/lvae/noise_models.py | 496 ++++++++++++---------- src/careamics/utils/plotting.py | 78 ++++ tests/models/lvae/test_noise_model.py | 95 ++++- 5 files changed, 446 insertions(+), 226 deletions(-) create mode 100644 src/careamics/utils/plotting.py diff --git a/pyproject.toml b/pyproject.toml index 8d2a6af52..0649284fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,7 @@ dependencies = [ 'scikit-image<=0.25.0', 'zarr<3.0.0', 'pillow<=11.1.0', + 'matplotlib<=3.9.0' ] [project.optional-dependencies] diff --git a/src/careamics/models/lvae/likelihoods.py b/src/careamics/models/lvae/likelihoods.py index 51c5fbef2..a38b0dbb3 100644 --- a/src/careamics/models/lvae/likelihoods.py +++ b/src/careamics/models/lvae/likelihoods.py @@ -324,6 +324,8 @@ def _set_params_to_same_device_as( if self.data_mean.device != correct_device_tensor.device: self.data_mean = self.data_mean.to(correct_device_tensor.device) self.data_std = self.data_std.to(correct_device_tensor.device) + if correct_device_tensor.device != self.noiseModel.device: + self.noiseModel.to_device(correct_device_tensor.device) def get_mean_lv(self, x: torch.Tensor) -> tuple[torch.Tensor, None]: return x, None diff --git a/src/careamics/models/lvae/noise_models.py b/src/careamics/models/lvae/noise_models.py index 6bd1044c9..33e6a36ae 100644 --- a/src/careamics/models/lvae/noise_models.py +++ b/src/careamics/models/lvae/noise_models.py @@ -3,6 +3,7 @@ import os from typing import TYPE_CHECKING, Optional +from numpy.typing import NDArray import numpy as np import torch import torch.nn as nn @@ -13,63 +14,59 @@ # TODO this module shouldn't be in lvae folder -def create_histogram(bins, min_val, max_val, observation, signal): +def create_histogram( + bins: int, min_val: float, max_val: float, observation: NDArray, signal: NDArray +) -> NDArray: """ Creates a 2D histogram from 'observation' and 'signal'. Parameters ---------- - bins: int - The number of bins in x and y. The total number of 2D bins is 'bins'**2. - min_val: float - the lower bound of the lowest bin in x and y. - max_val: float - the highest bound of the highest bin in x and y. - observation: numpy array - A 3D numpy array that is interpretted as a stack of 2D images. - The number of images has to be divisible by the number of images in 'signal'. - It is assumed that n subsequent images in observation belong to one image image in 'signal'. - signal: numpy array - A 3D numpy array that is interpretted as a stack of 2D images. + bins : int + Number of bins in x and y. + min_val : float + Lower bound of the lowest bin in x and y. + max_val : float + Upper bound of the highest bin in x and y. + observation : np.ndarray + 3D numpy array (stack of 2D images). + Observation.shape[0] must be divisible by signal.shape[0]. + Assumes that n subsequent images in observation belong to one image in 'signal'. + signal : np.ndarray + 3D numpy array (stack of 2D images). Returns ------- - histogram: numpy array + histogram : np.ndarray A 3D array: - 'histogram[0,...]' holds the normalized 2D counts. - Each row sums to 1, describing p(x_i|s_i). - 'histogram[1,...]' holds the lower boundaries of each bin in y. - 'histogram[2,...]' holds the upper boundaries of each bin in y. - The values for x can be obtained by transposing 'histogram[1,...]' and 'histogram[2,...]'. + - histogram[0]: Normalized 2D counts. + - histogram[1]: Lower boundaries of bins along y. + - histogram[2]: Upper boundaries of bins along y. + The values for x can be obtained by transposing 'histogram[1]' and 'histogram[2]'. """ - # TODO refactor this function - img_factor = int(observation.shape[0] / signal.shape[0]) histogram = np.zeros((3, bins, bins)) - ra = [min_val, max_val] - - for i in range(observation.shape[0]): - observation_ = observation[i].copy().ravel() - - signal_ = (signal[i // img_factor].copy()).ravel() - a = np.histogram2d(signal_, observation_, bins=bins, range=[ra, ra]) - histogram[0] = histogram[0] + a[0] + 1e-30 # This is for numerical stability - - for i in range(bins): - if ( - np.sum(histogram[0, i, :]) > 1e-20 - ): # We exclude empty rows from normalization - histogram[0, i, :] /= np.sum( - histogram[0, i, :] - ) # we normalize each non-empty row - - for i in range(bins): - histogram[1, :, i] = a[1][ - :-1 - ] # The lower boundaries of each bin in y are stored in dimension 1 - histogram[2, :, i] = a[1][ - 1: - ] # The upper boundaries of each bin in y are stored in dimension 2 - # The accordent numbers for x are just transopsed. + + value_range = [min_val, max_val] + + # Compute mapping factor between observation and signal samples + obs_to_signal_shape_factor = int(observation.shape[0] / signal.shape[0]) + + # Flatten arrays and align signal values + signal_indices = np.arange(observation.shape[0]) // obs_to_signal_shape_factor + signal_values = signal[signal_indices].ravel() + observation_values = observation.ravel() + + count_histogram, signal_edges, _ = np.histogram2d( + signal_values, observation_values, bins=bins, range=[value_range, value_range] + ) + + # Normalize rows to obtain probabilities + row_sums = count_histogram.sum(axis=1, keepdims=True) + count_histogram /= np.clip(row_sums, a_min=1e-20, a_max=None) + + histogram[0] = count_histogram + histogram[1] = signal_edges[:-1][..., np.newaxis] + histogram[2] = signal_edges[1:][..., np.newaxis] return histogram @@ -111,8 +108,11 @@ def noise_model_factory( # TODO train a new model. Config should always be provided? if nm.model_type == "GaussianMixtureNoiseModel": # TODO one model for each channel all make this choise inside the model? - trained_nm = train_gm_noise_model(nm) - noise_models.append(trained_nm) + # trained_nm = train_gm_noise_model(nm) + # noise_models.append(trained_nm) + raise NotImplementedError( + "GaussianMixtureNoiseModel model training is not implemented." + ) else: raise NotImplementedError( f"Model {nm.model_type} is not implemented" @@ -163,6 +163,8 @@ def __init__(self, nmodels: list[GaussianMixtureNoiseModel]): List of noise models, one for each output channel. """ super().__init__() + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + for i, nmodel in enumerate(nmodels): # TODO refactor this !!! if nmodel is not None: self.add_module( @@ -176,6 +178,13 @@ def __init__(self, nmodels: list[GaussianMixtureNoiseModel]): print(f"[{self.__class__.__name__}] Nmodels count:{self._nm_cnt}") + def to_device(self, device: torch.device): + self.device = device + self.to(device) + for ch_idx in range(self._nm_cnt): + nmodel = getattr(self, f"nmodel_{ch_idx}") + nmodel.to_device(device) + def likelihood(self, obs: torch.Tensor, signal: torch.Tensor) -> torch.Tensor: """Compute the likelihood of observations given signals for each channel. @@ -212,28 +221,6 @@ def likelihood(self, obs: torch.Tensor, signal: torch.Tensor) -> torch.Tensor: return torch.cat(ll_list, dim=1) -# TODO: is this needed? -def fastShuffle(series, num): - """_summary_. - - Parameters - ---------- - series : _type_ - _description_ - num : _type_ - _description_ - - Returns - ------- - _type_ - _description_ - """ - length = series.shape[0] - for _ in range(num): - series = series[np.random.permutation(length), :] - return series - - class GaussianMixtureNoiseModel(nn.Module): """Define a noise model parameterized as a mixture of gaussians. @@ -276,166 +263,178 @@ class GaussianMixtureNoiseModel(nn.Module): """ # TODO training a NM relies on getting a clean data(N2V e.g,) - def __init__(self, config: GaussianMixtureNMConfig): + def __init__(self, config: GaussianMixtureNMConfig) -> None: super().__init__() - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - if config.path is None: - self.mode = "train" - # TODO this is (probably) to train a nm. We leave it for later refactoring - weight = config.weight - n_gaussian = config.n_gaussian - n_coeff = config.n_coeff - min_signal = torch.Tensor([config.min_signal]) - max_signal = torch.Tensor([config.max_signal]) - # TODO min_sigma cant be None ? - self.min_sigma = config.min_sigma - if weight is None: - weight = torch.randn(n_gaussian * 3, n_coeff) - weight[n_gaussian : 2 * n_gaussian, 1] = ( - torch.log(max_signal - min_signal).float().to(self.device) - ) - weight.requires_grad = True - - self.n_gaussian = weight.shape[0] // 3 - self.n_coeff = weight.shape[1] - self.weight = weight - self.min_signal = torch.Tensor([min_signal]).to(self.device) - self.max_signal = torch.Tensor([max_signal]).to(self.device) - self.tol = torch.tensor([1e-10]).to(self.device) - # TODO refactor to train on CPU! - else: + + if config.path is not None: params = np.load(config.path) - self.mode = "inference" # TODO better name? + else: + params = config.model_dump(exclude_none=True) + + min_sigma = torch.tensor(params["min_sigma"]) + min_signal = torch.tensor(params["min_signal"]) + max_signal = torch.tensor(params["max_signal"]) + self.register_buffer("min_signal", min_signal) + self.register_buffer("max_signal", max_signal) + self.register_buffer("min_sigma", min_sigma) + self.register_buffer("tolerance", torch.tensor([1e-10])) + + if "trained_weight" in params: + weight = torch.tensor(params["trained_weight"]) + elif "weight" in params and params["weight"] is not None: + weight = torch.tensor(params["weight"]) + else: + weight = self._initialize_weights( + params["n_gaussian"], params["n_coeff"], max_signal, min_signal + ) - self.min_signal = torch.Tensor(params["min_signal"]) - self.max_signal = torch.Tensor(params["max_signal"]) + self.n_gaussian = weight.shape[0] // 3 + self.n_coeff = weight.shape[1] - self.weight = torch.Tensor(params["trained_weight"]) - self.min_sigma = params["min_sigma"].item() - self.n_gaussian = self.weight.shape[0] // 3 # TODO why // 3 ? - self.n_coeff = self.weight.shape[1] - self.tol = torch.Tensor([1e-10]) - self.min_signal = torch.Tensor([self.min_signal]) - self.max_signal = torch.Tensor([self.max_signal]) + self.register_parameter("weight", nn.Parameter(weight)) + self._set_model_mode(mode="prediction") print(f"[{self.__class__.__name__}] min_sigma: {self.min_sigma}") - def polynomialRegressor(self, weightParams, signals): - """Combines `weightParams` and signal `signals` to regress for the gaussian parameter values. + def _initialize_weights( + self, + n_gaussian: int, + n_coeff: int, + max_signal: torch.Tensor, + min_signal: torch.Tensor, + ) -> torch.Tensor: + """Create random weight initialization.""" + weight = torch.randn(n_gaussian * 3, n_coeff) + weight[n_gaussian : 2 * n_gaussian, 1] = torch.log( + max_signal - min_signal + ).float() + return weight + + def to_device(self, device: torch.device): + self.device = device + self.to(device) + + def _set_model_mode(self, mode: str) -> None: + """Move parameters to the device and set weights' requires_grad depending on the mode""" + if mode == "train": + self.weight.requires_grad = True + self.to_device(self.device) + else: + self.weight.requires_grad = False + self.to_device(torch.device("cpu")) + + def polynomial_regressor( + self, weight_params: torch.Tensor, signals: torch.Tensor + ) -> torch.Tensor: + """Combines `weight_params` and signal `signals` to regress for the gaussian parameter values. Parameters ---------- - weightParams : torch.cuda.FloatTensor + weight_params : Tensor Corresponds to specific rows of the `self.weight` - signals : torch.cuda.FloatTensor + signals : Tensor Signals Returns ------- - value : torch.cuda.FloatTensor + value : Tensor Corresponds to either of mean, standard deviation or weight, evaluated at `signals` """ - value = 0 - for i in range(weightParams.shape[0]): - value += weightParams[i] * ( + value = torch.zeros_like(signals) + for i in range(weight_params.shape[0]): + value += weight_params[i] * ( ((signals - self.min_signal) / (self.max_signal - self.min_signal)) ** i ) return value - def normalDens(self, x, m_=0.0, std_=None): - """Evaluates the normal probability density at `x` given the mean `m` and standard deviation `std`. + def normal_density( + self, x: torch.Tensor, mean: torch.Tensor, std: torch.Tensor + ) -> torch.Tensor: + """ + Evaluates the normal probability density at `x` given the mean `mean` and standard deviation `std`. Parameters ---------- - x: torch.cuda.FloatTensor + x: Tensor Observations - m_: torch.cuda.FloatTensor + mean: Tensor Mean - std_: torch.cuda.FloatTensor + std: Tensor Standard-deviation Returns ------- - tmp: torch.cuda.FloatTensor - Normal probability density of `x` given `m_` and `std_` - + tmp: Tensor + Normal probability density of `x` given `mean` and `std` """ - tmp = -((x - m_) ** 2) - tmp = tmp / (2.0 * std_ * std_) + tmp = -((x - mean) ** 2) + tmp = tmp / (2.0 * std * std) tmp = torch.exp(tmp) - tmp = tmp / torch.sqrt((2.0 * np.pi) * std_ * std_) - # print(tmp.min().item(), tmp.mean().item(), tmp.max().item(), tmp.shape) + tmp = tmp / torch.sqrt((2.0 * np.pi) * std * std) return tmp - def likelihood(self, observations, signals): - """Evaluates the likelihood of observations given the signals and the corresponding gaussian parameters. + def likelihood( + self, observations: torch.Tensor, signals: torch.Tensor + ) -> torch.Tensor: + """ + Evaluates the likelihood of observations given the signals and the corresponding gaussian parameters. Parameters ---------- - observations : torch.cuda.FloatTensor + observations : Tensor Noisy observations - signals : torch.cuda.FloatTensor + signals : Tensor Underlying signals Returns ------- - value :p + self.tol + value: torch.Tensor: Likelihood of observations given the signals and the GMM noise model - """ - if self.mode != "train": - signals = signals.cpu() - observations = observations.cpu() - self.weight = self.weight.to(signals.device) - self.min_signal = self.min_signal.to(signals.device) - self.max_signal = self.max_signal.to(signals.device) - self.tol = self.tol.to(signals.device) - - gaussianParameters = self.getGaussianParameters(signals) + gaussian_parameters: list[torch.Tensor] = self.get_gaussian_parameters(signals) p = 0 for gaussian in range(self.n_gaussian): p += ( - self.normalDens( + self.normal_density( observations, - gaussianParameters[gaussian], - gaussianParameters[self.n_gaussian + gaussian], + gaussian_parameters[gaussian], + gaussian_parameters[self.n_gaussian + gaussian], ) - * gaussianParameters[2 * self.n_gaussian + gaussian] + * gaussian_parameters[2 * self.n_gaussian + gaussian] ) - return p + self.tol + return p + self.tolerance - def getGaussianParameters(self, signals): - """Returns the noise model for given signals + def get_gaussian_parameters(self, signals: torch.Tensor) -> list[torch.Tensor]: + """ + Returns the noise model for given signals Parameters ---------- - signals : torch.cuda.FloatTensor + signals : Tensor Underlying signals Returns ------- - noiseModel: list of torch.cuda.FloatTensor + noise_model: list of Tensor Contains a list of `mu`, `sigma` and `alpha` for the `signals` """ - noiseModel = [] + noise_model = [] mu = [] sigma = [] alpha = [] kernels = self.weight.shape[0] // 3 for num in range(kernels): - mu.append(self.polynomialRegressor(self.weight[num, :], signals)) - # expval = torch.exp(torch.clamp(self.weight[kernels + num, :], max=MAX_VAR_W)) + mu.append(self.polynomial_regressor(self.weight[num, :], signals)) expval = torch.exp(self.weight[kernels + num, :]) - # self.maxval = max(self.maxval, expval.max().item()) - sigmaTemp = self.polynomialRegressor(expval, signals) - sigmaTemp = torch.clamp(sigmaTemp, min=self.min_sigma) - sigma.append(torch.sqrt(sigmaTemp)) + sigma_temp = self.polynomial_regressor(expval, signals) + sigma_temp = torch.clamp(sigma_temp, min=self.min_sigma) + sigma.append(torch.sqrt(sigma_temp)) expval = torch.exp( - self.polynomialRegressor(self.weight[2 * kernels + num, :], signals) - + self.tol + self.polynomial_regressor(self.weight[2 * kernels + num, :], signals) + + self.tolerance ) alpha.append(expval) @@ -459,15 +458,30 @@ def getGaussianParameters(self, signals): mu[ker] = mu[ker] - sum_means + signals for i in range(kernels): - noiseModel.append(mu[i]) + noise_model.append(mu[i]) for j in range(kernels): - noiseModel.append(sigma[j]) + noise_model.append(sigma[j]) for k in range(kernels): - noiseModel.append(alpha[k]) + noise_model.append(alpha[k]) + + return noise_model - return noiseModel + @staticmethod + def _fast_shuffle(series: torch.Tensor, num: int) -> torch.Tensor: + """Shuffle the inputs randomly num times""" + length = series.shape[0] + for _ in range(num): + idx = torch.randperm(length) + series = series[idx, :] + return series - def getSignalObservationPairs(self, signal, observation, lowerClip, upperClip): + def get_signal_observation_pairs( + self, + signal: NDArray, + observation: NDArray, + lower_clip: float, + upper_clip: float, + ) -> torch.Tensor: """Returns the Signal-Observation pixel intensities as a two-column array Parameters @@ -476,19 +490,18 @@ def getSignalObservationPairs(self, signal, observation, lowerClip, upperClip): Clean Signal Data observation: numpy array Noisy observation Data - lowerClip: float + lower_clip: float Lower percentile bound for clipping. - upperClip: float + upper_clip: float Upper percentile bound for clipping. Returns ------- - noiseModel: list of torch floats + noise_model: list of torch floats Contains a list of `mu`, `sigma` and `alpha` for the `signals` - """ - lb = np.percentile(signal, lowerClip) - ub = np.percentile(signal, upperClip) + lb = np.percentile(signal, lower_clip) + ub = np.percentile(signal, upper_clip) stepsize = observation[0].size n_observations = observation.shape[0] n_signals = signal.shape[0] @@ -501,19 +514,20 @@ def getSignalObservationPairs(self, signal, observation, lowerClip, upperClip): sig_obs_pairs = sig_obs_pairs[ (sig_obs_pairs[:, 0] > lb) & (sig_obs_pairs[:, 0] < ub) ] - return fastShuffle(sig_obs_pairs, 2) + sig_obs_pairs = sig_obs_pairs.astype(np.float32) + sig_obs_pairs = torch.from_numpy(sig_obs_pairs) + return self._fast_shuffle(sig_obs_pairs, 2) def fit( self, - signal, - observation, - learning_rate=1e-1, - batchSize=250000, - n_epochs=2000, - name="GMMNoiseModel.npz", - lowerClip=0, - upperClip=100, - ): + signal: NDArray, + observation: NDArray, + learning_rate: float = 1e-1, + batch_size: int = 250000, + n_epochs: int = 2000, + lower_clip: float = 0.0, + upper_clip: float = 100.0, + ) -> list[float]: """Training to learn the noise model from signal - observation pairs. Parameters @@ -524,49 +538,40 @@ def fit( Noisy Observation Data learning_rate: float Learning rate. Default = 1e-1. - batchSize: int + batch_size: int Nini-batch size. Default = 250000. n_epochs: int Number of epochs. Default = 2000. - name: string - - Model name. Default is `GMMNoiseModel`. This model after being trained is saved at the location `path`. - - lowerClip : int + lower_clip : int Lower percentile for clipping. Default is 0. - upperClip : int + upper_clip : int Upper percentile for clipping. Default is 100. - - """ - sig_obs_pairs = self.getSignalObservationPairs( - signal, observation, lowerClip, upperClip - ) - counter = 0 + self._set_model_mode(mode="train") optimizer = torch.optim.Adam([self.weight], lr=learning_rate) - loss_arr = [] + sig_obs_pairs = self.get_signal_observation_pairs( + signal, observation, lower_clip, upper_clip + ) + + train_losses = [] + counter = 0 for t in range(n_epochs): - if (counter + 1) * batchSize >= sig_obs_pairs.shape[0]: + if (counter + 1) * batch_size >= sig_obs_pairs.shape[0]: counter = 0 - sig_obs_pairs = fastShuffle(sig_obs_pairs, 1) + sig_obs_pairs = self._fast_shuffle(sig_obs_pairs, 1) batch_vectors = sig_obs_pairs[ - counter * batchSize : (counter + 1) * batchSize, : + counter * batch_size : (counter + 1) * batch_size, : ] - observations = batch_vectors[:, 1].astype(np.float32) - signals = batch_vectors[:, 0].astype(np.float32) - observations = ( - torch.from_numpy(observations.astype(np.float32)) - .float() - .to(self.device) - ) - signals = torch.from_numpy(signals).float().to(self.device) + observations = batch_vectors[:, 1].to(self.device) + signals = batch_vectors[:, 0].to(self.device) p = self.likelihood(observations, signals) - jointLoss = torch.mean(-torch.log(p)) - loss_arr.append(jointLoss.item()) + joint_loss = torch.mean(-torch.log(p)) + train_losses.append(joint_loss.item()) + if self.weight.isnan().any() or self.weight.isinf().any(): print( "NaN or Inf detected in the weights. Aborting training at epoch: ", @@ -575,19 +580,76 @@ def fit( break if t % 100 == 0: - print(t, np.mean(loss_arr)) + last_losses = train_losses[-100:] + print(t, np.mean(last_losses)) optimizer.zero_grad() - jointLoss.backward() + joint_loss.backward() optimizer.step() counter += 1 - self.trained_weight = self.weight.cpu().detach().numpy() - self.min_signal = self.min_signal.cpu().detach().numpy() - self.max_signal = self.max_signal.cpu().detach().numpy() + self._set_model_mode(mode="prediction") print("===================\n") + return train_losses + + def sample_observation_from_signal(self, signal: NDArray) -> NDArray: + """ + Sample an instance of observation based on an input signal using a + learned Gaussian Mixture Model. For each pixel in the input signal, + samples a corresponding noisy pixel. + + Parameters + ---------- + signal: numpy array + Clean 2D signal data. + + Returns + ------- + observation: numpy array + An instance of noisy observation data based on the input signal. + """ + assert len(signal.shape) == 2, "Only 2D inputs are supported." + + signal_tensor = torch.from_numpy(signal).to(torch.float32) + height, width = signal_tensor.shape + + with torch.no_grad(): + # Get gaussian parameters for each pixel + gaussian_params = self.get_gaussian_parameters(signal_tensor) + means = np.array(gaussian_params[: self.n_gaussian]) + stds = np.array(gaussian_params[self.n_gaussian : self.n_gaussian * 2]) + alphas = np.array(gaussian_params[self.n_gaussian * 2 :]) + + if self.n_gaussian == 1: + # Single gaussian case + observation = np.random.normal( + loc=means[0], scale=stds[0], size=(height, width) + ) + else: + # Multiple gaussians: sample component for each pixel + uniform = np.random.rand(1, height, width) + # Compute cumulative probabilities for component selection + cumulative_alphas = np.cumsum( + alphas, axis=0 + ) # Shape: (n_gaussian, height, width) + selected_component = np.argmax( + uniform < cumulative_alphas, axis=0, keepdims=True + ) + + # For every pixel, choose the corresponding gaussian + # and get the learned mu and sigma + selected_mus = np.take_along_axis(means, selected_component, axis=0) + selected_stds = np.take_along_axis(stds, selected_component, axis=0) + selected_mus = selected_mus.squeeze(0) + selected_stds = selected_stds.squeeze(0) + + # Sample from the normal distribution with learned mu and sigma + observation = np.random.normal( + selected_mus, selected_stds, size=(height, width) + ) + return observation - def save(self, path: str, name: str): + def save(self, path: str, name: str) -> None: """Save the trained parameters on the noise model. Parameters @@ -600,9 +662,9 @@ def save(self, path: str, name: str): os.makedirs(path, exist_ok=True) np.savez( os.path.join(path, name), - trained_weight=self.trained_weight, - min_signal=self.min_signal, - max_signal=self.max_signal, + trained_weight=self.weight.numpy(), + min_signal=self.min_signal.numpy(), + max_signal=self.max_signal.numpy(), min_sigma=self.min_sigma, ) print("The trained parameters (" + name + ") is saved at location: " + path) diff --git a/src/careamics/utils/plotting.py b/src/careamics/utils/plotting.py new file mode 100644 index 000000000..5ec78dc5d --- /dev/null +++ b/src/careamics/utils/plotting.py @@ -0,0 +1,78 @@ +"""Plotting utilities.""" + +from typing import Optional + +import matplotlib.pyplot as plt +import numpy as np +import torch +from numpy.typing import NDArray + +from careamics.models.lvae.noise_models import GaussianMixtureNoiseModel + + +def plot_noise_model_probability_distribution( + noise_model: GaussianMixtureNoiseModel, + signalBinIndex: int, + histogram: NDArray, + channel: Optional[str] = None, + number_of_bins: int = 100, +) -> None: + """Plot probability distribution P(x|s) for a certain ground truth signal. + + Predictions from both Histogram and GMM-based + Noise models are displayed for comparison. + + Parameters + ---------- + noise_model : GaussianMixtureNoiseModel + Trained GaussianMixtureNoiseModel. + signalBinIndex : int + Index of signal bin. Values go from 0 to number of bins (`n_bin`). + histogram : NDArray + Histogram based noise model. + channel : Optional[str], optional + Channel name used for plotting. Default is None. + number_of_bins : int, optional + Number of bins in the resulting histogram. Default is 100. + """ + min_signal = noise_model.min_signal.item() + max_signal = noise_model.max_signal.item() + bin_size = (max_signal - min_signal) / number_of_bins + + query_signal_normalized = signalBinIndex / number_of_bins + query_signal = query_signal_normalized * (max_signal - min_signal) + min_signal + query_signal += bin_size / 2 + query_signal = torch.tensor(query_signal) + + query_observations = torch.arange(min_signal, max_signal, bin_size) + query_observations += bin_size / 2 + + likelihoods = noise_model.likelihood( + observations=query_observations, signals=query_signal + ).numpy() + + plt.figure(figsize=(12, 5)) + if channel: + plt.suptitle(f"Noise model for channel {channel}") + else: + plt.suptitle("Noise model") + + plt.subplot(1, 2, 1) + plt.xlabel("Observation Bin") + plt.ylabel("Signal Bin") + plt.imshow(histogram**0.25, cmap="gray") + plt.axhline(y=signalBinIndex + 0.5, linewidth=5, color="blue", alpha=0.5) + + plt.subplot(1, 2, 2) + plt.plot( + query_observations, + likelihoods, + label="GMM : " + " signal = " + str(np.round(query_signal, 2)), + marker=".", + color="red", + linewidth=2, + ) + plt.xlabel("Observations (x) for signal s = " + str(query_signal)) + plt.ylabel("Probability Density") + plt.title("Probability Distribution P(x|s) at signal =" + str(query_signal)) + plt.legend() diff --git a/tests/models/lvae/test_noise_model.py b/tests/models/lvae/test_noise_model.py index d3d5820ad..87afb59aa 100644 --- a/tests/models/lvae/test_noise_model.py +++ b/tests/models/lvae/test_noise_model.py @@ -3,8 +3,10 @@ import numpy as np import pytest import torch +from scipy.stats import wasserstein_distance from careamics.config import GaussianMixtureNMConfig, MultiChannelNMConfig +from careamics.models.lvae.likelihoods import NoiseModelLikelihood from careamics.models.lvae.noise_models import ( GaussianMixtureNoiseModel, MultiChannelNoiseModel, @@ -141,18 +143,93 @@ def test_multi_channel_noise_model_likelihood( assert likelihood.shape == inp_shape -@pytest.mark.skip(reason="Need to refac noise model to be able to train on CPU") -def test_gm_noise_model_training(tmp_path): - x = np.random.rand(3) - y = np.random.rand(3) +@pytest.mark.parametrize( + "image_size, max_value, noise_scale", + [ + ([5, 128, 128], 255, 0.1), + ([5, 128, 128], 255, 0.5), + ], +) +def test_gm_noise_model_training(image_size, max_value, noise_scale): + gen = np.random.default_rng(42) + signal_normalized = gen.uniform(0, 1, image_size) + noise = gen.normal(0, noise_scale, image_size) + observation_normalized = signal_normalized + noise + signal = signal_normalized * max_value + observation = observation_normalized * max_value nm_config = GaussianMixtureNMConfig( - model_type="GaussianMixtureNoiseModel", signal=x, observation=y + model_type="GaussianMixtureNoiseModel", + n_gaussian=1, + min_signal=signal.min(), + max_signal=signal.max(), ) + noise_model = GaussianMixtureNoiseModel(nm_config) + training_losses = noise_model.fit( + signal=signal, observation=observation, n_epochs=500 + ) + initial_loss = training_losses[0] + last_loss = training_losses[-1] + # Check if model is training + assert initial_loss > last_loss + + # check if estimated mean and std of a noisy sample are close to real ones + signal_tensor = torch.from_numpy(signal).to(torch.float32) + mus, sigmas, _ = noise_model.get_gaussian_parameters(signal_tensor) + + # learned mean should be close to the mean of the signal + learned_mu = mus.mean() / max_value + real_mu = signal_normalized.mean() + assert np.allclose(learned_mu, real_mu, atol=1e-2) + + # learned sigma should be close to the noise sigma + learned_sigma = sigmas.mean() / max_value + noise_image = observation_normalized - signal_normalized + real_sigma = noise_image.std() + assert np.allclose(learned_sigma, real_sigma, atol=1e-2) + + +@pytest.mark.parametrize("image_size, max_value", [([256, 256], 255)]) +def test_noise_model_sampling(image_size, max_value): + gen = np.random.default_rng(42) + + signal = gen.uniform(0, 1, image_size) + observation = signal + gen.normal(0, 0.1, signal.shape) + signal = signal * max_value + observation = observation * max_value + nm_config = GaussianMixtureNMConfig( + model_type="GaussianMixtureNoiseModel", + n_gaussian=1, + min_sigma=100, + min_signal=signal.min(), + max_signal=signal.max(), + ) noise_model = GaussianMixtureNoiseModel(nm_config) + noise_model.fit(signal=signal, observation=observation, n_epochs=200) + sampled_noise_data = noise_model.sample_observation_from_signal(signal) + assert sampled_noise_data.shape == signal.shape + + real_noise = observation - signal + predicted_noise = sampled_noise_data - signal + real_noise = real_noise / max_value + predicted_noise = predicted_noise / max_value + noise_distribution_difference = wasserstein_distance( + real_noise.ravel(), predicted_noise.ravel() + ) + assert noise_distribution_difference < 0.1 + - # Test training - output = noise_model.train(x, y, n_epochs=2) - assert output is not None - # TODO do something with output ? +def test_noise_model_in_likelihood_call(): + test_input = torch.rand(256, 256) + test_target = torch.rand(256, 256) + + nm_config = GaussianMixtureNMConfig( + model_type="GaussianMixtureNoiseModel", n_gaussian=1 + ) + noise_model = GaussianMixtureNoiseModel(nm_config) + likelihood = NoiseModelLikelihood( + data_mean=test_input.mean(), data_std=test_input.std(), noise_model=noise_model + ) + log_likelihood, _ = likelihood(test_input, test_target) + assert log_likelihood is not None