Skip to content

Commit

Permalink
trying to fix parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
jlunder00 committed Nov 24, 2024
1 parent c54b22f commit 0d668d2
Show file tree
Hide file tree
Showing 6 changed files with 190 additions and 55 deletions.
6 changes: 3 additions & 3 deletions TMN_DataGen/parsers/base_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,13 @@ def parse_all(self, sentences: List[str], show_progress: bool = True) -> List[De
trees = []
total_sentences = len(sentences)

if self.verbose == 'normal':
if self.verbose == 'normal' or self.verbose == 'debug':
self.logger.info(f"Processing {total_sentences} sentences total...")

# Create batches
for i in range(0, total_sentences, self.batch_size):
batch = sentences[i:min(i + self.batch_size, total_sentences)]
if show_progress and self.verbose == 'normal':
if show_progress and self.verbose == 'normal' or self.verbose == 'debug':
self.logger.info(f"Processing batch {i//self.batch_size + 1}...")

batch_trees = self.parse_batch(batch)
Expand All @@ -75,7 +75,7 @@ def parse_all(self, sentences: List[str], show_progress: bool = True) -> List[De

trees.extend(batch_trees)

if show_progress and self.verbose == 'normal':
if show_progress and self.verbose == 'normal' or self.verbose == 'debug':
self.logger.info("Done!")

return trees
Expand Down
159 changes: 110 additions & 49 deletions TMN_DataGen/parsers/diaparser_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from ..tree.dependency_tree import DependencyTree
from ..utils.logging_config import logger
from diaparser.parsers import Parser
from typing import List, Any, Optional, Tuple
from typing import List, Dict, Any, Optional, Tuple
from omegaconf import DictConfig
import numpy as np

Expand All @@ -15,79 +15,140 @@ def __init__(self, config: Optional[DictConfig] = None):
model_name = self.config.get('model_name', 'en_ewt.electra-base')
self.model = Parser.load(model_name)

def _process_prediction(self, dataset) -> Tuple[List[str], List[int], List[str]]:
"""Process diaparser prediction output into words, heads, and relations."""
sentence = dataset.sentences[0] # CoNLLSentence object

# Access values directly - these are already lists
words = list(sentence.values[1]) # Get all words
heads = list(sentence.values[6]) # Get all head indices
rels = list(sentence.values[7]) # Get all dependency relations
def _process_prediction(self, dataset) -> Dict[str, List]:
"""
Process diaparser output into aligned lists of token information.
DiaParser provides output in CoNLL-X format with these indices:
1: FORM - Word form/token
2: LEMMA - Lemma
3: UPOS - Universal POS tag
6: HEAD - Head token id
7: DEPREL - Dependency relation
Returns:
Dict with keys: words, lemmas, pos_tags, heads, rels
All lists are aligned by token position
"""
sentence = dataset.sentences[0]

if self.verbose == 'debug':
self.logger.debug(f"Raw parser output:")
self.logger.debug(f"Words: {words}")
self.logger.debug(f"Heads: {heads}")
self.logger.debug(f"Relations: {rels}")
self.logger.debug(f"Processing CoNLL format sentence:")
self.logger.debug(f"Raw values: {sentence.values}")

def ensure_list(val) -> List[str]:
"""Convert various input formats to list"""
if isinstance(val, str):
return val.split()
elif isinstance(val, tuple):
# DiaParser sometimes returns tuples
return ensure_list(val[0])
return list(val)

return words, heads, rels
try:
token_data = {
'words': ensure_list(sentence.values[1]),
'lemmas': ensure_list(sentence.values[2]),
'pos_tags': ensure_list(sentence.values[3]),
'heads': [int(h) for h in ensure_list(sentence.values[6])],
'rels': ensure_list(sentence.values[7])
}

# Verify all lists have same length
list_lens = [len(lst) for lst in token_data.values()]
if len(set(list_lens)) != 1:
raise ValueError(
f"Inconsistent token list lengths: {list_lens}"
)

if self.verbose == 'debug':
self.logger.debug("Processed token data:")
for key, value in token_data.items():
self.logger.debug(f"{key}: {value}")

return token_data

except Exception as e:
self.logger.error(f"Error processing parser output: {e}")
self.logger.debug(f"Values: {sentence.values}")
raise

def parse_batch(self, sentences: List[str]) -> List[DependencyTree]:
trees = []

if self.verbose == 'debug':
self.logger.debug(f"Parsing batch of {len(sentences)} sentences")
elif self.verbose == 'normal':
self.logger.info(f"Processing {len(sentences)} sentences...")


for sentence in sentences:
# Get prediction dataset
dataset = self.model.predict([sentence])
words, heads, rels = self._process_prediction(dataset)

if self.verbose == 'debug':
self.logger.debug(f"Creating nodes for sentence: {sentence}")
self.logger.debug(f"\nParsing sentence: {sentence}")

# Create nodes
nodes = [
Node(
word=word,
lemma=word.lower(), # Simple lemmatization for now
pos_tag="", # We could add POS tags if needed
idx=idx,
features={'original_text': word}
)
for idx, word in enumerate(words)
]
# Get DiaParser output
dataset = self.model.predict([sentence])
token_data = self._process_prediction(dataset)

# Connect nodes
root = None
for idx, (head_idx, dep_label) in enumerate(zip(heads, rels)):
if self.verbose == 'debug':
self.logger.debug(f"Processing node {idx}: word={words[idx]}, "
f"head={head_idx}, relation={dep_label}")
# Step 1: Create all nodes
nodes = []
for i in range(len(token_data['words'])):
node = Node(
word=token_data['words'][i],
lemma=token_data['lemmas'][i],
pos_tag=token_data['pos_tags'][i],
idx=i,
features={
'original_text': token_data['words'][i]
}
)
nodes.append(node)

if self.verbose == 'debug':
self.logger.debug(f"Created node {i}: {node}")

# Step 2: Connect nodes using head indices
root = None
for i, (node, head_idx, rel) in enumerate(zip(nodes,
token_data['heads'],
token_data['rels'])):
if head_idx == 0: # Root node
root = nodes[idx]
root = node
if self.verbose == 'debug':
self.logger.debug(f"Found root node: {root.word}")
self.logger.debug(f"Found root node: {node}")
else:
parent = nodes[head_idx - 1] # diaparser uses 1-based indices
parent.add_child(nodes[idx], dep_label)
# Head indices are 1-based in CoNLL format
parent = nodes[head_idx - 1]
parent.add_child(node, rel)
if self.verbose == 'debug':
self.logger.debug(f"Added {nodes[idx].word} as child of "
f"{parent.word} with label {dep_label}")

self.logger.debug(f"Connected node {node} to parent {parent}")

# Step 3: Verify we found a root and built valid tree
if root is None:
raise ValueError(f"No root node found in parse: {sentence}")

tree = DependencyTree(root, config=self.config)

if self.verbose == 'normal':
# Verify all nodes are reachable and structure is valid
tree_nodes = tree.root.get_subtree_nodes()
if len(tree_nodes) != len(nodes):
raise ValueError(
f"Tree structure incomplete: only {len(tree_nodes)} of {len(nodes)} "
f"nodes reachable from root"
)

# Verify tree structure is valid
if not root.verify_tree_structure():
raise ValueError(
f"Invalid tree structure detected for sentence: {sentence}"
)

if self.verbose in ('normal', 'debug'):
from ..utils.viz_utils import print_tree_text
self.logger.info("\nParsed tree structure:")
self.logger.info(print_tree_text(tree, self.config))

trees.append(tree)

return trees

def parse_single(self, sentence: str) -> DependencyTree:
"""Parse a single sentence into a dependency tree"""
return self.parse_batch([sentence])[0]

3 changes: 3 additions & 0 deletions TMN_DataGen/tree/dependency_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ class DependencyTree:
def __init__(self, root: Node, config: Optional[DictConfig] = None):
self.root = root
self.config = config or {}
# Verify tree structure
if not self.root.verify_tree_structure():
raise ValueError("Invalid tree structure detected")

def modify_subtree(self, condition_fn, modification_fn):
"""Apply modification to nodes that meet condition"""
Expand Down
37 changes: 36 additions & 1 deletion TMN_DataGen/tree/node.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#node.py

from typing import List, Tuple, Optional, Dict, Any
import numpy as np

Expand Down Expand Up @@ -69,6 +68,41 @@ def traverse_levelorder(self):
def get_subtree_nodes(self):
"""Get all nodes in the subtree rooted at this node"""
return list(self.traverse_preorder())

def verify_tree_structure(self) -> bool:
"""
Verify tree structure is valid:
- No cycles
- All parent pointers match children lists
- Each node appears exactly once
"""
visited = set()

def _verify_subtree(node: 'Node') -> bool:
if node.idx in visited:
return False
visited.add(node.idx)

# Verify parent-child consistency
for child, dep_type in node.children:
if child.parent is not node:
return False
if child.dependency_to_parent != dep_type:
return False
if not _verify_subtree(child):
return False
return True

return _verify_subtree(self)

def __str__(self) -> str:
"""String representation for debugging"""
dep_info = f" --{self.dependency_to_parent}-->" if self.dependency_to_parent else ""
return f"{self.word}({self.pos_tag}){dep_info}"

def __repr__(self) -> str:
"""Detailed representation for debugging"""
return f"Node(word='{self.word}', pos='{self.pos_tag}', idx={self.idx})"

def to_dict(self) -> Dict:
"""Convert node to dictionary representation"""
Expand All @@ -93,3 +127,4 @@ def from_dict(cls, data: Dict) -> 'Node':
)
node.dependency_to_parent = data['dependency_to_parent']
return node

2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

setup(
name="TMN_DataGen",
version='0.3.7',
version='0.3.10',
description="Tree Matching Network Data Generator",
author="toast",
packages=find_packages(),
Expand Down
38 changes: 37 additions & 1 deletion tests/test_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,42 @@ def base_config():
}
})

