From 5bf876aecfe2bf8c9817bf33c779ee7746814943 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Tue, 17 Dec 2024 00:33:38 -0800 Subject: [PATCH] Try using normalized vectors with an MSELoss instead of Cosine. The CosineEmbeddingLoss was occasionally hitting nan... --- stanza/models/constituency/parser_training.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/stanza/models/constituency/parser_training.py b/stanza/models/constituency/parser_training.py index 369e156a09..c3006a3b39 100644 --- a/stanza/models/constituency/parser_training.py +++ b/stanza/models/constituency/parser_training.py @@ -363,7 +363,7 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d model_loss_function.to(device) if args['contrastive_learning_rate'] > 0: - contrastive_loss_function = nn.CosineEmbeddingLoss(margin=args['contrastive_margin']) + contrastive_loss_function = nn.MSELoss(reduction='sum') contrastive_loss_function.to(device) else: contrastive_loss_function = None @@ -598,16 +598,14 @@ def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_te gold_trees = [x.constituents.value.value.value for x in gold_states] gold_tree_hx = [x.constituents.value.value.tree_hx for x in gold_states] - reparsed_negatives = [hx for hx, reparsed_tree, gold_tree in zip(reparsed_tree_hx, reparsed_trees, gold_trees) if reparsed_tree != gold_tree] - gold_negatives = [hx for hx, reparsed_tree, gold_tree in zip(gold_tree_hx, reparsed_trees, gold_trees) if reparsed_tree != gold_tree] + reparsed_negatives = [nn.functional.normalize(hx) for hx, reparsed_tree, gold_tree in zip(reparsed_tree_hx, reparsed_trees, gold_trees) if reparsed_tree != gold_tree] + gold_negatives = [nn.functional.normalize(hx) for hx, reparsed_tree, gold_tree in zip(gold_tree_hx, reparsed_trees, gold_trees) if reparsed_tree != gold_tree] if len(reparsed_negatives) > 0: - reparsed_negatives = torch.cat(reparsed_negatives, dim=0) - gold_negatives = torch.cat(gold_negatives, dim=0) - + mse = torch.stack([torch.dot(x.squeeze(0), y.squeeze(0)) for x, y in zip(reparsed_negatives, gold_negatives)]) device = next(model.parameters()).device - target = -torch.ones(reparsed_negatives.shape[0]).to(device) - contrastive_loss = args['contrastive_learning_rate'] * contrastive_loss_function(reparsed_negatives, gold_negatives, target) + target = torch.zeros(mse.shape[0]).to(device) + contrastive_loss = args['contrastive_learning_rate'] * contrastive_loss_function(mse, target) # now we add the state to the trees in the batch # the state is built as a bulk operation