diff --git a/minerva/__init__.py b/minerva/__init__.py index e69de29..7b1cdcd 100644 --- a/minerva/__init__.py +++ b/minerva/__init__.py @@ -0,0 +1,10 @@ + +import minerva +import minerva.analysis +import minerva.callbacks +import minerva.data +import minerva.losses +import minerva.models +import minerva.pipelines +import minerva.transforms +import minerva.utils diff --git a/minerva/analysis/metrics/transformed_metrics.py b/minerva/analysis/metrics/transformed_metrics.py new file mode 100644 index 0000000..e3d3e56 --- /dev/null +++ b/minerva/analysis/metrics/transformed_metrics.py @@ -0,0 +1,191 @@ +import warnings +from typing import Optional + +import torch +from torchmetrics import Metric + + +class CroppedMetric(Metric): + def __init__( + self, + target_h_size: int, + target_w_size: int, + metric: Metric, + dist_sync_on_step: bool = False, + ): + """ + Initializes a new instance of CroppedMetric. + + Parameters + ---------- + target_h_size: int + The target height size. + target_w_size: int + The target width size. + dist_sync_on_step: bool, optional + Whether to synchronize metric state across processes at each step. + Defaults to False. + """ + super().__init__(dist_sync_on_step=dist_sync_on_step) + self.metric = metric + self.target_h_size = target_h_size + self.target_w_size = target_w_size + + def update(self, preds: torch.Tensor, target: torch.Tensor): + """ + Updates the metric state with the predictions and targets. + + Parameters + ---------- + preds: torch.Tensor + The predicted tensor. + target: + torch.Tensor The target tensor. + """ + + preds = self.crop(preds) + target = self.crop(target) + self.metric.update(preds, target) + + def compute(self) -> float: + """ + Computes the cropped metric. + + Returns: + float: The cropped metric. + """ + return self.metric.compute() + + def crop(self, x: torch.Tensor) -> torch.Tensor: + """crops the input tensor to the target size. + + Parameters + ---------- + x : torch.Tensor + The input tensor. + + Returns + ------- + torch.Tensor + The cropped tensor. + """ + h, w = x.shape[-2:] + start_h = (h - self.target_h_size) // 2 + start_w = (w - self.target_w_size) // 2 + end_h = start_h + self.target_h_size + end_w = start_w + self.target_w_size + + return x[..., start_h:end_h, start_w:end_w] + + +class ResizedMetric(Metric): + def __init__( + self, + target_h_size: Optional[int], + target_w_size: Optional[int], + metric: Metric, + keep_aspect_ratio: bool = False, + dist_sync_on_step: bool = False, + ): + """ + Initializes a new instance of ResizeMetric. + + Parameters + ---------- + target_h_size: int + The target height size. + target_w_size: int + The target width size. + dist_sync_on_step: bool, optional + Whether to synchronize metric state across processes at each step. + Defaults to False. + """ + super().__init__(dist_sync_on_step=dist_sync_on_step) + + if target_h_size is None and target_w_size is None: + raise ValueError( + "At least one of target_h_size or target_w_size must be provided." + ) + + if ( + target_h_size is not None and target_w_size is None + ) and keep_aspect_ratio is False: + warnings.warn( + "A target_w_size is not provided, but keep_aspect_ratio is set to False. keep_aspect_ratio will be set to True. If you want to resize the image to a specific width, please provide a target_w_size." + ) + keep_aspect_ratio = True + + if ( + target_w_size is not None and target_h_size is None + ) and keep_aspect_ratio is False: + warnings.warn( + "A target_h_size is not provided, but keep_aspect_ratio is set to False. keep_aspect_ratio will be set to True. If you want to resize the image to a specific height, please provide a target_h_size." + ) + keep_aspect_ratio = True + + self.metric = metric + self.target_h_size = target_h_size + self.target_w_size = target_w_size + self.keep_aspect_ratio = keep_aspect_ratio + + def update(self, preds: torch.Tensor, target: torch.Tensor): + """ + Updates the metric state with the predictions and targets. + + Parameters + ---------- + preds: torch.Tensor + The predicted tensor. + target: + torch.Tensor The target tensor. + """ + + preds = self.resize(preds) + target = self.resize(target) + self.metric.update(preds, target) + + def compute(self) -> float: + """ + Computes the resized metric. + + Returns: + float: The resized metric. + """ + return self.metric.compute() + + def resize(self, x: torch.Tensor) -> torch.Tensor: + """Resizes the input tensor to the target size. + + Parameters + ---------- + x : torch.Tensor + The input tensor. + + Returns + ------- + torch.Tensor + The resized tensor. + """ + h, w = x.shape[-2:] + + target_h_size = self.target_h_size + target_w_size = self.target_w_size + if self.keep_aspect_ratio: + if self.target_h_size is None: + scale = target_w_size / w + target_h_size = int(h * scale) + elif self.target_w_size is None: + scale = target_h_size / h + target_w_size = int(w * scale) + type_convert = False + if "LongTensor" in x.type(): + x = x.to(torch.uint8) + type_convert = True + + return ( + torch.nn.functional.interpolate(x, size=(target_h_size, target_w_size)) + if not type_convert + else torch.nn.functional.interpolate( + x, size=(target_h_size, target_w_size) + ).to(torch.long) + ) diff --git a/minerva/callbacks/HyperSearchCallbacks.py b/minerva/callbacks/HyperSearchCallbacks.py new file mode 100644 index 0000000..e24e790 --- /dev/null +++ b/minerva/callbacks/HyperSearchCallbacks.py @@ -0,0 +1,108 @@ +import os +import shutil +import tempfile +from pathlib import Path + +import lightning.pytorch as L +from ray import train +from ray._private.usage.usage_lib import TagKey, record_extra_usage_tag +from ray.train import Checkpoint + + +class TrainerReportOnIntervalCallback(L.Callback): + + CHECKPOINT_NAME = "checkpoint.ckpt" + + def __init__(self, interval: int = 1) -> None: + super().__init__() + self.trial_name = train.get_context().get_trial_name() + self.local_rank = train.get_context().get_local_rank() + self.tmpdir_prefix = Path(tempfile.gettempdir(), self.trial_name).as_posix() + self.interval = interval + self.step = 0 + if os.path.isdir(self.tmpdir_prefix) and self.local_rank == 0: + shutil.rmtree(self.tmpdir_prefix) + + record_extra_usage_tag(TagKey.TRAIN_LIGHTNING_RAYTRAINREPORTCALLBACK, "1") + + def on_train_epoch_end( + self, trainer: L.Trainer, pl_module: L.LightningModule + ) -> None: + + # Fetch metrics + metrics = trainer.callback_metrics + metrics = {k: v.item() for k, v in metrics.items()} + + # (Optional) Add customized metrics + metrics["epoch"] = trainer.current_epoch + metrics["step"] = trainer.global_step + + tmpdir = Path(self.tmpdir_prefix, str(trainer.current_epoch)).as_posix() + os.makedirs(tmpdir, exist_ok=True) + + if self.step % self.interval == 0: + + # Save checkpoint to local + ckpt_path = Path(tmpdir, self.CHECKPOINT_NAME).as_posix() + trainer.save_checkpoint(ckpt_path, weights_only=False) + + # Report to train session + checkpoint = Checkpoint.from_directory(tmpdir) + train.report(metrics=metrics, checkpoint=checkpoint) + else: + train.report(metrics=metrics) + + # Add a barrier to ensure all workers finished reporting here + trainer.strategy.barrier() + + if self.local_rank == 0: + shutil.rmtree(tmpdir) + + self.step += 1 + + +class TrainerReportKeepOnlyLastCallback(L.Callback): + + CHECKPOINT_NAME = "checkpoint.ckpt" + + def __init__(self) -> None: + super().__init__() + self.trial_name = train.get_context().get_trial_name() + self.local_rank = train.get_context().get_local_rank() + self.tmpdir_prefix = Path(tempfile.gettempdir(), self.trial_name).as_posix() + if os.path.isdir(self.tmpdir_prefix) and self.local_rank == 0: + shutil.rmtree(self.tmpdir_prefix) + + record_extra_usage_tag(TagKey.TRAIN_LIGHTNING_RAYTRAINREPORTCALLBACK, "1") + + def on_train_epoch_end( + self, trainer: L.Trainer, pl_module: L.LightningModule + ) -> None: + # Fetch metrics + metrics = trainer.callback_metrics + metrics = {k: v.item() for k, v in metrics.items()} + + # (Optional) Add customized metrics + metrics["epoch"] = trainer.current_epoch + metrics["step"] = trainer.global_step + + tmpdir = Path(self.tmpdir_prefix, "last").as_posix() + os.makedirs(tmpdir, exist_ok=True) + + # Delete previous checkpoint + if os.path.isdir(tmpdir): + shutil.rmtree(tmpdir) + + # Save checkpoint to local + ckpt_path = Path(tmpdir, self.CHECKPOINT_NAME).as_posix() + trainer.save_checkpoint(ckpt_path, weights_only=False) + + # Report to train session + checkpoint = Checkpoint.from_directory(tmpdir) + train.report(metrics=metrics, checkpoint=checkpoint) + + # Add a barrier to ensure all workers finished reporting here + trainer.strategy.barrier() + + if self.local_rank == 0: + shutil.rmtree(tmpdir) diff --git a/minerva/callbacks/__init__.py b/minerva/callbacks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/minerva/data/datasets/supervised_dataset.py b/minerva/data/datasets/supervised_dataset.py index cccf98e..7160077 100644 --- a/minerva/data/datasets/supervised_dataset.py +++ b/minerva/data/datasets/supervised_dataset.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple +from typing import Any, List, Optional, Tuple import numpy as np @@ -15,7 +15,7 @@ class SupervisedReconstructionDataset(SimpleDataset): Usually, both input and target data have the same shape. This dataset is useful for supervised tasks such as image reconstruction, - segmantic segmentation, and object detection, where the input data is the + semantic segmentation, and object detection, where the input data is the original data and the target is a mask or a segmentation map. Examples @@ -45,7 +45,12 @@ class SupervisedReconstructionDataset(SimpleDataset): ``` """ - def __init__(self, readers: List[_Reader], transforms: Optional[_Transform] = None): + def __init__( + self, + readers: List[_Reader], + transforms: Optional[_Transform] = None, + support_context_transforms: bool = False, + ): """A simple dataset class for supervised reconstruction tasks. Parameters @@ -62,12 +67,13 @@ def __init__(self, readers: List[_Reader], transforms: Optional[_Transform] = No AssertionError: If the number of readers is not exactly 2. """ super().__init__(readers, transforms) + self.support_context_transforms = support_context_transforms assert ( len(self.readers) == 2 ), "SupervisedReconstructionDataset requires exactly 2 readers" - def __getitem__(self, index: int) -> Tuple[np.ndarray, np.ndarray]: + def __getitem__(self, index: int) -> Tuple[Any, Any]: """Load data from sources and apply specified transforms. The same transform is applied to both input and target data. @@ -78,10 +84,29 @@ def __getitem__(self, index: int) -> Tuple[np.ndarray, np.ndarray]: Returns ------- - Tuple[np.ndarray, np.ndarray] - A tuple containing two numpy arrays representing the data. + Tuple[Any, Any] + A tuple containing two elements: the input data and the target data. """ - data = super().__getitem__(index) - - return (data[0], data[1]) + if not self.support_context_transforms: + data = super().__getitem__(index) + + return (data[0], data[1]) + else: + + data = [] + + # For each reader and transform, read the data and apply the transform. + # Then, append the transformed data to the list of data. + for reader, transform in zip(reversed(self.readers), self.transforms): + sample = reader[index] + # Apply the transform if it is not None + if transform is not None: + sample = transform(sample) + data.append(sample) + # Return the list of transformed data or a single sample if return_single + # is True and there is only one reader. + if self.return_single: + return data[1] + else: + return tuple(reversed(data)) diff --git a/minerva/engines/engine.py b/minerva/engines/engine.py new file mode 100644 index 0000000..017c20c --- /dev/null +++ b/minerva/engines/engine.py @@ -0,0 +1,23 @@ +from typing import Any, Union + +import lightning.pytorch as L +import numpy as np +import torch + + +class _Engine: + """Main interface for Engine classes. Engines are used to alter the behavior of a model's prediction. + An engine should be able to take a `model` and input data `x` and return a prediction. + An use case for Engines is patched inference, where the model's default input size is smaller them the desired input size. + The engine can be used to make predictions in patches and combine this predictions in to a single output. + """ + + def __init__(self) -> None: + super().__init__() + + def __call__( + self, + model: Union[L.LightningModule, torch.nn.Module], + x: Union[torch.Tensor, np.ndarray], + ): + raise NotImplementedError diff --git a/minerva/engines/patch_inferencer_engine.py b/minerva/engines/patch_inferencer_engine.py index f88147c..a4363bd 100644 --- a/minerva/engines/patch_inferencer_engine.py +++ b/minerva/engines/patch_inferencer_engine.py @@ -1,48 +1,216 @@ -from typing import List, Tuple, Optional, Dict, Any -import torch -import numpy as np +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + import lightning as L +import numpy as np +import torch +from minerva.engines.engine import _Engine +from minerva.models.nets.base import SimpleSupervisedModel -class BasePatchInferencer: - """Inference in patches for models - This class provides utility methods for performing inference in patches +class PatchInferencer(L.LightningModule): + """This class acts as a normal `L.LightningModule` that wraps a + `SimpleSupervisedModel` model allowing it to perform inference in patches. + This is useful when the model's default input size is smaller than the + desired input size (sample size). In this case, the engine split the input + tensor into patches, perform inference in each patch, and combine them into + a single output of the desired size. The combination of patches can be + parametrized by a `weight_function` allowing a customizable combination of + patches (e.g, combining using weighted average). It is important to note + that only model's forward are wrapped, and, thus, any method that requires + the forward method (e.g., training_step, predict_step) will be performed in + patches, transparently to the user. """ def __init__( self, - model: L.LightningModule, - input_shape: Tuple, - output_shape: Optional[Tuple] = None, - weight_function: Optional[callable] = None, - offsets: Optional[List[Tuple]] = None, + model: SimpleSupervisedModel, + input_shape: Tuple[int, ...], + output_shape: Optional[Tuple[int, ...]] = None, + weight_function: Optional[Callable[[Tuple[int, ...]], torch.Tensor]] = None, + offsets: Optional[List[Tuple[int, ...]]] = None, padding: Optional[Dict[str, Any]] = None, + return_tuple: Optional[int] = None, ): - """Initialize the patch inference auxiliary class + """Wrap a `SimpleSupervisedModel` model's forward method to perform + inference in patches, transparently splitting the input tensor into + patches, performing inference in each patch, and combining them into a + single output of the desired size. Parameters ---------- - model : L.LightningModule - Model used in inference. - input_shape : Tuple - Expected input shape of the model - output_shape : Tuple, optional - Expected output shape of the model. Defaults to input_shape - weight_function: callable, optional - Function that receives a tensor shape and returns the weights for each position of a tensor with the given shape - Useful when regions of the inference present diminishing performance when getting closer to borders, for instance. - offsets : Tuple, optional - List of tuples with offsets that determine the shift of the initial position of the patch subdivision + model : SimpleSupervisedModel + Model to be wrapped. + input_shape : Tuple[int, ...] + Expected input shape of the wrapped model. + output_shape : Tuple[int, ...], optional + Expected output shape of the wrapped model. For models that return + logits (e.g., classification models), the `output_shape` must + include an additional dimension at the beginning to accommodate + the number of output classes. For example, if the model processes + an input tensor of shape (1, 128, 128) and outputs logits for 10 + classes, the expected `output_shape` should be (10, 1, 128, 128). + If the model does not return logits (e.g., return a tensor after + applying an `argmax` operation, or a regression models that usually + returns a tensor with the same shape as the input tensor), the + `output_shape` should have the same number of dimensions as the + input shape. Defaults to None, which assumes the output shape is + the same as the `input_shape` parameter. + weight_function: Callable[[Tuple[int, ...]], torch.Tensor], optional + Function that receives a tensor shape and returns the weights for + each position of a tensor with the given shape. Useful when regions + of the inference present diminishing performance when getting + closer to borders, for instance. + offsets : List[Tuple[int, ...]], optional + List of tuples with offsets that determine the shift of the initial + position of the patch subdivision. padding : Dict[str, Any], optional Dictionary describing padding strategy. Keys: - pad: tuple with pad width (int) for each dimension, e.g. (0, 3, 3) when working with a tensor with 3 dimensions - mode (optional): 'constant', 'reflect', 'replicate' or 'cicular'. Defaults to 'constant'. - value (optional): fill value for 'constante'. Defaults to 0. + - pad (mandatory): tuple with pad width (int) for each + dimension, e.g.(0, 3, 3) when working with a tensor with 3 + dimensions. + - mode (optional): 'constant', 'reflect', 'replicate' or + 'circular'. Defaults to 'constant'. + - value (optional): fill value for 'constant'. Defaults to 0. + If None, no padding is applied. + return_tuple: int, optional + Some models may return multiple outputs for a single sample (e.g., + outputs from multiple auxiliary heads). This parameter is a integer + that defines the number of outputs the model generates. By default, + it is None, which indicates that the model produces a single output + for a single input. When set, it indicates the number of outputs + the model produces. """ + super().__init__() self.model = model - self.input_shape = input_shape - self.output_shape = output_shape if output_shape is not None else input_shape + self.patch_inferencer = PatchInferencerEngine( + input_shape, + output_shape, + offsets, + padding, + weight_function, + return_tuple, + ) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + return self.forward(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Perform inference in patches. + + Parameters + ---------- + x : torch.Tensor + Batch of input data. + """ + return self.patch_inferencer(self.model, x) + + def _single_step( + self, batch: torch.Tensor, batch_idx: int, step_name: str + ) -> torch.Tensor: + """Perform a single step of the training/validation loop. + + Parameters + ---------- + batch : torch.Tensor + The input data. + batch_idx : int + The index of the batch. + step_name : str + The name of the step, either "train" or "val". + + Returns + ------- + torch.Tensor + The loss value. + """ + x, y = batch + y_hat = self.forward(x.float()) + loss = self.model._loss_func(y_hat, y.squeeze(1)) + + metrics = self.model._compute_metrics(y_hat, y, step_name) + for metric_name, metric_value in metrics.items(): + self.log( + metric_name, + metric_value, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, + ) + + self.log( + f"{step_name}_loss", + loss, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, + ) + + return loss + + def training_step(self, batch: torch.Tensor, batch_idx: int): + return self._single_step(batch, batch_idx, "train") + + def validation_step(self, batch: torch.Tensor, batch_idx: int): + return self._single_step(batch, batch_idx, "val") + + def test_step(self, batch: torch.Tensor, batch_idx: int): + return self._single_step(batch, batch_idx, "test") + + +class PatchInferencerEngine(_Engine): + def __init__( + self, + input_shape: Tuple[int, ...], + output_shape: Optional[Tuple[int, ...]] = None, + offsets: Optional[List[Tuple[int, ...]]] = None, + padding: Optional[Dict[str, Any]] = None, + weight_function: Optional[Callable] = None, + return_tuple: Optional[int] = None, + ): + """ + Parameters + ---------- + input_shape : Tuple[int, ...] + Shape of each patch to process. + output_shape : Tuple[int, ...], optional + Expected output shape of the model. For models that return logits, + the `output_shape` must include an additional dimension at the + beginning to accommodate the number of output classes. Else, the + `output_shape` should have the same number of dimensions as the + `input_shape` (i.e., no logits are returned). Defaults to + input_shape. + padding : Dict[str, Any], optional + Padding configuration with keys: + - 'pad': Tuple of padding for each expected final dimension, + e.g., (0, 512, 512) - (c, h, w). + - 'mode': Padding mode, e.g., 'constant', 'reflect'. + - 'value': Padding value if mode is 'constant'. + Defaults to None, which means no padding is applyied. + weight_function : Callable, optional + Function to calculate the weight of each patch. Defaults to None. + return_tuple : int, optional + Number of outputs to return. This is useful when the model returns + multiple outputs for a single input (e.g., from multiple auxiliary + heads). Defaults to None. + """ + self.input_shape = (1, *input_shape) + self.output_shape = ( + (1, *output_shape) if output_shape is not None else self.input_shape + ) + + # Check if possible classification task (has logits) + self.logits_dim = len(self.input_shape) != len(self.output_shape) + self.output_simplified_shape = ( + tuple([*self.output_shape[:1], *self.output_shape[2:]]) + if self.logits_dim + else self.output_shape + ) + self.weight_function = weight_function if offsets is not None: @@ -59,45 +227,49 @@ def __init__( padding["pad"] ), f"Pad tuple does not match expected size ({len(input_shape)})" self.padding = padding + self.padding["pad"] = (0, *self.padding["pad"]) else: - self.padding = {"pad": tuple([0] * len(input_shape))} - - def __call__(self, x: torch.Tensor) -> torch.Tensor: - return self.forward(x) + self.padding = {"pad": tuple([0] * (len(input_shape) + 1))} + self.return_tuple = return_tuple def _reconstruct_patches( self, patches: torch.Tensor, index: Tuple[int], - weights: bool, - inner_dim: int = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - """ - Rearranges patches to reconstruct area of interest from patches and weights + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Rearranges patches to reconstruct area of interest from patches and + weights. """ + index = tuple([index[0], 1, *index[1:]]) if self.logits_dim else index reconstruct_shape = np.array(self.output_shape) * np.array(index) - if weights: - weight = torch.zeros(tuple(reconstruct_shape)) - base_weight = ( - self.weight_function(self.input_shape) - if self.weight_function - else torch.ones(self.input_shape) + + weight = ( + torch.zeros( + tuple([*reconstruct_shape[:1], *reconstruct_shape[2:]]), + device=patches.device, ) - else: - weight = None - if inner_dim is not None: - reconstruct_shape = np.append(reconstruct_shape, inner_dim) - reconstruct = torch.zeros(tuple(reconstruct_shape)) + if self.logits_dim + else torch.zeros(tuple(reconstruct_shape), device=patches.device) + ) + + base_weight = ( + self.weight_function(self.output_simplified_shape) + if self.weight_function + else torch.ones(self.output_simplified_shape, device=patches.device) + ) + + reconstruct = torch.zeros(tuple(reconstruct_shape), device=patches.device) for patch_index, patch in zip(np.ndindex(index), patches): sl = [ slice(idx * patch_len, (idx + 1) * patch_len, None) - for idx, patch_len in zip(patch_index, self.input_shape) + for idx, patch_len in zip(patch_index, self.output_shape) ] - if weights: - weight[tuple(sl)] = base_weight - if inner_dim is not None: - sl.append(slice(None, None, None)) reconstruct[tuple(sl)] = patch + if self.logits_dim: + sl.pop(1) + weight[tuple(sl)] = base_weight + if self.logits_dim: + weight = weight.unsqueeze(1) return reconstruct, weight def _adjust_patches( @@ -107,33 +279,24 @@ def _adjust_patches( offset: Tuple[int], pad_value: int = 0, ) -> List[torch.Tensor]: + """Pads reconstructed patches with `pad_value` to have same shape as + the reference shape from the base patch set. """ - Pads reconstructed_patches with 'pad_value' to have same shape as the reference shape from the base patch set - """ - has_inner_dim = len(offset) < len(arrays[0].shape) pad_width = [] sl = [] ref_shape = list(ref_shape) arr_shape = list(arrays[0].shape) - if has_inner_dim: - arr_shape = arr_shape[:-1] - for idx, lenght, ref in zip(offset, arr_shape, ref_shape): + adjusted_offset = [0, 0, *offset] if self.logits_dim else [0, *offset] + for idx, length, ref in zip(adjusted_offset, arr_shape, ref_shape): if idx > 0: - sl.append(slice(0, min(lenght, ref), None)) - pad_width = [idx, max(ref - lenght - idx, 0)] + pad_width + sl.append(slice(0, min(length, ref - idx), None)) + pad_width = [idx, max(ref - length - idx, 0)] + pad_width else: - sl.append(slice(np.abs(idx), min(lenght, ref - idx), None)) - pad_width = [0, max(ref - lenght - idx, 0)] + pad_width + sl.append(slice(np.abs(idx), min(length, ref - idx), None)) + pad_width = [0, max(ref - length - idx, 0)] + pad_width adjusted = [ ( torch.nn.functional.pad( - arr[tuple([*sl, slice(None, None, None)])], - pad=tuple([0, 0, *pad_width]), - mode="constant", - value=pad_value, - ) - if has_inner_dim - else torch.nn.functional.pad( arr[tuple(sl)], pad=tuple(pad_width), mode="constant", @@ -150,18 +313,25 @@ def _combine_patches( offsets: List[Tuple[int]], indexes: List[Tuple[int]], ) -> torch.Tensor: - """ - How results are combined is dependent on what is being combined. - RegressionPatchInferencer uses Weighted Average - ClassificationPatchInferencer uses Voting (hard or soft) - """ - raise NotImplementedError("Combine patches method must be implemented") + """Performs the combination of patches based on the weight function.""" + reconstructed = [] + weights = [] + for patches, offset, shape in zip(results, offsets, indexes): + reconstruct, weight = self._reconstruct_patches(patches, shape) + reconstruct, weight = self._adjust_patches( + [reconstruct, weight], self.ref_shape, offset + ) + reconstructed.append(reconstruct) + weights.append(weight) + reconstructed = torch.stack(reconstructed, dim=0) + weights = torch.stack(weights, dim=0) + return torch.sum(reconstructed * weights, dim=0) / torch.sum(weights, dim=0) def _extract_patches( self, data: torch.Tensor, patch_shape: Tuple[int] ) -> Tuple[torch.Tensor, Tuple[int]]: - """ - Patch extraction method. It will be called once for the base patch set and also for the requested offsets (overlapping patch sets) + """Patch extraction method. It will be called once for the base patch + set and also for the requested offsets (overlapping patch sets). """ indexes = tuple(np.array(data.shape) // np.array(patch_shape)) patches = [] @@ -174,42 +344,64 @@ def _extract_patches( return torch.stack(patches), indexes def _compute_output_shape(self, tensor: torch.Tensor) -> Tuple[int]: - """ - Computes PatchInferencer output shape based on input tensor shape, and model's input and output shapes. + """Computes `PatchInferencer` output shape based on input tensor shape, + and model's input and output shapes. """ if self.input_shape == self.output_shape: return tensor.shape shape = [] - for i, o, t in zip(self.input_shape, self.output_shape, tensor.shape): + for i, o, t in zip( + self.input_shape, self.output_simplified_shape, tensor.shape + ): if i != o: shape.append(int(t * o // i)) else: shape.append(t) + + if self.logits_dim: + shape.insert(1, self.output_shape[1]) + return tuple(shape) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def _compute_base_padding(self, tensor: torch.Tensor): + """Computes the padding for the base patch set based on the input + tensor shape and the model's input shape. """ - Perform Inference in Patches + padding = [0, 0] + for i, t in zip(self.padding["pad"][2:], tensor.shape[2:]): + padding.append(max(0, i - t)) + return padding + + def __call__( + self, model: Union[L.LightningModule, torch.nn.Module], x: torch.Tensor + ): + """Perform inference in patches, from the input tensor `x` using the + model `model`. Parameters ---------- + model: Union[L.LightningModule, torch.nn.Module] + Model to perform inference. x : torch.Tensor - Input Tensor. + Input tensor of the sample. It can be a single sample or a batch + of samples. """ - assert len(x.shape) == len( - self.input_shape - ), "Input and self.input_shape sizes must match" + if len(x.shape) == len(self.input_shape) - 1: + x = x.unsqueeze(0) + elif len(x.shape) == len(self.input_shape): + pass + else: + raise RuntimeError("Invalid input shape") self.ref_shape = self._compute_output_shape(x) offsets = list(self.offsets) - base = self.padding["pad"] - offsets.insert(0, tuple([0] * len(base))) - + base = self._compute_base_padding(x) + offsets.insert(0, tuple([0] * (len(base) - 1))) slices = [ tuple( [ - slice(i + base, None) # TODO: if ((i + base >= 0) and (i < in_dim)) - for i, base, in_dim in zip(offset, base, x.shape) + slice(i, None) # TODO: if ((i + base >= 0) and (i < in_dim)) + for i, in_dim in zip([0, *offset], x.shape) ] ) for offset in offsets @@ -217,153 +409,36 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: torch_pad = [] for pad_value in reversed(base): - torch_pad = torch_pad + [pad_value, pad_value] + torch_pad = torch_pad + [0, pad_value] x_padded = torch.nn.functional.pad( x, pad=tuple(torch_pad), mode=self.padding.get("mode", "constant"), value=self.padding.get("value", 0), ) - results = [] + results = ( + tuple([] for _ in range(self.return_tuple)) if self.return_tuple else [] + ) indexes = [] for sl in slices: patch_set, patch_idx = self._extract_patches(x_padded[sl], self.input_shape) - results.append(self.model(patch_set)) + patch_set = patch_set.squeeze(1) + inference = model(patch_set) + if self.return_tuple: + for i in range(self.return_tuple): + results[i].append(inference[i]) + else: + results.append(inference) indexes.append(patch_idx) - output_slice = tuple( - [slice(0, lenght) for lenght in x.shape] - ) - return self._combine_patches(results, offsets, indexes)[output_slice] - - -class WeightedAvgPatchInferencer(BasePatchInferencer): - """ - PatchInferencer with Weighted Average combination function. - """ - - def _combine_patches( - self, - results: List[torch.Tensor], - offsets: List[Tuple[int]], - indexes: List[Tuple[int]], - ) -> torch.Tensor: - reconstructed = [] - weights = [] - for patches, offset, shape in zip(results, offsets, indexes): - reconstruct, weight = self._reconstruct_patches( - patches, shape, weights=True - ) - reconstruct, weight = self._adjust_patches( - [reconstruct, weight], self.ref_shape, offset - ) - - reconstructed.append(reconstruct) - weights.append(weight) - reconstructed = torch.stack(reconstructed, dim=0) - weights = torch.stack(weights, dim=0) - return torch.sum(reconstructed * weights, dim=0) / torch.sum(weights, dim=0) - - -class VotingPatchInferencer(BasePatchInferencer): - """ - PatchInferencer with Voting combination function. - Note: Models used with VotingPatchInferencer must return class probabilities in inner dimension - """ - - def __init__( - self, - model: L.LightningModule, - num_classes: int, - input_shape: Tuple, - output_shape: Optional[Tuple] = None, - weight_function: Optional[callable] = None, - offsets: Optional[List[Tuple]] = None, - padding: Optional[Dict[str, Any]] = None, - voting: str = "soft", - ): - """Initialize the patch inference auxiliary class - - Parameters - ---------- - model : L.LightningModule - Model used in inference. - num_classes: int - number of classes of the classification task - input_shape : Tuple - Expected input shape of the model - output_shape : Tuple, optional - Expected output shape of the model. Defaults to input_shape - weight_function: callable, optional - Function that receives a tensor shape and returns the weights for each position of a tensor with the given shape - Useful when regions of the inference present diminishing performance when getting closer to borders, for instance. - offsets : Tuple, optional - List of tuples with offsets that determine the shift of the initial position of the patch subdivision - padding : Dict[str, Any], optional - Dictionary describing padding strategy. Keys: - pad: tuple with pad width (int) for each dimension, e.g. (0, 3, 3) when working with a tensor with 3 dimensions - mode (optional): 'constant', 'reflect', 'replicate' or 'cicular'. Defaults to 'constant'. - value (optional): fill value for 'constante'. Defaults to 0. - voting: str - voting method to use, can be either 'soft'or 'hard'. Defaults to 'soft'. - """ - super().__init__( - model, input_shape, output_shape, weight_function, offsets, padding - ) - assert voting in ["soft", "hard"], "voting should be either 'soft' or 'hard'" - self.num_classes = num_classes - self.voting = voting - - def _combine_patches( - self, - results: List[torch.Tensor], - offsets: List[Tuple[int]], - indexes: List[Tuple[int]], - ) -> torch.Tensor: - voting_method = getattr(self, f"_{self.voting}_voting") - return voting_method(results, offsets, indexes) - - def _hard_voting( - self, - results: List[torch.Tensor], - offsets: List[Tuple[int]], - indexes: List[Tuple[int]], - ) -> torch.Tensor: - """ - Hard voting combination function - """ - # torch.mode does not work like scipy.stats.mode - raise NotImplementedError("Hard voting not yet supported") - # reconstructed = [] - # for patches, offset, shape in zip(results, offsets, indexes): - # reconstruct, _ = self._reconstruct_patches( - # patches, shape, weights=False, inner_dim=self.num_classes - # ) - # reconstruct = torch.argmax(reconstruct, dim=-1).float() - # reconstruct = self._adjust_patches( - # [reconstruct], self.ref_shape, offset, pad_value=torch.nan - # )[0] - # reconstructed.append(reconstruct) - # reconstructed = torch.stack(reconstructed, dim=0) - # ret = torch.mode(reconstructed, dim=0, keepdims=False)[ - # 0 - # ] # TODO check behaviour on GPU, according to issues may have nonsense results - # return ret - - def _soft_voting( - self, - results: List[torch.Tensor], - offsets: List[Tuple[int]], - indexes: List[Tuple[int]], - ) -> torch.Tensor: - """ - Soft voting combination function - """ - reconstructed = [] - for patches, offset, shape in zip(results, offsets, indexes): - reconstruct, _ = self._reconstruct_patches( - patches, shape, weights=False, inner_dim=self.num_classes - ) - reconstruct = self._adjust_patches([reconstruct], self.ref_shape, offset)[0] - reconstructed.append(reconstruct) - reconstructed = torch.stack(reconstructed, dim=0) - return torch.argmax(torch.sum(reconstructed, dim=0), dim=-1) + output_slice = tuple([slice(0, length) for length in self.ref_shape]) + if self.return_tuple: + comb_list = [] + for i in range(self.return_tuple): + comb = self._combine_patches(results[i], offsets, indexes) + comb = comb[output_slice] + comb_list.append(comb) + comb = tuple(comb_list) + else: + comb = self._combine_patches(results, offsets, indexes) + comb = comb[output_slice] + return comb diff --git a/minerva/models/nets/image/setr.py b/minerva/models/nets/image/setr.py index f9fe225..797c847 100644 --- a/minerva/models/nets/image/setr.py +++ b/minerva/models/nets/image/setr.py @@ -1,20 +1,20 @@ import warnings from typing import Dict, List, Optional, Tuple, Union -import lightning as L +import lightning.pytorch as L import torch from torch import nn +from torch.optim.adam import Adam from torchmetrics import Metric +from minerva.engines.engine import _Engine from minerva.models.nets.image.vit import _VisionTransformerBackbone from minerva.utils.upsample import Upsample class _SETRUPHead(nn.Module): - """Naive upsampling head and Progressive upsampling head of SETR. - - Naive or PUP head of `SETR `_. - + """Naive upsampling head and Progressive upsampling head of SETR + (as in https://arxiv.org/pdf/2012.15840.pdf). """ def __init__( @@ -32,8 +32,7 @@ def __init__( dropout: float, interpolate_mode: str, ): - """ - Initializes the SETR model. + """The SETR PUP Head. Parameters ---------- @@ -120,9 +119,10 @@ def forward(self, x): class _SETRMLAHead(nn.Module): - """Multi level feature aggretation head of SETR. + """Multi level feature aggretation head of SETR (as in + https://arxiv.org/pdf/2012.15840.pdf) - MLA head of `SETR `_. + Note: This has not been tested yet! """ def __init__( @@ -130,7 +130,7 @@ def __init__( channels: int, conv_norm: Optional[nn.Module], conv_act: Optional[nn.Module], - in_channels: list[int], + in_channels: List[int], out_channels: int, num_classes: int, mla_channels: int = 128, @@ -219,7 +219,6 @@ def forward(self, x): class _SetR_PUP(nn.Module): - def __init__( self, image_size: Union[int, Tuple[int, int]], @@ -240,11 +239,11 @@ def __init__( conv_norm: nn.Module, conv_act: nn.Module, align_corners: bool, - aux_output: bool = False, - aux_output_layers: Optional[List[int]] = None, + aux_output: bool, + aux_output_layers: Optional[List[int]], + original_resolution: Optional[Tuple[int, int]], ): - """ - Initializes the SETR PUP model. + """Initializes the SETR PUP head. Parameters ---------- @@ -259,7 +258,7 @@ def __init__( hidden_dim : int The hidden dimension of the transformer encoder. mlp_dim : int - The dimension of the feed-forward network in the transformer encoder. + The dimension of the feed-forward network in the transformer encoder num_convs : int The number of convolutional layers in the decoder. num_classes : int @@ -279,11 +278,21 @@ def __init__( interpolate_mode : str The mode for interpolation during upsampling. conv_norm : nn.Module - The normalization layer to be used in the decoder convolutional layers. + The normalization layer to be used in the decoder convolutional + layers. conv_act : nn.Module - The activation function to be used in the decoder convolutional layers. + The activation function to be used in the decoder convolutional + layers. align_corners : bool Whether to align corners during upsampling. + aux_output: bool + Whether to use auxiliary outputs. If True, aux_output_layers must + be provided. + aux_output_layers: List[int], optional + The layers to use for auxiliary outputs. Must have exacly 3 values. + original_resolution: Tuple[int, int], optional + The original resolution of the input image in the pre-training + weights. When None, positional embeddings will not be interpolated. """ super().__init__() @@ -307,6 +316,7 @@ def __init__( dropout=encoder_dropout, aux_output=aux_output, aux_output_layers=aux_output_layers, + original_resolution=original_resolution, ) self.decoder = _SETRUPHead( @@ -382,12 +392,42 @@ def forward(self, x: torch.Tensor): x = self.encoder(x) x = self.decoder(x) return x - - def load_backbone(self, path: str): - self.encoder.load_state_dict(torch.load(path)) + + def load_backbone(self, path: str, freeze: bool = False): + self.encoder.load_backbone(path) + if freeze: + for param in self.encoder.parameters(): + param.requires_grad = False class SETR_PUP(L.LightningModule): + """SET-R model with PUP head for image segmentation. + + Methods + ------- + forward(x: torch.Tensor) -> torch.Tensor + Forward pass of the model. + _compute_metrics(y_hat: torch.Tensor, y: torch.Tensor, step_name: str) + Compute metrics for the given step. + _loss_func(y_hat: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], y: torch.Tensor) -> torch.Tensor + Calculate the loss between the output and the input data. + _single_step(batch: torch.Tensor, batch_idx: int, step_name: str) + Perform a single step of the training/validation loop. + training_step(batch: torch.Tensor, batch_idx: int) + Perform a single training step. + validation_step(batch: torch.Tensor, batch_idx: int) + Perform a single validation step. + test_step(batch: torch.Tensor, batch_idx: int) + Perform a single test step. + predict_step(batch: torch.Tensor, batch_idx: int, dataloader_idx: Optional[int] = None) + Perform a single prediction step. + load_backbone(path: str, freeze: bool = False) + Load a pre-trained backbone. + configure_optimizers() + Configure the optimizer for the model. + create_from_dict(config: Dict) -> "SETR_PUP" + Create an instance of SETR_PUP from a configuration dictionary. + """ def __init__( self, @@ -410,72 +450,119 @@ def __init__( conv_act: Optional[nn.Module] = None, interpolate_mode: str = "bilinear", loss_fn: Optional[nn.Module] = None, + optimizer_type: Optional[type] = None, + optimizer_params: Optional[Dict] = None, train_metrics: Optional[Dict[str, Metric]] = None, val_metrics: Optional[Dict[str, Metric]] = None, test_metrics: Optional[Dict[str, Metric]] = None, aux_output: bool = True, - aux_output_layers: Optional[List[int]] = [9, 14, 19], - aux_weights: List[float] = [0.3, 0.3, 0.3], + aux_output_layers: Optional[list[int]] = None, + aux_weights: Optional[list[float]] = None, + load_backbone_path: Optional[str] = None, + freeze_backbone_on_load: bool = True, + learning_rate: float = 1e-3, + loss_weights: Optional[list[float]] = None, + original_resolution: Optional[Tuple[int, int]] = None, + head_lr_factor: float = 1.0, + test_engine: Optional[_Engine] = None, ): - """ - Initializes the SetR model. + """Initialize the SETR model with Progressive Upsampling Head. Parameters ---------- - image_size : int or Tuple[int, int] - The input image size. Defaults to 512. - patch_size : int - The size of each patch. Defaults to 16. - num_layers : int - The number of layers in the transformer encoder. Defaults to 24. - num_heads : int - The number of attention heads in the transformer encoder. Defaults to 16. - hidden_dim : int - The hidden dimension of the transformer encoder. Defaults to 1024. - mlp_dim : int - The dimension of the MLP layers in the transformer encoder. Defaults to 4096. - encoder_dropout : float - The dropout rate for the transformer encoder. Defaults to 0.1. - num_classes : int - The number of output classes. Defaults to 1000. - norm_layer : nn.Module, optional - The normalization layer to be used in the decoder. Defaults to None. - decoder_channels : int - The number of channels in the decoder. Defaults to 256. - num_convs : int - The number of convolutional layers in the decoder. Defaults to 4. - up_scale : int - The scale factor for upsampling in the decoder. Defaults to 2. - kernel_size : int - The kernel size for convolutional layers in the decoder. Defaults to 3. - align_corners : bool - Whether to align corners during interpolation in the decoder. Defaults to False. - decoder_dropout : float - The dropout rate for the decoder. Defaults to 0.1. - conv_norm : nn.Module, optional - The normalization layer to be used in the convolutional layers of the decoder. Defaults to None. - conv_act : nn.Module, optional - The activation function to be used in the convolutional layers of the decoder. Defaults to None. - interpolate_mode : str - The interpolation mode for upsampling in the decoder. Defaults to "bilinear". - loss_fn : nn.Module, optional - The loss function to be used during training. Defaults to None. - train_metrics : Dict[str, Metric], optional - The metrics to be used for training evaluation. Defaults to None. - val_metrics : Dict[str, Metric], optional - The metrics to be used for validation evaluation. Defaults to None. - test_metrics : Dict[str, Metric], optional - The metrics to be used for testing evaluation. Defaults to None. - aux_output : bool - Whether to include auxiliary output heads in the model. Defaults to True. - aux_output_layers : List[int], optional - The indices of the layers to output auxiliary predictions. Defaults to [9, 14, 19]. - aux_weights : List[float] - The weights for the auxiliary predictions. Defaults to [0.3, 0.3, 0.3]. - + image_size : Union[int, Tuple[int, int]], optional + Size of the input image, by default 512. + patch_size : int, optional + Size of the patches to be extracted from the input image, by + default 16. + num_layers : int, optional + Number of transformer layers, by default 24. + num_heads : int, optional + Number of attention heads, by default 16. + hidden_dim : int, optional + Dimension of the hidden layer, by default 1024. + mlp_dim : int, optional + Dimension of the MLP layer, by default 4096. + encoder_dropout : float, optional + Dropout rate for the encoder, by default 0.1. + num_classes : int, optional + Number of output classes, by default 1000. + norm_layer : Optional[nn.Module], optional + Normalization layer, by default None. + decoder_channels : int, optional + Number of channels in the decoder, by default 256. + num_convs : int, optional + Number of convolutional layers in the decoder, by default 4. + up_scale : int, optional + Upscaling factor for the decoder, by default 2. + kernel_size : int, optional + Kernel size for the convolutional layers, by default 3. + align_corners : bool, optional + Whether to align corners when interpolating, by default False. + decoder_dropout : float, optional + Dropout rate for the decoder, by default 0.1. + conv_norm : Optional[nn.Module], optional + Normalization layer for the convolutional layers, by default None. + conv_act : Optional[nn.Module], optional + Activation function for the convolutional layers, by default None. + interpolate_mode : str, optional + Interpolation mode, by default "bilinear". + loss_fn : Optional[nn.Module], optional + Loss function, when None defaults to nn.CrossEntropyLoss, by + default None. + optimizer_type : Optional[type], optional + Type of optimizer, by default None. + optimizer_params : Optional[Dict], optional + Parameters for the optimizer, by default None. + train_metrics : Optional[Dict[str, Metric]], optional + Metrics for training, by default None. + val_metrics : Optional[Dict[str, Metric]], optional + Metrics for validation, by default None. + test_metrics : Optional[Dict[str, Metric]], optional + Metrics for testing, by default None. + aux_output : bool, optional + Whether to use auxiliary outputs, by default True. + aux_output_layers : list[int], optional + Layers for auxiliary outputs, when None it defaults to [9, 14, 19]. + aux_weights : list[float], optional + Weights for auxiliary outputs, when None it defaults [0.3, 0.3, 0.3]. + load_backbone_path : Optional[str], optional + Path to load the backbone model, by default None. + freeze_backbone_on_load : bool, optional + Whether to freeze the backbone model on load, by default True. + learning_rate : float, optional + Learning rate, by default 1e-3. + loss_weights : Optional[list[float]], optional + Weights for the loss function, by default None. + original_resolution : Optional[Tuple[int, int]], optional + The original resolution of the input image in the pre-training + weights. When None, positional embeddings will not be interpolated. + Defaults to None. + head_lr_factor : float, optional + Learning rate factor for the head. used if you need different + learning rates for backbone and prediction head, by default 1.0. + test_engine : Optional[_Engine], optional + Engine used for test and validation steps. When None, behavior of + all steps, training, testing and validation is the same, by default None. """ super().__init__() - self.loss_fn = loss_fn if loss_fn is not None else nn.CrossEntropyLoss() + + if head_lr_factor != 1: + self.automatic_optimization = False + self.multiple_optimizers = True + else: + self.automatic_optimization = True + self.multiple_optimizers = False + + self.loss_fn = ( + loss_fn + if loss_fn is not None + else nn.CrossEntropyLoss( + weight=( + torch.tensor(loss_weights) if loss_weights is not None else None + ) + ) + ) norm_layer = norm_layer if norm_layer is not None else nn.LayerNorm(hidden_dim) conv_norm = ( conv_norm if conv_norm is not None else nn.SyncBatchNorm(decoder_channels) @@ -483,16 +570,28 @@ def __init__( conv_act = conv_act if conv_act is not None else nn.ReLU() if aux_output: - assert aux_output_layers is not None, "aux_output_layers must be provided." + if aux_output_layers is None: + aux_output_layers = [9, 14, 19] + warnings.warn( + "aux_output_layers not provided. Using default values [9, 14, 19]." + ) + if aux_weights is None: + aux_weights = [0.3, 0.3, 0.3] + warnings.warn( + "aux_weights not provided. Using default values [0.3, 0.3, 0.3]." + ) assert ( len(aux_output_layers) == 3 ), "aux_output_layers must have 3 values. Only 3 aux heads are supported." - assert len(aux_weights) == len( - aux_output_layers - ), "aux_weights must have the same length as aux_output_layers." + + self.optimizer_type = optimizer_type + if optimizer_type is not None: + assert optimizer_params is not None, "optimizer_params must be provided." + self.optimizer_params = optimizer_params self.num_classes = num_classes self.aux_weights = aux_weights + self.head_lr_factor = head_lr_factor self.metrics = { "train": train_metrics, @@ -521,7 +620,13 @@ def __init__( align_corners=align_corners, aux_output=aux_output, aux_output_layers=aux_output_layers, + original_resolution=original_resolution, ) + if load_backbone_path is not None: + self.model.load_backbone(load_backbone_path, freeze_backbone_on_load) + + self.learning_rate = learning_rate + self.test_engine = test_engine def forward(self, x: torch.Tensor) -> torch.Tensor: return self.model(x) @@ -540,7 +645,8 @@ def _compute_metrics(self, y_hat: torch.Tensor, y: torch.Tensor, step_name: str) def _loss_func( self, y_hat: Union[ - torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + torch.Tensor, + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], ], y: torch.Tensor, ) -> torch.Tensor: @@ -592,10 +698,14 @@ def _single_step(self, batch: torch.Tensor, batch_idx: int, step_name: str): The loss value. """ x, y = batch - y_hat = self.model(x.float()) - loss = self._loss_func(y_hat[0], y.squeeze(1)) + if self.test_engine and (step_name == "test" or step_name == "val"): + y_hat = self.test_engine(self.model, x) + else: + y_hat = self.model(x) metrics = self._compute_metrics(y_hat[0], y, step_name) + loss = self._loss_func(y_hat, y.squeeze(1)) + for metric_name, metric_value in metrics.items(): self.log( metric_name, @@ -620,7 +730,20 @@ def _single_step(self, batch: torch.Tensor, batch_idx: int, step_name: str): return loss def training_step(self, batch: torch.Tensor, batch_idx: int): - return self._single_step(batch, batch_idx, "train") + if self.multiple_optimizers: + optimizers_list = self.optimizers() + + for opt in optimizers_list: + opt.zero_grad() + + loss = self._single_step(batch, batch_idx, "train") + + self.manual_backward(loss) + + for opt in optimizers_list: + opt.step() + else: + return self._single_step(batch, batch_idx, "train") def validation_step(self, batch: torch.Tensor, batch_idx: int): return self._single_step(batch, batch_idx, "val") @@ -629,13 +752,52 @@ def test_step(self, batch: torch.Tensor, batch_idx: int): return self._single_step(batch, batch_idx, "test") def predict_step( - self, batch: torch.Tensor, batch_idx: int, dataloader_idx: Optional[int] = None + self, + batch: torch.Tensor, + batch_idx: int, + dataloader_idx: Optional[int] = None, ): x, _ = batch return self.model(x)[0] - def load_backbone(self, path: str): - self.model.load_backbone(path) + def load_backbone(self, path: str, freeze: bool = False): + self.model.load_backbone(path, freeze) def configure_optimizers(self): - return torch.optim.Adam(self.model.parameters(), lr=1e-3) + if self.multiple_optimizers: + return ( + [ + self.optimizer_type( + self.model.encoder.parameters(), + lr=self.learning_rate, + **self.optimizer_params, + ), + self.optimizer_type( + list(self.model.decoder.parameters()) + + list(self.model.aux_head1.parameters()) + + list(self.model.aux_head2.parameters()) + + list(self.model.aux_head3.parameters()), + lr=self.learning_rate * self.head_lr_factor, + **self.optimizer_params, + ), + ] + if self.optimizer_type is not None + else [ + Adam(self.model.encoder.parameters(), lr=self.learning_rate), + Adam(self.model.decoder.parameters(), lr=self.learning_rate), + ] + ) + else: + return ( + self.optimizer_type( + self.model.parameters(), + lr=self.learning_rate, + **self.optimizer_params, + ) + if self.optimizer_type is not None + else Adam(self.model.parameters(), lr=self.learning_rate) + ) + + @staticmethod + def create_from_dict(config: Dict) -> "SETR_PUP": + return SETR_PUP(**config) diff --git a/minerva/models/nets/image/vit.py b/minerva/models/nets/image/vit.py index 238c63f..8cf5ba3 100644 --- a/minerva/models/nets/image/vit.py +++ b/minerva/models/nets/image/vit.py @@ -8,6 +8,7 @@ import torch.nn as nn from timm.models.vision_transformer import Block, PatchEmbed from torch import nn +from torch.nn import functional as F from torchvision.models.vision_transformer import ( Conv2dNormActivation, ConvStemConfig, @@ -87,6 +88,7 @@ def __init__( num_heads: int, hidden_dim: int, mlp_dim: int, + original_resolution: Optional[Tuple[int, int]] = None, dropout: float = 0.0, attention_dropout: float = 0.0, num_classes: int = 1000, @@ -102,7 +104,8 @@ def __init__( ---------- image_size : int or Tuple[int, int] The size of the input image. If an int is provided, it is assumed - to be a square image. If a tuple of ints is provided, it represents the height and width of the image. + to be a square image. If a tuple of ints is provided, it represents + the height and width of the image. patch_size : int The size of each patch in the image. num_layers : int @@ -112,19 +115,25 @@ def __init__( hidden_dim : int The dimensionality of the hidden layers in the transformer. mlp_dim : int - The dimensionality of the feed-forward MLP layers in the transformer. + The dimensionality of the feed-forward MLP layers in the transformer + original_resolution : Tuple[int, int], optional + The original resolution of the input image in the pre-training + weights. When None, positional embeddings will not be interpolated. + Defaults to None. dropout : float, optional The dropout rate to apply. Defaults to 0.0. attention_dropout : float, optional - The dropout rate to apply to the attention weights. Defaults to 0.0. + The dropout rate to apply to the attention weights. Defaults to 0.0 num_classes : int, optional The number of output classes. Defaults to 1000. norm_layer : Callable[..., torch.nn.Module], optional - The normalization layer to use. Defaults to nn.LayerNorm with epsilon=1e-6. + The normalization layer to use. Defaults to nn.LayerNorm with + epsilon=1e-6. conv_stem_configs : List[ConvStemConfig], optional The configuration for the convolutional stem layers. - If provided, the input image will be processed by these convolutional layers before being passed to - the transformer. Defaults to None. + If provided, the input image will be processed by these + convolutional layers before being passed to the transformer. + Defaults to None. """ super().__init__() @@ -138,7 +147,8 @@ def __init__( if isinstance(image_size, int): torch._assert( - image_size % patch_size == 0, "Input shape indivisible by patch size!" + image_size % patch_size == 0, + "Input shape indivisible by patch size!", ) elif isinstance(image_size, tuple): torch._assert( @@ -156,6 +166,9 @@ def __init__( self.norm_layer = norm_layer self.aux_output = aux_output self.aux_output_layers = aux_output_layers + self.original_resolution = ( + original_resolution if original_resolution else image_size + ) if conv_stem_configs is not None: # As per https://arxiv.org/abs/2106.14881 @@ -177,7 +190,9 @@ def __init__( seq_proj.add_module( "conv_last", nn.Conv2d( - in_channels=prev_channels, out_channels=hidden_dim, kernel_size=1 + in_channels=prev_channels, + out_channels=hidden_dim, + kernel_size=1, ), ) self.conv_proj: nn.Module = seq_proj @@ -241,7 +256,8 @@ def _process_input(self, x: torch.Tensor) -> Tuple[torch.Tensor, int, int]: x (torch.Tensor): The input tensor. Returns: - Tuple[torch.Tensor, int, int]: The reshaped tensor, number of rows, and number of columns. + Tuple[torch.Tensor, int, int]: The reshaped tensor, number of rows, + and number of columns. """ n, c, h, w = x.shape p = self.patch_size @@ -284,6 +300,97 @@ def _process_input(self, x: torch.Tensor) -> Tuple[torch.Tensor, int, int]: return x, n_h, n_w + def interpolate_pos_embeddings(self, pretrained_pos_embed, new_img_size): + """Interpolate encoder's positional embeddings to fit a new input size. + + Args: + pretrained_pos_embed (torch.Tensor): Pretrained positional embeddings. + new_img_size (Tuple[int, int]): New height and width of the input image. + """ + h, w = ( + new_img_size[0] // self.patch_size, + new_img_size[1] // self.patch_size, + ) + new_grid_size = (h, w) + + # Reshape pretrained positional embeddings to match the original grid size + + original_resolution = ( + self.original_resolution + if isinstance(self.original_resolution, Tuple) + else (self.original_resolution, self.original_resolution) + ) + + pos_embed_reshaped = pretrained_pos_embed[:, 1:].reshape( + 1, + original_resolution[0] // self.patch_size, + original_resolution[1] // self.patch_size, + -1, + ) + + # Interpolate positional embeddings to the new grid size + pos_embed_interpolated = ( + F.interpolate( + pos_embed_reshaped.permute( + 0, 3, 1, 2 + ), # (1, C, H, W) for interpolation + size=new_grid_size, + mode="bilinear", + align_corners=False, + ) + .permute(0, 2, 3, 1) + .reshape(1, -1, pos_embed_reshaped.shape[-1]) + ) + + # Concatenate the CLS token and the interpolated positional embeddings + cls_token = pretrained_pos_embed[:, :1] + pos_embed_interpolated = torch.cat((cls_token, pos_embed_interpolated), dim=1) + + return pos_embed_interpolated + + return pos_embed_interpolated + + def load_backbone(self, path: str, freeze: bool = False): + """Loads pretrained weights and handles positional embedding resizing + if necessary.""" + # Load the pretrained state dict + state_dict = torch.load(path) + + # Expected shape for positional embeddings based on current model image size + + image_size = ( + self.image_size + if isinstance(self.image_size, Tuple) + else (self.image_size, self.image_size) + ) + + expected_pos_embed_shape = ( + 1, + (image_size[0] // self.patch_size) * (image_size[1] // self.patch_size) + 1, + self.hidden_dim, + ) + + # Check if positional embeddings need interpolation + if state_dict["encoder.pos_embedding"].shape != expected_pos_embed_shape: + # Extract the positional embeddings from the state dict + pretrained_pos_embed = state_dict["encoder.pos_embedding"] + + # Interpolate to match the current image size + print("Interpolating positional embeddings to match the new image size.") + with torch.no_grad(): + pos_embed_interpolated = self.interpolate_pos_embeddings( + pretrained_pos_embed, (image_size[0], image_size[1]) + ) + state_dict["encoder.pos_embedding"] = pos_embed_interpolated + + # Load the (potentially modified) state dict into the encoder + self.encoder.load_state_dict(state_dict, strict=False) + + # Optionally freeze parameters + if freeze: + for param in self.encoder.parameters(): + param.requires_grad = False + def forward(self, x: torch.Tensor): """Forward pass of the Vision Transformer Backbone. @@ -293,6 +400,7 @@ def forward(self, x: torch.Tensor): Returns: torch.Tensor: The output tensor. """ + # Reshape and permute the input tensor x, n_h, n_w = self._process_input(x) n = x.shape[0] @@ -328,6 +436,45 @@ def forward(self, x: torch.Tensor): return x + def load_weights(self, weights_path: str, freeze: bool = False): + + state_dict = torch.load(weights_path) + + # Get expected positional embedding shape based on current image size + + image_size = ( + self.image_size + if isinstance(self.image_size, Tuple) + else (self.image_size, self.image_size) + ) + + expected_pos_embed_shape = ( + 1, + (image_size[0] // self.patch_size) * (image_size[1] // self.patch_size) + 1, + self.hidden_dim, + ) + + # Check if positional embeddings need interpolation + if state_dict["encoder.pos_embedding"].shape != expected_pos_embed_shape: + # Extract the positional embeddings from the state dict + pretrained_pos_embed = state_dict["encoder.pos_embedding"] + + # Interpolate to match the current image size + print("Interpolating positional embeddings to match the new image size.") + with torch.no_grad(): + pos_embed_interpolated = self.interpolate_pos_embeddings( + pretrained_pos_embed, (image_size[0], image_size[1]) + ) + state_dict["encoder.pos_embedding"] = pos_embed_interpolated + + # Load the (potentially modified) state dict + self.load_state_dict(state_dict, strict=False) + + # Optionally freeze parameters + if freeze: + for param in self.parameters(): + param.requires_grad = False + class MaskedAutoencoderViT(L.LightningModule): """ diff --git a/minerva/pipelines/base.py b/minerva/pipelines/base.py index f626a3f..88167d7 100644 --- a/minerva/pipelines/base.py +++ b/minerva/pipelines/base.py @@ -159,11 +159,11 @@ def pipeline_info(self) -> Dict[str, str]: The dictionary with the pipeline information """ return { - "class_name": self.__class__.__name__, - "created_time": self._created_at, + "class_name": str(self.__class__.__name__), + "created_time": str(self._created_at), "pipeline_id": self.pipeline_id, "log_dir": str(self.log_dir), - "run_count": self._run_count, + "run_count": str(self._run_count), } @property diff --git a/minerva/pipelines/hyperopt_hyperparameter_search.py b/minerva/pipelines/hyperopt_hyperparameter_search.py new file mode 100644 index 0000000..0cafb0c --- /dev/null +++ b/minerva/pipelines/hyperopt_hyperparameter_search.py @@ -0,0 +1,154 @@ +from copy import deepcopy +from typing import Any, Dict, Literal, Optional + +import lightning.pytorch as L +from lightning.pytorch.strategies import Strategy +from ray import tune +from ray.train import CheckpointConfig, RunConfig, ScalingConfig +from ray.train.lightning import RayDDPStrategy, RayLightningEnvironment, prepare_trainer +from ray.train.torch import TorchTrainer +from ray.tune.schedulers import ASHAScheduler, TrialScheduler +from ray.tune.search import ConcurrencyLimiter +from ray.tune.search.hyperopt import HyperOptSearch +from ray.tune.stopper import TrialPlateauStopper + +from minerva.callbacks.HyperSearchCallbacks import TrainerReportOnIntervalCallback +from minerva.pipelines.base import Pipeline +from minerva.utils.typing import PathLike + + +class HyperoptHyperParameterSearch(Pipeline): + + def __init__( + self, + model: type, + search_space: Dict[str, Any], + log_dir: Optional[PathLike] = None, + save_run_status: bool = True, + ): + super().__init__(log_dir=log_dir, save_run_status=save_run_status) + self.model = model + self.search_space = search_space + + def _search( + self, + data: L.LightningDataModule, + ckpt_path: Optional[PathLike], + devices: Optional[str] = "auto", + accelerator: Optional[str] = "auto", + strategy: Optional[Strategy] = None, + callbacks: Optional[Any] = None, + plugins: Optional[Any] = None, + num_nodes: int = 1, + debug_mode: Optional[bool] = False, + scaling_config: Optional[ScalingConfig] = None, + run_config: Optional[RunConfig] = None, + tuner_metric: Optional[str] = "val_loss", + tuner_mode: Optional[str] = "min", + num_samples: Optional[int] = -1, + scheduler: Optional[TrialScheduler] = None, + max_concurrent: Optional[int] = 4, + initial_parameters: Optional[Dict[str, Any]] = None, + max_epochs: Optional[int] = None, + num_results: Optional[int] = 5, + std: Optional[float] = 0.01, + grace_period: Optional[int] = 50, + ) -> Any: + + def _tuner_train_func(config): + dm = deepcopy(data) + model = self.model.create_from_dict(config) + trainer = L.Trainer( + max_epochs=max_epochs or 500, + devices=devices or "auto", + accelerator=accelerator or "auto", + strategy=strategy or RayDDPStrategy(find_unused_parameters=True), + callbacks=callbacks or [TrainerReportOnIntervalCallback(500)], + plugins=plugins or [RayLightningEnvironment()], + enable_progress_bar=False, + num_nodes=num_nodes, + enable_checkpointing=False if debug_mode else None, + ) + trainer = prepare_trainer(trainer) + trainer.fit(model, dm, ckpt_path=ckpt_path) + + scheduler = scheduler or ASHAScheduler( + time_attr="training_iteration", + metric=tuner_metric or "val_loss", + mode=tuner_mode or "min", + max_t=500, + grace_period=100, + ) + + scaling_config = scaling_config or ScalingConfig( + num_workers=1, use_gpu=True, resources_per_worker={"GPU": 1} + ) + + run_config = run_config or RunConfig( + checkpoint_config=CheckpointConfig( + num_to_keep=1, + checkpoint_score_attribute="val_loss", + checkpoint_score_order="min", + ), + stop=TrialPlateauStopper( + metric=tuner_metric or "val_loss", + mode=tuner_mode or "min", + num_results=num_results or 5, + std=std or 0.01, + grace_period=grace_period or 50, + ), + ) + + ray_trainer = TorchTrainer( + _tuner_train_func, + scaling_config=scaling_config, + run_config=run_config, + ) + + algo = ConcurrencyLimiter( + HyperOptSearch(initial_parameters), max_concurrent=max_concurrent or 4 + ) + + tuner = tune.Tuner( + ray_trainer, + param_space={"train_loop_config": self.search_space}, + tune_config=tune.TuneConfig( + metric=tuner_metric or "val_loss", + mode=tuner_mode or "min", + num_samples=num_samples or -1, + search_alg=algo, + ), + ) + return tuner.fit() + + def _test(self, data: L.LightningDataModule, ckpt_path: Optional[PathLike]) -> Any: + # TODO fix this + return self.trainer.test(self.model, data, ckpt_path=ckpt_path) + + def _run( + self, + data: L.LightningDataModule, + task: Optional[Literal["search", "test", "predict"]], + ckpt_path: Optional[PathLike] = None, + config: Dict[str, Any] = {}, + **kwargs, + ) -> Any: + if task == "search": + return self._search(data, ckpt_path, **config) + elif task == "test": + return self._test(data, ckpt_path) + elif task is None: + search = self._search(data, ckpt_path, **config) + test = self._test(data, ckpt_path) + return search, test + + +def main(): + from jsonargparse import CLI + + print("Hyper Searching 🔍") + CLI(HyperoptHyperParameterSearch, as_positional=False) + + +if __name__ == "__main__": + main() diff --git a/minerva/pipelines/ray_hyperparameter_search.py b/minerva/pipelines/ray_hyperparameter_search.py new file mode 100644 index 0000000..c8b96c2 --- /dev/null +++ b/minerva/pipelines/ray_hyperparameter_search.py @@ -0,0 +1,133 @@ +from copy import deepcopy +from typing import Any, Dict, Literal, Optional + +import lightning.pytorch as L +from lightning.pytorch.strategies import Strategy +from ray import tune +from ray.train import CheckpointConfig, RunConfig, ScalingConfig +from ray.train.lightning import RayDDPStrategy, RayLightningEnvironment, prepare_trainer +from ray.train.torch import TorchTrainer +from ray.tune.schedulers import ASHAScheduler, TrialScheduler + +from minerva.callbacks.HyperSearchCallbacks import TrainerReportKeepOnlyLastCallback +from minerva.pipelines.base import Pipeline +from minerva.utils.typing import PathLike + + +class RayHyperParameterSearch(Pipeline): + + def __init__( + self, + model: type, + search_space: Dict[str, Any], + log_dir: Optional[PathLike] = None, + save_run_status: bool = True, + ): + super().__init__(log_dir=log_dir, save_run_status=save_run_status) + self.model = model + self.search_space = search_space + + def _search( + self, + data: L.LightningDataModule, + ckpt_path: Optional[PathLike], + devices: Optional[str] = "auto", + accelerator: Optional[str] = "auto", + strategy: Optional[Strategy] = None, + callbacks: Optional[Any] = None, + plugins: Optional[Any] = None, + num_nodes: int = 1, + debug_mode: Optional[bool] = False, + scaling_config: Optional[ScalingConfig] = None, + run_config: Optional[RunConfig] = None, + tuner_metric: Optional[str] = "val_loss", + tuner_mode: Optional[str] = "min", + num_samples: Optional[int] = 10, + scheduler: Optional[TrialScheduler] = None, + ) -> Any: + + def _tuner_train_func(config): + dm = deepcopy(data) + model = self.model.create_from_dict(config_dict=config) + trainer = L.Trainer( + devices=devices or "auto", + accelerator=accelerator or "auto", + strategy=strategy or RayDDPStrategy(find_unused_parameters=True), + callbacks=callbacks or [TrainerReportKeepOnlyLastCallback()], + plugins=plugins or [RayLightningEnvironment()], + enable_progress_bar=False, + num_nodes=num_nodes, + enable_checkpointing=False if debug_mode else None, + ) + trainer = prepare_trainer(trainer) + trainer.fit(model, dm, ckpt_path=ckpt_path) + + scheduler = scheduler or ASHAScheduler( + time_attr="training_iteration", + metric=tuner_metric or "val_loss", + mode=tuner_mode or "min", + max_t=2, + grace_period=1, + brackets=1, + ) + + scaling_config = scaling_config or ScalingConfig( + num_workers=1, use_gpu=True, resources_per_worker={"GPU": 1} + ) + + run_config = run_config or RunConfig( + checkpoint_config=CheckpointConfig( + num_to_keep=1, + checkpoint_score_attribute="val_loss", + checkpoint_score_order="min", + checkpoint_frequency=10, + ) + ) + + ray_trainer = TorchTrainer( + _tuner_train_func, + scaling_config=scaling_config, + run_config=run_config, + ) + tuner = tune.Tuner( + ray_trainer, + param_space={"train_loop_config": self.search_space}, + tune_config=tune.TuneConfig( + metric=tuner_metric or "val_loss", + mode=tuner_mode or "min", + num_samples=num_samples or 10, + scheduler=scheduler, + ), + ) + return tuner.fit() + + def _test(self, data: L.LightningDataModule, ckpt_path: Optional[PathLike]) -> Any: + # TODO fix this + return self.trainer.test(self.model, data, ckpt_path=ckpt_path) + + def _run( + self, + data: L.LightningDataModule, + task: Optional[Literal["search", "test", "predict"]], + ckpt_path: Optional[PathLike] = None, + **kwargs, + ) -> Any: + if task == "search": + return self._search(data, ckpt_path, **kwargs) + elif task == "test": + return self._test(data, ckpt_path) + elif task is None: + search = self._search(data, ckpt_path, **kwargs) + test = self._test(data, ckpt_path) + return search, test + + +def main(): + from jsonargparse import CLI + + print("Hyper Searching 🔍") + CLI(RayHyperParameterSearch, as_positional=False) + + +if __name__ == "__main__": + main() diff --git a/minerva/transforms/context_transform.py b/minerva/transforms/context_transform.py new file mode 100644 index 0000000..8c09564 --- /dev/null +++ b/minerva/transforms/context_transform.py @@ -0,0 +1,70 @@ +from typing import Any + +import numpy as np + +from minerva.transforms.transform import _Transform + + +class ClassRatioCrop(_Transform): + + def __init__( + self, + target_h_size: int, + target_w_size: int, + cat_max_ratio: float = 0.75, + max_attempts: int = 10, + ) -> None: + """Crop the input data to a target size, while keeping the ratio of classes in the image. + + Parameters + ---------- + target_h_size : int + The target height of the crop. + target_w_size : int + The target width of the crop. + cat_max_ratio : float, optional + The maximum ratio of pixels of a single class in the crop, by default 0.75 + max_attempts : int, optional + The maximum number of attempts to crop the image, by default 10 + """ + self.target_h_size = target_h_size + self.target_w_size = target_w_size + self.cat_max_ratio = cat_max_ratio + self.max_attempts = max_attempts + self.crop_coords = None + + def __call__(self, x: np.ndarray) -> np.ndarray: + h, w = x.shape[:2] + + if self.crop_coords is None: + if not issubclass(x.dtype.type, np.integer): + raise ValueError( + "You must provide a mask first to use this functionality. For that you enable support_context_transforms if your dataset supports it, or use a different dataset that does supports it." + ) + + for _ in range(self.max_attempts): + # Randomly select the top-left corner for the crop + top = np.random.randint(0, h - self.target_h_size + 1) + left = np.random.randint(0, w - self.target_w_size + 1) + + # Extract the crop from both image and label + cropped_image = x[ + top : top + self.target_h_size, left : left + self.target_w_size + ] + + # Calculate the proportion of the most frequent class in the crop + _, counts = np.unique(cropped_image, return_counts=True) + class_ratios = counts / (self.target_h_size * self.target_w_size) + + if np.max(class_ratios) <= self.cat_max_ratio: + self.crop_coords = (top, left) + return cropped_image + + # If no valid crop was found, return the last crop (without meeting the ratio constraint) + self.crop_coords = (top, left) + return cropped_image + + else: + top, left = self.crop_coords + self.crop_coords = None + return x[top : top + self.target_h_size, left : left + self.target_w_size] diff --git a/minerva/transforms/random_transform.py b/minerva/transforms/random_transform.py new file mode 100644 index 0000000..6107845 --- /dev/null +++ b/minerva/transforms/random_transform.py @@ -0,0 +1,120 @@ +import random +from typing import List, Optional, Tuple, Union + +import numpy as np + +from minerva.transforms.transform import Flip, Resize, _Transform + + +class EmptyTransform(_Transform): + """A transform that does nothing to the input data.""" + + def __call__(self, data): + return data + + +class _RandomSyncedTransform(_Transform): + """Orchestrate the application of a type of random transform to a list of data, ensuring that the same random state is used for all of them.""" + + def __init__(self, num_samples: int, seed: Optional[int] = None): + """Orchestrate the application of a type of random transform to a list of data, ensuring that the same random state is used for all of them. + + Parameters + ---------- + transform : _Transform + A transform that will be applied to the input data. + num_samples : int + The number of samples that will be transformed. + seed : Optional[int], optional + The seed that will be used to generate the random state, by default None. + """ + self.num_samples = num_samples + self.transformations_executed = 0 + self.rng = np.random.default_rng(seed) + self.transform = EmptyTransform() + + def __call__(self, data): + if self.transformations_executed == 0: + self.transform = self.select_transform(data) + self.transformations_executed += 1 + return self.transform(data) + else: + if self.transformations_executed == self.num_samples - 1: + self.transformations_executed = 0 + else: + self.transformations_executed += 1 + return self.transform(data) + + def select_transform(self, data) -> _Transform: + raise NotImplementedError( + "This method should be implemented by the child class." + ) + + +class RandomFlip(_RandomSyncedTransform): + + def __init__( + self, + num_samples: int, + possible_axis: Union[int, List[int]] = 0, + seed: Optional[int] = None, + ): + """A transform that flips the input data along a random axis. + + Parameters + ---------- + num_samples : int + The number of samples that will be transformed. + possible_axis : Union[int, List[int]], optional + Possible axis to be transformed, will be chosen at random, by default 0 + seed : Optional[int], optional + A seed to ensure deterministic run, by default None + """ + super().__init__(num_samples, seed) + self.possible_axis = possible_axis + + def select_transform(self, data): + """selects the transform to be applied to the data.""" + + if isinstance(self.possible_axis, int): + flip_axis = self.rng.choice([True, False]) + if flip_axis: + return Flip(axis=self.possible_axis) + + else: + flip_axis = [ + bool(self.rng.choice([True, False])) + for _ in range(len(self.possible_axis)) + ] + if True in flip_axis: + chosen_axis = [ + axis for axis, flip in zip(self.possible_axis, flip_axis) if flip + ] + return Flip(axis=chosen_axis) + + return EmptyTransform() + + +class RandomResize(_RandomSyncedTransform): + + def __init__( + self, + target_scale: Tuple[int, int], + ratio_range: Tuple[float, float], + num_samples: int, + seed: Optional[int] = None, + ): + super().__init__(num_samples, seed) + self.target_scale = target_scale + self.ratio_range = ratio_range + self.resize: Optional[_Transform] = None + + def select_transform(self, data): + orig_height, orig_width = data.shape[:2] + + # Apply a random scaling factor within the ratio range + scale_factor = self.rng.uniform(*self.ratio_range) + new_width = int(self.target_scale[1] * scale_factor) + new_height = int(self.target_scale[0] * scale_factor) + + return Resize(new_width, new_height) diff --git a/minerva/transforms/transform.py b/minerva/transforms/transform.py index 9d2dd28..48744ad 100644 --- a/minerva/transforms/transform.py +++ b/minerva/transforms/transform.py @@ -1,6 +1,7 @@ from itertools import product -from typing import Any, List, Sequence, Union +from typing import Any, List, Literal, Sequence, Tuple, Union +import cv2 import numpy as np import torch from perlin_noise import PerlinNoise @@ -69,16 +70,16 @@ def __call__(self, x: np.ndarray) -> np.ndarray: """ if isinstance(self.axis, int): - return np.flip(x, axis=self.axis) + return np.flip(x, axis=self.axis).copy() assert ( len(self.axis) <= x.ndim - ), "Axis list has more dimentions than input data. The lenth of axis needs to be less or equal to input dimentions." + ), "Axis list has more dimensions than input data. The length of axis needs to be less or equal to input dimensions." for axis in self.axis: x = np.flip(x, axis=axis) - return x + return x.copy() class PerlinMasker(_Transform): @@ -175,19 +176,193 @@ def __call__(self, x: np.ndarray) -> np.ndarray: class Padding(_Transform): - def __init__(self, target_h_size: int, target_w_size: int): + def __init__( + self, + target_h_size: int, + target_w_size: int, + padding_mode: Literal["reflect", "constant"] = "reflect", + constant_value: int = 0, + mask_value: int = 255, + ): self.target_h_size = target_h_size self.target_w_size = target_w_size + self.padding_mode = padding_mode + self.constant_value = constant_value + self.mask_value = mask_value def __call__(self, x: np.ndarray) -> np.ndarray: h, w = x.shape[:2] pad_h = max(0, self.target_h_size - h) pad_w = max(0, self.target_w_size - w) + is_label = True if x.dtype == np.uint8 else False + if len(x.shape) == 2: - padded = np.pad(x, ((0, pad_h), (0, pad_w)), mode="reflect") + if self.padding_mode == "reflect": + padded = np.pad(x, ((0, pad_h), (0, pad_w)), mode="reflect") + elif self.padding_mode == "constant": + if is_label: + padded = np.pad( + x, + ((0, pad_h), (0, pad_w)), + mode="constant", + constant_values=self.mask_value, + ) + else: + padded = np.pad( + x, + ((0, pad_h), (0, pad_w)), + mode="constant", + constant_values=self.constant_value, + ) padded = np.expand_dims(padded, axis=2) - else: - padded = np.pad(x, ((0, pad_h), (0, pad_w), (0, 0)), mode="reflect") - padded = np.transpose(padded, (2, 0, 1)) + else: + if self.padding_mode == "reflect": + padded = np.pad(x, ((0, pad_h), (0, pad_w), (0, 0)), mode="reflect") + elif self.padding_mode == "constant": + if is_label: + padded = np.pad( + x, + ((0, pad_h), (0, pad_w), (0, 0)), + mode="constant", + constant_values=self.mask_value, + ) + else: + padded = np.pad( + x, + ((0, pad_h), (0, pad_w), (0, 0)), + mode="constant", + constant_values=self.constant_value, + ) return padded + + +class Normalize(_Transform): + def __init__(self, mean, std, to_rgb=False, normalize_labels=False): + """ + Initialize the Normalize transform. + + Args: + means (list or tuple): A list or tuple containing the mean for each channel. + stds (list or tuple): A list or tuple containing the standard deviation for each channel. + to_rgb (bool): If True, convert the data from BGR to RGB. + """ + assert len(mean) == len( + std + ), "Means and standard deviations must have the same length." + self.mean = mean + self.std = std + self.to_rgb = to_rgb + self.normalize_labels = normalize_labels + + def __call__(self, data): + """ + Normalize the input data using the provided means and standard deviations. + + Args: + data (numpy.ndarray): Input data array of shape (C, H, W) where C is the number of channels. + + Returns: + numpy.ndarray: Normalized data. + """ + + is_label = True if data.dtype == np.uint8 else False + + if is_label and self.normalize_labels: + # Convert from gray scale (1 channel) to RGB (3 channels) if to_rgb is True + if self.to_rgb and data.shape[0] == 1: + data = np.repeat(data, 3, axis=0) + + assert data.shape[0] == len( + self.mean + ), f"Number of channels in data does not match the number of provided mean/std. {data.shape}" + + # Normalize each channel + for i in range(len(self.mean)): + data[i, :, :] = (data[i, :, :] - self.mean[i]) / self.std[i] + + return data + + +class Crop(_Transform): + def __init__( + self, + target_h_size: int, + target_w_size: int, + start_coord: Tuple[int, int] = (0, 0), + ): + self.target_h_size = target_h_size + self.target_w_size = target_w_size + self.start_coord = start_coord + + def __call__(self, x: np.ndarray) -> np.ndarray: + h, w = x.shape[:2] + start_h = (h - self.target_h_size) // 2 + start_w = (w - self.target_w_size) // 2 + end_h = start_h + self.target_h_size + end_w = start_w + self.target_w_size + if len(x.shape) == 2: + cropped = x[start_h:end_h, start_w:end_w] + cropped = np.expand_dims(cropped, axis=2) + else: + cropped = x[start_h:end_h, start_w:end_w] + + return cropped + + +class Transpose(_Transform): + """Reorder the axes of numpy arrays.""" + + def __init__(self, axes: Sequence[int]): + """Reorder the axes of numpy arrays. + + Parameters + ---------- + axes : int + The order of the new axes + """ + self.axes = axes + + def __call__(self, x: np.ndarray) -> np.ndarray: + """Reorder the axes of numpy arrays.""" + + if len(x.shape) == 2: + x = np.expand_dims(x, axis=2) + return np.transpose(x, self.axes) + + +class Resize(_Transform): + + def __init__( + self, + target_h_size: int, + target_w_size: int, + keep_aspect_ratio: bool = False, + ): + self.target_h_size = target_h_size + self.target_w_size = target_w_size + self.keep_aspect_ratio = keep_aspect_ratio + + def __call__(self, x: np.ndarray) -> np.ndarray: + original_height, original_width = x.shape[:2] + + if not self.keep_aspect_ratio: + # Direct resize without keeping the aspect ratio + return cv2.resize( + x, + (self.target_w_size, self.target_h_size), + interpolation=cv2.INTER_NEAREST, + ) + + # Calculate scaling factors for both dimensions + width_scale = self.target_w_size / original_width + height_scale = self.target_h_size / original_height + + # Choose the smaller scale to keep aspect ratio, and round down + scale = min(width_scale, height_scale) + + # Compute new dimensions, rounding down to match MMsegmentation's behavior + new_width = int(original_width * scale) + new_height = int(original_height * scale) + + return cv2.resize(x, (new_width, new_height), interpolation=cv2.INTER_NEAREST) diff --git a/minerva/utils/position_embedding.py b/minerva/utils/position_embedding.py index 0be7959..c0be963 100644 --- a/minerva/utils/position_embedding.py +++ b/minerva/utils/position_embedding.py @@ -1,7 +1,7 @@ from functools import partial +import lightning.pytorch as pl import numpy as np -import pytorch_lightning as pl import torch import torch.nn as nn from timm.models.vision_transformer import Block, PatchEmbed diff --git a/pyproject.toml b/pyproject.toml index 4248306..e29a038 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,8 +32,10 @@ version = "0.2.2-beta" dependencies = [ "gitpython", + "jsonargparse", + "ray[tune]", "jsonargparse>=4.27", - "lightning>=2.1.9", + "lightning==2.2.0", "numpy>=1.23.5", "pandas>=2.2.2", "perlin-noise>=1.12", @@ -45,13 +47,16 @@ dependencies = [ "tifffile>=2024", "timm>=0.9", "torch>=2.0.8", - "torchmetrics>=1.3.0", + "torchmetrics==1.3.1", "torchvision>=0.15", - "zarr>=2.17" + "opencv-python>=4.10.0.84", + "zarr>=2.17", + "hyperopt>=0.2.5", ] -[tool.setuptools] -packages = ["minerva"] +[tool.setuptools.packages.find] +where = ["."] +include = ["minerva*"] [project.optional-dependencies] dev = ["mock", "pytest", "black", "isort"] @@ -63,7 +68,7 @@ docs = [ "sphinx-rtd-theme", "sphinx-autodoc-typehints", "sphinx-argparse", - "sphinx-autoapi" + "sphinx-autoapi", ] [project.urls] diff --git a/tests/engines/test_patch_inferencer_engine.py b/tests/engines/test_patch_inferencer_engine.py index ea508cc..97f6d1e 100644 --- a/tests/engines/test_patch_inferencer_engine.py +++ b/tests/engines/test_patch_inferencer_engine.py @@ -1,6 +1,7 @@ import torch -import lightning as L -from minerva.engines.patch_inferencer_engine import WeightedAvgPatchInferencer, VotingPatchInferencer + +from minerva.engines.patch_inferencer_engine import PatchInferencer +from minerva.models.nets.base import SimpleSupervisedModel pyramid = [ [ @@ -25,40 +26,41 @@ classes = [ [ [ - [0.50, 0.25, 0.25], - [0.50, 0.25, 0.25], - [0.25, 0.50, 0.25], - [0.25, 0.25, 0.50], - ], - [ - [0.50, 0.25, 0.25], - [0.50, 0.25, 0.25], - [0.25, 0.50, 0.25], - [0.25, 0.25, 0.50], - ], + [0.5, 0.5, 0.25, 0.25], + [0.5, 0.5, 0.25, 0.25], + [0.5, 0.5, 0.25, 0.25], + [0.5, 0.5, 0.25, 0.25], + ] + ], + [ [ - [0.50, 0.25, 0.25], - [0.50, 0.25, 0.25], - [0.25, 0.50, 0.25], - [0.25, 0.25, 0.50], - ], + [0.25, 0.25, 0.5, 0.25], + [0.25, 0.25, 0.5, 0.25], + [0.25, 0.25, 0.5, 0.25], + [0.25, 0.25, 0.5, 0.25], + ] + ], + [ [ - [0.50, 0.25, 0.25], - [0.50, 0.25, 0.25], - [0.25, 0.50, 0.25], - [0.25, 0.25, 0.50], - ], - ] + [0.25, 0.25, 0.25, 0.5], + [0.25, 0.25, 0.25, 0.5], + [0.25, 0.25, 0.25, 0.5], + [0.25, 0.25, 0.25, 0.5], + ] + ], ] def weight_function(shape: tuple) -> torch.Tensor: - assert shape == (1, 5, 5), "Reference shape must be (1, 5, 5)" + assert shape == (1, 1, 5, 5), "Reference shape must be (1, 1, 5, 5)" return torch.Tensor(weights) -class Pyramid5(L.LightningModule): +class Pyramid5(SimpleSupervisedModel): # Pyramid model, returns a pyramid independent of input + def __init__(self): + super().__init__(None, None, None) + def forward(self, x: torch.Tensor) -> torch.Tensor: assert x.shape[1:] == ( 1, @@ -68,8 +70,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.Tensor(pyramid).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) -class Classifier(L.LightningModule): +class Classifier(SimpleSupervisedModel): # Classifier model, returns same classification result for 4x4 windows with 3 classes + + def __init__(self): + super().__init__(None, None, None) + def forward(self, x: torch.Tensor) -> torch.Tensor: assert x.shape[1:] == ( 1, @@ -79,94 +85,147 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.Tensor(classes).unsqueeze(0).repeat(x.shape[0], 1, 1, 1, 1) -def test_weighted_avg_patch_inferencer_basic(): - # Test WeightedAvgPatchInferencer basic usage +def test_patch_inferencer_regression_basic(): + # Test PatchInferer basic usage in regression task model = Pyramid5() - inferencer_no_pad = WeightedAvgPatchInferencer(model=model, input_shape=(1, 5, 5)) + inferencer_no_pad = PatchInferencer(model=model, input_shape=(1, 5, 5)) - model_input = torch.zeros((2, 20, 20)) + model_input_batch = torch.zeros((1, 2, 20, 20)) - output = inferencer_no_pad(model_input) - assert model_input.shape == output.shape, "Input and Output don't have same shape" - expected_output = torch.Tensor(pyramid).repeat(2, 4, 4) + output = inferencer_no_pad(model_input_batch) + assert ( + model_input_batch.shape == output.shape + ), "Input and Output don't have same shape" + expected_output = torch.Tensor(pyramid).repeat(2, 4, 4).unsqueeze(0) assert torch.equal( output, expected_output ), "Output does not match with expected output" -def test_weighted_avg_patch_inferencer(): - # Test WeightedAvgPatchInferencer with offsets, padding and custom weight function +def test_patch_inferencer_regression(): + # Test PatchInferer with offsets, padding and custom weight function in regression task model = Pyramid5() - inferencer = WeightedAvgPatchInferencer( + inferencer = PatchInferencer( model=model, weight_function=weight_function, input_shape=(1, 5, 5), offsets=[ - (0, -1, 0), - (0, -2, 0), - (0, -3, 0), - (0, -4, 0), - (0, 0, -1), - (0, -1, -1), - (0, -2, -1), - (0, -3, -1), - (0, -4, -1), - (0, 0, -2), - (0, -1, -2), - (0, -2, -2), - (0, -3, -2), - (0, -4, -2), - (0, 0, -3), - (0, -1, -3), - (0, -2, -3), - (0, -3, -3), - (0, -4, -3), - (0, 0, -4), - (0, -1, -4), - (0, -2, -4), - (0, -3, -4), - (0, -4, -4), + (0, 1, 0), + (0, 2, 0), + (0, 3, 0), + (0, 4, 0), + (0, 0, 1), + (0, 1, 1), + (0, 2, 1), + (0, 3, 1), + (0, 4, 1), + (0, 0, 2), + (0, 1, 2), + (0, 2, 2), + (0, 3, 2), + (0, 4, 2), + (0, 0, 3), + (0, 1, 3), + (0, 2, 3), + (0, 3, 3), + (0, 4, 3), + (0, 0, 4), + (0, 1, 4), + (0, 2, 4), + (0, 3, 4), + (0, 4, 4), ], - padding={"pad": (0, 4, 4)}, + padding={"pad": (0, 24, 24)}, ) - model_input = torch.zeros((2, 20, 20)) + model_input_batch = torch.zeros((1, 2, 20, 20)) - output = inferencer(model_input) - assert model_input.shape == output.shape, "Input and Output don't have same shape" - # Offsets used in this test result in each point of the output to be the combination of all the 25 points - # that make up the 5x5 pyramid, using the weights to combine them in a weighted average + output = inferencer(model_input_batch) + assert ( + model_input_batch.shape == output.shape + ), "Input and Output don't have same shape" + # Offsets used in this test result in each point of the output to be the combination of all the 25 points + # that make up the 5x5 pyramid, using the weights to combine them in a weighted average expected_value = torch.sum( torch.Tensor(pyramid) * torch.Tensor(weights) ) / torch.sum(torch.Tensor(weights)) - expected_output = torch.full((2, 20, 20), fill_value=expected_value) + expected_output_middle = torch.full((1, 2, 16, 16), fill_value=expected_value) assert torch.equal( - output, expected_output - ), "Output does not match with expected output" + output[:, :, 0, :], torch.full((1, 2, 20), fill_value=1) + ), "Output upper border region does not match with expected values" + + assert torch.equal( + output[:, :, :, 0], torch.full((1, 2, 20), fill_value=1) + ), "Output left border region does not match with expected values" + assert torch.equal( + output[:, :, 4:, 4:], expected_output_middle + ), "Output middle region does not match with expected values" -def test_voting_patch_inferencer_basic(): - # Test VotingPatchInferencer basic usage +def test_patch_inferencer_classification_basic(): + # Test PatchInference basic usage in classification task model = Classifier() - inferencer = VotingPatchInferencer( - model=model, input_shape=(1, 4, 4), num_classes=3, voting="soft" + inferencer = PatchInferencer( + model=model, input_shape=(1, 4, 4), output_shape=(3, 1, 4, 4) ) - model_input = torch.zeros((2, 20, 20)) + model_input_batch = torch.zeros((1, 2, 20, 20)) + + output = inferencer(model_input_batch) + + assert (1, 3, 2, 20, 20) == output.shape, "Output doen't have expected shape" + expected_classification = ( + torch.Tensor( + [ + [0, 0, 1, 2], + [0, 0, 1, 2], + [0, 0, 1, 2], + [0, 0, 1, 2], + ] + ) + .repeat(2, 5, 5) + .unsqueeze(0) + ) - output = inferencer(model_input) - assert model_input.shape == output.shape, "Input and Output don't have same shape" - expected_output = torch.Tensor( - [ - [0, 0, 1, 2], - [0, 0, 1, 2], - [0, 0, 1, 2], - [0, 0, 1, 2], - ] - ).repeat(2, 5, 5) + predicted_classes = torch.argmax(output, dim=1, keepdim=False) assert torch.equal( - output, expected_output - ), "Output does not match with expected output" + predicted_classes, expected_classification + ), "Predicted classes don't match with expected classification" + + +def test_patch_inferencer_classification(): + # Test PatchInference with offset in classification task + model = Classifier() + inferencer = PatchInferencer( + model=model, + input_shape=(1, 4, 4), + output_shape=(3, 1, 4, 4), + offsets=[(0, 0, 2)], + ) + + model_input_batch = torch.zeros((1, 2, 8, 8)) + + output = inferencer(model_input_batch) + + assert (1, 3, 2, 8, 8) == output.shape, "Output doen't have expected shape" + expected_classification = ( + torch.Tensor( + [ + [0, 0, 0, 0, 0, 0, 1, 2], + [0, 0, 0, 0, 0, 0, 1, 2], + [0, 0, 0, 0, 0, 0, 1, 2], + [0, 0, 0, 0, 0, 0, 1, 2], + ] + ) + .repeat(2, 2, 1) + .unsqueeze(0) + ) + + predicted_classes = torch.argmax(output, dim=1, keepdim=False) + + assert torch.equal( + predicted_classes, expected_classification + ), "Predicted classes don't match with expected classification" diff --git a/tests/transforms/test_random_flip.py b/tests/transforms/test_random_flip.py new file mode 100644 index 0000000..51d10c7 --- /dev/null +++ b/tests/transforms/test_random_flip.py @@ -0,0 +1,79 @@ +import numpy as np +import pytest + +from minerva.transforms.random_transform import RandomFlip + + +def test_random_flip_single_axis_with_flip(): + # Create a dummy input + x = np.random.rand(10, 20) + + # Apply the flip transform along the first axis + flip_transform = RandomFlip(possible_axis=0, num_samples=1, seed=0) + flipped_x = flip_transform(x) + + # Check if the flipped data has the same shape as the input + assert flipped_x.shape == x.shape + + # Check if the flipped data is different from the input + assert np.allclose(flipped_x, np.flip(x, axis=0)) + + +def test_random_flip_single_axis_without_flip(): + # Create a dummy input + x = np.random.rand(10, 20) + + # Apply the flip transform along the first axis + flip_transform = RandomFlip(possible_axis=0, num_samples=1, seed=1) + flipped_x = flip_transform(x) + + # Check if the flipped data has the same shape as the input + assert flipped_x.shape == x.shape + + # Check if the flipped data is different from the input + assert np.allclose(flipped_x, x) + + +def test_random_flip_first_axis(): + # Create a dummy input + x = np.random.rand(10, 20, 30) + + # Apply the flip transform along multiple axes + flip_transform = RandomFlip(possible_axis=[0, 1], num_samples=1, seed=0) + flipped_x = flip_transform(x) + + # Check if the flipped data has the same shape as the input + assert flipped_x.shape == x.shape + + # check if only the first axis is flipped + assert np.allclose(flipped_x, np.flip(x, axis=0)) + + +def test_random_flip_second_axis(): + # Create a dummy input + x = np.random.rand(10, 20, 30) + + # Apply the flip transform along multiple axes + flip_transform = RandomFlip(possible_axis=[0, 1], num_samples=1, seed=1) + flipped_x = flip_transform(x) + + # Check if the flipped data has the same shape as the input + assert flipped_x.shape == x.shape + + # check if the second axis is flipped + assert np.allclose(flipped_x, np.flip(x, axis=1)) + + +def test_random_flip_two_axis(): + # Create a dummy input + x = np.random.rand(10, 20, 30) + + # Apply the flip transform along multiple axes + flip_transform = RandomFlip(possible_axis=[0, 1], num_samples=1, seed=2) + flipped_x = flip_transform(x) + + # Check if the flipped data has the same shape as the input + assert flipped_x.shape == x.shape + + # check if both axis are flipped + assert np.allclose(flipped_x, np.flip(x, axis=(0, 1)))