Skip to content

Commit

Permalink
updates to make diaparser work correctly and as expected
Browse files Browse the repository at this point in the history
  • Loading branch information
jlunder00 committed Nov 25, 2024
1 parent 1d4b9a5 commit f594ef2
Show file tree
Hide file tree
Showing 8 changed files with 190 additions and 17 deletions.
71 changes: 70 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ python -m spacy download en_core_web_sm
### DiaParser Models
The default electra-base model will be downloaded automatically, but other options are available.

## Configuration
### Model Configuration

Models can be specified in your configuration file:
```yaml
Expand All @@ -45,3 +45,72 @@ parser:
enabled: true
model_name: "en_core_web_trf" # or any installed SpaCy model
```
## Text Preprocessing and Tokenization
TMN_DataGen provides configurable text preprocessing and tokenization options:
### Preprocessing Levels
The preprocessing pipeline has 4 levels of strictness:
- **Level 0 (None)**: No preprocessing
- **Level 1 (Basic)**: Whitespace normalization and punctuation handling
- **Level 2 (Medium)**: Unicode normalization and non-ASCII removal
- **Level 3 (Strict)**: Case normalization and accent removal
### Tokenization Options
Two tokenization approaches are available:
1. **Regex Tokenizer**: Simple rule-based tokenization using word boundaries
2. **Stanza Tokenizer**: Neural tokenizer with better handling of complex cases
### Configuration
Example configuration in yaml:
```yaml
preprocessing:
strictness_level: 2 # 0-3
tokenizer: "regex" # "regex" or "stanza"
language: "en"
preserve_case: false
remove_punctuation: true
normalize_unicode: true
remove_numbers: false
max_token_length: 50
min_token_length: 1
```
### Example Usage
```python
from TMN_DataGen import DiaParserTreeParser
from omegaconf import OmegaConf

# Load config with strict preprocessing
config = OmegaConf.load('configs/preprocessing_strict.yaml')
parser = DiaParserTreeParser(config)

# Parse with preprocessing
sentence = "The café is nice!"
tree = parser.parse_single(sentence)
# Result: Normalized and tokenized sentence processed into dependency tree
```

for languages that require specialized tokenization, the Stanza tokenizer is recommended:
```yaml
preprocessing:
strictness_level: 1
tokenizer: "stanza"
language: "en"
```
Note: Stanza tokenizer requires additional dependencies. Install with:;
```bash
pip install TMN_DataGen[stanza]
```


17 changes: 17 additions & 0 deletions TMN_DataGen/parsers/base_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from ..tree.dependency_tree import DependencyTree
from ..utils.viz_utils import print_tree_text
from ..utils.logging_config import get_logger
from ..utils.text_preprocessing import BasePreprocessor
from ..utils.tokenizers import RegexTokenizer, StanzaTokenizer
from omegaconf import DictConfig
import torch
from tqdm import tqdm
Expand All @@ -30,6 +32,15 @@ def __init__(self, config: Optional[DictConfig] = None):
name=self.__class__.__name__,
verbose=self.verbose
)

# Initialize preprocessor
self.preprocessor = BasePreprocessor(self.config)

# Initialize Tokenizer
if self.config.preprocessing.tokenizer == "stanza":
self.tokenizer = StanzaTokenizer(self.config)
else:
self.tokenizer = RegexTokenizer(self.config)

self.initialized = True

Expand All @@ -42,6 +53,12 @@ def parse_batch(self, sentences: List[str]) -> List[DependencyTree]:
def parse_single(self, sentence: str) -> DependencyTree:
"""Parse a single sentence into a dependency tree"""
pass

def preprocess_and_tokenize(self, text: str) -> List[str]:
"""Preprocess text and tokenize into words"""
clean_text = self.preprocessor.preprocess(text)
tokens = self.tokenizer.tokenize(clean_text)
return tokens

