-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[checkpointio]support distributed checkpoint io for model saving. #6181
base: feature/dist-ckp-io
Are you sure you want to change the base?
Conversation
307d8f1
to
2a15001
Compare
|
||
MODEL_META_PREFIX = "pytorch_model-meta-dist-" | ||
MODEL_WEIGHT_PREFIX = "pytorch_model-dist-" | ||
MODEL_SHARD_SUUFIX = ".index.json" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SHARD_META_SUFFIX?
colossalai/checkpoint_io/__init__.py
Outdated
@@ -10,4 +11,5 @@ | |||
"GeneralCheckpointIO", | |||
"HybridParallelCheckpointIO", | |||
"MoECheckpointIO", | |||
"DistributedCheckpointIO", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should not be an independent checkpoint io class. It should provide some utils functions for each current checkpoint io class.
hi all, take a look at this please. This bug is quite annoying for me. |
e8659ea
to
51c208c
Compare
ok |
e77d1e3
to
e3f9de3
Compare
for more information, see https://pre-commit.ci
return destination | ||
|
||
|
||
def load_state_dict_into_dist_model( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this function for? Is it for loading whole state dict? Default model.load_state_dict()
has already implemented this feature.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The parallelmodule will perform the gather tensor operation.
tp_partition_dim = search_tp_partition_dim( | ||
current_shape=param.shape, original_shape=original_shape, tp_size=tp_size | ||
) | ||
model_metadata[prefix + name]["offsets"] = torch.zeros(len(original_shape), dtype=torch.int) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not use list directly?
def create_model_metadata( | ||
model: nn.Module, | ||
prefix: str = "", | ||
tp_size=None, | ||
tp_rank=None, | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that this function is only intended for TP. What about Gemini? If it's only designed for TP, then move it to hybrid parallel checkpoint io file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DP support can be added in the future.
if key not in covered_shards or rank not in covered_shards[key]: | ||
continue | ||
if dtype == None: | ||
dtype = weight.dtype | ||
covered_shards[key][rank]["weight"] = weight | ||
state_dict = {} | ||
for key, shards in covered_shards.items(): | ||
state = assemble_tensor_from_shards_partial( | ||
shards, model_metadata[key]["offsets"], model_metadata[key]["lengths"], dtype=dtype | ||
) | ||
state_dict[key] = state | ||
|
||
if not low_cpu_mem_mode: | ||
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads) | ||
|
||
load_state_dict_into_dist_model(model=model, state_dict=state_dict) | ||
|
||
# Update master params if mixed-precision training is enabled. | ||
model_before_wrapping.update_master_params() | ||
|
||
|
||
def save_dist_sharded_model( | ||
model: ModelWrapper, | ||
model_metadata: Dict, | ||
checkpoint: str, | ||
prefix: Optional[str] = None, | ||
size_per_shard: int = 1024, | ||
use_safetensors: bool = False, | ||
use_async: bool = False, | ||
dist_id: int = 0, | ||
pinned_state_dicts=None, | ||
) -> None: | ||
""" | ||
Save sharded model checkpoint under the given checkpointing path. | ||
The following files will be created under the path: | ||
- An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names. | ||
- Multiple files that store state tensors of models. | ||
If pipeline parallelism is used, the filenames are in the form of "pytorch_model.<prefix>-stage-000XX-shard-000XX.bin". | ||
If pipeline parallelism is not used, "pytorch_model.<prefix>-000XX.bin" | ||
|
||
|
||
Args: | ||
model (nn.Module): Model on local device to be saved. | ||
checkpoint (str): Checkpointing path which should be a directory path. | ||
gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True. | ||
prefix (str, optional): Perfix of file to save. Defaults to None. | ||
size_per_shard (int, optional): Size per shard in MB. Defaults to 1024. | ||
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False. | ||
use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False. | ||
""" | ||
|
||
model = model.unwrap() | ||
|
||
if os.path.isfile(checkpoint): | ||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") | ||
return | ||
|
||
Path(checkpoint).mkdir(parents=True, exist_ok=True) | ||
# Devices along the same dp_group share the same copies of model. | ||
# So only let the device with dp_rank == 0 and sp_rank == 0 save the model. | ||
|
||
if use_async: | ||
if id(model) not in pinned_state_dicts: | ||
pinned_state_dicts[id(model)] = {} | ||
pinned_state_dicts = pinned_state_dicts[id(model)] | ||
else: | ||
pinned_state_dicts = None | ||
state_dict_shard = dist_model_sharder(model, size_per_shard=size_per_shard, pinned_state_dicts=pinned_state_dicts) | ||
weights_name, _ = get_model_base_filenames(prefix, use_safetensors) | ||
index_file = CheckpointIndexFile(checkpoint) | ||
|
||
# Manage filenames of sharded weights and index file for each pipeline stage. | ||
weights_name = weights_name.replace(".bin", f"-dist-{dist_id:05d}-shard.bin") | ||
weights_name = weights_name.replace(".safetensors", f"-dist-{dist_id:05d}-shard.safetensors") | ||
metadata_file = os.path.join(checkpoint, f"{MODEL_META_PREFIX}{dist_id:05d}{SHARD_META_SUFFIX}") | ||
async_writers = [] | ||
if use_async: | ||
total_size, writers = async_save_state_dict_shards( | ||
sharded_state_dict=state_dict_shard, | ||
checkpoint=checkpoint, | ||
index_file=index_file, | ||
base_filename=weights_name, | ||
is_master=True, | ||
state_preprocess=False, | ||
) | ||
async_writers.extend(writers) | ||
else: | ||
total_size = save_state_dict_shards( | ||
sharded_state_dict=state_dict_shard, | ||
checkpoint=checkpoint, | ||
index_file=index_file, | ||
base_filename=weights_name, | ||
is_master=True, | ||
use_safetensors=use_safetensors, | ||
use_pp_format=True, | ||
) | ||
for k, _ in model_metadata.items(): | ||
model_metadata[k]["file"] = index_file.get_checkpoint_file(k) | ||
|
||
save_metadata(model_metadata, metadata_file, total_size=total_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it's only designed for hybrid parallel, then move it to hybrid parallel checkpoint io file. AND too many redundant codes. Please try to reuse some common code snippets.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The format of metadata is uniform, and save_metadata is generic.
DON'T merge to main. Create a new feature branch on the org repo and merge to it. |
6a8a917
to
c5b0882
Compare
if isinstance(v, torch.Tensor): | ||
v = v.tolist() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This case won't occur now, right?
def create_model_metadata( | ||
model: ModelWrapper, | ||
prefix: str = "", | ||
tp_size: int = None, | ||
tp_rank: int = None, | ||
zero_size: int = None, | ||
zero_rank: int = None, | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is not general and should be provided by each checkpoint io class
9984a64
to
8e6902c
Compare
📌 Checklist before creating the PR
[doc/gemini/tensor/...]: A concise description
pip install pre-commit && pre-commit install
🚨 Issue number
📝 What does this PR do?
💥 Checklist before requesting a review
⭐️ Do you enjoy contributing to Colossal-AI?
Tell us more if you don't enjoy contributing to Colossal-AI.