Skip to content

Commit

Permalink
Use the sentence F1 as a metric for choosing the best model to keep
Browse files Browse the repository at this point in the history
  • Loading branch information
AngledLuffa committed Nov 7, 2023
1 parent 3fc5b95 commit b1eca27
Showing 1 changed file with 14 additions and 2 deletions.
16 changes: 14 additions & 2 deletions stanza/models/coref/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def save_weights(self, save_path=None):
if save_path is None:
save_path = os.path.join(self.config.data_dir,
f"{self.config.section}"
f"_(e{self.epochs_trained}_{time}).pt")
f"_e{self.epochs_trained}_{time}.pt")
savedict = {name: module.state_dict() for name, module in to_save}
savedict["epochs_trained"] = self.epochs_trained # type: ignore
savedict["config"] = self.config
Expand All @@ -327,6 +327,7 @@ def train(self):
docs_ids = list(range(len(docs)))
avg_spans = sum(len(doc["head2span"]) for doc in docs) / len(docs)

best_f1 = None
for epoch in range(self.epochs_trained, self.config.train_epochs):
self.training = True
running_c_loss = 0.0
Expand Down Expand Up @@ -369,8 +370,19 @@ def train(self):
)

self.epochs_trained += 1
scores = self.evaluate()
if best_f1 is None or scores[1] > best_f1:
if best_f1 is None:
logger.info("Saving new best model: F1 %.4f", scores[1])
else:
logger.info("Saving new best model: F1 %.4f > %.4f", scores[1], best_f1)
best_f1 = scores[1]
# TODO: choose a different default save dir
save_path = os.path.join(self.config.data_dir,
f"{self.config.section}.pt")
self.save_weights(save_path)
# TODO: make save_each an option here
self.save_weights()
self.evaluate()

# ========================================================= Private methods

Expand Down

0 comments on commit b1eca27

Please sign in to comment.