diff --git a/.gitignore b/.gitignore index eb3fbdf9a..116e26420 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ .coverage coverage.xml lightning_logs +.mypy_cache \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 962892429..4c27447a3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,6 +17,7 @@ repos: rev: v0.4.7 hooks: - id: ruff + exclude: "^src/careamics/lvae_training/.*|^src/careamics/models/lvae/.*" args: [--fix, --target-version, py38] - repo: https://github.com/psf/black @@ -29,6 +30,7 @@ repos: hooks: - id: mypy files: "^src/" + exclude: "^src/careamics/lvae_training/.*|^src/careamics/models/lvae/.*" additional_dependencies: - numpy - types-PyYAML @@ -39,6 +41,7 @@ repos: rev: v1.7.0 hooks: - id: numpydoc-validation + exclude: "^src/careamics/lvae_training/.*|^src/careamics/models/lvae/.*" # # jupyter linting and formatting # - repo: https://github.com/nbQA-dev/nbQA diff --git a/examples/evaluate_LVAE.ipynb b/examples/evaluate_LVAE.ipynb new file mode 100644 index 000000000..a7c5d74e3 --- /dev/null +++ b/examples/evaluate_LVAE.ipynb @@ -0,0 +1,1356 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "### Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.display import display, HTML\n", + "display(HTML(\"\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "import random\n", + "import sys\n", + "import os \n", + "\n", + "import numpy as np\n", + "import torch\n", + "import pickle\n", + "import ml_collections\n", + "import glob\n", + "import torch\n", + "from torch.utils.data import DataLoader\n", + "import torch.nn as nn\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib.patches as patches\n", + "import matplotlib\n", + "from tqdm import tqdm\n", + "from copy import deepcopy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "DATA_ROOT = '/group/jug/federico/careamics_training/data'\n", + "OUT_ROOT = '/group/jug/federico/careamics_training/training'\n", + "CODE_ROOT = '/home/federico.carrara/'\n", + "DEBUG = False" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "sys.path.append(os.path.join(CODE_ROOT, 'Documents/projects/careamics/src'))\n", + "\n", + "from careamics.lvae_training.train_lvae import create_dataset\n", + "from careamics.models.lvae.utils import (\n", + " ModelType, LossType\n", + ")\n", + "from careamics.models.lvae import get_config\n", + "from careamics.lvae_training.data_utils import DataType, DataSplitType, GridAlignement, load_tiff\n", + "from careamics.lvae_training.metrics import (\n", + " PSNR, \n", + " RangeInvariantPsnr,\n", + " avg_psnr,\n", + " avg_range_inv_psnr,\n", + " avg_ssim,\n", + " compute_masked_psnr,\n", + " compute_multiscale_ssim\n", + ")\n", + "from careamics.lvae_training.train_utils import get_mean_std_dict_for_model\n", + "from careamics.lvae_training.lightning_module import LadderVAELight\n", + "from careamics.lvae_training.eval_utils import (\n", + " show_for_one, \n", + " get_plots_output_dir,\n", + " get_dset_predictions,\n", + " stitch_predictions,\n", + " Calibration,\n", + " get_calibrated_factor_for_stdev,\n", + " plot_calibration,\n", + " clean_ax,\n", + " plot_error\n", + ")\n", + "# from disentangle.analysis.lvae_utils import get_img_from_forward_output\n", + "# from disentangle.analysis.plot_utils import get_k_largest_indices,plot_imgs_from_idx\n", + "# from disentangle.analysis.critic_notebook_utils import get_mmse_dict, get_label_separated_loss\n", + "# from disentangle.sampler.random_sampler import RandomSampler\n", + "\n", + "torch.multiprocessing.set_sharing_strategy('file_system')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "def fix_seeds():\n", + " torch.manual_seed(0)\n", + " torch.cuda.manual_seed(0)\n", + " np.random.seed(0)\n", + " random.seed(0)\n", + " torch.backends.cudnn.deterministic = True" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "ckpt_dir = os.path.join(OUT_ROOT, '2406/LVAE_denoiSplit/53')\n", + "assert os.path.exists(ckpt_dir)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "# def get_dtype(ckpt_fpath):\n", + "# if os.path.isdir(ckpt_fpath):\n", + "# ckpt_fpath = ckpt_fpath[:-1] if ckpt_fpath[-1] == '/' else ckpt_fpath\n", + "# elif os.path.isfile(ckpt_fpath):\n", + "# ckpt_fpath = os.path.dirname(ckpt_fpath)\n", + "# assert ckpt_fpath[-1] != '/'\n", + "# return int(ckpt_fpath.split('/')[-2].split('-')[0][1:])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "# dtype = get_dtype(ckpt_dir)" + ] + }, + { + "cell_type": "markdown", + "id": "9", + "metadata": {}, + "source": [ + "### Set Evaluation Parameters" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "# Set parameters\n", + "mmse_count = 10\n", + "image_size_for_grid_centers = 32 # what we retain from inner padding/tiling\n", + "eval_patch_size = None # actual patch size --> if not specified data.image_size\n", + "data_t_list = None # list of indexes of the data to be used\n", + "model_type = ModelType.LadderVae\n", + "eval_datasplit_type = DataSplitType.Val \n", + "psnr_type = 'range_invariant' #'simple', 'range_invariant'\n", + "enable_calibration = True\n", + "which_ckpt = 'last' # 'best', 'last'\n", + "\n", + "save_comparative_plots = False\n", + "batch_size = 32\n", + "num_workers = 4\n", + "COMPUTE_LOSS = False\n", + "use_deterministic_grid = None # for training -> get one 64x64 patch at random (not from the grid)\n", + "\n", + "# threshold = None # 0.02\n", + "# compute_kl_loss = False\n", + "# evaluate_train = False # inspect training performance\n", + "# val_repeat_factor = None" + ] + }, + { + "cell_type": "markdown", + "id": "11", + "metadata": {}, + "source": [ + "### Load config " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": {}, + "outputs": [], + "source": [ + "def get_model_checkpoint(ckpt_dir, mode='best'):\n", + " output = []\n", + " if mode == 'best':\n", + " for filename in glob.glob(ckpt_dir + \"/*_best.ckpt\"):\n", + " output.append(filename)\n", + " elif mode == 'last':\n", + " for filename in glob.glob(ckpt_dir + \"/*_last.ckpt\"):\n", + " output.append(filename)\n", + " else:\n", + " raise ValueError(f\"Mode can be either 'best' or 'last', while you selected {mode}.\")\n", + " assert len(output) == 1, '\\n'.join(output)\n", + " return output[0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13", + "metadata": {}, + "outputs": [], + "source": [ + "def load_config(config_fpath):\n", + " if os.path.isdir(config_fpath):\n", + " config_fpath = os.path.join(config_fpath, 'config.pkl')\n", + " else:\n", + " assert config_fpath[-4:] == '.pkl', f'{config_fpath} is not a pickle file. Aborting'\n", + " with open(config_fpath, 'rb') as f:\n", + " config = pickle.load(f)\n", + " return config" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14", + "metadata": {}, + "outputs": [], + "source": [ + "if os.path.isdir(ckpt_dir):\n", + " config = load_config(ckpt_dir)\n", + "else:\n", + " config = load_config(os.path.dirname(ckpt_dir))\n", + "\n", + "config = ml_collections.ConfigDict(config)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15", + "metadata": {}, + "outputs": [], + "source": [ + "print(config)" + ] + }, + { + "cell_type": "markdown", + "id": "16", + "metadata": {}, + "source": [ + "Changing config parameters should not be needed anymore, since only few parameters of the model are customizable now" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "17", + "metadata": {}, + "outputs": [], + "source": [ + "training_image_size = None\n", + "training_grid_size = None\n", + "with config.unlocked():\n", + "# if 'test_fraction' not in config.training:\n", + "# config.training.test_fraction =0.0\n", + "\n", + " if 'datadir' not in config:\n", + " config.datadir = ''\n", + "\n", + "# if 'encoder' not in config.model:\n", + "# config.model.encoder = ml_collections.ConfigDict()\n", + "# assert 'decoder' not in config.model\n", + "# config.model.decoder = ml_collections.ConfigDict()\n", + " \n", + "# config.model.encoder.dropout = config.model.dropout\n", + "# config.model.decoder.dropout = config.model.dropout\n", + "# config.model.encoder.n_filters = config.model.n_filters\n", + "# config.model.decoder.n_filters = config.model.n_filters\n", + " \n", + "# if 'multiscale_retain_spatial_dims' not in config.model.decoder:\n", + "# config.model.decoder.multiscale_retain_spatial_dims = False\n", + " \n", + "# if 'res_block_kernel' not in config.model.encoder:\n", + "# config.model.encoder.res_block_kernel = 3\n", + "# assert 'res_block_kernel' not in config.model.decoder\n", + "# config.model.decoder.res_block_kernel = 3\n", + " \n", + "# if 'res_block_skip_padding' not in config.model.encoder:\n", + "# config.model.encoder.res_block_skip_padding = False\n", + "# assert 'res_block_skip_padding' not in config.model.decoder\n", + "# config.model.decoder.res_block_skip_padding = False\n", + " \n", + "# if 'skip_bottom_layers_count' in config.model:\n", + "# config.model.skip_bottom_layers_count = 0\n", + " \n", + "# if 'logvar_lowerbound' not in config.model:\n", + "# config.model.logvar_lowerbound = None\n", + " \n", + "# if 'train_aug_rotate' not in config.data:\n", + "# config.data.train_aug_rotate = False\n", + " \n", + "# if 'multiscale_lowres_separate_branch' not in config.model:\n", + "# config.model.multiscale_lowres_separate_branch = False\n", + " \n", + "# if 'multiscale_retain_spatial_dims' not in config.model:\n", + "# config.model.multiscale_retain_spatial_dims = False\n", + " \n", + "# config.data.train_aug_rotate=False\n", + " \n", + "# if 'randomized_channels' not in config.data:\n", + "# config.data.randomized_channels = False\n", + " \n", + " if 'predict_logvar' not in config.model:\n", + " config.model.predict_logvar = None\n", + " \n", + " # if 'batchnorm' in config.model and 'batchnorm' not in config.model.encoder:\n", + " # assert 'batchnorm' not in config.model.decoder\n", + " # config.model.decoder.batchnorm = config.model.batchnorm\n", + " # config.model.encoder.batchnorm = config.model.batchnorm\n", + " \n", + "# if 'conv2d_bias' not in config.model.decoder:\n", + "# config.model.decoder.conv2d_bias = True\n", + " \n", + " if eval_patch_size is not None:\n", + " training_image_size = config.data.image_size\n", + " config.data.image_size = eval_patch_size\n", + "\n", + " if image_size_for_grid_centers is not None:\n", + " training_grid_size = config.data.get('grid_size', \"grid_size not present\")\n", + " config.data.grid_size = image_size_for_grid_centers\n", + "\n", + "# if use_deterministic_grid is not None:\n", + "# config.data.deterministic_grid = use_deterministic_grid\n", + "\n", + "# if threshold is not None:\n", + "# config.data.threshold = threshold\n", + "\n", + "# if val_repeat_factor is not None:\n", + "# config.training.val_repeat_factor = val_repeat_factor\n", + "\n", + "# config.model.mode_pred = not compute_kl_loss\n", + " \n", + "# if 'skip_receptive_field_loss_tokens' not in config.loss:\n", + "# config.loss.skip_receptive_field_loss_tokens = []\n", + " \n", + "# if 'lowres_merge_type' not in config.model.encoder:\n", + "# config.model.encoder.lowres_merge_type = 0\n", + " \n", + "# if 'validtarget_random_fraction' in config.data:\n", + "# config.data.validtarget_random_fraction = None\n", + "\n", + "# if 'input_is_sum' not in config.data:\n", + "# config.data.input_is_sum = False\n", + "\n", + "# print(config)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18", + "metadata": {}, + "outputs": [], + "source": [ + "dtype = config.data.data_type\n", + "\n", + "if DEBUG:\n", + " if dtype == DataType.CustomSinosoid:\n", + " data_dir = f'{DATA_ROOT}/sinosoid/'\n", + " elif dtype == DataType.OptiMEM100_014:\n", + " data_dir = f'{DATA_ROOT}/microscopy/'\n", + "else:\n", + " if dtype in [DataType.CustomSinosoid, DataType.CustomSinosoidThreeCurve]:\n", + " data_dir = f'{DATA_ROOT}/sinosoid_without_test/sinosoid/'\n", + " elif dtype == DataType.OptiMEM100_014:\n", + " data_dir = f'{DATA_ROOT}/microscopy/'\n", + " elif dtype == DataType.Prevedel_EMBL:\n", + " data_dir = f'{DATA_ROOT}/Prevedel_EMBL/PKG_3P_dualcolor_stacks/NoAverage_NoRegistration/'\n", + " elif dtype == DataType.AllenCellMito:\n", + " data_dir = f'{DATA_ROOT}/allencell/2017_03_08_Struct_First_Pass_Seg/AICS-11/'\n", + " elif dtype == DataType.SeparateTiffData:\n", + " data_dir = f'{DATA_ROOT}/ventura_gigascience'\n", + " elif dtype == DataType.SemiSupBloodVesselsEMBL:\n", + " data_dir = f'{DATA_ROOT}/EMBL_halfsupervised/Demixing_3P'\n", + " elif dtype == DataType.Pavia2VanillaSplitting:\n", + " data_dir = f'{DATA_ROOT}/pavia2'\n", + " elif dtype == DataType.ExpansionMicroscopyMitoTub:\n", + " data_dir = f'{DATA_ROOT}/expansion_microscopy_Nick/'\n", + " elif dtype == DataType.ShroffMitoEr:\n", + " data_dir = f'{DATA_ROOT}/shrofflab/'\n", + " elif dtype == DataType.HTIba1Ki67:\n", + " data_dir = f'{DATA_ROOT}/Stefania/20230327_Ki67_and_Iba1_trainingdata/'\n", + " elif dtype == DataType.BioSR_MRC:\n", + " data_dir = f'{DATA_ROOT}/BioSR/'\n", + " elif dtype == DataType.ExpMicroscopyV2:\n", + " data_dir = f'{DATA_ROOT}/expansion_microscopy_v2/'\n", + " elif dtype == DataType.TavernaSox2GolgiV2:\n", + " data_dir = f'{DATA_ROOT}/TavernaSox2Golgi/acquisition2/'\n", + " elif dtype == DataType.Pavia3SeqData:\n", + " data_dir = f'{DATA_ROOT}/pavia3_sequential/'\n", + " elif dtype == DataType.NicolaData:\n", + " data_dir = f'{DATA_ROOT}/nikola_data/raw'\n", + " \n", + "print(data_dir)" + ] + }, + { + "cell_type": "markdown", + "id": "19", + "metadata": {}, + "source": [ + "### Load data and model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "20", + "metadata": {}, + "outputs": [], + "source": [ + "padding_kwargs = {'mode': 'constant',}\n", + "padding_kwargs['constant_values'] = config.data.get('padding_value', 0)\n", + "\n", + "dloader_kwargs = {\n", + " 'overlapping_padding_kwargs': padding_kwargs, \n", + " 'grid_alignment': GridAlignement.Center\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "21", + "metadata": {}, + "outputs": [], + "source": [ + "train_dset, val_dset = create_dataset(\n", + " config, \n", + " data_dir, \n", + " eval_datasplit_type=eval_datasplit_type,\n", + " kwargs_dict=dloader_kwargs\n", + ")\n", + "data_mean, data_std = train_dset.get_mean_std()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "22", + "metadata": {}, + "outputs": [], + "source": [ + "# create dataset without poisson noise as ground truth\n", + "new_config = deepcopy(ml_collections.ConfigDict(config))\n", + "if 'poisson_noise_factor' in new_config.data:\n", + " new_config.data.poisson_noise_factor = -1\n", + "if 'enable_gaussian_noise' in new_config.data:\n", + " new_config.data.enable_gaussian_noise = False \n", + " \n", + "_, highsnr_val_dset = create_dataset(\n", + " new_config, \n", + " data_dir, \n", + " eval_datasplit_type=eval_datasplit_type,\n", + " kwargs_dict=dloader_kwargs\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "23", + "metadata": {}, + "outputs": [], + "source": [ + "with config.unlocked():\n", + " if training_image_size is not None:\n", + " config.data.image_size = training_image_size\n", + " \n", + "mean_dict, std_dict = get_mean_std_dict_for_model(config, train_dset)\n", + " \n", + "model = LadderVAELight(\n", + " config, \n", + " mean_dict, \n", + " std_dict,\n", + " target_ch=config.data.num_channels\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "24", + "metadata": {}, + "outputs": [], + "source": [ + "if os.path.isdir(ckpt_dir):\n", + " ckpt_fpath = get_model_checkpoint(ckpt_dir, mode=which_ckpt)\n", + "else:\n", + " assert os.path.isfile(ckpt_dir)\n", + " ckpt_fpath = ckpt_dir\n", + "\n", + "print('Loading checkpoint from', ckpt_fpath)\n", + "checkpoint = torch.load(ckpt_fpath)\n", + "\n", + "_ = model.load_state_dict(checkpoint['state_dict'], strict=False)\n", + "model.eval()\n", + "_= model.cuda()\n", + "\n", + "model.set_params_to_same_device_as(torch.Tensor(1).cuda())\n", + "\n", + "print('Loading from epoch', checkpoint['epoch'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "25", + "metadata": {}, + "outputs": [], + "source": [ + "def count_parameters(model):\n", + " return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", + "\n", + "print(f'Model has {count_parameters(model)/1000_000:.3f}M parameters')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "26", + "metadata": {}, + "outputs": [], + "source": [ + "if config.data.multiscale_lowres_count is not None and eval_patch_size is not None:\n", + " model.reset_for_different_output_size(eval_patch_size)" + ] + }, + { + "cell_type": "markdown", + "id": "27", + "metadata": {}, + "source": [ + "### From here on we perform evaluation" + ] + }, + { + "cell_type": "markdown", + "id": "28", + "metadata": {}, + "source": [ + "Visualize Data: noisy & ground truth" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "29", + "metadata": {}, + "outputs": [], + "source": [ + "# Print input (first row) and target (second row) of the val_dset\n", + "idx = np.random.randint(len(val_dset))\n", + "inp_tmp, tar_tmp, *_ = val_dset[idx]\n", + "gt_inp_tmp, gt_tar_tmp, *_ = highsnr_val_dset[idx]\n", + "\n", + "# Noisy\n", + "ncols = len(tar_tmp)\n", + "nrows = 2\n", + "_, ax = plt.subplots(figsize=(4*ncols,4*nrows), ncols=ncols, nrows=nrows)\n", + "plt.suptitle(\"Noisy patches\")\n", + "for i in range(min(ncols, len(inp_tmp))):\n", + " ax[0,i].imshow(inp_tmp[i])\n", + "\n", + "for channel_id in range(ncols):\n", + " ax[1,channel_id].imshow(tar_tmp[channel_id])\n", + " \n", + "# Ground truth\n", + "ncols = len(gt_tar_tmp)\n", + "_, ax = plt.subplots(figsize=(4*ncols,4*nrows), ncols=ncols, nrows=nrows)\n", + "plt.suptitle(\"Ground Truth patches\")\n", + "for i in range(min(ncols, len(gt_inp_tmp))):\n", + " ax[0,i].imshow(gt_inp_tmp[i])\n", + "\n", + "for channel_id in range(ncols):\n", + " ax[1,channel_id].imshow(gt_tar_tmp[channel_id])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "30", + "metadata": {}, + "outputs": [], + "source": [ + "if data_t_list is not None:\n", + " val_dset.reduce_data(t_list=data_t_list)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "31", + "metadata": {}, + "outputs": [], + "source": [ + "def get_full_input_frame(idx, dset):\n", + " img_tuples, noise_tuples = dset._load_img(idx)\n", + " if len(noise_tuples) > 0:\n", + " factor = np.sqrt(2) if dset._input_is_sum else 1.0\n", + " img_tuples = [x + noise_tuples[0] * factor for x in img_tuples]\n", + "\n", + " inp = 0\n", + " for nch in img_tuples:\n", + " inp += nch/len(img_tuples)\n", + " h_start, w_start = dset._get_deterministic_hw(idx)\n", + " return inp, h_start, w_start" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "32", + "metadata": {}, + "outputs": [], + "source": [ + "index = np.random.randint(len(val_dset))\n", + "inp, tar = val_dset[index]\n", + "frame, h_start, w_start = get_full_input_frame(index, val_dset)\n", + "print(h_start, w_start)" + ] + }, + { + "cell_type": "markdown", + "id": "33", + "metadata": {}, + "source": [ + "#### Plot predictions against a baseline for specific indexes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "34", + "metadata": {}, + "outputs": [], + "source": [ + "def get_hwt_start(idx):\n", + " h,w,t = val_dset.idx_manager.hwt_from_idx(idx, grid_size=64)\n", + " print(h,w,t)\n", + " pad = val_dset.per_side_overlap_pixelcount()\n", + " h = h - pad\n", + " w = w - pad\n", + " return h,w,t\n", + "\n", + "def get_crop_from_fulldset_prediction(full_dset_pred, idx, patch_size=256):\n", + " h,w,t = get_hwt_start(idx)\n", + " return np.swapaxes(full_dset_pred[t,h:h+patch_size,w:w+patch_size].astype(np.float32)[None], 0, 3)[...,0]\n", + "\n", + "if save_comparative_plots: # this is false...\n", + " assert eval_datasplit_type == DataSplitType.Test\n", + " # CCP vs Microtubules: 925, 659, 502\n", + " # hdn_usplitdata = load_tiff('/group/jug/ashesh/data/paper_stats/Test_PNone_G16_M3_Sk0/pred_disentangle_2402_D23-M3-S0-L0_67.tif')\n", + " hdn_usplitdata = load_tiff('/group/jug/ashesh/data/paper_stats/Test_PNone_G32_M5_Sk0/pred_disentangle_2403_D23-M3-S0-L0_29.tif')\n", + "\n", + " # ER vs Microtubule 853, 859, 332\n", + " # hdn_usplitdata = load_tiff('/group/jug/ashesh/data/paper_stats/Test_PNone_G16_M3_Sk0/pred_disentangle_2402_D23-M3-S0-L0_60.tif')\n", + "\n", + " # ER vs CCP 327, 479, 637, 568\n", + " # hdn_usplitdata = load_tiff('/group/jug/ashesh/data/paper_stats/Test_PNone_G16_M3_Sk0/pred_disentangle_2402_D23-M3-S0-L0_59.tif')\n", + "\n", + " # F-actin vs ER 797\n", + " # hdn_usplitdata = load_tiff('/group/jug/ashesh/data/paper_stats/Test_PNone_G32_M10_Sk0/pred_disentangle_2403_D23-M3-S0-L0_15.tif')\n", + "\n", + " idx = 10 #np.random.randint(len(val_dset))\n", + " patch_size = 500\n", + " mmse_count = 50\n", + " print(idx)\n", + " show_for_one(\n", + " idx, val_dset, \n", + " highsnr_val_dset, \n", + " model, \n", + " None, \n", + " mmse_count=mmse_count, \n", + " patch_size=patch_size, \n", + " baseline_preds=[\n", + " get_crop_from_fulldset_prediction(hdn_usplitdata, idx).astype(np.float32),\n", + " ],\n", + " num_samples=0\n", + " )\n", + "\n", + " plotsdir = get_plots_output_dir(\n", + " ckpt_dir, \n", + " patch_size, \n", + " mmse_count=mmse_count\n", + " )\n", + " \n", + " model_id = ckpt_dir.strip('/').split('/')[-1]\n", + " fname = f'patch_comparison_{idx}_{model_id}.png'\n", + " fpath = os.path.join(plotsdir, fname)\n", + " plt.savefig(fpath, dpi=200, bbox_inches='tight')\n", + " print(f'Saved to {fpath}')" + ] + }, + { + "cell_type": "markdown", + "id": "35", + "metadata": {}, + "source": [ + "#### Compute predictions and related metrics (PSNR) for the entire validation set" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "36", + "metadata": {}, + "outputs": [], + "source": [ + "# patch-wise PSNR here\n", + "\n", + "pred_tiled, rec_loss, logvar_tiled, patch_psnr_tuple, pred_std_tiled = get_dset_predictions(\n", + " model, \n", + " val_dset,\n", + " batch_size,\n", + " num_workers=num_workers,\n", + " mmse_count=mmse_count,\n", + " model_type = model_type,\n", + ")\n", + "tmp = np.round([x.item() for x in patch_psnr_tuple],2)\n", + "print('Patch wise PSNR, as computed during training', tmp, np.mean(tmp))\n", + "print(f'Number of predicted tiles: {pred_tiled.shape[0]}, channels: {pred_tiled.shape[1]}, shape: {pred_tiled.shape[2:]}')\n", + "print(f'Reconstruction loss distrib: {np.quantile(rec_loss, [0,0.01,0.5, 0.9,0.99,0.999,1]).round(2)}')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "37", + "metadata": {}, + "outputs": [], + "source": [ + "# Print tiles in which the logvar is very low\n", + "idx_list = np.where(logvar_tiled.squeeze() < -6)[0]\n", + "if len(idx_list) > 0:\n", + " plt.imshow(val_dset[idx_list[0]][1][1])" + ] + }, + { + "cell_type": "markdown", + "id": "38", + "metadata": {}, + "source": [ + "Get full image predictions by stitching the predicted tiles" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "39", + "metadata": {}, + "outputs": [], + "source": [ + "if pred_tiled.shape[-1] != val_dset.get_img_sz():\n", + " pad = (val_dset.get_img_sz() - pred_tiled.shape[-1] )//2\n", + " pred_tiled = np.pad(pred_tiled, ((0,0),(0,0),(pad,pad),(pad,pad)))\n", + "\n", + "# Stitch tiled predictions\n", + "pred = stitch_predictions(\n", + " pred_tiled, \n", + " val_dset, \n", + " smoothening_pixelcount=0\n", + ")\n", + "\n", + "# Stitch predicted tiled logvar\n", + "if len(np.unique(logvar_tiled)) == 1:\n", + " logvar = None\n", + "else:\n", + " logvar = stitch_predictions(logvar_tiled, val_dset, smoothening_pixelcount=0)\n", + "\n", + "# Stitch the std of the predictions (i.e., std computed on the mmse_count predictions)\n", + "pred_std = stitch_predictions(pred_std_tiled, val_dset, smoothening_pixelcount=0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "40", + "metadata": {}, + "outputs": [], + "source": [ + "if 'target_idx_list' in config.data and config.data.target_idx_list is not None:\n", + " pred = pred[...,:len(config.data.target_idx_list)]\n", + " pred_std = pred_std[...,:len(config.data.target_idx_list)]" + ] + }, + { + "cell_type": "markdown", + "id": "41", + "metadata": {}, + "source": [ + "Ignore (and remove) the pixels which are present in the last few rows and columns (since not multiples of patch_size)\n", + "1. They don't come in the batches. So, in prediction, they are simply zeros. So they are being are ignored right now. \n", + "2. For the border pixels which are on the top and the left, overlapping yields worse performance. This is becuase, there is nothing to overlap on one side. So, they are essentially zero padded. This makes the performance worse. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "42", + "metadata": {}, + "outputs": [], + "source": [ + "def get_ignored_pixels():\n", + " ignored_pixels = 1\n", + " while(pred[0, -ignored_pixels:, -ignored_pixels:,].std() == 0):\n", + " ignored_pixels+=1\n", + " ignored_pixels-=1\n", + " print(f'In {pred.shape}, last {ignored_pixels} many rows and columns are all zero.')\n", + " return ignored_pixels\n", + "\n", + "actual_ignored_pixels = get_ignored_pixels()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "43", + "metadata": {}, + "outputs": [], + "source": [ + "if config.data.data_type in [\n", + " DataType.OptiMEM100_014,\n", + " DataType.SemiSupBloodVesselsEMBL, \n", + " DataType.Pavia2VanillaSplitting,\n", + " DataType.ExpansionMicroscopyMitoTub,\n", + " DataType.ShroffMitoEr,\n", + " DataType.HTIba1Ki67\n", + "]:\n", + " ignored_last_pixels = 32 \n", + "elif config.data.data_type == DataType.BioSR_MRC:\n", + " ignored_last_pixels = 44\n", + " if val_dset.get_img_sz() == 128:\n", + " ignored_last_pixels = 108\n", + "elif config.data.data_type == DataType.NicolaData:\n", + " ignored_last_pixels = 8\n", + "else:\n", + " ignored_last_pixels = 0\n", + "\n", + "ignore_first_pixels = 0\n", + "# ignored_last_pixels = 160\n", + "assert actual_ignored_pixels <= ignored_last_pixels, f'Set ignored_last_pixels={actual_ignored_pixels}'\n", + "print(ignored_last_pixels)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "44", + "metadata": {}, + "outputs": [], + "source": [ + "tar = val_dset._data\n", + "if 'target_idx_list' in config.data and config.data.target_idx_list is not None:\n", + " tar = tar[...,config.data.target_idx_list]\n", + "\n", + "def ignore_pixels(arr, patch_size):\n", + " if arr.shape[2] % patch_size:\n", + " if ignore_first_pixels:\n", + " arr = arr[:,ignore_first_pixels:,ignore_first_pixels:]\n", + " if ignored_last_pixels:\n", + " arr = arr[:,:-ignored_last_pixels,:-ignored_last_pixels]\n", + "\n", + " return arr\n", + "\n", + "pred = ignore_pixels(pred, val_dset.get_img_sz())\n", + "tar = ignore_pixels(tar, val_dset.get_img_sz())\n", + "if pred_std is not None:\n", + " pred_std = ignore_pixels(pred_std, val_dset.get_img_sz())\n", + " \n", + "print(pred.shape)" + ] + }, + { + "cell_type": "markdown", + "id": "45", + "metadata": {}, + "source": [ + "#### Perform Calibration" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "46", + "metadata": {}, + "outputs": [], + "source": [ + "sep_mean, sep_std = model.data_mean, model.data_std\n", + "if isinstance(sep_mean, dict):\n", + " sep_mean = sep_mean['target']\n", + " sep_std = sep_std['target']\n", + "\n", + "if isinstance(sep_mean, int):\n", + " pass\n", + "else:\n", + " sep_mean = sep_mean.squeeze()[None,None,None]\n", + " sep_std = sep_std.squeeze()[None,None,None]\n", + " sep_mean = sep_mean.cpu().numpy() \n", + " sep_std = sep_std.cpu().numpy()\n", + "\n", + "tar_normalized = (tar - sep_mean)/ sep_std\n", + "\n", + "# Check if normalization is correct (i.e., not already applied on tar)\n", + "print(f\"Channelwise means: tar -> {tar.mean(axis=(0,1,2))}, normalized -> {tar_normalized.mean(axis=(0,1,2))}\")" + ] + }, + { + "cell_type": "markdown", + "id": "47", + "metadata": {}, + "source": [ + "Plot RMV vs. RMSE without Calibration" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "48", + "metadata": {}, + "outputs": [], + "source": [ + "# NOTE: Recall the `pred_std` here is the pixel-wise std of the mmse_count many predictions\n", + "if enable_calibration:\n", + " calib = Calibration(\n", + " num_bins=30, \n", + " mode='pixelwise'\n", + " )\n", + " native_stats = calib.compute_stats(\n", + " pred=pred, \n", + " pred_logvar=pred_std, \n", + " target=tar_normalized\n", + " )\n", + " count = np.array(native_stats[0]['bin_count'])\n", + " count = count / count.sum()\n", + " # print(count.cumsum()[:-1])\n", + " plt.plot(native_stats[0]['rmv'][1:-1], native_stats[0]['rmse'][1:-1], 'o')\n", + " plt.title(\"RMV vs. RMSE plot - Not Calibrated\")\n", + " plt.xlabel('RMV'), plt.ylabel('RMSE')" + ] + }, + { + "cell_type": "markdown", + "id": "49", + "metadata": {}, + "source": [ + "Observe that the plot is far from resembling y = x!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "50", + "metadata": {}, + "outputs": [], + "source": [ + "def get_calibration_fnames(ckpt_dir):\n", + " tokens = ckpt_dir.strip('/').split('/')\n", + " modelid = int(tokens[-1])\n", + " model_specs = tokens[-2].replace('-','')\n", + " monthyear = tokens[-3]\n", + " fname_factor = f'calibration_factor_{monthyear}_{model_specs}_{modelid}.npy'\n", + " fname_stats = f'calibration_stats_{monthyear}_{model_specs}_{modelid}.pkl.npy'\n", + " return {'stats': fname_stats, 'factor': fname_factor}\n", + "\n", + "def get_calibration_factor_fname(ckpt_dir):\n", + " return get_calibration_fnames(ckpt_dir)['factor']\n", + "\n", + "def get_calibration_stats_fname(ckpt_dir):\n", + " return get_calibration_fnames(ckpt_dir)['stats']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "51", + "metadata": {}, + "outputs": [], + "source": [ + "if enable_calibration:\n", + " inp, _ = val_dset[0]\n", + " plotsdir = get_plots_output_dir(OUT_ROOT, inp.shape[1], mmse_count=mmse_count)\n", + " fname = get_calibration_factor_fname(ckpt_dir)\n", + " factor_fpath = os.path.join(plotsdir, fname)\n", + " \n", + " # Compute calibration factors\n", + " if eval_datasplit_type == DataSplitType.Val:\n", + " # Compute calibration factors for the channels\n", + " calib_factor0 = get_calibrated_factor_for_stdev(pred[...,0], np.log(pred_std[...,0]**2), tar_normalized[...,0], batch_size=8, lr=0.1)\n", + " calib_factor1 = get_calibrated_factor_for_stdev(pred[...,1], np.log(pred_std[...,1]**2), tar_normalized[...,1], batch_size=8, lr=0.1)\n", + " print(calib_factor0, calib_factor1)\n", + " calib_factor = np.array([calib_factor0, calib_factor1]).reshape(1,1,1,2)\n", + " np.save(factor_fpath, calib_factor)\n", + " print(f'Saved calibration factor fitted on validation set to {factor_fpath}')\n", + "\n", + " # Use pre-computed calibration factor\n", + " elif eval_datasplit_type == DataSplitType.Test:\n", + " print('Loading the calibration factor from the file', factor_fpath)\n", + " calib_factor = np.load(factor_fpath)\n", + "\n", + " # Given the calibration factor, plot RMV vs. RMSE\n", + " calib = Calibration(num_bins=30, mode='pixelwise')\n", + " pred_logvar = 2* np.log(pred_std * calib_factor)\n", + " stats = calib.compute_stats(\n", + " pred,\n", + " pred_logvar, \n", + " tar_normalized\n", + " )\n", + " _,ax = plt.subplots(figsize=(5,5))\n", + " plt.title(\"RMV vs. RMSE plot - Calibrated\")\n", + " plot_calibration(ax, stats)\n", + "\n", + "if eval_datasplit_type == DataSplitType.Test:\n", + " stats_fpath = os.path.join(plotsdir, get_calibration_stats_fname(ckpt_dir))\n", + " np.save(stats_fpath, stats)\n", + " print('Saved stats of Test set to ', stats_fpath)" + ] + }, + { + "cell_type": "markdown", + "id": "52", + "metadata": {}, + "source": [ + "A fancier Calibration Plot with multiple calibration factors:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "53", + "metadata": {}, + "outputs": [], + "source": [ + "def get_last_index(bin_count, quantile):\n", + " cumsum = np.cumsum(bin_count)\n", + " normalized_cumsum = cumsum / cumsum[-1]\n", + " for i in range(1, len(normalized_cumsum)):\n", + " if normalized_cumsum[-i] < quantile:\n", + " return i - 1\n", + " return None\n", + "\n", + "\n", + "def get_first_index(bin_count, quantile):\n", + " cumsum = np.cumsum(bin_count)\n", + " normalized_cumsum = cumsum / cumsum[-1]\n", + " for i in range(len(normalized_cumsum)):\n", + " if normalized_cumsum[i] > quantile:\n", + " return i\n", + " return None" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "54", + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " calib_factors = [\n", + " np.load(os.path.join('/path/to/calibration/factors/dir/', fpath), allow_pickle=True) \n", + " for fpath in [\n", + " 'calibration_stats_1.pkl.npy',\n", + " 'calibration_stats_2.pkl.npy',\n", + " 'calibration_stats_3.pkl.npy', \n", + " ]\n", + " ]\n", + " labels = ['w=0.5', 'w=0.9', 'w=1']\n", + "except FileNotFoundError:\n", + " print('Calibration factors not found. Skipping the plot.')\n", + " calib_factors = []\n", + "\n", + "if len(calib_factors) > 0:\n", + " _,ax = plt.subplots(figsize=(5,2.5))\n", + " for i, calibration_stats in enumerate(calib_factors):\n", + " first_idx = get_first_index(calibration_stats[()][0]['bin_count'], 0.0001)\n", + " last_idx = get_last_index(calibration_stats[()][0]['bin_count'], 0.9999)\n", + " ax.plot(\n", + " calibration_stats[()][0]['rmv'][first_idx:-last_idx],\n", + " calibration_stats[()][0]['rmse'][first_idx:-last_idx],\n", + " '-+',\n", + " label=labels[i]\n", + " )\n", + "\n", + " ax.yaxis.grid(color='gray', linestyle='dashed')\n", + " ax.xaxis.grid(color='gray', linestyle='dashed')\n", + " ax.plot(np.arange(0,1.5, 0.01), np.arange(0,1.5, 0.01), 'k--')\n", + " ax.set_facecolor('xkcd:light grey')\n", + " plt.legend(loc='lower right')\n", + " # plt.xlim(0,3)\n", + " # plt.ylim(0,1.25)\n", + " plt.xlabel('RMV')\n", + " plt.ylabel('RMSE')\n", + " ax.set_axisbelow(True)\n", + "\n", + "\n", + " plotsdir = get_plots_output_dir(ckpt_dir, 0, mmse_count=0)\n", + " model_id = ckpt_dir.strip('/').split('/')[-1]\n", + " fname = f'calibration_plot_{model_id}.png'\n", + " fpath = os.path.join(plotsdir, fname)\n", + " # plt.savefig(fpath, dpi=200, bbox_inches='tight')\n", + " print(f'Saved to {fpath}')\n" + ] + }, + { + "cell_type": "markdown", + "id": "55", + "metadata": {}, + "source": [ + "#### Visually compare Targets and Predictions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "56", + "metadata": {}, + "outputs": [], + "source": [ + "# One random target vs predicted image (patch of shape [sz x sz])\n", + "ncols = tar.shape[-1]\n", + "_,ax = plt.subplots(figsize=(ncols*5, 2*5), nrows=2, ncols=ncols)\n", + "img_idx = 0\n", + "sz = 800\n", + "hs = np.random.randint(tar.shape[1] - sz)\n", + "ws = np.random.randint(tar.shape[2] - sz)\n", + "for i in range(ncols):\n", + " ax[i,0].set_title(f'Target Channel {i+1}')\n", + " ax[i,0].imshow(tar[0, hs:hs+sz, ws:ws+sz, i])\n", + " ax[i,1].set_title(f'Predicted Channel {i+1}')\n", + " ax[i,1].imshow(pred[0, hs:hs+sz, ws:ws+sz, i])\n", + "\n", + "# plt.subplots_adjust(wspace=0.1, hspace=0.1)\n", + "# clean_ax(ax)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "57", + "metadata": {}, + "outputs": [], + "source": [ + "nrows = pred.shape[-1]\n", + "img_sz = 3\n", + "_,ax = plt.subplots(figsize=(4*img_sz,nrows*img_sz), ncols=4, nrows=nrows)\n", + "idx = np.random.randint(len(pred))\n", + "print(idx)\n", + "for ch_id in range(nrows):\n", + " ax[ch_id,0].set_title(f'Target Channel {ch_id+1}')\n", + " ax[ch_id,0].imshow(tar_normalized[idx,..., ch_id], cmap='magma')\n", + " ax[ch_id,1].set_title(f'Predicted Channel {ch_id+1}')\n", + " ax[ch_id,1].imshow(pred[idx,:,:,ch_id], cmap='magma')\n", + " plot_error(\n", + " tar_normalized[idx,...,ch_id], \n", + " pred[idx,:,:,ch_id], \n", + " cmap = matplotlib.cm.coolwarm, \n", + " ax = ax[ch_id,2], \n", + " max_val = None\n", + " )\n", + "\n", + " cropsz = 256\n", + " h_s = np.random.randint(0, tar_normalized.shape[1] - cropsz)\n", + " h_e = h_s + cropsz\n", + " w_s = np.random.randint(0, tar_normalized.shape[2] - cropsz)\n", + " w_e = w_s + cropsz\n", + "\n", + " plot_error(\n", + " tar_normalized[idx,h_s:h_e,w_s:w_e, ch_id], \n", + " pred[idx,h_s:h_e,w_s:w_e,ch_id], \n", + " cmap = matplotlib.cm.coolwarm, \n", + " ax = ax[ch_id,3], \n", + " max_val = None\n", + " )\n", + "\n", + " # Add rectangle to the region\n", + " rect = patches.Rectangle((w_s, h_s), w_e-w_s, h_e-h_s, linewidth=1, edgecolor='r', facecolor='none')\n", + " ax[ch_id,2].add_patch(rect)\n" + ] + }, + { + "cell_type": "markdown", + "id": "58", + "metadata": {}, + "source": [ + "#### Compute metrics between predicted data and high-SNR (ground truth) data" + ] + }, + { + "cell_type": "markdown", + "id": "59", + "metadata": {}, + "source": [ + "Prepare data:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "60", + "metadata": {}, + "outputs": [], + "source": [ + "# ch1_pred_unnorm = pred[...,0]*sep_std[...,0].cpu().numpy() + sep_mean[...,0].cpu().numpy()\n", + "# ch2_pred_unnorm = pred[...,1]*sep_std[...,1].cpu().numpy() + sep_mean[...,1].cpu().numpy()\n", + "pred_unnorm = []\n", + "for i in range(pred.shape[-1]):\n", + " if sep_std.shape[-1]==1:\n", + " temp_pred_unnorm = pred[...,i]*sep_std[...,0] + sep_mean[...,0]\n", + " else:\n", + " temp_pred_unnorm = pred[...,i]*sep_std[...,i] + sep_mean[...,i]\n", + " pred_unnorm.append(temp_pred_unnorm)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "61", + "metadata": {}, + "outputs": [], + "source": [ + "# Get & process high-SNR data from previously loaded dataset\n", + "highres_data = highsnr_val_dset._data\n", + "if highres_data is not None:\n", + " highres_data = ignore_pixels(highres_data, highsnr_val_dset.get_img_sz()).copy()\n", + " if data_t_list is not None:\n", + " highres_data = highres_data[data_t_list].copy()\n", + " \n", + " if 'target_idx_list' in config.data and config.data.target_idx_list is not None:\n", + " highres_data = highres_data[...,config.data.target_idx_list]" + ] + }, + { + "cell_type": "markdown", + "id": "62", + "metadata": {}, + "source": [ + "Compute metrics:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "63", + "metadata": {}, + "outputs": [], + "source": [ + "if highres_data is not None:\n", + " print(f'{DataSplitType.name(eval_datasplit_type)}_P{eval_patch_size}_G{image_size_for_grid_centers}_M{mmse_count}_Sk{ignored_last_pixels}')\n", + " psnr_list = [avg_range_inv_psnr(highres_data[...,k], pred_unnorm[k]) for k in range(len(pred_unnorm))]\n", + " tar_tmp = (highres_data - sep_mean) /sep_std\n", + " # tar0_tmp = (highres_data[...,0] - sep_mean[...,0]) /sep_std[...,0]\n", + " ssim_list = compute_multiscale_ssim(tar_tmp, pred)\n", + " # ssim1_hres_mean, ssim1_hres_std = avg_ssim(highres_data[...,0], pred_unnorm[0])\n", + " # ssim2_hres_mean, ssim2_hres_std = avg_ssim(highres_data[...,1], pred_unnorm[1])\n", + " print('PSNR on Highres', ' '.join([str(x) for x in psnr_list]))\n", + " print('SSIM on Highres', ' '.join([str(np.round(x,3)) for x in ssim_list]))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "64", + "metadata": {}, + "outputs": [], + "source": [ + "rmse_arr = []\n", + "psnr_arr = []\n", + "rinv_psnr_arr = []\n", + "ssim_arr = []\n", + "for ch_id in range(pred.shape[-1]):\n", + " rmse =np.sqrt(((pred[...,ch_id] - tar_normalized[...,ch_id])**2).reshape(len(pred),-1).mean(axis=1))\n", + " rmse_arr.append(rmse)\n", + " psnr = avg_psnr(tar_normalized[...,ch_id].copy(), pred[...,ch_id].copy()) \n", + " rinv_psnr = avg_range_inv_psnr(tar_normalized[...,ch_id].copy(), pred[...,ch_id].copy())\n", + " ssim_mean, ssim_std = avg_ssim(tar[...,ch_id], pred_unnorm[ch_id])\n", + " psnr_arr.append(psnr)\n", + " rinv_psnr_arr.append(rinv_psnr)\n", + " ssim_arr.append((ssim_mean,ssim_std))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "65", + "metadata": {}, + "outputs": [], + "source": [ + "print(f'{DataSplitType.name(eval_datasplit_type)}_P{eval_patch_size}_G{image_size_for_grid_centers}_M{mmse_count}_Sk{ignored_last_pixels}')\n", + "print('Rec Loss: ', np.round(rec_loss.mean(),3) )\n", + "print('RMSE: ', ' <--> '.join([str(np.mean(x).round(3)) for x in rmse_arr]))\n", + "print('PSNR: ', ' <--> '.join([str(x) for x in psnr_arr]))\n", + "print('RangeInvPSNR: ',' <--> '.join([str(x) for x in rinv_psnr_arr]))\n", + "print('SSIM: ',' <--> '.join([f'{round(x,3)}±{round(y,4)}' for (x,y) in ssim_arr]))\n", + "print()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "usplit", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.-1" + }, + "vscode": { + "interpreter": { + "hash": "e959a19f8af3b4149ff22eb57702a46c14a8caae5a2647a6be0b1f60abdfa4c2" + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pyproject.toml b/pyproject.toml index 1c62e6c79..8a51f73e6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ classifiers = [ dependencies = [ 'numpy<2.0.0', 'torch>=2.0.0', + 'torchvision', 'bioimageio.core>=0.6.0', 'tifffile', 'psutil', diff --git a/src/careamics/careamist.py b/src/careamics/careamist.py index 89aeeeb5a..1368a32d9 100644 --- a/src/careamics/careamist.py +++ b/src/careamics/careamist.py @@ -561,7 +561,9 @@ def predict( `tile_size`. Test-time augmentation (TTA) can be switched off using the `tta_transforms` - parameter. + parameter. The TTA augmentation applies all possible flip and 90 degrees + rotations to the prediction input and averages the predictions. TTA augmentation + should not be used if you did not train with these augmentations. Note that if you are using a UNet model and tiling, the tile size must be divisible in every dimension by 2**d, where d is the depth of the model. This diff --git a/src/careamics/dataset/dataset_utils/__init__.py b/src/careamics/dataset/dataset_utils/__init__.py index 69db09e87..b6a626aaf 100644 --- a/src/careamics/dataset/dataset_utils/__init__.py +++ b/src/careamics/dataset/dataset_utils/__init__.py @@ -10,12 +10,16 @@ "get_read_func", "read_zarr", "iterate_over_files", + "WelfordStatistics", ] -from .dataset_utils import compute_normalization_stats, reshape_array +from .dataset_utils import ( + reshape_array, +) from .file_utils import get_files_size, list_files, validate_source_target_files from .iterate_over_files import iterate_over_files from .read_tiff import read_tiff from .read_utils import get_read_func from .read_zarr import read_zarr +from .running_stats import WelfordStatistics, compute_normalization_stats diff --git a/src/careamics/dataset/dataset_utils/dataset_utils.py b/src/careamics/dataset/dataset_utils/dataset_utils.py index 6da4f122a..ebaed0d46 100644 --- a/src/careamics/dataset/dataset_utils/dataset_utils.py +++ b/src/careamics/dataset/dataset_utils/dataset_utils.py @@ -99,25 +99,3 @@ def reshape_array(x: np.ndarray, axes: str) -> np.ndarray: _x = np.expand_dims(_x, new_axes.index("S") + 1) return _x - - -def compute_normalization_stats(image: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: - """ - Compute mean and standard deviation of an array. - - Expected input shape is (S, C, (Z), Y, X). The mean and standard deviation are - computed per channel. - - Parameters - ---------- - image : np.ndarray - Input array. - - Returns - ------- - Tuple[List[float], List[float]] - Lists of mean and standard deviation values per channel. - """ - # Define the list of axes excluding the channel axis - axes = tuple(np.delete(np.arange(image.ndim), 1)) - return np.mean(image, axis=axes), np.std(image, axis=axes) diff --git a/src/careamics/dataset/dataset_utils/running_stats.py b/src/careamics/dataset/dataset_utils/running_stats.py new file mode 100644 index 000000000..5ee40abd5 --- /dev/null +++ b/src/careamics/dataset/dataset_utils/running_stats.py @@ -0,0 +1,186 @@ +"""Computing data statistics.""" + +import numpy as np +from numpy.typing import NDArray + + +def compute_normalization_stats(image: NDArray) -> tuple[NDArray, NDArray]: + """ + Compute mean and standard deviation of an array. + + Expected input shape is (S, C, (Z), Y, X). The mean and standard deviation are + computed per channel. + + Parameters + ---------- + image : NDArray + Input array. + + Returns + ------- + tuple of (list of floats, list of floats) + Lists of mean and standard deviation values per channel. + """ + # Define the list of axes excluding the channel axis + axes = tuple(np.delete(np.arange(image.ndim), 1)) + return np.mean(image, axis=axes), np.std(image, axis=axes) + + +def update_iterative_stats( + count: NDArray, mean: NDArray, m2: NDArray, new_values: NDArray +) -> tuple[NDArray, NDArray, NDArray]: + """Update the mean and variance of an array iteratively. + + Parameters + ---------- + count : NDArray + Number of elements in the array. + mean : NDArray + Mean of the array. + m2 : NDArray + Variance of the array. + new_values : NDArray + New values to add to the mean and variance. + + Returns + ------- + tuple[NDArray, NDArray, NDArray] + Updated count, mean, and variance. + """ + count += np.array([np.prod(channel.shape) for channel in new_values]) + # newvalues - oldMean + delta = [ + np.subtract(v.flatten(), [m] * len(v.flatten())) + for v, m in zip(new_values, mean) + ] + + mean += np.array([np.sum(d / c) for d, c in zip(delta, count)]) + # newvalues - newMeant + delta2 = [ + np.subtract(v.flatten(), [m] * len(v.flatten())) + for v, m in zip(new_values, mean) + ] + + m2 += np.array([np.sum(d * d2) for d, d2 in zip(delta, delta2)]) + + return (count, mean, m2) + + +def finalize_iterative_stats( + count: NDArray, mean: NDArray, m2: NDArray +) -> tuple[NDArray, NDArray]: + """Finalize the mean and variance computation. + + Parameters + ---------- + count : NDArray + Number of elements in the array. + mean : NDArray + Mean of the array. + m2 : NDArray + Variance of the array. + + Returns + ------- + tuple[NDArray, NDArray] + Final mean and standard deviation. + """ + std = np.array([np.sqrt(m / c) for m, c in zip(m2, count)]) + if any(c < 2 for c in count): + return np.full(mean.shape, np.nan), np.full(std.shape, np.nan) + else: + return mean, std + + +class WelfordStatistics: + """Compute Welford statistics iteratively. + + The Welford algorithm is used to compute the mean and variance of an array + iteratively. Based on the implementation from: + https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm + """ + + def update(self, array: NDArray, sample_idx: int) -> None: + """Update the Welford statistics. + + Parameters + ---------- + array : NDArray + Input array. + sample_idx : int + Current sample number. + """ + self.sample_idx = sample_idx + sample_channels = np.array(np.split(array, array.shape[1], axis=1)) + + # Initialize the statistics + if self.sample_idx == 0: + # Compute the mean and standard deviation + self.mean, _ = compute_normalization_stats(array) + # Initialize the count and m2 with zero-valued arrays of shape (C,) + self.count, self.mean, self.m2 = update_iterative_stats( + count=np.zeros(array.shape[1]), + mean=self.mean, + m2=np.zeros(array.shape[1]), + new_values=sample_channels, + ) + else: + # Update the statistics + self.count, self.mean, self.m2 = update_iterative_stats( + count=self.count, mean=self.mean, m2=self.m2, new_values=sample_channels + ) + + self.sample_idx += 1 + + def finalize(self) -> tuple[NDArray, NDArray]: + """Finalize the Welford statistics. + + Returns + ------- + tuple or numpy arrays + Final mean and standard deviation. + """ + return finalize_iterative_stats(self.count, self.mean, self.m2) + + +# from multiprocessing import Value +# from typing import tuple + +# import numpy as np + + +# class RunningStats: +# """Calculates running mean and std.""" + +# def __init__(self) -> None: +# self.reset() + +# def reset(self) -> None: +# """Reset the running stats.""" +# self.avg_mean = Value("d", 0) +# self.avg_std = Value("d", 0) +# self.m2 = Value("d", 0) +# self.count = Value("i", 0) + +# def init(self, mean: float, std: float) -> None: +# """Initialize running stats.""" +# with self.avg_mean.get_lock(): +# self.avg_mean.value += mean +# with self.avg_std.get_lock(): +# self.avg_std.value = std + +# def compute_std(self) -> tuple[float, float]: +# """Compute std.""" +# if self.count.value >= 2: +# self.avg_std.value = np.sqrt(self.m2.value / self.count.value) + +# def update(self, value: float) -> None: +# """Update running stats.""" +# with self.count.get_lock(): +# self.count.value += 1 +# delta = value - self.avg_mean.value +# with self.avg_mean.get_lock(): +# self.avg_mean.value += delta / self.count.value +# delta2 = value - self.avg_mean.value +# with self.m2.get_lock(): +# self.m2.value += delta * delta2 diff --git a/src/careamics/dataset/iterable_dataset.py b/src/careamics/dataset/iterable_dataset.py index 99beb2727..afa48c711 100644 --- a/src/careamics/dataset/iterable_dataset.py +++ b/src/careamics/dataset/iterable_dataset.py @@ -15,7 +15,8 @@ from careamics.transforms import Compose from ..utils.logging import get_logger -from .dataset_utils import compute_normalization_stats, iterate_over_files, read_tiff +from .dataset_utils import iterate_over_files, read_tiff +from .dataset_utils.running_stats import WelfordStatistics from .patching.patching import Stats from .patching.random_patching import extract_patches_random @@ -125,23 +126,20 @@ def _calculate_mean_and_std(self) -> tuple[Stats, Stats]: tuple of Stats and optional Stats Data classes containing the image and target statistics. """ - image_means = [] - image_stds = [] - target_means = [] - target_stds = [] num_samples = 0 + image_stats = WelfordStatistics() + if self.target_files is not None: + target_stats = WelfordStatistics() for sample, target in iterate_over_files( self.data_config, self.data_files, self.target_files, self.read_source_func ): - sample_mean, sample_std = compute_normalization_stats(sample) - image_means.append(sample_mean) - image_stds.append(sample_std) + # update the image statistics + image_stats.update(sample, num_samples) + # update the target statistics if target is available if target is not None: - target_mean, target_std = compute_normalization_stats(target) - target_means.append(target_mean) - target_stds.append(target_std) + target_stats.update(target, num_samples) num_samples += 1 @@ -149,15 +147,10 @@ def _calculate_mean_and_std(self) -> tuple[Stats, Stats]: raise ValueError("No samples found in the dataset.") # Average the means and stds per sample - image_means = np.mean(image_means, axis=0) - image_stds = np.sqrt(np.mean([std**2 for std in image_stds], axis=0)) - - logger.info(f"Calculated mean and std for {num_samples} images") - logger.info(f"Mean: {image_means}, std: {image_stds}") + image_means, image_stds = image_stats.finalize() if target is not None: - target_means = np.mean(target_means, axis=0) - target_stds = np.sqrt(np.mean([std**2 for std in target_stds], axis=0)) + target_means, target_stds = target_stats.finalize() return ( Stats(image_means, image_stds), diff --git a/src/careamics/dataset/patching/patching.py b/src/careamics/dataset/patching/patching.py index b152b7d7f..74cdd14d2 100644 --- a/src/careamics/dataset/patching/patching.py +++ b/src/careamics/dataset/patching/patching.py @@ -7,7 +7,8 @@ import numpy as np from ...utils.logging import get_logger -from ..dataset_utils import compute_normalization_stats, reshape_array +from ..dataset_utils import reshape_array +from ..dataset_utils.running_stats import compute_normalization_stats from .sequential_patching import extract_patches_sequential logger = get_logger(__name__) diff --git a/src/careamics/lvae_training/__init__.py b/src/careamics/lvae_training/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/careamics/lvae_training/data_modules.py b/src/careamics/lvae_training/data_modules.py new file mode 100644 index 000000000..f3ad56dec --- /dev/null +++ b/src/careamics/lvae_training/data_modules.py @@ -0,0 +1,1220 @@ +""" +A place for Datasets and Dataloaders. +""" + +import os +from typing import Tuple, Union + +# import albumentations as A +import ml_collections +import numpy as np +from skimage.transform import resize + +from .data_utils import ( + DataSplitType, + DataType, + GridAlignement, + GridIndexManager, + IndexSwitcher, + get_datasplit_tuples, + get_mrc_data, + load_tiff, +) + + +def get_train_val_data( + data_config, + fpath, + datasplit_type: DataSplitType, + val_fraction=None, + test_fraction=None, + allow_generation=None, + ignore_specific_datapoints=None, +): + """ + Load the data from the given path and split them in training, validation and test sets. + + Ensure that the shape of data should be N*H*W*C: N is number of data points. H,W are the image dimensions. + C is the number of channels. + """ + if data_config.data_type == DataType.SeparateTiffData: + fpath1 = os.path.join(fpath, data_config.ch1_fname) + fpath2 = os.path.join(fpath, data_config.ch2_fname) + fpaths = [fpath1, fpath2] + fpath0 = "" + if "ch_input_fname" in data_config: + fpath0 = os.path.join(fpath, data_config.ch_input_fname) + fpaths = [fpath0] + fpaths + + print( + f"Loading from {fpath} Channels: " + f"{fpath1},{fpath2}, inp:{fpath0} Mode:{DataSplitType.name(datasplit_type)}" + ) + + data = np.concatenate([load_tiff(fpath)[..., None] for fpath in fpaths], axis=3) + if data_config.data_type == DataType.PredictedTiffData: + assert len(data.shape) == 5 and data.shape[-1] == 1 + data = data[..., 0].copy() + # data = data[::3].copy() + # NOTE: This was not the correct way to do it. It is so because the noise present in the input was directly related + # to the noise present in the channels and so this is not the way we would get the data. + # We need to add the noise independently to the input and the target. + + # if data_config.get('poisson_noise_factor', False): + # data = np.random.poisson(data) + # if data_config.get('enable_gaussian_noise', False): + # synthetic_scale = data_config.get('synthetic_gaussian_scale', 0.1) + # print('Adding Gaussian noise with scale', synthetic_scale) + # noise = np.random.normal(0, synthetic_scale, data.shape) + # data = data + noise + + if datasplit_type == DataSplitType.All: + return data.astype(np.float32) + + train_idx, val_idx, test_idx = get_datasplit_tuples( + val_fraction, test_fraction, len(data), starting_test=True + ) + if datasplit_type == DataSplitType.Train: + return data[train_idx].astype(np.float32) + elif datasplit_type == DataSplitType.Val: + return data[val_idx].astype(np.float32) + elif datasplit_type == DataSplitType.Test: + return data[test_idx].astype(np.float32) + + elif data_config.data_type == DataType.BioSR_MRC: + num_channels = data_config.get("num_channels", 2) + fpaths = [] + data_list = [] + for i in range(num_channels): + fpath1 = os.path.join(fpath, data_config.get(f"ch{i + 1}_fname")) + fpaths.append(fpath1) + data = get_mrc_data(fpath1)[..., None] + data_list.append(data) + + dirname = os.path.dirname(os.path.dirname(fpaths[0])) + "/" + + msg = ",".join([x[len(dirname) :] for x in fpaths]) + print( + f"Loaded from {dirname} Channels:{len(fpaths)} {msg} Mode:{DataSplitType.name(datasplit_type)}" + ) + N = data_list[0].shape[0] + for data in data_list: + N = min(N, data.shape[0]) + + cropped_data = [] + for data in data_list: + cropped_data.append(data[:N]) + + data = np.concatenate(cropped_data, axis=3) + + if datasplit_type == DataSplitType.All: + return data.astype(np.float32) + + train_idx, val_idx, test_idx = get_datasplit_tuples( + val_fraction, test_fraction, len(data), starting_test=True + ) + if datasplit_type == DataSplitType.Train: + return data[train_idx].astype(np.float32) + elif datasplit_type == DataSplitType.Val: + return data[val_idx].astype(np.float32) + elif datasplit_type == DataSplitType.Test: + return data[test_idx].astype(np.float32) + + +class MultiChDloader: + + def __init__( + self, + data_config: ml_collections.ConfigDict, + fpath: str, + datasplit_type: DataSplitType = None, + val_fraction: float = None, + test_fraction: float = None, + normalized_input=None, + enable_rotation_aug: bool = False, + enable_random_cropping: bool = False, + use_one_mu_std=None, + allow_generation: bool = False, + max_val: float = None, + grid_alignment=GridAlignement.LeftTop, + overlapping_padding_kwargs=None, + print_vars: bool = True, + ): + """ + Here, an image is split into grids of size img_sz. + Args: + repeat_factor: Since we are doing a random crop, repeat_factor is + given which can repeatedly sample from the same image. If self.N=12 + and repeat_factor is 5, then index upto 12*5 = 60 is allowed. + use_one_mu_std: If this is set to true, then one mean and stdev is used + for both channels. Otherwise, two different meean and stdev are used. + + """ + self._data_type = data_config.data_type + self._fpath = fpath + self._data = self.N = self._noise_data = None + + # Hardcoded params, not included in the config file. + + # by default, if the noise is present, add it to the input and target. + self._disable_noise = False # to add synthetic noise + self._train_index_switcher = None + # NOTE: Input is the sum of the different channels. It is not the average of the different channels. + self._input_is_sum = data_config.get("input_is_sum", False) + self._num_channels = data_config.get("num_channels", 2) + self._input_idx = data_config.get("input_idx", None) + self._tar_idx_list = data_config.get("target_idx_list", None) + + if datasplit_type == DataSplitType.Train: + self._datausage_fraction = 1.0 + # assert self._datausage_fraction == 1.0, 'Not supported. Use validtarget_random_fraction and training_validtarget_fraction to get the same effect' + self._validtarget_rand_fract = None + # self._validtarget_random_fraction_final = data_config.get('validtarget_random_fraction_final', None) + # self._validtarget_random_fraction_stepepoch = data_config.get('validtarget_random_fraction_stepepoch', None) + # self._idx_count = 0 + elif datasplit_type == DataSplitType.Val: + self._datausage_fraction = 1.0 + else: + self._datausage_fraction = 1.0 + + self.load_data( + data_config, + datasplit_type, + val_fraction=val_fraction, + test_fraction=test_fraction, + allow_generation=allow_generation, + ) + self._normalized_input = normalized_input + self._quantile = 1.0 + self._channelwise_quantile = False + self._background_quantile = 0.0 + self._clip_background_noise_to_zero = False + self._skip_normalization_using_mean = False + self._empty_patch_replacement_enabled = False + + self._background_values = None + + self._grid_alignment = grid_alignment + self._overlapping_padding_kwargs = overlapping_padding_kwargs + if self._grid_alignment == GridAlignement.LeftTop: + assert ( + self._overlapping_padding_kwargs is None + or data_config.multiscale_lowres_count is not None + ), "Padding is not used with this alignement style" + elif self._grid_alignment == GridAlignement.Center: + assert ( + self._overlapping_padding_kwargs is not None + ), "With Center grid alignment, padding is needed." + + self._is_train = datasplit_type == DataSplitType.Train + + # input = alpha * ch1 + (1-alpha)*ch2. + # alpha is sampled randomly between these two extremes + self._start_alpha_arr = self._end_alpha_arr = self._return_alpha = ( + self._alpha_weighted_target + ) = None + + self._img_sz = self._grid_sz = self._repeat_factor = self.idx_manager = None + if self._is_train: + self._start_alpha_arr = None + self._end_alpha_arr = None + self._alpha_weighted_target = False + + self.set_img_sz( + data_config.image_size, + ( + data_config.grid_size + if "grid_size" in data_config + else data_config.image_size + ), + ) + + # if self._validtarget_rand_fract is not None: + # self._train_index_switcher = IndexSwitcher(self.idx_manager, data_config, self._img_sz) + # self._std_background_arr = None + + else: + self.set_img_sz( + data_config.image_size, + ( + data_config.grid_size + if "grid_size" in data_config + else data_config.image_size + ), + ) + + self._return_alpha = False + self._return_index = False + + # self._empty_patch_replacement_enabled = data_config.get("empty_patch_replacement_enabled", + # False) and self._is_train + # if self._empty_patch_replacement_enabled: + # self._empty_patch_replacement_channel_idx = data_config.empty_patch_replacement_channel_idx + # self._empty_patch_replacement_probab = data_config.empty_patch_replacement_probab + # data_frames = self._data[..., self._empty_patch_replacement_channel_idx] + # # NOTE: This is on the raw data. So, it must be called before removing the background. + # self._empty_patch_fetcher = EmptyPatchFetcher(self.idx_manager, + # self._img_sz, + # data_frames, + # max_val_threshold=data_config.empty_patch_max_val_threshold) + + self.rm_bkground_set_max_val_and_upperclip_data(max_val, datasplit_type) + + # For overlapping dloader, image_size and repeat_factors are not related. hence a different function. + + self._mean = None + self._std = None + self._use_one_mu_std = use_one_mu_std + # Hardcoded + self._target_separate_normalization = True + + self._enable_rotation = enable_rotation_aug + self._enable_random_cropping = enable_random_cropping + self._uncorrelated_channels = ( + data_config.get("uncorrelated_channels", False) and self._is_train + ) + assert self._is_train or self._uncorrelated_channels is False + assert ( + self._enable_random_cropping is True or self._uncorrelated_channels is False + ) + # Randomly rotate [-90,90] + + self._rotation_transform = None + if self._enable_rotation: + raise NotImplementedError( + "Augmentation by means of rotation is not supported yet." + ) + self._rotation_transform = A.Compose([A.Flip(), A.RandomRotate90()]) + + if print_vars: + msg = self._init_msg() + print(msg) + + def disable_noise(self): + assert ( + self._poisson_noise_factor is None + ), "This is not supported. Poisson noise is added to the data itself and so the noise cannot be disabled." + self._disable_noise = True + + def enable_noise(self): + self._disable_noise = False + + def get_data_shape(self): + return self._data.shape + + def load_data( + self, + data_config, + datasplit_type, + val_fraction=None, + test_fraction=None, + allow_generation=None, + ): + self._data = get_train_val_data( + data_config, + self._fpath, + datasplit_type, + val_fraction=val_fraction, + test_fraction=test_fraction, + allow_generation=allow_generation, + ) + + old_shape = self._data.shape + if self._datausage_fraction < 1.0: + framepixelcount = np.prod(self._data.shape[1:3]) + pixelcount = int( + len(self._data) * framepixelcount * self._datausage_fraction + ) + frame_count = int(np.ceil(pixelcount / framepixelcount)) + last_frame_reduced_size, _ = IndexSwitcher.get_reduced_frame_size( + self._data.shape[:3], self._datausage_fraction + ) + self._data = self._data[:frame_count].copy() + if frame_count == 1: + self._data = self._data[ + :, :last_frame_reduced_size, :last_frame_reduced_size + ].copy() + print( + f"[{self.__class__.__name__}] New data shape: {self._data.shape} Old: {old_shape}" + ) + + msg = "" + if data_config.get("poisson_noise_factor", -1) > 0: + self._poisson_noise_factor = data_config.poisson_noise_factor + msg += f"Adding Poisson noise with factor {self._poisson_noise_factor}.\t" + self._data = ( + np.random.poisson(self._data / self._poisson_noise_factor) + * self._poisson_noise_factor + ) + + if data_config.get("enable_gaussian_noise", False): + synthetic_scale = data_config.get("synthetic_gaussian_scale", 0.1) + msg += f"Adding Gaussian noise with scale {synthetic_scale}" + # 0 => noise for input. 1: => noise for all targets. + shape = self._data.shape + self._noise_data = np.random.normal( + 0, synthetic_scale, (*shape[:-1], shape[-1] + 1) + ) + if data_config.get("input_has_dependant_noise", False): + msg += ". Moreover, input has dependent noise" + self._noise_data[..., 0] = np.mean(self._noise_data[..., 1:], axis=-1) + print(msg) + + self.N = len(self._data) + assert ( + self._data.shape[-1] == self._num_channels + ), "Number of channels in data and config do not match." + + def save_background(self, channel_idx, frame_idx, background_value): + self._background_values[frame_idx, channel_idx] = background_value + + def get_background(self, channel_idx, frame_idx): + return self._background_values[frame_idx, channel_idx] + + def remove_background(self): + + self._background_values = np.zeros((self._data.shape[0], self._data.shape[-1])) + + if self._background_quantile == 0.0: + assert ( + self._clip_background_noise_to_zero is False + ), "This operation currently happens later in this function." + return + + if self._data.dtype in [np.uint16]: + # unsigned integer creates havoc + self._data = self._data.astype(np.int32) + + for ch in range(self._data.shape[-1]): + for idx in range(self._data.shape[0]): + qval = np.quantile(self._data[idx, ..., ch], self._background_quantile) + assert ( + np.abs(qval) > 20 + ), "We are truncating the qval to an integer which will only make sense if it is large enough" + # NOTE: Here, there can be an issue if you work with normalized data + qval = int(qval) + self.save_background(ch, idx, qval) + self._data[idx, ..., ch] -= qval + + if self._clip_background_noise_to_zero: + self._data[self._data < 0] = 0 + + def rm_bkground_set_max_val_and_upperclip_data(self, max_val, datasplit_type): + self.remove_background() + self.set_max_val(max_val, datasplit_type) + self.upperclip_data() + + def upperclip_data(self): + if isinstance(self.max_val, list): + chN = self._data.shape[-1] + assert chN == len(self.max_val) + for ch in range(chN): + ch_data = self._data[..., ch] + ch_q = self.max_val[ch] + ch_data[ch_data > ch_q] = ch_q + self._data[..., ch] = ch_data + else: + self._data[self._data > self.max_val] = self.max_val + + def compute_max_val(self): + if self._channelwise_quantile: + max_val_arr = [ + np.quantile(self._data[..., i], self._quantile) + for i in range(self._data.shape[-1]) + ] + return max_val_arr + else: + return np.quantile(self._data, self._quantile) + + def set_max_val(self, max_val, datasplit_type): + + if max_val is None: + assert datasplit_type == DataSplitType.Train + self.max_val = self.compute_max_val() + else: + assert max_val is not None + self.max_val = max_val + + def get_max_val(self): + return self.max_val + + def get_img_sz(self): + return self._img_sz + + def reduce_data( + self, t_list=None, h_start=None, h_end=None, w_start=None, w_end=None + ): + if t_list is None: + t_list = list(range(self._data.shape[0])) + if h_start is None: + h_start = 0 + if h_end is None: + h_end = self._data.shape[1] + if w_start is None: + w_start = 0 + if w_end is None: + w_end = self._data.shape[2] + + self._data = self._data[t_list, h_start:h_end, w_start:w_end, :].copy() + if self._noise_data is not None: + self._noise_data = self._noise_data[ + t_list, h_start:h_end, w_start:w_end, : + ].copy() + + self.N = len(t_list) + self.set_img_sz(self._img_sz, self._grid_sz) + print( + f"[{self.__class__.__name__}] Data reduced. New data shape: {self._data.shape}" + ) + + def set_img_sz(self, image_size, grid_size): + """ + If one wants to change the image size on the go, then this can be used. + Args: + image_size: size of one patch + grid_size: frame is divided into square grids of this size. A patch centered on a grid having size `image_size` is returned. + """ + self._img_sz = image_size + self._grid_sz = grid_size + self.idx_manager = GridIndexManager( + self._data.shape, self._grid_sz, self._img_sz, self._grid_alignment + ) + self.set_repeat_factor() + + def set_repeat_factor(self): + if self._grid_sz > 1: + self._repeat_factor = self.idx_manager.grid_rows( + self._grid_sz + ) * self.idx_manager.grid_cols(self._grid_sz) + else: + self._repeat_factor = self.idx_manager.grid_rows( + self._img_sz + ) * self.idx_manager.grid_cols(self._img_sz) + + def _init_msg( + self, + ): + msg = ( + f"[{self.__class__.__name__}] Train:{int(self._is_train)} Sz:{self._img_sz}" + ) + msg += f" N:{self.N} NumPatchPerN:{self._repeat_factor}" + # msg += f' NormInp:{self._normalized_input}' + # msg += f' SingleNorm:{self._use_one_mu_std}' + msg += f" Rot:{self._enable_rotation}" + msg += f" RandCrop:{self._enable_random_cropping}" + msg += f" Channel:{self._num_channels}" + # msg += f' Q:{self._quantile}' + if self._input_is_sum: + msg += f" SummedInput:{self._input_is_sum}" + + if self._empty_patch_replacement_enabled: + msg += f" ReplaceWithRandSample:{self._empty_patch_replacement_enabled}" + if self._uncorrelated_channels: + msg += f" Uncorr:{self._uncorrelated_channels}" + if self._empty_patch_replacement_enabled: + msg += f"-{self._empty_patch_replacement_channel_idx}-{self._empty_patch_replacement_probab}" + if self._background_quantile > 0.0: + msg += f" BckQ:{self._background_quantile}" + + if self._start_alpha_arr is not None: + msg += f" Alpha:[{self._start_alpha_arr},{self._end_alpha_arr}]" + return msg + + def _crop_imgs(self, index, *img_tuples: np.ndarray): + h, w = img_tuples[0].shape[-2:] + if self._img_sz is None: + return ( + *img_tuples, + {"h": [0, h], "w": [0, w], "hflip": False, "wflip": False}, + ) + + if self._enable_random_cropping: + h_start, w_start = self._get_random_hw(h, w) + else: + h_start, w_start = self._get_deterministic_hw(index) + + cropped_imgs = [] + for img in img_tuples: + img = self._crop_flip_img(img, h_start, w_start, False, False) + cropped_imgs.append(img) + + return ( + *tuple(cropped_imgs), + { + "h": [h_start, h_start + self._img_sz], + "w": [w_start, w_start + self._img_sz], + "hflip": False, + "wflip": False, + }, + ) + + def _crop_img(self, img: np.ndarray, h_start: int, w_start: int): + if self._grid_alignment == GridAlignement.LeftTop: + # In training, this is used. + # NOTE: It is my opinion that if I just use self._crop_img_with_padding, it will work perfectly fine. + # The only benefit this if else loop provides is that it makes it easier to see what happens during training. + new_img = img[ + ..., h_start : h_start + self._img_sz, w_start : w_start + self._img_sz + ] + return new_img + elif self._grid_alignment == GridAlignement.Center: + # During evaluation, this is used. In this situation, we can have negative h_start, w_start. Or h_start +self._img_sz can be larger than frame + # In these situations, we need some sort of padding. This is not needed in the LeftTop alignement. + return self._crop_img_with_padding(img, h_start, w_start) + + def get_begin_end_padding(self, start_pos, max_len): + """ + The effect is that the image with size self._grid_sz is in the center of the patch with sufficient + padding on all four sides so that the final patch size is self._img_sz. + """ + pad_start = 0 + pad_end = 0 + if start_pos < 0: + pad_start = -1 * start_pos + + pad_end = max(0, start_pos + self._img_sz - max_len) + + return pad_start, pad_end + + def _crop_img_with_padding(self, img: np.ndarray, h_start: int, w_start: int): + _, H, W = img.shape + h_on_boundary = self.on_boundary(h_start, H) + w_on_boundary = self.on_boundary(w_start, W) + + assert h_start < H + assert w_start < W + + assert h_start + self._img_sz <= H or h_on_boundary + assert w_start + self._img_sz <= W or w_on_boundary + # max() is needed since h_start could be negative. + new_img = img[ + ..., + max(0, h_start) : h_start + self._img_sz, + max(0, w_start) : w_start + self._img_sz, + ] + padding = np.array([[0, 0], [0, 0], [0, 0]]) + + if h_on_boundary: + pad = self.get_begin_end_padding(h_start, H) + padding[1] = pad + if w_on_boundary: + pad = self.get_begin_end_padding(w_start, W) + padding[2] = pad + + if not np.all(padding == 0): + new_img = np.pad(new_img, padding, **self._overlapping_padding_kwargs) + + return new_img + + def _crop_flip_img( + self, img: np.ndarray, h_start: int, w_start: int, h_flip: bool, w_flip: bool + ): + new_img = self._crop_img(img, h_start, w_start) + if h_flip: + new_img = new_img[..., ::-1, :] + if w_flip: + new_img = new_img[..., :, ::-1] + + return new_img.astype(np.float32) + + def __len__(self): + return self.N * self._repeat_factor + + def _load_img( + self, index: Union[int, Tuple[int, int]] + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Returns the channels and also the respective noise channels. + """ + if isinstance(index, int) or isinstance(index, np.int64): + idx = index + else: + idx = index[0] + + imgs = self._data[self.idx_manager.get_t(idx)] + loaded_imgs = [imgs[None, ..., i] for i in range(imgs.shape[-1])] + noise = [] + if self._noise_data is not None and not self._disable_noise: + noise = [ + self._noise_data[self.idx_manager.get_t(idx)][None, ..., i] + for i in range(self._noise_data.shape[-1]) + ] + return tuple(loaded_imgs), tuple(noise) + + def get_mean_std(self): + return self._mean, self._std + + def set_mean_std(self, mean_val, std_val): + self._mean = mean_val + self._std = std_val + + def normalize_img(self, *img_tuples): + mean, std = self.get_mean_std() + mean = mean["target"] + std = std["target"] + mean = mean.squeeze() + std = std.squeeze() + normalized_imgs = [] + for i, img in enumerate(img_tuples): + img = (img - mean[i]) / std[i] + normalized_imgs.append(img) + return tuple(normalized_imgs) + + def get_grid_size(self): + return self._grid_sz + + def get_idx_manager(self): + return self.idx_manager + + def per_side_overlap_pixelcount(self): + return (self._img_sz - self._grid_sz) // 2 + + def on_boundary(self, cur_loc, frame_size): + return cur_loc + self._img_sz > frame_size or cur_loc < 0 + + def _get_deterministic_hw(self, index: Union[int, Tuple[int, int]]): + """ + It returns the top-left corner of the patch corresponding to index. + """ + if isinstance(index, int) or isinstance(index, np.int64): + idx = index + grid_size = self._grid_sz + else: + idx, grid_size = index + + h_start, w_start = self.idx_manager.get_deterministic_hw( + idx, grid_size=grid_size + ) + if self._grid_alignment == GridAlignement.LeftTop: + return h_start, w_start + elif self._grid_alignment == GridAlignement.Center: + pad = self.per_side_overlap_pixelcount() + return h_start - pad, w_start - pad + + def compute_individual_mean_std(self): + # numpy 1.19.2 has issues in computing for large arrays. https://github.com/numpy/numpy/issues/8869 + # mean = np.mean(self._data, axis=(0, 1, 2)) + # std = np.std(self._data, axis=(0, 1, 2)) + mean_arr = [] + std_arr = [] + for ch_idx in range(self._data.shape[-1]): + mean_ = ( + 0.0 + if self._skip_normalization_using_mean + else self._data[..., ch_idx].mean() + ) + if self._noise_data is not None: + std_ = ( + self._data[..., ch_idx] + self._noise_data[..., ch_idx + 1] + ).std() + else: + std_ = self._data[..., ch_idx].std() + + mean_arr.append(mean_) + std_arr.append(std_) + + mean = np.array(mean_arr) + std = np.array(std_arr) + + return mean[None, :, None, None], std[None, :, None, None] + + def compute_mean_std(self, allow_for_validation_data=False): + """ + Note that we must compute this only for training data. + """ + assert ( + self._is_train is True or allow_for_validation_data + ), "This is just allowed for training data" + assert self._use_one_mu_std is True, "This is the only supported case" + + if self._input_idx is not None: + assert ( + self._tar_idx_list is not None + ), "tar_idx_list must be set if input_idx is set." + assert self._noise_data is None, "This is not supported with noise" + assert ( + self._target_separate_normalization is True + ), "This is not supported with target_separate_normalization=False" + + mean, std = self.compute_individual_mean_std() + mean_dict = { + "input": mean[:, self._input_idx : self._input_idx + 1], + "target": mean[:, self._tar_idx_list], + } + std_dict = { + "input": std[:, self._input_idx : self._input_idx + 1], + "target": std[:, self._tar_idx_list], + } + return mean_dict, std_dict + + if self._input_is_sum: + assert self._noise_data is None, "This is not supported with noise" + mean = [ + np.mean(self._data[..., k : k + 1], keepdims=True) + for k in range(self._num_channels) + ] + mean = np.sum(mean, keepdims=True)[0] + std = np.linalg.norm( + [ + np.std(self._data[..., k : k + 1], keepdims=True) + for k in range(self._num_channels) + ], + keepdims=True, + )[0] + else: + mean = np.mean(self._data, keepdims=True).reshape(1, 1, 1, 1) + if self._noise_data is not None: + std = np.std( + self._data + self._noise_data[..., 1:], keepdims=True + ).reshape(1, 1, 1, 1) + else: + std = np.std(self._data, keepdims=True).reshape(1, 1, 1, 1) + + mean = np.repeat(mean, self._num_channels, axis=1) + std = np.repeat(std, self._num_channels, axis=1) + + if self._skip_normalization_using_mean: + mean = np.zeros_like(mean) + + mean_dict = {"input": mean} # , 'target':mean} + std_dict = {"input": std} # , 'target':std} + + if self._target_separate_normalization: + mean, std = self.compute_individual_mean_std() + + mean_dict["target"] = mean + std_dict["target"] = std + return mean_dict, std_dict + + def _get_random_hw(self, h: int, w: int): + """ + Random starting position for the crop for the img with index `index`. + """ + if h != self._img_sz: + h_start = np.random.choice(h - self._img_sz) + w_start = np.random.choice(w - self._img_sz) + else: + h_start = 0 + w_start = 0 + return h_start, w_start + + def _get_img(self, index: Union[int, Tuple[int, int]]): + """ + Loads an image. + Crops the image such that cropped image has content. + """ + img_tuples, noise_tuples = self._load_img(index) + cropped_img_tuples = self._crop_imgs(index, *img_tuples, *noise_tuples)[:-1] + cropped_noise_tuples = cropped_img_tuples[len(img_tuples) :] + cropped_img_tuples = cropped_img_tuples[: len(img_tuples)] + return cropped_img_tuples, cropped_noise_tuples + + def replace_with_empty_patch(self, img_tuples): + empty_index = self._empty_patch_fetcher.sample() + empty_img_tuples = self._get_img(empty_index) + final_img_tuples = [] + for tuple_idx in range(len(img_tuples)): + if tuple_idx == self._empty_patch_replacement_channel_idx: + final_img_tuples.append(empty_img_tuples[tuple_idx]) + else: + final_img_tuples.append(img_tuples[tuple_idx]) + return tuple(final_img_tuples) + + def get_mean_std_for_input(self): + mean, std = self.get_mean_std() + return mean["input"], std["input"] + + def _compute_target(self, img_tuples, alpha): + if self._tar_idx_list is not None and isinstance(self._tar_idx_list, int): + target = img_tuples[self._tar_idx_list] + else: + if self._tar_idx_list is not None: + assert isinstance(self._tar_idx_list, list) or isinstance( + self._tar_idx_list, tuple + ) + img_tuples = [img_tuples[i] for i in self._tar_idx_list] + + if self._alpha_weighted_target: + assert self._input_is_sum is False + target = [] + for i in range(len(img_tuples)): + target.append(img_tuples[i] * alpha[i]) + target = np.concatenate(target, axis=0) + else: + target = np.concatenate(img_tuples, axis=0) + return target + + def _compute_input_with_alpha(self, img_tuples, alpha_list): + # assert self._normalized_input is True, "normalization should happen here" + if self._input_idx is not None: + inp = img_tuples[self._input_idx] + else: + inp = 0 + for alpha, img in zip(alpha_list, img_tuples): + inp += img * alpha + + if self._normalized_input is False: + return inp.astype(np.float32) + + mean, std = self.get_mean_std_for_input() + mean = mean.squeeze() + std = std.squeeze() + if mean.size == 1: + mean = mean.reshape( + 1, + ) + std = std.reshape( + 1, + ) + + for i in range(len(mean)): + assert mean[0] == mean[i] + assert std[0] == std[i] + + inp = (inp - mean[0]) / std[0] + return inp.astype(np.float32) + + def _sample_alpha(self): + alpha_arr = [] + for i in range(self._num_channels): + alpha_pos = np.random.rand() + alpha = self._start_alpha_arr[i] + alpha_pos * ( + self._end_alpha_arr[i] - self._start_alpha_arr[i] + ) + alpha_arr.append(alpha) + return alpha_arr + + def _compute_input(self, img_tuples): + alpha = [1 / len(img_tuples) for _ in range(len(img_tuples))] + if self._start_alpha_arr is not None: + alpha = self._sample_alpha() + + inp = self._compute_input_with_alpha(img_tuples, alpha) + if self._input_is_sum: + inp = len(img_tuples) * inp + return inp, alpha + + def _get_index_from_valid_target_logic(self, index): + if self._validtarget_rand_fract is not None: + if np.random.rand() < self._validtarget_rand_fract: + index = self._train_index_switcher.get_valid_target_index() + else: + index = self._train_index_switcher.get_invalid_target_index() + return index + + def _rotate(self, img_tuples, noise_tuples): + return self._rotate2D(img_tuples, noise_tuples) + + def _rotate2D(self, img_tuples, noise_tuples): + img_kwargs = {} + for i, img in enumerate(img_tuples): + for k in range(len(img)): + img_kwargs[f"img{i}_{k}"] = img[k] + + noise_kwargs = {} + for i, nimg in enumerate(noise_tuples): + for k in range(len(nimg)): + noise_kwargs[f"noise{i}_{k}"] = nimg[k] + + keys = list(img_kwargs.keys()) + list(noise_kwargs.keys()) + self._rotation_transform.add_targets({k: "image" for k in keys}) + rot_dic = self._rotation_transform( + image=img_tuples[0][0], **img_kwargs, **noise_kwargs + ) + rotated_img_tuples = [] + for i, img in enumerate(img_tuples): + if len(img) == 1: + rotated_img_tuples.append(rot_dic[f"img{i}_0"][None]) + else: + rotated_img_tuples.append( + np.concatenate( + [rot_dic[f"img{i}_{k}"][None] for k in range(len(img))], axis=0 + ) + ) + + rotated_noise_tuples = [] + for i, nimg in enumerate(noise_tuples): + if len(nimg) == 1: + rotated_noise_tuples.append(rot_dic[f"noise{i}_0"][None]) + else: + rotated_noise_tuples.append( + np.concatenate( + [rot_dic[f"noise{i}_{k}"][None] for k in range(len(nimg))], + axis=0, + ) + ) + + return rotated_img_tuples, rotated_noise_tuples + + def get_uncorrelated_img_tuples(self, index): + img_tuples, noise_tuples = self._get_img(index) + assert len(noise_tuples) == 0 + img_tuples = [img_tuples[0]] + for ch_idx in range(1, len(img_tuples)): + new_index = np.random.randint(len(self)) + other_img_tuples, _ = self._get_img(new_index) + img_tuples.append(other_img_tuples[ch_idx]) + return img_tuples, noise_tuples + + def __getitem__( + self, index: Union[int, Tuple[int, int]] + ) -> Tuple[np.ndarray, np.ndarray]: + if self._train_index_switcher is not None: + index = self._get_index_from_valid_target_logic(index) + + if self._uncorrelated_channels: + img_tuples, noise_tuples = self.get_uncorrelated_img_tuples(index) + else: + img_tuples, noise_tuples = self._get_img(index) + + assert ( + self._empty_patch_replacement_enabled != True + ), "This is not supported with noise" + + if self._empty_patch_replacement_enabled: + if np.random.rand() < self._empty_patch_replacement_probab: + img_tuples = self.replace_with_empty_patch(img_tuples) + + if self._enable_rotation: + img_tuples, noise_tuples = self._rotate(img_tuples, noise_tuples) + + # add noise to input + if len(noise_tuples) > 0: + factor = np.sqrt(2) if self._input_is_sum else 1.0 + input_tuples = [x + noise_tuples[0] * factor for x in img_tuples] + else: + input_tuples = img_tuples + inp, alpha = self._compute_input(input_tuples) + + # add noise to target. + if len(noise_tuples) >= 1: + img_tuples = [x + noise for x, noise in zip(img_tuples, noise_tuples[1:])] + + target = self._compute_target(img_tuples, alpha) + + output = [inp, target] + + if self._return_alpha: + output.append(alpha) + + if self._return_index: + output.append(index) + + if isinstance(index, int) or isinstance(index, np.int64): + return tuple(output) + + _, grid_size = index + output.append(grid_size) + return tuple(output) + + +class LCMultiChDloader(MultiChDloader): + + def __init__( + self, + data_config, + fpath: str, + datasplit_type: DataSplitType = None, + val_fraction=None, + test_fraction=None, + normalized_input=None, + enable_rotation_aug: bool = False, + use_one_mu_std=None, + num_scales: int = None, + enable_random_cropping=False, + padding_kwargs: dict = None, + allow_generation: bool = False, + lowres_supervision=None, + max_val=None, + grid_alignment=GridAlignement.LeftTop, + overlapping_padding_kwargs=None, + print_vars=True, + ): + """ + Args: + num_scales: The number of resolutions at which we want the input. Note that the target is formed at the + highest resolution. + """ + self._padding_kwargs = ( + padding_kwargs # mode=padding_mode, constant_values=constant_value + ) + if overlapping_padding_kwargs is not None: + assert ( + self._padding_kwargs == overlapping_padding_kwargs + ), "During evaluation, overlapping_padding_kwargs should be same as padding_args. \ + It should be so since we just use overlapping_padding_kwargs when it is not None" + + else: + overlapping_padding_kwargs = padding_kwargs + + super().__init__( + data_config, + fpath, + datasplit_type=datasplit_type, + val_fraction=val_fraction, + test_fraction=test_fraction, + normalized_input=normalized_input, + enable_rotation_aug=enable_rotation_aug, + enable_random_cropping=enable_random_cropping, + use_one_mu_std=use_one_mu_std, + allow_generation=allow_generation, + max_val=max_val, + grid_alignment=grid_alignment, + overlapping_padding_kwargs=overlapping_padding_kwargs, + print_vars=print_vars, + ) + self.num_scales = num_scales + assert self.num_scales is not None + self._scaled_data = [self._data] + self._scaled_noise_data = [self._noise_data] + + assert isinstance(self.num_scales, int) and self.num_scales >= 1 + self._lowres_supervision = lowres_supervision + assert isinstance(self._padding_kwargs, dict) + assert "mode" in self._padding_kwargs + + for _ in range(1, self.num_scales): + shape = self._scaled_data[-1].shape + assert len(shape) == 4 + new_shape = (shape[0], shape[1] // 2, shape[2] // 2, shape[3]) + ds_data = resize( + self._scaled_data[-1].astype(np.float32), new_shape + ).astype(self._scaled_data[-1].dtype) + # NOTE: These asserts are important. the resize method expects np.float32. otherwise, one gets weird results. + assert ( + ds_data.max() / self._scaled_data[-1].max() < 5 + ), "Downsampled image should not have very different values" + assert ( + ds_data.max() / self._scaled_data[-1].max() > 0.2 + ), "Downsampled image should not have very different values" + + self._scaled_data.append(ds_data) + # do the same for noise + if self._noise_data is not None: + noise_data = resize(self._scaled_noise_data[-1], new_shape) + self._scaled_noise_data.append(noise_data) + + def _init_msg(self): + msg = super()._init_msg() + msg += f" Pad:{self._padding_kwargs}" + return msg + + def _load_scaled_img( + self, scaled_index, index: Union[int, Tuple[int, int]] + ) -> Tuple[np.ndarray, np.ndarray]: + if isinstance(index, int): + idx = index + else: + idx, _ = index + imgs = self._scaled_data[scaled_index][idx % self.N] + imgs = tuple([imgs[None, :, :, i] for i in range(imgs.shape[-1])]) + if self._noise_data is not None: + noisedata = self._scaled_noise_data[scaled_index][idx % self.N] + noise = tuple( + [noisedata[None, :, :, i] for i in range(noisedata.shape[-1])] + ) + factor = np.sqrt(2) if self._input_is_sum else 1.0 + # since we are using this lowres images for just the input, we need to add the noise of the input. + assert self._lowres_supervision is None or self._lowres_supervision is False + imgs = tuple([img + noise[0] * factor for img in imgs]) + return imgs + + def _crop_img(self, img: np.ndarray, h_start: int, w_start: int): + """ + Here, h_start, w_start could be negative. That simply means we need to pick the content from 0. So, + the cropped image will be smaller than self._img_sz * self._img_sz + """ + return self._crop_img_with_padding(img, h_start, w_start) + + def _get_img(self, index: int): + """ + Returns the primary patch along with low resolution patches centered on the primary patch. + """ + img_tuples, noise_tuples = self._load_img(index) + assert self._img_sz is not None + h, w = img_tuples[0].shape[-2:] + if self._enable_random_cropping: + h_start, w_start = self._get_random_hw(h, w) + else: + h_start, w_start = self._get_deterministic_hw(index) + + cropped_img_tuples = [ + self._crop_flip_img(img, h_start, w_start, False, False) + for img in img_tuples + ] + cropped_noise_tuples = [ + self._crop_flip_img(noise, h_start, w_start, False, False) + for noise in noise_tuples + ] + h_center = h_start + self._img_sz // 2 + w_center = w_start + self._img_sz // 2 + allres_versions = { + i: [cropped_img_tuples[i]] for i in range(len(cropped_img_tuples)) + } + for scale_idx in range(1, self.num_scales): + scaled_img_tuples = self._load_scaled_img(scale_idx, index) + + h_center = h_center // 2 + w_center = w_center // 2 + + h_start = h_center - self._img_sz // 2 + w_start = w_center - self._img_sz // 2 + + scaled_cropped_img_tuples = [ + self._crop_flip_img(img, h_start, w_start, False, False) + for img in scaled_img_tuples + ] + for ch_idx in range(len(img_tuples)): + allres_versions[ch_idx].append(scaled_cropped_img_tuples[ch_idx]) + + output_img_tuples = tuple( + [ + np.concatenate(allres_versions[ch_idx]) + for ch_idx in range(len(img_tuples)) + ] + ) + return output_img_tuples, cropped_noise_tuples + + def __getitem__(self, index: Union[int, Tuple[int, int]]): + if self._uncorrelated_channels: + img_tuples, noise_tuples = self.get_uncorrelated_img_tuples(index) + else: + img_tuples, noise_tuples = self._get_img(index) + + if self._enable_rotation: + img_tuples, noise_tuples = self._rotate(img_tuples, noise_tuples) + + assert self._lowres_supervision != True + # add noise to input + if len(noise_tuples) > 0: + factor = np.sqrt(2) if self._input_is_sum else 1.0 + input_tuples = [] + for x in img_tuples: + # NOTE: other LC levels already have noise added. So, we just need to add noise to the highest resolution. + x[0] = x[0] + noise_tuples[0] * factor + input_tuples.append(x) + else: + input_tuples = img_tuples + + inp, alpha = self._compute_input(input_tuples) + # assert self._alpha_weighted_target in [False, None] + target_tuples = [img[:1] for img in img_tuples] + # add noise to target. + if len(noise_tuples) >= 1: + target_tuples = [ + x + noise for x, noise in zip(target_tuples, noise_tuples[1:]) + ] + + target = self._compute_target(target_tuples, alpha) + + output = [inp, target] + + if self._return_alpha: + output.append(alpha) + + if isinstance(index, int): + return tuple(output) + + _, grid_size = index + output.append(grid_size) + return tuple(output) diff --git a/src/careamics/lvae_training/data_utils.py b/src/careamics/lvae_training/data_utils.py new file mode 100644 index 000000000..975562519 --- /dev/null +++ b/src/careamics/lvae_training/data_utils.py @@ -0,0 +1,618 @@ +""" +Utility functions needed by dataloader & co. +""" + +from typing import List + +import numpy as np +from skimage.io import imread, imsave + +from careamics.models.lvae.utils import Enum + + +class DataType(Enum): + MNIST = 0 + Places365 = 1 + NotMNIST = 2 + OptiMEM100_014 = 3 + CustomSinosoid = 4 + Prevedel_EMBL = 5 + AllenCellMito = 6 + SeparateTiffData = 7 + CustomSinosoidThreeCurve = 8 + SemiSupBloodVesselsEMBL = 9 + Pavia2 = 10 + Pavia2VanillaSplitting = 11 + ExpansionMicroscopyMitoTub = 12 + ShroffMitoEr = 13 + HTIba1Ki67 = 14 + BSD68 = 15 + BioSR_MRC = 16 + TavernaSox2Golgi = 17 + Dao3Channel = 18 + ExpMicroscopyV2 = 19 + Dao3ChannelWithInput = 20 + TavernaSox2GolgiV2 = 21 + TwoDset = 22 + PredictedTiffData = 23 + Pavia3SeqData = 24 + # Here, we have 16 splitting tasks. + NicolaData = 25 + + +class DataSplitType(Enum): + All = 0 + Train = 1 + Val = 2 + Test = 3 + + +class GridAlignement(Enum): + """ + A patch is formed by padding the grid with content. If the grids are 'Center' aligned, then padding is to done equally on all 4 sides. + On the other hand, if grids are 'LeftTop' aligned, padding is to be done on the right and bottom end of the grid. + In the former case, one needs (patch_size - grid_size)//2 amount of content on the right end of the frame. + In the latter case, one needs patch_size - grid_size amount of content on the right end of the frame. + """ + + LeftTop = 0 + Center = 1 + + +def load_tiff(path): + """ + Returns a 4d numpy array: num_imgs*h*w*num_channels + """ + data = imread(path, plugin="tifffile") + return data + + +def save_tiff(path, data): + imsave(path, data, plugin="tifffile") + + +def load_tiffs(paths): + data = [load_tiff(path) for path in paths] + return np.concatenate(data, axis=0) + + +def split_in_half(s, e): + n = e - s + s1 = list(np.arange(n // 2)) + s2 = list(np.arange(n // 2, n)) + return [x + s for x in s1], [x + s for x in s2] + + +def adjust_for_imbalance_in_fraction_value( + val: List[int], + test: List[int], + val_fraction: float, + test_fraction: float, + total_size: int, +): + """ + here, val and test are divided almost equally. Here, we need to take into account their respective fractions + and pick elements rendomly from one array and put in the other array. + """ + if val_fraction == 0: + test += val + val = [] + elif test_fraction == 0: + val += test + test = [] + else: + diff_fraction = test_fraction - val_fraction + if diff_fraction > 0: + imb_count = int(diff_fraction * total_size / 2) + val = list(np.random.RandomState(seed=955).permutation(val)) + test += val[:imb_count] + val = val[imb_count:] + elif diff_fraction < 0: + imb_count = int(-1 * diff_fraction * total_size / 2) + test = list(np.random.RandomState(seed=955).permutation(test)) + val += test[:imb_count] + test = test[imb_count:] + return val, test + + +def get_datasplit_tuples( + val_fraction: float, + test_fraction: float, + total_size: int, + starting_test: bool = False, +): + if starting_test: + # test => val => train + test = list(range(0, int(total_size * test_fraction))) + val = list(range(test[-1] + 1, test[-1] + 1 + int(total_size * val_fraction))) + train = list(range(val[-1] + 1, total_size)) + else: + # {test,val}=> train + test_val_size = int((val_fraction + test_fraction) * total_size) + train = list(range(test_val_size, total_size)) + + if test_val_size == 0: + test = [] + val = [] + return train, val, test + + # Split the test and validation in chunks. + chunksize = max(1, min(3, test_val_size // 2)) + + nchunks = test_val_size // chunksize + + test = [] + val = [] + s = 0 + for i in range(nchunks): + if i % 2 == 0: + val += list(np.arange(s, s + chunksize)) + else: + test += list(np.arange(s, s + chunksize)) + s += chunksize + + if i % 2 == 0: + test += list(np.arange(s, test_val_size)) + else: + p1, p2 = split_in_half(s, test_val_size) + test += p1 + val += p2 + + val, test = adjust_for_imbalance_in_fraction_value( + val, test, val_fraction, test_fraction, total_size + ) + + return train, val, test + + +def get_mrc_data(fpath): + # HXWXN + _, data = read_mrc(fpath) + data = data[None] + data = np.swapaxes(data, 0, 3) + return data[..., 0] + + +class GridIndexManager: + + def __init__(self, data_shape, grid_size, patch_size, grid_alignement) -> None: + self._data_shape = data_shape + self._default_grid_size = grid_size + self.patch_size = patch_size + self.N = self._data_shape[0] + self._align = grid_alignement + + def get_data_shape(self): + return self._data_shape + + def use_default_grid(self, grid_size): + return grid_size is None or grid_size < 0 + + def grid_rows(self, grid_size): + if self._align == GridAlignement.LeftTop: + extra_pixels = self.patch_size - grid_size + elif self._align == GridAlignement.Center: + # Center is exclusively used during evaluation. In this case, we use the padding to handle edge cases. + # So, here, we will ideally like to cover all pixels and so extra_pixels is set to 0. + # If there was no padding, then it should be set to (self.patch_size - grid_size) // 2 + extra_pixels = 0 + + return (self._data_shape[-3] - extra_pixels) // grid_size + + def grid_cols(self, grid_size): + if self._align == GridAlignement.LeftTop: + extra_pixels = self.patch_size - grid_size + elif self._align == GridAlignement.Center: + extra_pixels = 0 + + return (self._data_shape[-2] - extra_pixels) // grid_size + + def grid_count(self, grid_size=None): + if self.use_default_grid(grid_size): + grid_size = self._default_grid_size + + return self.N * self.grid_rows(grid_size) * self.grid_cols(grid_size) + + def hwt_from_idx(self, index, grid_size=None): + t = self.get_t(index) + return (*self.get_deterministic_hw(index, grid_size=grid_size), t) + + def idx_from_hwt(self, h_start, w_start, t, grid_size=None): + """ + Given h,w,t (where h,w constitutes the top left corner of the patch), it returns the corresponding index. + """ + if grid_size is None: + grid_size = self._default_grid_size + + nth_row = h_start // grid_size + nth_col = w_start // grid_size + + index = self.grid_cols(grid_size) * nth_row + nth_col + return index * self._data_shape[0] + t + + def get_t(self, index): + return index % self.N + + def get_top_nbr_idx(self, index, grid_size=None): + if self.use_default_grid(grid_size): + grid_size = self._default_grid_size + + ncols = self.grid_cols(grid_size) + index -= ncols * self.N + if index < 0: + return None + + return index + + def get_bottom_nbr_idx(self, index, grid_size=None): + if self.use_default_grid(grid_size): + grid_size = self._default_grid_size + + ncols = self.grid_cols(grid_size) + index += ncols * self.N + if index > self.grid_count(grid_size=grid_size): + return None + + return index + + def get_left_nbr_idx(self, index, grid_size=None): + if self.on_left_boundary(index, grid_size=grid_size): + return None + + index -= self.N + return index + + def get_right_nbr_idx(self, index, grid_size=None): + if self.on_right_boundary(index, grid_size=grid_size): + return None + index += self.N + return index + + def on_left_boundary(self, index, grid_size=None): + if self.use_default_grid(grid_size): + grid_size = self._default_grid_size + + factor = index // self.N + ncols = self.grid_cols(grid_size) + + left_boundary = (factor // ncols) != (factor - 1) // ncols + return left_boundary + + def on_right_boundary(self, index, grid_size=None): + if self.use_default_grid(grid_size): + grid_size = self._default_grid_size + + factor = index // self.N + ncols = self.grid_cols(grid_size) + + right_boundary = (factor // ncols) != (factor + 1) // ncols + return right_boundary + + def on_top_boundary(self, index, grid_size=None): + if self.use_default_grid(grid_size): + grid_size = self._default_grid_size + + ncols = self.grid_cols(grid_size) + return index < self.N * ncols + + def on_bottom_boundary(self, index, grid_size=None): + if self.use_default_grid(grid_size): + grid_size = self._default_grid_size + + ncols = self.grid_cols(grid_size) + return index + self.N * ncols > self.grid_count(grid_size=grid_size) + + def on_boundary(self, idx, grid_size=None): + if self.on_left_boundary(idx, grid_size=grid_size): + return True + + if self.on_right_boundary(idx, grid_size=grid_size): + return True + + if self.on_top_boundary(idx, grid_size=grid_size): + return True + + if self.on_bottom_boundary(idx, grid_size=grid_size): + return True + return False + + def get_deterministic_hw(self, index: int, grid_size=None): + """ + Fixed starting position for the crop for the img with index `index`. + """ + if self.use_default_grid(grid_size): + grid_size = self._default_grid_size + + # _, h, w, _ = self._data_shape + # assert h == w + factor = index // self.N + ncols = self.grid_cols(grid_size) + + ith_row = factor // ncols + jth_col = factor % ncols + h_start = ith_row * grid_size + w_start = jth_col * grid_size + return h_start, w_start + + +class IndexSwitcher: + """ + The idea is to switch from valid indices for target to invalid indices for target. + If index in invalid for the target, then we return all zero vector as target. + This combines both logic: + 1. Using less amount of total data. + 2. Using less amount of target data but using full data. + """ + + def __init__(self, idx_manager, data_config, patch_size) -> None: + self.idx_manager = idx_manager + self._data_shape = self.idx_manager.get_data_shape() + self._training_validtarget_fraction = data_config.get( + "training_validtarget_fraction", 1.0 + ) + self._validtarget_ceilT = int( + np.ceil(self._data_shape[0] * self._training_validtarget_fraction) + ) + self._patch_size = patch_size + assert ( + data_config.deterministic_grid is True + ), "This only works when the dataset has deterministic grid. Needed randomness comes from this class." + assert ( + "grid_size" in data_config and data_config.grid_size == 1 + ), "We need a one to one mapping between index and h, w, t" + + self._h_validmax, self._w_validmax = self.get_reduced_frame_size( + self._data_shape[:3], self._training_validtarget_fraction + ) + if self._h_validmax < self._patch_size or self._w_validmax < self._patch_size: + print( + "WARNING: The valid target size is smaller than the patch size. This will result in all zero target. so, we are ignoring this frame for target." + ) + self._h_validmax = 0 + self._w_validmax = 0 + + print( + f"[{self.__class__.__name__}] Target Indices: [0,{self._validtarget_ceilT-1}]. Index={self._validtarget_ceilT-1} has shape [:{self._h_validmax},:{self._w_validmax}]. Available data: {self._data_shape[0]}" + ) + + def get_valid_target_index(self): + """ + Returns an index which corresponds to a frame which is expected to have a target. + """ + _, h, w, _ = self._data_shape + framepixelcount = h * w + targetpixels = np.array( + [framepixelcount] * (self._validtarget_ceilT - 1) + + [self._h_validmax * self._w_validmax] + ) + targetpixels = targetpixels / np.sum(targetpixels) + t = np.random.choice(self._validtarget_ceilT, p=targetpixels) + # t = np.random.randint(0, self._validtarget_ceilT) if self._validtarget_ceilT >= 1 else 0 + h, w = self.get_valid_target_hw(t) + index = self.idx_manager.idx_from_hwt(h, w, t) + # print('Valid', index, h,w,t) + return index + + def get_invalid_target_index(self): + # if self._validtarget_ceilT == 0: + # TODO: There may not be enough data for this to work. The better way is to skip using 0 for invalid target. + # t = np.random.randint(1, self._data_shape[0]) + # elif self._validtarget_ceilT < self._data_shape[0]: + # t = np.random.randint(self._validtarget_ceilT, self._data_shape[0]) + # else: + # t = self._validtarget_ceilT - 1 + # 5 + # 1.2 => 2 + total_t, h, w, _ = self._data_shape + framepixelcount = h * w + available_h = h - self._h_validmax + if available_h < self._patch_size: + available_h = 0 + available_w = w - self._w_validmax + if available_w < self._patch_size: + available_w = 0 + + targetpixels = np.array( + [available_h * available_w] + + [framepixelcount] * (total_t - self._validtarget_ceilT) + ) + t_probab = targetpixels / np.sum(targetpixels) + t = np.random.choice( + np.arange(self._validtarget_ceilT - 1, total_t), p=t_probab + ) + + h, w = self.get_invalid_target_hw(t) + index = self.idx_manager.idx_from_hwt(h, w, t) + # print('Invalid', index, h,w,t) + return index + + def get_valid_target_hw(self, t): + """ + This is the opposite of get_invalid_target_hw. It returns a h,w which is valid for target. + This is only valid for single frame setup. + """ + if t == self._validtarget_ceilT - 1: + h = np.random.randint(0, self._h_validmax - self._patch_size) + w = np.random.randint(0, self._w_validmax - self._patch_size) + else: + h = np.random.randint(0, self._data_shape[1] - self._patch_size) + w = np.random.randint(0, self._data_shape[2] - self._patch_size) + return h, w + + def get_invalid_target_hw(self, t): + """ + This is the opposite of get_valid_target_hw. It returns a h,w which is not valid for target. + This is only valid for single frame setup. + """ + if t == self._validtarget_ceilT - 1: + h = np.random.randint( + self._h_validmax, self._data_shape[1] - self._patch_size + ) + w = np.random.randint( + self._w_validmax, self._data_shape[2] - self._patch_size + ) + else: + h = np.random.randint(0, self._data_shape[1] - self._patch_size) + w = np.random.randint(0, self._data_shape[2] - self._patch_size) + return h, w + + def _get_tidx(self, index): + if isinstance(index, int) or isinstance(index, np.int64): + idx = index + else: + idx = index[0] + return self.idx_manager.get_t(idx) + + def index_should_have_target(self, index): + tidx = self._get_tidx(index) + if tidx < self._validtarget_ceilT - 1: + return True + elif tidx > self._validtarget_ceilT - 1: + return False + else: + h, w, _ = self.idx_manager.hwt_from_idx(index) + return ( + h + self._patch_size < self._h_validmax + and w + self._patch_size < self._w_validmax + ) + + @staticmethod + def get_reduced_frame_size(data_shape_nhw, fraction): + n, h, w = data_shape_nhw + + framepixelcount = h * w + targetpixelcount = int(n * framepixelcount * fraction) + + # We are currently supporting this only when there is just one frame. + # if np.ceil(pixelcount / framepixelcount) > 1: + # return None, None + + lastframepixelcount = targetpixelcount % framepixelcount + assert data_shape_nhw[1] == data_shape_nhw[2] + if lastframepixelcount > 0: + new_size = int(np.sqrt(lastframepixelcount)) + return new_size, new_size + else: + assert ( + targetpixelcount / framepixelcount >= 1 + ), "This is not possible in euclidean space :D (so this is a bug)" + return h, w + + +rec_header_dtd = [ + ("nx", "i4"), # Number of columns + ("ny", "i4"), # Number of rows + ("nz", "i4"), # Number of sections + ("mode", "i4"), # Types of pixels in the image. Values used by IMOD: + # 0 = unsigned or signed bytes depending on flag in imodFlags + # 1 = signed short integers (16 bits) + # 2 = float (32 bits) + # 3 = short * 2, (used for complex data) + # 4 = float * 2, (used for complex data) + # 6 = unsigned 16-bit integers (non-standard) + # 16 = unsigned char * 3 (for rgb data, non-standard) + ("nxstart", "i4"), # Starting point of sub-image (not used in IMOD) + ("nystart", "i4"), + ("nzstart", "i4"), + ("mx", "i4"), # Grid size in X, Y and Z + ("my", "i4"), + ("mz", "i4"), + ("xlen", "f4"), # Cell size; pixel spacing = xlen/mx, ylen/my, zlen/mz + ("ylen", "f4"), + ("zlen", "f4"), + ("alpha", "f4"), # Cell angles - ignored by IMOD + ("beta", "f4"), + ("gamma", "f4"), + # These need to be set to 1, 2, and 3 for pixel spacing to be interpreted correctly + ("mapc", "i4"), # map column 1=x,2=y,3=z. + ("mapr", "i4"), # map row 1=x,2=y,3=z. + ("maps", "i4"), # map section 1=x,2=y,3=z. + # These need to be set for proper scaling of data + ("amin", "f4"), # Minimum pixel value + ("amax", "f4"), # Maximum pixel value + ("amean", "f4"), # Mean pixel value + ("ispg", "i4"), # space group number (ignored by IMOD) + ( + "next", + "i4", + ), # number of bytes in extended header (called nsymbt in MRC standard) + ("creatid", "i2"), # used to be an ID number, is 0 as of IMOD 4.2.23 + ("extra_data", "V30"), # (not used, first two bytes should be 0) + # These two values specify the structure of data in the extended header; their meaning depend on whether the + # extended header has the Agard format, a series of 4-byte integers then real numbers, or has data + # produced by SerialEM, a series of short integers. SerialEM stores a float as two shorts, s1 and s2, by: + # value = (sign of s1)*(|s1|*256 + (|s2| modulo 256)) * 2**((sign of s2) * (|s2|/256)) + ("nint", "i2"), + # Number of integers per section (Agard format) or number of bytes per section (SerialEM format) + ("nreal", "i2"), # Number of reals per section (Agard format) or bit + # Number of reals per section (Agard format) or bit + # flags for which types of short data (SerialEM format): + # 1 = tilt angle * 100 (2 bytes) + # 2 = piece coordinates for montage (6 bytes) + # 4 = Stage position * 25 (4 bytes) + # 8 = Magnification / 100 (2 bytes) + # 16 = Intensity * 25000 (2 bytes) + # 32 = Exposure dose in e-/A2, a float in 4 bytes + # 128, 512: Reserved for 4-byte items + # 64, 256, 1024: Reserved for 2-byte items + # If the number of bytes implied by these flags does + # not add up to the value in nint, then nint and nreal + # are interpreted as ints and reals per section + ("extra_data2", "V20"), # extra data (not used) + ("imodStamp", "i4"), # 1146047817 indicates that file was created by IMOD + ("imodFlags", "i4"), # Bit flags: 1 = bytes are stored as signed + # Explanation of type of data + ("idtype", "i2"), # ( 0 = mono, 1 = tilt, 2 = tilts, 3 = lina, 4 = lins) + ("lens", "i2"), + # ("nd1", "i2"), # for idtype = 1, nd1 = axis (1, 2, or 3) + # ("nd2", "i2"), + ("nphase", "i4"), + ("vd1", "i2"), # vd1 = 100. * tilt increment + ("vd2", "i2"), # vd2 = 100. * starting angle + # Current angles are used to rotate a model to match a new rotated image. The three values in each set are + # rotations about X, Y, and Z axes, applied in the order Z, Y, X. + ("triangles", "f4", 6), # 0,1,2 = original: 3,4,5 = current + ("xorg", "f4"), # Origin of image + ("yorg", "f4"), + ("zorg", "f4"), + ("cmap", "S4"), # Contains "MAP " + ( + "stamp", + "u1", + 4, + ), # First two bytes have 17 and 17 for big-endian or 68 and 65 for little-endian + ("rms", "f4"), # RMS deviation of densities from mean density + ("nlabl", "i4"), # Number of labels with useful data + ("labels", "S80", 10), # 10 labels of 80 charactors +] + + +def read_mrc(filename, filetype="image"): + + fd = open(filename, "rb") + header = np.fromfile(fd, dtype=rec_header_dtd, count=1) + + nx, ny, nz = header["nx"][0], header["ny"][0], header["nz"][0] + + if header[0][3] == 1: + data_type = "int16" + elif header[0][3] == 2: + data_type = "float32" + elif header[0][3] == 4: + data_type = "single" + nx = nx * 2 + elif header[0][3] == 6: + data_type = "uint16" + + data = np.ndarray(shape=(nx, ny, nz)) + imgrawdata = np.fromfile(fd, data_type) + fd.close() + + if filetype == "image": + for iz in range(nz): + data_2d = imgrawdata[nx * ny * iz : nx * ny * (iz + 1)] + data[:, :, iz] = data_2d.reshape(nx, ny, order="F") + else: + data = imgrawdata + + return header, data diff --git a/src/careamics/lvae_training/eval_utils.py b/src/careamics/lvae_training/eval_utils.py new file mode 100644 index 000000000..6a5bdd7b3 --- /dev/null +++ b/src/careamics/lvae_training/eval_utils.py @@ -0,0 +1,905 @@ +""" +This script provides methods to evaluate the performance of the LVAE model. +It includes functions to: + - make predictions, + - quantify the performance of the model + - create plots to visualize the results. +""" + +import math +import os +from typing import Dict, List, Literal, Union + +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import torch +from matplotlib.gridspec import GridSpec +from torch.utils.data import DataLoader +from tqdm import tqdm + +from careamics.models.lvae.utils import ModelType + +from .metrics import RangeInvariantPsnr, RunningPSNR + + +# ------------------------------------------------------------------------------------------------ +# Function of plotting: TODO -> moved them to another file, plot_utils.py +def clean_ax(ax): + """ + Helper function to remove ticks from axes in plots. + """ + # 2D or 1D axes are of type np.ndarray + if isinstance(ax, np.ndarray): + for one_ax in ax: + clean_ax(one_ax) + return + + ax.set_yticklabels([]) + ax.set_xticklabels([]) + ax.tick_params(left=False, right=False, top=False, bottom=False) + + +def get_plots_output_dir( + saveplotsdir: str, patch_size: int, mmse_count: int = 50 +) -> str: + """ + Given the path to a root directory to save plots, patch size, and mmse count, + it returns the specific directory to save the plots. + """ + plotsrootdir = os.path.join( + saveplotsdir, f"plots/patch_{patch_size}_mmse_{mmse_count}" + ) + os.makedirs(plotsrootdir, exist_ok=True) + print(plotsrootdir) + return plotsrootdir + + +def get_psnr_str(tar_hsnr, pred, col_idx): + """ + Compute PSNR between the ground truth (`tar_hsnr`) and the predicted image (`pred`). + """ + return ( + f"{RangeInvariantPsnr(tar_hsnr[col_idx][None], pred[col_idx][None]).item():.1f}" + ) + + +def add_psnr_str(ax_, psnr): + """ + Add psnr string to the axes + """ + textstr = f"PSNR\n{psnr}" + props = dict(boxstyle="round", facecolor="gray", alpha=0.5) + # place a text box in upper left in axes coords + ax_.text( + 0.05, + 0.95, + textstr, + transform=ax_.transAxes, + fontsize=11, + verticalalignment="top", + bbox=props, + color="white", + ) + + +def get_last_index(bin_count, quantile): + cumsum = np.cumsum(bin_count) + normalized_cumsum = cumsum / cumsum[-1] + for i in range(1, len(normalized_cumsum)): + if normalized_cumsum[-i] < quantile: + return i - 1 + return None + + +def get_first_index(bin_count, quantile): + cumsum = np.cumsum(bin_count) + normalized_cumsum = cumsum / cumsum[-1] + for i in range(len(normalized_cumsum)): + if normalized_cumsum[i] > quantile: + return i + return None + + +def show_for_one( + idx, + val_dset, + highsnr_val_dset, + model, + calibration_stats, + mmse_count=5, + patch_size=256, + num_samples=2, + baseline_preds=None, +): + """ + Given an index, it plots the input, target, reconstructed images and the difference image. + Note the the difference image is computed with respect to a ground truth image, obtained from the high SNR dataset. + """ + highsnr_val_dset.set_img_sz(patch_size, 64) + highsnr_val_dset.disable_noise() + _, tar_hsnr = highsnr_val_dset[idx] + inp, tar, recon_img_list = get_predictions( + idx, val_dset, model, mmse_count=mmse_count, patch_size=patch_size + ) + plot_crops( + inp, + tar, + tar_hsnr, + recon_img_list, + calibration_stats, + num_samples=num_samples, + baseline_preds=baseline_preds, + ) + + +def plot_crops( + inp, + tar, + tar_hsnr, + recon_img_list, + calibration_stats, + num_samples=2, + baseline_preds=None, +): + """ """ + if baseline_preds is None: + baseline_preds = [] + if len(baseline_preds) > 0: + for i in range(len(baseline_preds)): + if baseline_preds[i].shape != tar_hsnr.shape: + print( + f"Baseline prediction {i} shape {baseline_preds[i].shape} does not match target shape {tar_hsnr.shape}" + ) + print("This happens when we want to predict the edges of the image.") + return + + # color_ch_list = ['goldenrod', 'cyan'] + # color_pred = 'red' + # insetplot_xmax_value = 10000 + # insetplot_xmin_value = -1000 + # inset_min_labelsize = 10 + # inset_rect = [0.05, 0.05, 0.4, 0.2] + + # Set plot attributes + img_sz = 3 + ncols = num_samples + len(baseline_preds) + 1 + 1 + 1 + 1 + 1 * (num_samples > 1) + grid_factor = 5 + grid_img_sz = img_sz * grid_factor + example_spacing = 1 + c0_extra = 1 + nimgs = 1 + fig_w = ncols * img_sz + 2 * c0_extra / grid_factor + fig_h = int(img_sz * ncols + (example_spacing * (nimgs - 1)) / grid_factor) + fig = plt.figure(figsize=(fig_w, fig_h)) + gs = GridSpec( + nrows=int(grid_factor * fig_h), + ncols=int(grid_factor * fig_w), + hspace=0.2, + wspace=0.2, + ) + params = {"mathtext.default": "regular"} + plt.rcParams.update(params) + + # plot baselines + for i in range(2, 2 + len(baseline_preds)): + for col_idx in range(baseline_preds[0].shape[0]): + ax_temp = fig.add_subplot( + gs[ + col_idx * grid_img_sz : grid_img_sz * (col_idx + 1), + i * grid_img_sz + c0_extra : (i + 1) * grid_img_sz + c0_extra, + ] + ) + print(tar_hsnr.shape, baseline_preds[i - 2].shape) + psnr = get_psnr_str(tar_hsnr, baseline_preds[i - 2], col_idx) + ax_temp.imshow(baseline_preds[i - 2][col_idx], cmap="magma") + add_psnr_str(ax_temp, psnr) + clean_ax(ax_temp) + + # plot samples + sample_start_idx = 2 + len(baseline_preds) + for i in range(sample_start_idx, ncols - 3): + for col_idx in range(recon_img_list.shape[1]): + ax_temp = fig.add_subplot( + gs[ + col_idx * grid_img_sz : grid_img_sz * (col_idx + 1), + i * grid_img_sz + c0_extra : (i + 1) * grid_img_sz + c0_extra, + ] + ) + psnr = get_psnr_str(tar_hsnr, recon_img_list[i - sample_start_idx], col_idx) + ax_temp.imshow(recon_img_list[i - sample_start_idx][col_idx], cmap="magma") + add_psnr_str(ax_temp, psnr) + clean_ax(ax_temp) + # inset_ax = add_pixel_kde(ax_temp, + # inset_rect, + # [tar_hsnr[col_idx], + # recon_img_list[i - sample_start_idx][col_idx]], + # inset_min_labelsize, + # label_list=['', ''], + # color_list=[color_ch_list[col_idx], color_pred], + # plot_xmax_value=insetplot_xmax_value, + # plot_xmin_value=insetplot_xmin_value) + + # inset_ax.set_xticks([]) + # inset_ax.set_yticks([]) + + # difference image + if num_samples > 1: + for col_idx in range(recon_img_list.shape[1]): + ax_temp = fig.add_subplot( + gs[ + col_idx * grid_img_sz : grid_img_sz * (col_idx + 1), + (ncols - 3) * grid_img_sz + + c0_extra : (ncols - 2) * grid_img_sz + + c0_extra, + ] + ) + ax_temp.imshow( + recon_img_list[1][col_idx] - recon_img_list[0][col_idx], cmap="coolwarm" + ) + clean_ax(ax_temp) + + for col_idx in range(recon_img_list.shape[1]): + # print(recon_img_list.shape) + ax_temp = fig.add_subplot( + gs[ + col_idx * grid_img_sz : grid_img_sz * (col_idx + 1), + c0_extra + + (ncols - 2) * grid_img_sz : (ncols - 1) * grid_img_sz + + c0_extra, + ] + ) + psnr = get_psnr_str(tar_hsnr, recon_img_list.mean(axis=0), col_idx) + ax_temp.imshow(recon_img_list.mean(axis=0)[col_idx], cmap="magma") + add_psnr_str(ax_temp, psnr) + # inset_ax = add_pixel_kde(ax_temp, + # inset_rect, + # [tar_hsnr[col_idx], + # recon_img_list.mean(axis=0)[col_idx]], + # inset_min_labelsize, + # label_list=['', ''], + # color_list=[color_ch_list[col_idx], color_pred], + # plot_xmax_value=insetplot_xmax_value, + # plot_xmin_value=insetplot_xmin_value) + # inset_ax.set_xticks([]) + # inset_ax.set_yticks([]) + + clean_ax(ax_temp) + + ax_temp = fig.add_subplot( + gs[ + col_idx * grid_img_sz : grid_img_sz * (col_idx + 1), + (ncols - 1) * grid_img_sz + + 2 * c0_extra : (ncols) * grid_img_sz + + 2 * c0_extra, + ] + ) + ax_temp.imshow(tar_hsnr[col_idx], cmap="magma") + if col_idx == 0: + legend_ch1_ax = ax_temp + if col_idx == 1: + legend_ch2_ax = ax_temp + + # inset_ax = add_pixel_kde(ax_temp, + # inset_rect, + # [tar_hsnr[col_idx], + # ], + # inset_min_labelsize, + # label_list=[''], + # color_list=[color_ch_list[col_idx]], + # plot_xmax_value=insetplot_xmax_value, + # plot_xmin_value=insetplot_xmin_value) + # inset_ax.set_xticks([]) + # inset_ax.set_yticks([]) + + clean_ax(ax_temp) + + ax_temp = fig.add_subplot( + gs[ + col_idx * grid_img_sz : grid_img_sz * (col_idx + 1), + grid_img_sz : 2 * grid_img_sz, + ] + ) + ax_temp.imshow(tar[0, col_idx].cpu().numpy(), cmap="magma") + # inset_ax = add_pixel_kde(ax_temp, + # inset_rect, + # [tar[0,col_idx].cpu().numpy(), + # ], + # inset_min_labelsize, + # label_list=[''], + # color_list=[color_ch_list[col_idx]], + # plot_kwargs_list=[{'linestyle':'--'}], + # plot_xmax_value=insetplot_xmax_value, + # plot_xmin_value=insetplot_xmin_value) + + # inset_ax.set_xticks([]) + # inset_ax.set_yticks([]) + + clean_ax(ax_temp) + + ax_temp = fig.add_subplot(gs[0:grid_img_sz, 0:grid_img_sz]) + ax_temp.imshow(inp[0, 0].cpu().numpy(), cmap="magma") + clean_ax(ax_temp) + + # line_ch1 = mlines.Line2D([0, 1], [0, 1], color=color_ch_list[0], linestyle='-', label='$C_1$') + # line_ch2 = mlines.Line2D([0, 1], [0, 1], color=color_ch_list[1], linestyle='-', label='$C_2$') + # line_pred = mlines.Line2D([0, 1], [0, 1], color=color_pred, linestyle='-', label='Pred') + # line_noisych1 = mlines.Line2D([0, 1], [0, 1], color=color_ch_list[0], linestyle='--', label='$C^N_1$') + # line_noisych2 = mlines.Line2D([0, 1], [0, 1], color=color_ch_list[1], linestyle='--', label='$C^N_2$') + # legend_ch1 = legend_ch1_ax.legend(handles=[line_ch1, line_noisych1, line_pred], loc='upper right', frameon=False, labelcolor='white', + # prop={'size': 11}) + # legend_ch2 = legend_ch2_ax.legend(handles=[line_ch2, line_noisych2, line_pred], loc='upper right', frameon=False, labelcolor='white', + # prop={'size': 11}) + + if calibration_stats is not None: + smaller_offset = 4 + ax_temp = fig.add_subplot( + gs[ + grid_img_sz + 1 : 2 * grid_img_sz - smaller_offset + 1, + smaller_offset - 1 : grid_img_sz - 1, + ] + ) + plot_calibration(ax_temp, calibration_stats) + + +def plot_calibration(ax, calibration_stats): + """ + To plot calibration statistics (RMV vs RMSE). + """ + first_idx = get_first_index(calibration_stats[0]["bin_count"], 0.001) + last_idx = get_last_index(calibration_stats[0]["bin_count"], 0.999) + ax.plot( + calibration_stats[0]["rmv"][first_idx:-last_idx], + calibration_stats[0]["rmse"][first_idx:-last_idx], + "o", + label=r"$\hat{C}_0$", + ) + + first_idx = get_first_index(calibration_stats[1]["bin_count"], 0.001) + last_idx = get_last_index(calibration_stats[1]["bin_count"], 0.999) + ax.plot( + calibration_stats[1]["rmv"][first_idx:-last_idx], + calibration_stats[1]["rmse"][first_idx:-last_idx], + "o", + label=r"$\hat{C}_1$", + ) + + ax.set_xlabel("RMV") + ax.set_ylabel("RMSE") + ax.legend() + + +def shiftedColorMap(cmap, start=0, midpoint=0.5, stop=1.0, name="shiftedcmap"): + """ + Adapted from https://stackoverflow.com/questions/7404116/defining-the-midpoint-of-a-colormap-in-matplotlib + + Function to offset the "center" of a colormap. Useful for + data with a negative min and positive max and you want the + middle of the colormap's dynamic range to be at zero. + + Input + ----- + cmap : The matplotlib colormap to be altered + start : Offset from lowest point in the colormap's range. + Defaults to 0.0 (no lower offset). Should be between + 0.0 and `midpoint`. + midpoint : The new center of the colormap. Defaults to + 0.5 (no shift). Should be between 0.0 and 1.0. In + general, this should be 1 - vmax / (vmax + abs(vmin)) + For example if your data range from -15.0 to +5.0 and + you want the center of the colormap at 0.0, `midpoint` + should be set to 1 - 5/(5 + 15)) or 0.75 + stop : Offset from highest point in the colormap's range. + Defaults to 1.0 (no upper offset). Should be between + `midpoint` and 1.0. + """ + cdict = {"red": [], "green": [], "blue": [], "alpha": []} + + # regular index to compute the colors + reg_index = np.linspace(start, stop, 257) + mid_idx = len(reg_index) // 2 + # shifted index to match the data + shift_index = np.hstack( + [ + np.linspace(0.0, midpoint, 128, endpoint=False), + np.linspace(midpoint, 1.0, 129, endpoint=True), + ] + ) + + for ri, si in zip(reg_index, shift_index): + r, g, b, a = cmap(ri) + a = np.abs(ri - reg_index[mid_idx]) / reg_index[mid_idx] + # print(a) + cdict["red"].append((si, r, r)) + cdict["green"].append((si, g, g)) + cdict["blue"].append((si, b, b)) + cdict["alpha"].append((si, a, a)) + + newcmap = matplotlib.colors.LinearSegmentedColormap(name, cdict) + matplotlib.colormaps.register(cmap=newcmap, force=True) + + return newcmap + + +def get_fractional_change(target, prediction, max_val=None): + """ + Get relative difference between target and prediction. + """ + if max_val is None: + max_val = target.max() + return (target - prediction) / max_val + + +def get_zero_centered_midval(error): + """ + When done this way, the midval ensures that the colorbar is centered at 0. (Don't know how, but it works ;)) + """ + vmax = error.max() + vmin = error.min() + midval = 1 - vmax / (vmax + abs(vmin)) + return midval + + +def plot_error(target, prediction, cmap=matplotlib.cm.coolwarm, ax=None, max_val=None): + """ + Plot the relative difference between target and prediction. + NOTE: The plot is overlapped to the prediction image (in gray scale). + NOTE: The colorbar is centered at 0. + """ + if ax is None: + _, ax = plt.subplots(figsize=(6, 6)) + + # Relative difference between target and prediction + rel_diff = get_fractional_change(target, prediction, max_val=max_val) + midval = get_zero_centered_midval(rel_diff) + shifted_cmap = shiftedColorMap( + cmap, start=0, midpoint=midval, stop=1.0, name="shiftedcmap" + ) + ax.imshow(prediction, cmap="gray") + img_err = ax.imshow(rel_diff, cmap=shifted_cmap, alpha=1) + plt.colorbar(img_err, ax=ax) + + +# ------------------------------------------------------------------------------------------------ + + +def get_predictions(idx, val_dset, model, mmse_count=50, patch_size=256): + """ + Given an index and a validation/test set, it returns the input, target and the reconstructed images for that index. + """ + print(f"Predicting for {idx}") + val_dset.set_img_sz(patch_size, 64) + + with torch.no_grad(): + # val_dset.enable_noise() + inp, tar = val_dset[idx] + # val_dset.disable_noise() + + inp = torch.Tensor(inp[None]) + tar = torch.Tensor(tar[None]) + inp = inp.cuda() + x_normalized = model.normalize_input(inp) + tar = tar.cuda() + tar_normalized = model.normalize_target(tar) + + recon_img_list = [] + for _ in range(mmse_count): + recon_normalized, td_data = model(x_normalized) + rec_loss, imgs = model.get_reconstruction_loss( + recon_normalized, + x_normalized, + tar_normalized, + return_predicted_img=True, + ) + imgs = model.unnormalize_target(imgs) + recon_img_list.append(imgs.cpu().numpy()[0]) + + recon_img_list = np.array(recon_img_list) + return inp, tar, recon_img_list + + +def get_dset_predictions( + model, + dset, + batch_size: int, + model_type: ModelType = None, + mmse_count: int = 1, + num_workers: int = 4, +): + """ + Get predictions from a model for the entire dataset. + + Parameters + ---------- + mmse_count : int + Number of samples to generate for each input and then to average over for MMSE estimation. + """ + dloader = DataLoader( + dset, + pin_memory=False, + num_workers=num_workers, + shuffle=False, + batch_size=batch_size, + ) + likelihood = model.model.likelihood + predictions = [] + predictions_std = [] + losses = [] + logvar_arr = [] + patch_psnr_channels = [RunningPSNR() for _ in range(dset[0][1].shape[0])] + with torch.no_grad(): + for batch in tqdm(dloader): + inp, tar = batch[:2] + inp = inp.cuda() + tar = tar.cuda() + + recon_img_list = [] + for mmse_idx in range(mmse_count): + if model_type == ModelType.Denoiser: + assert model.denoise_channel in [ + "Ch1", + "Ch2", + "input", + ], '"all" denoise channel not supported for evaluation. Pick one of "Ch1", "Ch2", "input"' + + x_normalized_new, tar_new = model.get_new_input_target( + (inp, tar, *batch[2:]) + ) + tar_normalized = model.normalize_target(tar_new) + recon_normalized, _ = model(x_normalized_new) + rec_loss, imgs = model.get_reconstruction_loss( + recon_normalized, + tar_normalized, + x_normalized_new, + return_predicted_img=True, + ) + else: + x_normalized = model.normalize_input(inp) + tar_normalized = model.normalize_target(tar) + recon_normalized, _ = model(x_normalized) + rec_loss, imgs = model.get_reconstruction_loss( + recon_normalized, tar_normalized, inp, return_predicted_img=True + ) + + if mmse_idx == 0: + q_dic = ( + likelihood.distr_params(recon_normalized) + if likelihood is not None + else {"logvar": None} + ) + if q_dic["logvar"] is not None: + logvar_arr.append(q_dic["logvar"].cpu().numpy()) + else: + logvar_arr.append(np.array([-1])) + + try: + losses.append(rec_loss["loss"].cpu().numpy()) + except: + losses.append(rec_loss["loss"]) + + for i in range(imgs.shape[1]): + patch_psnr_channels[i].update(imgs[:, i], tar_normalized[:, i]) + + recon_img_list.append(imgs.cpu()[None]) + + samples = torch.cat(recon_img_list, dim=0) + mmse_imgs = torch.mean(samples, dim=0) + mmse_std = torch.std(samples, dim=0) + predictions.append(mmse_imgs.cpu().numpy()) + predictions_std.append(mmse_std.cpu().numpy()) + + psnr = [x.get() for x in patch_psnr_channels] + return ( + np.concatenate(predictions, axis=0), + np.array(losses), + np.concatenate(logvar_arr), + psnr, + np.concatenate(predictions_std, axis=0), + ) + + +# ------------------------------------------------------------------------------------------ +### Classes and Functions used to stitch predictions +class PatchLocation: + """ + Encapsulates t_idx and spatial location. + """ + + def __init__(self, h_idx_range, w_idx_range, t_idx): + self.t = t_idx + self.h_start, self.h_end = h_idx_range + self.w_start, self.w_end = w_idx_range + + def __str__(self): + msg = f"T:{self.t} [{self.h_start}-{self.h_end}) [{self.w_start}-{self.w_end}) " + return msg + + +def _get_location(extra_padding, hwt, pred_h, pred_w): + h_start, w_start, t_idx = hwt + h_start -= extra_padding + h_end = h_start + pred_h + w_start -= extra_padding + w_end = w_start + pred_w + return PatchLocation((h_start, h_end), (w_start, w_end), t_idx) + + +def get_location_from_idx(dset, dset_input_idx, pred_h, pred_w): + """ + For a given idx of the dataset, it returns where exactly in the dataset, does this prediction lies. + Note that this prediction also has padded pixels and so a subset of it will be used in the final prediction. + Which time frame, which spatial location (h_start, h_end, w_start,w_end) + Args: + dset: + dset_input_idx: + pred_h: + pred_w: + + Returns + ------- + """ + extra_padding = dset.per_side_overlap_pixelcount() + htw = dset.get_idx_manager().hwt_from_idx( + dset_input_idx, grid_size=dset.get_grid_size() + ) + return _get_location(extra_padding, htw, pred_h, pred_w) + + +def remove_pad(pred, loc, extra_padding, smoothening_pixelcount, frame_shape): + assert smoothening_pixelcount == 0 + if extra_padding - smoothening_pixelcount > 0: + h_s = extra_padding - smoothening_pixelcount + + # rows + h_N = frame_shape[0] + if loc.h_end > h_N: + assert loc.h_end - extra_padding + smoothening_pixelcount <= h_N + h_e = extra_padding - smoothening_pixelcount + + w_s = extra_padding - smoothening_pixelcount + + # columns + w_N = frame_shape[1] + if loc.w_end > w_N: + assert loc.w_end - extra_padding + smoothening_pixelcount <= w_N + + w_e = extra_padding - smoothening_pixelcount + + return pred[h_s:-h_e, w_s:-w_e] + + return pred + + +def update_loc_for_final_insertion(loc, extra_padding, smoothening_pixelcount): + extra_padding = extra_padding - smoothening_pixelcount + loc.h_start += extra_padding + loc.w_start += extra_padding + loc.h_end -= extra_padding + loc.w_end -= extra_padding + return loc + + +def stitch_predictions(predictions, dset, smoothening_pixelcount=0): + """ + Args: + smoothening_pixelcount: number of pixels which can be interpolated + """ + assert smoothening_pixelcount >= 0 and isinstance(smoothening_pixelcount, int) + extra_padding = dset.per_side_overlap_pixelcount() + # if there are more channels, use all of them. + shape = list(dset.get_data_shape()) + shape[-1] = max(shape[-1], predictions.shape[1]) + + output = np.zeros(shape, dtype=predictions.dtype) + frame_shape = dset.get_data_shape()[1:3] + for dset_input_idx in range(predictions.shape[0]): + loc = get_location_from_idx( + dset, dset_input_idx, predictions.shape[-2], predictions.shape[-1] + ) + + mask = None + cropped_pred_list = [] + for ch_idx in range(predictions.shape[1]): + # class i + cropped_pred_i = remove_pad( + predictions[dset_input_idx, ch_idx], + loc, + extra_padding, + smoothening_pixelcount, + frame_shape, + ) + + if mask is None: + # NOTE: don't need to compute it for every patch. + assert ( + smoothening_pixelcount == 0 + ), "For smoothing,enable the get_smoothing_mask. It is disabled since I don't use it and it needs modification to work with non-square images" + mask = 1 + # mask = _get_smoothing_mask(cropped_pred_i.shape, smoothening_pixelcount, loc, frame_size) + + cropped_pred_list.append(cropped_pred_i) + + loc = update_loc_for_final_insertion(loc, extra_padding, smoothening_pixelcount) + for ch_idx in range(predictions.shape[1]): + output[loc.t, loc.h_start : loc.h_end, loc.w_start : loc.w_end, ch_idx] += ( + cropped_pred_list[ch_idx] * mask + ) + + return output + + +# ------------------------------------------------------------------------------------------ + + +# ------------------------------------------------------------------------------------------ +### Classes and Functions used for Calibration +class Calibration: + + def __init__( + self, num_bins: int = 15, mode: Literal["pixelwise", "patchwise"] = "pixelwise" + ): + self._bins = num_bins + self._bin_boundaries = None + self._mode = mode + assert mode in ["pixelwise", "patchwise"] + self._boundary_mode = "uniform" + assert self._boundary_mode in ["quantile", "uniform"] + # self._bin_boundaries = {} + + def logvar_to_std(self, logvar: np.ndarray) -> np.ndarray: + return np.exp(logvar / 2) + + def compute_bin_boundaries(self, predict_logvar: np.ndarray) -> np.ndarray: + """ + Compute the bin boundaries for `num_bins` bins and the given logvar values. + """ + if self._boundary_mode == "quantile": + boundaries = np.quantile( + self.logvar_to_std(predict_logvar), np.linspace(0, 1, self._bins + 1) + ) + return boundaries + else: + min_logvar = np.min(predict_logvar) + max_logvar = np.max(predict_logvar) + min_std = self.logvar_to_std(min_logvar) + max_std = self.logvar_to_std(max_logvar) + return np.linspace(min_std, max_std, self._bins + 1) + + def compute_stats( + self, pred: np.ndarray, pred_logvar: np.ndarray, target: np.ndarray + ) -> Dict[int, Dict[str, Union[np.ndarray, List]]]: + """ + It computes the bin-wise RMSE and RMV for each channel of the predicted image. + + Recall that: + - RMSE = np.sqrt((pred - target)**2 / num_pixels) + - RMV = np.sqrt(np.mean(pred_std**2)) + + ALGORITHM + - For each channel: + - Given the bin boundaries, assign pixels of `std_ch` array to a specific bin index. + - For each bin index: + - Compute the RMSE, RMV, and number of pixels for that bin. + + NOTE: each channel of the predicted image/logvar has its own stats. + + Args: + pred: np.ndarray, shape (n, h, w, c) + pred_logvar: np.ndarray, shape (n, h, w, c) + target: np.ndarray, shape (n, h, w, c) + """ + self._bin_boundaries = {} + stats = {} + for ch_idx in range(pred.shape[-1]): + stats[ch_idx] = { + "bin_count": [], + "rmv": [], + "rmse": [], + "bin_boundaries": None, + "bin_matrix": [], + } + pred_ch = pred[..., ch_idx] + logvar_ch = pred_logvar[..., ch_idx] + std_ch = self.logvar_to_std(logvar_ch) + target_ch = target[..., ch_idx] + if self._mode == "pixelwise": + boundaries = self.compute_bin_boundaries(logvar_ch) + stats[ch_idx]["bin_boundaries"] = boundaries + bin_matrix = np.digitize(std_ch.reshape(-1), boundaries) + bin_matrix = bin_matrix.reshape(std_ch.shape) + stats[ch_idx]["bin_matrix"] = bin_matrix + error = (pred_ch - target_ch) ** 2 + for bin_idx in range(self._bins): + bin_mask = bin_matrix == bin_idx + bin_error = error[bin_mask] + bin_size = np.sum(bin_mask) + bin_error = ( + np.sqrt(np.sum(bin_error) / bin_size) if bin_size > 0 else None + ) # RMSE + bin_var = np.sqrt(np.mean(std_ch[bin_mask] ** 2)) # RMV + stats[ch_idx]["rmse"].append(bin_error) + stats[ch_idx]["rmv"].append(bin_var) + stats[ch_idx]["bin_count"].append(bin_size) + else: + raise NotImplementedError("Patchwise mode is not implemented yet.") + return stats + + +def nll(x, mean, logvar): + """ + Log of the probability density of the values x under the Normal + distribution with parameters mean and logvar. + + :param x: tensor of points, with shape (batch, channels, dim1, dim2) + :param mean: tensor with mean of distribution, shape + (batch, channels, dim1, dim2) + :param logvar: tensor with log-variance of distribution, shape has to be + either scalar or broadcastable + """ + var = torch.exp(logvar) + log_prob = -0.5 * ( + ((x - mean) ** 2) / var + logvar + torch.tensor(2 * math.pi).log() + ) + nll = -log_prob + return nll + + +def get_calibrated_factor_for_stdev( + pred: Union[np.ndarray, torch.Tensor], + pred_logvar: Union[np.ndarray, torch.Tensor], + target: Union[np.ndarray, torch.Tensor], + batch_size: int = 32, + epochs: int = 500, + lr: float = 0.01, +): + """ + Here, we calibrate the uncertainty by multiplying the predicted std (mmse estimate or predicted logvar) with a scalar. + We return the calibrated scalar. This needs to be multiplied with the std. + + NOTE: Why is the input logvar and not std? because the model typically predicts logvar and not std. + """ + # create a learnable scalar + scalar = torch.nn.Parameter(torch.tensor(2.0)) + optimizer = torch.optim.Adam([scalar], lr=lr) + + bar = tqdm(range(epochs)) + for _ in bar: + optimizer.zero_grad() + # Select a random batch of predictions + mask = np.random.randint(0, pred.shape[0], batch_size) + pred_batch = torch.Tensor(pred[mask]).cuda() + pred_logvar_batch = torch.Tensor(pred_logvar[mask]).cuda() + target_batch = torch.Tensor(target[mask]).cuda() + + loss = torch.mean( + nll(target_batch, pred_batch, pred_logvar_batch + torch.log(scalar)) + ) + loss.backward() + optimizer.step() + bar.set_description(f"nll: {loss.item()} scalar: {scalar.item()}") + + return np.sqrt(scalar.item()) + + +def plot_calibration(ax, calibration_stats): + first_idx = get_first_index(calibration_stats[0]["bin_count"], 0.001) + last_idx = get_last_index(calibration_stats[0]["bin_count"], 0.999) + ax.plot( + calibration_stats[0]["rmv"][first_idx:-last_idx], + calibration_stats[0]["rmse"][first_idx:-last_idx], + "o", + label=r"$\hat{C}_0$: Ch1", + ) + + first_idx = get_first_index(calibration_stats[1]["bin_count"], 0.001) + last_idx = get_last_index(calibration_stats[1]["bin_count"], 0.999) + ax.plot( + calibration_stats[1]["rmv"][first_idx:-last_idx], + calibration_stats[1]["rmse"][first_idx:-last_idx], + "o", + label=r"$\hat{C}_1: : Ch2$", + ) + + ax.set_xlabel("RMV") + ax.set_ylabel("RMSE") + ax.legend() diff --git a/src/careamics/lvae_training/get_config.py b/src/careamics/lvae_training/get_config.py new file mode 100644 index 000000000..c2353047c --- /dev/null +++ b/src/careamics/lvae_training/get_config.py @@ -0,0 +1,84 @@ +""" +Here there are functions to define a config file. +""" + +import os + +import ml_collections + +from careamics.lvae_training.data_utils import DataType +from careamics.models.lvae.utils import LossType + + +def _init_config(): + """ + Create a default config object with all the required fields. + """ + config = ml_collections.ConfigDict() + + config.data = ml_collections.ConfigDict() + + config.model = ml_collections.ConfigDict() + + config.loss = ml_collections.ConfigDict() + + config.training = ml_collections.ConfigDict() + + config.workdir = os.getcwd() + config.datadir = "" + return config + + +def get_config(): + config = _init_config() + + data = config.data + data.image_size = 128 # the patch size + # data.grid_size = 32 # the retained sub-patch when doing inner tiling + data.multiscale_lowres_count = ( + None # todo: this one will be an issue in current careamics + ) + data.num_channels = 2 # in careamics probably in lvae pydantic model + + model = config.model # all in lvae pydantic model + model.z_dims = [128, 128, 128, 128] + model.n_filters = 64 + model.dropout = 0.1 + model.nonlin = "elu" + model.enable_noise_model = True + model.analytical_kl = False + model.predict_logvar = None + + loss = config.loss # in algorithm config + loss.loss_type = LossType.Elbo # LossType.Elbo or LossType.DenoiSplitMuSplit + loss.kl_loss_formulation = "" # '', 'usplit', 'denoisplit' + + training = config.training + training.lr = 0.001 # in algorithm config + training.lr_scheduler_patience = 30 + training.batch_size = 32 # in data config + training.earlystop_patience = ( + 200 # in training config in the callbacks (early stopping) + ) + training.max_epochs = 400 # training config + training.pre_trained_ckpt_fpath = "" # this is through the careamics API + + # Set of attributes not to include in the PyDantic data model + training.num_workers = ( + 4 # this is in the data config, passed in the dataloader parameters + ) + training.grad_clip_norm_value = 0.5 # Taken from https://github.com/openai/vdvae/blob/main/hps.py#L38 # this maybe should be in a new trainer_parameters dict in the training config pydantic model + training.gradient_clip_algorithm = "value" + training.precision = 32 + data.data_type = DataType.BioSR_MRC + data.ch1_fname = "ER/GT_all.mrc" + data.ch2_fname = "Microtubules/GT_all.mrc" + model.noise_model_ch1_fpath = "/group/jug/ashesh/training_pre_eccv/noise_model/2402/429/GMMNoiseModel_ER-GT_all__6_4_Clip0.0-1.0_Sig0.125_UpNone_Norm0_bootstrap.npz" + model.noise_model_ch2_fpath = "/group/jug/ashesh/training_pre_eccv/noise_model/2402/434/GMMNoiseModel_Microtubules-GT_all__6_4_Clip0.0-1.0_Sig0.125_UpNone_Norm0_bootstrap.npz" + # Parameters to apply synthetic noise to data (e.g., used with BioSR data for denoiSplit) + data.poisson_noise_factor = 1000 + data.enable_gaussian_noise = True + data.synthetic_gaussian_scale = 4450 + data.input_has_dependant_noise = True + + return config diff --git a/src/careamics/lvae_training/lightning_module.py b/src/careamics/lvae_training/lightning_module.py new file mode 100644 index 000000000..6ead5a83c --- /dev/null +++ b/src/careamics/lvae_training/lightning_module.py @@ -0,0 +1,701 @@ +""" +Lightning Module for LadderVAE. +""" + +from typing import Any, Dict + +import ml_collections +import numpy as np +import pytorch_lightning as L +import torch +import torchvision.transforms.functional as F + +from careamics.models.lvae.likelihoods import LikelihoodModule +from careamics.models.lvae.lvae import LadderVAE +from careamics.models.lvae.utils import ( + LossType, + compute_batch_mean, + free_bits_kl, + torch_nanmean, +) + +from .metrics import RangeInvariantPsnr, RunningPSNR +from .train_utils import MetricMonitor + + +class LadderVAELight(L.LightningModule): + + def __init__( + self, + config: ml_collections.ConfigDict, + data_mean: Dict[str, torch.Tensor], + data_std: Dict[str, torch.Tensor], + target_ch: int, + ): + """ + Here we will do the following: + - initialize the model (from LadderVAE class) + - initialize the parameters related to the training and loss. + + NOTE: + Some of the model attributes are defined in the model object itself, while some others will be defined here. + Note that all the attributes related to the training and loss that were already defined in the model object + are redefined here as Lightning module attributes (e.g., self.some_attr = model.some_attr). + The attributes related to the model itself are treated as model attributes (e.g., self.model.some_attr). + + NOTE: HC stands for Hard Coded attribute. + """ + super().__init__() + + self.data_mean = data_mean + self.data_std = data_std + self.target_ch = target_ch + + # Initialize LVAE model + self.model = LadderVAE( + data_mean=data_mean, data_std=data_std, config=config, target_ch=target_ch + ) + + ##### Define attributes from config ##### + self.workdir = config.workdir + self._input_is_sum = False + self.kl_loss_formulation = config.loss.kl_loss_formulation + assert self.kl_loss_formulation in [ + None, + "", + "usplit", + "denoisplit", + "denoisplit_usplit", + ], f""" + Invalid kl_loss_formulation. {self.kl_loss_formulation}""" + + ##### Define loss attributes ##### + # Parameters already defined in the model object + self.loss_type = self.model.loss_type + self._denoisplit_w = self._usplit_w = None + if self.loss_type == LossType.DenoiSplitMuSplit: + self._usplit_w = 0 + self._denoisplit_w = 1 - self._usplit_w + assert self._denoisplit_w + self._usplit_w == 1 + self._restricted_kl = self.model._restricted_kl + + # General loss parameters + self.channel_1_w = 1 + self.channel_2_w = 1 + + # About Reconsruction Loss + self.reconstruction_mode = False + self.skip_nboundary_pixels_from_loss = None + self.reconstruction_weight = 1.0 + self._exclusion_loss_weight = 0 + self.ch1_recons_w = 1 + self.ch2_recons_w = 1 + self.enable_mixed_rec = False + self.mixed_rec_w_step = 0 + + # About KL Loss + self.kl_weight = 1.0 # HC + self.usplit_kl_weight = None # HC + self.free_bits = 1.0 # HC + self.kl_annealing = False # HC + self.kl_annealtime = self.kl_start = None + if self.kl_annealing: + self.kl_annealtime = 10 # HC + self.kl_start = -1 # HC + + ##### Define training attributes ##### + self.lr = config.training.lr + self.lr_scheduler_patience = config.training.lr_scheduler_patience + self.lr_scheduler_monitor = config.model.get("monitor", "val_loss") + self.lr_scheduler_mode = MetricMonitor(self.lr_scheduler_monitor).mode() + + # Initialize object for keeping track of PSNR for each output channel + self.channels_psnr = [RunningPSNR() for _ in range(self.model.target_ch)] + + def forward(self, x: Any) -> Any: + return self.model(x) + + def training_step( + self, batch: torch.Tensor, batch_idx: int, enable_logging: bool = True + ) -> Dict[str, torch.Tensor]: + + if self.current_epoch == 0 and batch_idx == 0: + self.log("val_psnr", 1.0, on_epoch=True) + + # Pre-processing of inputs + x, target = batch[:2] + self.set_params_to_same_device_as(x) + x_normalized = self.normalize_input(x) + if self.reconstruction_mode: # just for experimental purpose + target_normalized = x_normalized[:, :1].repeat(1, 2, 1, 1) + target = None + mask = None + else: + target_normalized = self.normalize_target(target) + mask = ~((target == 0).reshape(len(target), -1).all(dim=1)) + + # Forward pass + out, td_data = self.forward(x_normalized) + + if ( + self.model.encoder_no_padding_mode + and out.shape[-2:] != target_normalized.shape[-2:] + ): + target_normalized = F.center_crop(target_normalized, out.shape[-2:]) + + # Loss Computations + # mask = torch.isnan(target.reshape(len(x), -1)).all(dim=1) + recons_loss_dict, imgs = self.get_reconstruction_loss( + reconstruction=out, + target=target_normalized, + input=x_normalized, + splitting_mask=mask, + return_predicted_img=True, + ) + + # This `if` is not used by default config + if self.skip_nboundary_pixels_from_loss: + pad = self.skip_nboundary_pixels_from_loss + target_normalized = target_normalized[:, :, pad:-pad, pad:-pad] + + recons_loss = recons_loss_dict["loss"] * self.reconstruction_weight + + if torch.isnan(recons_loss).any(): + recons_loss = 0.0 + + if self.model.non_stochastic_version: + kl_loss = torch.Tensor([0.0]).cuda() + net_loss = recons_loss + else: + if self.loss_type == LossType.DenoiSplitMuSplit: + msg = f"For the loss type {LossType.name(self.loss_type)}, kl_loss_formulation must be denoisplit_usplit" + assert self.kl_loss_formulation == "denoisplit_usplit", msg + assert self._denoisplit_w is not None and self._usplit_w is not None + + kl_key_denoisplit = "kl_restricted" if self._restricted_kl else "kl" + # NOTE: 'kl' key stands for the 'kl_samplewise' key in the TopDownLayer class. + # The different naming comes from `top_down_pass()` method in the LadderVAE class. + denoisplit_kl = self.get_kl_divergence_loss( + topdown_layer_data_dict=td_data, kl_key=kl_key_denoisplit + ) + usplit_kl = self.get_kl_divergence_loss_usplit( + topdown_layer_data_dict=td_data + ) + kl_loss = ( + self._denoisplit_w * denoisplit_kl + self._usplit_w * usplit_kl + ) + kl_loss = self.kl_weight * kl_loss + + recons_loss = self.reconstruction_loss_musplit_denoisplit( + out, target_normalized + ) + # recons_loss = self._denoisplit_w * recons_loss_nm + self._usplit_w * recons_loss_gm + + elif self.kl_loss_formulation == "usplit": + kl_loss = self.get_kl_weight() * self.get_kl_divergence_loss_usplit( + td_data + ) + elif self.kl_loss_formulation in ["", "denoisplit"]: + kl_loss = self.get_kl_weight() * self.get_kl_divergence_loss(td_data) + net_loss = recons_loss + kl_loss + + # Logging + if enable_logging: + for i, x in enumerate(td_data["debug_qvar_max"]): + self.log(f"qvar_max:{i}", x.item(), on_epoch=True) + + self.log("reconstruction_loss", recons_loss_dict["loss"], on_epoch=True) + self.log("kl_loss", kl_loss, on_epoch=True) + self.log("training_loss", net_loss, on_epoch=True) + self.log("lr", self.lr, on_epoch=True) + if self.model._tethered_ch2_scalar is not None: + self.log( + "tethered_ch2_scalar", + self.model._tethered_ch2_scalar, + on_epoch=True, + ) + self.log( + "tethered_ch1_scalar", + self.model._tethered_ch1_scalar, + on_epoch=True, + ) + + # self.log('grad_norm_bottom_up', self.grad_norm_bottom_up, on_epoch=True) + # self.log('grad_norm_top_down', self.grad_norm_top_down, on_epoch=True) + + output = { + "loss": net_loss, + "reconstruction_loss": ( + recons_loss.detach() + if isinstance(recons_loss, torch.Tensor) + else recons_loss + ), + "kl_loss": kl_loss.detach(), + } + # https://github.com/openai/vdvae/blob/main/train.py#L26 + if torch.isnan(net_loss).any(): + return None + + return output + + def validation_step(self, batch: torch.Tensor, batch_idx: int): + # Pre-processing of inputs + x, target = batch[:2] + self.set_params_to_same_device_as(x) + x_normalized = self.normalize_input(x) + if self.reconstruction_mode: # only for experimental purpose + target_normalized = x_normalized[:, :1].repeat(1, 2, 1, 1) + target = None + mask = None + else: + target_normalized = self.normalize_target(target) + mask = ~((target == 0).reshape(len(target), -1).all(dim=1)) + + # Forward pass + out, _ = self.forward(x_normalized) + + if self.model.predict_logvar is not None: + out_mean, _ = out.chunk(2, dim=1) + else: + out_mean = out + + if ( + self.model.encoder_no_padding_mode + and out.shape[-2:] != target_normalized.shape[-2:] + ): + target_normalized = F.center_crop(target_normalized, out.shape[-2:]) + + if self.loss_type == LossType.DenoiSplitMuSplit: + recons_loss = self.reconstruction_loss_musplit_denoisplit( + out, target_normalized + ) + recons_loss_dict = {"loss": recons_loss} + recons_img = out_mean + else: + # Metrics computation + recons_loss_dict, recons_img = self.get_reconstruction_loss( + reconstruction=out_mean, + target=target_normalized, + input=x_normalized, + splitting_mask=mask, + return_predicted_img=True, + ) + + # This `if` is not used by default config + if self.skip_nboundary_pixels_from_loss: + pad = self.skip_nboundary_pixels_from_loss + target_normalized = target_normalized[:, :, pad:-pad, pad:-pad] + + channels_rinvpsnr = [] + for i in range(target_normalized.shape[1]): + self.channels_psnr[i].update(recons_img[:, i], target_normalized[:, i]) + psnr = RangeInvariantPsnr( + target_normalized[:, i].clone(), recons_img[:, i].clone() + ) + channels_rinvpsnr.append(psnr) + psnr = torch_nanmean(psnr).item() + self.log(f"val_psnr_l{i+1}", psnr, on_epoch=True) + + recons_loss = recons_loss_dict["loss"] + if torch.isnan(recons_loss).any(): + return + + self.log("val_loss", recons_loss, on_epoch=True) + # self.log('val_psnr', (val_psnr_l1 + val_psnr_l2) / 2, on_epoch=True) + + # if batch_idx == 0 and self.power_of_2(self.current_epoch): + # all_samples = [] + # for i in range(20): + # sample, _ = self(x_normalized[0:1, ...]) + # sample = self.likelihood.get_mean_lv(sample)[0] + # all_samples.append(sample[None]) + + # all_samples = torch.cat(all_samples, dim=0) + # all_samples = all_samples * self.data_std + self.data_mean + # all_samples = all_samples.cpu() + # img_mmse = torch.mean(all_samples, dim=0)[0] + # self.log_images_for_tensorboard(all_samples[:, 0, 0, ...], target[0, 0, ...], img_mmse[0], 'label1') + # self.log_images_for_tensorboard(all_samples[:, 0, 1, ...], target[0, 1, ...], img_mmse[1], 'label2') + + # return net_loss + + def on_validation_epoch_end(self): + psnr_arr = [] + for i in range(len(self.channels_psnr)): + psnr = self.channels_psnr[i].get() + if psnr is None: + psnr_arr = None + break + psnr_arr.append(psnr.cpu().numpy()) + self.channels_psnr[i].reset() + + if psnr_arr is not None: + psnr = np.mean(psnr_arr) + self.log("val_psnr", psnr, on_epoch=True) + else: + self.log("val_psnr", 0.0, on_epoch=True) + + if self.mixed_rec_w_step: + self.mixed_rec_w = max(self.mixed_rec_w - self.mixed_rec_w_step, 0.0) + self.log("mixed_rec_w", self.mixed_rec_w, on_epoch=True) + + def predict_step(self, batch: torch.Tensor, batch_idx: Any) -> Any: + raise NotImplementedError("predict_step is not implemented") + + def configure_optimizers(self): + optimizer = torch.optim.Adamax(self.parameters(), lr=self.lr, weight_decay=0) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, + self.lr_scheduler_mode, + patience=self.lr_scheduler_patience, + factor=0.5, + min_lr=1e-12, + verbose=True, + ) + + return { + "optimizer": optimizer, + "lr_scheduler": scheduler, + "monitor": self.lr_scheduler_monitor, + } + + ##### REQUIRED Methods for Loss Computation ##### + def get_reconstruction_loss( + self, + reconstruction: torch.Tensor, + target: torch.Tensor, + input: torch.Tensor, + splitting_mask: torch.Tensor = None, + return_predicted_img: bool = False, + likelihood_obj: LikelihoodModule = None, + ) -> Dict[str, torch.Tensor]: + """ + Parameters + ---------- + reconstruction: torch.Tensor, + target: torch.Tensor + input: torch.Tensor + splitting_mask: torch.Tensor = None + A boolean tensor that indicates which items to keep for reconstruction loss computation. + If `None`, all the elements of the items are considered (i.e., the mask is all `True`). + return_predicted_img: bool = False + likelihood_obj: LikelihoodModule = None + """ + output = self._get_reconstruction_loss_vector( + reconstruction=reconstruction, + target=target, + input=input, + return_predicted_img=return_predicted_img, + likelihood_obj=likelihood_obj, + ) + loss_dict = output[0] if return_predicted_img else output + + if splitting_mask is None: + splitting_mask = torch.ones_like(loss_dict["loss"]).bool() + + # print(len(target) - (torch.isnan(loss_dict['loss'])).sum()) + + loss_dict["loss"] = loss_dict["loss"][splitting_mask].sum() / len( + reconstruction + ) + for i in range(1, 1 + target.shape[1]): + key = f"ch{i}_loss" + loss_dict[key] = loss_dict[key][splitting_mask].sum() / len(reconstruction) + + if "mixed_loss" in loss_dict: + loss_dict["mixed_loss"] = torch.mean(loss_dict["mixed_loss"]) + if return_predicted_img: + assert len(output) == 2 + return loss_dict, output[1] + else: + return loss_dict + + def _get_reconstruction_loss_vector( + self, + reconstruction: torch.Tensor, + target: torch.Tensor, + input: torch.Tensor, + return_predicted_img: bool = False, + likelihood_obj: LikelihoodModule = None, + ): + """ + Parameters + ---------- + return_predicted_img: bool + If set to `True`, the besides the loss, the reconstructed image is also returned. + Default is `False`. + """ + output = { + "loss": None, + "mixed_loss": None, + } + + for i in range(1, 1 + target.shape[1]): + output[f"ch{i}_loss"] = None + + if likelihood_obj is None: + likelihood_obj = self.model.likelihood + + # Log likelihood + ll, like_dict = likelihood_obj(reconstruction, target) + ll = self._get_weighted_likelihood(ll) + if ( + self.skip_nboundary_pixels_from_loss is not None + and self.skip_nboundary_pixels_from_loss > 0 + ): + pad = self.skip_nboundary_pixels_from_loss + ll = ll[:, :, pad:-pad, pad:-pad] + like_dict["params"]["mean"] = like_dict["params"]["mean"][ + :, :, pad:-pad, pad:-pad + ] + + # assert ll.shape[1] == 2, f"Change the code below to handle >2 channels first. ll.shape {ll.shape}" + output = {"loss": compute_batch_mean(-1 * ll)} + if ll.shape[1] > 1: + for i in range(1, 1 + target.shape[1]): + output[f"ch{i}_loss"] = compute_batch_mean(-ll[:, i - 1]) + else: + assert ll.shape[1] == 1 + output["ch1_loss"] = output["loss"] + output["ch2_loss"] = output["loss"] + + if ( + self.channel_1_w is not None + and self.channel_2_w is not None + and (self.channel_1_w != 1 or self.channel_2_w != 1) + ): + assert ll.shape[1] == 2, "Only 2 channels are supported for now." + output["loss"] = ( + self.channel_1_w * output["ch1_loss"] + + self.channel_2_w * output["ch2_loss"] + ) / (self.channel_1_w + self.channel_2_w) + + # This `if` is not used by default config + if self.enable_mixed_rec: + mixed_pred, mixed_logvar = self.get_mixed_prediction( + like_dict["params"]["mean"], + like_dict["params"]["logvar"], + self.data_mean, + self.data_std, + ) + if ( + self.model._multiscale_count is not None + and self.model._multiscale_count > 1 + ): + assert input.shape[1] == self.model._multiscale_count + input = input[:, :1] + + assert ( + input.shape == mixed_pred.shape + ), "No fucking room for vectorization induced bugs." + mixed_recons_ll = self.model.likelihood.log_likelihood( + input, {"mean": mixed_pred, "logvar": mixed_logvar} + ) + output["mixed_loss"] = compute_batch_mean(-1 * mixed_recons_ll) + + # This `if` is not used by default config + if self._exclusion_loss_weight: + raise NotImplementedError( + "Exclusion loss is not well defined here, so it should not be used." + ) + imgs = like_dict["params"]["mean"] + exclusion_loss = compute_exclusion_loss(imgs[:, :1], imgs[:, 1:]) + output["exclusion_loss"] = exclusion_loss + + if return_predicted_img: + return output, like_dict["params"]["mean"] + + return output + + def reconstruction_loss_musplit_denoisplit(self, out, target_normalized): + if self.model.predict_logvar is not None: + out_mean, _ = out.chunk(2, dim=1) + else: + out_mean = out + + recons_loss_nm = ( + -1 * self.model.likelihood_NM(out_mean, target_normalized)[0].mean() + ) + recons_loss_gm = -1 * self.model.likelihood_gm(out, target_normalized)[0].mean() + recons_loss = ( + self._denoisplit_w * recons_loss_nm + self._usplit_w * recons_loss_gm + ) + return recons_loss + + def _get_weighted_likelihood(self, ll): + """ + Each of the channels gets multiplied with a different weight. + """ + if self.ch1_recons_w == 1 and self.ch2_recons_w == 1: + return ll + + assert ll.shape[1] == 2, "This function is only for 2 channel images" + + mask1 = torch.zeros((len(ll), ll.shape[1], 1, 1), device=ll.device) + mask1[:, 0] = 1 + mask2 = torch.zeros((len(ll), ll.shape[1], 1, 1), device=ll.device) + mask2[:, 1] = 1 + + return ll * mask1 * self.ch1_recons_w + ll * mask2 * self.ch2_recons_w + + def get_kl_weight(self): + """ + KL loss can be weighted depending whether any annealing procedure is used. + This function computes the weight of the KL loss in case of annealing. + """ + if self.kl_annealing == True: + # calculate relative weight + kl_weight = (self.current_epoch - self.kl_start) * ( + 1.0 / self.kl_annealtime + ) + # clamp to [0,1] + kl_weight = min(max(0.0, kl_weight), 1.0) + + # if the final weight is given, then apply that weight on top of it + if self.kl_weight is not None: + kl_weight = kl_weight * self.kl_weight + elif self.kl_weight is not None: + return self.kl_weight + else: + kl_weight = 1.0 + return kl_weight + + def get_kl_divergence_loss_usplit( + self, topdown_layer_data_dict: Dict[str, torch.Tensor] + ) -> torch.Tensor: + """ """ + kl = torch.cat( + [kl_layer.unsqueeze(1) for kl_layer in topdown_layer_data_dict["kl"]], dim=1 + ) + # NOTE: kl.shape = (16,4) 16 is batch size. 4 is number of layers. + # Values are sum() and so are of the order 30000 + # Example values: 30626.6758, 31028.8145, 29509.8809, 29945.4922, 28919.1875, 29075.2988 + + nlayers = kl.shape[1] + for i in range(nlayers): + # topdown_layer_data_dict['z'][2].shape[-3:] = 128 * 32 * 32 + norm_factor = np.prod(topdown_layer_data_dict["z"][i].shape[-3:]) + # if self._restricted_kl: + # pow = np.power(2,min(i + 1, self._multiscale_count-1)) + # norm_factor /= pow * pow + + kl[:, i] = kl[:, i] / norm_factor + + kl_loss = free_bits_kl(kl, 0.0).mean() + return kl_loss + + def get_kl_divergence_loss(self, topdown_layer_data_dict, kl_key="kl"): + """ + kl[i] for each i has length batch_size + resulting kl shape: (batch_size, layers) + """ + kl = torch.cat( + [kl_layer.unsqueeze(1) for kl_layer in topdown_layer_data_dict[kl_key]], + dim=1, + ) + + # As compared to uSplit kl divergence, + # more by a factor of 4 just because we do sum and not mean. + kl_loss = free_bits_kl(kl, self.free_bits).sum() + # NOTE: at each hierarchy, it is more by a factor of 128/i**2). + # 128/(2*2) = 32 (bottommost layer) + # 128/(4*4) = 8 + # 128/(8*8) = 2 + # 128/(16*16) = 0.5 (topmost layer) + + # Normalize the KL-loss w.r.t. the latent space + kl_loss = kl_loss / np.prod(self.model.img_shape) + return kl_loss + + ##### UTILS Methods ##### + def normalize_input(self, x): + if self.model.normalized_input: + return x + return (x - self.data_mean["input"].mean()) / self.data_std["input"].mean() + + def normalize_target(self, target, batch=None): + return (target - self.data_mean["target"]) / self.data_std["target"] + + def unnormalize_target(self, target_normalized): + return target_normalized * self.data_std["target"] + self.data_mean["target"] + + ##### ADDITIONAL Methods ##### + # def log_images_for_tensorboard(self, pred, target, img_mmse, label): + # clamped_pred = torch.clamp((pred - pred.min()) / (pred.max() - pred.min()), 0, 1) + # clamped_mmse = torch.clamp((img_mmse - img_mmse.min()) / (img_mmse.max() - img_mmse.min()), 0, 1) + # if target is not None: + # clamped_input = torch.clamp((target - target.min()) / (target.max() - target.min()), 0, 1) + # img = wandb.Image(clamped_input[None].cpu().numpy()) + # self.logger.experiment.log({f'target_for{label}': img}) + # # self.trainer.logger.experiment.add_image(f'target_for{label}', clamped_input[None], self.current_epoch) + # for i in range(3): + # # self.trainer.logger.experiment.add_image(f'{label}/sample_{i}', clamped_pred[i:i + 1], self.current_epoch) + # img = wandb.Image(clamped_pred[i:i + 1].cpu().numpy()) + # self.logger.experiment.log({f'{label}/sample_{i}': img}) + + # img = wandb.Image(clamped_mmse[None].cpu().numpy()) + # self.trainer.logger.experiment.log({f'{label}/mmse (100 samples)': img}) + + @property + def global_step(self) -> int: + """Global step.""" + return self._global_step + + def increment_global_step(self): + """Increments global step by 1.""" + self._global_step += 1 + + def set_params_to_same_device_as(self, correct_device_tensor: torch.Tensor): + + self.model.likelihood.set_params_to_same_device_as(correct_device_tensor) + if isinstance(self.data_mean, torch.Tensor): + if self.data_mean.device != correct_device_tensor.device: + self.data_mean = self.data_mean.to(correct_device_tensor.device) + self.data_std = self.data_std.to(correct_device_tensor.device) + elif isinstance(self.data_mean, dict): + for k, v in self.data_mean.items(): + if v.device != correct_device_tensor.device: + self.data_mean[k] = v.to(correct_device_tensor.device) + self.data_std[k] = self.data_std[k].to(correct_device_tensor.device) + + def get_mixed_prediction( + self, prediction, prediction_logvar, data_mean, data_std, channel_weights=None + ): + pred_unorm = prediction * data_std["target"] + data_mean["target"] + if channel_weights is None: + channel_weights = 1 + + if self._input_is_sum: + mixed_prediction = torch.sum( + pred_unorm * channel_weights, dim=1, keepdim=True + ) + else: + mixed_prediction = torch.mean( + pred_unorm * channel_weights, dim=1, keepdim=True + ) + + mixed_prediction = (mixed_prediction - data_mean["input"].mean()) / data_std[ + "input" + ].mean() + + if prediction_logvar is not None: + if data_std["target"].shape == data_std["input"].shape and torch.all( + data_std["target"] == data_std["input"] + ): + assert channel_weights == 1 + logvar = prediction_logvar + else: + var = torch.exp(prediction_logvar) + var = var * (data_std["target"] / data_std["input"]) ** 2 + if channel_weights != 1: + var = var * torch.square(channel_weights) + + # sum of variance. + mixed_var = 0 + for i in range(var.shape[1]): + mixed_var += var[:, i : i + 1] + + logvar = torch.log(mixed_var) + else: + logvar = None + return mixed_prediction, logvar diff --git a/src/careamics/lvae_training/metrics.py b/src/careamics/lvae_training/metrics.py new file mode 100644 index 000000000..b4c7a4304 --- /dev/null +++ b/src/careamics/lvae_training/metrics.py @@ -0,0 +1,214 @@ +""" +This script contains the functions/classes to compute loss and metrics used to train and evaluate the performance of the model. +""" + +import numpy as np +import torch +from skimage.metrics import structural_similarity +from torchmetrics.image import MultiScaleStructuralSimilarityIndexMeasure + +from careamics.models.lvae.utils import allow_numpy + + +class RunningPSNR: + """ + This class allows to compute the running PSNR during validation step in training. + In this way it is possible to compute the PSNR on the entire validation set one batch at the time. + """ + + def __init__(self): + # number of elements seen so far during the epoch + self.N = None + # running sum of the MSE over the self.N elements seen so far + self.mse_sum = None + # running max and min values of the self.N target images seen so far + self.max = self.min = None + self.reset() + + def reset(self): + """ + Used to reset the running PSNR (usually called at the end of each epoch). + """ + self.mse_sum = 0 + self.N = 0 + self.max = self.min = None + + def update(self, rec: torch.Tensor, tar: torch.Tensor) -> None: + """ + Given a batch of reconstructed and target images, it updates the MSE and. + + Parameters + ---------- + rec: torch.Tensor + Batch of reconstructed images (B, H, W). + tar: torch.Tensor + Batch of target images (B, H, W). + """ + ins_max = torch.max(tar).item() + ins_min = torch.min(tar).item() + if self.max is None: + assert self.min is None + self.max = ins_max + self.min = ins_min + else: + self.max = max(self.max, ins_max) + self.min = min(self.min, ins_min) + + mse = (rec - tar) ** 2 + elementwise_mse = torch.mean(mse.view(len(mse), -1), dim=1) + self.mse_sum += torch.nansum(elementwise_mse) + self.N += len(elementwise_mse) - torch.sum(torch.isnan(elementwise_mse)) + + def get(self): + """ + The get the actual PSNR value given the running statistics. + """ + if self.N == 0 or self.N is None: + return None + rmse = torch.sqrt(self.mse_sum / self.N) + return 20 * torch.log10((self.max - self.min) / rmse) + + +def zero_mean(x): + return x - torch.mean(x, dim=1, keepdim=True) + + +def fix_range(gt, x): + a = torch.sum(gt * x, dim=1, keepdim=True) / (torch.sum(x * x, dim=1, keepdim=True)) + return x * a + + +def fix(gt, x): + gt_ = zero_mean(gt) + return fix_range(gt_, zero_mean(x)) + + +def _PSNR_internal(gt, pred, range_=None): + if range_ is None: + range_ = torch.max(gt, dim=1).values - torch.min(gt, dim=1).values + + mse = torch.mean((gt - pred) ** 2, dim=1) + return 20 * torch.log10(range_ / torch.sqrt(mse)) + + +@allow_numpy +def PSNR(gt, pred, range_=None): + """ + Compute PSNR. + + Parameters + ---------- + gt: array + Ground truth image. + pred: array + Predicted image. + """ + assert len(gt.shape) == 3, "Images must be in shape: (batch,H,W)" + + gt = gt.view(len(gt), -1) + pred = pred.view(len(gt), -1) + return _PSNR_internal(gt, pred, range_=range_) + + +@allow_numpy +def RangeInvariantPsnr(gt: torch.Tensor, pred: torch.Tensor): + """ + NOTE: Works only for grayscale images. + Adapted from https://github.com/juglab/ScaleInvPSNR/blob/master/psnr.py + It rescales the prediction to ensure that the prediction has the same range as the ground truth. + """ + assert len(gt.shape) == 3, "Images must be in shape: (batch,H,W)" + gt = gt.view(len(gt), -1) + pred = pred.view(len(gt), -1) + ra = (torch.max(gt, dim=1).values - torch.min(gt, dim=1).values) / torch.std( + gt, dim=1 + ) + gt_ = zero_mean(gt) / torch.std(gt, dim=1, keepdim=True) + return _PSNR_internal(zero_mean(gt_), fix(gt_, pred), ra) + + +def _avg_psnr(target, prediction, psnr_fn): + output = np.mean( + [ + psnr_fn(target[i : i + 1], prediction[i : i + 1]).item() + for i in range(len(prediction)) + ] + ) + return round(output, 2) + + +def avg_range_inv_psnr(target, prediction): + return _avg_psnr(target, prediction, RangeInvariantPsnr) + + +def avg_psnr(target, prediction): + return _avg_psnr(target, prediction, PSNR) + + +def compute_masked_psnr(mask, tar1, tar2, pred1, pred2): + mask = mask.astype(bool) + mask = mask[..., 0] + tmp_tar1 = tar1[mask].reshape((len(tar1), -1, 1)) + tmp_pred1 = pred1[mask].reshape((len(tar1), -1, 1)) + tmp_tar2 = tar2[mask].reshape((len(tar2), -1, 1)) + tmp_pred2 = pred2[mask].reshape((len(tar2), -1, 1)) + psnr1 = avg_range_inv_psnr(tmp_tar1, tmp_pred1) + psnr2 = avg_range_inv_psnr(tmp_tar2, tmp_pred2) + return psnr1, psnr2 + + +def avg_ssim(target, prediction): + ssim = [ + structural_similarity( + target[i], prediction[i], data_range=(target[i].max() - target[i].min()) + ) + for i in range(len(target)) + ] + return np.mean(ssim), np.std(ssim) + + +@allow_numpy +def range_invariant_multiscale_ssim(gt_, pred_): + """ + Computes range invariant multiscale ssim for one channel. + This has the benefit that it is invariant to scalar multiplications in the prediction. + """ + shape = gt_.shape + gt_ = torch.Tensor(gt_.reshape((shape[0], -1))) + pred_ = torch.Tensor(pred_.reshape((shape[0], -1))) + gt_ = zero_mean(gt_) + pred_ = zero_mean(pred_) + pred_ = fix(gt_, pred_) + pred_ = pred_.reshape(shape) + gt_ = gt_.reshape(shape) + + ms_ssim = MultiScaleStructuralSimilarityIndexMeasure( + data_range=gt_.max() - gt_.min() + ) + return ms_ssim(torch.Tensor(pred_[:, None]), torch.Tensor(gt_[:, None])).item() + + +def compute_multiscale_ssim(gt_, pred_, range_invariant=True): + """ + Computes multiscale ssim for each channel. + Args: + gt_: ground truth image with shape (N, H, W, C) + pred_: predicted image with shape (N, H, W, C) + range_invariant: whether to use range invariant multiscale ssim + """ + ms_ssim_values = {i: None for i in range(gt_.shape[-1])} + for ch_idx in range(gt_.shape[-1]): + tar_tmp = gt_[..., ch_idx] + pred_tmp = pred_[..., ch_idx] + if range_invariant: + ms_ssim_values[ch_idx] = range_invariant_multiscale_ssim(tar_tmp, pred_tmp) + else: + ms_ssim = MultiScaleStructuralSimilarityIndexMeasure( + data_range=tar_tmp.max() - tar_tmp.min() + ) + ms_ssim_values[ch_idx] = ms_ssim( + torch.Tensor(pred_tmp[:, None]), torch.Tensor(tar_tmp[:, None]) + ).item() + + output = [ms_ssim_values[i] for i in range(gt_.shape[-1])] + return output diff --git a/src/careamics/lvae_training/train_lvae.py b/src/careamics/lvae_training/train_lvae.py new file mode 100644 index 000000000..884b12387 --- /dev/null +++ b/src/careamics/lvae_training/train_lvae.py @@ -0,0 +1,339 @@ +""" +This script is meant to load data, intialize the model, and provide the logic for training it. +""" + +import glob +import os +import socket +import sys +from typing import Dict + +import pytorch_lightning as pl +import torch +from absl import app, flags +from ml_collections.config_flags import config_flags +from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint +from pytorch_lightning.callbacks.early_stopping import EarlyStopping +from pytorch_lightning.loggers import WandbLogger +from torch.utils.data import DataLoader + +sys.path.append(os.path.join(os.path.dirname(__file__), "../..")) +print(sys.path) + +from careamics.lvae_training.data_modules import LCMultiChDloader, MultiChDloader +from careamics.lvae_training.data_utils import DataSplitType +from careamics.lvae_training.lightning_module import LadderVAELight +from careamics.lvae_training.train_utils import * + +FLAGS = flags.FLAGS + +config_flags.DEFINE_config_file( + "config", None, "Training configuration.", lock_config=False +) +flags.DEFINE_string("workdir", None, "Work directory.") +flags.DEFINE_enum("mode", None, ["train", "eval"], "Running mode: train or eval") +flags.DEFINE_string( + "logdir", "/group/jug/federico/wandb_backup/", "The folder name for storing logging" +) +flags.DEFINE_string( + "datadir", "/group/jug/federico/careamics_training/data/BioSR", "Data directory." +) +flags.DEFINE_boolean("use_max_version", False, "Overwrite the max version of the model") +flags.DEFINE_string( + "load_ckptfpath", + "", + "The path to a previous ckpt from which the weights should be loaded", +) +flags.mark_flags_as_required(["workdir", "config", "mode"]) + + +def create_dataset( + config, + datadir, + eval_datasplit_type=DataSplitType.Val, + raw_data_dict=None, + skip_train_dataset=False, + kwargs_dict=None, +): + + if kwargs_dict is None: + kwargs_dict = {} + + datapath = datadir + + # Hard-coded parameters (used to be in the config file) + normalized_input = True + use_one_mu_std = True + train_aug_rotate = False + enable_random_cropping = True + lowres_supervision = False + + # 1) Data loader for Lateral Contextualization + if ( + "multiscale_lowres_count" in config.data + and config.data.multiscale_lowres_count is not None + ): + # Get padding attributes + if "padding_kwargs" not in kwargs_dict: + padding_kwargs = {} + if "padding_mode" in config.data and config.data.padding_mode is not None: + padding_kwargs["mode"] = config.data.padding_mode + else: + padding_kwargs["mode"] = "reflect" + if "padding_value" in config.data and config.data.padding_value is not None: + padding_kwargs["constant_values"] = config.data.padding_value + else: + padding_kwargs["constant_values"] = None + else: + padding_kwargs = kwargs_dict.pop("padding_kwargs") + + train_data = ( + None + if skip_train_dataset + else LCMultiChDloader( + config.data, + datapath, + datasplit_type=DataSplitType.Train, + val_fraction=0.1, + test_fraction=0.1, + normalized_input=normalized_input, + use_one_mu_std=use_one_mu_std, + enable_rotation_aug=train_aug_rotate, + enable_random_cropping=enable_random_cropping, + num_scales=config.data.multiscale_lowres_count, + lowres_supervision=lowres_supervision, + padding_kwargs=padding_kwargs, + **kwargs_dict, + allow_generation=True, + ) + ) + max_val = train_data.get_max_val() + + val_data = LCMultiChDloader( + config.data, + datapath, + datasplit_type=eval_datasplit_type, + val_fraction=0.1, + test_fraction=0.1, + normalized_input=normalized_input, + use_one_mu_std=use_one_mu_std, + enable_rotation_aug=False, # No rotation aug on validation + enable_random_cropping=False, + # No random cropping on validation. Validation is evaluated on determistic grids + num_scales=config.data.multiscale_lowres_count, + lowres_supervision=lowres_supervision, + padding_kwargs=padding_kwargs, + allow_generation=False, + **kwargs_dict, + max_val=max_val, + ) + # 2) Vanilla data loader + else: + train_data_kwargs = {"allow_generation": True, **kwargs_dict} + val_data_kwargs = {"allow_generation": False, **kwargs_dict} + + train_data_kwargs["enable_random_cropping"] = enable_random_cropping + val_data_kwargs["enable_random_cropping"] = False + + train_data = ( + None + if skip_train_dataset + else MultiChDloader( + data_config=config.data, + fpath=datapath, + datasplit_type=DataSplitType.Train, + val_fraction=0.1, + test_fraction=0.1, + normalized_input=normalized_input, + use_one_mu_std=use_one_mu_std, + enable_rotation_aug=train_aug_rotate, + **train_data_kwargs, + ) + ) + + max_val = train_data.get_max_val() + val_data = MultiChDloader( + data_config=config.data, + fpath=datapath, + datasplit_type=eval_datasplit_type, + val_fraction=0.1, + test_fraction=0.1, + normalized_input=normalized_input, + use_one_mu_std=use_one_mu_std, + enable_rotation_aug=False, # No rotation aug on validation + max_val=max_val, + **val_data_kwargs, + ) + + # For normalizing, we should be using the training data's mean and std. + mean_val, std_val = train_data.compute_mean_std() + train_data.set_mean_std(mean_val, std_val) + val_data.set_mean_std(mean_val, std_val) + + return train_data, val_data + + +def create_model_and_train( + config: ml_collections.ConfigDict, + data_mean: Dict[str, torch.Tensor], + data_std: Dict[str, torch.Tensor], + logger: WandbLogger, + checkpoint_callback: ModelCheckpoint, + train_loader: DataLoader, + val_loader: DataLoader, +): + # tensorboard previous files. + for filename in glob.glob(config.workdir + "/events*"): + os.remove(filename) + + # checkpoints + for filename in glob.glob(config.workdir + "/*.ckpt"): + os.remove(filename) + + if "num_targets" in config.model: + target_ch = config.model.num_targets + else: + target_ch = config.data.get("num_channels", 2) + + # Instantiate the model (lightning wrapper) + model = LadderVAELight( + data_mean=data_mean, data_std=data_std, config=config, target_ch=target_ch + ) + + # Load pre-trained weights if any + if config.training.pre_trained_ckpt_fpath: + print("Starting with pre-trained model", config.training.pre_trained_ckpt_fpath) + checkpoint = torch.load(config.training.pre_trained_ckpt_fpath) + _ = model.load_state_dict(checkpoint["state_dict"], strict=False) + + estop_monitor = config.model.get("monitor", "val_loss") + estop_mode = MetricMonitor(estop_monitor).mode() + + callbacks = [ + EarlyStopping( + monitor=estop_monitor, + min_delta=1e-6, + patience=config.training.earlystop_patience, + verbose=True, + mode=estop_mode, + ), + checkpoint_callback, + LearningRateMonitor(logging_interval="epoch"), + ] + + logger.experiment.config.update(config.to_dict()) + # wandb.init(config=config) + trainer = pl.Trainer( + accelerator="gpu", + max_epochs=config.training.max_epochs, + gradient_clip_val=config.training.grad_clip_norm_value, + gradient_clip_algorithm=config.training.gradient_clip_algorithm, + logger=logger, + callbacks=callbacks, + # limit_train_batches = config.training.limit_train_batches, + precision=config.training.precision, + ) + trainer.fit(model, train_loader, val_loader) + + +def train_network( + train_loader: DataLoader, + val_loader: DataLoader, + data_mean: Dict[str, torch.Tensor], + data_std: Dict[str, torch.Tensor], + config: ml_collections.ConfigDict, + model_name: str, + logdir: str, +): + ckpt_monitor = config.model.get("monitor", "val_loss") + ckpt_mode = MetricMonitor(ckpt_monitor).mode() + checkpoint_callback = ModelCheckpoint( + monitor=ckpt_monitor, + dirpath=config.workdir, + filename=model_name + "_best", + save_last=True, + save_top_k=1, + mode=ckpt_mode, + ) + checkpoint_callback.CHECKPOINT_NAME_LAST = model_name + "_last" + logger = WandbLogger( + name=os.path.join(config.hostname, config.exptname), + save_dir=logdir, + project="Disentanglement", + ) + + create_model_and_train( + config=config, + data_mean=data_mean, + data_std=data_std, + logger=logger, + checkpoint_callback=checkpoint_callback, + train_loader=train_loader, + val_loader=val_loader, + ) + + +def main(argv): + config = FLAGS.config + + assert os.path.exists(FLAGS.workdir) + cur_workdir, relative_path = get_workdir( + config, FLAGS.workdir, FLAGS.use_max_version + ) + print(f"Saving training to {cur_workdir}") + + config.workdir = cur_workdir + config.exptname = relative_path + config.hostname = socket.gethostname() + config.datadir = FLAGS.datadir + config.training.pre_trained_ckpt_fpath = FLAGS.load_ckptfpath + + if FLAGS.mode == "train": + set_logger(workdir=cur_workdir) + raw_data_dict = None + + # From now on, config cannot be changed. + config = ml_collections.FrozenConfigDict(config) + log_config(config, cur_workdir) + + train_data, val_data = create_dataset( + config, FLAGS.datadir, raw_data_dict=raw_data_dict + ) + + mean_dict, std_dict = get_mean_std_dict_for_model(config, train_data) + + batch_size = config.training.batch_size + shuffle = True + train_dloader = DataLoader( + train_data, + pin_memory=False, + num_workers=config.training.num_workers, + shuffle=shuffle, + batch_size=batch_size, + ) + val_dloader = DataLoader( + val_data, + pin_memory=False, + num_workers=config.training.num_workers, + shuffle=False, + batch_size=batch_size, + ) + + train_network( + train_loader=train_dloader, + val_loader=val_dloader, + data_mean=mean_dict, + data_std=std_dict, + config=config, + model_name="BaselineVAECL", + logdir=FLAGS.logdir, + ) + + elif FLAGS.mode == "eval": + pass + else: + raise ValueError(f"Mode {FLAGS.mode} not recognized.") + + +if __name__ == "__main__": + app.run(main) diff --git a/src/careamics/lvae_training/train_utils.py b/src/careamics/lvae_training/train_utils.py new file mode 100644 index 000000000..e7c9c8a31 --- /dev/null +++ b/src/careamics/lvae_training/train_utils.py @@ -0,0 +1,121 @@ +""" +This script contains the utility functions for training the LVAE model. +These functions are mainly used in `train.py` script. +""" + +import logging +import os +import pickle +import time +from copy import deepcopy +from datetime import datetime +from pathlib import Path + +import ml_collections + + +def log_config(config: ml_collections.ConfigDict, cur_workdir: str) -> None: + # Saving config file. + with open(os.path.join(cur_workdir, "config.pkl"), "wb") as f: + pickle.dump(config, f) + print(f"Saved config to {cur_workdir}/config.pkl") + + +def set_logger(workdir: str) -> None: + os.makedirs(workdir, exist_ok=True) + fstream = open(os.path.join(workdir, "stdout.txt"), "w") + handler = logging.StreamHandler(fstream) + formatter = logging.Formatter( + "%(levelname)s - %(filename)s - %(asctime)s - %(message)s" + ) + handler.setFormatter(formatter) + logger = logging.getLogger() + logger.addHandler(handler) + logger.setLevel("INFO") + + +def get_new_model_version(model_dir: str) -> str: + """ + A model will have multiple runs. Each run will have a different version. + """ + versions = [] + for version_dir in os.listdir(model_dir): + try: + versions.append(int(version_dir)) + except: + print( + f"Invalid subdirectory:{model_dir}/{version_dir}. Only integer versions are allowed" + ) + exit() + if len(versions) == 0: + return "0" + return f"{max(versions) + 1}" + + +def get_model_name(config: ml_collections.ConfigDict) -> str: + return "LVAE_denoiSplit" + + +def get_workdir( + config: ml_collections.ConfigDict, + root_dir: str, + use_max_version: bool, + nested_call: int = 0, +): + rel_path = datetime.now().strftime("%y%m") + cur_workdir = os.path.join(root_dir, rel_path) + Path(cur_workdir).mkdir(exist_ok=True) + + rel_path = os.path.join(rel_path, get_model_name(config)) + cur_workdir = os.path.join(root_dir, rel_path) + Path(cur_workdir).mkdir(exist_ok=True) + + if use_max_version: + # Used for debugging. + version = int(get_new_model_version(cur_workdir)) + if version > 0: + version = f"{version - 1}" + + rel_path = os.path.join(rel_path, str(version)) + else: + rel_path = os.path.join(rel_path, get_new_model_version(cur_workdir)) + + cur_workdir = os.path.join(root_dir, rel_path) + try: + Path(cur_workdir).mkdir(exist_ok=False) + except FileExistsError: + print( + f"Workdir {cur_workdir} already exists. Probably because someother program also created the exact same directory. Trying to get a new version." + ) + time.sleep(2.5) + if nested_call > 10: + raise ValueError( + f"Cannot create a new directory. {cur_workdir} already exists." + ) + + return get_workdir(config, root_dir, use_max_version, nested_call + 1) + + return cur_workdir, rel_path + + +def get_mean_std_dict_for_model(config, train_dset): + """ + Computes the mean and std for the model. This will be subsequently passed to the model. + """ + mean_dict, std_dict = train_dset.get_mean_std() + + return deepcopy(mean_dict), deepcopy(std_dict) + + +class MetricMonitor: + def __init__(self, metric): + assert metric in ["val_loss", "val_psnr"] + self.metric = metric + + def mode(self): + if self.metric == "val_loss": + return "min" + elif self.metric == "val_psnr": + return "max" + else: + raise ValueError(f"Invalid metric:{self.metric}") diff --git a/src/careamics/models/lvae/__init__.py b/src/careamics/models/lvae/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/careamics/models/lvae/layers.py b/src/careamics/models/lvae/layers.py new file mode 100644 index 000000000..3527040d8 --- /dev/null +++ b/src/careamics/models/lvae/layers.py @@ -0,0 +1,1998 @@ +""" +Script containing the common basic blocks (nn.Module) reused by the LadderVAE architecture. + +Hierarchy in the model blocks: + +""" + +from copy import deepcopy +from typing import Callable, Dict, Iterable, Literal, Tuple, Union + +import torch +import torch.nn as nn +import torchvision.transforms.functional as F +from torch.distributions import kl_divergence +from torch.distributions.normal import Normal + +from .utils import ( + StableLogVar, + StableMean, + crop_img_tensor, + kl_normal_mc, + pad_img_tensor, +) + + +class ResidualBlock(nn.Module): + """ + Residual block with 2 convolutional layers. + + Some architectural notes: + - The number of input, intermediate, and output channels is the same, + - Padding is always 'same', + - The 2 convolutional layers have the same groups, + - No stride allowed, + - Kernel sizes must be odd. + + The output isgiven by: `out = gate(f(x)) + x`. + The presence of the gating mechanism is optional, and f(x) has different + structures depending on the `block_type` argument. + Specifically, `block_type` is a string specifying the block's structure, with: + a = activation + b = batch norm + c = conv layer + d = dropout. + For example, "bacdbacd" defines a block with 2x[batchnorm, activation, conv, dropout]. + """ + + default_kernel_size = (3, 3) + + def __init__( + self, + channels: int, + nonlin: Callable, + kernel: Union[int, Iterable[int]] = None, + groups: int = 1, + batchnorm: bool = True, + block_type: str = None, + dropout: float = None, + gated: bool = None, + skip_padding: bool = False, + conv2d_bias: bool = True, + ): + """ + Constructor. + + Parameters + ---------- + channels: int + The number of input and output channels (they are the same). + nonlin: Callable + The non-linearity function used in the block (e.g., `nn.ReLU`). + kernel: Union[int, Iterable[int]], optional + The kernel size used in the convolutions of the block. + It can be either a single integer or a pair of integers defining the squared kernel. + Default is `None`. + groups: int, optional + The number of groups to consider in the convolutions. Default is 1. + batchnorm: bool, optional + Whether to use batchnorm layers. Default is `True`. + block_type: str, optional + A string specifying the block structure, check class docstring for more info. + Default is `None`. + dropout: float, optional + The dropout probability in dropout layers. If `None` dropout is not used. + Default is `None`. + gated: bool, optional + Whether to use gated layer. Default is `None`. + skip_padding: bool, optional + Whether to skip padding in convolutions. Default is `False`. + conv2d_bias: bool, optional + Whether to use bias term in convolutions. Default is `True`. + """ + super().__init__() + + # Set kernel size & padding + if kernel is None: + kernel = self.default_kernel_size + elif isinstance(kernel, int): + kernel = (kernel, kernel) + elif len(kernel) != 2: + raise ValueError("kernel has to be None, int, or an iterable of length 2") + assert all([k % 2 == 1 for k in kernel]), "kernel sizes have to be odd" + kernel = list(kernel) + self.skip_padding = skip_padding + pad = [0] * len(kernel) if self.skip_padding else [k // 2 for k in kernel] + # print(kernel, pad) + + modules = [] + if block_type == "cabdcabd": + for i in range(2): + conv = nn.Conv2d( + channels, + channels, + kernel[i], + padding=pad[i], + groups=groups, + bias=conv2d_bias, + ) + modules.append(conv) + modules.append(nonlin()) + if batchnorm: + modules.append(nn.BatchNorm2d(channels)) + if dropout is not None: + modules.append(nn.Dropout2d(dropout)) + elif block_type == "bacdbac": + for i in range(2): + if batchnorm: + modules.append(nn.BatchNorm2d(channels)) + modules.append(nonlin()) + conv = nn.Conv2d( + channels, + channels, + kernel[i], + padding=pad[i], + groups=groups, + bias=conv2d_bias, + ) + modules.append(conv) + if dropout is not None and i == 0: + modules.append(nn.Dropout2d(dropout)) + elif block_type == "bacdbacd": + for i in range(2): + if batchnorm: + modules.append(nn.BatchNorm2d(channels)) + modules.append(nonlin()) + conv = nn.Conv2d( + channels, + channels, + kernel[i], + padding=pad[i], + groups=groups, + bias=conv2d_bias, + ) + modules.append(conv) + modules.append(nn.Dropout2d(dropout)) + + else: + raise ValueError(f"unrecognized block type '{block_type}'") + + self.gated = gated + if gated: + modules.append(GateLayer2d(channels, 1, nonlin)) + + self.block = nn.Sequential(*modules) + + def forward(self, x): + + out = self.block(x) + if out.shape != x.shape: + return out + F.center_crop(x, out.shape[-2:]) + else: + return out + x + + +class ResidualGatedBlock(ResidualBlock): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs, gated=True) + + +class GateLayer2d(nn.Module): + """ + Double the number of channels through a convolutional layer, then use + half the channels as gate for the other half. + """ + + def __init__(self, channels, kernel_size, nonlin=nn.LeakyReLU): + super().__init__() + assert kernel_size % 2 == 1 + pad = kernel_size // 2 + self.conv = nn.Conv2d(channels, 2 * channels, kernel_size, padding=pad) + self.nonlin = nonlin() + + def forward(self, x): + x = self.conv(x) + x, gate = torch.chunk(x, 2, dim=1) + x = self.nonlin(x) # TODO remove this? + gate = torch.sigmoid(gate) + return x * gate + + +class ResBlockWithResampling(nn.Module): + """ + Residual block that takes care of resampling (i.e. downsampling or upsampling) steps (by a factor 2). + It is structured as follows: + 1. `pre_conv`: a downsampling or upsampling strided convolutional layer in case of resampling, or + a 1x1 convolutional layer that maps the number of channels of the input to `inner_channels`. + 2. `ResidualBlock` + 3. `post_conv`: a 1x1 convolutional layer that maps the number of channels to `c_out`. + + Some implementation notes: + - Resampling is performed through a strided convolution layer at the beginning of the block. + - The strided convolution block has fixed kernel size of 3x3 and 1 layer of zero-padding. + - The number of channels is adjusted at the beginning and end of the block through 1x1 convolutional layers. + - The number of internal channels is by default the same as the number of output channels, but + min_inner_channels can override the behaviour. + """ + + def __init__( + self, + mode: Literal["top-down", "bottom-up"], + c_in: int, + c_out: int, + min_inner_channels: int = None, + nonlin: Callable = nn.LeakyReLU, + resample: bool = False, + res_block_kernel: Union[int, Iterable[int]] = None, + groups: int = 1, + batchnorm: bool = True, + res_block_type: str = None, + dropout: float = None, + gated: bool = None, + skip_padding: bool = False, + conv2d_bias: bool = True, + # lowres_input: bool = False, + ): + """ + Constructor. + + Parameters + ---------- + mode: Literal["top-down", "bottom-up"] + The type of resampling performed in the initial strided convolution of the block. + If "bottom-up" downsampling of a factor 2 is done. + If "top-down" upsampling of a factor 2 is done. + c_in: int + The number of input channels. + c_out: int + The number of output channels. + min_inner_channels: int, optional + The number of channels used in the inner layer of this module. + Default is `None`, meaning that the number of inner channels is set to `c_out`. + nonlin: Callable, optional + The non-linearity function used in the block. Default is `nn.LeakyReLU`. + resample: bool, optional + Whether to perform resampling in the first convolutional layer. + If `False`, the first convolutional layer just maps the input to a tensor with + `inner_channels` channels through 1x1 convolution. Deafult is `False`. + res_block_kernel: Union[int, Iterable[int]], optional + The kernel size used in the convolutions of the residual block. + It can be either a single integer or a pair of integers defining the squared kernel. + Default is `None`. + groups: int, optional + The number of groups to consider in the convolutions. Default is 1. + batchnorm: bool, optional + Whether to use batchnorm layers. Default is `True`. + res_block_type: str, optional + A string specifying the structure of residual block. + Check `ResidualBlock` doscstring for more information. + Default is `None`. + dropout: float, optional + The dropout probability in dropout layers. If `None` dropout is not used. + Default is `None`. + gated: bool, optional + Whether to use gated layer. Default is `None`. + skip_padding: bool, optional + Whether to skip padding in convolutions. Default is `False`. + conv2d_bias: bool, optional + Whether to use bias term in convolutions. Default is `True`. + """ + super().__init__() + assert mode in ["top-down", "bottom-up"] + + if min_inner_channels is None: + min_inner_channels = 0 + # inner_channels is the number of channels used in the inner layers + # of ResBlockWithResampling + inner_channels = max(c_out, min_inner_channels) + + # Define first conv layer to change num channels and/or up/downsample + if resample: + if mode == "bottom-up": # downsample + self.pre_conv = nn.Conv2d( + in_channels=c_in, + out_channels=inner_channels, + kernel_size=3, + padding=1, + stride=2, + groups=groups, + bias=conv2d_bias, + ) + elif mode == "top-down": # upsample + self.pre_conv = nn.ConvTranspose2d( + in_channels=c_in, + kernel_size=3, + out_channels=inner_channels, + padding=1, + stride=2, + groups=groups, + output_padding=1, + bias=conv2d_bias, + ) + elif c_in != inner_channels: + self.pre_conv = nn.Conv2d( + c_in, inner_channels, 1, groups=groups, bias=conv2d_bias + ) + else: + self.pre_conv = None + + # Residual block + self.res = ResidualBlock( + channels=inner_channels, + nonlin=nonlin, + kernel=res_block_kernel, + groups=groups, + batchnorm=batchnorm, + dropout=dropout, + gated=gated, + block_type=res_block_type, + skip_padding=skip_padding, + conv2d_bias=conv2d_bias, + ) + + # Define last conv layer to get correct num output channels + if inner_channels != c_out: + self.post_conv = nn.Conv2d( + inner_channels, c_out, 1, groups=groups, bias=conv2d_bias + ) + else: + self.post_conv = None + + def forward(self, x): + if self.pre_conv is not None: + x = self.pre_conv(x) + + x = self.res(x) + + if self.post_conv is not None: + x = self.post_conv(x) + return x + + +class TopDownDeterministicResBlock(ResBlockWithResampling): + + def __init__(self, *args, upsample: bool = False, **kwargs): + kwargs["resample"] = upsample + super().__init__("top-down", *args, **kwargs) + + +class BottomUpDeterministicResBlock(ResBlockWithResampling): + + def __init__(self, *args, downsample: bool = False, **kwargs): + kwargs["resample"] = downsample + super().__init__("bottom-up", *args, **kwargs) + + +class BottomUpLayer(nn.Module): + """ + Bottom-up deterministic layer. + It consists of one or a stack of `BottomUpDeterministicResBlock`'s. + The outputs are the so-called `bu_values` that are later used in the Decoder to update the + generative distributions. + + NOTE: When Lateral Contextualization is Enabled (i.e., `enable_multiscale=True`), + the low-res lateral input is first fed through a BottomUpDeterministicBlock (BUDB) + (without downsampling), and then merged to the latent tensor produced by the primary flow + of the `BottomUpLayer` through the `MergeLowRes` layer. It is meaningful to remark that + the BUDB that takes care of encoding the low-res input can be either shared with the + primary flow (and in that case it is the "same_size" BUDB (or stack of BUDBs) -> see `self.net`), + or can be a deep-copy of the primary flow's BUDB. + This behaviour is controlled by `lowres_separate_branch` parameter. + """ + + def __init__( + self, + n_res_blocks: int, + n_filters: int, + downsampling_steps: int = 0, + nonlin: Callable = None, + batchnorm: bool = True, + dropout: float = None, + res_block_type: str = None, + res_block_kernel: int = None, + res_block_skip_padding: bool = False, + gated: bool = None, + enable_multiscale: bool = False, + multiscale_lowres_size_factor: int = None, + lowres_separate_branch: bool = False, + multiscale_retain_spatial_dims: bool = False, + decoder_retain_spatial_dims: bool = False, + output_expected_shape: Iterable[int] = None, + ): + """ + Constructor. + + Parameters + ---------- + n_res_blocks: int + Number of `BottomUpDeterministicResBlock` modules stacked in this layer. + n_filters: int + Number of channels present through out the layers of this block. + downsampling_steps: int, optional + Number of downsampling steps that has to be done in this layer (typically 1). + Default is 0. + nonlin: Callable, optional + The non-linearity function used in the block. Default is `None`. + batchnorm: bool, optional + Whether to use batchnorm layers. Default is `True`. + dropout: float, optional + The dropout probability in dropout layers. If `None` dropout is not used. + Default is `None`. + res_block_type: str, optional + A string specifying the structure of residual block. + Check `ResidualBlock` doscstring for more information. + Default is `None`. + res_block_kernel: Union[int, Iterable[int]], optional + The kernel size used in the convolutions of the residual block. + It can be either a single integer or a pair of integers defining the squared kernel. + Default is `None`. + res_block_skip_padding: bool, optional + Whether to skip padding in convolutions in the Residual block. Default is `False`. + gated: bool, optional + Whether to use gated layer. Default is `None`. + enable_multiscale: bool, optional + Whether to enable multiscale (Lateral Contextualization) or not. Default is `False`. + multiscale_lowres_size_factor: int, optional + A factor the expresses the relative size of the primary flow tensor with respect to the + lower-resolution lateral input tensor. Default in `None`. + lowres_separate_branch: bool, optional + Whether the residual block(s) encoding the low-res input should be shared (`False`) or + not (`True`) with the primary flow "same-size" residual block(s). Default is `False`. + multiscale_retain_spatial_dims: bool, optional + Whether to pad the latent tensor resulting from the bottom-up layer's primary flow + to match the size of the low-res input. Default is `False`. + decoder_retain_spatial_dims: bool, optional + Default is `False`. + output_expected_shape: Iterable[int], optional + The expected shape of the layer output (only used if `enable_multiscale == True`). + Default is `None`. + """ + super().__init__() + + # Define attributes for Lateral Contextualization + self.enable_multiscale = enable_multiscale + self.lowres_separate_branch = lowres_separate_branch + self.multiscale_retain_spatial_dims = multiscale_retain_spatial_dims + self.multiscale_lowres_size_factor = multiscale_lowres_size_factor + self.decoder_retain_spatial_dims = decoder_retain_spatial_dims + self.output_expected_shape = output_expected_shape + assert self.output_expected_shape is None or self.enable_multiscale is True + + bu_blocks_downsized = [] + bu_blocks_samesize = [] + for _ in range(n_res_blocks): + do_resample = False + if downsampling_steps > 0: + do_resample = True + downsampling_steps -= 1 + block = BottomUpDeterministicResBlock( + c_in=n_filters, + c_out=n_filters, + nonlin=nonlin, + downsample=do_resample, + batchnorm=batchnorm, + dropout=dropout, + res_block_type=res_block_type, + res_block_kernel=res_block_kernel, + skip_padding=res_block_skip_padding, + gated=gated, + ) + if do_resample: + bu_blocks_downsized.append(block) + else: + bu_blocks_samesize.append(block) + + self.net_downsized = nn.Sequential(*bu_blocks_downsized) + self.net = nn.Sequential(*bu_blocks_samesize) + + # Using the same net for the low resolution (and larger sized image) + self.lowres_net = self.lowres_merge = None + if self.enable_multiscale: + self._init_multiscale( + n_filters=n_filters, + nonlin=nonlin, + batchnorm=batchnorm, + dropout=dropout, + res_block_type=res_block_type, + ) + + # msg = f'[{self.__class__.__name__}] McEnabled:{int(enable_multiscale)} ' + # if enable_multiscale: + # msg += f'McParallelBeam:{int(multiscale_retain_spatial_dims)} McFactor{multiscale_lowres_size_factor}' + # print(msg) + + def _init_multiscale( + self, + nonlin: Callable = None, + n_filters: int = None, + batchnorm: bool = None, + dropout: float = None, + res_block_type: str = None, + ) -> None: + """ + This method defines the modules responsible of merging compressed lateral inputs to the outputs + of the primary flow at different hierarchical levels in the multiresolution approach (LC). + + Specifically, the method initializes `lowres_net`, which is a stack of `BottomUpDeterministicBlock`'s + (w/out downsampling) that takes care of additionally processing the low-res input, and `lowres_merge`, + which is the module responsible of merging the compressed lateral input to the main flow. + + NOTE: The merge modality is set by default to "residual", meaning that the merge layer + performs concatenation on dim=1, followed by 1x1 convolution and a Residual Gated block. + + Parameters + ---------- + nonlin: Callable, optional + The non-linearity function used in the block. Default is `None`. + n_filters: int + Number of channels present through out the layers of this block. + batchnorm: bool, optional + Whether to use batchnorm layers. Default is `True`. + dropout: float, optional + The dropout probability in dropout layers. If `None` dropout is not used. + Default is `None`. + res_block_type: str, optional + A string specifying the structure of residual block. + Check `ResidualBlock` doscstring for more information. + Default is `None`. + """ + self.lowres_net = self.net + if self.lowres_separate_branch: + self.lowres_net = deepcopy(self.net) + + self.lowres_merge = MergeLowRes( + channels=n_filters, + merge_type="residual", + nonlin=nonlin, + batchnorm=batchnorm, + dropout=dropout, + res_block_type=res_block_type, + multiscale_retain_spatial_dims=self.multiscale_retain_spatial_dims, + multiscale_lowres_size_factor=self.multiscale_lowres_size_factor, + ) + + def forward( + self, x: torch.Tensor, lowres_x: torch.Tensor = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Parameters + ---------- + x: torch.Tensor + The input of the `BottomUpLayer`, i.e., the input image or the output of the + previous layer. + lowres_x: torch.Tensor, optional + The low-res input used for Lateral Contextualization (LC). Default is `None`. + """ + # The input is fed through the residual downsampling block(s) + primary_flow = self.net_downsized(x) + # The downsampling output is fed through additional residual block(s) + primary_flow = self.net(primary_flow) + + # If LC is not used, simply return output of primary-flow + if self.enable_multiscale is False: + assert lowres_x is None + return primary_flow, primary_flow + + if lowres_x is not None: + # First encode the low-res lateral input + lowres_flow = self.lowres_net(lowres_x) + # Then pass the result through the MergeLowRes layer + merged = self.lowres_merge(primary_flow, lowres_flow) + else: + merged = primary_flow + + if ( + self.multiscale_retain_spatial_dims is False + or self.decoder_retain_spatial_dims is True + ): + return merged, merged + + if self.output_expected_shape is not None: + expected_shape = self.output_expected_shape + else: + fac = self.multiscale_lowres_size_factor + expected_shape = (merged.shape[-2] // fac, merged.shape[-1] // fac) + assert merged.shape[-2:] != expected_shape + + # Crop the resulting tensor so that it matches with the Decoder + value_to_use_in_topdown = crop_img_tensor(merged, expected_shape) + return merged, value_to_use_in_topdown + + +class MergeLayer(nn.Module): + """ + This layer merges two or more 4D input tensors by concatenating along dim=1 and passes the result through: + a) a convolutional 1x1 layer (`merge_type == "linear"`), or + b) a convolutional 1x1 layer and then a gated residual block (`merge_type == "residual"`), or + c) a convolutional 1x1 layer and then an ungated residual block (`merge_type == "residual_ungated"`). + """ + + def __init__( + self, + merge_type: Literal["linear", "residual", "residual_ungated"], + channels: Union[int, Iterable[int]], + nonlin: Callable = nn.LeakyReLU, + batchnorm: bool = True, + dropout: float = None, + res_block_type: str = None, + res_block_kernel: int = None, + res_block_skip_padding: bool = False, + conv2d_bias: bool = True, + ): + """ + Constructor. + + Parameters + ---------- + merge_type: Literal["linear", "residual", "residual_ungated"] + The type of merge done in the layer. It can be chosen between "linear", "residual", and "residual_ungated". + Check the class docstring for more information about the behaviour of different merge modalities. + channels: Union[int, Iterable[int]] + The number of channels used in the convolutional blocks of this layer. + If it is an `int`: + - 1st 1x1 Conv2d: in_channels=2*channels, out_channels=channels + - (Optional) ResBlock: in_channels=channels, out_channels=channels + If it is an Iterable (must have `len(channels)==3`): + - 1st 1x1 Conv2d: in_channels=sum(channels[:-1]), out_channels=channels[-1] + - (Optional) ResBlock: in_channels=channels[-1], out_channels=channels[-1] + nonlin: Callable, optional + The non-linearity function used in the block. Default is `nn.LeakyReLU`. + batchnorm: bool, optional + Whether to use batchnorm layers. Default is `True`. + dropout: float, optional + The dropout probability in dropout layers. If `None` dropout is not used. + Default is `None`. + res_block_type: str, optional + A string specifying the structure of residual block. + Check `ResidualBlock` doscstring for more information. + Default is `None`. + res_block_kernel: Union[int, Iterable[int]], optional + The kernel size used in the convolutions of the residual block. + It can be either a single integer or a pair of integers defining the squared kernel. + Default is `None`. + res_block_skip_padding: bool, optional + Whether to skip padding in convolutions in the Residual block. Default is `False`. + conv2d_bias: bool, optional + Whether to use bias term in convolutions. Default is `True`. + """ + super().__init__() + try: + iter(channels) + except TypeError: # it is not iterable + channels = [channels] * 3 + else: # it is iterable + if len(channels) == 1: + channels = [channels[0]] * 3 + + # assert len(channels) == 3 + + if merge_type == "linear": + self.layer = nn.Conv2d( + sum(channels[:-1]), channels[-1], 1, bias=conv2d_bias + ) + elif merge_type == "residual": + self.layer = nn.Sequential( + nn.Conv2d( + sum(channels[:-1]), channels[-1], 1, padding=0, bias=conv2d_bias + ), + ResidualGatedBlock( + channels[-1], + nonlin, + batchnorm=batchnorm, + dropout=dropout, + block_type=res_block_type, + kernel=res_block_kernel, + conv2d_bias=conv2d_bias, + skip_padding=res_block_skip_padding, + ), + ) + elif merge_type == "residual_ungated": + self.layer = nn.Sequential( + nn.Conv2d( + sum(channels[:-1]), channels[-1], 1, padding=0, bias=conv2d_bias + ), + ResidualBlock( + channels[-1], + nonlin, + batchnorm=batchnorm, + dropout=dropout, + block_type=res_block_type, + kernel=res_block_kernel, + conv2d_bias=conv2d_bias, + skip_padding=res_block_skip_padding, + ), + ) + + def forward(self, *args) -> torch.Tensor: + + # Concatenate the input tensors along dim=1 + x = torch.cat(args, dim=1) + + # Pass the concatenated tensor through the conv layer + x = self.layer(x) + + return x + + +class MergeLowRes(MergeLayer): + """ + Child class of `MergeLayer`, specifically designed to merge the low-resolution patches + that are used in Lateral Contextualization approach. + """ + + def __init__(self, *args, **kwargs): + self.retain_spatial_dims = kwargs.pop("multiscale_retain_spatial_dims") + self.multiscale_lowres_size_factor = kwargs.pop("multiscale_lowres_size_factor") + super().__init__(*args, **kwargs) + + def forward(self, latent: torch.Tensor, lowres: torch.Tensor) -> torch.Tensor: + """ + Parameters + ---------- + latent: torch.Tensor + The output latent tensor from previous layer in the LVAE hierarchy. + lowres: torch.Tensor + The low-res patch image to be merged to increase the context. + """ + if self.retain_spatial_dims: + # Pad latent tensor to match lowres tensor's shape + latent = pad_img_tensor(latent, lowres.shape[2:]) + else: + # Crop lowres tensor to match latent tensor's shape + lh, lw = lowres.shape[-2:] + h = lh // self.multiscale_lowres_size_factor + w = lw // self.multiscale_lowres_size_factor + h_pad = (lh - h) // 2 + w_pad = (lw - w) // 2 + lowres = lowres[:, :, h_pad:-h_pad, w_pad:-w_pad] + + return super().forward(latent, lowres) + + +class SkipConnectionMerger(MergeLayer): + """ + A specialized `MergeLayer` module, designed to handle skip connections in the model. + """ + + def __init__( + self, + nonlin: Callable, + channels: Union[int, Iterable[int]], + batchnorm: bool, + dropout: float, + res_block_type: str, + merge_type: Literal["linear", "residual", "residual_ungated"] = "residual", + conv2d_bias: bool = True, + res_block_kernel: int = None, + res_block_skip_padding: bool = False, + ): + """ + Constructor. + + nonlin: Callable, optional + The non-linearity function used in the block. Default is `nn.LeakyReLU`. + channels: Union[int, Iterable[int]] + The number of channels used in the convolutional blocks of this layer. + If it is an `int`: + - 1st 1x1 Conv2d: in_channels=2*channels, out_channels=channels + - (Optional) ResBlock: in_channels=channels, out_channels=channels + If it is an Iterable (must have `len(channels)==3`): + - 1st 1x1 Conv2d: in_channels=sum(channels[:-1]), out_channels=channels[-1] + - (Optional) ResBlock: in_channels=channels[-1], out_channels=channels[-1] + batchnorm: bool, optional + Whether to use batchnorm layers. Default is `True`. + dropout: float, optional + The dropout probability in dropout layers. If `None` dropout is not used. + Default is `None`. + res_block_type: str, optional + A string specifying the structure of residual block. + Check `ResidualBlock` doscstring for more information. + Default is `None`. + merge_type: Literal["linear", "residual", "residual_ungated"] + The type of merge done in the layer. It can be chosen between "linear", "residual", and "residual_ungated". + Check the class docstring for more information about the behaviour of different merge modalities. + conv2d_bias: bool, optional + Whether to use bias term in convolutions. Default is `True`. + res_block_kernel: Union[int, Iterable[int]], optional + The kernel size used in the convolutions of the residual block. + It can be either a single integer or a pair of integers defining the squared kernel. + Default is `None`. + res_block_skip_padding: bool, optional + Whether to skip padding in convolutions in the Residual block. Default is `False`. + """ + super().__init__( + channels=channels, + nonlin=nonlin, + merge_type=merge_type, + batchnorm=batchnorm, + dropout=dropout, + res_block_type=res_block_type, + res_block_kernel=res_block_kernel, + conv2d_bias=conv2d_bias, + res_block_skip_padding=res_block_skip_padding, + ) + + +class TopDownLayer(nn.Module): + """ + Top-down inference layer. + It includes: + - Stochastic sampling, + - Computation of KL divergence, + - A small deterministic ResNet that performs upsampling. + + NOTE 1: + The algorithm for generative inference approximately works as follows: + - p_params = output of top-down layer above + - bu = inferred bottom-up value at this layer + - q_params = merge(bu, p_params) + - z = stochastic_layer(q_params) + - (optional) get and merge skip connection from prev top-down layer + - top-down deterministic ResNet + + NOTE 2: + The Top-Down layer can work in two modes: inference and prediction/generative. + Depending on the particular mode, it follows distinct behaviours: + - In inference mode, parameters of q(z_i|z_i+1) are obtained from the inference path, + by merging outcomes of bottom-up and top-down passes. The exception is the top layer, + in which the parameters of q(z_L|x) are set as the output of the topmost bottom-up layer. + - On the contrary in prediciton/generative mode, parameters of q(z_i|z_i+1) can be obtained + once again by merging bottom-up and top-down outputs (CONDITIONAL GENERATION), or it is + possible to directly sample from the prior p(z_i|z_i+1) (UNCONDITIONAL GENERATION). + + NOTE 3: + When doing unconditional generation, bu_value is not available. Hence the + merge layer is not used, and z is sampled directly from p_params. + + NOTE 4: + If this is the top layer, at inference time, the uppermost bottom-up value + is used directly as q_params, and p_params are defined in this layer + (while they are usually taken from the previous layer), and can be learned. + """ + + def __init__( + self, + z_dim: int, + n_res_blocks: int, + n_filters: int, + is_top_layer: bool = False, + downsampling_steps: int = None, + nonlin: Callable = None, + merge_type: Literal["linear", "residual", "residual_ungated"] = None, + batchnorm: bool = True, + dropout: float = None, + stochastic_skip: bool = False, + res_block_type: str = None, + res_block_kernel: int = None, + res_block_skip_padding: bool = None, + groups: int = 1, + gated: bool = None, + learn_top_prior: bool = False, + top_prior_param_shape: Iterable[int] = None, + analytical_kl: bool = False, + bottomup_no_padding_mode: bool = False, + topdown_no_padding_mode: bool = False, + retain_spatial_dims: bool = False, + restricted_kl: bool = False, + vanilla_latent_hw: Iterable[int] = None, + non_stochastic_version: bool = False, + input_image_shape: Union[None, Tuple[int, int]] = None, + normalize_latent_factor: float = 1.0, + conv2d_bias: bool = True, + stochastic_use_naive_exponential: bool = False, + ): + """ + Constructor. + + Parameters + ---------- + z_dim: int + The size of the latent space. + n_res_blocks: int + The number of TopDownDeterministicResBlock blocks + n_filters: int + The number of channels present through out the layers of this block. + is_top_layer: bool, optional + Whether the current layer is at the top of the Decoder hierarchy. Default is `False`. + downsampling_steps: int, optional + The number of downsampling steps that has to be done in this layer (typically 1). + Default is `False`. + nonlin: Callable, optional + The non-linearity function used in the block (e.g., `nn.ReLU`). Deafault is `None`. + merge_type: Literal["linear", "residual", "residual_ungated"], optional + The type of merge done in the layer. It can be chosen between "linear", "residual", + and "residual_ungated". Check the `MergeLayer` class docstring for more information + about the behaviour of different merging modalities. Default is `None`. + batchnorm: bool, optional + Whether to use batchnorm layers. Default is `True`. + dropout: float, optional + The dropout probability in dropout layers. If `None` dropout is not used. + Default is `None`. + stochastic_skip: bool, optional + Whether to use skip connections between previous top-down layer's output and this layer's stochastic output. + Stochastic skip connection allows the previous layer's output has a way to directly reach this hierarchical + level, hence facilitating the gradient flow during backpropagation. Default is `False`. + res_block_type: str, optional + A string specifying the structure of residual block. + Check `ResidualBlock` documentation for more information. + Default is `None`. + res_block_kernel: Union[int, Iterable[int]], optional + The kernel size used in the convolutions of the residual block. + It can be either a single integer or a pair of integers defining the squared kernel. + Default is `None`. + res_block_skip_padding: bool, optional + Whether to skip padding in convolutions in the Residual block. Default is `None`. + groups: int, optional + The number of groups to consider in the convolutions. Default is 1. + gated: bool, optional + Whether to use gated layer in `ResidualBlock`. Default is `None`. + learn_top_prior: + Whether to set the top prior as learnable. + If this is set to `False`, in the top-most layer the prior will be N(0,1). + Otherwise, we will still have a normal distribution whose parameters will be learnt. + Deafult is `False`. + top_prior_param_shape: Iterable[int], optional + The size of the tensor which expresses the mean and the variance + of the prior for the top most layer. Default is `None`. + analytical_kl: bool, optional + If True, KL divergence is calculated according to the analytical formula. + Otherwise, an MC approximation using sampled latents is calculated. + Default is `False`. + bottomup_no_padding_mode: bool, optional + Whether padding is used in the different layers of the bottom-up pass. + It is meaningful to know this in advance in order to assess whether before + merging `bu_values` and `p_params` tensors any alignment is needed. + Default is `False`. + topdown_no_padding_mode: bool, optional + Whether padding is used in the different layers of the top-down pass. + It is meaningful to know this in advance in order to assess whether before + merging `bu_values` and `p_params` tensors any alignment is needed. + The same information is also needed in handling the skip connections between + top-down layers. Default is `False`. + retain_spatial_dims: bool, optional + If `True`, the size of Encoder's latent space is kept to `input_image_shape` within the topdown layer. + This implies that the oput spatial size equals the input spatial size. + To achieve this, we centercrop the intermediate representation. + Default is `False`. + restricted_kl: bool, optional + Whether to compute the restricted version of KL Divergence. + See `NormalStochasticBlock2d` module for more information about its computation. + Default is `False`. + vanilla_latent_hw: Iterable[int], optional + The shape of the latent tensor used for prediction (i.e., it influences the computation of restricted KL). + Default is `None`. + non_stochastic_version: bool, optional + Whether to replace the stochastic layer that samples a latent variable from the latent distribiution with + a non-stochastic layer that simply drwas a sample as the mode of the latent distribution. + Default is `False`. + input_image_shape: Tuple[int, int], optionalut + The shape of the input image tensor. + When `retain_spatial_dims` is set to `True`, this is used to ensure that the shape of this layer + output has the same shape as the input. Default is `None`. + normalize_latent_factor: float, optional + A factor used to normalize the latent tensors `q_params`. + Specifically, normalization is done by dividing the latent tensor by this factor. + Default is 1.0. + conv2d_bias: bool, optional + Whether to use bias term is the convolutional blocks of this layer. + Default is `True`. + stochastic_use_naive_exponential: bool, optional + If `False`, in the NormalStochasticBlock2d exponentials are computed according + to the alternative definition provided by `StableExponential` class. + This should improve numerical stability in the training process. + Default is `False`. + """ + super().__init__() + + self.is_top_layer = is_top_layer + self.z_dim = z_dim + self.stochastic_skip = stochastic_skip + self.learn_top_prior = learn_top_prior + self.analytical_kl = analytical_kl + self.bottomup_no_padding_mode = bottomup_no_padding_mode + self.topdown_no_padding_mode = topdown_no_padding_mode + self.retain_spatial_dims = retain_spatial_dims + self.latent_shape = input_image_shape if self.retain_spatial_dims else None + self.non_stochastic_version = non_stochastic_version + self.normalize_latent_factor = normalize_latent_factor + self._vanilla_latent_hw = vanilla_latent_hw + + # Define top layer prior parameters, possibly learnable + if is_top_layer: + self.top_prior_params = nn.Parameter( + torch.zeros(top_prior_param_shape), requires_grad=learn_top_prior + ) + + # Downsampling steps left to do in this layer + dws_left = downsampling_steps + + # Define deterministic top-down block, which is a sequence of deterministic + # residual blocks with (optional) downsampling. + block_list = [] + for _ in range(n_res_blocks): + do_resample = False + if dws_left > 0: + do_resample = True + dws_left -= 1 + block_list.append( + TopDownDeterministicResBlock( + c_in=n_filters, + c_out=n_filters, + nonlin=nonlin, + upsample=do_resample, + batchnorm=batchnorm, + dropout=dropout, + res_block_type=res_block_type, + res_block_kernel=res_block_kernel, + skip_padding=res_block_skip_padding, + gated=gated, + conv2d_bias=conv2d_bias, + groups=groups, + ) + ) + self.deterministic_block = nn.Sequential(*block_list) + + # Define stochastic block with 2D convolutions + if self.non_stochastic_version: + self.stochastic = NonStochasticBlock2d( + c_in=n_filters, + c_vars=z_dim, + c_out=n_filters, + transform_p_params=(not is_top_layer), + groups=groups, + conv2d_bias=conv2d_bias, + ) + else: + self.stochastic = NormalStochasticBlock2d( + c_in=n_filters, + c_vars=z_dim, + c_out=n_filters, + transform_p_params=(not is_top_layer), + vanilla_latent_hw=vanilla_latent_hw, + restricted_kl=restricted_kl, + use_naive_exponential=stochastic_use_naive_exponential, + ) + + if not is_top_layer: + # Merge layer: it combines bottom-up inference and top-down + # generative outcomes to give posterior parameters + self.merge = MergeLayer( + channels=n_filters, + merge_type=merge_type, + nonlin=nonlin, + batchnorm=batchnorm, + dropout=dropout, + res_block_type=res_block_type, + res_block_kernel=res_block_kernel, + conv2d_bias=conv2d_bias, + ) + + # Skip connection that goes around the stochastic top-down layer + if stochastic_skip: + self.skip_connection_merger = SkipConnectionMerger( + channels=n_filters, + nonlin=nonlin, + batchnorm=batchnorm, + dropout=dropout, + res_block_type=res_block_type, + merge_type=merge_type, + conv2d_bias=conv2d_bias, + res_block_kernel=res_block_kernel, + res_block_skip_padding=res_block_skip_padding, + ) + + # print(f'[{self.__class__.__name__}] normalize_latent_factor:{self.normalize_latent_factor}') + + def sample_from_q( + self, + input_: torch.Tensor, + bu_value: torch.Tensor, + var_clip_max: float = None, + mask: torch.Tensor = None, + ) -> torch.Tensor: + """ + This method computes the latent inference distribution q(z_i|z_{i+1}) amd samples a latent tensor from it. + + Parameters + ---------- + input_: torch.Tensor + The input tensor to the layer, which is the output of the top-down layer above. + bu_value: torch.Tensor + The tensor defining the parameters /mu_q and /sigma_q computed during the bottom-up deterministic pass + at the correspondent hierarchical layer. + var_clip_max: float, optional + The maximum value reachable by the log-variance of the latent distribtion. + Values exceeding this threshold are clipped. Default is `None`. + mask: Union[None, torch.Tensor], optional + A tensor that is used to mask the sampled latent tensor. Default is `None`. + """ + if self.is_top_layer: # In top layer, we don't merge bu_value with p_params + q_params = bu_value + else: + # NOTE: Here the assumption is that the vampprior is only applied on the top layer. + n_img_prior = None + p_params = self.get_p_params(input_, n_img_prior) + q_params = self.merge(bu_value, p_params) + + sample = self.stochastic.sample_from_q(q_params, var_clip_max) + + if mask: + return sample[mask] + + return sample + + def get_p_params( + self, + input_: torch.Tensor, + n_img_prior: int, + ) -> torch.Tensor: + """ + This method returns the parameters of the prior distribution p(z_i|z_{i+1}) for the latent tensor + depending on the hierarchical level of the layer and other specific conditions. + + Parameters + ---------- + input_: torch.Tensor + The input tensor to the layer, which is the output of the top-down layer above. + n_img_prior: int + The number of images to be generated from the unconditional prior distribution p(z_L). + """ + p_params = None + + # If top layer, define p_params as the ones of the prior p(z_L) + if self.is_top_layer: + p_params = self.top_prior_params + + # Sample specific number of images by expanding the prior + if n_img_prior is not None: + p_params = p_params.expand(n_img_prior, -1, -1, -1) + + # Else the input from the layer above is p_params itself + else: + p_params = input_ + + return p_params + + def align_pparams_buvalue( + self, p_params: torch.Tensor, bu_value: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + In case the padding is not used either (or both) in encoder and decoder, we could have a shape mismatch + in the spatial dimensions (usually, dim=2 & dim=3). + This method performs a centercrop to ensure that both remain aligned. + + Parameters + ---------- + p_params: torch.Tensor + The tensor defining the parameters /mu_p and /sigma_p for the latent distribution p(z_i|z_{i+1}). + bu_value: torch.Tensor + The tensor defining the parameters /mu_q and /sigma_q computed during the bottom-up deterministic pass + at the correspondent hierarchical layer. + """ + if bu_value.shape[-2:] != p_params.shape[-2:]: + assert self.bottomup_no_padding_mode is True + if self.topdown_no_padding_mode is False: + assert bu_value.shape[-1] > p_params.shape[-1] + bu_value = F.center_crop(bu_value, p_params.shape[-2:]) + else: + if bu_value.shape[-1] > p_params.shape[-1]: + bu_value = F.center_crop(bu_value, p_params.shape[-2:]) + else: + p_params = F.center_crop(p_params, bu_value.shape[-2:]) + return p_params, bu_value + + def forward( + self, + input_: torch.Tensor = None, + skip_connection_input: torch.Tensor = None, + inference_mode: bool = False, + bu_value: torch.Tensor = None, + n_img_prior: int = None, + forced_latent: torch.Tensor = None, + use_mode: bool = False, + force_constant_output: bool = False, + mode_pred: bool = False, + use_uncond_mode: bool = False, + var_clip_max: float = None, + ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]: + """ + Parameters + ---------- + input_: torch.Tensor, optional + The input tensor to the layer, which is the output of the top-down layer above. + Default is `None`. + skip_connection_input: torch.Tensor, optional + The tensor brought by the skip connection between the current and the previous top-down layer. + Default is `None`. + inference_mode: bool, optional + Whether the layer is in inference mode. See NOTE 2 in class description for more info. + Default is `False`. + bu_value: torch.Tensor, optional + The tensor defining the parameters /mu_q and /sigma_q computed during the bottom-up deterministic pass + at the correspondent hierarchical layer. Default is `None`. + n_img_prior: int, optional + The number of images to be generated from the unconditional prior distribution p(z_L). + Default is `None`. + forced_latent: torch.Tensor, optional + A pre-defined latent tensor. If it is not `None`, than it is used as the actual latent tensor and, + hence, sampling does not happen. Default is `None`. + use_mode: bool, optional + Wheteher the latent tensor should be set as the latent distribution mode. + In the case of Gaussian, the mode coincides with the mean of the distribution. + Default is `False`. + force_constant_output: bool, optional + Whether to copy the first sample (and rel. distrib parameters) over the whole batch. + This is used when doing experiment from the prior - q is not used. + Default is `False`. + mode_pred: bool, optional + Whether the model is in prediction mode. Default is `False`. + use_uncond_mode: bool, optional + Whether to use the uncoditional distribution p(z) to sample latents in prediction mode. + var_clip_max: float + The maximum value reachable by the log-variance of the latent distribtion. + Values exceeding this threshold are clipped. + """ + # Check consistency of arguments + inputs_none = input_ is None and skip_connection_input is None + if self.is_top_layer and not inputs_none: + raise ValueError("In top layer, inputs should be None") + + p_params = self.get_p_params(input_, n_img_prior) + + # Get the parameters for the latent distribution to sample from + if inference_mode: + if self.is_top_layer: + q_params = bu_value + if mode_pred is False: + p_params, bu_value = self.align_pparams_buvalue(p_params, bu_value) + else: + if use_uncond_mode: + q_params = p_params + else: + p_params, bu_value = self.align_pparams_buvalue(p_params, bu_value) + q_params = self.merge(bu_value, p_params) + # In generative mode, q is not used + else: + q_params = None + + # NOTE: Sampling is done either from q(z_i | z_{i+1}, x) or p(z_i | z_{i+1}) + # depending on the mode (hence, in practice, by checking whether q_params is None). + + # Normalization of latent space parameters: + # it is done, purely for stablity. See Very deep VAEs generalize autoregressive models. + if self.normalize_latent_factor: + q_params = q_params / self.normalize_latent_factor + + # Sample (and process) a latent tensor in the stochastic layer + x, data_stoch = self.stochastic( + p_params=p_params, + q_params=q_params, + forced_latent=forced_latent, + use_mode=use_mode, + force_constant_output=force_constant_output, + analytical_kl=self.analytical_kl, + mode_pred=mode_pred, + use_uncond_mode=use_uncond_mode, + var_clip_max=var_clip_max, + ) + + # Merge skip connection from previous layer + if self.stochastic_skip and not self.is_top_layer: + if self.topdown_no_padding_mode is True: + # If no padding is done in the current top-down pass, there may be a shape mismatch between current tensor and skip connection input. + # As an example, if the output of last TopDownLayer was of size 64*64, due to lack of padding in the current layer, the current tensor + # might become different in shape, say 60*60. + # In order to avoid shape mismatch, we do central crop of the skip connection input. + skip_connection_input = F.center_crop( + skip_connection_input, x.shape[-2:] + ) + + x = self.skip_connection_merger(x, skip_connection_input) + + # Save activation before residual block as it can be the skip connection input in the next layer + x_pre_residual = x + + if self.retain_spatial_dims: + # when we don't want to do padding in topdown as well, we need to spare some boundary pixels which would be used up. + extra_len = (self.topdown_no_padding_mode is True) * 3 + + # this means that x should be of the same size as config.data.image_size. So, we have to centercrop by a factor of 2 at this point. + # assert x.shape[-1] >= self.latent_shape[-1] // 2 + extra_len + # we assume that one topdown layer will have exactly one upscaling layer. + new_latent_shape = ( + self.latent_shape[0] // 2 + extra_len, + self.latent_shape[1] // 2 + extra_len, + ) + + # If the LC is not applied on all layers, then this can happen. + if x.shape[-1] > new_latent_shape[-1]: + x = F.center_crop(x, new_latent_shape) + + # Last top-down block (sequence of residual blocks) + x = self.deterministic_block(x) + + if self.topdown_no_padding_mode: + x = F.center_crop(x, self.latent_shape) + + # Save some metrics that will be used in the loss computation + keys = [ + "z", + "kl_samplewise", + "kl_samplewise_restricted", + "kl_spatial", + "kl_channelwise", + # 'logprob_p', + "logprob_q", + "qvar_max", + ] + data = {k: data_stoch.get(k, None) for k in keys} + data["q_mu"] = None + data["q_lv"] = None + if data_stoch["q_params"] is not None: + q_mu, q_lv = data_stoch["q_params"] + data["q_mu"] = q_mu + data["q_lv"] = q_lv + + return x, x_pre_residual, data + + +class NormalStochasticBlock2d(nn.Module): + """ + Stochastic block used in the Top-Down inference pass. + + Algorithm: + - map input parameters to q(z) and (optionally) p(z) via convolution + - sample a latent tensor z ~ q(z) + - feed z to convolution and return. + + NOTE 1: + If parameters for q are not given, sampling is done from p(z). + + NOTE 2: + The restricted KL divergence is obtained by first computing the element-wise KL divergence + (i.e., the KL computed for each element of the latent tensors). Then, the restricted version + is computed by summing over the channels and the spatial dimensions associated only to the + portion of the latent tensor that is used for prediction. + """ + + def __init__( + self, + c_in: int, + c_vars: int, + c_out: int, + kernel: int = 3, + transform_p_params: bool = True, + vanilla_latent_hw: int = None, + restricted_kl: bool = False, + use_naive_exponential: bool = False, + ): + """ + Parameters + ---------- + c_in: int + The number of channels of the input tensor. + c_vars: int + The number of channels of the latent space tensor. + c_out: int + The output of the stochastic layer. + Note that this is different from the sampled latent z. + kernel: int, optional + The size of the kernel used in convolutional layers. + Default is 3. + transform_p_params: bool, optional + Whether a transformation should be applied to the `p_params` tensor. + The transformation consists in a 2D convolution ()`conv_in_p()`) that + maps the input to a larger number of channels. + Default is `True`. + vanilla_latent_hw: int, optional + The shape of the latent tensor used for prediction (i.e., it influences the computation of restricted KL). + Default is `None`. + restricted_kl: bool, optional + Whether to compute the restricted version of KL Divergence. + See NOTE 2 for more information about its computation. + Default is `False`. + use_naive_exponential: bool, optional + If `False`, exponentials are computed according to the alternative definition + provided by `StableExponential` class. This should improve numerical stability + in the training process. Default is `False`. + """ + super().__init__() + assert kernel % 2 == 1 + pad = kernel // 2 + self.transform_p_params = transform_p_params + self.c_in = c_in + self.c_out = c_out + self.c_vars = c_vars + self._use_naive_exponential = use_naive_exponential + self._vanilla_latent_hw = vanilla_latent_hw + self._restricted_kl = restricted_kl + + if transform_p_params: + self.conv_in_p = nn.Conv2d(c_in, 2 * c_vars, kernel, padding=pad) + self.conv_in_q = nn.Conv2d(c_in, 2 * c_vars, kernel, padding=pad) + self.conv_out = nn.Conv2d(c_vars, c_out, kernel, padding=pad) + + # def forward_swapped(self, p_params, q_mu, q_lv): + # + # if self.transform_p_params: + # p_params = self.conv_in_p(p_params) + # else: + # assert p_params.size(1) == 2 * self.c_vars + # + # # Define p(z) + # p_mu, p_lv = p_params.chunk(2, dim=1) + # p = Normal(p_mu, (p_lv / 2).exp()) + # + # # Define q(z) + # q = Normal(q_mu, (q_lv / 2).exp()) + # # Sample from q(z) + # sampling_distrib = q + # + # # Generate latent variable (typically by sampling) + # z = sampling_distrib.rsample() + # + # # Output of stochastic layer + # out = self.conv_out(z) + # + # data = { + # 'z': z, # sampled variable at this layer (batch, ch, h, w) + # 'p_params': p_params, # (b, ch, h, w) where b is 1 or batch size + # } + # return out, data + + def get_z( + self, + sampling_distrib: torch.distributions.normal.Normal, + forced_latent: torch.Tensor, + use_mode: bool, + mode_pred: bool, + use_uncond_mode: bool, + ) -> torch.Tensor: + """ + This method enables to sample a latent tensor given the distribution to sample from. + + Latent variable can be obtained is several ways: + - Sampled from the (Gaussian) latent distribution. + - Taken as a pre-defined forced latent. + - Taken as the mode (mean) of the latent distribution. + - In prediction mode (`mode_pred==True`), can be either sample or taken as the distribution mode. + + Parameters + ---------- + sampling_distrib: torch.distributions.normal.Normal + The Gaussian distribution from which latent tensor is sampled. + forced_latent: torch.Tensor + A pre-defined latent tensor. If it is not `None`, than it is used as the actual latent tensor and, + hence, sampling does not happen. + use_mode: bool + Wheteher the latent tensor should be set as the latent distribution mode. + In the case of Gaussian, the mode coincides with the mean of the distribution. + mode_pred: bool + Whether the model is prediction mode. + use_uncond_mode: bool + Whether to use the uncoditional distribution p(z) to sample latents in prediction mode. + """ + if forced_latent is None: + if use_mode: + z = sampling_distrib.mean + else: + if mode_pred: + if use_uncond_mode: + z = sampling_distrib.mean + else: + z = sampling_distrib.rsample() + else: + z = sampling_distrib.rsample() + else: + z = forced_latent + return z + + def sample_from_q( + self, q_params: torch.Tensor, var_clip_max: float + ) -> torch.Tensor: + """ + Given an input parameter tensor defining q(z), + it processes it by calling `process_q_params()` method and + sample a latent tensor from the resulting distribution. + + Parameters + ---------- + q_params: torch.Tensor + The input tensor to be processed. + var_clip_max: float + The maximum value reachable by the log-variance of the latent distribtion. + Values exceeding this threshold are clipped. + """ + _, _, q = self.process_q_params(q_params, var_clip_max) + return q.rsample() + + def compute_kl_metrics( + self, + p: torch.distributions.normal.Normal, + p_params: torch.Tensor, + q: torch.distributions.normal.Normal, + q_params: torch.Tensor, + mode_pred: bool, + analytical_kl: bool, + z: torch.Tensor, + ) -> Dict[str, torch.Tensor]: + """ + Compute KL (analytical or MC estimate) and then process it, extracting composed versions of the metric. + Specifically, the different versions of the KL loss terms are: + - `kl_elementwise`: KL term for each single element of the latent tensor [Shape: (batch, ch, h, w)]. + - `kl_samplewise`: KL term associated to each sample in the batch [Shape: (batch, )]. + - `kl_samplewise_restricted`: KL term only associated to the portion of the latent tensor that is + used for prediction and summed over channel and spatial dimensions [Shape: (batch, )]. + - `kl_channelwise`: KL term associated to each sample and each channel [Shape: (batch, ch, )]. + - `kl_spatial`: KL term summed over the channels, i.e., retaining the spatial dimensions [Shape: (batch, h, w)] + + Parameters + ---------- + p: torch.distributions.normal.Normal + The prior generative distribution p(z_i|z_{i+1}) (or p(z_L)). + p_params: torch.Tensor + The parameters of the prior generative distribution. + q: torch.distributions.normal.Normal + The inference distribution q(z_i|z_{i+1}) (or q(z_L|x)). + q_params: torch.Tensor + The parameters of the inference distribution. + mode_pred: bool + Whether the model is in prediction mode. + analytical_kl: bool + Whether to compute the KL divergence analytically or using Monte Carlo estimation. + z: torch.Tensor + The sampled latent tensor. + """ + kl_samplewise_restricted = None + + if mode_pred is False: # if not in prediction mode + # KL term for each single element of the latent tensor [Shape: (batch, ch, h, w)] + if analytical_kl: + kl_elementwise = kl_divergence(q, p) + else: + kl_elementwise = kl_normal_mc(z, p_params, q_params) + + # KL term only associated to the portion of the latent tensor that is used for prediction and + # summed over channel and spatial dimensions. [Shape: (batch, )] + # NOTE: vanilla_latent_hw is the shape of the latent tensor used for prediction, hence + # the restriction has shape [Shape: (batch, ch, vanilla_latent_hw[0], vanilla_latent_hw[1])] + if self._restricted_kl: + pad = (kl_elementwise.shape[-1] - self._vanilla_latent_hw) // 2 + assert pad > 0, "Disable restricted kl since there is no restriction." + tmp = kl_elementwise[..., pad:-pad, pad:-pad] + kl_samplewise_restricted = tmp.sum((1, 2, 3)) + + # KL term associated to each sample in the batch [Shape: (batch, )] + kl_samplewise = kl_elementwise.sum((1, 2, 3)) + + # KL term associated to each sample and each channel [Shape: (batch, ch, )] + kl_channelwise = kl_elementwise.sum((2, 3)) + + # KL term summed over the channels, i.e., retaining the spatial dimensions [Shape: (batch, h, w)] + kl_spatial = kl_elementwise.sum(1) + else: # if predicting, no need to compute KL + kl_elementwise = kl_samplewise = kl_spatial = kl_channelwise = None + + kl_dict = { + "kl_elementwise": kl_elementwise, # (batch, ch, h, w) + "kl_samplewise": kl_samplewise, # (batch, ) + "kl_samplewise_restricted": kl_samplewise_restricted, # (batch, ) + "kl_channelwise": kl_channelwise, # (batch, ch) + "kl_spatial": kl_spatial, # (batch, h, w) + } + return kl_dict + + def process_p_params( + self, p_params: torch.Tensor, var_clip_max: float + ) -> Tuple[torch.Tensor, torch.Tensor, torch.distributions.normal.Normal]: + """ + Process the input parameters to get the prior distribution p(z_i|z_{i+1}) (or p(z_L)). + + Processing consists in: + - (optionally) 2D convolution on the input tensor to increase number of channels. + - split the resulting tensor into two chunks, the mean and the log-variance. + - (optionally) clip the log-variance to an upper threshold. + - define the normal distribution p(z) given the parameter tensors above. + + Parameters + ---------- + p_params: torch.Tensor + The input tensor to be processed. + var_clip_max: float + The maximum value reachable by the log-variance of the latent distribtion. + Values exceeding this threshold are clipped. + """ + if self.transform_p_params: + p_params = self.conv_in_p(p_params) + else: + assert p_params.size(1) == 2 * self.c_vars + + # Define p(z) + p_mu, p_lv = p_params.chunk(2, dim=1) + if var_clip_max is not None: + p_lv = torch.clip(p_lv, max=var_clip_max) + + p_mu = StableMean(p_mu) + p_lv = StableLogVar(p_lv, enable_stable=not self._use_naive_exponential) + p = Normal(p_mu.get(), p_lv.get_std()) + return p_mu, p_lv, p + + def process_q_params( + self, q_params: torch.Tensor, var_clip_max: float, allow_oddsizes: bool = False + ) -> Tuple[torch.Tensor, torch.Tensor, torch.distributions.normal.Normal]: + """ + Process the input parameters to get the inference distribution q(z_i|z_{i+1}) (or q(z|x)). + + Processing consists in: + - 2D convolution on the input tensor to increase number of channels. + - split the resulting tensor into two chunks, the mean and the log-variance. + - (optionally) clip the log-variance to an upper threshold. + - (optionally) crop the resulting tensors to ensure that the last spatial dimension is even. + - define the normal distribution q(z) given the parameter tensors above. + + Parameters + ---------- + p_params: torch.Tensor + The input tensor to be processed. + var_clip_max: float + The maximum value reachable by the log-variance of the latent distribtion. + Values exceeding this threshold are clipped. + """ + q_params = self.conv_in_q(q_params) + + q_mu, q_lv = q_params.chunk(2, dim=1) + if var_clip_max is not None: + q_lv = torch.clip(q_lv, max=var_clip_max) + + if q_mu.shape[-1] % 2 == 1 and allow_oddsizes is False: + q_mu = F.center_crop(q_mu, q_mu.shape[-1] - 1) + q_lv = F.center_crop(q_lv, q_lv.shape[-1] - 1) + # clip_start = np.random.rand() > 0.5 + # q_mu = q_mu[:, :, 1:, 1:] if clip_start else q_mu[:, :, :-1, :-1] + # q_lv = q_lv[:, :, 1:, 1:] if clip_start else q_lv[:, :, :-1, :-1] + + q_mu = StableMean(q_mu) + q_lv = StableLogVar(q_lv, enable_stable=not self._use_naive_exponential) + q = Normal(q_mu.get(), q_lv.get_std()) + return q_mu, q_lv, q + + def forward( + self, + p_params: torch.Tensor, + q_params: torch.Tensor = None, + forced_latent: torch.Tensor = None, + use_mode: bool = False, + force_constant_output: bool = False, + analytical_kl: bool = False, + mode_pred: bool = False, + use_uncond_mode: bool = False, + var_clip_max: float = None, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + """ + Parameters + ---------- + p_params: torch.Tensor + The output tensor of the top-down layer above (i.e., mu_{p,i+1}, sigma_{p,i+1}). + q_params: torch.Tensor, optional + The tensor resulting from merging the bu_value tensor at the same hierarchical level + from the bottom-up pass and the `p_params` tensor. Default is `None`. + forced_latent: torch.Tensor, optional + A pre-defined latent tensor. If it is not `None`, than it is used as the actual latent + tensor and, hence, sampling does not happen. Default is `None`. + use_mode: bool, optional + Wheteher the latent tensor should be set as the latent distribution mode. + In the case of Gaussian, the mode coincides with the mean of the distribution. + Default is `False`. + force_constant_output: bool, optional + Whether to copy the first sample (and rel. distrib parameters) over the whole batch. + This is used when doing experiment from the prior - q is not used. + Default is `False`. + analytical_kl: bool, optional + Whether to compute the KL divergence analytically or using Monte Carlo estimation. + Default is `False`. + mode_pred: bool, optional + Whether the model is in prediction mode. Default is `False`. + use_uncond_mode: bool, optional + Whether to use the uncoditional distribution p(z) to sample latents in prediction mode. + Default is `False`. + var_clip_max: float, optional + The maximum value reachable by the log-variance of the latent distribtion. + Values exceeding this threshold are clipped. Default is `None`. + """ + debug_qvar_max = 0 + + # Check sampling options consistency + assert (forced_latent is None) or (not use_mode) + + # Get generative distribution p(z_i|z_{i+1}) + p_mu, p_lv, p = self.process_p_params(p_params, var_clip_max) + p_params = (p_mu, p_lv) + + if q_params is not None: + # Get inference distribution q(z_i|z_{i+1}) + # NOTE: At inference time, don't centercrop the q_params even if they are odd in size. + q_mu, q_lv, q = self.process_q_params( + q_params, var_clip_max, allow_oddsizes=mode_pred is True + ) + q_params = (q_mu, q_lv) + sampling_distrib = q + debug_qvar_max = torch.max(q_lv.get()) + + # Centercrop p_params so that their size matches the one of q_params + q_size = q_mu.get().shape[-1] + if p_mu.get().shape[-1] != q_size and mode_pred is False: + p_mu.centercrop_to_size(q_size) + p_lv.centercrop_to_size(q_size) + else: + sampling_distrib = p + + # Sample latent variable + z = self.get_z( + sampling_distrib, forced_latent, use_mode, mode_pred, use_uncond_mode + ) + + # Copy one sample (and distrib parameters) over the whole batch. + # This is used when doing experiment from the prior - q is not used. + if force_constant_output: + z = z[0:1].expand_as(z).clone() + p_params = ( + p_params[0][0:1].expand_as(p_params[0]).clone(), + p_params[1][0:1].expand_as(p_params[1]).clone(), + ) + + # Pass the sampled latent througn the output convolutional layer of stochastic block + out = self.conv_out(z) + + # Compute log p(z)# NOTE: disabling its computation. + # if mode_pred is False: + # logprob_p = p.log_prob(z).sum((1, 2, 3)) + # else: + # logprob_p = None + + if q_params is not None: + # Compute log q(z) + logprob_q = q.log_prob(z).sum((1, 2, 3)) + # Compute KL divergence metrics + kl_dict = self.compute_kl_metrics( + p, p_params, q, q_params, mode_pred, analytical_kl, z + ) + else: + kl_dict = {} + logprob_q = None + + # Store meaningful quantities to use them in following layers + data = kl_dict + data["z"] = z # sampled variable at this layer (batch, ch, h, w) + data["p_params"] = p_params # (b, ch, h, w) where b is 1 or batch size + data["q_params"] = q_params # (batch, ch, h, w) + # data['logprob_p'] = logprob_p # (batch, ) + data["logprob_q"] = logprob_q # (batch, ) + data["qvar_max"] = debug_qvar_max + + return out, data + + +class NonStochasticBlock2d(nn.Module): + """ + Non-stochastic version of the NormalStochasticBlock2d. + """ + + def __init__( + self, + c_vars: int, + c_in: int, + c_out: int, + kernel: int = 3, + groups: int = 1, + conv2d_bias: bool = True, + transform_p_params: bool = True, + ): + """ + Constructor. + + Parameters + ---------- + c_vars: int + The number of channels of the latent space tensor. + c_in: int + The number of channels of the input tensor. + c_out: int + The output of the stochastic layer. + Note that this is different from the sampled latent z. + kernel: int, optional + The size of the kernel used in convolutional layers. + Default is 3. + groups: int, optional + The number of groups to consider in the convolutions of this layer. + Default is 1. + conv2d_bias: bool, optional + Whether to use bias term is the convolutional blocks of this layer. + Default is `True`. + transform_p_params: bool, optional + Whether a transformation should be applied to the `p_params` tensor. + The transformation consists in a 2D convolution ()`conv_in_p()`) that + maps the input to a larger number of channels. + Default is `True`. + """ + super().__init__() + assert kernel % 2 == 1 + pad = kernel // 2 + self.transform_p_params = transform_p_params + self.c_in = c_in + self.c_out = c_out + self.c_vars = c_vars + + if transform_p_params: + self.conv_in_p = nn.Conv2d( + c_in, 2 * c_vars, kernel, padding=pad, bias=conv2d_bias, groups=groups + ) + self.conv_in_q = nn.Conv2d( + c_in, 2 * c_vars, kernel, padding=pad, bias=conv2d_bias, groups=groups + ) + self.conv_out = nn.Conv2d( + c_vars, c_out, kernel, padding=pad, bias=conv2d_bias, groups=groups + ) + + def compute_kl_metrics( + self, + p: torch.distributions.normal.Normal, + p_params: torch.Tensor, + q: torch.distributions.normal.Normal, + q_params: torch.Tensor, + mode_pred: bool, + analytical_kl: bool, + z: torch.Tensor, + ) -> Dict[str, None]: + """ + Compute KL (analytical or MC estimate) and then process it, extracting composed versions of the metric. + Specifically, the different versions of the KL loss terms are: + - `kl_elementwise`: KL term for each single element of the latent tensor [Shape: (batch, ch, h, w)]. + - `kl_samplewise`: KL term associated to each sample in the batch [Shape: (batch, )]. + - `kl_samplewise_restricted`: KL term only associated to the portion of the latent tensor that is + used for prediction and summed over channel and spatial dimensions [Shape: (batch, )]. + - `kl_channelwise`: KL term associated to each sample and each channel [Shape: (batch, ch, )]. + - `kl_spatial`: # KL term summed over the channels, i.e., retaining the spatial dimensions [Shape: (batch, h, w)] + + NOTE: in this class all the KL metrics are set to `None`. + + Parameters + ---------- + p: torch.distributions.normal.Normal + The prior generative distribution p(z_i|z_{i+1}) (or p(z_L)). + p_params: torch.Tensor + The parameters of the prior generative distribution. + q: torch.distributions.normal.Normal + The inference distribution q(z_i|z_{i+1}) (or q(z_L|x)). + q_params: torch.Tensor + The parameters of the inference distribution. + mode_pred: bool + Whether the model is in prediction mode. + analytical_kl: bool + Whether to compute the KL divergence analytically or using Monte Carlo estimation. + z: torch.Tensor + The sampled latent tensor. + """ + kl_dict = { + "kl_elementwise": None, # (batch, ch, h, w) + "kl_samplewise": None, # (batch, ) + "kl_spatial": None, # (batch, h, w) + "kl_channelwise": None, # (batch, ch) + } + return kl_dict + + def process_p_params(self, p_params, var_clip_max): + if self.transform_p_params: + p_params = self.conv_in_p(p_params) + else: + + assert ( + p_params.size(1) == 2 * self.c_vars + ), f"{p_params.shape} {self.c_vars}" + + # Define p(z) + p_mu, p_lv = p_params.chunk(2, dim=1) + return p_mu, None + + def process_q_params(self, q_params, var_clip_max, allow_oddsizes=False): + # Define q(z) + q_params = self.conv_in_q(q_params) + q_mu, q_lv = q_params.chunk(2, dim=1) + + if q_mu.shape[-1] % 2 == 1 and allow_oddsizes is False: + q_mu = F.center_crop(q_mu, q_mu.shape[-1] - 1) + + return q_mu, None + + def forward( + self, + p_params: torch.Tensor, + q_params: torch.Tensor = None, + forced_latent: Union[None, torch.Tensor] = None, + use_mode: bool = False, + force_constant_output: bool = False, + analytical_kl: bool = False, + mode_pred: bool = False, + use_uncond_mode: bool = False, + var_clip_max: float = None, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + """ + Parameters + ---------- + p_params: torch.Tensor + The output tensor of the top-down layer above (i.e., mu_{p,i+1}, sigma_{p,i+1}). + q_params: torch.Tensor, optional + The tensor resulting from merging the bu_value tensor at the same hierarchical level + from the bottom-up pass and the `p_params` tensor. Default is `None`. + forced_latent: torch.Tensor, optional + A pre-defined latent tensor. If it is not `None`, than it is used as the actual latent + tensor and, hence, sampling does not happen. Default is `None`. + use_mode: bool, optional + Wheteher the latent tensor should be set as the latent distribution mode. + In the case of Gaussian, the mode coincides with the mean of the distribution. + Default is `False`. + force_constant_output: bool, optional + Whether to copy the first sample (and rel. distrib parameters) over the whole batch. + This is used when doing experiment from the prior - q is not used. + Default is `False`. + analytical_kl: bool, optional + Whether to compute the KL divergence analytically or using Monte Carlo estimation. + Default is `False`. + mode_pred: bool, optional + Whether the model is in prediction mode. Default is `False`. + use_uncond_mode: bool, optional + Whether to use the uncoditional distribution p(z) to sample latents in prediction mode. + Default is `False`. + var_clip_max: float, optional + The maximum value reachable by the log-variance of the latent distribtion. + Values exceeding this threshold are clipped. Default is `None`. + """ + debug_qvar_max = 0 + assert (forced_latent is None) or (not use_mode) + + p_mu, _ = self.process_p_params(p_params, var_clip_max) + + p_params = (p_mu, None) + + if q_params is not None: + # At inference time, just don't centercrop the q_params even if they are odd in size. + q_mu, _ = self.process_q_params( + q_params, var_clip_max, allow_oddsizes=mode_pred is True + ) + q_params = (q_mu, None) + debug_qvar_max = torch.Tensor([1]).to(q_mu.device) + # Sample from q(z) + sampling_distrib = q_mu + q_size = q_mu.shape[-1] + if p_mu.shape[-1] != q_size and mode_pred is False: + p_mu.centercrop_to_size(q_size) + else: + # Sample from p(z) + sampling_distrib = p_mu + + # Generate latent variable (typically by sampling) + z = sampling_distrib + + # Copy one sample (and distrib parameters) over the whole batch. + # This is used when doing experiment from the prior - q is not used. + if force_constant_output: + z = z[0:1].expand_as(z).clone() + p_params = ( + p_params[0][0:1].expand_as(p_params[0]).clone(), + p_params[1][0:1].expand_as(p_params[1]).clone(), + ) + + # Output of stochastic layer + out = self.conv_out(z) + + kl_dict = {} + logprob_q = None + + data = kl_dict + data["z"] = z # sampled variable at this layer (batch, ch, h, w) + data["p_params"] = p_params # (b, ch, h, w) where b is 1 or batch size + data["q_params"] = q_params # (batch, ch, h, w) + data["logprob_q"] = logprob_q # (batch, ) + data["qvar_max"] = debug_qvar_max + + return out, data diff --git a/src/careamics/models/lvae/likelihoods.py b/src/careamics/models/lvae/likelihoods.py new file mode 100644 index 000000000..0ac55efba --- /dev/null +++ b/src/careamics/models/lvae/likelihoods.py @@ -0,0 +1,312 @@ +""" +Script containing modules for definining different likelihood functions (as nn.Module). +""" + +import math +from typing import Dict, Literal, Tuple, Union + +import numpy as np +import torch +from torch import nn + + +class LikelihoodModule(nn.Module): + """ + The base class for all likelihood modules. + It defines the fundamental structure and methods for specialized likelihood models. + """ + + def distr_params(self, x): + return None + + def set_params_to_same_device_as(self, correct_device_tensor): + pass + + @staticmethod + def logvar(params): + return None + + @staticmethod + def mean(params): + return None + + @staticmethod + def mode(params): + return None + + @staticmethod + def sample(params): + return None + + def log_likelihood(self, x, params): + return None + + def forward( + self, input_: torch.Tensor, x: torch.Tensor + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + + distr_params = self.distr_params(input_) + mean = self.mean(distr_params) + mode = self.mode(distr_params) + sample = self.sample(distr_params) + logvar = self.logvar(distr_params) + + if x is None: + ll = None + else: + ll = self.log_likelihood(x, distr_params) + + dct = { + "mean": mean, + "mode": mode, + "sample": sample, + "params": distr_params, + "logvar": logvar, + } + + return ll, dct + + +class GaussianLikelihood(LikelihoodModule): + r""" + A specialize `LikelihoodModule` for Gaussian likelihood. + + Specifically, in the LVAE model, the likelihood is defined as: + p(x|z_1) = N(x|\mu_{p,1}, \sigma_{p,1}^2) + """ + + def __init__( + self, + ch_in: int, + color_channels: int, + predict_logvar: Literal[None, "pixelwise", "global", "channelwise"] = None, + logvar_lowerbound: float = None, + conv2d_bias: bool = True, + ): + """ + Constructor. + + Parameters + ---------- + predict_logvar: Literal[None, 'global', 'pixelwise', 'channelwise'], optional + If not `None`, it expresses how to compute the log-variance. + Namely: + - if `pixelwise`, log-variance is computed for each pixel. + - if `global`, log-variance is computed as the mean of all pixel-wise entries. + - if `channelwise`, log-variance is computed as the average over the channels. + Default is `None`. + logvar_lowerbound: float, optional + The lowerbound value for log-variance. Default is `None`. + conv2d_bias: bool, optional + Whether to use bias term in convolutions. Default is `True`. + """ + super().__init__() + + # If True, then we also predict pixelwise logvar. + self.predict_logvar = predict_logvar + self.logvar_lowerbound = logvar_lowerbound + self.conv2d_bias = conv2d_bias + assert self.predict_logvar in [None, "global", "pixelwise", "channelwise"] + + # logvar_ch_needed = self.predict_logvar is not None + # self.parameter_net = nn.Conv2d(ch_in, + # color_channels * (1 + logvar_ch_needed), + # kernel_size=3, + # padding=1, + # bias=self.conv2d_bias) + self.parameter_net = nn.Identity() + + print( + f"[{self.__class__.__name__}] PredLVar:{self.predict_logvar} LowBLVar:{self.logvar_lowerbound}" + ) + + def get_mean_lv(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Given the output of the top-down pass, compute the mean and log-variance of the + Gaussian distribution defining the likelihood. + + Parameters + ---------- + x: torch.Tensor + The input tensor to the likelihood module, i.e., the output of the top-down pass. + """ + # Feed the output of the top-down pass to a parameter network + # This network can be either a Conv2d or Identity module + x = self.parameter_net(x) + + if self.predict_logvar is not None: + # Get pixel-wise mean and logvar + mean, lv = x.chunk(2, dim=1) + + # Optionally, compute the global or channel-wise logvar + if self.predict_logvar in ["channelwise", "global"]: + if self.predict_logvar == "channelwise": + # logvar should be of the following shape (batch, num_channels, ). Other dims would be singletons. + N = np.prod(lv.shape[:2]) + new_shape = (*mean.shape[:2], *([1] * len(mean.shape[2:]))) + elif self.predict_logvar == "global": + # logvar should be of the following shape (batch, ). Other dims would be singletons. + N = lv.shape[0] + new_shape = (*mean.shape[:1], *([1] * len(mean.shape[1:]))) + else: + raise ValueError( + f"Invalid value for self.predict_logvar:{self.predict_logvar}" + ) + + lv = torch.mean(lv.reshape(N, -1), dim=1) + lv = lv.reshape(new_shape) + + # Optionally, clip log-var to a lower bound + if self.logvar_lowerbound is not None: + lv = torch.clip(lv, min=self.logvar_lowerbound) + else: + mean = x + lv = None + return mean, lv + + def distr_params(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: + """ + Get parameters (mean, log-var) of the Gaussian distribution defined by the likelihood. + + Parameters + ---------- + x: torch.Tensor + The input tensor to the likelihood module, i.e., the output of the top-down pass. + """ + mean, lv = self.get_mean_lv(x) + params = { + "mean": mean, + "logvar": lv, + } + return params + + @staticmethod + def mean(params): + return params["mean"] + + @staticmethod + def mode(params): + return params["mean"] + + @staticmethod + def sample(params): + # p = Normal(params['mean'], (params['logvar'] / 2).exp()) + # return p.rsample() + return params["mean"] + + @staticmethod + def logvar(params): + return params["logvar"] + + def log_likelihood(self, x, params): + if self.predict_logvar is not None: + logprob = log_normal(x, params["mean"], params["logvar"]) + else: + logprob = -0.5 * (params["mean"] - x) ** 2 + return logprob + + +def log_normal( + x: torch.Tensor, mean: torch.Tensor, logvar: torch.Tensor +) -> torch.Tensor: + """ + Compute the log-probability at `x` of a Gaussian distribution + with parameters `(mean, exp(logvar))`. + + NOTE: In the case of LVAE, the log-likeihood formula becomes: + \\mathbb{E}_{z_1\\sim{q_\\phi}}[\\log{p_\theta(x|z_1)}]=-\frac{1}{2}(\\mathbb{E}_{z_1\\sim{q_\\phi}}[\\log{2\\pi\\sigma_{p,0}^2(z_1)}] +\\mathbb{E}_{z_1\\sim{q_\\phi}}[\frac{(x-\\mu_{p,0}(z_1))^2}{\\sigma_{p,0}^2(z_1)}]) + + Parameters + ---------- + x: torch.Tensor + The ground-truth tensor. Shape is (batch, channels, dim1, dim2). + mean: torch.Tensor + The inferred mean of distribution. Shape is (batch, channels, dim1, dim2). + logvar: torch.Tensor + The inferred log-variance of distribution. Shape has to be either scalar or broadcastable. + """ + var = torch.exp(logvar) + log_prob = -0.5 * ( + ((x - mean) ** 2) / var + logvar + torch.tensor(2 * math.pi).log() + ) + return log_prob + + +class NoiseModelLikelihood(LikelihoodModule): + + def __init__( + self, + ch_in: int, + color_channels: int, + data_mean: Union[Dict[str, torch.Tensor], torch.Tensor], + data_std: Union[Dict[str, torch.Tensor], torch.Tensor], + noiseModel: nn.Module, + ): + super().__init__() + self.parameter_net = ( + nn.Identity() + ) # nn.Conv2d(ch_in, color_channels, kernel_size=3, padding=1) + self.data_mean = data_mean + self.data_std = data_std + self.noiseModel = noiseModel + + def set_params_to_same_device_as(self, correct_device_tensor): + if isinstance(self.data_mean, torch.Tensor): + if self.data_mean.device != correct_device_tensor.device: + self.data_mean = self.data_mean.to(correct_device_tensor.device) + self.data_std = self.data_std.to(correct_device_tensor.device) + elif isinstance(self.data_mean, dict): + for key in self.data_mean.keys(): + self.data_mean[key] = self.data_mean[key].to( + correct_device_tensor.device + ) + self.data_std[key] = self.data_std[key].to(correct_device_tensor.device) + + def get_mean_lv(self, x): + return self.parameter_net(x), None + + def distr_params(self, x): + mean, lv = self.get_mean_lv(x) + # mean, lv = x.chunk(2, dim=1) + + params = { + "mean": mean, + "logvar": lv, + } + return params + + @staticmethod + def mean(params): + return params["mean"] + + @staticmethod + def mode(params): + return params["mean"] + + @staticmethod + def sample(params): + # p = Normal(params['mean'], (params['logvar'] / 2).exp()) + # return p.rsample() + return params["mean"] + + def log_likelihood(self, x: torch.Tensor, params: Dict[str, torch.Tensor]): + """ + Compute the log-likelihood given the parameters `params` obtained from the reconstruction tensor and the target tensor `x`. + """ + predicted_s_denormalized = ( + params["mean"] * self.data_std["target"] + self.data_mean["target"] + ) + x_denormalized = x * self.data_std["target"] + self.data_mean["target"] + # predicted_s_cloned = predicted_s_denormalized + # predicted_s_reduced = predicted_s_cloned.permute(1, 0, 2, 3) + + # x_cloned = x_denormalized + # x_cloned = x_cloned.permute(1, 0, 2, 3) + # x_reduced = x_cloned[0, ...] + # import pdb;pdb.set_trace() + likelihoods = self.noiseModel.likelihood( + x_denormalized, predicted_s_denormalized + ) + # likelihoods = self.noiseModel.likelihood(x, params['mean']) + logprob = torch.log(likelihoods) + return logprob diff --git a/src/careamics/models/lvae/lvae.py b/src/careamics/models/lvae/lvae.py new file mode 100644 index 000000000..f7d86a036 --- /dev/null +++ b/src/careamics/models/lvae/lvae.py @@ -0,0 +1,985 @@ +""" +Ladder VAE (LVAE) Model + +The current implementation is based on "Interpretable Unsupervised Diversity Denoising and Artefact Removal, Prakash et al." +""" + +from typing import Dict, Iterable, List, Tuple, Union + +import ml_collections +import numpy as np +import torch +import torch.nn as nn + +from .layers import ( + BottomUpDeterministicResBlock, + BottomUpLayer, + TopDownDeterministicResBlock, + TopDownLayer, +) +from .likelihoods import GaussianLikelihood, NoiseModelLikelihood +from .noise_models import get_noise_model +from .utils import Interpolate, LossType, ModelType, crop_img_tensor, pad_img_tensor + + +class LadderVAE(nn.Module): + + def __init__( + self, + data_mean: Union[np.ndarray, Dict[str, torch.Tensor]], + data_std: Union[np.ndarray, Dict[str, torch.Tensor]], + config: ml_collections.ConfigDict, + use_uncond_mode_at: Iterable[int] = [], + target_ch: int = 2, + ): + """ + Constructor. + + Parameters + ---------- + data_mean: Union[np.ndarray, Dict[str, torch.Tensor]] + The mean of the data used for normalization. + data_std: Union[np.ndarray, Dict[str, torch.Tensor]] + The standard deviation of the data used for normalization. + config: ml_collections.ConfigDict + The configuration object of the model. + use_uncond_mode_at: Iterable[int], optional + A sequence of indexes associated to the layers in which sampling is disabled + and the mode (mean value) is used instead. Default is `[]`. + target_ch: int, optional + The number of target channels (e.g., 1 for super-resolution or 2 for splitting). + Default is `2`. + """ + super().__init__() + + # ------------------------------------------------------- + # Customizable attributes + self.image_size = config.data.image_size + self._multiscale_count = config.data.multiscale_lowres_count + self.z_dims = config.model.z_dims + self.encoder_n_filters = config.model.n_filters + self.decoder_n_filters = config.model.n_filters + self.encoder_dropout = config.model.dropout + self.decoder_dropout = config.model.dropout + self.nonlin = config.model.nonlin + self.predict_logvar = config.model.predict_logvar + self.enable_noise_model = config.model.enable_noise_model + self.noise_model_ch1_fpath = config.model.noise_model_ch1_fpath + self.noise_model_ch2_fpath = config.model.noise_model_ch2_fpath + self.analytical_kl = config.model.analytical_kl + # ------------------------------------------------------- + + # ------------------------------------------------------- + # Model attributes -> Hardcoded + self.model_type = ModelType.LadderVae + self.encoder_blocks_per_layer = 1 + self.decoder_blocks_per_layer = 1 + self.bottomup_batchnorm = True + self.topdown_batchnorm = True + self.topdown_conv2d_bias = True + self.gated = True + self.encoder_res_block_kernel = 3 + self.decoder_res_block_kernel = 3 + self.encoder_res_block_skip_padding = False + self.decoder_res_block_skip_padding = False + self.merge_type = "residual" + self.no_initial_downscaling = True + self.skip_bottomk_buvalues = 0 + self.non_stochastic_version = False + self.stochastic_skip = True + self.learn_top_prior = True + self.res_block_type = "bacdbacd" + self.mode_pred = False + self.logvar_lowerbound = -5 + self._var_clip_max = 20 + self._stochastic_use_naive_exponential = False + self._enable_topdown_normalize_factor = True + + # Noise model attributes -> Hardcoded + self.noise_model_type = "gmm" + self.denoise_channel = ( + "input" # 4 values for denoise_channel {'Ch1', 'Ch2', 'input','all'} + ) + self.noise_model_learnable = False + + # Attributes that handle LC -> Hardcoded + self.enable_multiscale = ( + self._multiscale_count is not None and self._multiscale_count > 1 + ) + self.multiscale_retain_spatial_dims = True + self.multiscale_lowres_separate_branch = False + self.multiscale_decoder_retain_spatial_dims = ( + self.multiscale_retain_spatial_dims and self.enable_multiscale + ) + + # Derived attributes + self.n_layers = len(self.z_dims) + self.encoder_no_padding_mode = ( + self.encoder_res_block_skip_padding is True + and self.encoder_res_block_kernel > 1 + ) + self.decoder_no_padding_mode = ( + self.decoder_res_block_skip_padding is True + and self.decoder_res_block_kernel > 1 + ) + + # Others... + self._tethered_to_input = False + self._tethered_ch1_scalar = self._tethered_ch2_scalar = None + if self._tethered_to_input: + target_ch = 1 + requires_grad = False + self._tethered_ch1_scalar = nn.Parameter( + torch.ones(1) * 0.5, requires_grad=requires_grad + ) + self._tethered_ch2_scalar = nn.Parameter( + torch.ones(1) * 2.0, requires_grad=requires_grad + ) + # ------------------------------------------------------- + + # ------------------------------------------------------- + # Data attributes + self.color_ch = 1 + self.img_shape = (self.image_size, self.image_size) + self.normalized_input = True + # ------------------------------------------------------- + + # ------------------------------------------------------- + # Loss attributes + self._restricted_kl = False # HC + # enabling reconstruction loss on mixed input + self.mixed_rec_w = 0 + self.nbr_consistency_w = 0 + + # Setting the loss_type + self.loss_type = config.loss.get("loss_type", LossType.DenoiSplitMuSplit) + # ------------------------------------------------------- + + # ------------------------------------------------------- + # # Training attributes + # # can be used to tile the validation predictions + # self._val_idx_manager = val_idx_manager + # self._val_frame_creator = None + # # initialize the learning rate scheduler params. + # self.lr_scheduler_monitor = self.lr_scheduler_mode = None + # self._init_lr_scheduler_params(config) + # self._global_step = 0 + # ------------------------------------------------------- + + # ------------------------------------------------------- + # Attributes from constructor arguments + self.target_ch = target_ch + self.use_uncond_mode_at = use_uncond_mode_at + + # Data mean and std used for normalization + if isinstance(data_mean, np.ndarray): + self.data_mean = torch.Tensor(data_mean) + self.data_std = torch.Tensor(data_std) + elif isinstance(data_mean, dict): + for k in data_mean.keys(): + data_mean[k] = ( + torch.Tensor(data_mean[k]) + if not isinstance(data_mean[k], dict) + else data_mean[k] + ) + data_std[k] = ( + torch.Tensor(data_std[k]) + if not isinstance(data_std[k], dict) + else data_std[k] + ) + self.data_mean = data_mean + self.data_std = data_std + else: + raise NotImplementedError( + "data_mean and data_std must be either a numpy array or a dictionary" + ) + + assert self.data_std is not None + assert self.data_mean is not None + + # Initialize the Noise Model + self.likelihood_gm = self.likelihood_NM = None + self.noiseModel = get_noise_model( + enable_noise_model=self.enable_noise_model, + model_type=self.model_type, + noise_model_type=self.noise_model_type, + noise_model_ch1_fpath=self.noise_model_ch1_fpath, + noise_model_ch2_fpath=self.noise_model_ch2_fpath, + noise_model_learnable=self.noise_model_learnable, + ) + + if self.noiseModel is None: + self.likelihood_form = "gaussian" + else: + self.likelihood_form = "noise_model" + + # Calculate the downsampling happening in the network + self.downsample = [1] * self.n_layers + self.overall_downscale_factor = np.power(2, sum(self.downsample)) + if not self.no_initial_downscaling: # by default do another downscaling + self.overall_downscale_factor *= 2 + + assert max(self.downsample) <= self.encoder_blocks_per_layer + assert len(self.downsample) == self.n_layers + # ------------------------------------------------------- + + # ------------------------------------------------------- + ### CREATE MODEL BLOCKS + # First bottom-up layer: change num channels + downsample by factor 2 + # unless we want to prevent this + stride = 1 if self.no_initial_downscaling else 2 + self.first_bottom_up = self.create_first_bottom_up(stride) + + # Input Branches for Lateral Contextualization + self.lowres_first_bottom_ups = None + self._init_multires() + + # Other bottom-up layers + self.bottom_up_layers = self.create_bottom_up_layers( + self.multiscale_lowres_separate_branch + ) + + # Top-down layers + self.top_down_layers = self.create_top_down_layers() + self.final_top_down = self.create_final_topdown_layer( + not self.no_initial_downscaling + ) + + # Likelihood module + self.likelihood = self.create_likelihood_module() + + # Output layer --> Project to target_ch many channels + logvar_ch_needed = self.predict_logvar is not None + self.output_layer = self.parameter_net = nn.Conv2d( + self.decoder_n_filters, + self.target_ch * (1 + logvar_ch_needed), + kernel_size=3, + padding=1, + bias=self.topdown_conv2d_bias, + ) + + # # gradient norms. updated while training. this is also logged. + # self.grad_norm_bottom_up = 0.0 + # self.grad_norm_top_down = 0.0 + # PSNR computation on validation. + # self.label1_psnr = RunningPSNR() + # self.label2_psnr = RunningPSNR() + + # msg =f'[{self.__class__.__name__}] Stoc:{not self.non_stochastic_version} RecMode:{self.reconstruction_mode} TethInput:{self._tethered_to_input}' + # msg += f' TargetCh: {self.target_ch}' + # print(msg) + + ### SET OF METHODS TO CREATE MODEL BLOCKS + def create_first_bottom_up( + self, + init_stride: int, + num_res_blocks: int = 1, + ) -> nn.Sequential: + """ + This method creates the first bottom-up block of the Encoder. + Its role is to perform a first image compression step. + It is composed by a sequence of nn.Conv2d + non-linearity + + BottomUpDeterministicResBlock (1 or more, default is 1). + + Parameters + ---------- + init_stride: int + The stride used by the intial Conv2d block. + num_res_blocks: int, optional + The number of BottomUpDeterministicResBlocks to include in the layer, default is 1. + """ + nonlin = self.get_nonlin() + modules = [ + nn.Conv2d( + in_channels=self.color_ch, + out_channels=self.encoder_n_filters, + kernel_size=self.encoder_res_block_kernel, + padding=( + 0 + if self.encoder_res_block_skip_padding + else self.encoder_res_block_kernel // 2 + ), + stride=init_stride, + ), + nonlin(), + ] + + for _ in range(num_res_blocks): + modules.append( + BottomUpDeterministicResBlock( + c_in=self.encoder_n_filters, + c_out=self.encoder_n_filters, + nonlin=nonlin, + downsample=False, + batchnorm=self.bottomup_batchnorm, + dropout=self.encoder_dropout, + res_block_type=self.res_block_type, + skip_padding=self.encoder_res_block_skip_padding, + res_block_kernel=self.encoder_res_block_kernel, + ) + ) + + return nn.Sequential(*modules) + + def create_bottom_up_layers(self, lowres_separate_branch: bool) -> nn.ModuleList: + """ + This method creates the stack of bottom-up layers of the Encoder + that are used to generate the so-called `bu_values`. + + NOTE: + If `self._multiscale_count < self.n_layers`, then LC is done only in the first + `self._multiscale_count` bottom-up layers (starting from the bottom). + + Parameters + ---------- + lowres_separate_branch: bool + Whether the residual block(s) used for encoding the low-res input are shared (`False`) or + not (`True`) with the "same-size" residual block(s) in the `BottomUpLayer`'s primary flow. + """ + multiscale_lowres_size_factor = 1 + nonlin = self.get_nonlin() + + bottom_up_layers = nn.ModuleList([]) + for i in range(self.n_layers): + # Whether this is the top layer + is_top = i == self.n_layers - 1 + + # LC is applied only to the first (_multiscale_count - 1) bottom-up layers + layer_enable_multiscale = ( + self.enable_multiscale and self._multiscale_count > i + 1 + ) + + # This factor determines the factor by which the low-resolution tensor is larger + # N.B. Only used if layer_enable_multiscale == True, so we updated it only in that case + multiscale_lowres_size_factor *= 1 + int(layer_enable_multiscale) + + output_expected_shape = ( + (self.img_shape[0] // 2 ** (i + 1), self.img_shape[1] // 2 ** (i + 1)) + if self._multiscale_count > 1 + else None + ) + + # Add bottom-up deterministic layer at level i. + # It's a sequence of residual blocks (BottomUpDeterministicResBlock), possibly with downsampling between them. + bottom_up_layers.append( + BottomUpLayer( + n_res_blocks=self.encoder_blocks_per_layer, + n_filters=self.encoder_n_filters, + downsampling_steps=self.downsample[i], + nonlin=nonlin, + batchnorm=self.bottomup_batchnorm, + dropout=self.encoder_dropout, + res_block_type=self.res_block_type, + res_block_kernel=self.encoder_res_block_kernel, + res_block_skip_padding=self.encoder_res_block_skip_padding, + gated=self.gated, + lowres_separate_branch=lowres_separate_branch, + enable_multiscale=self.enable_multiscale, # shouldn't the arg be `layer_enable_multiscale` here? + multiscale_retain_spatial_dims=self.multiscale_retain_spatial_dims, + multiscale_lowres_size_factor=multiscale_lowres_size_factor, + decoder_retain_spatial_dims=self.multiscale_decoder_retain_spatial_dims, + output_expected_shape=output_expected_shape, + ) + ) + + return bottom_up_layers + + def create_top_down_layers(self) -> nn.ModuleList: + """ + This method creates the stack of top-down layers of the Decoder. + In these layer the `bu`_values` from the Encoder are merged with the `p_params` from the previous layer + of the Decoder to get `q_params`. Then, a stochastic layer generates a sample from the latent distribution + with parameters `q_params`. Finally, this sample is fed through a TopDownDeterministicResBlock to + compute the `p_params` for the layer below. + + NOTE 1: + The algorithm for generative inference approximately works as follows: + - p_params = output of top-down layer above + - bu = inferred bottom-up value at this layer + - q_params = merge(bu, p_params) + - z = stochastic_layer(q_params) + - (optional) get and merge skip connection from prev top-down layer + - top-down deterministic ResNet + + NOTE 2: + When doing unconditional generation, bu_value is not available. Hence the + merge layer is not used, and z is sampled directly from p_params. + + Parameters + ---------- + """ + top_down_layers = nn.ModuleList([]) + nonlin = self.get_nonlin() + # NOTE: top-down layers are created starting from the bottom-most + for i in range(self.n_layers): + # Check if this is the top layer + is_top = i == self.n_layers - 1 + + if self._enable_topdown_normalize_factor: + normalize_latent_factor = ( + 1 / np.sqrt(2 * (1 + i)) if len(self.z_dims) > 4 else 1.0 + ) + else: + normalize_latent_factor = 1.0 + + top_down_layers.append( + TopDownLayer( + z_dim=self.z_dims[i], + n_res_blocks=self.decoder_blocks_per_layer, + n_filters=self.decoder_n_filters, + is_top_layer=is_top, + downsampling_steps=self.downsample[i], + nonlin=nonlin, + merge_type=self.merge_type, + batchnorm=self.topdown_batchnorm, + dropout=self.decoder_dropout, + stochastic_skip=self.stochastic_skip, + learn_top_prior=self.learn_top_prior, + top_prior_param_shape=self.get_top_prior_param_shape(), + res_block_type=self.res_block_type, + res_block_kernel=self.decoder_res_block_kernel, + res_block_skip_padding=self.decoder_res_block_skip_padding, + gated=self.gated, + analytical_kl=self.analytical_kl, + restricted_kl=self._restricted_kl, + vanilla_latent_hw=self.get_latent_spatial_size(i), + # in no_padding_mode, what gets passed from the encoder are not multiples of 2 and so merging operation does not work natively. + bottomup_no_padding_mode=self.encoder_no_padding_mode, + topdown_no_padding_mode=self.decoder_no_padding_mode, + retain_spatial_dims=self.multiscale_decoder_retain_spatial_dims, + non_stochastic_version=self.non_stochastic_version, + input_image_shape=self.img_shape, + normalize_latent_factor=normalize_latent_factor, + conv2d_bias=self.topdown_conv2d_bias, + stochastic_use_naive_exponential=self._stochastic_use_naive_exponential, + ) + ) + return top_down_layers + + def create_final_topdown_layer(self, upsample: bool) -> nn.Sequential: + """ + This method creates the final top-down layer of the Decoder. + + Parameters + ---------- + upsample: bool + Whether to upsample the input of the final top-down layer + by bilinear interpolation with `scale_factor=2`. + """ + # Final top-down layer + modules = list() + + if upsample: + modules.append(Interpolate(scale=2)) + + for i in range(self.decoder_blocks_per_layer): + modules.append( + TopDownDeterministicResBlock( + c_in=self.decoder_n_filters, + c_out=self.decoder_n_filters, + nonlin=self.get_nonlin(), + batchnorm=self.topdown_batchnorm, + dropout=self.decoder_dropout, + res_block_type=self.res_block_type, + res_block_kernel=self.decoder_res_block_kernel, + skip_padding=self.decoder_res_block_skip_padding, + gated=self.gated, + conv2d_bias=self.topdown_conv2d_bias, + ) + ) + return nn.Sequential(*modules) + + def create_likelihood_module(self): + """ + This method defines the likelihood module for the current LVAE model. + The existing likelihood modules are `GaussianLikelihood` and `NoiseModelLikelihood`. + """ + self.likelihood_gm = GaussianLikelihood( + self.decoder_n_filters, + self.target_ch, + predict_logvar=self.predict_logvar, + logvar_lowerbound=self.logvar_lowerbound, + conv2d_bias=self.topdown_conv2d_bias, + ) + + self.likelihood_NM = None + if self.enable_noise_model: + self.likelihood_NM = NoiseModelLikelihood( + self.decoder_n_filters, + self.target_ch, + self.data_mean, + self.data_std, + self.noiseModel, + ) + if self.loss_type == LossType.DenoiSplitMuSplit or self.likelihood_NM is None: + return self.likelihood_gm + + return self.likelihood_NM + + def _init_multires(self, config: ml_collections.ConfigDict = None) -> nn.ModuleList: + """ + This method defines the input block/branch to encode/compress low-res lateral inputs at different hierarchical levels + in the multiresolution approach (LC). The role of the input branches is similar to the one of the first bottom-up layer + in the primary flow of the Encoder, namely to compress the lateral input image to a degree that is compatible with the + one of the primary flow. + + NOTE 1: Each input branch consists of a sequence of Conv2d + non-linearity + BottomUpDeterministicResBlock. + It is meaningful to observe that the `BottomUpDeterministicResBlock` shares the same model attributes with the blocks + in the primary flow of the Encoder (e.g., c_in, c_out, dropout, etc. etc.). Moreover, it does not perform downsampling. + + NOTE 2: `_multiscale_count` attribute defines the total number of inputs to the bottom-up pass. + In other terms if we have the input patch and n_LC additional lateral inputs, we will have a total of (n_LC + 1) inputs. + """ + stride = 1 if self.no_initial_downscaling else 2 + nonlin = self.get_nonlin() + if self._multiscale_count is None: + self._multiscale_count = 1 + + msg = "Multiscale count({}) should not exceed the number of bottom up layers ({}) by more than 1" + msg = msg.format(self._multiscale_count, self.n_layers) + assert ( + self._multiscale_count <= 1 or self._multiscale_count <= 1 + self.n_layers + ), msg + + msg = ( + "if multiscale is enabled, then we are just working with monocrome images." + ) + assert self._multiscale_count == 1 or self.color_ch == 1, msg + + lowres_first_bottom_ups = [] + for _ in range(1, self._multiscale_count): + first_bottom_up = nn.Sequential( + nn.Conv2d( + in_channels=self.color_ch, + out_channels=self.encoder_n_filters, + kernel_size=5, + padding=2, + stride=stride, + ), + nonlin(), + BottomUpDeterministicResBlock( + c_in=self.encoder_n_filters, + c_out=self.encoder_n_filters, + nonlin=nonlin, + downsample=False, + batchnorm=self.bottomup_batchnorm, + dropout=self.encoder_dropout, + res_block_type=self.res_block_type, + skip_padding=self.encoder_res_block_skip_padding, + ), + ) + lowres_first_bottom_ups.append(first_bottom_up) + + self.lowres_first_bottom_ups = ( + nn.ModuleList(lowres_first_bottom_ups) + if len(lowres_first_bottom_ups) + else None + ) + + ### SET OF FORWARD-LIKE METHODS + def bottomup_pass(self, inp: torch.Tensor) -> List[torch.Tensor]: + """ + Wrapper of _bottomup_pass(). + """ + return self._bottomup_pass( + inp, + self.first_bottom_up, + self.lowres_first_bottom_ups, + self.bottom_up_layers, + ) + + def _bottomup_pass( + self, + inp: torch.Tensor, + first_bottom_up: nn.Sequential, + lowres_first_bottom_ups: nn.ModuleList, + bottom_up_layers: nn.ModuleList, + ) -> List[torch.Tensor]: + """ + This method defines the forward pass throught the LVAE Encoder, the so-called + Bottom-Up pass. + + Parameters + ---------- + inp: torch.Tensor + The input tensor to the bottom-up pass of shape (B, 1+n_LC, H, W), where n_LC + is the number of lateral low-res inputs used in the LC approach. + In particular, the first channel corresponds to the input patch, while the + remaining ones are associated to the lateral low-res inputs. + first_bottom_up: nn.Sequential + The module defining the first bottom-up layer of the Encoder. + lowres_first_bottom_ups: nn.ModuleList + The list of modules defining Lateral Contextualization. + bottom_up_layers: nn.ModuleList + The list of modules defining the stack of bottom-up layers of the Encoder. + """ + if self._multiscale_count > 1: + x = first_bottom_up(inp[:, :1]) + else: + x = first_bottom_up(inp) + + # Loop from bottom to top layer, store all deterministic nodes we + # need for the top-down pass in bu_values list + bu_values = [] + for i in range(self.n_layers): + lowres_x = None + if self._multiscale_count > 1 and i + 1 < inp.shape[1]: + lowres_x = lowres_first_bottom_ups[i](inp[:, i + 1 : i + 2]) + + x, bu_value = bottom_up_layers[i](x, lowres_x=lowres_x) + bu_values.append(bu_value) + + return bu_values + + def topdown_pass( + self, + bu_values: torch.Tensor = None, + n_img_prior: torch.Tensor = None, + mode_layers: Iterable[int] = None, + constant_layers: Iterable[int] = None, + forced_latent: List[torch.Tensor] = None, + top_down_layers: nn.ModuleList = None, + final_top_down_layer: nn.Sequential = None, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + """ + This method defines the forward pass throught the LVAE Decoder, the so-called + Top-Down pass. + + Parameters + ---------- + bu_values: torch.Tensor, optional + Output of the bottom-up pass. It will have values from multiple layers of the ladder. + n_img_prior: optional + When `bu_values` is `None`, `n_img_prior` indicates the number of images to generate + from the prior (so bottom-up pass is not used at all here). + mode_layers: Iterable[int], optional + A sequence of indexes associated to the layers in which sampling is disabled and + the mode (mean value) is used instead. Set to `None` to avoid this behaviour. + constant_layers: Iterable[int], optional + A sequence of indexes associated to the layers in which a single instance's z is + copied over the entire batch (bottom-up path is not used, so only prior is used here). + Set to `None` to avoid this behaviour. + forced_latent: List[torch.Tensor], optional + A list of tensors that are used as fixed latent variables (hence, sampling doesn't take + place in this case). + top_down_layers: nn.ModuleList, optional + A list of top-down layers to use in the top-down pass. If `None`, the method uses the + default layers defined in the contructor. + final_top_down_layer: nn.Sequential, optional + The last top-down layer of the top-down pass. If `None`, the method uses the default + layers defined in the contructor. + """ + if top_down_layers is None: + top_down_layers = self.top_down_layers + if final_top_down_layer is None: + final_top_down_layer = self.final_top_down + + # Default: no layer is sampled from the distribution's mode + if mode_layers is None: + mode_layers = [] + if constant_layers is None: + constant_layers = [] + prior_experiment = len(mode_layers) > 0 or len(constant_layers) > 0 + + # If the bottom-up inference values are not given, don't do + # inference, sample from prior instead + inference_mode = bu_values is not None + + # Check consistency of arguments + if inference_mode != (n_img_prior is None): + msg = ( + "Number of images for top-down generation has to be given " + "if and only if we're not doing inference" + ) + raise RuntimeError(msg) + if ( + inference_mode + and prior_experiment + and (self.non_stochastic_version is False) + ): + msg = ( + "Prior experiments (e.g. sampling from mode) are not" + " compatible with inference mode" + ) + raise RuntimeError(msg) + + # Sampled latent variables at each layer + z = [None] * self.n_layers + + # KL divergence of each layer + kl = [None] * self.n_layers + # Kl divergence restricted, only for the LC enabled setup denoiSplit. + kl_restricted = [None] * self.n_layers + + # mean from which z is sampled. + q_mu = [None] * self.n_layers + # log(var) from which z is sampled. + q_lv = [None] * self.n_layers + + # Spatial map of KL divergence for each layer + kl_spatial = [None] * self.n_layers + + debug_qvar_max = [None] * self.n_layers + + kl_channelwise = [None] * self.n_layers + + if forced_latent is None: + forced_latent = [None] * self.n_layers + + # log p(z) where z is the sample in the topdown pass + # logprob_p = 0. + + # Top-down inference/generation loop + out = out_pre_residual = None + for i in reversed(range(self.n_layers)): + + # If available, get deterministic node from bottom-up inference + try: + bu_value = bu_values[i] + except TypeError: + bu_value = None + + # Whether the current layer should be sampled from the mode + use_mode = i in mode_layers + constant_out = i in constant_layers + use_uncond_mode = i in self.use_uncond_mode_at + + # Input for skip connection + skip_input = out # TODO or n? or both? + + # Full top-down layer, including sampling and deterministic part + out, out_pre_residual, aux = top_down_layers[i]( + input_=out, + skip_connection_input=skip_input, + inference_mode=inference_mode, + bu_value=bu_value, + n_img_prior=n_img_prior, + use_mode=use_mode, + force_constant_output=constant_out, + forced_latent=forced_latent[i], + mode_pred=self.mode_pred, + use_uncond_mode=use_uncond_mode, + var_clip_max=self._var_clip_max, + ) + + # Save useful variables + z[i] = aux["z"] # sampled variable at this layer (batch, ch, h, w) + kl[i] = aux["kl_samplewise"] # (batch, ) + kl_restricted[i] = aux["kl_samplewise_restricted"] + kl_spatial[i] = aux["kl_spatial"] # (batch, h, w) + q_mu[i] = aux["q_mu"] + q_lv[i] = aux["q_lv"] + + kl_channelwise[i] = aux["kl_channelwise"] + debug_qvar_max[i] = aux["qvar_max"] + # if self.mode_pred is False: + # logprob_p += aux['logprob_p'].mean() # mean over batch + # else: + # logprob_p = None + + # Final top-down layer + out = final_top_down_layer(out) + + # Store useful variables in a dict to return them + data = { + "z": z, # list of tensors with shape (batch, ch[i], h[i], w[i]) + "kl": kl, # list of tensors with shape (batch, ) + "kl_restricted": kl_restricted, # list of tensors with shape (batch, ) + "kl_spatial": kl_spatial, # list of tensors w shape (batch, h[i], w[i]) + "kl_channelwise": kl_channelwise, # list of tensors with shape (batch, ch[i]) + # 'logprob_p': logprob_p, # scalar, mean over batch + "q_mu": q_mu, + "q_lv": q_lv, + "debug_qvar_max": debug_qvar_max, + } + return out, data + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + """ + Parameters + ---------- + x: torch.Tensor + The input tensor of shape (B, C, H, W). + """ + img_size = x.size()[2:] + + # Pad input to size equal to the closest power of 2 + x_pad = self.pad_input(x) + + # Bottom-up inference: return list of length n_layers (bottom to top) + bu_values = self.bottomup_pass(x_pad) + for i in range(0, self.skip_bottomk_buvalues): + bu_values[i] = None + + mode_layers = range(self.n_layers) if self.non_stochastic_version else None + + # Top-down inference/generation + out, td_data = self.topdown_pass(bu_values, mode_layers=mode_layers) + + if out.shape[-1] > img_size[-1]: + # Restore original image size + out = crop_img_tensor(out, img_size) + + out = self.output_layer(out) + if self._tethered_to_input: + assert out.shape[1] == 1 + ch2 = self.get_other_channel(out, x_pad) + out = torch.cat([out, ch2], dim=1) + + return out, td_data + + ### SET OF UTILS METHODS + # def sample_prior( + # self, + # n_imgs, + # mode_layers=None, + # constant_layers=None + # ): + + # # Generate from prior + # out, _ = self.topdown_pass(n_img_prior=n_imgs, mode_layers=mode_layers, constant_layers=constant_layers) + # out = crop_img_tensor(out, self.img_shape) + + # # Log likelihood and other info (per data point) + # _, likelihood_data = self.likelihood(out, None) + + # return likelihood_data['sample'] + + # ### ??? + # def sample_from_q(self, x, masks=None): + # """ + # This method performs the bottomup_pass() and samples from the + # obtained distribution. + # """ + # img_size = x.size()[2:] + + # # Pad input to make everything easier with conv strides + # x_pad = self.pad_input(x) + + # # Bottom-up inference: return list of length n_layers (bottom to top) + # bu_values = self.bottomup_pass(x_pad) + # return self._sample_from_q(bu_values, masks=masks) + # ### ??? + + # def _sample_from_q(self, bu_values, top_down_layers=None, final_top_down_layer=None, masks=None): + # if top_down_layers is None: + # top_down_layers = self.top_down_layers + # if final_top_down_layer is None: + # final_top_down_layer = self.final_top_down + # if masks is None: + # masks = [None] * len(bu_values) + + # msg = "Multiscale is not supported as of now. You need the output from the previous layers to do this." + # assert self.n_layers == 1, msg + # samples = [] + # for i in reversed(range(self.n_layers)): + # bu_value = bu_values[i] + + # # Note that the first argument can be set to None since we are just dealing with one level + # sample = top_down_layers[i].sample_from_q(None, bu_value, var_clip_max=self._var_clip_max, mask=masks[i]) + # samples.append(sample) + + # return samples + + # def reset_for_different_output_size(self, output_size): + # for i in range(self.n_layers): + # sz = output_size // 2**(1 + i) + # self.bottom_up_layers[i].output_expected_shape = (sz, sz) + # self.top_down_layers[i].latent_shape = (output_size, output_size) + + def pad_input(self, x): + """ + Pads input x so that its sizes are powers of 2 + :param x: + :return: Padded tensor + """ + size = self.get_padded_size(x.size()) + x = pad_img_tensor(x, size) + return x + + ### SET OF GETTERS + def get_nonlin(self): + nonlin = { + "relu": nn.ReLU, + "leakyrelu": nn.LeakyReLU, + "elu": nn.ELU, + "selu": nn.SELU, + } + return nonlin[self.nonlin] + + def get_padded_size(self, size): + """ + Returns the smallest size (H, W) of the image with actual size given + as input, such that H and W are powers of 2. + :param size: input size, tuple either (N, C, H, w) or (H, W) + :return: 2-tuple (H, W) + """ + # Make size argument into (heigth, width) + if len(size) == 4: + size = size[2:] + if len(size) != 2: + msg = ( + "input size must be either (N, C, H, W) or (H, W), but it " + f"has length {len(size)} (size={size})" + ) + raise RuntimeError(msg) + + if self.multiscale_decoder_retain_spatial_dims is True: + # In this case, we can go much more deeper and so this is not required + # (in the way it is. ;). More work would be needed if this was to be correctly implemented ) + return list(size) + + # Overall downscale factor from input to top layer (power of 2) + dwnsc = self.overall_downscale_factor + + # Output smallest powers of 2 that are larger than current sizes + padded_size = list(((s - 1) // dwnsc + 1) * dwnsc for s in size) + + return padded_size + + def get_latent_spatial_size(self, level_idx: int): + """ + level_idx: 0 is the bottommost layer, the highest resolution one. + """ + actual_downsampling = level_idx + 1 + dwnsc = 2**actual_downsampling + sz = self.get_padded_size(self.img_shape) + h = sz[0] // dwnsc + w = sz[1] // dwnsc + assert h == w + return h + + def get_top_prior_param_shape(self, n_imgs: int = 1): + # TODO num channels depends on random variable we're using + + # Compute the total downscaling performed in the Encoder + if self.multiscale_decoder_retain_spatial_dims is False: + dwnsc = self.overall_downscale_factor + else: + # LC allow the encoder latents to keep the same (H, W) size at different levels + actual_downsampling = self.n_layers + 1 - self._multiscale_count + dwnsc = 2**actual_downsampling + + sz = self.get_padded_size(self.img_shape) + h = sz[0] // dwnsc + w = sz[1] // dwnsc + c = self.z_dims[-1] * 2 # mu and logvar + top_layer_shape = (n_imgs, c, h, w) + return top_layer_shape + + def get_other_channel(self, ch1, input): + assert self.data_std["target"].squeeze().shape == (2,) + assert self.data_mean["target"].squeeze().shape == (2,) + assert self.target_ch == 2 + ch1_un = ( + ch1[:, :1] * self.data_std["target"][:, :1] + + self.data_mean["target"][:, :1] + ) + input_un = input * self.data_std["input"] + self.data_mean["input"] + ch2_un = self._tethered_ch2_scalar * ( + input_un - ch1_un * self._tethered_ch1_scalar + ) + ch2 = (ch2_un - self.data_mean["target"][:, -1:]) / self.data_std["target"][ + :, -1: + ] + return ch2 diff --git a/src/careamics/models/lvae/noise_models.py b/src/careamics/models/lvae/noise_models.py new file mode 100644 index 000000000..10e8c0c35 --- /dev/null +++ b/src/careamics/models/lvae/noise_models.py @@ -0,0 +1,409 @@ +import json +import os + +import numpy as np +import torch +import torch.nn as nn + +from .utils import ModelType + + +class DisentNoiseModel(nn.Module): + + def __init__(self, *nmodels): + """ + Constructor. + + This class receives as input a variable number of noise models, each one corresponding to a channel. + """ + super().__init__() + # self.nmodels = nmodels + for i, nmodel in enumerate(nmodels): + if nmodel is not None: + self.add_module(f"nmodel_{i}", nmodel) + + self._nm_cnt = 0 + for nmodel in nmodels: + if nmodel is not None: + self._nm_cnt += 1 + + print(f"[{self.__class__.__name__}] Nmodels count:{self._nm_cnt}") + + def likelihood(self, obs: torch.Tensor, signal: torch.Tensor) -> torch.Tensor: + + if obs.shape[1] == 1: + assert signal.shape[1] == 1 + assert self.n2model is None + return self.nmodel_0.likelihood(obs, signal) + + assert obs.shape[1] == self._nm_cnt, f"{obs.shape[1]} != {self._nm_cnt}" + + ll_list = [] + for ch_idx in range(obs.shape[1]): + nmodel = getattr(self, f"nmodel_{ch_idx}") + ll_list.append( + nmodel.likelihood( + obs[:, ch_idx : ch_idx + 1], signal[:, ch_idx : ch_idx + 1] + ) + ) + + return torch.cat(ll_list, dim=1) + + +def last2path(fpath: str): + return os.path.join(*fpath.split("/")[-2:]) + + +def get_nm_config(noise_model_fpath: str): + config_fpath = os.path.join(os.path.dirname(noise_model_fpath), "config.json") + with open(config_fpath) as f: + noise_model_config = json.load(f) + return noise_model_config + + +def fastShuffle(series, num): + length = series.shape[0] + for i in range(num): + series = series[np.random.permutation(length), :] + return series + + +def get_noise_model( + enable_noise_model: bool, + model_type: ModelType, + noise_model_type: str, + noise_model_ch1_fpath: str, + noise_model_ch2_fpath: str, + noise_model_learnable: bool = False, + denoise_channel: str = "input", +): + if enable_noise_model: + nmodels = [] + # HDN -> one single output -> one single noise model + if model_type == ModelType.Denoiser: + if noise_model_type == "hist": + raise NotImplementedError( + '"hist" noise model is not supported for now.' + ) + elif noise_model_type == "gmm": + if denoise_channel == "Ch1": + nmodel_fpath = noise_model_ch1_fpath + print(f"Noise model Ch1: {nmodel_fpath}") + nmodel1 = GaussianMixtureNoiseModel(params=np.load(nmodel_fpath)) + nmodel2 = None + nmodels = [nmodel1, nmodel2] + elif denoise_channel == "Ch2": + nmodel_fpath = noise_model_ch2_fpath + print(f"Noise model Ch2: {nmodel_fpath}") + nmodel1 = GaussianMixtureNoiseModel(params=np.load(nmodel_fpath)) + nmodel2 = None + nmodels = [nmodel1, nmodel2] + elif denoise_channel == "input": + nmodel_fpath = noise_model_ch1_fpath + print(f"Noise model input: {nmodel_fpath}") + nmodel1 = GaussianMixtureNoiseModel(params=np.load(nmodel_fpath)) + nmodel2 = None + nmodels = [nmodel1, nmodel2] + else: + raise ValueError(f"Invalid denoise_channel: {denoise_channel}") + # muSplit -> two outputs -> two noise models + elif noise_model_type == "gmm": + print(f"Noise model Ch1: {noise_model_ch1_fpath}") + print(f"Noise model Ch2: {noise_model_ch2_fpath}") + + nmodel1 = GaussianMixtureNoiseModel(params=np.load(noise_model_ch1_fpath)) + nmodel2 = GaussianMixtureNoiseModel(params=np.load(noise_model_ch2_fpath)) + + nmodels = [nmodel1, nmodel2] + + # if 'noise_model_ch3_fpath' in config.model: + # print(f'Noise model Ch3: {config.model.noise_model_ch3_fpath}') + # nmodel3 = GaussianMixtureNoiseModel(params=np.load(config.model.noise_model_ch3_fpath)) + # nmodels = [nmodel1, nmodel2, nmodel3] + # else: + # nmodels = [nmodel1, nmodel2] + else: + raise ValueError(f"Invalid noise_model_type: {noise_model_type}") + + if noise_model_learnable: + for nmodel in nmodels: + if nmodel is not None: + nmodel.make_learnable() + + return DisentNoiseModel(*nmodels) + return None + + +class GaussianMixtureNoiseModel(nn.Module): + """ + The GaussianMixtureNoiseModel class describes a noise model which is parameterized as a mixture of gaussians. + If you would like to initialize a new object from scratch, then set `params`= None and specify the other parameters as keyword arguments. + If you are instead loading a model, use only `params`. + + Parameters + ---------- + **kwargs: keyworded, variable-length argument dictionary. + Arguments include: + min_signal : float + Minimum signal intensity expected in the image. + max_signal : float + Maximum signal intensity expected in the image. + path: string + Path to the directory where the trained noise model (*.npz) is saved in the `train` method. + weight : array + A [3*n_gaussian, n_coeff] sized array containing the values of the weights describing the noise model. + Each gaussian contributes three parameters (mean, standard deviation and weight), hence the number of rows in `weight` are 3*n_gaussian. + If `weight=None`, the weight array is initialized using the `min_signal` and `max_signal` parameters. + n_gaussian: int + Number of gaussians. + n_coeff: int + Number of coefficients to describe the functional relationship between gaussian parameters and the signal. + 2 implies a linear relationship, 3 implies a quadratic relationship and so on. + device: device + GPU device. + min_sigma: int + All values of sigma (`standard deviation`) below min_sigma are clamped to become equal to min_sigma. + params: dictionary + Use `params` if one wishes to load a model with trained weights. + While initializing a new object of the class `GaussianMixtureNoiseModel` from scratch, set this to `None`. + """ + + def __init__(self, **kwargs): + super().__init__() + self._learnable = False + + if kwargs.get("params") is None: + weight = kwargs.get("weight") + n_gaussian = kwargs.get("n_gaussian") + n_coeff = kwargs.get("n_coeff") + min_signal = kwargs.get("min_signal") + max_signal = kwargs.get("max_signal") + # self.device = kwargs.get('device') + self.path = kwargs.get("path") + self.min_sigma = kwargs.get("min_sigma") + if weight is None: + weight = np.random.randn(n_gaussian * 3, n_coeff) + weight[n_gaussian : 2 * n_gaussian, 1] = np.log(max_signal - min_signal) + weight = torch.from_numpy( + weight.astype(np.float32) + ).float() # .to(self.device) + weight = nn.Parameter(weight, requires_grad=True) + + self.n_gaussian = weight.shape[0] // 3 + self.n_coeff = weight.shape[1] + self.weight = weight + self.min_signal = torch.Tensor([min_signal]) # .to(self.device) + self.max_signal = torch.Tensor([max_signal]) # .to(self.device) + self.tol = torch.Tensor([1e-10]) # .to(self.device) + else: + params = kwargs.get("params") + # self.device = kwargs.get('device') + + self.min_signal = torch.Tensor(params["min_signal"]) # .to(self.device) + self.max_signal = torch.Tensor(params["max_signal"]) # .to(self.device) + + self.weight = torch.nn.Parameter( + torch.Tensor(params["trained_weight"]), requires_grad=False + ) # .to(self.device) + self.min_sigma = params["min_sigma"].item() + self.n_gaussian = self.weight.shape[0] // 3 + self.n_coeff = self.weight.shape[1] + self.tol = torch.Tensor([1e-10]) # .to(self.device) + self.min_signal = torch.Tensor([self.min_signal]) # .to(self.device) + self.max_signal = torch.Tensor([self.max_signal]) # .to(self.device) + + print(f"[{self.__class__.__name__}] min_sigma: {self.min_sigma}") + + def make_learnable(self): + print(f"[{self.__class__.__name__}] Making noise model learnable") + + self._learnable = True + self.weight.requires_grad = True + + # + + def to_device(self, cuda_tensor): + # move everything to GPU + if self.min_signal.device != cuda_tensor.device: + self.max_signal = self.max_signal.to(cuda_tensor.device) + self.min_signal = self.min_signal.to(cuda_tensor.device) + self.tol = self.tol.to(cuda_tensor.device) + self.weight = self.weight.to(cuda_tensor.device) + if self._learnable: + self.weight.requires_grad = True + + def polynomialRegressor(self, weightParams, signals): + """Combines `weightParams` and signal `signals` to regress for the gaussian parameter values. + + Parameters + ---------- + weightParams : torch.cuda.FloatTensor + Corresponds to specific rows of the `self.weight` + signals : torch.cuda.FloatTensor + Signals + + Returns + ------- + value : torch.cuda.FloatTensor + Corresponds to either of mean, standard deviation or weight, evaluated at `signals` + """ + value = 0 + for i in range(weightParams.shape[0]): + value += weightParams[i] * ( + ((signals - self.min_signal) / (self.max_signal - self.min_signal)) ** i + ) + return value + + def normalDens(self, x, m_=0.0, std_=None): + """Evaluates the normal probability density at `x` given the mean `m` and standard deviation `std`. + + Parameters + ---------- + x: torch.cuda.FloatTensor + Observations + m_: torch.cuda.FloatTensor + Mean + std_: torch.cuda.FloatTensor + Standard-deviation + + Returns + ------- + tmp: torch.cuda.FloatTensor + Normal probability density of `x` given `m_` and `std_` + """ + tmp = -((x - m_) ** 2) + tmp = tmp / (2.0 * std_ * std_) + tmp = torch.exp(tmp) + tmp = tmp / torch.sqrt((2.0 * np.pi) * std_ * std_) + return tmp + + def likelihood(self, observations, signals): + """Evaluates the likelihood of observations given the signals and the corresponding gaussian parameters. + + Parameters + ---------- + observations : torch.cuda.FloatTensor + Noisy observations + signals : torch.cuda.FloatTensor + Underlying signals + + Returns + ------- + value :p + self.tol + Likelihood of observations given the signals and the GMM noise model + """ + self.to_device(signals) + gaussianParameters = self.getGaussianParameters(signals) + p = 0 + for gaussian in range(self.n_gaussian): + p += ( + self.normalDens( + observations, + gaussianParameters[gaussian], + gaussianParameters[self.n_gaussian + gaussian], + ) + * gaussianParameters[2 * self.n_gaussian + gaussian] + ) + return p + self.tol + + def getGaussianParameters(self, signals): + """Returns the noise model for given signals + + Parameters + ---------- + signals : torch.cuda.FloatTensor + Underlying signals + + Returns + ------- + noiseModel: list of torch.cuda.FloatTensor + Contains a list of `mu`, `sigma` and `alpha` for the `signals` + + """ + noiseModel = [] + mu = [] + sigma = [] + alpha = [] + kernels = self.weight.shape[0] // 3 + for num in range(kernels): + mu.append(self.polynomialRegressor(self.weight[num, :], signals)) + # expval = torch.exp(torch.clamp(self.weight[kernels + num, :], max=MAX_VAR_W)) + expval = torch.exp(self.weight[kernels + num, :]) + # self.maxval = max(self.maxval, expval.max().item()) + sigmaTemp = self.polynomialRegressor(expval, signals) + sigmaTemp = torch.clamp(sigmaTemp, min=self.min_sigma) + sigma.append(torch.sqrt(sigmaTemp)) + + # expval = torch.exp( + # torch.clamp( + # self.polynomialRegressor(self.weight[2 * kernels + num, :], signals) + self.tol, MAX_ALPHA_W)) + expval = torch.exp( + self.polynomialRegressor(self.weight[2 * kernels + num, :], signals) + + self.tol + ) + # self.maxval = max(self.maxval, expval.max().item()) + alpha.append(expval) + + sum_alpha = 0 + for al in range(kernels): + sum_alpha = alpha[al] + sum_alpha + + # sum of alpha is forced to be 1. + for ker in range(kernels): + alpha[ker] = alpha[ker] / sum_alpha + + sum_means = 0 + # sum_means is the alpha weighted average of the means + for ker in range(kernels): + sum_means = alpha[ker] * mu[ker] + sum_means + + mu_shifted = [] + # subtracting the alpha weighted average of the means from the means + # ensures that the GMM has the inclination to have the mean=signals. + # its like a residual conection. I don't understand why we need to learn the mean? + for ker in range(kernels): + mu[ker] = mu[ker] - sum_means + signals + + for i in range(kernels): + noiseModel.append(mu[i]) + for j in range(kernels): + noiseModel.append(sigma[j]) + for k in range(kernels): + noiseModel.append(alpha[k]) + + return noiseModel + + def getSignalObservationPairs(self, signal, observation, lowerClip, upperClip): + """Returns the Signal-Observation pixel intensities as a two-column array + + Parameters + ---------- + signal : numpy array + Clean Signal Data + observation: numpy array + Noisy observation Data + lowerClip: float + Lower percentile bound for clipping. + upperClip: float + Upper percentile bound for clipping. + + Returns + ------- + noiseModel: list of torch floats + Contains a list of `mu`, `sigma` and `alpha` for the `signals` + """ + lb = np.percentile(signal, lowerClip) + ub = np.percentile(signal, upperClip) + stepsize = observation[0].size + n_observations = observation.shape[0] + n_signals = signal.shape[0] + sig_obs_pairs = np.zeros((n_observations * stepsize, 2)) + + for i in range(n_observations): + j = i // (n_observations // n_signals) + sig_obs_pairs[stepsize * i : stepsize * (i + 1), 0] = signal[j].ravel() + sig_obs_pairs[stepsize * i : stepsize * (i + 1), 1] = observation[i].ravel() + sig_obs_pairs = sig_obs_pairs[ + (sig_obs_pairs[:, 0] > lb) & (sig_obs_pairs[:, 0] < ub) + ] + return fastShuffle(sig_obs_pairs, 2) diff --git a/src/careamics/models/lvae/utils.py b/src/careamics/models/lvae/utils.py new file mode 100644 index 000000000..6e51db501 --- /dev/null +++ b/src/careamics/models/lvae/utils.py @@ -0,0 +1,395 @@ +""" +Script for utility functions needed by the LVAE model. +""" + +from typing import Iterable + +import numpy as np +import torch +import torch.nn as nn +import torchvision.transforms.functional as F +from torch.distributions.normal import Normal + + +def torch_nanmean(inp): + return torch.mean(inp[~inp.isnan()]) + + +def compute_batch_mean(x): + N = len(x) + return x.view(N, -1).mean(dim=1) + + +def power_of_2(self, x): + assert isinstance(x, int) + if x == 1: + return True + if x == 0: + # happens with validation + return False + if x % 2 == 1: + return False + return self.power_of_2(x // 2) + + +class Enum: + @classmethod + def name(cls, enum_type): + for key, value in cls.__dict__.items(): + if enum_type == value: + return key + + @classmethod + def contains(cls, enum_type): + for key, value in cls.__dict__.items(): + if enum_type == value: + return True + return False + + @classmethod + def from_name(cls, enum_type_str): + for key, value in cls.__dict__.items(): + if key == enum_type_str: + return value + assert f"{cls.__name__}:{enum_type_str} doesnot exist." + + +class LossType(Enum): + Elbo = 0 + ElboWithCritic = 1 + ElboMixedReconstruction = 2 + MSE = 3 + ElboWithNbrConsistency = 4 + ElboSemiSupMixedReconstruction = 5 + ElboCL = 6 + ElboRestrictedReconstruction = 7 + DenoiSplitMuSplit = 8 + + +class ModelType(Enum): + LadderVae = 3 + LadderVaeTwinDecoder = 4 + LadderVAECritic = 5 + # Separate vampprior: two optimizers + LadderVaeSepVampprior = 6 + # one encoder for mixed input, two for separate inputs. + LadderVaeSepEncoder = 7 + LadderVAEMultiTarget = 8 + LadderVaeSepEncoderSingleOptim = 9 + UNet = 10 + BraveNet = 11 + LadderVaeStitch = 12 + LadderVaeSemiSupervised = 13 + LadderVaeStitch2Stage = 14 # Note that previously trained models will have issue. + # since earlier, LadderVaeStitch2Stage = 13, LadderVaeSemiSupervised = 14 + LadderVaeMixedRecons = 15 + LadderVaeCL = 16 + LadderVaeTwoDataSet = ( + 17 # on one subdset, apply disentanglement, on other apply reconstruction + ) + LadderVaeTwoDatasetMultiBranch = 18 + LadderVaeTwoDatasetMultiOptim = 19 + LVaeDeepEncoderIntensityAug = 20 + AutoRegresiveLadderVAE = 21 + LadderVAEInterleavedOptimization = 22 + Denoiser = 23 + DenoiserSplitter = 24 + SplitterDenoiser = 25 + LadderVAERestrictedReconstruction = 26 + LadderVAETwoDataSetRestRecon = 27 + LadderVAETwoDataSetFinetuning = 28 + + +def _pad_crop_img(x, size, mode) -> torch.Tensor: + """Pads or crops a tensor. + Pads or crops a tensor of shape (batch, channels, h, w) to new height + and width given by a tuple. + Args: + x (torch.Tensor): Input image + size (list or tuple): Desired size (height, width) + mode (str): Mode, either 'pad' or 'crop' + Returns: + The padded or cropped tensor + """ + assert x.dim() == 4 and len(size) == 2 + size = tuple(size) + x_size = x.size()[2:4] + if mode == "pad": + cond = x_size[0] > size[0] or x_size[1] > size[1] + elif mode == "crop": + cond = x_size[0] < size[0] or x_size[1] < size[1] + else: + raise ValueError(f"invalid mode '{mode}'") + if cond: + raise ValueError(f"trying to {mode} from size {x_size} to size {size}") + dr, dc = (abs(x_size[0] - size[0]), abs(x_size[1] - size[1])) + dr1, dr2 = dr // 2, dr - (dr // 2) + dc1, dc2 = dc // 2, dc - (dc // 2) + if mode == "pad": + return nn.functional.pad(x, [dc1, dc2, dr1, dr2, 0, 0, 0, 0]) + elif mode == "crop": + return x[:, :, dr1 : x_size[0] - dr2, dc1 : x_size[1] - dc2] + + +def pad_img_tensor(x, size) -> torch.Tensor: + """Pads a tensor. + Pads a tensor of shape (batch, channels, h, w) to a desired height and width. + Args: + x (torch.Tensor): Input image + size (list or tuple): Desired size (height, width) + + Returns + ------- + The padded tensor + """ + return _pad_crop_img(x, size, "pad") + + +def crop_img_tensor(x, size) -> torch.Tensor: + """Crops a tensor. + Crops a tensor of shape (batch, channels, h, w) to a desired height and width + given by a tuple. + Args: + x (torch.Tensor): Input image + size (list or tuple): Desired size (height, width) + + Returns + ------- + The cropped tensor + """ + return _pad_crop_img(x, size, "crop") + + +class StableExponential: + """ + Class that redefines the definition of exp() to increase numerical stability. + Naturally, also the definition of log() must change accordingly. + However, it is worth noting that the two operations remain one the inverse of the other, + meaning that x = log(exp(x)) and x = exp(log(x)) are always true. + + Definition: + exp(x) = { + exp(x) if x<=0 + x+1 if x>0 + } + + log(x) = { + x if x<=0 + log(1+x) if x>0 + } + + NOTE 1: + Within the class everything is done on the tensor given as input to the constructor. + Therefore, when exp() is called, self._tensor.exp() is computed. + When log() is called, torch.log(self._tensor.exp()) is computed instead. + + NOTE 2: + Given the output from exp(), torch.log() or the log() method of the class give identical results. + """ + + def __init__(self, tensor): + self._raw_tensor = tensor + posneg_dic = self.posneg_separation(self._raw_tensor) + self.pos_f, self.neg_f = posneg_dic["filter"] + self.pos_data, self.neg_data = posneg_dic["value"] + + def posneg_separation(self, tensor): + pos = tensor > 0 + pos_tensor = torch.clip(tensor, min=0) + + neg = tensor <= 0 + neg_tensor = torch.clip(tensor, max=0) + + return {"filter": [pos, neg], "value": [pos_tensor, neg_tensor]} + + def exp(self): + return torch.exp(self.neg_data) * self.neg_f + (1 + self.pos_data) * self.pos_f + + def log(self): + return self.neg_data * self.neg_f + torch.log(1 + self.pos_data) * self.pos_f + + +class StableLogVar: + """ + Class that provides a numerically stable implementation of Log-Variance. + Specifically, it uses the exp() and log() formulas defined in `StableExponential` class. + """ + + def __init__( + self, logvar: torch.Tensor, enable_stable: bool = True, var_eps: float = 1e-6 + ): + """ + Contructor. + + Parameters + ---------- + logvar: torch.Tensor + The input (true) logvar vector, to be converted in the Stable version. + enable_stable: bool, optional + Whether to compute the stable version of log-variance. Default is `True`. + var_eps: float, optional + The minimum value attainable by the variance. Default is `1e-6`. + """ + self._lv = logvar + self._enable_stable = enable_stable + self._eps = var_eps + + def get(self) -> torch.Tensor: + if self._enable_stable is False: + return self._lv + + return torch.log(self.get_var()) + + def get_var(self) -> torch.Tensor: + """ + Get Variance from Log-Variance. + """ + if self._enable_stable is False: + return torch.exp(self._lv) + return StableExponential(self._lv).exp() + self._eps + + def get_std(self) -> torch.Tensor: + return torch.sqrt(self.get_var()) + + def centercrop_to_size(self, size: Iterable[int]) -> None: + """ + Centercrop the log-variance tensor to the desired size. + + Parameters + ---------- + size: torch.Tensor + The desired size of the log-variance tensor. + """ + if self._lv.shape[-1] == size: + return + + diff = self._lv.shape[-1] - size + assert diff > 0 and diff % 2 == 0 + self._lv = F.center_crop(self._lv, (size, size)) + + +class StableMean: + + def __init__(self, mean): + self._mean = mean + + def get(self) -> torch.Tensor: + return self._mean + + def centercrop_to_size(self, size: Iterable[int]) -> None: + """ + Centercrop the mean tensor to the desired size. + + Parameters + ---------- + size: torch.Tensor + The desired size of the log-variance tensor. + """ + if self._mean.shape[-1] == size: + return + + diff = self._mean.shape[-1] - size + assert diff > 0 and diff % 2 == 0 + self._mean = F.center_crop(self._mean, (size, size)) + + +def allow_numpy(func): + """ + All optional arguements are passed as is. positional arguments are checked. if they are numpy array, + they are converted to torch Tensor. + """ + + def numpy_wrapper(*args, **kwargs): + new_args = [] + for arg in args: + if isinstance(arg, np.ndarray): + arg = torch.Tensor(arg) + new_args.append(arg) + new_args = tuple(new_args) + + output = func(*new_args, **kwargs) + return output + + return numpy_wrapper + + +class Interpolate(nn.Module): + """Wrapper for torch.nn.functional.interpolate.""" + + def __init__(self, size=None, scale=None, mode="bilinear", align_corners=False): + super().__init__() + assert (size is None) == (scale is not None) + self.size = size + self.scale = scale + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + out = F.interpolate( + x, + size=self.size, + scale_factor=self.scale, + mode=self.mode, + align_corners=self.align_corners, + ) + return out + + +def kl_normal_mc(z, p_mulv, q_mulv): + """ + One-sample estimation of element-wise KL between two diagonal + multivariate normal distributions. Any number of dimensions, + broadcasting supported (be careful). + :param z: + :param p_mulv: + :param q_mulv: + :return: + """ + assert isinstance(p_mulv, tuple) + assert isinstance(q_mulv, tuple) + p_mu, p_lv = p_mulv + q_mu, q_lv = q_mulv + + p_std = p_lv.get_std() + q_std = q_lv.get_std() + + p_distrib = Normal(p_mu.get(), p_std) + q_distrib = Normal(q_mu.get(), q_std) + return q_distrib.log_prob(z) - p_distrib.log_prob(z) + + +def free_bits_kl( + kl: torch.Tensor, free_bits: float, batch_average: bool = False, eps: float = 1e-6 +) -> torch.Tensor: + """ + Computes free-bits version of KL divergence. + Ensures that the KL doesn't go to zero for any latent dimension. + Hence, it contributes to use latent variables more efficiently, + leading to better representation learning. + + NOTE: + Takes in the KL with shape (batch size, layers), returns the KL with + free bits (for optimization) with shape (layers,), which is the average + free-bits KL per layer in the current batch. + If batch_average is False (default), the free bits are per layer and + per batch element. Otherwise, the free bits are still per layer, but + are assigned on average to the whole batch. In both cases, the batch + average is returned, so it's simply a matter of doing mean(clamp(KL)) + or clamp(mean(KL)). + + Args: + kl (torch.Tensor) + free_bits (float) + batch_average (bool, optional)) + eps (float, optional) + + Returns + ------- + The KL with free bits + """ + assert kl.dim() == 2 + if free_bits < eps: + return kl.mean(0) + if batch_average: + return kl.mean(0).clamp(min=free_bits) + return kl.clamp(min=free_bits).mean(0) diff --git a/src/careamics/prediction_utils/stitch_prediction.py b/src/careamics/prediction_utils/stitch_prediction.py index eace8381c..1a5679759 100644 --- a/src/careamics/prediction_utils/stitch_prediction.py +++ b/src/careamics/prediction_utils/stitch_prediction.py @@ -37,7 +37,9 @@ def stitch_prediction( last_tiles = [tile_info.last_tile for tile_info in tile_infos] last_tile_position = np.where(last_tiles)[0] image_slices = [ - slice(None if i == 0 else last_tile_position[i - 1], last_tile_position[i] + 1) + slice( + None if i == 0 else last_tile_position[i - 1] + 1, last_tile_position[i] + 1 + ) for i in range(len(last_tile_position)) ] image_predictions = [] diff --git a/src/careamics/transforms/tta.py b/src/careamics/transforms/tta.py index fff518c06..0e6919852 100644 --- a/src/careamics/transforms/tta.py +++ b/src/careamics/transforms/tta.py @@ -1,11 +1,8 @@ """Test-time augmentations.""" -from typing import List - from torch import Tensor, flip, mean, rot90, stack -# TODO add tests class ImageRestorationTTA: """ Test-time augmentation for image restoration tasks. @@ -13,62 +10,79 @@ class ImageRestorationTTA: The augmentation is performed using all 90 deg rotations and their flipped version, as well as the original image flipped. - Tensors should be of shape SC(Z)YX + Tensors should be of shape SC(Z)YX. This transformation is used in the LightningModule in order to perform test-time - agumentation. + augmentation. """ - def __init__(self) -> None: - """Constructor.""" - pass - - def forward(self, x: Tensor) -> List[Tensor]: + def forward(self, input_tensor: Tensor) -> list[Tensor]: """ Apply test-time augmentation to the input tensor. Parameters ---------- - x : Tensor + input_tensor : Tensor Input tensor, shape SC(Z)YX. Returns ------- - List[Tensor] + list of torch.Tensor List of augmented tensors. """ + # axes: only applies to YX axes + axes = (-2, -1) + augmented = [ - x, - rot90(x, 1, dims=(-2, -1)), - rot90(x, 2, dims=(-2, -1)), - rot90(x, 3, dims=(-2, -1)), + # original + input_tensor, + # rotations + rot90(input_tensor, 1, dims=axes), + rot90(input_tensor, 2, dims=axes), + rot90(input_tensor, 3, dims=axes), + # original flipped + flip(input_tensor, dims=(axes[0],)), + flip(input_tensor, dims=(axes[1],)), ] - augmented_flip = augmented.copy() - for x_ in augmented: - augmented_flip.append(flip(x_, dims=(-3, -1))) - return augmented_flip - def backward(self, x: List[Tensor]) -> Tensor: + # rotated once, flipped + augmented.extend( + [ + flip(augmented[1], dims=(axes[0],)), + flip(augmented[1], dims=(axes[1],)), + ] + ) + + return augmented + + def backward(self, x: list[Tensor]) -> Tensor: """Undo the test-time augmentation. Parameters ---------- x : Any - List of augmented tensors. + List of augmented tensors of shape SC(Z)YX. Returns ------- Any Original tensor. """ + axes = (-2, -1) + reverse = [ + # original x[0], - rot90(x[1], -1, dims=(-2, -1)), - rot90(x[2], -2, dims=(-2, -1)), - rot90(x[3], -3, dims=(-2, -1)), - flip(x[4], dims=(-3, -1)), - rot90(flip(x[5], dims=(-3, -1)), -1, dims=(-2, -1)), - rot90(flip(x[6], dims=(-3, -1)), -2, dims=(-2, -1)), - rot90(flip(x[7], dims=(-3, -1)), -3, dims=(-2, -1)), + # rotated + rot90(x[1], -1, dims=axes), + rot90(x[2], -2, dims=axes), + rot90(x[3], -3, dims=axes), + # original flipped + flip(x[4], dims=(axes[0],)), + flip(x[5], dims=(axes[1],)), + # rotated once, flipped + rot90(flip(x[6], dims=(axes[0],)), -1, dims=axes), + rot90(flip(x[7], dims=(axes[1],)), -1, dims=axes), ] + return mean(stack(reverse), dim=0) diff --git a/src/careamics/utils/running_stats.py b/src/careamics/utils/running_stats.py deleted file mode 100644 index 1268d3e43..000000000 --- a/src/careamics/utils/running_stats.py +++ /dev/null @@ -1,43 +0,0 @@ -"""Running stats submodule, used in the Zarr dataset.""" - -# from multiprocessing import Value -# from typing import Tuple - -# import numpy as np - - -# class RunningStats: -# """Calculates running mean and std.""" - -# def __init__(self) -> None: -# self.reset() - -# def reset(self) -> None: -# """Reset the running stats.""" -# self.avg_mean = Value("d", 0) -# self.avg_std = Value("d", 0) -# self.m2 = Value("d", 0) -# self.count = Value("i", 0) - -# def init(self, mean: float, std: float) -> None: -# """Initialize running stats.""" -# with self.avg_mean.get_lock(): -# self.avg_mean.value += mean -# with self.avg_std.get_lock(): -# self.avg_std.value = std - -# def compute_std(self) -> Tuple[float, float]: -# """Compute std.""" -# if self.count.value >= 2: -# self.avg_std.value = np.sqrt(self.m2.value / self.count.value) - -# def update(self, value: float) -> None: -# """Update running stats.""" -# with self.count.get_lock(): -# self.count.value += 1 -# delta = value - self.avg_mean.value -# with self.avg_mean.get_lock(): -# self.avg_mean.value += delta / self.count.value -# delta2 = value - self.avg_mean.value -# with self.m2.get_lock(): -# self.m2.value += delta * delta2 diff --git a/tests/dataset/dataset_utils/test_compute_normalization_stats.py b/tests/dataset/dataset_utils/test_compute_normalization_stats.py index cf6a157ee..d03e527db 100644 --- a/tests/dataset/dataset_utils/test_compute_normalization_stats.py +++ b/tests/dataset/dataset_utils/test_compute_normalization_stats.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from careamics.dataset.dataset_utils import compute_normalization_stats +from careamics.dataset.dataset_utils.running_stats import compute_normalization_stats @pytest.mark.parametrize("samples, channels", [[1, 2], [1, 2]]) diff --git a/tests/dataset/test_iterable_dataset.py b/tests/dataset/test_iterable_dataset.py index b5c24646a..d3e90febb 100644 --- a/tests/dataset/test_iterable_dataset.py +++ b/tests/dataset/test_iterable_dataset.py @@ -1,3 +1,5 @@ +import os + import numpy as np import pytest import tifffile @@ -145,29 +147,20 @@ def test_extracting_val_files(tmp_path, ordered_array, percentage): ((32, 32, 32), "ZYX", (8, 8, 8)), ], ) -def test_compute_mean_std_transform_iterable( - tmp_path, ordered_array, shape, axes, patch_size -): +def test_compute_mean_std_transform_welford(tmp_path, shape, axes, patch_size): """Test that mean and std are computed and correctly added to the configuration and transform.""" - # create array - n_files = 3 - array = ordered_array(shape) - - # save three files + n_files = 100 files = [] - - # create test array with channel axis - if "C" not in axes: - stacked_array = np.stack([array] * n_files)[:, np.newaxis, ...] - else: - stacked_array = np.stack([array] * n_files) + array = np.random.randint(0, np.iinfo(np.uint16).max, (n_files, *shape)) for i in range(n_files): file = tmp_path / f"array{i}.tif" - tifffile.imwrite(file, array) + tifffile.imwrite(file, array[i]) files.append(file) + array = array[:, np.newaxis, ...] if "C" not in axes else array + # create config config_dict = { "data_type": SupportedData.TIFF.value, @@ -181,9 +174,67 @@ def test_compute_mean_std_transform_iterable( data_config=config, src_files=files, read_source_func=read_tiff ) - axes = tuple(np.delete(np.arange(stacked_array.ndim), 1)) + # define axes for mean and std computation + stats_axes = tuple(np.delete(np.arange(array.ndim), 1)) + + assert np.allclose(array.mean(axis=stats_axes), dataset.data_config.image_means) + assert np.allclose(array.std(axis=stats_axes), dataset.data_config.image_stds) + - assert np.array_equal( - stacked_array.mean(axis=axes), dataset.data_config.image_means +@pytest.mark.parametrize( + "shape, axes, patch_size", + [ + ((32, 32), "YX", (8, 8)), + ((2, 32, 32), "CYX", (8, 8)), + ((32, 32, 32), "ZYX", (8, 8, 8)), + ], +) +def test_compute_mean_std_transform_welford_with_targets( + tmp_path, shape, axes, patch_size +): + """Test that mean and std are computed and correctly added to the configuration + and transform.""" + n_files = 100 + files = [] + target_files = [] + array = np.random.randint(0, np.iinfo(np.uint16).max, (n_files, *shape)) + target_array = np.random.randint(0, np.iinfo(np.uint16).max, (n_files, *shape)) + + for i in range(n_files): + file = tmp_path / "images" / f"array{i}.tif" + target_file = tmp_path / "targets" / f"array{i}.tif" + os.makedirs(file.parent, exist_ok=True) + os.makedirs(target_file.parent, exist_ok=True) + tifffile.imwrite(file, array[i]) + tifffile.imwrite(target_file, target_array[i]) + files.append(file) + target_files.append(target_file) + + array = array[:, np.newaxis, ...] if "C" not in axes else array + target_array = target_array[:, np.newaxis, ...] if "C" not in axes else target_array + + # create config + config_dict = { + "data_type": SupportedData.TIFF.value, + "patch_size": patch_size, + "axes": axes, + } + config = DataConfig(**config_dict) + + # create dataset + dataset = PathIterableDataset( + data_config=config, + src_files=files, + target_files=target_files, + read_source_func=read_tiff, + ) + + # define axes for mean and std computation + stats_axes = tuple(np.delete(np.arange(array.ndim), 1)) + + assert np.allclose( + target_array.mean(axis=stats_axes), dataset.data_config.target_means + ) + assert np.allclose( + target_array.std(axis=stats_axes), dataset.data_config.target_stds ) - assert np.array_equal(stacked_array.std(axis=axes), dataset.data_config.image_stds) diff --git a/tests/test_careamist.py b/tests/test_careamist.py index 49756c451..fea9ed990 100644 --- a/tests/test_careamist.py +++ b/tests/test_careamist.py @@ -661,15 +661,20 @@ def test_predict_tiled_channel( assert predicted.squeeze().shape == train_array.shape +@pytest.mark.parametrize("tiled", [True, False]) +@pytest.mark.parametrize("n_samples", [1, 2]) @pytest.mark.parametrize("batch_size", [1, 2]) -def test_predict_path(tmp_path: Path, minimum_configuration: dict, batch_size): +def test_predict_path( + tmp_path: Path, minimum_configuration: dict, batch_size, n_samples, tiled +): """Test that CAREamics can predict with tiff files.""" # training data train_array = random_array((32, 32)) # save files - train_file = tmp_path / "train.tiff" - tifffile.imwrite(train_file, train_array) + for i in range(n_samples): + train_file = tmp_path / f"train_{i}.tiff" + tifffile.imwrite(train_file, train_array) # create configuration config = Configuration(**minimum_configuration) @@ -685,11 +690,27 @@ def test_predict_path(tmp_path: Path, minimum_configuration: dict, batch_size): # train CAREamist careamist.train(train_source=train_file) + if tiled: + tile_size = (16, 16) + tile_overlap = (4, 4) + else: + tile_size = None + tile_overlap = None + # predict CAREamist - predicted = careamist.predict(train_file, batch_size=batch_size) + predicted = careamist.predict( + train_file, + batch_size=batch_size, + tile_size=tile_size, + tile_overlap=tile_overlap, + ) # check that it predicted - assert predicted.squeeze().shape == train_array.shape + if isinstance(predicted, list): + for p in predicted: + assert p.squeeze().shape == train_array.shape + else: + assert predicted.squeeze().shape == train_array.shape # export to BMZ careamist.export_to_bmz( diff --git a/tests/transforms/test_compose.py b/tests/transforms/test_compose.py index de7dcab9f..12f04e2e2 100644 --- a/tests/transforms/test_compose.py +++ b/tests/transforms/test_compose.py @@ -7,7 +7,7 @@ XYFlipModel, XYRandomRotate90Model, ) -from careamics.dataset.dataset_utils import compute_normalization_stats +from careamics.dataset.dataset_utils.running_stats import compute_normalization_stats from careamics.transforms import Compose, Normalize, XYFlip, XYRandomRotate90 diff --git a/tests/transforms/test_normalize.py b/tests/transforms/test_normalize.py index ee8303eba..98ec36420 100644 --- a/tests/transforms/test_normalize.py +++ b/tests/transforms/test_normalize.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from careamics.dataset.dataset_utils import compute_normalization_stats +from careamics.dataset.dataset_utils.running_stats import compute_normalization_stats from careamics.transforms import Denormalize, Normalize from careamics.transforms.normalize import _reshape_stats diff --git a/tests/transforms/test_tta.py b/tests/transforms/test_tta.py new file mode 100644 index 000000000..49ee266a7 --- /dev/null +++ b/tests/transforms/test_tta.py @@ -0,0 +1,105 @@ +import numpy as np +import pytest +import torch + +from careamics.transforms import ImageRestorationTTA, XYFlip, XYRandomRotate90 + + +@pytest.mark.parametrize( + "shape", + [ + # 2D + (1, 1, 8, 8), + (2, 1, 8, 8), + (1, 3, 8, 8), + (2, 3, 8, 8), + # 3D + (1, 1, 8, 8, 8), + (2, 1, 8, 8, 8), + (1, 3, 8, 8, 8), + (2, 3, 8, 8, 8), + ], +) +def test_forward(shape): + """Test that the transformations are correct.""" + array = np.arange(np.prod(shape)).reshape(shape) + tensor = torch.Tensor(array) + + tta = ImageRestorationTTA() + augmented = tta.forward(tensor) + assert len(augmented) == 8 + + # check that the shape is the same + for aug in augmented: + assert aug.shape == tensor.shape + + # check that all transformed tensors are unique + for i, aug1 in enumerate(augmented): + for j, aug2 in enumerate(augmented): + if i != j: + assert not torch.allclose(aug1, aug2) + + +@pytest.mark.parametrize( + "shape", + [ + # 2D + (1, 1, 8, 8), + (1, 3, 8, 8), + # 3D + (1, 1, 8, 8, 8), + (1, 3, 8, 8, 8), + ], +) +def test_same_transforms(shape): + """Check that all arrays produced by the rotation and flip + transforms are generated by the TTA. + + Note that the transforms require no sample dimension. + """ + array = np.arange(np.prod(shape)).reshape(shape) + tensor = torch.Tensor(array) + + # apply forward transformation + tta = ImageRestorationTTA() + augmented = tta.forward(tensor) + + # transofrms + rot = XYRandomRotate90(seed=42) + flip = XYFlip(seed=42) + + for _ in range(100): + rotated, _ = rot(tensor) + flipped, _ = flip(rotated) + + # check that is in the augmented list + assert any(torch.allclose(torch.Tensor(flipped), aug) for aug in augmented) + + +@pytest.mark.parametrize( + "shape", + [ + # 2D + (1, 1, 8, 8), + (2, 1, 8, 8), + (1, 3, 8, 8), + (2, 3, 8, 8), + # 3D + (1, 1, 8, 8, 8), + (2, 1, 8, 8, 8), + (1, 3, 8, 8, 8), + (2, 3, 8, 8, 8), + ], +) +def test_backward(shape): + """Test that the backward transformation returns the original tensor.""" + array = np.arange(np.prod(shape)).reshape(shape) + tensor = torch.Tensor(array) + + # apply forward transformation + tta = ImageRestorationTTA() + augmented = tta.forward(tensor) + + # apply backward transformation + original = tta.backward(augmented) + assert torch.allclose(tensor, original)