Skip to content

Commit

Permalink
Build up orthogonal losses for neighboring subtrees. Need to add cros…
Browse files Browse the repository at this point in the history
…sing subtrees and need to fix the loss
  • Loading branch information
AngledLuffa committed Dec 5, 2024
1 parent 84de27b commit 1f78391
Showing 1 changed file with 25 additions and 7 deletions.
32 changes: 25 additions & 7 deletions stanza/models/constituency/parser_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,13 +668,31 @@ def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_te
orthogonal_loss = 0.0
if epoch >= args['orthogonal_initial_epoch'] and orthogonal_loss_function is not None:
gold_results = model.analyze_trees([x.tree for x in training_batch], keep_constituents=True, keep_scores=False)
gold_constituents = [x.constituents for x in gold_results]
# TODO
#for gold_con in gold_constituents:
# print(len(gold_con))
# for con in gold_con:
# print(con)
# raise ValueError
left_orthogonal_losses = []
right_orthogonal_losses = []
def build_losses(con_values, tree):
# this can skip preterminals
# but a preterminal in the middle of a phrase has a high chance of being
# a conjunction, a punctuation, or other non-function word anyway
subtrees = [x for x in tree.children if not x.is_preterminal()]
for subtree in subtrees:
build_losses(con_values, subtree)
for subtree_idx in range(len(subtrees)-1):
left = str(subtrees[subtree_idx])
right = str(subtrees[subtree_idx+1])
if left in con_values and right in con_values:
left_orthogonal_losses.append(con_values[left])
right_orthogonal_losses.append(con_values[right])
for result in gold_results:
gold_constituents = result.constituents
con_values = {}
for con in gold_constituents:
con_values[str(con.value)] = con.tree_hx
build_losses(con_values, result.gold)
left_inputs = torch.cat(left_orthogonal_losses, axis=0)
right_inputs = torch.cat(right_orthogonal_losses, axis=0)
target = -torch.ones(left_inputs.shape[0]).to(left_inputs.device)
orthogonal_loss = orthogonal_loss_function(left_inputs, right_inputs, target) * args['orthogonal_learning_rate']


errors = process_outputs(errors)
Expand Down

0 comments on commit 1f78391

Please sign in to comment.