Skip to content

Commit

Permalink
add preprocessing and tokenization
Browse files Browse the repository at this point in the history
  • Loading branch information
jlunder00 committed Nov 25, 2024
1 parent 0d668d2 commit 1d4b9a5
Show file tree
Hide file tree
Showing 12 changed files with 274 additions and 5 deletions.
5 changes: 4 additions & 1 deletion TMN_DataGen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@
from .parsers.multi_parser import MultiParser
from .utils.feature_utils import FeatureExtractor
from .utils.viz_utils import print_tree_text, visualize_tree_graphviz, format_tree_pair
from .utils.text_preprocessing import BasePreprocessor
from .utils.tokenizers import BaseTokenizer, RegexTokenizer, StanzaTokenizer

__all__ = [
'Node', 'DependencyTree',
'DiaParserTreeParser', 'SpacyTreeParser', 'MultiParser',
'DatasetGenerator', 'FeatureExtractor',
'print_tree_text', 'visualize_tree_graphviz', 'format_tree_pair'
'print_tree_text', 'visualize_tree_graphviz', 'format_tree_pair',
'BasePreprocessor', 'BaseTokenizer', 'RegexTokenizer', 'StanzaTokenizer'
]
10 changes: 8 additions & 2 deletions TMN_DataGen/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
# TMN_DataGen/TMN_DataGen/utils/__init__.py
from .logging_config import logger
from .feature_utils import FeatureExtractor
from .feature_utils import FeatureExtractor
from .viz_utils import print_tree_text, visualize_tree_graphviz, format_tree_pair
from .text_preprocessing import BasePreprocessor
from .tokenizers import BaseTokenizer, RegexTokenizer, StanzaTokenizer

__all__ = ['logger', 'FeatureExtractor', 'print_tree_text', 'visualize_tree_graphviz', 'format_tree_pair']
__all__ = [
'logger', 'FeatureExtractor', 'print_tree_text', 'visualize_tree_graphviz',
'format_tree_pair', 'BasePreprocessor', 'BaseTokenizer', 'RegexTokenizer',
'StanzaTokenizer'
]
1 change: 0 additions & 1 deletion TMN_DataGen/utils/logging_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# TMN_DataGen/TMN_DataGen/utils/logging_config.py
# TMN_DataGen/TMN_DataGen/utils/logging_config.py
import logging
import sys
from typing import Optional
Expand Down
63 changes: 63 additions & 0 deletions TMN_DataGen/utils/text_preprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# TMN_DataGen/TMN_DataGen/utils/text_preprocessing.py
from abc import ABC, abstractmethod
import unicodedata
import re
from typing import List

class BasePreprocessor:
"""Base text preprocessing with configurable strictness"""

def __init__(self, config):
self.config = config
self.strictness = config.preprocessing.strictness_level

def preprocess(self, text: str) -> str:
"""Apply preprocessing based on strictness level"""
if self.strictness == 0:
return text

# Basic (level 1)
if self.strictness >= 1:
text = self._basic_cleanup(text)

# Medium (level 2)
if self.strictness >= 2:
text = self._medium_cleanup(text)

# Strict (level 3)
if self.strictness >= 3:
text = self._strict_cleanup(text)

return text

def _basic_cleanup(self, text: str) -> str:
"""Basic normalization"""
# Normalize whitespace
text = re.sub(r'\s+', ' ', text)
text = text.strip()

if self.config.preprocessing.remove_punctuation:
text = re.sub(r'[^\w\s]', ' ', text)

return text

def _medium_cleanup(self, text: str) -> str:
"""Medium level cleanup"""
if self.config.preprocessing.normalize_unicode:
# Normalize unicode characters
text = unicodedata.normalize('NFKD', text)
# Remove non-ASCII
text = re.sub(r'[^\x00-\x7F]+', '', text)

return text

def _strict_cleanup(self, text: str) -> str:
"""Strict cleanup"""
# Remove accents
text = ''.join(c for c in unicodedata.normalize('NFD', text)
if unicodedata.category(c) != 'Mn')

if not self.config.preprocessing.preserve_case:
text = text.lower()

return text
42 changes: 42 additions & 0 deletions TMN_DataGen/utils/tokenizers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# TMN_DataGen/TMN_DataGen/utils/tokenizers.py
from abc import ABC, abstractmethod
import re
import stanza
from typing import List

class BaseTokenizer(ABC):
@abstractmethod
def tokenize(self, text: str) -> List[str]:
pass

class RegexTokenizer(BaseTokenizer):
def __init__(self, config):
self.config = config
self.min_len = config.preprocessing.min_token_length
self.max_len = config.preprocessing.max_token_length

def tokenize(self, text: str) -> List[str]:
# Simple word boundary tokenization
tokens = re.findall(r'\b\w+\b', text)
# Apply length filters
tokens = [t for t in tokens
if self.min_len <= len(t) <= self.max_len]
return tokens

class StanzaTokenizer(BaseTokenizer):
def __init__(self, config):
self.config = config
try:
self.nlp = stanza.Pipeline(
lang=config.preprocessing.language,
processors='tokenize',
use_gpu=True
)
except Exception as e:
raise ValueError(f"Failed to load Stanza: {e}")

def tokenize(self, text: str) -> List[str]:
doc = self.nlp(text)
tokens = [word.text for sent in doc.sentences
for word in sent.words]
return tokens
13 changes: 13 additions & 0 deletions configs/default_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,17 @@ processing:
max_sequence_length: 128
lowercase: true

