Skip to content

Commit

Permalink
all tests passing
Browse files Browse the repository at this point in the history
  • Loading branch information
jlunder00 committed Nov 24, 2024
1 parent 6331618 commit 3b88d6b
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 64 deletions.
71 changes: 19 additions & 52 deletions TMN_DataGen/dataset_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,45 +14,20 @@ class DatasetGenerator:
def __init__(self, config: Optional[DictConfig] = None):
self.config = config or {}
parser_type = self.config.get('parser', {}).get('type', 'diaparser')
self.parser = {
# Create an instance of the parser, not just store the class
parser_class = {
"diaparser": DiaParserTreeParser,
"spacy": SpacyTreeParser,
"multi": MultiParser
}[parser_type]
self.parser = parser_class(self.config) # Initialize parser instance

self.label_map = {
'entails': 1,
'contradicts': -1,
'neutral': 0
}

# def generate_dataset(self,
# sentence_pairs: List[Tuple[str, str]],
# labels: List[str],
# output_path: str,
# show_progress: bool = True):
# """Generate and save dataset in GMN-compatible format"""
# # Parse all sentences
# all_sentences = []
# for premise, hypothesis in sentence_pairs:
# all_sentences.extend([premise, hypothesis])
#
# print("Parsing sentences...")
# all_trees = self.parser.parse_all(all_sentences, show_progress)
#
# # Pair trees
# tree_pairs = []
# for i in range(0, len(all_trees), 2):
# tree_pairs.append((all_trees[i], all_trees[i+1]))
#
# # Convert to GMN format
# print("Converting to GMN format...")
# dataset = self._convert_to_gmn_format(tree_pairs, labels)
#
# # Save dataset
# print(f"Saving dataset to {output_path}")
# self._save_dataset(dataset, output_path)

def generate_dataset(self, sentence_pairs: List[Tuple[str, str]],
labels: List[str],
output_path: str,
Expand All @@ -77,29 +52,21 @@ def generate_dataset(self, sentence_pairs: List[Tuple[str, str]],
# Save dataset
with open(output_path, 'wb') as f:
pickle.dump(dataset, f)

def _convert_to_gmn_format(self,
tree_pairs: List[Tuple[DependencyTree, DependencyTree]],
labels: List[str]) -> Dict:
"""Convert tree pairs to GMN-compatible format"""
graph_pairs = []
numeric_labels = []

for (tree1, tree2), label in zip(tree_pairs, labels):
graph1 = tree1.to_graph_data()
graph2 = tree2.to_graph_data()
graph_pairs.append((graph1, graph2))
numeric_labels.append(self.label_map[label])

return {
'graph_pairs': graph_pairs,
'labels': numeric_labels
}

def _save_dataset(self, dataset: Dict, output_path: str):
"""Save dataset to file"""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
def _convert_to_gmn_format(self,
tree_pairs: List[Tuple[DependencyTree, DependencyTree]],
labels: List[str]) -> Dict:
"""Convert tree pairs to GMN-compatible format"""
graph_pairs = []
numeric_labels = []

with open(output_path, 'wb') as f:
pickle.dump(dataset, f)
for (tree1, tree2), label in zip(tree_pairs, labels):
graph1 = tree1.to_graph_data()
graph2 = tree2.to_graph_data()
graph_pairs.append((graph1, graph2))
numeric_labels.append(self.label_map[label])

return {
'graph_pairs': graph_pairs,
'labels': numeric_labels
}
40 changes: 32 additions & 8 deletions TMN_DataGen/parsers/base_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,12 @@ def parse_single(self, sentence: str) -> DependencyTree:

def parse_all(self, sentences: List[str], show_progress: bool = True) -> List[DependencyTree]:
"""Parse all sentences with batching and progress bar"""
# First ensure sentences is a list
if not isinstance(sentences, list):
sentences = list(sentences) # Try to convert to list if possible

if not sentences: # Handle empty input
return []

if not all(isinstance(s, str) for s in sentences):
raise TypeError("All elements must be strings")
if hasattr(sentences, '__iter__'): # Check if it's iterable
sentences = list(sentences)
else:
raise TypeError("sentences must be a list of strings")

trees = []
total_sentences = len(sentences)
Expand All @@ -72,9 +70,35 @@ def parse_all(self, sentences: List[str], show_progress: bool = True) -> List[De

if show_progress:
print("\nDone!")

return trees

# def parse_all(self, sentences: List[str], show_progress: bool = True) -> List[DependencyTree]:
# """Parse all sentences with batching and progress bar"""
# if not isinstance(sentences, list):
# sentences = list(sentences) # Try to convert to list if possible
#
# if not sentences: # Handle empty input
# return []
#
# if not all(isinstance(s, str) for s in sentences):
# raise TypeError("All elements must be strings")
#
# trees = []
# total_sentences = len(sentences)
#
# # Create batches
# for i in range(0, total_sentences, self.batch_size):
# batch = sentences[i:min(i + self.batch_size, total_sentences)]
# if show_progress:
# print(f"\rProcessing {i+1}/{total_sentences} sentences...", end="")
# trees.extend(self.parse_batch(batch))
#
# if show_progress:
# print("\nDone!")
#
# return trees

# def parse_all(self, sentences: List[str], show_progress: bool = True) -> List[DependencyTree]:
# """Parse all sentences with batching and progress bar"""
# trees = []
Expand Down
24 changes: 24 additions & 0 deletions TMN_DataGen/parsers/multi_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,30 @@ def parse_batch(self, sentences: List[str]) -> List[DependencyTree]:

return combined_trees

def _enhance_tree(self, base_tree: DependencyTree, parser_trees: Dict[str, DependencyTree]):
base_nodes = {node.idx: node for node in base_tree.root.get_subtree_nodes()}

for feature, parser_name in self.feature_sources.items():
if parser_name not in parser_trees:
continue

other_tree = parser_trees[parser_name]
other_nodes = {node.idx: node for node in other_tree.root.get_subtree_nodes()}

for idx, base_node in base_nodes.items():
if idx in other_nodes:
other_node = other_nodes[idx]
if feature == 'pos_tags':
base_node.pos_tag = other_node.pos_tag
elif feature == 'lemmas':
base_node.lemma = other_node.lemma
elif feature == 'morph':
# Ensure morph_features exists in features dict
if 'morph_features' not in base_node.features:
base_node.features['morph_features'] = {}
# Copy morph features from spacy
base_node.features['morph_features'].update(other_node.features.get('morph_features', {}))

def _enhance_tree(self, base_tree: DependencyTree,
parser_trees: Dict[str, DependencyTree]):
"""Enhance base tree with features from other parsers"""
Expand Down
9 changes: 6 additions & 3 deletions TMN_DataGen/parsers/spacy_impl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#spacy_impl.py

from .base_parser import BaseTreeParser
from ..tree.node import Node
from ..tree.dependency_tree import DependencyTree
Expand All @@ -23,14 +22,18 @@ def parse_single(self, sentence: str) -> DependencyTree:
return self._convert_to_tree(doc)

def _convert_to_tree(self, doc: Any) -> DependencyTree:
# Create nodes
nodes = [
Node(
word=token.text,
lemma=token.lemma_,
pos_tag=token.pos_,
idx=token.i,
features={'original_text': token.text}
features={
'original_text': token.text,
'morph_features': dict(feature.split('=')
for feature in str(token.morph).split('|')
if feature != '') # Handle empty morph case
}
)
for token in doc
]
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

setup(
name="TMN_DataGen",
version='0.2.6',
version='0.2.10',
description="Tree Matching Network Data Generator",
author="toast",
packages=find_packages(),
Expand Down

0 comments on commit 3b88d6b

Please sign in to comment.