-
Notifications
You must be signed in to change notification settings - Fork 896
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mixing inputs that has/doesn't have upos
, xpos
, feats
#1306
Conversation
|
||
return DataBatch(words, words_mask, wordchars, wordchars_mask, upos, xpos, ufeats, | ||
pretrained, orig_idx, word_orig_idx, lens, word_lens, text, idx) | ||
return _ShadowDataset(self).to_loader(**kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw: this shouldn't break any previous APIs because the old .to_loader()
still works, it just makes a shadow dataset on your behalf with the one Dataset
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This hasn't been publicly released yet, so we should be free to change it however we like
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still, it's a pretty intuitive solution: the one Dataset version is just the N Datasets version reduced to 1 dataset
If you look in pos/model.py, you can see the part where it is checking the |
I can take that on, unless it's something you want to experiment with |
happy to help experiment |
stanza/models/pos/data.py
Outdated
|
||
# sort sentences by lens for easy RNN operations | ||
lens = [torch.sum(x != PAD_ID) for x in words] | ||
(words, wordchars, upos, xpos, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
surely the has_whatever needs to be sorted here as well, or the items won't be aligned
vocab = Dataset.init_vocab(train_docs, args) | ||
train_data = [Dataset(i, args, pretrain, vocab=vocab, evaluation=False) | ||
for i in train_docs] | ||
# here we make sure the model will learn to output _ for empty columns |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i still think this block is necessary, unless there's some other way in which this is being calculated which i have missed. the idea is: if dataset X has, and dataset Y does not have, then we want X to have the has
bit set and Y to have it set to False. but if both X and Y don't have the column, they need to be marked as has
specifically so that the model will learn blank features, xpos, etc
Yes, but the cross entropy is set to ignore padding indicies, so if we set those upos/xpos etc. as padding in that area, it will not contribute to the loss. therefore, during batch time, we set the indicies for which |
091611c
to
80b6a6c
Compare
…ther with the data items
…nputs to make sure the batching doesn't fail in some weird way. Then, redo the calls to update() for the batches and check that the losses are the same for a batch of size one or a batch of size two
Closing in favor of e4c2273, which is a dataloader-level mix. Feel free to reopen if we want to persue a dataset-level mix again. |
Description
Some languages, like German, OOM when training with the new PyTorch
Dataset
scheme as the overhead loading multiple datasets into separateDataLoader
and then mixing them didn't work well. We did this because some entire input files wouldn't have upos/xpos/ufeats, and we don't want to calculate loss.Instead, this PR elects to create a
_ShadowDataset
object between them, and masks out loss (by turning the upos/xpos/etc. into padding tokens at batch time) with the exact sentences which came from datasets that doesn't have upos/xpos/ufeats masked out only.Unit test coverage
Passes all tests in
stanza.tests.pos.test_tagger