diff --git a/stanza/models/pos/data.py b/stanza/models/pos/data.py index 567d2605e7..783cfdcf96 100644 --- a/stanza/models/pos/data.py +++ b/stanza/models/pos/data.py @@ -280,4 +280,46 @@ def resolve_none(data): data[sent_idx][tok_idx][feat_idx] = '_' return data +class ShuffledDataset: + """A wrapper around one or more datasets which shuffles the data in batch_size chunks + + This means that if multiple datasets are passed in, the batches + from each dataset are shuffled together, with one batch being + entirely members of the same dataset. + + The main use case of this is that in the tagger, there are cases + where batches from different datasets will have different + properties, such as having or not having UPOS tags. We found that + it is actually somewhat tricky to make the model's loss function + (in model.py) properly represent batches with mixed w/ and w/o + property, whereas keeping one entire batch together makes it a lot + easier to process. + + The mechanism for the shuffling is that the iterator first makes a + list long enough to represent each batch from each dataset, + tracking the index of the dataset it is coming from, then shuffles + that list. Another alternative would be to use a weighted + randomization approach, but this is very simple and the memory + requirements are not too onerous. + + Note that the batch indices are wasteful in the case of only one + underlying dataset, which is actually the most common use case, + but the overhead is small enough that it probably isn't worth + special casing the one dataset version. + """ + def __init__(self, datasets, batch_size): + self.batch_size = batch_size + self.datasets = datasets + self.loaders = [x.to_loader(batch_size=self.batch_size, shuffle=True) for x in self.datasets] + def __iter__(self): + iterators = [iter(x) for x in self.loaders] + lengths = [len(x) for x in self.loaders] + indices = [[x] * y for x, y in enumerate(lengths)] + indices = [idx for inner in indices for idx in inner] + + for idx in indices: + yield(next(iterators[idx])) + + def __len__(self): + return sum(len(x) for x in self.datasets) diff --git a/stanza/models/tagger.py b/stanza/models/tagger.py index 7212538edd..a315154293 100644 --- a/stanza/models/tagger.py +++ b/stanza/models/tagger.py @@ -19,7 +19,7 @@ from torch import nn, optim import stanza.models.pos.data as data -from stanza.models.pos.data import Dataset +from stanza.models.pos.data import Dataset, ShuffledDataset from stanza.models.pos.trainer import Trainer from stanza.models.pos import scorer from stanza.models.common import utils @@ -205,8 +205,7 @@ def load_training_data(args, pretrain): for td in train_data: td.has_feats = True # calculate the batches - train_batches = [i.to_loader(batch_size=args["batch_size"], shuffle=True) - for i in train_data] + train_batches = ShuffledDataset(train_data, args["batch_size"]) return vocab, train_data, train_batches def train(args): @@ -284,14 +283,7 @@ def train(args): trainer.model.log_norms() while True: do_break = False - # we now merge all train batches together into one giant list - # this allows us to mix batches which have or don't have individual training columns, - # such as if XPOS or UPOS are missing from a training file, - # as we shuffle all of those batches together - # the downside being that it loses the efficiency benefit of the pytorch dataloader - all_train_batches = [x for train_batch in train_batches for x in iter(train_batch)] - random.shuffle(all_train_batches) - for i, batch in enumerate(all_train_batches): + for i, batch in enumerate(train_batches): start_time = time.time() global_step += 1 loss = trainer.update(batch, eval=False) # update step