From ab911129097dfc5c944918efdac7dd0b2d96aca2 Mon Sep 17 00:00:00 2001 From: Han Zhu <1106766460@qq.com> Date: Thu, 9 Jan 2025 15:05:38 +0800 Subject: [PATCH] Improve infinity-check (#1862) 1. Attach the inf-check hooks if the grad scale is getting too small. 2. Add try-catch to avoid OOM in the inf-check hooks. 3. Set warmup_start=0.1 to reduce chances of divergence --- egs/librispeech/ASR/zipformer/train.py | 25 ++++++++++++++++++------- icefall/hooks.py | 26 ++++++++++++++++++-------- 2 files changed, 36 insertions(+), 15 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index c074c32ec7..f8864d58b7 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -1165,23 +1165,34 @@ def save_bad_model(suffix: str = ""): rank=rank, ) - if batch_idx % 100 == 0 and params.use_autocast: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. + if params.use_autocast: cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) if cur_grad_scale < 0.01: if not saved_bad_model: save_bad_model(suffix="-first-warning") saved_bad_model = True + if not params.inf_check: + register_inf_check_hooks(model) logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: save_bad_model() raise_grad_scale_is_too_small_error(cur_grad_scale) + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + if ( + batch_idx % 25 == 0 + and cur_grad_scale < 2.0 + or batch_idx % 100 == 0 + and cur_grad_scale < 8.0 + or batch_idx % 400 == 0 + and cur_grad_scale < 32.0 + ): + scaler.update(cur_grad_scale * 2.0) + if batch_idx % params.log_interval == 0: cur_lr = max(scheduler.get_last_lr()) cur_grad_scale = scaler._scale.item() if params.use_autocast else 1.0 @@ -1335,7 +1346,7 @@ def run(rank, world_size, args): clipping_scale=2.0, ) - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs, warmup_start=0.1) if checkpoints and "optimizer" in checkpoints: logging.info("Loading optimizer state dict") diff --git a/icefall/hooks.py b/icefall/hooks.py index 83f2750faf..85583acbe2 100644 --- a/icefall/hooks.py +++ b/icefall/hooks.py @@ -39,24 +39,34 @@ def register_inf_check_hooks(model: nn.Module) -> None: # default param _name is a way to capture the current value of the variable "name". def forward_hook(_module, _input, _output, _name=name): if isinstance(_output, Tensor): - if not torch.isfinite(_output.to(torch.float32).sum()): - logging.warning(f"The sum of {_name}.output is not finite") + try: + if not torch.isfinite(_output.to(torch.float32).sum()): + logging.warning(f"The sum of {_name}.output is not finite") + except RuntimeError: # e.g. CUDA out of memory + pass elif isinstance(_output, tuple): for i, o in enumerate(_output): if isinstance(o, tuple): o = o[0] if not isinstance(o, Tensor): continue - if not torch.isfinite(o.to(torch.float32).sum()): - logging.warning(f"The sum of {_name}.output[{i}] is not finite") + try: + if not torch.isfinite(o.to(torch.float32).sum()): + logging.warning( + f"The sum of {_name}.output[{i}] is not finite" + ) + except RuntimeError: # e.g. CUDA out of memory + pass # default param _name is a way to capture the current value of the variable "name". def backward_hook(_module, _input, _output, _name=name): if isinstance(_output, Tensor): - if not torch.isfinite(_output.to(torch.float32).sum()): - logging.warning( - f"The sum of {_name}.grad is not finite" # ": {_output}" - ) + try: + if not torch.isfinite(_output.to(torch.float32).sum()): + logging.warning(f"The sum of {_name}.grad is not finite") + except RuntimeError: # e.g. CUDA out of memory + pass + elif isinstance(_output, tuple): for i, o in enumerate(_output): if isinstance(o, tuple):