-
Notifications
You must be signed in to change notification settings - Fork 897
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mixing inputs that has/doesn't have upos
, xpos
, feats
#1306
Changes from all commits
9343df0
2e0193e
025a832
a052581
37d05bb
0076181
80b6a6c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,7 +19,7 @@ | |
from torch import nn, optim | ||
|
||
import stanza.models.pos.data as data | ||
from stanza.models.pos.data import Dataset | ||
from stanza.models.pos.data import Dataset, merge_datasets | ||
from stanza.models.pos.trainer import Trainer | ||
from stanza.models.pos import scorer | ||
from stanza.models.common import utils | ||
|
@@ -188,10 +188,11 @@ def load_training_data(args, pretrain): | |
# therefore, we create seperate datasets and loaders for each input training file, | ||
# which will ensure the system be able to see batches with both upos available | ||
# and upos unavailable depending on what the availability in the file is. | ||
|
||
vocab = Dataset.init_vocab(train_docs, args) | ||
train_data = [Dataset(i, args, pretrain, vocab=vocab, evaluation=False) | ||
for i in train_docs] | ||
# here we make sure the model will learn to output _ for empty columns | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i still think this block is necessary, unless there's some other way in which this is being calculated which i have missed. the idea is: if dataset X has, and dataset Y does not have, then we want X to have the |
||
|
||
# if *any* dataset has data for the upos, xpos, or feature column, | ||
# we consider that data enough to train the model on that column | ||
# otherwise, we want to train the model to always output blanks | ||
|
@@ -204,9 +205,9 @@ def load_training_data(args, pretrain): | |
if not any(td.has_feats for td in train_data): | ||
for td in train_data: | ||
td.has_feats = True | ||
|
||
# calculate the batches | ||
train_batches = [i.to_loader(batch_size=args["batch_size"], shuffle=True) | ||
for i in train_data] | ||
train_batches = merge_datasets(train_data).to_loader(batch_size=args["batch_size"], shuffle=True) | ||
return vocab, train_data, train_batches | ||
|
||
def train(args): | ||
|
@@ -235,7 +236,8 @@ def train(args): | |
vocab, train_data, train_batches = load_training_data(args, pretrain) | ||
|
||
dev_doc = CoNLL.conll2doc(input_file=args['eval_file']) | ||
dev_data = Dataset(dev_doc, args, pretrain, vocab=vocab, evaluation=True, sort_during_eval=True) | ||
dev_data = Dataset(dev_doc, args, pretrain, vocab=vocab, | ||
evaluation=True, sort_during_eval=True) | ||
dev_batch = dev_data.to_loader(batch_size=args["batch_size"]) | ||
|
||
eval_type = get_eval_type(dev_data) | ||
|
@@ -289,9 +291,7 @@ def train(args): | |
# such as if XPOS or UPOS are missing from a training file, | ||
# as we shuffle all of those batches together | ||
# the downside being that it loses the efficiency benefit of the pytorch dataloader | ||
all_train_batches = [x for train_batch in train_batches for x in iter(train_batch)] | ||
random.shuffle(all_train_batches) | ||
for i, batch in enumerate(all_train_batches): | ||
for i, batch in enumerate(iter(train_batches)): | ||
start_time = time.time() | ||
global_step += 1 | ||
loss = trainer.update(batch, eval=False) # update step | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw: this shouldn't break any previous APIs because the old
.to_loader()
still works, it just makes a shadow dataset on your behalf with the oneDataset
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This hasn't been publicly released yet, so we should be free to change it however we like
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still, it's a pretty intuitive solution: the one Dataset version is just the N Datasets version reduced to 1 dataset