From 85d1115285bc95e70094f7a8ca90f6b2999f5e1b Mon Sep 17 00:00:00 2001 From: Eliot Behr Date: Mon, 27 Jan 2025 16:02:58 -0500 Subject: [PATCH 1/3] feat: Add bottleneck embedding extraction functionality - Add BottleneckEnsemblePredictor class for extracting embeddings - Add extraction script for easy command-line usage - Support ensemble predictions across all folds - Fall back to checkpoint_best.pth if final not found --- nnunetv2/inference/bottleneck_predictor.py | 47 ++++++++++++++++++++++ scripts/extract_embeddings.py | 17 ++++++++ 2 files changed, 64 insertions(+) create mode 100644 nnunetv2/inference/bottleneck_predictor.py create mode 100644 scripts/extract_embeddings.py diff --git a/nnunetv2/inference/bottleneck_predictor.py b/nnunetv2/inference/bottleneck_predictor.py new file mode 100644 index 000000000..f7832d427 --- /dev/null +++ b/nnunetv2/inference/bottleneck_predictor.py @@ -0,0 +1,47 @@ +from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor +import numpy as np +import torch +import os + +class BottleneckEnsemblePredictor(nnUNetPredictor): + def initialize_from_trained_model_folder(self, model_training_output_dir: str, + use_folds: tuple[int, ...], + checkpoint_name: str = "checkpoint_final.pth"): + # Try checkpoint_final.pth first, fall back to checkpoint_best.pth if not found + if not all(os.path.exists(os.path.join(model_training_output_dir, f'fold_{i}', checkpoint_name)) + for i in use_folds): + print(f"Warning: {checkpoint_name} not found in all folds, falling back to checkpoint_best.pth") + checkpoint_name = "checkpoint_best.pth" + + return super().initialize_from_trained_model_folder( + model_training_output_dir, + use_folds, + checkpoint_name + ) + + def predict_single_npy_array(self, input_image: np.ndarray, image_properties: dict): + self.network.eval() + all_fold_features = [] + + with torch.no_grad(): + x = torch.from_numpy(input_image).cuda(self.device, non_blocking=True) + + for network in self.networks_and_mirrors: + net = network[0] + features = net.encoder(x) # Get bottleneck features + all_fold_features.append(features.cpu().numpy()) + + # Average across folds + ensemble_features = np.mean(all_fold_features, axis=0) + return ensemble_features + + def predict_from_files(self, input_folder: str, output_folder: str, *args, **kwargs): + if not os.path.exists(output_folder): + os.makedirs(output_folder) + + # Call parent class but capture embeddings + embeddings = super().predict_from_files(input_folder, output_folder, *args, **kwargs) + + # Save embeddings + np.save(os.path.join(output_folder, 'bottleneck_embeddings.npy'), embeddings) + return embeddings diff --git a/scripts/extract_embeddings.py b/scripts/extract_embeddings.py new file mode 100644 index 000000000..b22d05c21 --- /dev/null +++ b/scripts/extract_embeddings.py @@ -0,0 +1,17 @@ +from bottleneck_predictor import BottleneckEnsemblePredictor + +def main(): + predictor = BottleneckEnsemblePredictor() + predictor.initialize_from_trained_model_folder( + model_folder, # Path to your model folder + use_folds=(0,1,2,3,4), # Use all folds for ensemble + checkpoint_name="checkpoint_final.pth" # Will fall back to checkpoint_best.pth if not found + ) + + embeddings = predictor.predict_from_files( + input_folder="path/to/input/images", + output_folder="path/to/output/embeddings", + ) + +if __name__ == "__main__": + main() From 75c65ddc3f75289215f79764dd8ee849f0bcc27c Mon Sep 17 00:00:00 2001 From: Eliot Behr Date: Mon, 27 Jan 2025 17:00:00 -0500 Subject: [PATCH 2/3] bottleneck.py and extract_embeddings.py testable --- nnunetv2/inference/bottleneck_predictor.py | 139 +++++++++++++++++---- scripts/extract_embeddings.py | 2 +- 2 files changed, 114 insertions(+), 27 deletions(-) diff --git a/nnunetv2/inference/bottleneck_predictor.py b/nnunetv2/inference/bottleneck_predictor.py index f7832d427..e2e14bf0f 100644 --- a/nnunetv2/inference/bottleneck_predictor.py +++ b/nnunetv2/inference/bottleneck_predictor.py @@ -2,46 +2,133 @@ import numpy as np import torch import os +from typing import Union, Tuple +from batchgenerators.utilities.file_and_folder_operations import join -class BottleneckEnsemblePredictor(nnUNetPredictor): +class BottleneckPredictor(nnUNetPredictor): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.network = None + def initialize_from_trained_model_folder(self, model_training_output_dir: str, - use_folds: tuple[int, ...], + use_folds: Union[Tuple[Union[int, str]], None], checkpoint_name: str = "checkpoint_final.pth"): - # Try checkpoint_final.pth first, fall back to checkpoint_best.pth if not found - if not all(os.path.exists(os.path.join(model_training_output_dir, f'fold_{i}', checkpoint_name)) - for i in use_folds): + """ + Initialize model from the trained model folder, with fallback to best checkpoint + """ + # First try the specified checkpoint + checkpoint_exists = False + if use_folds is not None: + checkpoint_exists = all( + os.path.exists(join(model_training_output_dir, f'fold_{i}', checkpoint_name)) + for i in use_folds if i != 'all' + ) + + # Fallback to checkpoint_best if final not found + if not checkpoint_exists: print(f"Warning: {checkpoint_name} not found in all folds, falling back to checkpoint_best.pth") checkpoint_name = "checkpoint_best.pth" - + return super().initialize_from_trained_model_folder( model_training_output_dir, use_folds, checkpoint_name ) - def predict_single_npy_array(self, input_image: np.ndarray, image_properties: dict): + def predict_from_preprocessed_data(self, preprocessed_data: Union[str, np.ndarray]) -> np.ndarray: + """ + Predict bottleneck features from preprocessed data + Args: + preprocessed_data: Either path to .npy file or numpy array + Returns: + bottleneck features as numpy array + """ self.network.eval() - all_fold_features = [] - with torch.no_grad(): - x = torch.from_numpy(input_image).cuda(self.device, non_blocking=True) - - for network in self.networks_and_mirrors: - net = network[0] - features = net.encoder(x) # Get bottleneck features - all_fold_features.append(features.cpu().numpy()) + # Load data if path provided + if isinstance(preprocessed_data, str): + data = np.load(preprocessed_data) + else: + data = preprocessed_data - # Average across folds - ensemble_features = np.mean(all_fold_features, axis=0) - return ensemble_features - - def predict_from_files(self, input_folder: str, output_folder: str, *args, **kwargs): - if not os.path.exists(output_folder): + # Convert to torch tensor + with torch.no_grad(): + data = torch.from_numpy(data).to(self.device) + if len(data.shape) == 3: + data = data.unsqueeze(0).unsqueeze(0) # Add batch and channel dims + elif len(data.shape) == 4: + data = data.unsqueeze(0) # Add batch dim only + + # Get bottleneck features + features = self.network.encoder(data) + return features.cpu().numpy() + + def predict_from_preprocessed_folder(self, + input_folder: str, + output_folder: str = None, + save_embeddings: bool = True) -> dict: + """ + Predict bottleneck features for all .npy files in a folder + Args: + input_folder: Folder containing preprocessed .npy files + output_folder: Where to save embeddings (optional) + save_embeddings: Whether to save embeddings to disk + Returns: + Dictionary mapping filenames to bottleneck features + """ + if output_folder is not None and not os.path.exists(output_folder): os.makedirs(output_folder) - # Call parent class but capture embeddings - embeddings = super().predict_from_files(input_folder, output_folder, *args, **kwargs) + results = {} - # Save embeddings - np.save(os.path.join(output_folder, 'bottleneck_embeddings.npy'), embeddings) - return embeddings + # Process all .npy files + for filename in os.listdir(input_folder): + if filename.endswith('.npy'): + filepath = join(input_folder, filename) + print(f"Processing {filename}...") + + # Get embeddings + embeddings = self.predict_from_preprocessed_data(filepath) + results[filename] = embeddings + + # Save if requested + if save_embeddings and output_folder is not None: + output_path = join(output_folder, f"{filename.split('.')[0]}_embeddings.npy") + np.save(output_path, embeddings) + + return results + +# Example usage: +if __name__ == "__main__": + # Set environment variables + os.environ['nnUNet_preprocessed'] = r"C:\Users\Eliot Behr\VS\Data\HST18\nnUNet_preprocessed" + os.environ['nnUNet_results'] = r"C:\Users\Eliot Behr\VS\Data\HST18\nnUNet_results" + + # Initialize predictor + predictor = BottleneckPredictor( + tile_step_size=0.5, + use_gaussian=True, + use_mirroring=True, + perform_everything_on_device=True, + device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), + verbose=False + ) + + # Initialize from model folder + model_folder = join(os.environ['nnUNet_results'], + 'Dataset001_BrainTumor/nnUNetTrainer__nnUNetPlans__2d') + predictor.initialize_from_trained_model_folder( + model_folder, + use_folds=(0,), # Using fold 0 + ) + + # Predict from preprocessed folder + preprocessed_folder = join(os.environ['nnUNet_preprocessed'], + 'Dataset001_BrainTumor/nnUNetPlans_2d') + output_folder = join(os.environ['nnUNet_results'], + 'Dataset001_BrainTumor/bottleneck_features') + + embeddings = predictor.predict_from_preprocessed_folder( + preprocessed_folder, + output_folder + ) diff --git a/scripts/extract_embeddings.py b/scripts/extract_embeddings.py index b22d05c21..a2fdf7aea 100644 --- a/scripts/extract_embeddings.py +++ b/scripts/extract_embeddings.py @@ -1,4 +1,4 @@ -from bottleneck_predictor import BottleneckEnsemblePredictor +from nnunetv2.inference.bottleneck_predictor import BottleneckEnsemblePredictor def main(): predictor = BottleneckEnsemblePredictor() From 5743f28adddb27160009ee93f9366497566660e2 Mon Sep 17 00:00:00 2001 From: Eliot Behr Date: Mon, 27 Jan 2025 17:01:32 -0500 Subject: [PATCH 3/3] fixed extract_embeddings.py script --- scripts/extract_embeddings.py | 80 +++++++++++++++++++++++++++++++---- 1 file changed, 72 insertions(+), 8 deletions(-) diff --git a/scripts/extract_embeddings.py b/scripts/extract_embeddings.py index a2fdf7aea..2c50da8d8 100644 --- a/scripts/extract_embeddings.py +++ b/scripts/extract_embeddings.py @@ -1,17 +1,81 @@ -from nnunetv2.inference.bottleneck_predictor import BottleneckEnsemblePredictor +import os +import argparse +from nnunetv2.inference.bottleneck_predictor import BottleneckPredictor +from batchgenerators.utilities.file_and_folder_operations import join +import torch + +def parse_args(): + parser = argparse.ArgumentParser(description='Extract bottleneck embeddings from preprocessed nnUNet data') + + # Required arguments + parser.add_argument('-d', '--dataset_id', type=str, required=True, + help='Dataset ID and name (e.g. Dataset001_BrainTumor)') + parser.add_argument('-c', '--configuration', type=str, required=True, + help='Model configuration (e.g. 2d, 3d_fullres)') + + # Optional arguments + parser.add_argument('-f', '--folds', nargs='+', type=int, default=[0], + help='Folds to use for prediction (e.g. 0 1 2 3 4)') + parser.add_argument('--checkpoint', type=str, default='checkpoint_final.pth', + help='Checkpoint name to use (default: checkpoint_final.pth)') + parser.add_argument('--no_save', action='store_false', dest='save_embeddings', + help='Do not save embeddings to disk') + parser.add_argument('--device', type=str, default='cuda', + help='Device to use (cuda or cpu)') + + return parser.parse_args() def main(): - predictor = BottleneckEnsemblePredictor() + args = parse_args() + + # Verify environment variables are set + if not all(os.environ.get(var) for var in ['nnUNet_preprocessed', 'nnUNet_results']): + raise RuntimeError( + "Environment variables nnUNet_preprocessed and nnUNet_results must be set. " + "Please see nnunetv2/documentation/setting_up_paths.md" + ) + + # Initialize predictor + predictor = BottleneckPredictor( + tile_step_size=0.5, + use_gaussian=True, + use_mirroring=True, + perform_everything_on_device=True, + device=torch.device(args.device if torch.cuda.is_available() else 'cpu'), + verbose=False + ) + + # Setup paths + model_folder = join(os.environ['nnUNet_results'], + args.dataset_id, + f'nnUNetTrainer__nnUNetPlans__{args.configuration}') + preprocessed_folder = join(os.environ['nnUNet_preprocessed'], + args.dataset_id, + f'nnUNetPlans_{args.configuration}') + output_folder = join(os.environ['nnUNet_results'], + args.dataset_id, + 'bottleneck_features') + + # Initialize from model folder + print(f"Loading model from {model_folder}") + print(f"Using folds: {args.folds}") predictor.initialize_from_trained_model_folder( - model_folder, # Path to your model folder - use_folds=(0,1,2,3,4), # Use all folds for ensemble - checkpoint_name="checkpoint_final.pth" # Will fall back to checkpoint_best.pth if not found + model_folder, + use_folds=tuple(args.folds), + checkpoint_name=args.checkpoint ) - embeddings = predictor.predict_from_files( - input_folder="path/to/input/images", - output_folder="path/to/output/embeddings", + # Extract embeddings + print(f"Processing preprocessed data from {preprocessed_folder}") + print(f"Saving results to {output_folder}") + embeddings = predictor.predict_from_preprocessed_folder( + preprocessed_folder, + output_folder if args.save_embeddings else None, + save_embeddings=args.save_embeddings ) + + print("Extraction complete!") + return embeddings if __name__ == "__main__": main()