Skip to content

Commit

Permalink
fixes and improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
jlunder00 committed Nov 23, 2024
1 parent a0481a2 commit 679ac8c
Show file tree
Hide file tree
Showing 24 changed files with 914 additions and 107 deletions.
3 changes: 3 additions & 0 deletions TMN_DataGen/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#TMN_DataGen/TMN_DataGen/__init__.py
from .tree.node import Node
from .tree.dependency_tree import DependencyTree
from .parsers.diaparser_impl import DiaParserTreeParser
from .parsers.spacy_impl import SpacyTreeParser
from .dataset_generator import DatasetGenerator
from .parsers.multi_parser import MultiParser
from .utils.feature_utils import FeatureExtractor
10 changes: 6 additions & 4 deletions TMN_DataGen/dataset_generator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#dataset_generator.py
from typing import List, Tuple, Optional, Dict
from omegaconf import DictConfig
from .parsers import DiaParserTreeParser, SpacyTreeParser
from .parsers import DiaParserTreeParser, SpacyTreeParser, MultiParser
from .tree import DependencyTree
import torch
import numpy as np
Expand All @@ -12,11 +13,12 @@
class DatasetGenerator:
def __init__(self, config: Optional[DictConfig] = None):
self.config = config or {}
parser_type = self.config.get('parser_type', 'diaparser')
parser_type = self.config.get('parser', {}).get('type', 'diaparser')
self.parser = {
"diaparser": DiaParserTreeParser,
"spacy": SpacyTreeParser
}[parser_type](config)
"spacy": SpacyTreeParser,
"multi": MultiParser
}[parser_type]

