Skip to content

Commit

Permalink
Make a dataset shuffler which can shuffle per batch without reading a…
Browse files Browse the repository at this point in the history
…ll of the batches into memory first. Saves memory on some of the excessively large datasets, such as DE_HDT
  • Loading branch information
AngledLuffa committed Nov 7, 2023
1 parent 82f7872 commit e4c2273
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 11 deletions.
16 changes: 16 additions & 0 deletions stanza/models/pos/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,4 +280,20 @@ def resolve_none(data):
data[sent_idx][tok_idx][feat_idx] = '_'
return data

class ShuffledDataset:
def __init__(self, datasets, batch_size):
self.batch_size = batch_size
self.datasets = datasets
self.loaders = [x.to_loader(batch_size=self.batch_size, shuffle=True) for x in self.datasets]

def __iter__(self):
iterators = [iter(x) for x in self.loaders]

This comment has been minimized.

Copy link
@Jemoka

Jemoka Nov 7, 2023

Member

This feels like the same exact loading per iter which the previous loop does, I believe. I'm worried that doing this won't solve the OOM issues in German we saw.

This comment has been minimized.

Copy link
@AngledLuffa

AngledLuffa Nov 7, 2023

Author Collaborator

ah, but this is just one integer per batch, as opposed to the entire batch

lengths = [len(x) for x in self.loaders]
indices = [[x] * y for x, y in enumerate(lengths)]
indices = [idx for inner in indices for idx in inner]

for idx in indices:
yield(next(iterators[idx]))

def __len__(self):
return sum(len(x) for x in self.datasets)
14 changes: 3 additions & 11 deletions stanza/models/tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch import nn, optim

import stanza.models.pos.data as data
from stanza.models.pos.data import Dataset
from stanza.models.pos.data import Dataset, ShuffledDataset
from stanza.models.pos.trainer import Trainer
from stanza.models.pos import scorer
from stanza.models.common import utils
Expand Down Expand Up @@ -205,8 +205,7 @@ def load_training_data(args, pretrain):
for td in train_data:
td.has_feats = True
# calculate the batches
train_batches = [i.to_loader(batch_size=args["batch_size"], shuffle=True)
for i in train_data]
train_batches = ShuffledDataset(train_data, args["batch_size"])
return vocab, train_data, train_batches

def train(args):
Expand Down Expand Up @@ -284,14 +283,7 @@ def train(args):
trainer.model.log_norms()
while True:
do_break = False
# we now merge all train batches together into one giant list
# this allows us to mix batches which have or don't have individual training columns,
# such as if XPOS or UPOS are missing from a training file,
# as we shuffle all of those batches together
# the downside being that it loses the efficiency benefit of the pytorch dataloader
all_train_batches = [x for train_batch in train_batches for x in iter(train_batch)]
random.shuffle(all_train_batches)
for i, batch in enumerate(all_train_batches):
for i, batch in enumerate(train_batches):
start_time = time.time()
global_step += 1
loss = trainer.update(batch, eval=False) # update step
Expand Down

0 comments on commit e4c2273

Please sign in to comment.