def parse_all(self, sentences: List[str], show_progress: bool = True) -> List[DependencyTree]:
"""Parse all sentences with batching and progress bar."""
Expand Down
14 changes: 9 additions & 5 deletions TMN_DataGen/parsers/diaparser_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ def ensure_list(val) -> List[str]:
"""Convert various input formats to list"""
if isinstance(val, str):
return val.split()
elif isinstance(val, tuple):
# DiaParser sometimes returns tuples
return ensure_list(val[0])
# elif isinstance(val, tuple):
# # DiaParser sometimes returns tuples
# return ensure_list(val[0])
return list(val)

try:
Expand Down Expand Up @@ -79,11 +79,15 @@ def parse_batch(self, sentences: List[str]) -> List[DependencyTree]:
self.logger.debug(f"Parsing batch of {len(sentences)} sentences")

for sentence in sentences:

# Preprocess and tokenize first
tokens = self.preprocess_and_tokenize(sentence)

if self.verbose == 'debug':
self.logger.debug(f"\nParsing sentence: {sentence}")

# Get DiaParser output
dataset = self.model.predict([sentence])
dataset = self.model.predict([tokens])
token_data = self._process_prediction(dataset)

# Step 1: Create all nodes
Expand Down Expand Up @@ -141,7 +145,7 @@ def parse_batch(self, sentences: List[str]) -> List[DependencyTree]:

if self.verbose in ('normal', 'debug'):
from ..utils.viz_utils import print_tree_text
self.logger.info("\nParsed tree structure:")
self.logger.info("\nDiapaarser parsed tree structure:")
self.logger.info(print_tree_text(tree, self.config))

trees.append(tree)
Expand Down
22 changes: 20 additions & 2 deletions TMN_DataGen/parsers/multi_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,27 @@ def __init__(self, config: Optional[DictConfig] = None):
self.initialized = True

def parse_batch(self, sentences: List[str]) -> List[DependencyTree]:
if self.verbose == 'info' or self.verbose == 'debug':
self.logger.info("Begin processing with multi parser")

# First do preprocessing once
processed_sentences = []
for sentence in sentences:
tokens = self.preprocess_and_tokenize(sentence)
processed_text = ' '.join(tokens)
processed_sentences.append(processed_text)

if self.verbose == 'debug':
self.logger.debug(f"Preprocessed '{sentence}' to '{processed_text}'")

# Get parses from all enabled parsers
parser_results = {}
for name, parser in self.parsers.items():
parser_results[name] = parser.parse_batch(sentences)
parser_results[name] = parser.parse_batch(processed_sentences)

# Combine results into final trees
combined_trees = []
for i in range(len(sentences)):
for i in range(len(processed_sentences)):
# Start with the base tree structure from preferred parser
base_parser = self.feature_sources['tree_structure']
if base_parser not in parser_results:
Expand All @@ -58,6 +71,11 @@ def parse_batch(self, sentences: List[str]) -> List[DependencyTree]:
)

combined_trees.append(base_tree)

if self.verbose == 'debug':
self.logger.debug(f"\nProcessed sentence {i+1}/{len(processed_sentences)}")
self.logger.debug(f"Combined features from {len(parser_results)} parsers")


return combined_trees

Expand Down
45 changes: 41 additions & 4 deletions TMN_DataGen/parsers/spacy_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,41 @@ def __init__(self, config: Optional[DictConfig] = None):
self.model = spacy.load(model_name)

def parse_batch(self, sentences: List[str]) -> List[DependencyTree]:
docs = self.model.pipe(sentences)
return [self._convert_to_tree(doc) for doc in docs]
if self.verbose == 'info' or self.verbose == 'debug':
self.logger.info("Begin Spacy batch processing")
processed_sentences = []
for sentence in sentences:
if self.verbose == 'info' or self.verbose == 'debug':
self.logger.info(f"Processing {sentence} with Spacy")
# Preprocess first
tokens = self.preprocess_and_tokenize(sentence)
# Join tokens for spaCy - it expects a text string
processed_text = ' '.join(tokens)
processed_sentences.append(processed_text)

if self.verbose == 'debug':
self.logger.debug(f"Preprocessed '{sentence}' to '{processed_text}'")

