diff --git a/stanza/models/constituency/parser_training.py b/stanza/models/constituency/parser_training.py index 59bb6d090..86f904a36 100644 --- a/stanza/models/constituency/parser_training.py +++ b/stanza/models/constituency/parser_training.py @@ -447,6 +447,8 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d "Epoch %d finished" % trainer.epochs_trained, "Transitions correct: %d" % total_correct, "Transitions incorrect: %d" % total_incorrect, + "Transition loss for epoch: %.5f" % epoch_stats.transition_loss, + "Similarity loss for epoch: %.5f" % epoch_stats.similarity_loss, "Total loss for epoch: %.5f" % epoch_stats.total_loss, "Dev score (%5d): %8f" % (trainer.epochs_trained, f1), "Best dev score (%5d): %8f" % (trainer.best_epoch, trainer.best_f1) @@ -597,18 +599,20 @@ def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_te gold_results = model.analyze_trees([x.tree for x in training_batch], keep_output_layers=True) errors = [error_analysis_in_order.analyze_tree(result.gold, result.predictions[0].tree) for result in reparsed_results] - similarities_inputs = [] - similarities_targets = [] + similarity_inputs = [] + similarity_targets = [] for reparsed, gold, first_error in zip(reparsed_results, gold_results, errors): if reparsed.predictions[0].tree == reparsed.gold: continue error_type, gold_index, pred_index = error_analysis_in_order.analyze_tree(reparsed.gold, reparsed.predictions[0].tree) if gold_index is None or pred_index is None: continue - similarities_inputs.append(reparsed.output_layers[pred_index]) - similarities_targets.append(gold.output_layers[gold_index]) - if len(similarities_inputs) > 0: - similarity_loss = similarity_loss_function(torch.stack(similarities_inputs), torch.stack(similarities_targets)) + similarity_inputs.append(reparsed.output_layers[pred_index]) + similarity_targets.append(gold.output_layers[gold_index]) + if len(similarity_inputs) > 0: + similarity_inputs = torch.stack(similarity_inputs) + similarity_targets = torch.stack(similarity_targets) + similarity_loss = similarity_loss_function(similarity_inputs, similarity_targets) * args['similarity_learning_rate'] all_errors = [] all_answers = []