From c9225fb3338be3f571d474a89d7fca7fa0909b3d Mon Sep 17 00:00:00 2001 From: Ray Andrew <4437323+rayandrew@users.noreply.github.com> Date: Wed, 30 Oct 2024 11:19:14 -0500 Subject: [PATCH] enable option to disable pin_memory in pytorch (#239) * enable option to disable pin_memory in pytorch * add the docs for pytorch pin memory --- dlio_benchmark/data_loader/torch_data_loader.py | 4 ++-- dlio_benchmark/utils/config.py | 3 +++ docs/source/config.rst | 3 +++ 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/dlio_benchmark/data_loader/torch_data_loader.py b/dlio_benchmark/data_loader/torch_data_loader.py index 3048dab1..ffd383de 100644 --- a/dlio_benchmark/data_loader/torch_data_loader.py +++ b/dlio_benchmark/data_loader/torch_data_loader.py @@ -143,7 +143,7 @@ def read(self): batch_size=self.batch_size, sampler=sampler, num_workers=self._args.read_threads, - pin_memory=True, + pin_memory=self._args.pin_memory, drop_last=True, worker_init_fn=dataset.worker_init, **kwargs) @@ -152,7 +152,7 @@ def read(self): batch_size=self.batch_size, sampler=sampler, num_workers=self._args.read_threads, - pin_memory=True, + pin_memory=self._args.pin_memory, drop_last=True, worker_init_fn=dataset.worker_init, **kwargs) # 2 is the default value diff --git a/dlio_benchmark/utils/config.py b/dlio_benchmark/utils/config.py index 13265f41..39ed54d9 100644 --- a/dlio_benchmark/utils/config.py +++ b/dlio_benchmark/utils/config.py @@ -118,6 +118,7 @@ class ConfigArguments: data_loader_sampler: DataLoaderSampler = None reader_classname: str = None multiprocessing_context: str = "fork" + pin_memory: bool = True # derived fields required_samples: int = 1 @@ -521,6 +522,8 @@ def LoadConfig(args, config): args.preprocess_time = reader['preprocess_time'] if 'preprocess_time_stdev' in reader: args.preprocess_time_stdev = reader['preprocess_time_stdev'] + if 'pin_memory' in reader: + args.pin_memory = reader['pin_memory'] # training relevant setting if 'train' in config: diff --git a/docs/source/config.rst b/docs/source/config.rst index 95fecde8..4073509b 100644 --- a/docs/source/config.rst +++ b/docs/source/config.rst @@ -201,6 +201,9 @@ reader * - read_threads* - 1 - number of threads to load the data (for tensorflow and pytorch data loader) + * - pin_memory + - True + - whether to pin the memory for pytorch data loader * - computation_threads - 1 - number of threads to preprocess the data