Skip to content

Commit

Permalink
Log the similarity loss. Properly scale by the learning rate
Browse files Browse the repository at this point in the history
  • Loading branch information
AngledLuffa committed Feb 23, 2025
1 parent 4dc6692 commit e6c8aab
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions stanza/models/constituency/parser_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,8 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d
"Epoch %d finished" % trainer.epochs_trained,
"Transitions correct: %d" % total_correct,
"Transitions incorrect: %d" % total_incorrect,
"Transition loss for epoch: %.5f" % epoch_stats.transition_loss,
"Similarity loss for epoch: %.5f" % epoch_stats.similarity_loss,
"Total loss for epoch: %.5f" % epoch_stats.total_loss,
"Dev score (%5d): %8f" % (trainer.epochs_trained, f1),
"Best dev score (%5d): %8f" % (trainer.best_epoch, trainer.best_f1)
Expand Down Expand Up @@ -597,18 +599,20 @@ def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_te
gold_results = model.analyze_trees([x.tree for x in training_batch], keep_output_layers=True)
errors = [error_analysis_in_order.analyze_tree(result.gold, result.predictions[0].tree) for result in reparsed_results]

similarities_inputs = []
similarities_targets = []
similarity_inputs = []
similarity_targets = []
for reparsed, gold, first_error in zip(reparsed_results, gold_results, errors):
if reparsed.predictions[0].tree == reparsed.gold:
continue
error_type, gold_index, pred_index = error_analysis_in_order.analyze_tree(reparsed.gold, reparsed.predictions[0].tree)
if gold_index is None or pred_index is None:
continue
similarities_inputs.append(reparsed.output_layers[pred_index])
similarities_targets.append(gold.output_layers[gold_index])
if len(similarities_inputs) > 0:
similarity_loss = similarity_loss_function(torch.stack(similarities_inputs), torch.stack(similarities_targets))
similarity_inputs.append(reparsed.output_layers[pred_index])
similarity_targets.append(gold.output_layers[gold_index])
if len(similarity_inputs) > 0:
similarity_inputs = torch.stack(similarity_inputs)
similarity_targets = torch.stack(similarity_targets)
similarity_loss = similarity_loss_function(similarity_inputs, similarity_targets) * args['similarity_learning_rate']

all_errors = []
all_answers = []
Expand Down

0 comments on commit e6c8aab

Please sign in to comment.