Skip to content

Commit

Permalink
data mixtures
Browse files Browse the repository at this point in the history
  • Loading branch information
Jemoka committed Nov 6, 2023
1 parent d9fc52d commit 35a91bf
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 74 deletions.
152 changes: 98 additions & 54 deletions stanza/models/pos/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from collections import namedtuple

from torch.utils.data import DataLoader as DL
from torch.utils.data import Dataset as DS
from torch.nn.utils.rnn import pad_sequence

from stanza.models.common.bert_embedding import filter_data
Expand All @@ -16,9 +17,96 @@

logger = logging.getLogger('stanza')

DataSample = namedtuple("DataSample", "word char upos xpos feats pretrain text")
DataSample = namedtuple("DataSample", "word char upos xpos feats pretrain has_upos has_xpos has_feats text")
DataBatch = namedtuple("DataBatch", "words words_mask wordchars wordchars_mask upos xpos ufeats pretrained orig_idx word_orig_idx lens word_lens text idx")

def merge_datasets(datasets):
"""Merge multiple datasets"""

return _ShadowDataset(*datasets)

class _ShadowDataset(DS):
def __init__(self, *datasets):
self.datasets = datasets

# precache the lengths of the datasets, cumulated
self.__cumulate_lens = []
self.__len = 0
for i in self.datasets:
self.__cumulate_lens.append(self.__len)
self.__len += len(i)


def to_loader(self, **kwargs):
"""Converts self to a DataLoader """

return DL(self, collate_fn=self.__collate_fn, **kwargs)

def __indx2loader(self, index):
"""Search through the loader lengths to get the id to the right dataset"""

# we iterate through cumulative lengths in *REVERSE* bec
for indx, i in reversed(list(enumerate(self.__cumulate_lens))):
if index >= i:
return indx, index-i

def __getitem__(self, key):
"""Get a single key for whether or not upos/xpos etc. is avaliable"""

dataset_num, indx = self.__indx2loader(key)
return self.datasets[dataset_num][indx], key

def __len__(self):
return self.__len

@staticmethod
def __collate_fn(data):
"""Function used by DataLoader to pack data"""
(data, idx) = zip(*data)
(words, wordchars, upos, xpos, ufeats, pretrained,
has_upos, has_xpos, has_feats, text) = zip(*data)

# collate_fn is given a list of length batch size
batch_size = len(data)

# sort sentences by lens for easy RNN operations
lens = [torch.sum(x != PAD_ID) for x in words]
(words, wordchars, upos, xpos,
ufeats, pretrained, text), orig_idx = sort_all((words, wordchars, upos, xpos,
ufeats, pretrained, text), lens)
lens = [torch.sum(x != PAD_ID) for x in words] # we need to reinterpret lengths for the RNN

# combine all words into one large list, and sort for easy charRNN ops
wordchars = [w for sent in wordchars for w in sent]
word_lens = [len(x) for x in wordchars]
(wordchars,), word_orig_idx = sort_all([wordchars], word_lens)
word_lens = [len(x) for x in wordchars] # we need to reinterpret lengths for the RNN

# We now pad everything
words = pad_sequence(words, True, PAD_ID)
upos = pad_sequence(upos, True, PAD_ID)
xpos = pad_sequence(xpos, True, PAD_ID)
ufeats = pad_sequence(ufeats, True, PAD_ID)
pretrained = pad_sequence(pretrained, True, PAD_ID)
wordchars = get_long_tensor(wordchars, len(word_lens))

# and get boolean mask tensors for upos, xpos, feats
upos_mask = torch.tensor(has_upos)
xpos_mask = torch.tensor(has_xpos)
feats_mask = torch.tensor(has_feats)

# mask out the elements for which upos/xpos/feats isn't available
upos[~upos_mask] = PAD_ID
xpos[~xpos_mask] = PAD_ID
ufeats[~feats_mask] = PAD_ID

# and finally create masks for the padding indices
words_mask = torch.eq(words, PAD_ID)
wordchars_mask = torch.eq(wordchars, PAD_ID)

return DataBatch(words, words_mask, wordchars, wordchars_mask, upos, xpos, ufeats,
pretrained, orig_idx, word_orig_idx, lens, word_lens, text, idx)

class Dataset:
def __init__(self, doc, args, pretrain, vocab=None, evaluation=False, sort_during_eval=False, bert_tokenizer=None, **kwargs):
self.args = args
Expand Down Expand Up @@ -87,6 +175,9 @@ def preprocess(self, data, vocab, pretrain_vocab, args):
pretrain = ([pretrain_vocab.map([w[0].lower() for w in sent])]
if pretrain_vocab is not None
else [[PAD_ID] * len(sent)]),
has_upos = self.has_upos,
has_xpos = self.has_xpos,
has_feats = self.has_feats,
text = [w[0] for w in sent]
)
processed.append(processed_sent)
Expand Down Expand Up @@ -166,9 +257,9 @@ def __getitem__(self, key):
# TODO: only store single lists per data entry?
words = torch.tensor(sample.word[0])
# convert the rest to tensors
upos = torch.tensor(sample.upos[0]) if self.has_upos else None
xpos = torch.tensor(sample.xpos[0]) if self.has_xpos else None
ufeats = torch.tensor(sample.feats[0]) if self.has_feats else None
upos = torch.tensor(sample.upos[0])
xpos = torch.tensor(sample.xpos[0])
ufeats = torch.tensor(sample.feats[0])
pretrained = torch.tensor(sample.pretrain[0])

