Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mixing inputs that has/doesn't have upos, xpos, feats #1306

Closed
wants to merge 7 commits into from
152 changes: 98 additions & 54 deletions stanza/models/pos/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from collections import namedtuple

from torch.utils.data import DataLoader as DL
from torch.utils.data import Dataset as DS
from torch.nn.utils.rnn import pad_sequence

from stanza.models.common.bert_embedding import filter_data
Expand All @@ -16,9 +17,96 @@

logger = logging.getLogger('stanza')

DataSample = namedtuple("DataSample", "word char upos xpos feats pretrain text")
DataSample = namedtuple("DataSample", "word char upos xpos feats pretrain has_upos has_xpos has_feats text")
DataBatch = namedtuple("DataBatch", "words words_mask wordchars wordchars_mask upos xpos ufeats pretrained orig_idx word_orig_idx lens word_lens text idx")

def merge_datasets(datasets):
"""Merge multiple datasets"""

return _ShadowDataset(*datasets)

class _ShadowDataset(DS):
def __init__(self, *datasets):
self.datasets = datasets

# precache the lengths of the datasets, cumulated
self.__cumulate_lens = []
self.__len = 0
for i in self.datasets:
self.__cumulate_lens.append(self.__len)
self.__len += len(i)


def to_loader(self, **kwargs):
"""Converts self to a DataLoader """

return DL(self, collate_fn=self.__collate_fn, **kwargs)

def __indx2loader(self, index):
"""Search through the loader lengths to get the id to the right dataset"""

# we iterate through cumulative lengths in *REVERSE* bec
for indx, i in reversed(list(enumerate(self.__cumulate_lens))):
if index >= i:
return indx, index-i

def __getitem__(self, key):
"""Get a single key for whether or not upos/xpos etc. is avaliable"""

dataset_num, indx = self.__indx2loader(key)
return self.datasets[dataset_num][indx], key

def __len__(self):
return self.__len

@staticmethod
def __collate_fn(data):
"""Function used by DataLoader to pack data"""
(data, idx) = zip(*data)
(words, wordchars, upos, xpos, ufeats, pretrained,
has_upos, has_xpos, has_feats, text) = zip(*data)

# collate_fn is given a list of length batch size
batch_size = len(data)

# sort sentences by lens for easy RNN operations
lens = [torch.sum(x != PAD_ID) for x in words]
(words, wordchars, upos, xpos, ufeats, pretrained,
has_upos, has_xpos, has_feats, text), orig_idx = sort_all((words, wordchars, upos, xpos, ufeats, pretrained,
has_upos, has_xpos, has_feats, text), lens)
lens = [torch.sum(x != PAD_ID) for x in words] # we need to reinterpret lengths for the RNN

# combine all words into one large list, and sort for easy charRNN ops
wordchars = [w for sent in wordchars for w in sent]
word_lens = [len(x) for x in wordchars]
(wordchars,), word_orig_idx = sort_all([wordchars], word_lens)
word_lens = [len(x) for x in wordchars] # we need to reinterpret lengths for the RNN

# We now pad everything
words = pad_sequence(words, True, PAD_ID)
upos = pad_sequence(upos, True, PAD_ID)
xpos = pad_sequence(xpos, True, PAD_ID)
ufeats = pad_sequence(ufeats, True, PAD_ID)
pretrained = pad_sequence(pretrained, True, PAD_ID)
wordchars = get_long_tensor(wordchars, len(word_lens))

# and get boolean mask tensors for upos, xpos, feats
upos_mask = torch.tensor(has_upos)
xpos_mask = torch.tensor(has_xpos)
feats_mask = torch.tensor(has_feats)

# mask out the elements for which upos/xpos/feats isn't available
upos[~upos_mask] = PAD_ID
xpos[~xpos_mask] = PAD_ID
ufeats[~feats_mask] = PAD_ID

# and finally create masks for the padding indices
words_mask = torch.eq(words, PAD_ID)
wordchars_mask = torch.eq(wordchars, PAD_ID)

