-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
added checkpointing to support LLMs #114
Conversation
This PR addressed: #88 |
@zhenghh04 I can confirm with the profiler that this change to checkpointing accurately represents the checkpointing in deepspeed. Additionally, the indexed_binary and mmap_indexed_binary are the two modes used in megatron deepspeed for data reading and the calls are accurate. You can merge this if it looks good to u. |
faee51e
to
87c195b
Compare
8a1fb5a
to
8537d35
Compare
8537d35
to
03796ad
Compare
@hariharan-devarajan could you take a look at the conflict, and make sure that the check-pointing writes are performed with the storage API functions which apply also for S3 storage. |
|
||
train: | ||
epochs: 1 | ||
computation_time: 0.064296 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the computation time from running the real workload?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is based on the configuration used in the PR #88
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have to validate this after merging the PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed
if self.model_state: | ||
fname = os.path.join(self.checkpoint_folder, f"model-{epoch}-{step_number}-{my_rank}.pt") | ||
with open(fname, "wb") as f: | ||
torch.save(self.model_state, f) | ||
if self.optimization_state: | ||
fname = os.path.join(self.checkpoint_folder, f"optimizer-{epoch}-{step_number}-{my_rank}.pt") | ||
with open(fname, "wb") as f: | ||
torch.save(self.optimization_state, f) | ||
|
||
if self.layer_state and self.args.num_layers > 0: | ||
for layer in range(self.args.num_layers): | ||
fname = os.path.join(self.checkpoint_folder, f"layer-{layer}-{epoch}-{step_number}-{my_rank}.pt") | ||
with open(fname, "wb") as f: | ||
torch.save(self.layer_state, f) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sure the conflict is solved.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
@zhenghh04 The original code uses TensorFlow and PyTorch APIs to save. This is needed as we are storing complex tensors. How would this work with S3? I think we need that fspec type interface for abstracting storage not manual abstraction. Thoughts? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for all the implementation. This feature implemented here is very useful. Please address the issues raise up.
@@ -17,6 +17,16 @@ | |||
|
|||
from enum import Enum | |||
|
|||
class CheckpointType(Enum): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe rename this is IOType instead of CheckpointType?
Check point looks like more different kinds of checkpoint? We can use it as for example, only checkpoint model, optimization state, etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about CheckpointIOType Just IOType might confuse with Reading.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I named its CheckpointLocationType as RANK_ZERO or ALL_RANKS
|
||
train: | ||
epochs: 1 | ||
computation_time: 0.064296 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have to validate this after merging the PR.
""" | ||
super().generate() | ||
np.random.seed(10) | ||
GB=1024**3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please change GB=1073741824.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FIXED
sample_size = dim1 * dim2 | ||
total_size = sample_size * self.num_samples | ||
write_size = total_size | ||
MEMORY_SIZE = 2*GB |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we allow user to configure this using environment variable, with a default value of 2GB?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
under dataset, I will add a configuration called generation_buffer_size. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
if self.args.checkpoint_type == CheckpointType.COLLECTIVE: | ||
rank_to_checkpoint = 0 | ||
if rank_to_checkpoint == self.args.my_rank: | ||
num_ranks = 1 | ||
if self.args.checkpoint_type == CheckpointType.COLLECTIVE: | ||
num_ranks = self.args.comm_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does it mean for COLLECTIVE, is it every rank writing data?
Lines 62-63 and Lines 58-59 are inconsistent to each other.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Collective basically means in the context of checkpointing is that all data is collected by rank zero and written. I am open for a better word to describe it. Maybe Aggregated and Per-Process?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CALLED IT RANK_ZERO
if self.args.checkpoint_type == CheckpointType.COLLECTIVE: | ||
num_ranks = self.args.comm_size | ||
if self.args.model_size > 0: | ||
self.model_state = {"a": self._get_tensor(self.args.model_size*num_ranks)} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model_size is the size of the model, right?
It is confusing there, to have model_size * num_ranks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model_size is size of model per GPU.
We can define it as absolute model size of app in which case.
- For Per GPU case we need to divide it.
Else if it is per GPU then we will have to multiply it for the Collective case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Explained correctly in Doc
Change GB to a abs value.
b83819e
to
ddc92ff
Compare
ddc92ff
to
0c058ce
Compare
6dfa60e
to
3f28662
Compare
Args model size
b3f4427
to
3727e5a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR looks good now.
But we need to validating DLRM and Magatron-Deepspeed config files. I'll create two issues to keep track of this.
Changes to support Microsoft's Megatron Deepspeed.