Skip to content

Commit

Permalink
Add logging lines for CE / BCE
Browse files Browse the repository at this point in the history
  • Loading branch information
AngledLuffa committed Dec 28, 2023
1 parent f29ec46 commit 91cb24f
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions stanza/models/lemma_classifier/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,11 @@ def __init__(self, vocab_size: int, embedding_file: str, embedding_dim: int, hid
if loss_fn == "ce":
self.criterion = nn.CrossEntropyLoss()
self.weighted_loss = False
logging.debug("Using CE loss")
elif loss_fn == "weighted_bce":
self.criterion = nn.BCEWithLogitsLoss()
self.weighted_loss = True # used to add weights during train time.
logging.debug("Using Weighted BCE loss")
else:
raise ValueError("Must enter a valid loss function (e.g. 'ce' or 'weighted_bce')")

Expand Down

0 comments on commit 91cb24f

Please sign in to comment.