Skip to content

Commit

Permalink
Refactor test
Browse files Browse the repository at this point in the history
  • Loading branch information
AngledLuffa committed Jan 11, 2024
1 parent 551a854 commit bfbf963
Showing 1 changed file with 26 additions and 23 deletions.
49 changes: 26 additions & 23 deletions stanza/tests/lemma_classifier/test_data_preparation.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,19 +183,40 @@
6 soldiers soldier NOUN NNS Number=Plur 0 root 0:root _
"""

def test_convert_one_sentence(tmp_path):
def write_test_dataset(tmp_path, texts, datasets):
ud_path = tmp_path / "ud"
input_path = ud_path / "UD_English-EWT"
output_path = tmp_path / "data" / "lemma_classifier"

os.makedirs(input_path, exist_ok=True)
sample_file = input_path / "en_ewt-ud-train.conllu"
with open(sample_file, "w", encoding="utf-8") as fout:
fout.write(EWT_ONE_SENTENCE)

for text, dataset in zip(texts, datasets):
sample_file = input_path / ("en_ewt-ud-%s.conllu" % dataset)
with open(sample_file, "w", encoding="utf-8") as fout:
fout.write(text)

paths = {"UDBASE": ud_path,
"LEMMA_CLASSIFIER_DATA_DIR": output_path}

return paths

def write_english_test_dataset(tmp_path):
texts = (EWT_TRAIN_SENTENCES, EWT_DEV_SENTENCES, EWT_TEST_SENTENCES)
datasets = prepare_lemma_classifier.SECTIONS
return write_test_dataset(tmp_path, texts, datasets)

def convert_english_dataset(tmp_path):
paths = write_english_test_dataset(tmp_path)
converted_files = prepare_lemma_classifier.process_treebank(paths, "en_ewt", "'s", "AUX", "be|have")
assert len(converted_files) == 3

return converted_files

def test_convert_one_sentence(tmp_path):
texts = [EWT_ONE_SENTENCE]
datasets = ["train"]
paths = write_test_dataset(tmp_path, texts, datasets)

converted_files = prepare_lemma_classifier.process_treebank(paths, "en_ewt", "'s", "AUX", "be|have", ["train"])
assert len(converted_files) == 1

Expand All @@ -207,25 +228,7 @@ def test_convert_one_sentence(tmp_path):
assert upos == ['ADV', 'AUX', 'DET', 'PROPN', 'PROPN', 'NOUN']

def test_convert_dataset(tmp_path):
ud_path = tmp_path / "ud"
input_path = ud_path / "UD_English-EWT"
output_path = tmp_path / "data" / "lemma_classifier"

os.makedirs(input_path, exist_ok=True)

texts = (EWT_TRAIN_SENTENCES, EWT_DEV_SENTENCES, EWT_TEST_SENTENCES)
datasets = prepare_lemma_classifier.SECTIONS

for text, dataset in zip(texts, datasets):
sample_file = input_path / ("en_ewt-ud-%s.conllu" % dataset)
with open(sample_file, "w", encoding="utf-8") as fout:
fout.write(text)

paths = {"UDBASE": ud_path,
"LEMMA_CLASSIFIER_DATA_DIR": output_path}

converted_files = prepare_lemma_classifier.process_treebank(paths, "en_ewt", "'s", "AUX", "be|have")
assert len(converted_files) == 3
converted_files = convert_english_dataset(tmp_path)

text_batches, idx_batches, upos_batches, label_batches, counts, label_decoder, upos_to_id = utils.load_dataset(converted_files[0], get_counts=True, batch_size=10)

Expand Down

0 comments on commit bfbf963

Please sign in to comment.