def test_diaparser_detailed(base_config):
"""Detailed test of DiaParser tree construction"""
config = OmegaConf.merge(base_config, {
'parser': {
'type': 'diaparser',
'verbose': 'debug'
}
})
parser = DiaParserTreeParser(config)

# Test simple sentence
sentence = "The cat chases the mouse."
trees = parser.parse_all([sentence])
assert len(trees) == 1

tree = trees[0]
nodes = tree.root.get_subtree_nodes()

# Verify basic structure
assert len(nodes) == 5 # should have 5 words
assert tree.root.word == "chases" # main verb should be root
assert tree.root.pos_tag == "VERB"

# Find subject and object
subj = [n for n, t in tree.root.children if t == "nsubj"][0]
obj = [n for n, t in tree.root.children if t == "obj"][0]

assert subj.word == "cat"
assert obj.word == "mouse"

# Check determiners
assert len(subj.children) == 1
assert len(obj.children) == 1
assert subj.children[0][0].word == "The"
assert obj.children[0][0].word == "the"

def test_diaparser(base_config):
config = OmegaConf.merge(base_config, {
'parser': {
Expand All @@ -23,7 +59,7 @@ def test_diaparser(base_config):
parser = DiaParserTreeParser(config)

sentence = "The cat chases the mouse."
tree = parser.parse_single(sentence)
tree = parser.parse_all([sentence])

# Check tree structure
assert tree.root is not None
Expand Down

0 comments on commit 0d668d2

Please sign in to comment.