diff --git a/stanza/models/coref/config.py b/stanza/models/coref/config.py index 5ac365b940..d5af644650 100644 --- a/stanza/models/coref/config.py +++ b/stanza/models/coref/config.py @@ -47,3 +47,5 @@ class Config: # pylint: disable=too-many-instance-attributes, too-few-public-me tokenizer_kwargs: Dict[str, dict] conll_log_dir: str + + save_each_checkpoint: bool diff --git a/stanza/models/coref/model.py b/stanza/models/coref/model.py index f29eb33c66..74454d2869 100644 --- a/stanza/models/coref/model.py +++ b/stanza/models/coref/model.py @@ -400,8 +400,12 @@ def train(self): save_path = os.path.join(self.config.data_dir, f"{self.config.section}.pt") self.save_weights(save_path, save_optimizers=False) - # TODO: make save_each an option here - self.save_weights() + if self.config.save_each_checkpoint: + self.save_weights() + else: + checkpoint_path = os.path.join(self.config.data_dir, + f"{self.config.section}.checkpoint.pt") + self.save_weights(checkpoint_path) # ========================================================= Private methods