Skip to content

Commit

Permalink
Add a test of the shuffling & whether it keeps the has_xpos flag toge…
Browse files Browse the repository at this point in the history
…ther with the data items
  • Loading branch information
AngledLuffa committed Nov 7, 2023
1 parent 35a91bf commit 31b15c1
Showing 1 changed file with 37 additions and 1 deletion.
38 changes: 37 additions & 1 deletion stanza/tests/pos/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
import os
import pytest

import torch

from stanza.models.common.doc import *
from stanza.models import tagger
from stanza.models.pos.data import Dataset
from stanza.models.pos.data import Dataset, merge_datasets
from stanza.utils.conll import CoNLL

from stanza.tests.pos.test_tagger import TRAIN_DATA, TRAIN_DATA_NO_XPOS, TRAIN_DATA_NO_UPOS, TRAIN_DATA_NO_FEATS
Expand Down Expand Up @@ -68,6 +70,40 @@ def test_no_feats():
assert data.has_xpos
assert not data.has_feats

def test_merge_some_xpos():
# batch size 2 for 3 total elements so that it sometimes randomly
# puts the sentence with no xpos in a block by itself and
# sometimes randomly puts it in a block with 1 other element
args = tagger.parse_args(args=['--batch_size', '2'])

train_docs = [CoNLL.conll2doc(input_str=TRAIN_DATA),
CoNLL.conll2doc(input_str=TRAIN_DATA_NO_XPOS)]

# TODO: maybe refactor the reading code in the main body of the tagger for easier testing
vocab = Dataset.init_vocab(train_docs, args)
train_data = [Dataset(i, args, None, vocab=vocab, evaluation=False) for i in train_docs]
lens = list(len(x) for x in train_data)
assert lens == [2, 1]
merged = merge_datasets(train_data)
assert len(merged) == 3
train_batches = merged.to_loader(batch_size=args["batch_size"], shuffle=True)
it_first = 0
for _ in range(200):
for batch_idx, batch in enumerate(iter(train_batches)):
if batch.text[-1][0] == 'It':
if batch_idx == 0:
it_first += 1
# TODO: I would expect this to always be False, unless
# I have misinterpreted the process for the masking,
# but there are times it comes back True instead. I
# think that is an effect of the has_xpos not being
# sorted with everything else by length
print(torch.any(batch.xpos[-1]))

# check that the sentence w/o xpos is sometimes but not always in the first batch
assert it_first > 5
assert it_first < 195

def test_no_augment():
"""
Test that with no punct removing augmentation, the doc always has punct at the end
Expand Down

0 comments on commit 31b15c1

Please sign in to comment.