From 1f78391b427aff2df2bc9a7c28896d7a9e9820ed Mon Sep 17 00:00:00 2001 From: John Bauer Date: Thu, 5 Dec 2024 00:57:05 -0800 Subject: [PATCH] Build up orthogonal losses for neighboring subtrees. Need to add crossing subtrees and need to fix the loss --- stanza/models/constituency/parser_training.py | 32 +++++++++++++++---- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/stanza/models/constituency/parser_training.py b/stanza/models/constituency/parser_training.py index 13ffbd3837..3bddf2202c 100644 --- a/stanza/models/constituency/parser_training.py +++ b/stanza/models/constituency/parser_training.py @@ -668,13 +668,31 @@ def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_te orthogonal_loss = 0.0 if epoch >= args['orthogonal_initial_epoch'] and orthogonal_loss_function is not None: gold_results = model.analyze_trees([x.tree for x in training_batch], keep_constituents=True, keep_scores=False) - gold_constituents = [x.constituents for x in gold_results] - # TODO - #for gold_con in gold_constituents: - # print(len(gold_con)) - # for con in gold_con: - # print(con) - # raise ValueError + left_orthogonal_losses = [] + right_orthogonal_losses = [] + def build_losses(con_values, tree): + # this can skip preterminals + # but a preterminal in the middle of a phrase has a high chance of being + # a conjunction, a punctuation, or other non-function word anyway + subtrees = [x for x in tree.children if not x.is_preterminal()] + for subtree in subtrees: + build_losses(con_values, subtree) + for subtree_idx in range(len(subtrees)-1): + left = str(subtrees[subtree_idx]) + right = str(subtrees[subtree_idx+1]) + if left in con_values and right in con_values: + left_orthogonal_losses.append(con_values[left]) + right_orthogonal_losses.append(con_values[right]) + for result in gold_results: + gold_constituents = result.constituents + con_values = {} + for con in gold_constituents: + con_values[str(con.value)] = con.tree_hx + build_losses(con_values, result.gold) + left_inputs = torch.cat(left_orthogonal_losses, axis=0) + right_inputs = torch.cat(right_orthogonal_losses, axis=0) + target = -torch.ones(left_inputs.shape[0]).to(left_inputs.device) + orthogonal_loss = orthogonal_loss_function(left_inputs, right_inputs, target) * args['orthogonal_learning_rate'] errors = process_outputs(errors)