diff --git a/stanza/models/lemma_classifier/transformer_baseline/baseline_trainer.py b/stanza/models/lemma_classifier/transformer_baseline/baseline_trainer.py index 56e37ccb94..b63fc855dc 100644 --- a/stanza/models/lemma_classifier/transformer_baseline/baseline_trainer.py +++ b/stanza/models/lemma_classifier/transformer_baseline/baseline_trainer.py @@ -140,7 +140,7 @@ def train(self, num_epochs: int, save_name: str, args: Mapping, eval_file: str, if self.weighted_loss: # BCEWithLogitsLoss requires a vector for target where probability is 1 on the true label class, and 0 on others. targets = torch.stack([torch.tensor([1, 0]) if label == 0 else torch.tensor([0, 1]) for label in labels]).to(dtype=torch.float32).to(device) else: # CELoss accepts target as just raw label - targets = labels + targets = labels.to(device) loss = self.criterion(outputs, targets)