From 4744b29be241eb492b3991a139d7d36ef6040295 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Tue, 14 Nov 2023 00:13:24 -0800 Subject: [PATCH] Add a flag to not save every checkpoint --- stanza/models/coref/config.py | 2 ++ stanza/models/coref/model.py | 8 ++++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/stanza/models/coref/config.py b/stanza/models/coref/config.py index 5ac365b940..d5af644650 100644 --- a/stanza/models/coref/config.py +++ b/stanza/models/coref/config.py @@ -47,3 +47,5 @@ class Config: # pylint: disable=too-many-instance-attributes, too-few-public-me tokenizer_kwargs: Dict[str, dict] conll_log_dir: str + + save_each_checkpoint: bool diff --git a/stanza/models/coref/model.py b/stanza/models/coref/model.py index f29eb33c66..74454d2869 100644 --- a/stanza/models/coref/model.py +++ b/stanza/models/coref/model.py @@ -400,8 +400,12 @@ def train(self): save_path = os.path.join(self.config.data_dir, f"{self.config.section}.pt") self.save_weights(save_path, save_optimizers=False) - # TODO: make save_each an option here - self.save_weights() + if self.config.save_each_checkpoint: + self.save_weights() + else: + checkpoint_path = os.path.join(self.config.data_dir, + f"{self.config.section}.checkpoint.pt") + self.save_weights(checkpoint_path) # ========================================================= Private methods