Skip to content

Commit

Permalink
Enable torch.no_grad() when computing the forward pass
Browse files Browse the repository at this point in the history
  • Loading branch information
SecroLoL committed Jan 6, 2024
1 parent b468519 commit 8fea749
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion stanza/models/lemma_classifier/transformer_baseline/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def forward(self, idx_positions: List[int], sentences: List[List[str]]):
tokenized_inputs = {key: val.to(device) for key, val in tokenized_inputs.items()}

# Forward pass through Transformer
outputs = self.transformer(**tokenized_inputs)
with torch.no_grad():
outputs = self.transformer(**tokenized_inputs)

# Get embeddings for all tokens
last_hidden_state = outputs.last_hidden_state
Expand Down

0 comments on commit 8fea749

Please sign in to comment.