self.label_map = {
'entails': 1,
Expand Down
1 change: 1 addition & 0 deletions TMN_DataGen/edge.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Definition of the Edge object, which is a connection between two nodes
"""
#this may not be necessary anymore

import sys, os, re, json, typing

Expand Down
3 changes: 3 additions & 0 deletions TMN_DataGen/parsers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
#parsers/__init__.py

from .base_parser import BaseTreeParser
from .diaparser_impl import DiaParserTreeParser
from .spacy_impl import SpacyTreeParser
from .multi_parser import MultiParser
2 changes: 2 additions & 0 deletions TMN_DataGen/parsers/base_parser.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#base_parser.py

from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional
from ..tree.dependency_tree import DependencyTree
Expand Down
18 changes: 18 additions & 0 deletions TMN_DataGen/parsers/diaparser_impl.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
#diaparser_impl.py

from .base_parser import BaseTreeParser
from ..tree.node import Node
from ..tree.dependency_tree import DependencyTree
from diaparser.parsers import Parser
from typing import List, Any, Optional, Tuple
from omegaconf import DictConfig
import numpy as np

class DiaParserTreeParser(BaseTreeParser):
def __init__(self, config: Optional[DictConfig] = None):
Expand Down Expand Up @@ -104,3 +107,18 @@ def _convert_to_tree(self, sentence: str, parse_result: Any) -> DependencyTree:
parent.add_child(nodes[idx], dep_label)

return DependencyTree(root)

def _create_node_features(self, node: Node) -> np.ndarray:
from ..utils.feature_utils import FeatureExtractor
extractor = FeatureExtractor(self.config)
features = extractor.create_node_features(
node,
self.config.get('feature_extraction', {})
)
return features.numpy()

def _create_edge_features(self, dependency_type: str) -> np.ndarray:
from ..utils.feature_utils import FeatureExtractor
extractor = FeatureExtractor(self.config)
features = extractor.create_edge_features(dependency_type)
return features.numpy()
68 changes: 33 additions & 35 deletions TMN_DataGen/parsers/multi_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,50 +8,48 @@

class MultiParser(BaseTreeParser):
def __init__(self, config: Optional[DictConfig] = None):
self.parsers = {}
super().__init__(config)
if not hasattr(self, 'initialized'):
self.parsers = {}
self.config = config or {}

# Initialize requested parsers
parser_configs = self.config.get('parsers', {
'diaparser': {'enabled': True},
'spacy': {'enabled': False}
})

if parser_configs.get('diaparser', {}).get('enabled', True):
self.parsers['diaparser'] = DiaParserTreeParser(
parser_configs.get('diaparser', {})
)

if parser_configs.get('spacy', {}).get('enabled', False):
self.parsers['spacy'] = SpacyTreeParser(
parser_configs.get('spacy', {})
)

# Configure which parser to use for which features
self.feature_sources = self.config.get('feature_sources', {
'tree_structure': 'diaparser', # Use diaparser for basic tree structure
'pos_tags': 'spacy', # Use spacy for POS tags
'morph': 'spacy', # Use spacy for morphological features
'lemmas': 'spacy' # Use spacy for lemmatization
})

self.initialized = True


# Initialize requested parsers
parser_configs = self.config.get('parser', {}).get('parsers', {})

if parser_configs.get('diaparser', {}).get('enabled', True):
self.parsers['diaparser'] = DiaParserTreeParser(
self.config
)

if parser_configs.get('spacy', {}).get('enabled', True):
self.parsers['spacy'] = SpacyTreeParser(
self.config
)

# Configure feature sources
self.feature_sources = self.config.get('parser', {}).get('feature_sources', {
'tree_structure': 'diaparser',
'pos_tags': 'spacy',
'morph': 'spacy',
'lemmas': 'spacy'
})

self.initialized = True

def parse_batch(self, sentences: List[str]) -> List[DependencyTree]:
# Get parses from all enabled parsers
parser_results = {
name: parser.parse_batch(sentences)
for name, parser in self.parsers.items()
}
parser_results = {}
for name, parser in self.parsers.items():
parser_results[name] = parser.parse_batch(sentences)

# Combine results into final trees
combined_trees = []
for i in range(len(sentences)):
# Start with the base tree structure from preferred parser
base_parser = self.feature_sources['tree_structure']
if base_parser not in parser_results:
raise ValueError(f"Base parser {base_parser} not available")

base_tree = parser_results[base_parser][i]
base_tree.config = self.config # Propagate config

# Enhance with features from other parsers
self._enhance_tree(
Expand All @@ -62,7 +60,7 @@ def parse_batch(self, sentences: List[str]) -> List[DependencyTree]:
combined_trees.append(base_tree)

return combined_trees

def _enhance_tree(self, base_tree: DependencyTree,
parser_trees: Dict[str, DependencyTree]):
"""Enhance base tree with features from other parsers"""
Expand Down
2 changes: 2 additions & 0 deletions TMN_DataGen/parsers/spacy_impl.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#spacy_impl.py

from .base_parser import BaseTreeParser
from ..tree.node import Node
from ..tree.dependency_tree import DependencyTree
Expand Down
1 change: 1 addition & 0 deletions TMN_DataGen/tree/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
#__init__.py
from .node import Node
from .dependency_tree import DependencyTree
89 changes: 69 additions & 20 deletions TMN_DataGen/tree/dependency_tree.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,55 @@
#dependency_tree.py
from typing import List, Dict, Any, Tuple, Optional
import numpy as np
from omegaconf import DictConfig

try:
from .node import Node
except ImportError:
from node import Node


class DependencyTree:
def __init__(self, root: Node):
def __init__(self, root: Node, config: Optional[DictConfig] = None):
self.root = root
self.config = config or {}

def modify_subtree(self, condition_fn, modification_fn):
"""Apply modification to nodes that meet condition"""
for node in self.root.traverse_preorder():
if condition_fn(node):
modification_fn(node)

# def swap_nodes(self, node1: Node, node2: Node):
# """Swap two nodes while preserving tree structure"""
# # Save parent relationships
# parent1 = node1.parent
# parent2 = node2.parent
# dep1 = node1.dependency_to_parent
# dep2 = node2.dependency_to_parent
#
# # Swap in parents' children lists
# if parent1:
# parent1.replace_child(node1, node2, dep1)
# if parent2:
# parent2.replace_child(node2, node1, dep2)
#
# # Handle case where one is parent of the other
# if node1 in [child for child, _ in node2.children]:
# idx = [child for child, _ in node2.children].index(node1)
# node2.children[idx] = (node1, node2.children[idx][1])
# elif node2 in [child for child, _ in node1.children]:
# idx = [child for child, _ in node1.children].index(node2)
# node1.children[idx] = (node2, node1.children[idx][1])
#
# # Swap children lists
# node1.children, node2.children = node2.children, node1.children
#
# # Update root if needed
# if self.root == node1:
# self.root = node2
# elif self.root == node2:
# self.root = node1

def swap_nodes(self, node1: Node, node2: Node):
"""Swap two nodes while preserving tree structure"""
# Save parent relationships
Expand All @@ -25,22 +58,31 @@ def swap_nodes(self, node1: Node, node2: Node):
dep1 = node1.dependency_to_parent
dep2 = node2.dependency_to_parent

# Swap in parents' children lists
# Save children lists
children1 = node1.children.copy()
children2 = node2.children.copy()

# Remove from current parents
if parent1:
parent1.replace_child(node1, node2, dep1)
parent1.remove_child(node1)
if parent2:
parent2.replace_child(node2, node1, dep2)

# Handle case where one is parent of the other
if node1 in [child for child, _ in node2.children]:
idx = [child for child, _ in node2.children].index(node1)
node2.children[idx] = (node1, node2.children[idx][1])
elif node2 in [child for child, _ in node1.children]:
idx = [child for child, _ in node1.children].index(node2)
node1.children[idx] = (node2, node1.children[idx][1])

# Swap children lists
node1.children, node2.children = node2.children, node1.children
parent2.remove_child(node2)

# Add to new parents
if parent1:
parent1.add_child(node2, dep1)
if parent2:
parent2.add_child(node1, dep2)

# Update children
node1.children = children2
node2.children = children1

# Update parent references in children
for child, _ in node1.children:
child.parent = node1
for child, _ in node2.children:
child.parent = node2

# Update root if needed
if self.root == node1:
Expand Down Expand Up @@ -79,12 +121,19 @@ def to_graph_data(self) -> Dict[str, np.ndarray]:
}

def _create_node_features(self, node: Node) -> np.ndarray:
"""Convert node to feature vector - implement in subclass"""
raise NotImplementedError
from ..utils.feature_utils import FeatureExtractor
extractor = FeatureExtractor(self.config)
features = extractor.create_node_features(
node,
self.config.get('feature_extraction', {})
)
return features.numpy()

def _create_edge_features(self, dependency_type: str) -> np.ndarray:
"""Convert dependency type to feature vector - implement in subclass"""
raise NotImplementedError
from ..utils.feature_utils import FeatureExtractor
extractor = FeatureExtractor(self.config)
features = extractor.create_edge_features(dependency_type)
return features.numpy()

def to_dict(self) -> Dict:
"""Convert tree to dictionary representation"""
Expand Down
2 changes: 2 additions & 0 deletions TMN_DataGen/tree/node.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#node.py

from typing import List, Tuple, Optional, Dict, Any
import numpy as np

Expand Down
3 changes: 2 additions & 1 deletion TMN_DataGen/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .feature_utils import create_node_features, create_edge_features
#utils/__init__.py
from .feature_utils import FeatureExtractor
Loading

0 comments on commit 679ac8c

Please sign in to comment.