From 9d647fbc7ee763923f72ca9af2f32c1bc1955d07 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Sun, 8 Dec 2024 20:28:22 -0800 Subject: [PATCH] Try dotting all pairs of vectors for orthogonality, not just neighbors --- stanza/models/constituency/parser_training.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/stanza/models/constituency/parser_training.py b/stanza/models/constituency/parser_training.py index 927bf89d8a..6abf4121c6 100644 --- a/stanza/models/constituency/parser_training.py +++ b/stanza/models/constituency/parser_training.py @@ -1,5 +1,6 @@ from collections import Counter, namedtuple import copy +import itertools import logging import os import random @@ -689,13 +690,13 @@ def build_losses(con_values, tree): subtrees = [x for x in tree.children if not x.is_preterminal()] for subtree in subtrees: build_losses(con_values, subtree) - for subtree_idx in range(len(subtrees)-1): - left = str(subtrees[subtree_idx]) - right = str(subtrees[subtree_idx+1]) + for left, right in itertools.combinations(subtrees, 2): + left = str(left) + right = str(right) if left in con_values and right in con_values: left_value = con_values[left].squeeze(0) right_value = con_values[right].squeeze(0) - mse = torch.dot(left_value, right_value) + mse = torch.dot(left_value, right_value) / (len(subtrees) - 1) orthogonal_losses.append(mse) for result in gold_results: gold_constituents = result.constituents