diff --git a/stanza/models/pos/data.py b/stanza/models/pos/data.py index 567d2605e7..f10aa46bcb 100644 --- a/stanza/models/pos/data.py +++ b/stanza/models/pos/data.py @@ -5,6 +5,7 @@ from collections import namedtuple from torch.utils.data import DataLoader as DL +from torch.utils.data import Dataset as DS from torch.nn.utils.rnn import pad_sequence from stanza.models.common.bert_embedding import filter_data @@ -16,9 +17,96 @@ logger = logging.getLogger('stanza') -DataSample = namedtuple("DataSample", "word char upos xpos feats pretrain text") +DataSample = namedtuple("DataSample", "word char upos xpos feats pretrain has_upos has_xpos has_feats text") DataBatch = namedtuple("DataBatch", "words words_mask wordchars wordchars_mask upos xpos ufeats pretrained orig_idx word_orig_idx lens word_lens text idx") +def merge_datasets(datasets): + """Merge multiple datasets""" + + return _ShadowDataset(*datasets) + +class _ShadowDataset(DS): + def __init__(self, *datasets): + self.datasets = datasets + + # precache the lengths of the datasets, cumulated + self.__cumulate_lens = [] + self.__len = 0 + for i in self.datasets: + self.__cumulate_lens.append(self.__len) + self.__len += len(i) + + + def to_loader(self, **kwargs): + """Converts self to a DataLoader """ + + return DL(self, collate_fn=self.__collate_fn, **kwargs) + + def __indx2loader(self, index): + """Search through the loader lengths to get the id to the right dataset""" + + # we iterate through cumulative lengths in *REVERSE* bec + for indx, i in reversed(list(enumerate(self.__cumulate_lens))): + if index >= i: + return indx, index-i + + def __getitem__(self, key): + """Get a single key for whether or not upos/xpos etc. is avaliable""" + + dataset_num, indx = self.__indx2loader(key) + return self.datasets[dataset_num][indx], key + + def __len__(self): + return self.__len + + @staticmethod + def __collate_fn(data): + """Function used by DataLoader to pack data""" + (data, idx) = zip(*data) + (words, wordchars, upos, xpos, ufeats, pretrained, + has_upos, has_xpos, has_feats, text) = zip(*data) + + # collate_fn is given a list of length batch size + batch_size = len(data) + + # sort sentences by lens for easy RNN operations + lens = [torch.sum(x != PAD_ID) for x in words] + (words, wordchars, upos, xpos, + ufeats, pretrained, text), orig_idx = sort_all((words, wordchars, upos, xpos, + ufeats, pretrained, text), lens) + lens = [torch.sum(x != PAD_ID) for x in words] # we need to reinterpret lengths for the RNN + + # combine all words into one large list, and sort for easy charRNN ops + wordchars = [w for sent in wordchars for w in sent] + word_lens = [len(x) for x in wordchars] + (wordchars,), word_orig_idx = sort_all([wordchars], word_lens) + word_lens = [len(x) for x in wordchars] # we need to reinterpret lengths for the RNN + + # We now pad everything + words = pad_sequence(words, True, PAD_ID) + upos = pad_sequence(upos, True, PAD_ID) + xpos = pad_sequence(xpos, True, PAD_ID) + ufeats = pad_sequence(ufeats, True, PAD_ID) + pretrained = pad_sequence(pretrained, True, PAD_ID) + wordchars = get_long_tensor(wordchars, len(word_lens)) + + # and get boolean mask tensors for upos, xpos, feats + upos_mask = torch.tensor(has_upos) + xpos_mask = torch.tensor(has_xpos) + feats_mask = torch.tensor(has_feats) + + # mask out the elements for which upos/xpos/feats isn't available + upos[~upos_mask] = PAD_ID + xpos[~xpos_mask] = PAD_ID + ufeats[~feats_mask] = PAD_ID + + # and finally create masks for the padding indices + words_mask = torch.eq(words, PAD_ID) + wordchars_mask = torch.eq(wordchars, PAD_ID) + + return DataBatch(words, words_mask, wordchars, wordchars_mask, upos, xpos, ufeats, + pretrained, orig_idx, word_orig_idx, lens, word_lens, text, idx) + class Dataset: def __init__(self, doc, args, pretrain, vocab=None, evaluation=False, sort_during_eval=False, bert_tokenizer=None, **kwargs): self.args = args @@ -87,6 +175,9 @@ def preprocess(self, data, vocab, pretrain_vocab, args): pretrain = ([pretrain_vocab.map([w[0].lower() for w in sent])] if pretrain_vocab is not None else [[PAD_ID] * len(sent)]), + has_upos = self.has_upos, + has_xpos = self.has_xpos, + has_feats = self.has_feats, text = [w[0] for w in sent] ) processed.append(processed_sent) @@ -166,9 +257,9 @@ def __getitem__(self, key): # TODO: only store single lists per data entry? words = torch.tensor(sample.word[0]) # convert the rest to tensors - upos = torch.tensor(sample.upos[0]) if self.has_upos else None - xpos = torch.tensor(sample.xpos[0]) if self.has_xpos else None - ufeats = torch.tensor(sample.feats[0]) if self.has_feats else None + upos = torch.tensor(sample.upos[0]) + xpos = torch.tensor(sample.xpos[0]) + ufeats = torch.tensor(sample.feats[0]) pretrained = torch.tensor(sample.pretrain[0]) # and deal with char & raw_text @@ -205,7 +296,8 @@ def __getitem__(self, key): # get each character from the input sentnece # chars = [w for sent in char for w in sent] - return DataSample(words, char, upos, xpos, ufeats, pretrained, raw_text), key + return DataSample(words, char, upos, xpos, ufeats, pretrained, + self.has_upos, self.has_xpos, self.has_feats, raw_text) def __iter__(self): for i in range(self.__len__()): @@ -214,55 +306,7 @@ def __iter__(self): def to_loader(self, **kwargs): """Converts self to a DataLoader """ - return DL(self, - collate_fn=Dataset.__collate_fn, - **kwargs) - - @staticmethod - def __collate_fn(data): - """Function used by DataLoader to pack data""" - (data, idx) = zip(*data) - (words, wordchars, upos, xpos, ufeats, pretrained, text) = zip(*data) - - # collate_fn is given a list of length batch size - batch_size = len(data) - - # sort sentences by lens for easy RNN operations - lens = [torch.sum(x != PAD_ID) for x in words] - (words, wordchars, upos, xpos, - ufeats, pretrained, text), orig_idx = sort_all((words, wordchars, upos, xpos, - ufeats, pretrained, text), lens) - lens = [torch.sum(x != PAD_ID) for x in words] # we need to reinterpret lengths for the RNN - - # combine all words into one large list, and sort for easy charRNN ops - wordchars = [w for sent in wordchars for w in sent] - word_lens = [len(x) for x in wordchars] - (wordchars,), word_orig_idx = sort_all([wordchars], word_lens) - word_lens = [len(x) for x in wordchars] # we need to reinterpret lengths for the RNN - - # We now pad everything - words = pad_sequence(words, True, PAD_ID) - if None not in upos: - upos = pad_sequence(upos, True, PAD_ID) - else: - upos = None - if None not in xpos: - xpos = pad_sequence(xpos, True, PAD_ID) - else: - xpos = None - if None not in ufeats: - ufeats = pad_sequence(ufeats, True, PAD_ID) - else: - ufeats = None - pretrained = pad_sequence(pretrained, True, PAD_ID) - wordchars = get_long_tensor(wordchars, len(word_lens)) - - # and finally create masks for the padding indices - words_mask = torch.eq(words, PAD_ID) - wordchars_mask = torch.eq(wordchars, PAD_ID) - - return DataBatch(words, words_mask, wordchars, wordchars_mask, upos, xpos, ufeats, - pretrained, orig_idx, word_orig_idx, lens, word_lens, text, idx) + return _ShadowDataset(self).to_loader(**kwargs) @staticmethod def load_doc(doc): diff --git a/stanza/models/tagger.py b/stanza/models/tagger.py index 9ede1b19c3..f698f26d3e 100644 --- a/stanza/models/tagger.py +++ b/stanza/models/tagger.py @@ -19,7 +19,7 @@ from torch import nn, optim import stanza.models.pos.data as data -from stanza.models.pos.data import Dataset +from stanza.models.pos.data import Dataset, merge_datasets from stanza.models.pos.trainer import Trainer from stanza.models.pos import scorer from stanza.models.common import utils @@ -210,27 +210,18 @@ def train(args): # therefore, we create seperate datasets and loaders for each input training file, # which will ensure the system be able to see batches with both upos available # and upos unavailable depending on what the availability in the file is. + vocab = Dataset.init_vocab(train_docs, args) train_data = [Dataset(i, args, pretrain, vocab=vocab, evaluation=False) for i in train_docs] - # here we make sure the model will learn to output _ for empty columns - # if *any* dataset has data for the upos, xpos, or feature column, - # we consider that data enough to train the model on that column - # otherwise, we want to train the model to always output blanks - if not any(td.has_upos for td in train_data): - for td in train_data: - td.has_upos = True - if not any(td.has_xpos for td in train_data): - for td in train_data: - td.has_xpos = True - if not any(td.has_feats for td in train_data): - for td in train_data: - td.has_feats = True + # calculate the batches - train_batches = [i.to_loader(batch_size=args["batch_size"], shuffle=True) - for i in train_data] + train_batches = merge_datasets(train_data).to_loader(batch_size=args["batch_size"], + shuffle=True) + dev_doc = CoNLL.conll2doc(input_file=args['eval_file']) - dev_data = Dataset(dev_doc, args, pretrain, vocab=vocab, evaluation=True, sort_during_eval=True) + dev_data = Dataset(dev_doc, args, pretrain, vocab=vocab, + evaluation=True, sort_during_eval=True) dev_batch = dev_data.to_loader(batch_size=args["batch_size"]) eval_type = get_eval_type(dev_data) @@ -284,9 +275,7 @@ def train(args): # such as if XPOS or UPOS are missing from a training file, # as we shuffle all of those batches together # the downside being that it loses the efficiency benefit of the pytorch dataloader - all_train_batches = [x for train_batch in train_batches for x in iter(train_batch)] - random.shuffle(all_train_batches) - for i, batch in enumerate(all_train_batches): + for i, batch in enumerate(iter(train_batches)): start_time = time.time() global_step += 1 loss = trainer.update(batch, eval=False) # update step