diff --git a/src/careamics/lvae_training/calibration.py b/src/careamics/lvae_training/calibration.py index cc1c37a49..5319bb4cd 100644 --- a/src/careamics/lvae_training/calibration.py +++ b/src/careamics/lvae_training/calibration.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Union, Optional import numpy as np import torch @@ -34,9 +34,6 @@ def __init__(self, num_bins: int = 15): self._bins = num_bins self._bin_boundaries = None - def logvar_to_std(self, logvar: np.ndarray) -> np.ndarray: - return np.exp(logvar / 2) - def compute_bin_boundaries(self, predict_std: np.ndarray) -> np.ndarray: """Compute the bin boundaries for `num_bins` bins and predicted std values.""" min_std = np.min(predict_std) @@ -104,65 +101,75 @@ def compute_stats( ) rmse_stderr = np.sqrt(stderr) if stderr is not None else None - bin_var = np.mean((std_ch[bin_mask] ** 2)) + bin_var = np.mean(std_ch[bin_mask] ** 2) stats_dict[ch_idx]["rmse"].append(bin_error) stats_dict[ch_idx]["rmse_err"].append(rmse_stderr) stats_dict[ch_idx]["rmv"].append(np.sqrt(bin_var)) stats_dict[ch_idx]["bin_count"].append(bin_size) + self.stats_dict = stats_dict return stats_dict + def get_calibrated_factor_for_stdev( + self, + pred: Optional[np.ndarray] = None, + pred_std: Optional[np.ndarray] = None, + target: Optional[np.ndarray] = None, + q_s: float = 0.00001, + q_e: float = 0.99999, + ) -> dict[str, float]: + """Calibrate the uncertainty by multiplying the predicted std with a scalar. -def get_calibrated_factor_for_stdev( - pred: Union[np.ndarray, torch.Tensor], - pred_std: Union[np.ndarray, torch.Tensor], - target: Union[np.ndarray, torch.Tensor], - q_s: float = 0.00001, - q_e: float = 0.99999, - num_bins: int = 30, -) -> dict[str, float]: - """Calibrate the uncertainty by multiplying the predicted std with a scalar. - - Parameters - ---------- - pred : Union[np.ndarray, torch.Tensor] - Predicted image, shape (n, h, w, c). - pred_std : Union[np.ndarray, torch.Tensor] - Predicted std, shape (n, h, w, c). - target : Union[np.ndarray, torch.Tensor] - Target image, shape (n, h, w, c). - q_s : float, optional - Start quantile, by default 0.00001. - q_e : float, optional - End quantile, by default 0.99999. - num_bins : int, optional - Number of bins to use for calibration, by default 30. - - Returns - ------- - dict[str, float] - Calibrated factor for each channel (slope + intercept). - """ - calib = Calibration(num_bins=num_bins) - stats_dict = calib.compute_stats(pred, pred_std, target) - outputs = {} - for ch_idx in stats_dict.keys(): - y = stats_dict[ch_idx]["rmse"] - x = stats_dict[ch_idx]["rmv"] - count = stats_dict[ch_idx]["bin_count"] - - first_idx = get_first_index(count, q_s) - last_idx = get_last_index(count, q_e) - x = x[first_idx:-last_idx] - y = y[first_idx:-last_idx] - slope, intercept, *_ = stats.linregress(x, y) - output = {"scalar": slope, "offset": intercept} - outputs[ch_idx] = output - return outputs + Parameters + ---------- + stats_dict : dict[int, dict[str, Union[np.ndarray, list]]] + Dictionary containing the stats for each channel. + q_s : float, optional + Start quantile, by default 0.00001. + q_e : float, optional + End quantile, by default 0.99999. + + Returns + ------- + dict[str, float] + Calibrated factor for each channel (slope + intercept). + """ + if not hasattr(self, "stats_dict"): + print("No stats found. Computing stats...") + if any(v is None for v in [pred, pred_std, target]): + raise ValueError("pred, pred_std, and target must be provided.") + self.stats_dict = self.compute_stats( + pred=pred, pred_std=pred_std, target=target + ) + outputs = {} + for ch_idx in self.stats_dict.keys(): + y = self.stats_dict[ch_idx]["rmse"] + x = self.stats_dict[ch_idx]["rmv"] + count = self.stats_dict[ch_idx]["bin_count"] + + first_idx = get_first_index(count, q_s) + last_idx = get_last_index(count, q_e) + x = x[first_idx:-last_idx] + y = y[first_idx:-last_idx] + slope, intercept, *_ = stats.linregress(x, y) + output = {"scalar": slope, "offset": intercept} + outputs[ch_idx] = output + factors = self.get_factors_array(factors_dict=outputs) + return outputs, factors + + def get_factors_array(self, factors_dict: list[dict]): + """Get the calibration factors as a numpy array.""" + calib_scalar = [factors_dict[i]["scalar"] for i in range(len(factors_dict))] + calib_scalar = np.array(calib_scalar).reshape(1, 1, 1, -1) + calib_offset = [ + factors_dict[i].get("offset", 0.0) for i in range(len(factors_dict)) + ] + calib_offset = np.array(calib_offset).reshape(1, 1, 1, -1) + return {"scalar": calib_scalar, "offset": calib_offset} def plot_calibration(ax, calibration_stats): - first_idx = get_first_index(calibration_stats[0]["bin_count"], 0.001) - last_idx = get_last_index(calibration_stats[0]["bin_count"], 0.999) + first_idx = get_first_index(calibration_stats[0]["bin_count"], 0.0001) + last_idx = get_last_index(calibration_stats[0]["bin_count"], 0.9999) ax.plot( calibration_stats[0]["rmv"][first_idx:-last_idx], calibration_stats[0]["rmse"][first_idx:-last_idx], @@ -170,15 +177,15 @@ def plot_calibration(ax, calibration_stats): label=r"$\hat{C}_0$: Ch1", ) - first_idx = get_first_index(calibration_stats[1]["bin_count"], 0.001) - last_idx = get_last_index(calibration_stats[1]["bin_count"], 0.999) + first_idx = get_first_index(calibration_stats[1]["bin_count"], 0.0001) + last_idx = get_last_index(calibration_stats[1]["bin_count"], 0.9999) ax.plot( calibration_stats[1]["rmv"][first_idx:-last_idx], calibration_stats[1]["rmse"][first_idx:-last_idx], "o", - label=r"$\hat{C}_1: : Ch2$", + label=r"$\hat{C}_1$: Ch2", ) - + # TODO add multichannel ax.set_xlabel("RMV") ax.set_ylabel("RMSE") ax.legend() diff --git a/src/careamics/lvae_training/dataset/lc_dataset.py b/src/careamics/lvae_training/dataset/lc_dataset.py index af710fe28..ac18e2c86 100644 --- a/src/careamics/lvae_training/dataset/lc_dataset.py +++ b/src/careamics/lvae_training/dataset/lc_dataset.py @@ -97,7 +97,8 @@ def reduce_data( ] self.N = len(t_list) - self.set_img_sz(self._img_sz, self._grid_sz) + # TODO where tf is self._img_sz defined? + self.set_img_sz([self._img_sz, self._img_sz], self._grid_sz) print( f"[{self.__class__.__name__}] Data reduced. New data shape: {self._data.shape}" ) diff --git a/src/careamics/lvae_training/dataset/multich_dataset.py b/src/careamics/lvae_training/dataset/multich_dataset.py index 87180ab3b..ae2a1c30c 100644 --- a/src/careamics/lvae_training/dataset/multich_dataset.py +++ b/src/careamics/lvae_training/dataset/multich_dataset.py @@ -359,8 +359,8 @@ def reduce_data( self._noise_data = self._noise_data[ t_list, h_start:h_end, w_start:w_end, : ].copy() - - self.set_img_sz(self._img_sz, self._grid_sz) + # TODO where tf is self._img_sz defined? + self.set_img_sz([self._img_sz, self._img_sz], self._grid_sz) print( f"[{self.__class__.__name__}] Data reduced. New data shape: {self._data.shape}" ) diff --git a/src/careamics/lvae_training/dataset/types.py b/src/careamics/lvae_training/dataset/types.py index 7ae1230b5..1eb92e786 100644 --- a/src/careamics/lvae_training/dataset/types.py +++ b/src/careamics/lvae_training/dataset/types.py @@ -3,7 +3,7 @@ class DataType(Enum): Elisa3DData = 0 - NicolaData = 1 + HTLIF24Data = 1 Pavia3SeqData = 2 TavernaSox2GolgiV2 = 3 Dao3ChannelWithInput = 4 diff --git a/src/careamics/lvae_training/eval_utils.py b/src/careamics/lvae_training/eval_utils.py index 0c16ae3a4..2705e60bb 100644 --- a/src/careamics/lvae_training/eval_utils.py +++ b/src/careamics/lvae_training/eval_utils.py @@ -7,23 +7,18 @@ """ import os -from typing import List, Literal, Union +from typing import Optional import matplotlib import matplotlib.pyplot as plt import numpy as np -from scipy import stats import torch -from torch import nn -from torch.utils.data import Dataset from matplotlib.gridspec import GridSpec -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, Dataset, Subset from tqdm import tqdm from careamics.lightning import VAEModule - -from careamics.models.lvae.utils import ModelType -from careamics.utils.metrics import scale_invariant_psnr, RunningPSNR +from careamics.utils.metrics import scale_invariant_psnr class TilingMode: @@ -149,11 +144,10 @@ def plot_crops( tar, tar_hsnr, recon_img_list, - calibration_stats, + calibration_stats=None, num_samples=2, baseline_preds=None, ): - """ """ if baseline_preds is None: baseline_preds = [] if len(baseline_preds) > 0: @@ -164,15 +158,13 @@ def plot_crops( ) print("This happens when we want to predict the edges of the image.") return + color_ch_list = ["goldenrod", "cyan"] + color_pred = "red" + insetplot_xmax_value = 10000 + insetplot_xmin_value = -1000 + inset_min_labelsize = 10 + inset_rect = [0.05, 0.05, 0.4, 0.2] - # color_ch_list = ['goldenrod', 'cyan'] - # color_pred = 'red' - # insetplot_xmax_value = 10000 - # insetplot_xmin_value = -1000 - # inset_min_labelsize = 10 - # inset_rect = [0.05, 0.05, 0.4, 0.2] - - # Set plot attributes img_sz = 3 ncols = num_samples + len(baseline_preds) + 1 + 1 + 1 + 1 + 1 * (num_samples > 1) grid_factor = 5 @@ -191,7 +183,6 @@ def plot_crops( ) params = {"mathtext.default": "regular"} plt.rcParams.update(params) - # plot baselines for i in range(2, 2 + len(baseline_preds)): for col_idx in range(baseline_preds[0].shape[0]): @@ -471,52 +462,17 @@ def plot_error(target, prediction, cmap=matplotlib.cm.coolwarm, ax=None, max_val plt.colorbar(img_err, ax=ax) -# ------------------------------------------------------------------------------------------------ - - -def get_predictions(idx, val_dset, model, mmse_count=50, patch_size=256): - """ - Given an index and a validation/test set, it returns the input, target and the reconstructed images for that index. - """ - print(f"Predicting for {idx}") - val_dset.set_img_sz(patch_size, 64) - - with torch.no_grad(): - # val_dset.enable_noise() - inp, tar = val_dset[idx] - # val_dset.disable_noise() - - inp = torch.Tensor(inp[None]) - tar = torch.Tensor(tar[None]) - inp = inp.cuda() - x_normalized = model.normalize_input(inp) - tar = tar.cuda() - tar_normalized = model.normalize_target(tar) - - recon_img_list = [] - for _ in range(mmse_count): - recon_normalized, td_data = model(x_normalized) - rec_loss, imgs = model.get_reconstruction_loss( - recon_normalized, - x_normalized, - tar_normalized, - return_predicted_img=True, - ) - imgs = model.unnormalize_target(imgs) - recon_img_list.append(imgs.cpu().numpy()[0]) - - recon_img_list = np.array(recon_img_list) - return inp, tar, recon_img_list +# ------------------------------------------------------------------------------------- -def get_dset_predictions( +def get_predictions( model: VAEModule, dset: Dataset, batch_size: int, - loss_type: Literal["musplit", "denoisplit", "denoisplit_musplit"], + tile_size: Optional[tuple[int, int]] = None, mmse_count: int = 1, num_workers: int = 4, -) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, List[float]]: +) -> tuple[dict, dict, dict]: """Get patch-wise predictions from a model for the entire dataset. Parameters @@ -545,6 +501,55 @@ def get_dset_predictions( - losses: Reconstruction losses for the predictions. - psnr: PSNR values for the predictions. """ + if hasattr(dset, "dsets"): + multifile_stitched_predictions = {} + multifile_stitched_stds = {} + for d in dset.dsets: + stitched_predictions, stitched_stds = get_single_file_mmse( + model=model, + dset=d, + batch_size=batch_size, + tile_size=tile_size, + mmse_count=mmse_count, + num_workers=num_workers, + ) + # get filename without extension and path + filename = str(d._fpath).split("/")[-1].split(".")[0] + multifile_stitched_predictions[filename] = stitched_predictions + multifile_stitched_stds[filename] = stitched_stds + return ( + multifile_stitched_predictions, + multifile_stitched_stds, + ) + else: + stitched_predictions, stitched_stds = get_single_file_mmse( + model=model, + dset=dset, + batch_size=batch_size, + tile_size=tile_size, + mmse_count=mmse_count, + num_workers=num_workers, + ) + # get filename without extension and path + filename = str(dset._fpath).split("/")[-1].split(".")[0] + return ( + {filename: stitched_predictions}, + {filename: stitched_stds}, + ) + + +def get_single_file_predictions( + model: VAEModule, + dset: Dataset, + batch_size: int, + tile_size: Optional[tuple[int, int]] = None, + grid_size: Optional[int] = None, + num_workers: int = 4, +) -> tuple[np.ndarray, np.ndarray]: + """Get patch-wise predictions from a model for a single file dataset.""" + if tile_size and grid_size: + dset.set_img_sz(tile_size, grid_size) + dloader = DataLoader( dset, pin_memory=False, @@ -552,43 +557,64 @@ def get_dset_predictions( shuffle=False, batch_size=batch_size, ) + model.eval() + model.cuda() + tiles = [] + logvar_arr = [] + with torch.no_grad(): + for batch in tqdm(dloader, desc="Predicting tiles"): + inp, tar = batch + inp = inp.cuda() + tar = tar.cuda() + + # get model output + rec, _ = model(inp) - gauss_likelihood = model.gaussian_likelihood - nm_likelihood = model.noise_model_likelihood + # get reconstructed img + if model.model.predict_logvar is None: + rec_img = rec + logvar = torch.tensor([-1]) + else: + rec_img, logvar = torch.chunk(rec, chunks=2, dim=1) + logvar_arr.append(logvar.cpu().numpy()) # Why do we need this ? + + tiles.append(rec_img.cpu().numpy()) + + tile_samples = np.concatenate(tiles, axis=0) + return stitch_predictions_new(tile_samples, dset) - predictions = [] - predictions_std = [] - losses = [] + +def get_single_file_mmse( + model: VAEModule, + dset: Dataset, + batch_size: int, + tile_size: Optional[tuple[int, int]] = None, + mmse_count: int = 1, + num_workers: int = 4, +) -> tuple[np.ndarray, np.ndarray]: + """Get patch-wise predictions from a model for a single file dataset.""" + dloader = DataLoader( + dset, + pin_memory=False, + num_workers=num_workers, + shuffle=False, + batch_size=batch_size, + ) + if tile_size: + dset.set_img_sz(tile_size, tile_size[-1] // 2) + model.eval() + model.cuda() + tile_mmse = [] + tile_stds = [] logvar_arr = [] - num_channels = dset[0][1].shape[0] - patch_psnr_channels = [RunningPSNR() for _ in range(num_channels)] with torch.no_grad(): - for batch in tqdm(dloader, desc="Predicting patches"): + for batch in tqdm(dloader, desc="Predicting tiles"): inp, tar = batch inp = inp.cuda() tar = tar.cuda() rec_img_list = [] - for mmse_idx in range(mmse_count): - - # TODO: case of HDN left for future refactoring - # if model_type == ModelType.Denoiser: - # assert model.denoise_channel in [ - # "Ch1", - # "Ch2", - # "input", - # ], '"all" denoise channel not supported for evaluation. Pick one of "Ch1", "Ch2", "input"' - - # x_normalized_new, tar_new = model.get_new_input_target( - # (inp, tar, *batch[2:]) - # ) - # rec, _ = model(x_normalized_new) - # rec_loss, imgs = model.get_reconstruction_loss( - # rec, - # tar, - # x_normalized_new, - # return_predicted_img=True, - # ) + for _ in range(mmse_count): # get model output rec, _ = model(inp) @@ -600,52 +626,21 @@ def get_dset_predictions( else: rec_img, logvar = torch.chunk(rec, chunks=2, dim=1) rec_img_list.append(rec_img.cpu().unsqueeze(0)) # add MMSE dim - logvar_arr.append(logvar.cpu().numpy()) - - # compute reconstruction loss - # if loss_type == "musplit": - # rec_loss = get_reconstruction_loss( - # reconstruction=rec, target=tar, likelihood_obj=gauss_likelihood - # ) - # elif loss_type == "denoisplit": - # rec_loss = get_reconstruction_loss( - # reconstruction=rec, target=tar, likelihood_obj=nm_likelihood - # ) - # elif loss_type == "denoisplit_musplit": - # rec_loss = reconstruction_loss_musplit_denoisplit( - # predictions=rec, - # targets=tar, - # gaussian_likelihood=gauss_likelihood, - # nm_likelihood=nm_likelihood, - # nm_weight=model.loss_parameters.denoisplit_weight, - # gaussian_weight=model.loss_parameters.musplit_weight, - # ) - # rec_loss = {"loss": rec_loss} # hacky, but ok for now - - # # store rec loss values for first pred - # if mmse_idx == 0: - # try: - # losses.append(rec_loss["loss"].cpu().numpy()) - # except: - # losses.append(rec_loss["loss"]) - - # update running PSNR - # for i in range(num_channels): - # patch_psnr_channels[i].update(rec_img[:, i], tar[:, i]) + logvar_arr.append(logvar.cpu().numpy()) # Why do we need this ? # aggregate results samples = torch.cat(rec_img_list, dim=0) mmse_imgs = torch.mean(samples, dim=0) # avg over MMSE dim - # mmse_std = torch.std(samples, dim=0) - predictions.append(mmse_imgs.cpu().numpy()) - # predictions_std.append(mmse_std.cpu().numpy()) - - # psnr = [x.get() for x in patch_psnr_channels] - return np.concatenate(predictions, axis=0) - # np.concatenate(predictions_std, axis=0), - # np.concatenate(logvar_arr), - # np.array(losses), - # psnr, # TODO revisit ! + std_imgs = torch.std(samples, dim=0) # std over MMSE dim + + tile_mmse.append(mmse_imgs.cpu().numpy()) + tile_stds.append(std_imgs.cpu().numpy()) + + tiles_arr = np.concatenate(tile_mmse, axis=0) + tile_stds = np.concatenate(tile_stds, axis=0) + stitched_predictions = stitch_predictions_new(tiles_arr, dset) + stitched_stds = stitch_predictions_new(tile_stds, dset) + return stitched_predictions, stitched_stds # ------------------------------------------------------------------------------------------ diff --git a/src/careamics/models/lvae/lvae.py b/src/careamics/models/lvae/lvae.py index 4b4a789cf..7f496ae38 100644 --- a/src/careamics/models/lvae/lvae.py +++ b/src/careamics/models/lvae/lvae.py @@ -6,7 +6,7 @@ """ from collections.abc import Iterable -from typing import Union +from typing import Optional, Union import numpy as np import torch @@ -834,3 +834,15 @@ def get_top_prior_param_shape(self, n_imgs: int = 1): # TODO check if model_3D_depth is needed ? top_layer_shape = (n_imgs, mu_logvar, self._model_3D_depth, h, w) return top_layer_shape + + def reset_for_inference(self, tile_size: Optional[tuple[int, int]] = None): + """Should be called if we want to predict for a different input/output size.""" + self.mode_pred = True + if tile_size is None: + tile_size = self.image_size + self.image_size = tile_size + for i in range(self.n_layers): + self.bottom_up_layers[i].output_expected_shape = ( + ts // 2 ** (i + 1) for ts in tile_size + ) + self.top_down_layers[i].latent_shape = tile_size diff --git a/src/careamics/models/lvae/stochastic.py b/src/careamics/models/lvae/stochastic.py index e7f899c9c..948fbfe2c 100644 --- a/src/careamics/models/lvae/stochastic.py +++ b/src/careamics/models/lvae/stochastic.py @@ -193,6 +193,7 @@ def compute_kl_metrics( z: torch.Tensor The sampled latent tensor. """ + kl_samplewise_restricted = None if mode_pred is False: # if not predicting if analytical_kl: kl_elementwise = kl_divergence(q, p) diff --git a/src/careamics/utils/metrics.py b/src/careamics/utils/metrics.py index 0900a9023..d057840b8 100644 --- a/src/careamics/utils/metrics.py +++ b/src/careamics/utils/metrics.py @@ -14,6 +14,31 @@ # TODO: does this add additional dependency? +# TODO revisit metric for notebook +def avg_range_invariant_psnr( + pred: np.ndarray, + target: np.ndarray, +) -> float: + """Compute the average range-invariant PSNR. + + Parameters + ---------- + pred : np.ndarray + Predicted images. + target : np.ndarray + Target images. + + Returns + ------- + float + Average range-invariant PSNR value. + """ + psnr_arr = [] + for i in range(pred.shape[0]): + psnr_arr.append(scale_invariant_psnr(pred[i], target[i])) + return np.mean(psnr_arr) + + def psnr(gt: np.ndarray, pred: np.ndarray, data_range: float) -> float: """ Peak Signal to Noise Ratio.