return DataBatch(words, words_mask, wordchars, wordchars_mask, upos, xpos, ufeats,
pretrained, orig_idx, word_orig_idx, lens, word_lens, text, idx)

class Dataset:
def __init__(self, doc, args, pretrain, vocab=None, evaluation=False, sort_during_eval=False, bert_tokenizer=None, **kwargs):
self.args = args
Expand Down Expand Up @@ -87,6 +175,9 @@ def preprocess(self, data, vocab, pretrain_vocab, args):
pretrain = ([pretrain_vocab.map([w[0].lower() for w in sent])]
if pretrain_vocab is not None
else [[PAD_ID] * len(sent)]),
has_upos = self.has_upos,
has_xpos = self.has_xpos,
has_feats = self.has_feats,
text = [w[0] for w in sent]
)
processed.append(processed_sent)
Expand Down Expand Up @@ -166,9 +257,9 @@ def __getitem__(self, key):
# TODO: only store single lists per data entry?
words = torch.tensor(sample.word[0])
# convert the rest to tensors
upos = torch.tensor(sample.upos[0]) if self.has_upos else None
xpos = torch.tensor(sample.xpos[0]) if self.has_xpos else None
ufeats = torch.tensor(sample.feats[0]) if self.has_feats else None
upos = torch.tensor(sample.upos[0])
xpos = torch.tensor(sample.xpos[0])
ufeats = torch.tensor(sample.feats[0])
pretrained = torch.tensor(sample.pretrain[0])

# and deal with char & raw_text
Expand Down Expand Up @@ -205,7 +296,8 @@ def __getitem__(self, key):
# get each character from the input sentnece
# chars = [w for sent in char for w in sent]

return DataSample(words, char, upos, xpos, ufeats, pretrained, raw_text), key
return DataSample(words, char, upos, xpos, ufeats, pretrained,
self.has_upos, self.has_xpos, self.has_feats, raw_text)

def __iter__(self):
for i in range(self.__len__()):
Expand All @@ -214,55 +306,7 @@ def __iter__(self):
def to_loader(self, **kwargs):
"""Converts self to a DataLoader """

return DL(self,
collate_fn=Dataset.__collate_fn,
**kwargs)

@staticmethod
def __collate_fn(data):
"""Function used by DataLoader to pack data"""
(data, idx) = zip(*data)
(words, wordchars, upos, xpos, ufeats, pretrained, text) = zip(*data)

# collate_fn is given a list of length batch size
batch_size = len(data)

# sort sentences by lens for easy RNN operations
lens = [torch.sum(x != PAD_ID) for x in words]
(words, wordchars, upos, xpos,
ufeats, pretrained, text), orig_idx = sort_all((words, wordchars, upos, xpos,
ufeats, pretrained, text), lens)
lens = [torch.sum(x != PAD_ID) for x in words] # we need to reinterpret lengths for the RNN

# combine all words into one large list, and sort for easy charRNN ops
wordchars = [w for sent in wordchars for w in sent]
word_lens = [len(x) for x in wordchars]
(wordchars,), word_orig_idx = sort_all([wordchars], word_lens)
word_lens = [len(x) for x in wordchars] # we need to reinterpret lengths for the RNN

# We now pad everything
words = pad_sequence(words, True, PAD_ID)
if None not in upos:
upos = pad_sequence(upos, True, PAD_ID)
else:
upos = None
if None not in xpos:
xpos = pad_sequence(xpos, True, PAD_ID)
else:
xpos = None
if None not in ufeats:
ufeats = pad_sequence(ufeats, True, PAD_ID)
else:
ufeats = None
pretrained = pad_sequence(pretrained, True, PAD_ID)
wordchars = get_long_tensor(wordchars, len(word_lens))

# and finally create masks for the padding indices
words_mask = torch.eq(words, PAD_ID)
wordchars_mask = torch.eq(wordchars, PAD_ID)

