Skip to content

Commit

Permalink
fixes for tests and configs
Browse files Browse the repository at this point in the history
  • Loading branch information
jlunder00 committed Nov 25, 2024
1 parent 2d505cc commit 6a84bbf
Show file tree
Hide file tree
Showing 10 changed files with 72 additions and 89 deletions.
2 changes: 2 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
recursive-include TMN_DataGen/configs *.yaml

File renamed without changes.
File renamed without changes.
File renamed without changes.
4 changes: 3 additions & 1 deletion TMN_DataGen/dataset_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .tree import DependencyTree
from .utils.viz_utils import format_tree_pair
from .utils.logging_config import logger
from importlib.resources import files

class DatasetGenerator:
def __init__(self):
Expand All @@ -27,7 +28,8 @@ def _load_configs(
override_pkg_config: Optional[Union[str, Dict]] = None
) -> Dict:
"""Load and merge configurations"""
config_dir = Path(__file__).parent / 'configs'
config_dir = Path(files('TMN_DataGen').joinpath('configs'))
# config_dir = Path(__file__).parent / 'configs'

# Load default configs
with open(config_dir / 'default_package_config.yaml') as f:
Expand Down
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,19 @@

setup(
name="TMN_DataGen",
version='0.4.2',
version='0.4.5',
description="Tree Matching Network Data Generator",
author="toast",
packages=find_packages(),
install_requires=requirements,
extras_require={
'stanza': ['stanza>=1.4.0'],
'stanza': ['stanza>=1.2.3'],
'all': [
'stanza>=1.2.3',
'regex>=2022.1.18',
'unicodedata2>=15.0.0'
]
},
include_package_data=True,
zip_safe=False,
)
39 changes: 22 additions & 17 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,36 @@
import pytest
import yaml
from pathlib import Path
from omegaconf import OmegaConf

# @pytest.fixture
# def base_config():
# """Load default configs for testing"""
# config_dir = Path(__file__).parent.parent / 'TMN_DataGen' / 'configs'
#
# with open(config_dir / 'default_parser_config.yaml') as f:
# config = yaml.safe_load(f)
#
# with open(config_dir / 'default_preprocessing_config.yaml') as f:
# config.update(yaml.safe_load(f))
#
# config['verbose'] = 'debug' # Use debug for tests
# return config
@pytest.fixture
def default_config():
"""Load all default configs for testing"""
config_dir = Path(__file__).parent.parent / 'TMN_DataGen' / 'configs'

# Load package config
with open(config_dir / 'default_package_config.yaml') as f:
pkg_config = yaml.safe_load(f)

# Load and merge parser and preprocessing configs
with open(config_dir / 'default_parser_config.yaml') as f:
config = yaml.safe_load(f)

with open(config_dir / 'default_preprocessing_config.yaml') as f:
preproc = yaml.safe_load(f)
config.update(preproc)

return OmegaConf.create(config), pkg_config

@pytest.fixture
@pytest.fixture
def sample_data():
"""Sample text data for testing"""
return {
'sentence_pairs': [
('The cat chases the mouse.', 'The mouse is being chased by the cat.'),
('The dog barks.', 'The cat meows.'),
('Birds fly in the sky.', 'The birds are on the ground.')
('The dog barks.', 'The cat meows.')
],
'labels': ['entails', 'neutral', 'contradicts']
'labels': ['entails', 'neutral']
}

@pytest.fixture
Expand Down
41 changes: 19 additions & 22 deletions tests/test_dataset_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
from TMN_DataGen import DatasetGenerator
import pickle

def test_dataset_generation(sample_data, tmp_path):
def test_dataset_generation(sample_data, default_config, tmp_path):
"""Test basic dataset generation workflow"""
config, pkg_config = default_config

# Initialize generator
generator = DatasetGenerator()
output_path = tmp_path / "test_dataset.pkl"
Expand All @@ -13,37 +15,32 @@ def test_dataset_generation(sample_data, tmp_path):
generator.generate_dataset(
sentence_pairs=sample_data['sentence_pairs'],
labels=sample_data['labels'],
output_path=str(output_path)
output_path=str(output_path),
parser_config=config,
verbosity='normal'
)

# Verify output
assert output_path.exists()
with open(output_path, 'rb') as f:
dataset = pickle.load(f)
assert 'graph_pairs' in dataset
assert 'labels' in dataset
assert len(dataset['graph_pairs']) == len(sample_data['labels'])

