Skip to content

Commit

Permalink
Rearrange the logging a bit - log everything in one place after each …
Browse files Browse the repository at this point in the history
…epoch
  • Loading branch information
AngledLuffa committed Jan 16, 2025
1 parent 10099c1 commit 6587190
Showing 1 changed file with 14 additions and 15 deletions.
29 changes: 14 additions & 15 deletions stanza/models/constituency/parser_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,10 +425,22 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d
trainer.save(args['save_name'], save_optimizer=False)
if epoch_stats.nans > 0:
tlogger.warning("Had to ignore %d batches with NaN", epoch_stats.nans)
# TODO: refactor the logging?
total_correct = sum(v for _, v in epoch_stats.transitions_correct.items())
correct_transitions_str = "\n ".join(["%s: %d" % (x, epoch_stats.transitions_correct[x]) for x in epoch_stats.transitions_correct])
tlogger.info("Transitions correct: %d\n %s", total_correct, correct_transitions_str)
total_incorrect = sum(v for _, v in epoch_stats.transitions_incorrect.items())
incorrect_transitions_str = "\n ".join(["%s: %d" % (x, epoch_stats.transitions_incorrect[x]) for x in epoch_stats.transitions_incorrect])
tlogger.info("Transitions incorrect: %d\n %s", total_incorrect, incorrect_transitions_str)
if len(epoch_stats.repairs_used) > 0:
tlogger.info("Oracle repairs:\n %s", "\n ".join("%s (%s): %d" % (x.name, x.value, y) for x, y in epoch_stats.repairs_used.most_common()))
if epoch_stats.fake_transitions_used > 0:
tlogger.info("Fake transitions used: %d", epoch_stats.fake_transitions_used)

stats_log_lines = [
"Epoch %d finished" % trainer.epochs_trained,
"Transitions correct: %s" % epoch_stats.transitions_correct,
"Transitions incorrect: %s" % epoch_stats.transitions_incorrect,
"Transitions correct: %d" % total_correct,
"Transitions incorrect: %s" % total_incorrect,
"Total loss for epoch: %.5f" % epoch_stats.epoch_loss,
"Dev score (%5d): %8f" % (trainer.epochs_trained, f1),
"Best dev score (%5d): %8f" % (trainer.best_epoch, trainer.best_f1)
Expand Down Expand Up @@ -549,19 +561,6 @@ def train_model_one_epoch(epoch, trainer, transition_tensors, process_outputs, m
optimizer.zero_grad()
epoch_stats = epoch_stats + batch_stats


# TODO: refactor the logging?
total_correct = sum(v for _, v in epoch_stats.transitions_correct.items())
correct_transitions_str = "\n ".join(["%s: %d" % (x, epoch_stats.transitions_correct[x]) for x in epoch_stats.transitions_correct])
tlogger.info("Transitions correct: %d\n %s", total_correct, correct_transitions_str)
total_incorrect = sum(v for _, v in epoch_stats.transitions_incorrect.items())
incorrect_transitions_str = "\n ".join(["%s: %d" % (x, epoch_stats.transitions_incorrect[x]) for x in epoch_stats.transitions_incorrect])
tlogger.info("Transitions incorrect: %d\n %s", total_incorrect, incorrect_transitions_str)
if len(epoch_stats.repairs_used) > 0:
tlogger.info("Oracle repairs:\n %s", "\n ".join("%s (%s): %d" % (x.name, x.value, y) for x, y in epoch_stats.repairs_used.most_common()))
if epoch_stats.fake_transitions_used > 0:
tlogger.info("Fake transitions used: %d", epoch_stats.fake_transitions_used)

return epoch_stats

def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_tensors, process_outputs, model_loss_function, oracle, args):
Expand Down

0 comments on commit 6587190

Please sign in to comment.