Skip to content

Commit

Permalink
Add a flag to not save every checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
AngledLuffa committed Nov 14, 2023
1 parent 0479f37 commit 4744b29
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
2 changes: 2 additions & 0 deletions stanza/models/coref/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,5 @@ class Config: # pylint: disable=too-many-instance-attributes, too-few-public-me

tokenizer_kwargs: Dict[str, dict]
conll_log_dir: str

save_each_checkpoint: bool
8 changes: 6 additions & 2 deletions stanza/models/coref/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,8 +400,12 @@ def train(self):
save_path = os.path.join(self.config.data_dir,
f"{self.config.section}.pt")
self.save_weights(save_path, save_optimizers=False)
# TODO: make save_each an option here
self.save_weights()
if self.config.save_each_checkpoint:
self.save_weights()
else:
checkpoint_path = os.path.join(self.config.data_dir,
f"{self.config.section}.checkpoint.pt")
self.save_weights(checkpoint_path)

# ========================================================= Private methods

Expand Down

0 comments on commit 4744b29

Please sign in to comment.