From a3c6b4bda822bb61ca84f9c3754f9f78733dcf88 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Fri, 30 Aug 2024 08:11:16 -0700 Subject: [PATCH] Add a flag for turning off dropout, as in early dropout Add a test that the early dropout is turning off all the dropouts in a model --- stanza/models/constituency/parser_training.py | 5 ++++ stanza/models/constituency_parser.py | 8 +++++++ stanza/tests/constituency/test_trainer.py | 24 +++++++++++++++++++ 3 files changed, 37 insertions(+) diff --git a/stanza/models/constituency/parser_training.py b/stanza/models/constituency/parser_training.py index b133045a17..1317fc073d 100644 --- a/stanza/models/constituency/parser_training.py +++ b/stanza/models/constituency/parser_training.py @@ -449,6 +449,11 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d if watch_regex.search(n): wandb.log({n: torch.linalg.norm(p)}) + if args['early_dropout'] > 0 and trainer.epochs_trained >= args['early_dropout']: + trainer.model.word_dropout.p = 0 + trainer.model.predict_dropout.p = 0 + trainer.model.lstm_input_dropout.p = 0 + # recreate the optimizer and alter the model as needed if we hit a new multistage split if args['multistage'] and trainer.epochs_trained in multistage_splits: # we may be loading a save model from an earlier epoch if the scores stopped increasing diff --git a/stanza/models/constituency_parser.py b/stanza/models/constituency_parser.py index b5b8b79804..8efdbfbea0 100644 --- a/stanza/models/constituency_parser.py +++ b/stanza/models/constituency_parser.py @@ -562,6 +562,14 @@ def build_argparse(): parser.add_argument('--loss', default='cross', help='cross, large_margin, or focal. Focal requires `pip install focal_loss_torch`') parser.add_argument('--loss_focal_gamma', default=2, type=float, help='gamma value for a focal loss') + # turn off dropout for word_dropout, predict_dropout, and lstm_input_dropout + # this mechanism doesn't actually turn off lstm_layer_dropout (yet) + # but that is set to a default of 0 anyway + # this is reusing the idea presented in + # https://arxiv.org/pdf/2303.01500v2 + # "Dropout Reduces Underfitting" + # Zhuang Liu, Zhiqiu Xu, Joseph Jin, Zhiqiang Shen, Trevor Darrell + parser.add_argument('--early_dropout', default=-1, type=int, help='When to turn off dropout') # When using word_dropout and predict_dropout in conjunction with relu, one particular experiment produced the following dev scores after 300 iterations: # 0.0: 0.9085 # 0.2: 0.9165 diff --git a/stanza/tests/constituency/test_trainer.py b/stanza/tests/constituency/test_trainer.py index f2afc49bd4..8b8cc13352 100644 --- a/stanza/tests/constituency/test_trainer.py +++ b/stanza/tests/constituency/test_trainer.py @@ -5,6 +5,7 @@ import pytest import torch +from torch import nn from torch import optim from stanza import Pipeline @@ -253,6 +254,29 @@ def test_train(self, wordvec_pretrain_file): with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: self.run_train_test(wordvec_pretrain_file, tmpdirname) + def test_early_dropout(self, wordvec_pretrain_file): + """ + Test the whole thing for a few iterations on the fake data + """ + with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: + args = ['--early_dropout', '3'] + _, model = self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=6, extra_args=args) + model = model.model + dropouts = [(name, module) for name, module in model.named_children() if isinstance(module, nn.Dropout)] + assert len(dropouts) > 0, "Didn't find any dropouts in the model!" + for name, module in dropouts: + assert module.p == 0.0, "Dropout module %s was not set to 0 with early_dropout" + + with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: + # test that when turned off, early_dropout doesn't happen + args = ['--early_dropout', '-1'] + _, model = self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=6, extra_args=args) + model = model.model + dropouts = [(name, module) for name, module in model.named_children() if isinstance(module, nn.Dropout)] + assert len(dropouts) > 0, "Didn't find any dropouts in the model!" + if all(module.p == 0.0 for _, module in dropouts): + raise AssertionError("All dropouts were 0 after training even though early_dropout was set to -1") + def test_train_silver(self, wordvec_pretrain_file): """ Test the whole thing for a few iterations on the fake data