forked from NVIDIA/Megatron-LM
-
Notifications
You must be signed in to change notification settings - Fork 349
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for signal-based dynamic checkpointing
- Loading branch information
1 parent
d6380fd
commit dbe6c72
Showing
6 changed files
with
115 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
import signal | ||
|
||
import torch | ||
|
||
|
||
def get_world_size(): | ||
if torch.distributed.is_available() and torch.distributed.is_initialized(): | ||
world_size = torch.distributed.get_world_size() | ||
else: | ||
world_size = 1 | ||
return world_size | ||
|
||
|
||
def get_device(local_rank=None): | ||
backend = torch.distributed.get_backend() | ||
if backend == 'nccl': | ||
if local_rank is None: | ||
device = torch.device('cuda') | ||
else: | ||
device = torch.device(f'cuda:{local_rank}') | ||
elif backend == 'gloo': | ||
device = torch.device('cpu') | ||
else: | ||
raise RuntimeError | ||
return device | ||
|
||
|
||
def all_gather_item(item, dtype, group=None, async_op=False, local_rank=None): | ||
if not torch.distributed.is_available() or \ | ||
not torch.distributed.is_initialized(): | ||
return [item] | ||
|
||
device = get_device(local_rank) | ||
|
||
if group is not None: | ||
group_size = group.size() | ||
else: | ||
group_size = get_world_size() | ||
|
||
tensor = torch.tensor([item], device=device, dtype=dtype) | ||
output_tensors = [ | ||
torch.zeros(1, dtype=tensor.dtype, device=tensor.device) | ||
for _ in range(group_size) | ||
] | ||
torch.distributed.all_gather(output_tensors, tensor, group, async_op) | ||
output = [elem.item() for elem in output_tensors] | ||
return output | ||
|
||
|
||
class DistributedSignalHandler: | ||
def __init__(self, sig=signal.SIGTERM): | ||
self.sig = sig | ||
|
||
def signals_received(self): | ||
all_received = all_gather_item( | ||
self._signal_received, dtype=torch.int32 | ||
) | ||
return all_received | ||
|
||
def __enter__(self): | ||
self._signal_received = False | ||
self.released = False | ||
self.original_handler = signal.getsignal(self.sig) | ||
|
||
def handler(signum, frame): | ||
self._signal_received = True | ||
|
||
signal.signal(self.sig, handler) | ||
|
||
return self | ||
|
||
def __exit__(self, type, value, tb): | ||
self.release() | ||
|
||
def release(self): | ||
if self.released: | ||
return False | ||
|
||
signal.signal(self.sig, self.original_handler) | ||
self.released = True | ||
return True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters