From af06d162cfdb2b53193c3d61b788cb5fc2f03efc Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Wed, 25 Dec 2024 17:03:25 +0800 Subject: [PATCH] [checkpointio] support non blocking pin load (#6172) * [checkpointio] support non blocking pin load * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- colossalai/booster/booster.py | 32 ++- colossalai/booster/plugin/gemini_plugin.py | 53 +++- .../booster/plugin/low_level_zero_plugin.py | 48 +++- colossalai/booster/plugin/torch_ddp_plugin.py | 39 ++- .../booster/plugin/torch_fsdp_plugin.py | 19 +- .../checkpoint_io/checkpoint_io_base.py | 74 +++++- .../checkpoint_io/general_checkpoint_io.py | 35 ++- .../hybrid_parallel_checkpoint_io.py | 39 ++- colossalai/checkpoint_io/moe_checkpoint.py | 18 +- colossalai/checkpoint_io/utils.py | 38 ++- .../test_gemini_checkpoint_io.py | 14 +- ...st_hybrid_parallel_plugin_checkpoint_io.py | 9 +- .../test_low_level_zero_checkpoint_io.py | 7 +- .../test_safetensors_async_io.py | 226 ++++++++++-------- .../test_torch_ddp_checkpoint_io.py | 7 +- 15 files changed, 484 insertions(+), 174 deletions(-) diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py index 43a3b75317ba..13694a3c3136 100644 --- a/colossalai/booster/booster.py +++ b/colossalai/booster/booster.py @@ -288,7 +288,14 @@ def enable_lora( return self.plugin.enable_lora(model, pretrained_dir, lora_config, bnb_quantization_config) - def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None: + def load_model( + self, + model: Union[nn.Module, ModelWrapper], + checkpoint: str, + strict: bool = True, + low_cpu_mem_mode: bool = True, + num_threads: int = 1, + ) -> None: """Load model from checkpoint. Args: @@ -298,8 +305,12 @@ def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, str strict (bool, optional): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Defaults to True. + low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True. + num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1. """ - self.checkpoint_io.load_model(model, checkpoint, strict) + self.checkpoint_io.load_model( + model, checkpoint, strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads + ) def save_model( self, @@ -338,18 +349,25 @@ def save_model( use_async=use_async, ) - def load_optimizer(self, optimizer: Optimizer, checkpoint: str) -> None: + def load_optimizer( + self, + optimizer: Optimizer, + checkpoint: str, + low_cpu_mem_mode: bool = True, + num_threads: int = 1, + ) -> None: """Load optimizer from checkpoint. Args: optimizer (Optimizer): An optimizer boosted by Booster. checkpoint (str): Path to the checkpoint. It must be a local path. It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path. - prefix (str, optional): A prefix added to parameter and buffer - names to compose the keys in state_dict. Defaults to None. - size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024. + low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True. + num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1. """ - self.checkpoint_io.load_optimizer(optimizer, checkpoint) + self.checkpoint_io.load_optimizer( + optimizer, checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads + ) def save_optimizer( self, diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 441670a0aaea..ba43a5066d60 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -1,4 +1,3 @@ -import gc import os import random from pathlib import Path @@ -97,13 +96,22 @@ def save_unsharded_model( else: save_state_dict(state_dict, checkpoint, use_safetensors) - def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True): + def load_unsharded_model( + self, + model: GeminiDDP, + checkpoint: str, + strict: bool = True, + low_cpu_mem_mode: bool = True, + num_threads: int = 1, + ): """ Load model from checkpoint with automatic unwrapping. The model should be unwrapped in self.load_model via ModelWrapper.unwrap. """ assert isinstance(model, GeminiDDP), "Please boost the model before loading!" - super().load_unsharded_model(model, checkpoint, strict=strict) + super().load_unsharded_model( + model, checkpoint, strict=strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads + ) def save_unsharded_optimizer( self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool, use_async: bool = False @@ -131,13 +139,17 @@ def save_unsharded_optimizer( else: save_state_dict(state_dict, checkpoint, use_safetensors=False) - def load_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str): + def load_unsharded_optimizer( + self, optimizer: GeminiOptimizer, checkpoint: str, low_cpu_mem_mode: bool = True, num_threads: int = 1 + ): """ Loading unsharded optimizer from checkpoint file. For each process, only loading optimizer states of parameters it controls. """ assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!" - super().load_unsharded_optimizer(optimizer, checkpoint) + super().load_unsharded_optimizer( + optimizer, checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads + ) def save_sharded_model( self, @@ -206,13 +218,27 @@ def save_sharded_model( ) def load_sharded_model( - self, model: GeminiDDP, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False + self, + model: GeminiDDP, + checkpoint_index_file: Path, + strict: bool = False, + use_safetensors: bool = False, + low_cpu_mem_mode: bool = True, + num_threads: int = 1, ): """ Load shard model, load model from multiple files. """ assert isinstance(model, GeminiDDP), "Please boost the model before loading!" - return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False) + return super().load_sharded_model( + model, + checkpoint_index_file, + strict, + use_safetensors, + load_sub_module=False, + low_cpu_mem_mode=low_cpu_mem_mode, + num_threads=num_threads, + ) def save_sharded_optimizer( self, @@ -289,7 +315,14 @@ def save_sharded_optimizer( ranks=[0], ) - def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_file: Path, prefix: str): + def load_sharded_optimizer( + self, + optimizer: GeminiOptimizer, + checkpoint_index_file: Path, + prefix: str, + low_cpu_mem_mode: bool = True, + num_threads: int = 1, + ): """ Loading sharded optimizer from checkpoint folder, with index file given. For each process, only loading optimizer states of parameters it controls. @@ -322,9 +355,9 @@ def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_fi state_dict_shard = load_flat(shard_file) else: state_dict_shard = load_shard_state_dict(Path(shard_file), use_safetensors=False) + if not low_cpu_mem_mode: + state_dict_shard = create_pinned_state_dict(state_dict_shard, empty=False, num_threads=num_threads) optimizer.load_param_states(state_dict_shard) - del state_dict_shard - gc.collect() optimizer.optimizer_loading_epilogue() diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 6e91bd8ed117..0bb4ae9ede5b 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -20,6 +20,7 @@ from colossalai.accelerator import get_accelerator from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO from colossalai.checkpoint_io.utils import ( + create_pinned_state_dict, get_optimizer_base_filenames, get_shard_filename, load_param_groups_into_optimizer, @@ -145,7 +146,9 @@ def save_unsharded_optimizer( else: save_state_dict(state_dict, checkpoint, use_safetensors=False) - def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str): + def load_unsharded_optimizer( + self, optimizer: OptimizerWrapper, checkpoint: str, low_cpu_mem_mode: bool = True, num_threads: int = 1 + ): use_async = checkpoint.endswith(".safetensors") if use_async: from colossalai.utils.safetensors import load_flat @@ -153,6 +156,8 @@ def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str) checkpoint = load_flat(checkpoint) else: checkpoint = load_state_dict(checkpoint) + if not low_cpu_mem_mode: + checkpoint = create_pinned_state_dict(checkpoint, empty=False, num_threads=num_threads) optimizer.load_state_dict(checkpoint) def save_sharded_optimizer( @@ -239,7 +244,14 @@ def save_sharded_optimizer( ranks=[0], ) - def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: str, prefix: str): + def load_sharded_optimizer( + self, + optimizer: OptimizerWrapper, + index_file_path: str, + prefix: str, + low_cpu_mem_mode: bool = True, + num_threads: int = 1, + ): """Load sharded optimizer with the given path to index file. Args: @@ -283,14 +295,28 @@ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: s if padding_size > 0: v = torch.nn.functional.pad(v, [0, padding_size]) v_list = v.split(v.numel() // self.coordinator.world_size) - state_dict[param_idx][k] = v_list[self.coordinator.rank].detach().clone() + state_dict[param_idx][k] = v_list[self.coordinator.rank].detach() + if low_cpu_mem_mode: + state_dict[param_idx][k] = state_dict[param_idx][k].clone() + + if not low_cpu_mem_mode: + state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads) load_states_into_optimizer(optimizer, state_dict, id_map) sharded_optimizer_loading_epilogue(optimizer) - def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True): + def load_unsharded_model( + self, + model: ModelWrapper, + checkpoint: str, + strict: bool = True, + low_cpu_mem_mode: bool = True, + num_threads: int = 1, + ): assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!" model._force_wait_all_gather() - super().load_unsharded_model(model, checkpoint, strict) + super().load_unsharded_model( + model, checkpoint, strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads + ) model.update_master_params() def load_sharded_model( @@ -300,10 +326,20 @@ def load_sharded_model( strict: bool = False, use_safetensors: bool = False, load_sub_module: bool = True, + low_cpu_mem_mode: bool = True, + num_threads: int = 1, ): assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!" model._force_wait_all_gather() - super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module) + super().load_sharded_model( + model, + checkpoint_index_file, + strict, + use_safetensors, + load_sub_module, + low_cpu_mem_mode=low_cpu_mem_mode, + num_threads=num_threads, + ) model.update_master_params() def save_unsharded_model( diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index 90d406eefaa3..acec7e82d437 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -26,12 +26,21 @@ def __init__(self) -> None: self.coordinator = DistCoordinator() self.logger = get_dist_logger() - def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True): + def load_unsharded_model( + self, + model: ModelWrapper, + checkpoint: str, + strict: bool = True, + low_cpu_mem_mode: bool = True, + num_threads: int = 1, + ): """ Load model from checkpoint. """ assert isinstance(model, ModelWrapper), "Please boost the model before loading!" - super().load_unsharded_model(model.unwrap(), checkpoint, strict=strict) + super().load_unsharded_model( + model.unwrap(), checkpoint, strict=strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads + ) def save_unsharded_model( self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False @@ -45,12 +54,16 @@ def save_unsharded_model( model.unwrap(), checkpoint, gather_dtensor, use_safetensors, use_async=use_async ) - def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str): + def load_unsharded_optimizer( + self, optimizer: OptimizerWrapper, checkpoint: str, low_cpu_mem_mode: bool = True, num_threads: int = 1 + ): """ Load optimizer from checkpoint. """ assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!" - super().load_unsharded_optimizer(optimizer, checkpoint) + super().load_unsharded_optimizer( + optimizer, checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads + ) def save_unsharded_optimizer( self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False @@ -101,12 +114,22 @@ def load_sharded_model( strict: bool = False, use_safetensors: bool = False, load_sub_module: bool = True, + low_cpu_mem_mode: bool = True, + num_threads: int = 1, ): """ Load model from sharded checkpoint. """ assert isinstance(model, ModelWrapper), "Please boost the model before loading!" - super().load_sharded_model(model.unwrap(), checkpoint_index_file, strict, use_safetensors, load_sub_module) + super().load_sharded_model( + model.unwrap(), + checkpoint_index_file, + strict, + use_safetensors, + load_sub_module, + low_cpu_mem_mode=low_cpu_mem_mode, + num_threads=num_threads, + ) def save_sharded_optimizer( self, @@ -131,12 +154,16 @@ def load_sharded_optimizer( optimizer: Optimizer, index_file_path: str, prefix: Optional[str] = None, + low_cpu_mem_mode: bool = True, + num_threads: int = 1, ): """ Load optimizer from sharded checkpoint. """ assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!" - super().load_sharded_optimizer(optimizer.unwrap(), index_file_path, prefix) + super().load_sharded_optimizer( + optimizer.unwrap(), index_file_path, prefix, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads + ) def save_lora_as_pretrained( self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False diff --git a/colossalai/booster/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/torch_fsdp_plugin.py index 1d792757b9de..182fb3c7bdaa 100644 --- a/colossalai/booster/plugin/torch_fsdp_plugin.py +++ b/colossalai/booster/plugin/torch_fsdp_plugin.py @@ -43,13 +43,17 @@ def __init__(self) -> None: self.coordinator = DistCoordinator() self.logger = get_dist_logger() - def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool): + def load_unsharded_model( + self, model: ModelWrapper, checkpoint: str, strict: bool, low_cpu_mem_mode: bool = True, num_threads: int = 1 + ): assert isinstance(model, TorchFSDPModel), "Please boost the model before loading!" model = model.unwrap() checkpoint = utils.load_state_dict(checkpoint) model.load_state_dict(checkpoint) - def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: Path): + def load_unsharded_optimizer( + self, optimizer: OptimizerWrapper, checkpoint: Path, low_cpu_mem_mode: bool = True, num_threads: int = 1 + ): assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before loading!" if checkpoint.endswith(".safetensors"): checkpoint = load_flat(checkpoint, seperator=".") @@ -232,6 +236,8 @@ def load_sharded_model( strict: bool = False, use_safetensors: bool = False, load_sub_module: bool = True, + low_cpu_mem_mode: bool = True, + num_threads: int = 1, ): """ Load model to checkpoint but only on master process. @@ -354,7 +360,14 @@ def pack_group(group: Dict[str, Any]) -> Dict[str, Any]: f"index located at {save_index_file}." ) - def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, size_per_shard: int): + def load_sharded_optimizer( + self, + optimizer: Optimizer, + index_file_path: str, + size_per_shard: int, + low_cpu_mem_mode: bool = True, + num_threads: int = 1, + ): """ Load optimizer to checkpoint but only on master process. """ diff --git a/colossalai/checkpoint_io/checkpoint_io_base.py b/colossalai/checkpoint_io/checkpoint_io_base.py index c67020e979ac..40024f8a86dc 100644 --- a/colossalai/checkpoint_io/checkpoint_io_base.py +++ b/colossalai/checkpoint_io/checkpoint_io_base.py @@ -85,7 +85,12 @@ def __del__(self): self._sync_io() def load_model( - self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True + self, + model: Union[nn.Module, ModelWrapper], + checkpoint: str, + strict: bool = True, + low_cpu_mem_mode: bool = True, + num_threads: int = 1, ) -> Union[nn.Module, ModelWrapper]: """ Load model from checkpoint. @@ -100,6 +105,8 @@ def load_model( Distributed tensors cannot be loaded directly unless gathered offline via our CLI. strict (bool): whether to strictly enforce that the param name in the checkpoint match the keys returned by this module's. + low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True. + num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1. """ # since we only support loaded sharded and unsharded weight format # containing no distributed tensors, dtensor -> full tensor conversion @@ -111,17 +118,25 @@ def load_model( origin_model = model if index_file_exists: - self.load_sharded_model(model, index_file_path, strict) + self.load_sharded_model( + model, index_file_path, strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads + ) else: path = Path(checkpoint, SAFE_WEIGHTS_NAME) if path.is_file(): - self.load_unsharded_model(model, str(path), strict) + self.load_unsharded_model( + model, str(path), strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads + ) else: path = Path(checkpoint, WEIGHTS_NAME) if path.is_file(): - self.load_unsharded_model(model, str(path), strict) + self.load_unsharded_model( + model, str(path), strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads + ) else: - self.load_unsharded_model(model, checkpoint, strict) + self.load_unsharded_model( + model, checkpoint, strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads + ) return origin_model @@ -178,7 +193,14 @@ def save_model( else: self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async) - def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = None, size_per_shard: int = 1024): + def load_optimizer( + self, + optimizer: Optimizer, + checkpoint: str, + prefix: str = None, + low_cpu_mem_mode: bool = True, + num_threads: int = 1, + ): """ Load optimizer from checkpoint. @@ -187,7 +209,8 @@ def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = No checkpoint (str): checkpoint path. This value is made compatibility with the model checkpoints in the prefix (str, optional): A prefix added to parameter and buffer names to compose the keys in state_dict. Defaults to None. - size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024. + low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True. + num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1. """ index_file_exists, index_file_path = has_index_file(checkpoint) @@ -198,9 +221,13 @@ def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = No if index_file_exists: # the existence of index file means it is a sharded checkpoint - self.load_sharded_optimizer(optimizer, index_file_path, prefix) + self.load_sharded_optimizer( + optimizer, index_file_path, prefix, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads + ) else: - self.load_unsharded_optimizer(optimizer, checkpoint) + self.load_unsharded_optimizer( + optimizer, checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads + ) def save_optimizer( self, @@ -238,7 +265,9 @@ def save_optimizer( # Abstract methods for model loading/saving implementation # ======================================================== @abstractmethod - def load_sharded_model(self, model: nn.Module, index_file_path: str, strict: bool): + def load_sharded_model( + self, model: nn.Module, index_file_path: str, strict: bool, low_cpu_mem_mode: bool = True, num_threads: int = 1 + ): """ Load model from sharded checkpoint. @@ -247,10 +276,14 @@ def load_sharded_model(self, model: nn.Module, index_file_path: str, strict: boo index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file. strict (bool): whether to strictly enforce that the param name in the checkpoint match the keys returned by this module's. + low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True. + num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1. """ @abstractmethod - def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool): + def load_unsharded_model( + self, model: nn.Module, checkpoint: str, strict: bool, low_cpu_mem_mode: bool = True, num_threads: int = 1 + ): """ Load model from unsharded checkpoint. @@ -259,6 +292,8 @@ def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool): checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary. strict (bool): whether to strictly enforce that the param name in the checkpoint match the keys returned by this module's. + low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True. + num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1. """ @abstractmethod @@ -303,7 +338,14 @@ def save_unsharded_model( # ======================================================== @abstractmethod - def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str): + def load_sharded_optimizer( + self, + optimizer: Optimizer, + index_file_path: str, + prefix: str, + low_cpu_mem_mode: bool = True, + num_threads: int = 1, + ): """ Load optimizer from sharded checkpoint. @@ -311,16 +353,22 @@ def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, pre optimizer (Optimizer): optimizer to be loaded. index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file. prefix (str): prefix for the optimizer checkpoint. + low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True. + num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1. """ @abstractmethod - def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path): + def load_unsharded_optimizer( + self, optimizer: Optimizer, checkpoint: Path, low_cpu_mem_mode: bool = True, num_threads: int = 1 + ): """ Load optimizer from unsharded checkpoint. Args: optimizer (Optimizer): optimizer to be loaded. checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary. + low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True. + num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1. """ @abstractmethod diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index f6bf1bb4a71d..78404f908cab 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -1,4 +1,3 @@ -import gc import logging import os from functools import reduce @@ -40,8 +39,17 @@ class GeneralCheckpointIO(CheckpointIO): Checkpoint IO """ - def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool): + def load_unsharded_model( + self, + model: nn.Module, + checkpoint: str, + strict: bool, + low_cpu_mem_mode: bool = True, + num_threads: int = 1, + ): checkpoint = load_state_dict(checkpoint) + if not low_cpu_mem_mode: + checkpoint = create_pinned_state_dict(checkpoint, empty=False, num_threads=num_threads) model.load_state_dict(checkpoint, strict=strict) def save_unsharded_model( @@ -60,7 +68,14 @@ def save_unsharded_model( # save the checkpoint save_state_dict(state_dict, checkpoint, use_safetensors) - def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str): + def load_sharded_optimizer( + self, + optimizer: Optimizer, + index_file_path: str, + prefix: str, + low_cpu_mem_mode: bool = True, + num_threads: int = 1, + ): """ Load sharded optimizer with the given path to index file. """ @@ -84,6 +99,8 @@ def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, pre state_dict = load_flat(shard_file) else: state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False) + if not low_cpu_mem_mode: + state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads) load_states_into_optimizer(optimizer, state_dict, id_map) sharded_optimizer_loading_epilogue(optimizer) @@ -158,11 +175,15 @@ def save_sharded_optimizer( f"index located at {save_index_file}." ) - def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path): + def load_unsharded_optimizer( + self, optimizer: Optimizer, checkpoint: Path, low_cpu_mem_mode: bool = True, num_threads: int = 1 + ): if checkpoint.endswith(".safetensors"): checkpoint = load_flat(checkpoint) else: checkpoint = load_state_dict(checkpoint) + if not low_cpu_mem_mode: + checkpoint = create_pinned_state_dict(checkpoint, empty=False, num_threads=num_threads) optimizer.load_state_dict(checkpoint) def save_unsharded_optimizer( @@ -256,6 +277,8 @@ def load_sharded_model( strict: bool = False, use_safetensors: bool = False, load_sub_module: bool = True, + low_cpu_mem_mode: bool = True, + num_threads: int = 1, ): """ load shard model, load model from multiple files @@ -274,9 +297,9 @@ def load_sharded_model( for shard_file in checkpoint_files: state_dict = load_shard_state_dict(Path(shard_file), use_safetensors) + if not low_cpu_mem_mode: + state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads) load_state_dict_into_model(model, state_dict, missing_keys, strict, load_sub_module) - del state_dict - gc.collect() if strict: remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys)) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 0a2e598ca619..154d5cb5e5f3 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -355,7 +355,14 @@ def save_sharded_model( f"index located at {final_index_file_path}." ) - def load_sharded_model(self, model: ModelWrapper, checkpoint_index_file: Path, strict: bool = False): + def load_sharded_model( + self, + model: ModelWrapper, + checkpoint_index_file: Path, + strict: bool = False, + low_cpu_mem_mode: bool = True, + num_threads: int = 1, + ): """ Load sharded model with the given path to index file of checkpoint folder. @@ -403,6 +410,8 @@ def _load(name: str): file_path = os.path.join(ckpt_root_path, filename) state_dict = load_shard_state_dict(Path(file_path), use_safetensors) + if not low_cpu_mem_mode: + state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads) load_state_dict_into_model( model, state_dict, missing_keys=missing_keys, strict=strict, load_sub_module=True @@ -632,7 +641,14 @@ def save_sharded_optimizer( f"index located at {final_index_file_path}." ) - def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""): + def load_sharded_optimizer( + self, + optimizer: OptimizerWrapper, + checkpoint_index_file: str, + prefix: str = "", + low_cpu_mem_mode: bool = True, + num_threads: int = 1, + ): """ Load sharded optimizer with the given path to index file of checkpoint folder. @@ -706,6 +722,8 @@ def _get_param_id_from_optimizer_param( state_dict = load_flat(file_path) else: state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False) + if not low_cpu_mem_mode: + state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads) load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True) loaded_file.add(filename) @@ -789,7 +807,14 @@ def save_unsharded_model( else: save_state_dict(complete_state_dict, checkpoint, use_safetensors) - def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = False): + def load_unsharded_model( + self, + model: ModelWrapper, + checkpoint: str, + strict: bool = False, + low_cpu_mem_mode: bool = True, + num_threads: int = 1, + ): """ Load model from a single file with the given path of checkpoint. @@ -812,6 +837,8 @@ def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: boo # has been implemented by _load_from_state_dict method of ParallelModule in Shardformer, # model.load_state_dict can be directly called. state_dict = load_state_dict(checkpoint) + if not low_cpu_mem_mode: + state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads) model.load_state_dict(state_dict, strict=strict) # Update master params if mixed-precision training is enabled. @@ -912,7 +939,9 @@ def save_unsharded_optimizer( else: save_state_dict(state_dict, checkpoint, use_safetensors=False) - def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str): + def load_unsharded_optimizer( + self, optimizer: OptimizerWrapper, checkpoint: str, low_cpu_mem_mode: bool = True, num_threads: int = 1 + ): """ Load optimizer from a file with given path. @@ -940,6 +969,8 @@ def _get_param_id_from_optimizer_param( state_dict = load_flat(checkpoint) else: state_dict = load_state_dict(checkpoint) + if not low_cpu_mem_mode: + state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads) # Load param_groups. updated_groups = [] diff --git a/colossalai/checkpoint_io/moe_checkpoint.py b/colossalai/checkpoint_io/moe_checkpoint.py index f6aefd33a9f5..04655dec5b77 100644 --- a/colossalai/checkpoint_io/moe_checkpoint.py +++ b/colossalai/checkpoint_io/moe_checkpoint.py @@ -510,7 +510,14 @@ def save_sharded_optimizer( f"index located at {final_index_file_path}." ) - def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""): + def load_sharded_optimizer( + self, + optimizer: OptimizerWrapper, + checkpoint_index_file: str, + prefix: str = "", + low_cpu_mem_mode: bool = True, + num_threads: int = 1, + ): """ Load sharded optimizer with the given path to index file of checkpoint folder. @@ -795,7 +802,14 @@ def save_unsharded_optimizer( dist.barrier() # Copied from colossalai.moe - def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, strict: bool = False): + def load_unsharded_optimizer( + self, + optimizer: OptimizerWrapper, + checkpoint: str, + strict: bool = False, + low_cpu_mem_mode: bool = True, + num_threads: int = 1, + ): """ Load optimizer from a file with given path. diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 71422f4c2dcc..7b322b657211 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -1,18 +1,20 @@ # coding=utf-8 +import concurrent.futures import os import re from collections import abc as container_abcs from collections import defaultdict from itertools import chain from pathlib import Path -from typing import Dict, Iterator, List, Mapping, Optional, OrderedDict, Tuple +from typing import Dict, Iterator, List, Mapping, Optional, OrderedDict, Tuple, Union import torch import torch.nn as nn from packaging.version import Version from torch.optim import Optimizer -from torch.utils._pytree import tree_map +from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten +from colossalai.accelerator import get_accelerator from colossalai.tensor.d_tensor import ( is_customized_distributed_tensor, is_distributed_tensor, @@ -791,7 +793,7 @@ def cast(param, value, key=None): if key != "step": if param.is_floating_point(): value = value.to(param.dtype) - value = value.to(param.device) + value = value.to(param.device, non_blocking=True) return value elif isinstance(value, dict): return {k: cast(param, v, key=k) for k, v in value.items()} @@ -811,6 +813,7 @@ def cast(param, value, key=None): elif not strict: new_states[k] = v + get_accelerator().synchronize() optimizer.state.update(new_states) @@ -945,8 +948,27 @@ def get_shard_filename(weights_name: str, idx: int): return shard_file -def create_pinned_state_dict(state_dict: Dict[str, torch.Tensor]): - pin_mem = dict() - for name, tensor in state_dict.items(): - pin_mem[name] = torch.empty_like(tensor, pin_memory=True, device="cpu") - return pin_mem +def _pin_tensor(tensor: torch.Tensor, empty: bool = True) -> torch.Tensor: + if empty: + return torch.empty_like(tensor, pin_memory=True, device="cpu") + return tensor.pin_memory() + + +def create_pinned_state_dict( + state_dict: Union[Dict[str, torch.Tensor], Dict[int, Dict[str, torch.Tensor]]], + empty: bool = True, + num_threads: int = 1, +) -> Dict[str, torch.Tensor]: + if num_threads == 1: + return tree_map(lambda x: _pin_tensor(x, empty=empty) if isinstance(x, torch.Tensor) else x, state_dict) + else: + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + elems, spec = tree_flatten(state_dict) + future_to_idx = {} + for i, elem in enumerate(elems): + if isinstance(elem, torch.Tensor): + future_to_idx[executor.submit(_pin_tensor, elem, empty)] = i + for future in concurrent.futures.as_completed(future_to_idx): + idx = future_to_idx[future] + elems[idx] = future.result() + return tree_unflatten(elems, spec) diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py index a6d65cae5953..53dd3c8dd3ba 100644 --- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -90,8 +90,16 @@ def exam_state_dict_with_origin( @parameterize("tp_size", [1, 2]) @parameterize("zero_size", [2]) @parameterize("use_async", [False, True]) +@parameterize("low_cpu_mem_mode", [True, False]) def exam_state_dict( - placement_config, shard: bool, model_name: str, size_per_shard: int, tp_size: int, zero_size: int, use_async: bool + placement_config, + shard: bool, + model_name: str, + size_per_shard: int, + tp_size: int, + zero_size: int, + use_async: bool, + low_cpu_mem_mode: bool, ): (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) criterion = lambda x: x.mean() @@ -147,12 +155,12 @@ def exam_state_dict( booster.checkpoint_io._sync_io() dist.barrier() - booster.load_model(new_model, model_ckpt_path) + booster.load_model(new_model, model_ckpt_path, low_cpu_mem_mode=low_cpu_mem_mode) check_state_dict_equal( model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), ignore_dtype=True ) - booster.load_optimizer(new_optimizer, optimizer_ckpt_path) + booster.load_optimizer(new_optimizer, optimizer_ckpt_path, low_cpu_mem_mode=low_cpu_mem_mode) check_state_dict_equal(optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(only_rank_0=False)) for group in new_optimizer.param_groups: assert group["lr"] == 0.1 diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py index 81d184f7681a..a338d98f4746 100644 --- a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py @@ -43,8 +43,11 @@ @parameterize("size_per_shard", [32]) @parameterize("test_config", TEST_CONFIGS) @parameterize("use_async", [False, True]) +@parameterize("low_cpu_mem_mode", [False, True]) @clear_cache_before_run() -def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict, use_async: bool): +def exam_state_dict( + shard: bool, model_name: str, size_per_shard: int, test_config: dict, use_async: bool, low_cpu_mem_mode: bool +): (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) = next( iter(model_zoo.get_sub_registry(model_name).values()) ) @@ -102,9 +105,9 @@ def _preprocess_data(data): new_optimizer = Adam(new_model.parameters(), lr=1e-3) new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion) - booster.load_model(new_model, model_ckpt_path) + booster.load_model(new_model, model_ckpt_path, low_cpu_mem_mode=low_cpu_mem_mode) check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict()) - booster.load_optimizer(new_optimizer, optimizer_ckpt_path) + booster.load_optimizer(new_optimizer, optimizer_ckpt_path, low_cpu_mem_mode=low_cpu_mem_mode) check_state_dict_equal(optimizer.unwrap().state_dict(), new_optimizer.unwrap().state_dict()) dist.barrier() diff --git a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py index 05dfcce4f674..9f0180d52882 100644 --- a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py @@ -29,7 +29,8 @@ @parameterize("shard", [False, True]) @parameterize("offload", [False, True]) @parameterize("use_async", [False, True]) -def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, use_async: bool): +@parameterize("low_cpu_mem_mode", [False, True]) +def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, use_async: bool, low_cpu_mem_mode: bool): plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=32, cpu_offload=offload) booster = Booster(plugin=plugin) model = resnet18() @@ -70,7 +71,7 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, us new_optimizer = HybridAdam((new_model.parameters()), lr=0.001) new_model, new_optimizer, _, _, _ = booster.boost(new_model, new_optimizer) - booster.load_model(new_model, model_ckpt_path) + booster.load_model(new_model, model_ckpt_path, low_cpu_mem_mode=low_cpu_mem_mode) check_state_dict_equal(model.state_dict(), new_model.state_dict()) # check master weight assert isinstance(new_optimizer, LowLevelZeroOptimizer) @@ -85,7 +86,7 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, us working_shard, master_param.data.view(-1).to(dtype=padded_param.dtype, device=padded_param.device) ) - booster.load_optimizer(new_optimizer, optimizer_ckpt_path) + booster.load_optimizer(new_optimizer, optimizer_ckpt_path, low_cpu_mem_mode=low_cpu_mem_mode) check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict()) torch.cuda.empty_cache() diff --git a/tests/test_checkpoint_io/test_safetensors_async_io.py b/tests/test_checkpoint_io/test_safetensors_async_io.py index 7de73b46bd79..a0b51eae3c88 100644 --- a/tests/test_checkpoint_io/test_safetensors_async_io.py +++ b/tests/test_checkpoint_io/test_safetensors_async_io.py @@ -1,108 +1,144 @@ import tempfile +import pytest import torch from safetensors.torch import load_file +from colossalai.checkpoint_io.utils import create_pinned_state_dict from colossalai.testing import check_state_dict_equal, clear_cache_before_run from colossalai.utils import get_current_device from colossalai.utils.safetensors import load_flat, move_and_save, save, save_nested +def gen_optim_state_dict(): + return { + "state": { + 0: { + "step": torch.tensor(1.0), + "exp_avg": torch.rand((1024, 1024)), + "exp_avg_sq": torch.rand((1024, 1024)), + }, + 1: { + "step": torch.tensor(1.0), + "exp_avg": torch.rand((1024, 1024)), + "exp_avg_sq": torch.rand((1024, 1024)), + }, + 2: { + "step": torch.tensor(1.0), + "exp_avg": torch.rand((1024, 1024)), + "exp_avg_sq": torch.rand((1024, 1024)), + }, + }, + "param_groups": [ + { + "lr": 0.001, + "betas": (0.9, 0.999), + "eps": 1e-08, + "weight_decay": 0, + "bias_correction": True, + "params": [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + 30, + 31, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + 50, + 51, + 52, + 53, + 54, + 55, + 56, + 57, + 58, + 59, + 60, + 61, + ], + } + ], + } + + +def gen_model_state_dict(): + return { + "module.weight0": torch.rand((1024, 1024)), + "module.weight1": torch.rand((1024, 1024)), + "module.weight2": torch.rand((1024, 1024)), + } + + +@pytest.mark.parametrize("empty", [True, False]) +@pytest.mark.parametrize("num_threads", [1, 4]) +def test_create_pin(empty: bool, num_threads: int): + model_state_dict = gen_model_state_dict() + model_state_dict_pinned = create_pinned_state_dict(model_state_dict, empty=empty, num_threads=num_threads) + for k in model_state_dict.keys(): + assert model_state_dict_pinned[k].is_pinned() + if not empty: + assert torch.equal(model_state_dict_pinned[k], model_state_dict[k]) + optim_state_dict = gen_optim_state_dict() + optim_state_dict_pinned = create_pinned_state_dict(optim_state_dict, empty=empty, num_threads=num_threads) + for k in optim_state_dict.keys(): + if k == "state": + for idx in optim_state_dict[k].keys(): + for kk in optim_state_dict[k][idx].keys(): + assert optim_state_dict_pinned[k][idx][kk].is_pinned() + if not empty: + assert torch.equal(optim_state_dict_pinned[k][idx][kk], optim_state_dict[k][idx][kk]) + else: + assert optim_state_dict[k] == optim_state_dict_pinned[k] + + @clear_cache_before_run() def test_save_load(): with tempfile.TemporaryDirectory() as tempdir: - optimizer_state_dict = { - "state": { - 0: { - "step": torch.tensor(1.0), - "exp_avg": torch.rand((1024, 1024)), - "exp_avg_sq": torch.rand((1024, 1024)), - }, - 1: { - "step": torch.tensor(1.0), - "exp_avg": torch.rand((1024, 1024)), - "exp_avg_sq": torch.rand((1024, 1024)), - }, - 2: { - "step": torch.tensor(1.0), - "exp_avg": torch.rand((1024, 1024)), - "exp_avg_sq": torch.rand((1024, 1024)), - }, - }, - "param_groups": [ - { - "lr": 0.001, - "betas": (0.9, 0.999), - "eps": 1e-08, - "weight_decay": 0, - "bias_correction": True, - "params": [ - 0, - 1, - 2, - 3, - 4, - 5, - 6, - 7, - 8, - 9, - 10, - 11, - 12, - 13, - 14, - 15, - 16, - 17, - 18, - 19, - 20, - 21, - 22, - 23, - 24, - 25, - 26, - 27, - 28, - 29, - 30, - 31, - 32, - 33, - 34, - 35, - 36, - 37, - 38, - 39, - 40, - 41, - 42, - 43, - 44, - 45, - 46, - 47, - 48, - 49, - 50, - 51, - 52, - 53, - 54, - 55, - 56, - 57, - 58, - 59, - 60, - 61, - ], - } - ], - } + optimizer_state_dict = gen_optim_state_dict() optimizer_saved_path = f"{tempdir}/save_optimizer.safetensors" f_writer = save_nested(optimizer_saved_path, optimizer_state_dict) @@ -120,11 +156,7 @@ def test_save_load(): load_state_dict_shard = load_flat(optimizer_shard_saved_path) check_state_dict_equal(load_state_dict_shard, optimizer_state_dict["state"]) - model_state_dict = { - "module.weight0": torch.rand((1024, 1024)), - "module.weight1": torch.rand((1024, 1024)), - "module.weight2": torch.rand((1024, 1024)), - } + model_state_dict = gen_model_state_dict() model_saved_path = f"{tempdir}/save_model.safetensors" f_writer = save(model_saved_path, model_state_dict) f_writer.sync_before_step() diff --git a/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py b/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py index b90ea0960c8d..f3d1085bfba2 100644 --- a/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py @@ -15,7 +15,8 @@ @parameterize("shard", [False, True]) @parameterize("size_per_shard", [16, 128]) @parameterize("use_async", [False, True]) -def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int, use_async: bool): +@parameterize("low_cpu_mem_mode", [False, True]) +def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int, use_async: bool, low_cpu_mem_mode: bool): plugin = TorchDDPPlugin() booster = Booster(plugin=plugin) model = resnet18() @@ -61,10 +62,10 @@ def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int, use_async: bo new_model, new_optimizer, lr_scheduler=new_scheduler ) - booster.load_model(new_model, model_ckpt_path) + booster.load_model(new_model, model_ckpt_path, low_cpu_mem_mode=low_cpu_mem_mode) check_state_dict_equal(model.state_dict(), new_model.state_dict()) - booster.load_optimizer(new_optimizer, optimizer_ckpt_path) + booster.load_optimizer(new_optimizer, optimizer_ckpt_path, low_cpu_mem_mode=low_cpu_mem_mode) check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict()) booster.load_lr_scheduler(new_scheduler, lr_scheduler_ckpt_path) check_state_dict_equal(scheduler.state_dict(), new_scheduler.state_dict())