Skip to content

Commit

Permalink
skip NaN batches
Browse files Browse the repository at this point in the history
  • Loading branch information
Guitaricet committed Aug 14, 2023
1 parent 3784fb4 commit 96aa492
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 9 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ checkpoints
wandb
ignore
experimental_data
fine_tuning_results*
log
notebooks/*.pdf
notebooks/*.png
Expand Down
28 changes: 19 additions & 9 deletions torchrun_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,8 @@ def main(args):
# global steps and others are defined above
update_time = time.time()
local_step = 0 # when warmed_up_model is used, local_step != global_step
loss_info = torch.tensor([0.0, 0.0], device=device) # loss, n_batches
loss_info = torch.tensor([0.0, 0.0, 0.0], device=device) # loss, n_batches, n_NaNs
n_skipped_batches = 0

# ##############################
# TRAINING LOOP
Expand Down Expand Up @@ -747,10 +748,10 @@ def main(args):
tokens_seen += batch["input_ids"].numel() * world_size

loss = model(**batch, labels=batch["input_ids"]).loss
assert not torch.isnan(loss), "Loss is nan"

loss_info[0] += loss.detach()
loss_info[1] += 1
loss_info[2] += torch.isnan(loss).float()

if global_step == 0 and global_rank == 0:
# log loss without any optimization
Expand All @@ -768,12 +769,26 @@ def main(args):
if args.clip_grad_norm > 0:
torch.nn.utils.clip_grad_norm_(trainable_params, args.clip_grad_norm, error_if_nonfinite=True)

optimizer.step()
scheduler.step()
dist.all_reduce(loss_info, op=dist.ReduceOp.SUM)
_loss = loss_info[0] / loss_info[1] # loss to log in wandb below

if loss_info[2] == 0: # no NaNs, update model
optimizer.step()
scheduler.step()
else:
logger.error(f"Nan detected in loss_info, {_loss=}, skipping update")
n_skipped_batches += 1

if n_skipped_batches > 0.05 * args.num_training_steps:
logger.error(f"More than 5% of batches skipped due to NaNs, stopping training.")
break

optimizer.zero_grad()
update_step += 1
update_time = time.time() - update_time

loss_info = torch.zeros_like(loss_info)

if local_step > args.gradient_accumulation and update_step % args.save_every == 0:
current_model_directory = f"{args.save_dir}/model_{update_step}"
logger.info(f"Saving model and optimizer to {current_model_directory}, update step {update_step}")
Expand Down Expand Up @@ -852,11 +867,6 @@ def main(args):
tokens_seen_before = tokens_seen
batches_in_update = args.gradient_accumulation * world_size

# log loss without any optimization
dist.all_reduce(loss_info, op=dist.ReduceOp.SUM)
_loss = loss_info[0] / loss_info[1]
loss_info = torch.tensor([0.0, 0.0], device=device)

if global_rank == 0:
wandb.log({
"loss": _loss,
Expand Down

0 comments on commit 96aa492

Please sign in to comment.