Skip to content

Commit

Permalink
debugging visualization tools and logging
Browse files Browse the repository at this point in the history
  • Loading branch information
jlunder00 committed Nov 24, 2024
1 parent 3b88d6b commit ef33a90
Show file tree
Hide file tree
Showing 15 changed files with 439 additions and 88 deletions.
21 changes: 20 additions & 1 deletion TMN_DataGen/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,27 @@
#TMN_DataGen/TMN_DataGen/__init__.py
# TMN_DataGen/TMN_DataGen/__init__.py
import logging

# Set up logging configuration
logging.basicConfig(
level=logging.INFO,
format='%(message)s', # Simple format to just show messages for tree visualization
force=True # Ensure our config takes precedence
)

logger = logging.getLogger(__name__)

from .tree.node import Node
from .tree.dependency_tree import DependencyTree
from .parsers.diaparser_impl import DiaParserTreeParser
from .parsers.spacy_impl import SpacyTreeParser
from .dataset_generator import DatasetGenerator
from .parsers.multi_parser import MultiParser
from .utils.feature_utils import FeatureExtractor
from .utils.viz_utils import print_tree_text, visualize_tree_graphviz, format_tree_pair

__all__ = [
'Node', 'DependencyTree',
'DiaParserTreeParser', 'SpacyTreeParser', 'MultiParser',
'DatasetGenerator', 'FeatureExtractor',
'print_tree_text', 'visualize_tree_graphviz', 'format_tree_pair'
]
38 changes: 29 additions & 9 deletions TMN_DataGen/dataset_generator.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,54 @@
#dataset_generator.py
# TMN_DataGen/TMN_DataGen/dataset_generator.py
from typing import List, Tuple, Optional, Dict
from omegaconf import DictConfig
from .parsers import DiaParserTreeParser, SpacyTreeParser, MultiParser
from .tree import DependencyTree
from .utils.viz_utils import format_tree_pair
from .utils.logging_config import logger
import torch
import numpy as np
from tqdm import tqdm
import json
import pickle
from pathlib import Path


class DatasetGenerator:
def __init__(self, config: Optional[DictConfig] = None):
self.config = config or {}
self.verbose = self.config.get('verbose', False)
parser_type = self.config.get('parser', {}).get('type', 'diaparser')
# Create an instance of the parser, not just store the class

if 'parser' not in self.config:
self.config['parser'] = {}
self.config['parser']['verbose'] = self.verbose

parser_class = {
"diaparser": DiaParserTreeParser,
"spacy": SpacyTreeParser,
"multi": MultiParser
}[parser_type]
self.parser = parser_class(self.config) # Initialize parser instance
self.parser = parser_class(self.config)

self.label_map = {
'entails': 1,
'contradicts': -1,
'neutral': 0
}
if self.verbose:
logger.info("Initialized DatasetGenerator with verbose output")

def generate_dataset(self, sentence_pairs: List[Tuple[str, str]],
labels: List[str],
output_path: str,
show_progress: bool = True) -> None:
"""Generate dataset from sentence pairs and labels"""
# Flatten sentences for parsing
if self.verbose:
logger.info("\nGenerating dataset...")
logger.info(f"Processing {len(sentence_pairs)} sentence pairs")

all_sentences = [s for pair in sentence_pairs for s in pair]

# Parse all sentences
print("Parsing sentences...")
logger.info("Parsing sentences...")
all_trees = self.parser.parse_all(all_sentences, show_progress)

# Pair up trees
Expand All @@ -46,12 +57,21 @@ def generate_dataset(self, sentence_pairs: List[Tuple[str, str]],
for i in range(0, len(all_trees), 2)
]

print("Converting to GMN format...")
if self.verbose:
logger.info("\nGenerated tree pairs:")
for (tree1, tree2), label in zip(tree_pairs, labels):
logger.info("\n" + "=" * 80)
logger.info(format_tree_pair(tree1, tree2, label))
logger.info("=" * 80)

logger.info("Converting to GMN format...")
dataset = self._convert_to_gmn_format(tree_pairs, labels)

# Save dataset
with open(output_path, 'wb') as f:
pickle.dump(dataset, f)

if self.verbose:
logger.info(f"\nDataset saved to {output_path}")