return DataBatch(words, words_mask, wordchars, wordchars_mask, upos, xpos, ufeats,
pretrained, orig_idx, word_orig_idx, lens, word_lens, text, idx)
return _ShadowDataset(self).to_loader(**kwargs)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw: this shouldn't break any previous APIs because the old .to_loader() still works, it just makes a shadow dataset on your behalf with the one Dataset

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This hasn't been publicly released yet, so we should be free to change it however we like

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still, it's a pretty intuitive solution: the one Dataset version is just the N Datasets version reduced to 1 dataset


@staticmethod
def load_doc(doc):
Expand Down
14 changes: 7 additions & 7 deletions stanza/models/pos/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from stanza.models.common.foundation_cache import load_bert, load_charlm
from stanza.models.common.hlstm import HighwayLSTM
from stanza.models.common.dropout import WordDropout
from stanza.models.common.vocab import CompositeVocab
from stanza.models.common.vocab import CompositeVocab, PAD_ID
from stanza.models.common.char_model import CharacterModel

logger = logging.getLogger('stanza')
Expand Down Expand Up @@ -218,11 +218,11 @@ def pad(x):

preds = [pad(upos_pred).max(2)[1]]

loss = 0.0
if upos is not None:
upos = pack(upos).data
loss = self.crit(upos_pred.view(-1, upos_pred.size(-1)), upos.view(-1))
else:
loss = 0.0
if not torch.all(upos.eq(PAD_ID)):
loss = self.crit(upos_pred.view(-1, upos_pred.size(-1)), upos.view(-1))

if self.share_hid:
xpos_hid = upos_hid
Expand All @@ -245,21 +245,21 @@ def pad(x):
xpos_preds = []
for i in range(len(self.vocab['xpos'])):
xpos_pred = clffunc(self.xpos_clf[i], xpos_hid)
if xpos is not None:
if xpos is not None and not torch.all(xpos[:, i].eq(PAD_ID)):
loss += self.crit(xpos_pred.view(-1, xpos_pred.size(-1)), xpos[:, i].view(-1))
xpos_preds.append(pad(xpos_pred).max(2, keepdim=True)[1])
preds.append(torch.cat(xpos_preds, 2))
else:
xpos_pred = clffunc(self.xpos_clf, xpos_hid)
if xpos is not None:
if xpos is not None and not torch.all(xpos.eq(PAD_ID)):
loss += self.crit(xpos_pred.view(-1, xpos_pred.size(-1)), xpos.view(-1))
preds.append(pad(xpos_pred).max(2)[1])

ufeats_preds = []
if ufeats is not None: ufeats = pack(ufeats).data
for i in range(len(self.vocab['feats'])):
ufeats_pred = clffunc(self.ufeats_clf[i], ufeats_hid)
if ufeats is not None:
if ufeats is not None and not torch.all(ufeats[:, i].eq(PAD_ID)):
loss += self.crit(ufeats_pred.view(-1, ufeats_pred.size(-1)), ufeats[:, i].view(-1))
ufeats_preds.append(pad(ufeats_pred).max(2, keepdim=True)[1])
preds.append(torch.cat(ufeats_preds, 2))
Expand Down
16 changes: 8 additions & 8 deletions stanza/models/tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch import nn, optim

import stanza.models.pos.data as data
from stanza.models.pos.data import Dataset
from stanza.models.pos.data import Dataset, merge_datasets
from stanza.models.pos.trainer import Trainer
from stanza.models.pos import scorer
from stanza.models.common import utils
Expand Down Expand Up @@ -188,10 +188,11 @@ def load_training_data(args, pretrain):
# therefore, we create seperate datasets and loaders for each input training file,
# which will ensure the system be able to see batches with both upos available
# and upos unavailable depending on what the availability in the file is.

vocab = Dataset.init_vocab(train_docs, args)
train_data = [Dataset(i, args, pretrain, vocab=vocab, evaluation=False)
for i in train_docs]
# here we make sure the model will learn to output _ for empty columns
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i still think this block is necessary, unless there's some other way in which this is being calculated which i have missed. the idea is: if dataset X has, and dataset Y does not have, then we want X to have the has bit set and Y to have it set to False. but if both X and Y don't have the column, they need to be marked as has specifically so that the model will learn blank features, xpos, etc