def test_config_override(sample_data, tmp_path):
def test_config_override(sample_data, default_config, tmp_path):
"""Test config override functionality"""
config, pkg_config = default_config

# Modify config
config.parser.feature_sources.update({
'tree_structure': 'diaparser',
'pos_tags': 'spacy'
})

generator = DatasetGenerator()
output_path = tmp_path / "test_dataset.pkl"

parser_config = {
'parser': {
'type': 'multi',
'feature_sources': {
'tree_structure': 'diaparser',
'pos_tags': 'spacy'
}
}
}

# Should work with custom config
generator.generate_dataset(
sentence_pairs=sample_data['sentence_pairs'],
labels=sample_data['labels'],
labels=sample_data['labels'],
output_path=str(output_path),
parser_config=parser_config,
parser_config=config,
verbosity='debug'
)


32 changes: 12 additions & 20 deletions tests/test_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,34 @@
import pytest
from TMN_DataGen.parsers import MultiParser, DiaParserTreeParser

def test_diaparser_basic():
def test_diaparser_basic(default_config):
"""Test basic DiaParser functionality"""
parser = DiaParserTreeParser()
config, _ = default_config
parser = DiaParserTreeParser(config)
sentence = "The cat chases the mouse."
tree = parser.parse_single(sentence)

# Verify tree structure
assert tree.root is not None
nodes = tree.root.get_subtree_nodes()
assert len(nodes) == 5

# Verify dependency labels
root_node = tree.root
assert root_node.word == 'chases'
assert root_node.dependency_to_parent is None

def test_multi_parser():
def test_multi_parser(default_config):
"""Test MultiParser feature combination"""
parser = MultiParser()
config, pkg_config = default_config
parser = MultiParser(config, pkg_config)
sentence = "The cat chases the mouse."
tree = parser.parse_single(sentence)

nodes = tree.root.get_subtree_nodes()

# Verify features from both parsers
for node in nodes:
# From spaCy
assert node.pos_tag is not None
assert node.lemma is not None

# From DiaParser
assert hasattr(node, 'dependency_to_parent')
assert node.pos_tag is not None # From spaCy
assert hasattr(node, 'dependency_to_parent') # From DiaParser

def test_preprocessing():
def test_preprocessing(default_config):
"""Test preprocessing pipeline"""
parser = MultiParser()
config, pkg_config = default_config
parser = MultiParser(config, pkg_config)
text = "Hello, World! "
processed = parser.preprocess_and_tokenize(text)
assert processed == ['hello', 'world']

38 changes: 11 additions & 27 deletions tests/test_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,41 +3,25 @@
from TMN_DataGen.utils import RegexTokenizer, StanzaTokenizer
import pytest

def test_basic_preprocessing():
def test_basic_preprocessing(default_config):
"""Test basic preprocessing features"""
config = {
'preprocessing': {
'strictness_level': 1,
'tokenizer': 'regex',
'preserve_case': False
}
}

config, _ = default_config
preprocessor = BasePreprocessor(config)
assert preprocessor.preprocess("Hello, World! ") == "hello world"
def test_unicode_handling():

def test_unicode_handling(default_config):
"""Test unicode normalization"""
config = {
'preprocessing': {
'strictness_level': 2,
'normalize_unicode': True
}
}

config, _ = default_config
config.preprocessing.strictness_level = 2
config.preprocessing.normalize_unicode = True
preprocessor = BasePreprocessor(config)
assert preprocessor.preprocess("café") == "cafe"

@pytest.mark.skipif(not pytest.importorskip("stanza"), reason="stanza not installed")
def test_stanza_tokenizer():
@pytest.mark.skipif(not pytest.importorskip("stanza"), reason="stanza not installed")
def test_stanza_tokenizer(default_config):
"""Test Stanza tokenizer if available"""
config = {
'preprocessing': {
'tokenizer': 'stanza',
'language': 'en'
}
}

config, _ = default_config
config.preprocessing.tokenizer = 'stanza'
tokenizer = StanzaTokenizer(config)
tokens = tokenizer.tokenize("Hello, world!")
assert tokens == ['Hello', ',', 'world', '!']

0 comments on commit 6a84bbf

Please sign in to comment.