preprocessing:
strictness_level: 1 # 0-3
tokenizer: "regex" # "regex" or "stanza"
language: "en"
preserve_case: false
# Regex preprocessing options
remove_punctuation: true
normalize_unicode: true
remove_numbers: false
# Tokenizer specific options
max_token_length: 50
min_token_length: 1

verbose: "debug" # Options: null, "normal", "debug"
9 changes: 9 additions & 0 deletions configs/preprocessing_stanza.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# configs/preprocessing_stanza.yaml
preprocessing:
strictness_level: 1
tokenizer: "stanza"
language: "en"
preserve_case: true
remove_punctuation: false
normalize_unicode: false
remove_numbers: false
11 changes: 11 additions & 0 deletions configs/preprocessing_strict.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# configs/preprocessing_strict.yaml
preprocessing:
strictness_level: 3
tokenizer: "regex"
language: "en"
preserve_case: false
remove_punctuation: true
normalize_unicode: true
remove_numbers: true
max_token_length: 50
min_token_length: 2
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,6 @@ transformers>=4.30.0
tqdm>=4.65.0
pytest>=7.0.0
pytest-cov>=4.0.0
regex>=2022.1.18
unicodedata2>=15.0.0
stanza>=1.4.0; python_version >= "3.6" # Optional
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

setup(
name="TMN_DataGen",
version='0.3.10',
version='0.4.0',
description="Tree Matching Network Data Generator",
author="toast",
packages=find_packages(),
Expand Down
48 changes: 48 additions & 0 deletions tests/test_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,54 @@ def test_multi_parser(base_config):
assert node.pos_tag != ""
assert 'morph_features' in node.features

def test_parser_preprocessing(base_config):
"""Test preprocessing integration with parser"""
config = OmegaConf.merge(base_config, {
'parser': {
'type': 'diaparser',
'language': 'en'
},
'preprocessing': {
'strictness_level': 2,
'tokenizer': 'regex',
'remove_punctuation': True
}
})

parser = DiaParserTreeParser(config)

# Test preprocessing pipeline
sentence = "The quick, brown fox!"
tokens = parser.preprocess_and_tokenize(sentence)
assert tokens == ['The', 'quick', 'brown', 'fox']

# Test full parsing with preprocessing
tree = parser.parse_single(sentence)
nodes = tree.root.get_subtree_nodes()
words = [node.word for node in nodes]
assert words == tokens

def test_parser_with_unicode(base_config):
"""Test parser handling of unicode text"""
config = OmegaConf.merge(base_config, {
'parser': {
'type': 'diaparser',
'language': 'en'
},
'preprocessing': {
'strictness_level': 3,
'normalize_unicode': True
}
})

parser = DiaParserTreeParser(config)

sentence = "The café is nice."
tree = parser.parse_single(sentence)
nodes = tree.root.get_subtree_nodes()
words = [node.word for node in nodes]
assert 'cafe' in words # accent removed

class TestMultiParser:
@pytest.fixture
def multi_parser_config(self, base_config):
Expand Down
72 changes: 72 additions & 0 deletions tests/test_preprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#tests/test_preprocessing.py
import pytest
from TMN_DataGen.utils.text_preprocessing import BasePreprocessor
from TMN_DataGen.utils.tokenizers import RegexTokenizer, StanzaTokenizer
from omegaconf import OmegaConf

class TestPreprocessing:
@pytest.fixture
def config(self):
return OmegaConf.create({
'preprocessing': {
'strictness_level': 1,
'tokenizer': 'regex',
'language': 'en',
'preserve_case': False,
'remove_punctuation': True,
'normalize_unicode': True,
'remove_numbers': False,
'max_token_length': 50,
'min_token_length': 1
}
})

def test_strictness_levels(self, config):
# Test STRICTNESS_NONE
config.preprocessing.strictness_level = 0
preprocessor = BasePreprocessor(config)
text = "Hello, World! こんにちは"
assert preprocessor.preprocess(text) == text

# Test STRICTNESS_BASIC
config.preprocessing.strictness_level = 1
preprocessor = BasePreprocessor(config)
result = preprocessor.preprocess("Hello, World! ")
assert result == "Hello World"

# Test STRICTNESS_MEDIUM
config.preprocessing.strictness_level = 2
preprocessor = BasePreprocessor(config)
result = preprocessor.preprocess("Hello, World! こんにちは")
assert result == "Hello World"

# Test STRICTNESS_STRICT
config.preprocessing.strictness_level = 3
preprocessor = BasePreprocessor(config)
result = preprocessor.preprocess("Héllo, Wörld!")
assert result == "hello world"

def test_regex_tokenizer(self, config):
config.preprocessing.tokenizer = 'regex'
tokenizer = RegexTokenizer(config)

# Test basic tokenization
text = "The quick brown fox"
tokens = tokenizer.tokenize(text)
assert tokens == ['The', 'quick', 'brown', 'fox']

# Test length filters
config.preprocessing.min_token_length = 4
tokenizer = RegexTokenizer(config)
tokens = tokenizer.tokenize(text)
assert tokens == ['quick', 'brown']

@pytest.mark.skipif(not pytest.importorskip("stanza"), reason="stanza not installed")
def test_stanza_tokenizer(self, config):
config.preprocessing.tokenizer = 'stanza'
tokenizer = StanzaTokenizer(config)

text = "The quick brown fox."
tokens = tokenizer.tokenize(text)
assert len(tokens) == 5 # includes period
assert tokens[:-1] == ['The', 'quick', 'brown', 'fox']

0 comments on commit 1d4b9a5

Please sign in to comment.