diff --git a/stanza/models/constituency/parser_training.py b/stanza/models/constituency/parser_training.py index 4790fc859..6a03f74f2 100644 --- a/stanza/models/constituency/parser_training.py +++ b/stanza/models/constituency/parser_training.py @@ -619,22 +619,27 @@ def contrast_trees(reparsed, gold): if reparsed.is_preterminal() or gold.is_preterminal(): return - if (len(reparsed.children) == len(gold.children) and - all(x.start_index == y.start_index and x.end_index == y.end_index for x, y in zip(reparsed.children, gold.children))): - for x, y in zip(reparsed.children, gold.children): - contrast_trees(x, y) - return - - # TODO: instead compare all subtrees? - # preterminals don't have values returned by the tree analysis functions above - if (reparsed.children[0].end_index != gold.children[0].end_index and - not reparsed.children[0].is_preterminal() and not gold.children[0].is_preterminal()): - reparsed_negatives.append(reparsed_hx[str(reparsed.children[0])]) - gold_negatives.append(gold_hx[str(gold.children[0])]) - if (reparsed.children[-1].start_index != gold.children[-1].start_index and - not reparsed.children[-1].is_preterminal() and not gold.children[-1].is_preterminal()): - reparsed_negatives.append(reparsed_hx[str(reparsed.children[-1])]) - gold_negatives.append(gold_hx[str(gold.children[-1])]) + reparsed_idx = 0 + gold_idx = 0 + while reparsed_idx < len(reparsed.children) and gold_idx < len(gold.children): + reparsed_child = reparsed.children[reparsed_idx] + gold_child = gold.children[gold_idx] + if not reparsed_child.is_preterminal() and not gold_child.is_preterminal(): + # TODO: check that comparing labels is helpful + if (reparsed_child.label == gold_child.label and + reparsed_child.start_index == gold_child.start_index and + reparsed_child.end_index == gold_child.end_index): + contrast_trees(reparsed_child, gold_child) + else: + reparsed_negatives.append(reparsed_hx[str(reparsed_child)]) + gold_negatives.append(gold_hx[str(gold_child)]) + if reparsed_child.end_index == gold_child.end_index: + reparsed_idx += 1 + gold_idx += 1 + elif reparsed_child.end_index < gold_child.end_index: + reparsed_idx += 1 + else: + gold_idx += 1 reparsed_tree.mark_spans() gold_tree.mark_spans()