# and deal with char & raw_text
Expand Down Expand Up @@ -205,7 +296,8 @@ def __getitem__(self, key):
# get each character from the input sentnece
# chars = [w for sent in char for w in sent]

return DataSample(words, char, upos, xpos, ufeats, pretrained, raw_text), key
return DataSample(words, char, upos, xpos, ufeats, pretrained,
self.has_upos, self.has_xpos, self.has_feats, raw_text)

def __iter__(self):
for i in range(self.__len__()):
Expand All @@ -214,55 +306,7 @@ def __iter__(self):
def to_loader(self, **kwargs):
"""Converts self to a DataLoader """

return DL(self,
collate_fn=Dataset.__collate_fn,
**kwargs)

@staticmethod
def __collate_fn(data):
"""Function used by DataLoader to pack data"""
(data, idx) = zip(*data)
(words, wordchars, upos, xpos, ufeats, pretrained, text) = zip(*data)

# collate_fn is given a list of length batch size
batch_size = len(data)

# sort sentences by lens for easy RNN operations
lens = [torch.sum(x != PAD_ID) for x in words]
(words, wordchars, upos, xpos,
ufeats, pretrained, text), orig_idx = sort_all((words, wordchars, upos, xpos,
ufeats, pretrained, text), lens)
lens = [torch.sum(x != PAD_ID) for x in words] # we need to reinterpret lengths for the RNN

# combine all words into one large list, and sort for easy charRNN ops
wordchars = [w for sent in wordchars for w in sent]
word_lens = [len(x) for x in wordchars]
(wordchars,), word_orig_idx = sort_all([wordchars], word_lens)
word_lens = [len(x) for x in wordchars] # we need to reinterpret lengths for the RNN

# We now pad everything
words = pad_sequence(words, True, PAD_ID)
if None not in upos:
upos = pad_sequence(upos, True, PAD_ID)
else:
upos = None
if None not in xpos:
xpos = pad_sequence(xpos, True, PAD_ID)
else:
xpos = None
if None not in ufeats:
ufeats = pad_sequence(ufeats, True, PAD_ID)
else:
ufeats = None
pretrained = pad_sequence(pretrained, True, PAD_ID)
wordchars = get_long_tensor(wordchars, len(word_lens))

# and finally create masks for the padding indices
words_mask = torch.eq(words, PAD_ID)
wordchars_mask = torch.eq(wordchars, PAD_ID)

return DataBatch(words, words_mask, wordchars, wordchars_mask, upos, xpos, ufeats,
pretrained, orig_idx, word_orig_idx, lens, word_lens, text, idx)
return _ShadowDataset(self).to_loader(**kwargs)

@staticmethod
def load_doc(doc):
Expand Down
29 changes: 9 additions & 20 deletions stanza/models/tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch import nn, optim

import stanza.models.pos.data as data
from stanza.models.pos.data import Dataset
from stanza.models.pos.data import Dataset, merge_datasets
from stanza.models.pos.trainer import Trainer
from stanza.models.pos import scorer
from stanza.models.common import utils
Expand Down Expand Up @@ -210,27 +210,18 @@ def train(args):
# therefore, we create seperate datasets and loaders for each input training file,
# which will ensure the system be able to see batches with both upos available
# and upos unavailable depending on what the availability in the file is.

vocab = Dataset.init_vocab(train_docs, args)
train_data = [Dataset(i, args, pretrain, vocab=vocab, evaluation=False)
for i in train_docs]
# here we make sure the model will learn to output _ for empty columns
# if *any* dataset has data for the upos, xpos, or feature column,
# we consider that data enough to train the model on that column
# otherwise, we want to train the model to always output blanks
if not any(td.has_upos for td in train_data):
for td in train_data:
td.has_upos = True
if not any(td.has_xpos for td in train_data):
for td in train_data:
td.has_xpos = True
if not any(td.has_feats for td in train_data):
for td in train_data:
td.has_feats = True

# calculate the batches
train_batches = [i.to_loader(batch_size=args["batch_size"], shuffle=True)
for i in train_data]
train_batches = merge_datasets(train_data).to_loader(batch_size=args["batch_size"],
shuffle=True)

dev_doc = CoNLL.conll2doc(input_file=args['eval_file'])
dev_data = Dataset(dev_doc, args, pretrain, vocab=vocab, evaluation=True, sort_during_eval=True)
dev_data = Dataset(dev_doc, args, pretrain, vocab=vocab,
evaluation=True, sort_during_eval=True)
dev_batch = dev_data.to_loader(batch_size=args["batch_size"])

eval_type = get_eval_type(dev_data)
Expand Down Expand Up @@ -284,9 +275,7 @@ def train(args):
# such as if XPOS or UPOS are missing from a training file,
# as we shuffle all of those batches together
# the downside being that it loses the efficiency benefit of the pytorch dataloader
all_train_batches = [x for train_batch in train_batches for x in iter(train_batch)]
random.shuffle(all_train_batches)
for i, batch in enumerate(all_train_batches):
for i, batch in enumerate(iter(train_batches)):
start_time = time.time()
global_step += 1
loss = trainer.update(batch, eval=False) # update step
Expand Down

0 comments on commit 35a91bf

Please sign in to comment.