Skip to content

Commit

Permalink
Make a dataset shuffler which can shuffle per batch without reading a…
Browse files Browse the repository at this point in the history
…ll of the batches into memory first. Saves memory on some of the excessively large datasets, such as DE_HDT
  • Loading branch information
AngledLuffa committed Nov 9, 2023
1 parent 27983ae commit 748b6e6
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 12 deletions.
43 changes: 43 additions & 0 deletions stanza/models/pos/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,4 +280,47 @@ def resolve_none(data):
data[sent_idx][tok_idx][feat_idx] = '_'
return data

class ShuffledDataset:
"""A wrapper around one or more datasets which shuffles the data in batch_size chunks
This means that if multiple datasets are passed in, the batches
from each dataset are shuffled together, with one batch being
entirely members of the same dataset.
The main use case of this is that in the tagger, there are cases
where batches from different datasets will have different
properties, such as having or not having UPOS tags. We found that
it is actually somewhat tricky to make the model's loss function
(in model.py) properly represent batches with mixed w/ and w/o
property, whereas keeping one entire batch together makes it a lot
easier to process.
The mechanism for the shuffling is that the iterator first makes a
list long enough to represent each batch from each dataset,
tracking the index of the dataset it is coming from, then shuffles
that list. Another alternative would be to use a weighted
randomization approach, but this is very simple and the memory
requirements are not too onerous.
Note that the batch indices are wasteful in the case of only one
underlying dataset, which is actually the most common use case,
but the overhead is small enough that it probably isn't worth
special casing the one dataset version.
"""
def __init__(self, datasets, batch_size):
self.batch_size = batch_size
self.datasets = datasets
self.loaders = [x.to_loader(batch_size=self.batch_size, shuffle=True) for x in self.datasets]

def __iter__(self):
iterators = [iter(x) for x in self.loaders]
lengths = [len(x) for x in self.loaders]
indices = [[x] * y for x, y in enumerate(lengths)]
indices = [idx for inner in indices for idx in inner]
random.shuffle(indices)

for idx in indices:
yield(next(iterators[idx]))

def __len__(self):
return sum(len(x) for x in self.datasets)
14 changes: 3 additions & 11 deletions stanza/models/tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch import nn, optim

import stanza.models.pos.data as data
from stanza.models.pos.data import Dataset
from stanza.models.pos.data import Dataset, ShuffledDataset
from stanza.models.pos.trainer import Trainer
from stanza.models.pos import scorer
from stanza.models.common import utils
Expand Down Expand Up @@ -205,8 +205,7 @@ def load_training_data(args, pretrain):
for td in train_data:
td.has_feats = True
# calculate the batches
train_batches = [i.to_loader(batch_size=args["batch_size"], shuffle=True)
for i in train_data]
train_batches = ShuffledDataset(train_data, args["batch_size"])
return vocab, train_data, train_batches

def train(args):
Expand Down Expand Up @@ -284,14 +283,7 @@ def train(args):
trainer.model.log_norms()
while True:
do_break = False
# we now merge all train batches together into one giant list
# this allows us to mix batches which have or don't have individual training columns,
# such as if XPOS or UPOS are missing from a training file,
# as we shuffle all of those batches together
# the downside being that it loses the efficiency benefit of the pytorch dataloader
all_train_batches = [x for train_batch in train_batches for x in iter(train_batch)]
random.shuffle(all_train_batches)
for i, batch in enumerate(all_train_batches):
for i, batch in enumerate(train_batches):
start_time = time.time()
global_step += 1
loss = trainer.update(batch, eval=False) # update step
Expand Down
49 changes: 48 additions & 1 deletion stanza/tests/pos/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from stanza.models.common.doc import *
from stanza.models import tagger
from stanza.models.pos.data import Dataset
from stanza.models.pos.data import Dataset, ShuffledDataset
from stanza.utils.conll import CoNLL

from stanza.tests.pos.test_tagger import TRAIN_DATA, TRAIN_DATA_NO_XPOS, TRAIN_DATA_NO_UPOS, TRAIN_DATA_NO_FEATS
Expand Down Expand Up @@ -127,3 +127,50 @@ def test_sometimes_augment():
assert count_without > 5


NO_XPOS_TEMPLATE = """
# text = Noxpos {indexp}
# sent_id = {index}
1 Noxpos noxpos NOUN _ Number=Sing 0 root _ start_char=0|end_char=8|ner=O
2 {indexp} {indexp} NUM _ NumForm=Digit|NumType=Card 1 dep _ start_char=9|end_char=10|ner=S-CARDINAL
""".strip()

YES_XPOS_TEMPLATE = """
# text = Yesxpos {indexp}
# sent_id = {index}
1 Yesxpos yesxpos NOUN NN Number=Sing 0 root _ start_char=0|end_char=8|ner=O
2 {indexp} {indexp} NUM CD NumForm=Digit|NumType=Card 1 dep _ start_char=9|end_char=10|ner=S-CARDINAL
""".strip()

def test_shuffle(tmp_path):
args = tagger.parse_args(args=["--batch_size", "10", "--shorthand", "en_test", "--augment_nopunct", "0.0"])

# 100 looked nice but was actually a 1/1000000 chance of the test failing
# so let's crank it up to 1000 and make it 1/10^58
no_xpos = [NO_XPOS_TEMPLATE.format(index=idx, indexp=idx+1) for idx in range(1000)]
no_doc = CoNLL.conll2doc(input_str="\n\n".join(no_xpos))
no_data = Dataset(no_doc, args, None)

yes_xpos = [YES_XPOS_TEMPLATE.format(index=idx, indexp=idx+101) for idx in range(1000)]
yes_doc = CoNLL.conll2doc(input_str="\n\n".join(yes_xpos))
yes_data = Dataset(yes_doc, args, None)

shuffled = ShuffledDataset([no_data, yes_data], 10)

assert sum(1 for _ in shuffled) == 200

num_with = 0
num_without = 0
for batch in shuffled:
if batch.xpos is not None:
num_with += 1
else:
num_without += 1
# at the halfway point of the iteration, there should be at
# least one in each category
# for example, if we had forgotten to shuffle, this assertion would fail
if num_with + num_without == 100:
assert num_with > 1
assert num_without > 1

assert num_with == 100
assert num_without == 100

0 comments on commit 748b6e6

Please sign in to comment.