Skip to content

Commit

Permalink
Make a dataset shuffler which can shuffle per batch without reading a…
Browse files Browse the repository at this point in the history
…ll of the batches into memory first. Saves memory on some of the excessively large datasets, such as DE_HDT
  • Loading branch information
AngledLuffa committed Nov 8, 2023
1 parent b1e2991 commit 9721b52
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 11 deletions.
42 changes: 42 additions & 0 deletions stanza/models/pos/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,4 +280,46 @@ def resolve_none(data):
data[sent_idx][tok_idx][feat_idx] = '_'
return data

class ShuffledDataset:
"""A wrapper around one or more datasets which shuffles the data in batch_size chunks
This means that if multiple datasets are passed in, the batches
from each dataset are shuffled together, with one batch being
entirely members of the same dataset.
The main use case of this is that in the tagger, there are cases
where batches from different datasets will have different
properties, such as having or not having UPOS tags. We found that
it is actually somewhat tricky to make the model's loss function
(in model.py) properly represent batches with mixed w/ and w/o
property, whereas keeping one entire batch together makes it a lot
easier to process.
The mechanism for the shuffling is that the iterator first makes a
list long enough to represent each batch from each dataset,
tracking the index of the dataset it is coming from, then shuffles
that list. Another alternative would be to use a weighted
randomization approach, but this is very simple and the memory
requirements are not too onerous.
Note that the batch indices are wasteful in the case of only one
underlying dataset, which is actually the most common use case,
but the overhead is small enough that it probably isn't worth
special casing the one dataset version.
"""
def __init__(self, datasets, batch_size):
self.batch_size = batch_size
self.datasets = datasets
self.loaders = [x.to_loader(batch_size=self.batch_size, shuffle=True) for x in self.datasets]

def __iter__(self):
iterators = [iter(x) for x in self.loaders]
lengths = [len(x) for x in self.loaders]
indices = [[x] * y for x, y in enumerate(lengths)]
indices = [idx for inner in indices for idx in inner]

for idx in indices:
yield(next(iterators[idx]))

def __len__(self):
return sum(len(x) for x in self.datasets)
14 changes: 3 additions & 11 deletions stanza/models/tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch import nn, optim

import stanza.models.pos.data as data
from stanza.models.pos.data import Dataset
from stanza.models.pos.data import Dataset, ShuffledDataset
from stanza.models.pos.trainer import Trainer
from stanza.models.pos import scorer
from stanza.models.common import utils
Expand Down Expand Up @@ -205,8 +205,7 @@ def load_training_data(args, pretrain):
for td in train_data:
td.has_feats = True
# calculate the batches
train_batches = [i.to_loader(batch_size=args["batch_size"], shuffle=True)
for i in train_data]
train_batches = ShuffledDataset(train_data, args["batch_size"])
return vocab, train_data, train_batches

def train(args):
Expand Down Expand Up @@ -284,14 +283,7 @@ def train(args):
trainer.model.log_norms()
while True:
do_break = False
# we now merge all train batches together into one giant list
# this allows us to mix batches which have or don't have individual training columns,
# such as if XPOS or UPOS are missing from a training file,
# as we shuffle all of those batches together
# the downside being that it loses the efficiency benefit of the pytorch dataloader
all_train_batches = [x for train_batch in train_batches for x in iter(train_batch)]
random.shuffle(all_train_batches)
for i, batch in enumerate(all_train_batches):
for i, batch in enumerate(train_batches):
start_time = time.time()
global_step += 1
loss = trainer.update(batch, eval=False) # update step
Expand Down

0 comments on commit 9721b52

Please sign in to comment.