# if *any* dataset has data for the upos, xpos, or feature column,
# we consider that data enough to train the model on that column
# otherwise, we want to train the model to always output blanks
Expand All @@ -204,9 +205,9 @@ def load_training_data(args, pretrain):
if not any(td.has_feats for td in train_data):
for td in train_data:
td.has_feats = True

# calculate the batches
train_batches = [i.to_loader(batch_size=args["batch_size"], shuffle=True)
for i in train_data]
train_batches = merge_datasets(train_data).to_loader(batch_size=args["batch_size"], shuffle=True)
return vocab, train_data, train_batches

def train(args):
Expand Down Expand Up @@ -235,7 +236,8 @@ def train(args):
vocab, train_data, train_batches = load_training_data(args, pretrain)

dev_doc = CoNLL.conll2doc(input_file=args['eval_file'])
dev_data = Dataset(dev_doc, args, pretrain, vocab=vocab, evaluation=True, sort_during_eval=True)
dev_data = Dataset(dev_doc, args, pretrain, vocab=vocab,
evaluation=True, sort_during_eval=True)
dev_batch = dev_data.to_loader(batch_size=args["batch_size"])

eval_type = get_eval_type(dev_data)
Expand Down Expand Up @@ -289,9 +291,7 @@ def train(args):
# such as if XPOS or UPOS are missing from a training file,
# as we shuffle all of those batches together
# the downside being that it loses the efficiency benefit of the pytorch dataloader
all_train_batches = [x for train_batch in train_batches for x in iter(train_batch)]
random.shuffle(all_train_batches)
for i, batch in enumerate(all_train_batches):
for i, batch in enumerate(iter(train_batches)):
start_time = time.time()
global_step += 1
loss = trainer.update(batch, eval=False) # update step
Expand Down
38 changes: 37 additions & 1 deletion stanza/tests/pos/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
import os
import pytest

import torch

from stanza.models.common.doc import *
from stanza.models import tagger
from stanza.models.pos.data import Dataset
from stanza.models.pos.data import Dataset, merge_datasets
from stanza.utils.conll import CoNLL

from stanza.tests.pos.test_tagger import TRAIN_DATA, TRAIN_DATA_NO_XPOS, TRAIN_DATA_NO_UPOS, TRAIN_DATA_NO_FEATS
Expand Down Expand Up @@ -68,6 +70,40 @@ def test_no_feats():
assert data.has_xpos
assert not data.has_feats

def test_merge_some_xpos():
# batch size 2 for 3 total elements so that it sometimes randomly
# puts the sentence with no xpos in a block by itself and
# sometimes randomly puts it in a block with 1 other element
args = tagger.parse_args(args=['--batch_size', '2'])

train_docs = [CoNLL.conll2doc(input_str=TRAIN_DATA),
CoNLL.conll2doc(input_str=TRAIN_DATA_NO_XPOS)]

# TODO: maybe refactor the reading code in the main body of the tagger for easier testing
vocab = Dataset.init_vocab(train_docs, args)
train_data = [Dataset(i, args, None, vocab=vocab, evaluation=False) for i in train_docs]
lens = list(len(x) for x in train_data)
assert lens == [2, 1]
merged = merge_datasets(train_data)
assert len(merged) == 3
train_batches = merged.to_loader(batch_size=args["batch_size"], shuffle=True)
it_first = 0
for _ in range(200):
for batch_idx, batch in enumerate(iter(train_batches)):
if batch.text[-1][0] == 'It':
if batch_idx == 0:
it_first += 1
# in a batch of size 2, the other item has to be
# one of the with-xpos items
assert any(batch.xpos[0])
# this should always be false to represent that this
# item in the batch has been masked
assert not any(batch.xpos[-1])

# check that the sentence w/o xpos is sometimes but not always in the first batch
assert it_first > 5
assert it_first < 195

def test_no_augment():
"""
Test that with no punct removing augmentation, the doc always has punct at the end
Expand Down
Loading