From 0ddd7056cce3299989851d91a0458f0d0504f046 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Thu, 28 Nov 2024 21:02:06 -0800 Subject: [PATCH] Convert the MWT training to use a pytorch dataloader with shuffling In theory this should also provide some cpu/gpu parallelism at test time, although we haven't done anything to ensure it is using multiprocessing Fix the max_steps count by counting batches, not samples --- stanza/models/mwt/data.py | 85 ++++++++++++------- stanza/models/mwt_expander.py | 10 ++- stanza/pipeline/mwt_processor.py | 2 +- stanza/tests/mwt/test_character_classifier.py | 2 +- 4 files changed, 61 insertions(+), 38 deletions(-) diff --git a/stanza/models/mwt/data.py b/stanza/models/mwt/data.py index 1c67eea94a..5c6f73f855 100644 --- a/stanza/models/mwt/data.py +++ b/stanza/models/mwt/data.py @@ -1,10 +1,12 @@ import random import numpy as np import os -from collections import Counter +from collections import Counter, namedtuple import logging import torch +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import DataLoader as DL import stanza.models.common.seq2seq_constant as constant from stanza.models.common.data import map_to_ids, get_long_tensor, get_float_tensor, sort_all @@ -14,6 +16,9 @@ logger = logging.getLogger('stanza') +DataSample = namedtuple("DataSample", "src tgt_in tgt_out orig_text") +DataBatch = namedtuple("DataBatch", "src src_mask tgt_in tgt_out orig_text orig_idx") + # enforce that the MWT splitter knows about a couple different alternate apostrophes # including covering some potential " typos # setting the augmentation to a very low value should be enough to teach it @@ -22,7 +27,6 @@ # 0x22, 0x27, 0x02BC, 0x02CA, 0x055A, 0x07F4, 0x2019, 0xFF07 APOS = ('"', "'", 'ʼ', 'ˊ', '՚', 'ߴ', '’', ''') -# TODO: can wrap this in a Pytorch DataLoader, such as what was done for POS class DataLoader: def __init__(self, doc, batch_size, args, vocab=None, evaluation=False, expand_unk_vocab=False): self.batch_size = batch_size @@ -56,12 +60,9 @@ def __init__(self, doc, batch_size, args, vocab=None, evaluation=False, expand_u indices = list(range(len(data))) random.shuffle(indices) data = [data[i] for i in indices] - self.num_examples = len(data) - # chunk into batches - data = [data[i:i+batch_size] for i in range(0, len(data), batch_size)] self.data = data - logger.debug("{} batches created.".format(len(data))) + self.num_examples = len(data) def init_vocab(self, data): assert self.evaluation == False # for eval vocab must exist @@ -77,17 +78,14 @@ def maybe_augment_apos(self, datum): break return datum - - def process(self, data): - processed = [] - for d in data: - if not self.evaluation and self.augment_apos > 0: - d = self.maybe_augment_apos(d) - src = list(d[0]) - src = [constant.SOS] + src + [constant.EOS] - tgt_in, tgt_out = self.prepare_target(self.vocab, d) - src = self.vocab.map(src) - processed += [[src, tgt_in, tgt_out, d[0]]] + def process(self, sample): + if not self.evaluation and self.augment_apos > 0: + sample = self.maybe_augment_apos(sample) + src = list(sample[0]) + src = [constant.SOS] + src + [constant.EOS] + tgt_in, tgt_out = self.prepare_target(self.vocab, sample) + src = self.vocab.map(src) + processed = [src, tgt_in, tgt_out, sample[0]] return processed def prepare_target(self, vocab, datum): @@ -108,31 +106,54 @@ def __getitem__(self, key): raise TypeError if key < 0 or key >= len(self.data): raise IndexError - batch = self.data[key] - batch = self.process(batch) - batch_size = len(batch) - batch = list(zip(*batch)) - assert len(batch) == 4 - - # sort all fields by lens for easy RNN operations - lens = [len(x) for x in batch[0]] - batch, orig_idx = sort_all(batch, lens) + sample = self.data[key] + sample = self.process(sample) + assert len(sample) == 4 + + src = torch.tensor(sample[0]) + tgt_in = torch.tensor(sample[1]) + tgt_out = torch.tensor(sample[2]) + orig_text = sample[3] + result = DataSample(src, tgt_in, tgt_out, orig_text), key + return result + + @staticmethod + def __collate_fn(data): + (data, idx) = zip(*data) + (src, tgt_in, tgt_out, orig_text) = zip(*data) + + # collate_fn is given a list of length batch size + batch_size = len(data) + + # need to sort by length of src to properly handle + # the batching in the model itself + lens = [len(x) for x in src] + (src, tgt_in, tgt_out, orig_text), orig_idx = sort_all((src, tgt_in, tgt_out, orig_text), lens) + lens = [len(x) for x in src] # convert to tensors - src = batch[0] - src = get_long_tensor(src, batch_size) + src = pad_sequence(src, True, constant.PAD_ID) src_mask = torch.eq(src, constant.PAD_ID) - tgt_in = get_long_tensor(batch[1], batch_size) - tgt_out = get_long_tensor(batch[2], batch_size) - orig_text = batch[3] + tgt_in = pad_sequence(tgt_in, True, constant.PAD_ID) + tgt_out = pad_sequence(tgt_out, True, constant.PAD_ID) assert tgt_in.size(1) == tgt_out.size(1), \ "Target input and output sequence sizes do not match." - return (src, src_mask, tgt_in, tgt_out, orig_text, orig_idx) + return DataBatch(src, src_mask, tgt_in, tgt_out, orig_text, orig_idx) def __iter__(self): for i in range(self.__len__()): yield self.__getitem__(i) + def to_loader(self): + """Converts self to a DataLoader """ + + batch_size = self.batch_size + shuffle = not self.evaluation + return DL(self, + collate_fn=self.__collate_fn, + batch_size=batch_size, + shuffle=shuffle) + def load_doc(self, doc, evaluation=False): data = doc.get_mwt_expansions(evaluation) if evaluation: data = [[e] for e in data] diff --git a/stanza/models/mwt_expander.py b/stanza/models/mwt_expander.py index 87eadbb8a0..29a67ae336 100644 --- a/stanza/models/mwt_expander.py +++ b/stanza/models/mwt_expander.py @@ -16,6 +16,7 @@ from datetime import datetime import argparse import logging +import math import numpy as np import random import torch @@ -184,7 +185,8 @@ def train(args): # train a seq2seq model logger.info("Training seq2seq-based MWT expander...") global_step = 0 - max_steps = len(train_batch) * args['num_epoch'] + steps_per_epoch = math.ceil(len(train_batch) / args['batch_size']) + max_steps = steps_per_epoch * args['num_epoch'] dev_score_history = [] best_dev_preds = [] current_lr = args['lr'] @@ -201,7 +203,7 @@ def train(args): # start training for epoch in range(1, args['num_epoch']+1): train_loss = 0 - for i, batch in enumerate(train_batch): + for i, batch in enumerate(train_batch.to_loader()): start_time = time.time() global_step += 1 loss = trainer.update(batch, eval=False) # update step @@ -218,7 +220,7 @@ def train(args): # eval on dev logger.info("Evaluating on dev set...") dev_preds = [] - for i, batch in enumerate(dev_batch): + for i, batch in enumerate(dev_batch.to_loader()): preds = trainer.predict(batch) dev_preds += preds if args.get('ensemble_dict', False) and args.get('ensemble_early_stop', False): @@ -296,7 +298,7 @@ def evaluate(args): else: logger.info("Running the seq2seq model...") preds = [] - for i, b in enumerate(batch): + for i, b in enumerate(batch.to_loader()): preds += trainer.predict(b) if loaded_args.get('ensemble_dict', False): diff --git a/stanza/pipeline/mwt_processor.py b/stanza/pipeline/mwt_processor.py index 50b83bfd81..6aaf1b3112 100644 --- a/stanza/pipeline/mwt_processor.py +++ b/stanza/pipeline/mwt_processor.py @@ -37,7 +37,7 @@ def process(self, document): else: with torch.no_grad(): preds = [] - for i, b in enumerate(batch): + for i, b in enumerate(batch.to_loader()): preds += self.trainer.predict(b, never_decode_unk=True, vocab=batch.vocab) if self.config.get('ensemble_dict', False): diff --git a/stanza/tests/mwt/test_character_classifier.py b/stanza/tests/mwt/test_character_classifier.py index 2ae0cc31b3..1d8a699a20 100644 --- a/stanza/tests/mwt/test_character_classifier.py +++ b/stanza/tests/mwt/test_character_classifier.py @@ -81,7 +81,7 @@ def test_train(tmp_path): doc = CoNLL.conll2doc(input_str=ENG_DEV) dataloader = DataLoader(doc, 10, model.args, vocab=model.vocab, evaluation=True, expand_unk_vocab=True) preds = [] - for i, batch in enumerate(dataloader): + for i, batch in enumerate(dataloader.to_loader()): assert i == 0 # there should only be one batch preds += model.predict(batch, never_decode_unk=True, vocab=dataloader.vocab) assert len(preds) == 1