def _convert_to_gmn_format(self,
tree_pairs: List[Tuple[DependencyTree, DependencyTree]],
Expand Down
3 changes: 3 additions & 0 deletions TMN_DataGen/parsers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#parsers/__init__.py

from ..utils.logging_config import logger
from .base_parser import BaseTreeParser
from .diaparser_impl import DiaParserTreeParser
from .spacy_impl import SpacyTreeParser
from .multi_parser import MultiParser

__all__ = ['BaseTreeParser', 'DiaParserTreeParser', 'SpacyTreeParser', 'MultiParser']
81 changes: 20 additions & 61 deletions TMN_DataGen/parsers/base_parser.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,28 @@
#base_parser.py

# TMN_DataGen/TMN_DataGen/parsers/base_parser.py
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional
from ..tree.dependency_tree import DependencyTree
from ..utils.viz_utils import print_tree_text
from ..utils.logging_config import logger
from omegaconf import DictConfig
import torch
from tqdm import tqdm


class BaseTreeParser(ABC):
_instances: Dict[str, 'BaseTreeParser'] = {}

def __new__(cls, config: Optional[DictConfig] = None):
"""Singleton pattern implementation"""
if cls not in cls._instances:
cls._instances[cls.__name__] = super(BaseTreeParser, cls).__new__(cls)
return cls._instances[cls.__name__]

@abstractmethod
def __init__(self, config: Optional[DictConfig] = None):
"""Initialize parser with configuration"""
if not hasattr(self, 'initialized'):
self.config = config or {}
self.batch_size = self.config.get('batch_size', 32)
self.verbose = self.config.get('verbose', False)
self.initialized = True

@abstractmethod
Expand All @@ -34,26 +35,10 @@ def parse_single(self, sentence: str) -> DependencyTree:
"""Parse a single sentence into a dependency tree"""
pass

# def parse_all(self, sentences: List[str], show_progress: bool = True) -> List[DependencyTree]:
# """Parse all sentences with batching and progress bar"""
# if not isinstance(sentences, list):
# raise TypeError("sentences must be a list of strings")
#
# trees = []
# iterator = range(0, len(sentences), self.batch_size)
# if show_progress:
# iterator = tqdm(iterator, desc="Parsing sentences")
#
# for i in iterator:
# batch = sentences[i:i + self.batch_size]
# trees.extend(self.parse_batch(batch))
# return trees

def parse_all(self, sentences: List[str], show_progress: bool = True) -> List[DependencyTree]:
"""Parse all sentences with batching and progress bar"""
# First ensure sentences is a list
"""Parse all sentences with batching and progress bar."""
if not isinstance(sentences, list):
if hasattr(sentences, '__iter__'): # Check if it's iterable
if hasattr(sentences, '__iter__'):
sentences = list(sentences)
else:
raise TypeError("sentences must be a list of strings")
Expand All @@ -65,48 +50,22 @@ def parse_all(self, sentences: List[str], show_progress: bool = True) -> List[De
for i in range(0, total_sentences, self.batch_size):
batch = sentences[i:min(i + self.batch_size, total_sentences)]
if show_progress:
print(f"\rProcessing {i+1}/{total_sentences} sentences...", end="")
trees.extend(self.parse_batch(batch))
logger.info(f"Processing {i+1}/{total_sentences} sentences...")

batch_trees = self.parse_batch(batch)
if self.verbose:
for sent, tree in zip(batch, batch_trees):
logger.info("\n" + "="*80)
logger.info(f"Processed sentence: {sent}")
logger.info("\nTree structure:")
logger.info(print_tree_text(tree))
logger.info("="*80)

trees.extend(batch_trees)

if show_progress:
print("\nDone!")
logger.info("Done!")

return trees

# def parse_all(self, sentences: List[str], show_progress: bool = True) -> List[DependencyTree]:
# """Parse all sentences with batching and progress bar"""
# if not isinstance(sentences, list):
# sentences = list(sentences) # Try to convert to list if possible
#
# if not sentences: # Handle empty input
# return []
#
# if not all(isinstance(s, str) for s in sentences):
# raise TypeError("All elements must be strings")
#
# trees = []
# total_sentences = len(sentences)
#
# # Create batches
# for i in range(0, total_sentences, self.batch_size):
# batch = sentences[i:min(i + self.batch_size, total_sentences)]
# if show_progress:
# print(f"\rProcessing {i+1}/{total_sentences} sentences...", end="")
# trees.extend(self.parse_batch(batch))
#
# if show_progress:
# print("\nDone!")
#
# return trees

# def parse_all(self, sentences: List[str], show_progress: bool = True) -> List[DependencyTree]:
# """Parse all sentences with batching and progress bar"""
# trees = []
# iterator = range(0, len(sentences), self.batch_size)
# if show_progress:
# iterator = tqdm(iterator, desc="Parsing sentences")
#
# for i in iterator:
# batch = sentences[i:i + self.batch_size]
# trees.extend(self.parse_batch(batch))
# return trees
61 changes: 53 additions & 8 deletions TMN_DataGen/parsers/diaparser_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .base_parser import BaseTreeParser
from ..tree.node import Node
from ..tree.dependency_tree import DependencyTree
from ..utils.logging_config import logger
from diaparser.parsers import Parser
from typing import List, Any, Optional, Tuple
from omegaconf import DictConfig
Expand All @@ -22,28 +23,42 @@ def _process_prediction(self, dataset) -> Tuple[List[str], List[int], List[str]]
# Access the values directly from the CoNLLSentence object
# The sentence.values contains tuples for each field in order:
# [id, form, lemma, upos, xpos, feats, head, deprel, deps, misc]
words = []
heads = []
rels = []

# Debug log raw parser output
if self.verbose:
logger.info("\nParser raw output:")
for i, field in enumerate(sentence.values):
logger.info(f"Field {i}: {field}")

# Get the form (words) from values[1]
words_tuple = sentence.values[1]
if isinstance(words_tuple, tuple) and len(words_tuple) == 1:
# Split the sentence into words if it's a single string
words = words_tuple[0].split()
else:
words = words_tuple

# Get the head indices from values[6]
heads_list = sentence.values[6]
if isinstance(heads_list, list):
heads = heads_list

else:
heads = [heads_list]

# Get the dependency relations from values[7]
rels_list = sentence.values[7]
if isinstance(rels_list, list):
rels = rels_list
else:
rels = [rels_list]

if self.verbose:
logger.info("\nProcessed parser fields:")
logger.info(f"Words: {words}")
logger.info(f"Head indices: {heads}")
logger.info(f"Relations: {rels}")

return words, heads, rels

def parse_batch(self, sentences: List[str]) -> List[DependencyTree]:
trees = []
for sentence in sentences:
Expand Down Expand Up @@ -81,6 +96,14 @@ def parse_single(self, sentence: str) -> DependencyTree:
return self.parse_batch([sentence])[0]

def _convert_to_tree(self, sentence: str, parse_result: Any) -> DependencyTree:
"""Convert parser output to tree structure"""
# Debug log inputs
if self.verbose:
logger.info("\nConverting to tree:")
logger.info(f"Words: {parse_result.words}")
logger.info(f"Lemmas: {parse_result.lemmas}")
logger.info(f"POS tags: {parse_result.pos_tags}")

# Create nodes
nodes = [
Node(
Expand All @@ -98,15 +121,37 @@ def _convert_to_tree(self, sentence: str, parse_result: Any) -> DependencyTree:

# Connect nodes
root = None

if self.verbose:
logger.info("\nConnecting nodes:")
logger.info(f"Head indices: {parse_result.head_indices}")
logger.info(f"Dep labels: {parse_result.dep_labels}")

for idx, (head_idx, dep_label) in enumerate(zip(parse_result.head_indices,
parse_result.dep_labels)):
if head_idx == 0: # Root node
root = nodes[idx]
if self.verbose:
logger.info(f"Found root: {root.word}")
else:
parent = nodes[head_idx - 1] # diaparser uses 1-based indices
parent.add_child(nodes[idx], dep_label)

return DependencyTree(root)
if self.verbose:
logger.info(f"Added {nodes[idx].word} as child of {parent.word} with label {dep_label}")

tree = DependencyTree(root)

# Debug final tree
if self.verbose:
logger.info("\nFinal tree structure:")
for node in tree.root.get_subtree_nodes():
logger.info(f"Node: {node.word}")
if node.parent:
logger.info(f" Parent: {node.parent.word}")
logger.info(f" Children: {[child[0].word for child in node.children]}")

return tree


def _create_node_features(self, node: Node) -> np.ndarray:
from ..utils.feature_utils import FeatureExtractor
Expand Down
1 change: 1 addition & 0 deletions TMN_DataGen/tree/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#__init__.py
from ..utils.logging_config import logger
from .node import Node
from .dependency_tree import DependencyTree
8 changes: 6 additions & 2 deletions TMN_DataGen/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
#utils/__init__.py
from .feature_utils import FeatureExtractor
# TMN_DataGen/TMN_DataGen/utils/__init__.py
from .logging_config import logger
from .feature_utils import FeatureExtractor
from .viz_utils import print_tree_text, visualize_tree_graphviz, format_tree_pair

__all__ = ['logger', 'FeatureExtractor', 'print_tree_text', 'visualize_tree_graphviz', 'format_tree_pair']
Loading

0 comments on commit ef33a90

Please sign in to comment.