From 65871905e2ad3036e54f58b6a80bf34baf80c2b4 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Wed, 15 Jan 2025 17:17:05 -0800 Subject: [PATCH] Rearrange the logging a bit - log everything in one place after each epoch --- stanza/models/constituency/parser_training.py | 29 +++++++++---------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/stanza/models/constituency/parser_training.py b/stanza/models/constituency/parser_training.py index 5c53d61eb..50ac36b92 100644 --- a/stanza/models/constituency/parser_training.py +++ b/stanza/models/constituency/parser_training.py @@ -425,10 +425,22 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d trainer.save(args['save_name'], save_optimizer=False) if epoch_stats.nans > 0: tlogger.warning("Had to ignore %d batches with NaN", epoch_stats.nans) + # TODO: refactor the logging? + total_correct = sum(v for _, v in epoch_stats.transitions_correct.items()) + correct_transitions_str = "\n ".join(["%s: %d" % (x, epoch_stats.transitions_correct[x]) for x in epoch_stats.transitions_correct]) + tlogger.info("Transitions correct: %d\n %s", total_correct, correct_transitions_str) + total_incorrect = sum(v for _, v in epoch_stats.transitions_incorrect.items()) + incorrect_transitions_str = "\n ".join(["%s: %d" % (x, epoch_stats.transitions_incorrect[x]) for x in epoch_stats.transitions_incorrect]) + tlogger.info("Transitions incorrect: %d\n %s", total_incorrect, incorrect_transitions_str) + if len(epoch_stats.repairs_used) > 0: + tlogger.info("Oracle repairs:\n %s", "\n ".join("%s (%s): %d" % (x.name, x.value, y) for x, y in epoch_stats.repairs_used.most_common())) + if epoch_stats.fake_transitions_used > 0: + tlogger.info("Fake transitions used: %d", epoch_stats.fake_transitions_used) + stats_log_lines = [ "Epoch %d finished" % trainer.epochs_trained, - "Transitions correct: %s" % epoch_stats.transitions_correct, - "Transitions incorrect: %s" % epoch_stats.transitions_incorrect, + "Transitions correct: %d" % total_correct, + "Transitions incorrect: %s" % total_incorrect, "Total loss for epoch: %.5f" % epoch_stats.epoch_loss, "Dev score (%5d): %8f" % (trainer.epochs_trained, f1), "Best dev score (%5d): %8f" % (trainer.best_epoch, trainer.best_f1) @@ -549,19 +561,6 @@ def train_model_one_epoch(epoch, trainer, transition_tensors, process_outputs, m optimizer.zero_grad() epoch_stats = epoch_stats + batch_stats - - # TODO: refactor the logging? - total_correct = sum(v for _, v in epoch_stats.transitions_correct.items()) - correct_transitions_str = "\n ".join(["%s: %d" % (x, epoch_stats.transitions_correct[x]) for x in epoch_stats.transitions_correct]) - tlogger.info("Transitions correct: %d\n %s", total_correct, correct_transitions_str) - total_incorrect = sum(v for _, v in epoch_stats.transitions_incorrect.items()) - incorrect_transitions_str = "\n ".join(["%s: %d" % (x, epoch_stats.transitions_incorrect[x]) for x in epoch_stats.transitions_incorrect]) - tlogger.info("Transitions incorrect: %d\n %s", total_incorrect, incorrect_transitions_str) - if len(epoch_stats.repairs_used) > 0: - tlogger.info("Oracle repairs:\n %s", "\n ".join("%s (%s): %d" % (x.name, x.value, y) for x, y in epoch_stats.repairs_used.most_common())) - if epoch_stats.fake_transitions_used > 0: - tlogger.info("Fake transitions used: %d", epoch_stats.fake_transitions_used) - return epoch_stats def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_tensors, process_outputs, model_loss_function, oracle, args):