Skip to content

Commit

Permalink
Add support for signal-based dynamic checkpointing
Browse files Browse the repository at this point in the history
  • Loading branch information
szmigacz authored and jaredcasper committed Nov 24, 2021
1 parent d6380fd commit dbe6c72
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 1 deletion.
1 change: 1 addition & 0 deletions megatron/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .global_vars import get_args
from .global_vars import get_current_global_batch_size
from .global_vars import get_num_microbatches
from .global_vars import get_signal_handler
from .global_vars import update_num_microbatches
from .global_vars import get_tokenizer
from .global_vars import get_tensorboard_writer
Expand Down
6 changes: 6 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,9 @@ def _add_logging_args(parser):
group.add_argument('--log-memory-to-tensorboard',
action='store_true',
help='Enable memory logging to tensorboard.')
group.add_argument('--log-world-size-to-tensorboard',
action='store_true',
help='Enable world size logging to tensorboard.')

return parser

Expand Down Expand Up @@ -472,6 +475,9 @@ def _add_training_args(parser):
'by this value.')
group.add_argument('--exit-duration-in-mins', type=int, default=None,
help='Exit the program after this many minutes.')
group.add_argument('--exit-signal-handler', action='store_true',
help='Dynamically save the checkpoint and shutdown the '
'training if SIGTERM is received')
group.add_argument('--tensorboard-dir', type=str, default=None,
help='Write TensorBoard logs to this directory.')
group.add_argument('--no-masked-softmax-fusion',
Expand Down
81 changes: 81 additions & 0 deletions megatron/dist_signal_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import signal

import torch


def get_world_size():
if torch.distributed.is_available() and torch.distributed.is_initialized():
world_size = torch.distributed.get_world_size()
else:
world_size = 1
return world_size


def get_device(local_rank=None):
backend = torch.distributed.get_backend()
if backend == 'nccl':
if local_rank is None:
device = torch.device('cuda')
else:
device = torch.device(f'cuda:{local_rank}')
elif backend == 'gloo':
device = torch.device('cpu')
else:
raise RuntimeError
return device


def all_gather_item(item, dtype, group=None, async_op=False, local_rank=None):
if not torch.distributed.is_available() or \
not torch.distributed.is_initialized():
return [item]

device = get_device(local_rank)

if group is not None:
group_size = group.size()
else:
group_size = get_world_size()

tensor = torch.tensor([item], device=device, dtype=dtype)
output_tensors = [
torch.zeros(1, dtype=tensor.dtype, device=tensor.device)
for _ in range(group_size)
]
torch.distributed.all_gather(output_tensors, tensor, group, async_op)
output = [elem.item() for elem in output_tensors]
return output


class DistributedSignalHandler:
def __init__(self, sig=signal.SIGTERM):
self.sig = sig

def signals_received(self):
all_received = all_gather_item(
self._signal_received, dtype=torch.int32
)
return all_received

def __enter__(self):
self._signal_received = False
self.released = False
self.original_handler = signal.getsignal(self.sig)

def handler(signum, frame):
self._signal_received = True

signal.signal(self.sig, handler)

return self

def __exit__(self, type, value, tb):
self.release()

def release(self):
if self.released:
return False

signal.signal(self.sig, self.original_handler)
self.released = True
return True
13 changes: 13 additions & 0 deletions megatron/global_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import torch

from megatron import dist_signal_handler
from megatron.tokenizer import build_tokenizer
from .arguments import parse_args
from .microbatches import build_num_microbatches_calculator
Expand All @@ -31,6 +32,7 @@
_GLOBAL_TENSORBOARD_WRITER = None
_GLOBAL_ADLR_AUTORESUME = None
_GLOBAL_TIMERS = None
_GLOBAL_SIGNAL_HANDLER = None


def get_args():
Expand Down Expand Up @@ -75,6 +77,14 @@ def get_timers():
_ensure_var_is_initialized(_GLOBAL_TIMERS, 'timers')
return _GLOBAL_TIMERS

def get_signal_handler():
_ensure_var_is_initialized(_GLOBAL_SIGNAL_HANDLER, 'signal handler')
return _GLOBAL_SIGNAL_HANDLER

def _set_signal_handler():
global _GLOBAL_SIGNAL_HANDLER
_ensure_var_is_not_initialized(_GLOBAL_SIGNAL_HANDLER, 'signal handler')
_GLOBAL_SIGNAL_HANDLER = dist_signal_handler.DistributedSignalHandler().__enter__()

def set_global_variables(extra_args_provider=None, args_defaults={},
ignore_unknown_args=False):
Expand All @@ -89,6 +99,9 @@ def set_global_variables(extra_args_provider=None, args_defaults={},
_set_adlr_autoresume(args)
_set_timers()

if args.exit_signal_handler:
_set_signal_handler()


def _parse_args(extra_args_provider=None, defaults={},
ignore_unknown_args=False):
Expand Down
2 changes: 1 addition & 1 deletion megatron/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def _initialize_distributed():
torch.distributed.init_process_group(
backend=args.distributed_backend,
world_size=args.world_size, rank=args.rank,
timeout=timedelta(days=7))
timeout=timedelta(minutes=10))

# Set the tensor model-parallel, pipeline model-parallel, and
# data-parallel communicators.
Expand Down
13 changes: 13 additions & 0 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP

from megatron import get_args
from megatron import get_signal_handler
from megatron import get_timers
from megatron import get_tensorboard_writer
from megatron import get_current_global_batch_size
Expand Down Expand Up @@ -544,6 +545,10 @@ def add_to_logging(name):
writer.add_scalar('loss-scale', loss_scale, iteration)
writer.add_scalar('loss-scale vs samples', loss_scale,
args.consumed_train_samples)
if args.log_world_size_to_tensorboard:
writer.add_scalar('world-size', args.world_size, iteration)
writer.add_scalar('world-size vs samples', args.world_size,
args.consumed_train_samples)
if grad_norm is not None:
writer.add_scalar('grad-norm', grad_norm, iteration)
writer.add_scalar('grad-norm vs samples', grad_norm,
Expand Down Expand Up @@ -698,6 +703,14 @@ def train(forward_step_func, model, optimizer, lr_scheduler,

# Checkpointing
saved_checkpoint = False
if args.exit_signal_handler:
signal_handler = get_signal_handler()
if any(signal_handler.signals_received()):
save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler)
print_datetime('exiting program after receiving SIGTERM.')
sys.exit()

if args.save and args.save_interval and \
iteration % args.save_interval == 0:
save_checkpoint_and_time(iteration, model, optimizer,
Expand Down

0 comments on commit dbe6c72

Please sign in to comment.