Skip to content

Commit

Permalink
[checkpointio] support load-pin overlap (#6177)
Browse files Browse the repository at this point in the history
* [checkpointio] support load-pin overlap

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [test] add conftest

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
ver217 and pre-commit-ci[bot] authored Jan 7, 2025
1 parent 479067e commit ee81366
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 32 deletions.
11 changes: 4 additions & 7 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
create_pinned_state_dict,
get_model_base_filenames,
get_optimizer_base_filenames,
load_shard_state_dict,
load_state_dict_shards,
save_config_file,
save_state_dict,
save_state_dict_shards,
Expand All @@ -29,7 +29,6 @@
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.utils.safetensors import load_flat
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.memory_tracer import MemStats

Expand Down Expand Up @@ -350,11 +349,9 @@ def load_sharded_optimizer(

# Load optimizer states from shard files under checkpoint path.
# For each file, only load the states managed by current process.
for shard_file in checkpoint_files:
if shard_file.endswith(".safetensors"):
state_dict_shard = load_flat(shard_file)
else:
state_dict_shard = load_shard_state_dict(Path(shard_file), use_safetensors=False)
for state_dict_shard in load_state_dict_shards(
checkpoint_files, True, False, low_cpu_mem_mode=low_cpu_mem_mode
):
if not low_cpu_mem_mode:
state_dict_shard = create_pinned_state_dict(state_dict_shard, empty=False, num_threads=num_threads)
optimizer.load_param_states(state_dict_shard)
Expand Down
10 changes: 2 additions & 8 deletions colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
get_optimizer_base_filenames,
get_shard_filename,
load_param_groups_into_optimizer,
load_shard_state_dict,
load_state_dict,
load_state_dict_shards,
load_states_into_optimizer,
save_param_groups,
save_state_dict,
Expand Down Expand Up @@ -276,13 +276,7 @@ def load_sharded_optimizer(

checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()

for shard_file in checkpoint_files:
if shard_file.endswith(".safetensors"):
from colossalai.utils.safetensors import load_flat

state_dict = load_flat(shard_file)
else:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
for state_dict in load_state_dict_shards(checkpoint_files, True, False, low_cpu_mem_mode):
# shard state dict
for param_idx, state in state_dict.items():
for k, v in state.items():
Expand Down
10 changes: 3 additions & 7 deletions colossalai/booster/plugin/torch_fsdp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,8 @@ def load_sharded_model(
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()

fsdp_state_dict = {}
for shard_file in checkpoint_files:
fsdp_state_dict.update(utils.load_shard_state_dict(Path(shard_file), use_safetensors))
for state_dict in utils.load_state_dict_shards(checkpoint_files, False, use_safetensors):
fsdp_state_dict.update(state_dict)

with FSDP.state_dict_type(model.unwrap(), StateDictType.FULL_STATE_DICT):
model.unwrap().load_state_dict(fsdp_state_dict, strict=False)
Expand Down Expand Up @@ -388,11 +388,7 @@ def load_sharded_optimizer(
# Load param
fsdp_optim_state = {}
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
for shard_file in checkpoint_files:
if shard_file.endswith(".safetensors"):
state_dict_shard = load_flat(shard_file, seperator=".")
else:
state_dict_shard = utils.load_shard_state_dict(Path(shard_file), use_safetensors=False)
for state_dict_shard in utils.load_state_dict_shards(checkpoint_files, True, False):
fsdp_optim_state.update(state_dict_shard)

fsdp_optim_dict = dict(state=fsdp_optim_state, param_groups=saved_param_groups)
Expand Down
11 changes: 3 additions & 8 deletions colossalai/checkpoint_io/general_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
get_optimizer_base_filenames,
is_safetensors_available,
load_param_groups_into_optimizer,
load_shard_state_dict,
load_state_dict,
load_state_dict_into_model,
load_state_dict_shards,
load_states_into_optimizer,
save_config_file,
save_param_groups,
Expand Down Expand Up @@ -94,11 +94,7 @@ def load_sharded_optimizer(

checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()

for shard_file in checkpoint_files:
if shard_file.endswith(".safetensors"):
state_dict = load_flat(shard_file)
else:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
for state_dict in load_state_dict_shards(checkpoint_files, True, False, low_cpu_mem_mode):
if not low_cpu_mem_mode:
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
load_states_into_optimizer(optimizer, state_dict, id_map)
Expand Down Expand Up @@ -295,8 +291,7 @@ def load_sharded_model(
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
missing_keys = []

for shard_file in checkpoint_files:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors)
for state_dict in load_state_dict_shards(checkpoint_files, False, use_safetensors, low_cpu_mem_mode):
if not low_cpu_mem_mode:
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
load_state_dict_into_model(model, state_dict, missing_keys, strict, load_sub_module)
Expand Down
36 changes: 34 additions & 2 deletions colossalai/checkpoint_io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections import defaultdict
from itertools import chain
from pathlib import Path
from typing import Dict, Iterator, List, Mapping, Optional, OrderedDict, Tuple, Union
from typing import Dict, Generator, Iterator, List, Mapping, Optional, OrderedDict, Tuple, Union

import torch
import torch.nn as nn
Expand All @@ -21,7 +21,7 @@
to_global,
to_global_for_customized_distributed_tensor,
)
from colossalai.utils.safetensors import _flatten_optim_state_dict
from colossalai.utils.safetensors import _flatten_optim_state_dict, load_flat

SAFE_WEIGHTS_NAME = "model.safetensors"
WEIGHTS_NAME = "pytorch_model.bin"
Expand Down Expand Up @@ -972,3 +972,35 @@ def create_pinned_state_dict(
idx = future_to_idx[future]
elems[idx] = future.result()
return tree_unflatten(elems, spec)


def load_optim_or_model_shard(path: str, is_optim: bool, use_safetensors: bool) -> dict:
if is_optim:
if path.endswith(".safetensors"):
state_dict = load_flat(path)
else:
state_dict = load_shard_state_dict(Path(path), use_safetensors=False)
else:
state_dict = load_shard_state_dict(Path(path), use_safetensors)
return state_dict


def load_state_dict_shards(
checkpoint_files: List[str],
is_optim: bool,
use_safetensors: bool,
low_cpu_mem_mode: bool = True,
prefetch: int = 3,
) -> Generator[dict, None, None]:
if low_cpu_mem_mode:
for shard_file in checkpoint_files:
state_dict = load_optim_or_model_shard(shard_file, is_optim, use_safetensors)
yield state_dict
else:
with concurrent.futures.ThreadPoolExecutor(max_workers=prefetch) as executor:
futures = []
for shard_file in checkpoint_files:
future = executor.submit(load_optim_or_model_shard, shard_file, is_optim, use_safetensors)
futures.append(future)
for future in concurrent.futures.as_completed(futures):
yield future.result()
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import gc

from colossalai.accelerator import get_accelerator


def pytest_runtest_setup(item):
# called for running each test in 'a' directory
accelerator = get_accelerator()
accelerator.empty_cache()
gc.collect()

0 comments on commit ee81366

Please sign in to comment.