docs = self.model.pipe(processed_sentences)
trees = [self._convert_to_tree(doc) for doc in docs]
if self.verbose == 'debug':
for sent, tree in zip(processed_sentences, trees):
self.logger.debug(f"\nSpacy processed sentence: {sent}")
self.logger.debug(f"Generated Spacy tree with {len(tree.root.get_subtree_nodes())} nodes")


def parse_single(self, sentence: str) -> DependencyTree:
doc = self.model(sentence)
if self.verbose == 'info' or self.verbose == 'debug':
self.logger.info("Begin Spacy single processing")
# Preprocess
if self.verbose == 'info' or self.verbose == 'debug':
self.logger.info(f"Processing {sentence} with Spacy")
tokens = self.preprocess_and_tokenize(sentence)
processed_text = ' '.join(tokens)

if self.verbose == 'debug':
self.logger.debug(f"Preprocessed '{sentence}' to '{processed_text}'")
doc = self.model(processed_text)
return self._convert_to_tree(doc)

def _convert_to_tree(self, doc: Any) -> DependencyTree:
Expand Down Expand Up @@ -46,5 +76,12 @@ def _convert_to_tree(self, doc: Any) -> DependencyTree:
else:
parent = nodes[token.head.i]
parent.add_child(nodes[token.i], token.dep_)

if not root:
raise ValueError("No root node found in parse")

return DependencyTree(root)
tree = DependencyTree(root, self.config)
if self.verbose == 'debug':
self.logger.debug("Tree structure created successfully")

return tree
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ graphviz

torch>=2.0.0
numpy>=1.20.0
diaparser>=1.0.0
diaparser>=1.1.3
spacy>=3.0.0
omegaconf>=2.3.0
transformers>=4.30.0
Expand All @@ -18,4 +18,4 @@ pytest>=7.0.0
pytest-cov>=4.0.0
regex>=2022.1.18
unicodedata2>=15.0.0
stanza>=1.4.0; python_version >= "3.6" # Optional
stanza>=1.2.3; python_version >= "3.6" # Optional
10 changes: 9 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,18 @@

setup(
name="TMN_DataGen",
version='0.4.0',
version='0.4.2',
description="Tree Matching Network Data Generator",
author="toast",
packages=find_packages(),
install_requires=requirements,
extras_require={
'stanza': ['stanza>=1.4.0'],
'all': [
'stanza>=1.2.3',
'regex>=2022.1.18',
'unicodedata2>=15.0.0'
]
},
zip_safe=False,
)
24 changes: 22 additions & 2 deletions tests/test_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,22 @@ def test_diaparser_detailed(base_config):
config = OmegaConf.merge(base_config, {
'parser': {
'type': 'diaparser',
'language': 'en',
'verbose': 'debug'
},
'preprocessing': {
'strictness_level': 2,
'tokenizer': 'regex',
'remove_punctuation': True,
'language': 'en',
'preserve_case': False,
'normalize_unicode': True,
'remove_numbers': False,
'max_token_length': 50,
'min_token_length': 1
}
})

parser = DiaParserTreeParser(config)

# Test simple sentence
Expand All @@ -35,7 +48,8 @@ def test_diaparser_detailed(base_config):
# Verify basic structure
assert len(nodes) == 5 # should have 5 words
assert tree.root.word == "chases" # main verb should be root
assert tree.root.pos_tag == "VERB"

# assert tree.root.pos_tag == "VERB" # Diaparser does not give pos tags or lemmas

# Find subject and object
subj = [n for n, t in tree.root.children if t == "nsubj"][0]
Expand Down Expand Up @@ -112,7 +126,13 @@ def test_parser_preprocessing(base_config):
'preprocessing': {
'strictness_level': 2,
'tokenizer': 'regex',
'remove_punctuation': True
'remove_punctuation': True,
'language': 'en',
'preserve_case': False,
'normalize_unicode': True,
'remove_numbers': False,
'max_token_length': 50,
'min_token_length': 1
}
})

Expand Down

0 comments on commit f594ef2

Please sign in to comment.