Skip to content

Commit

Permalink
Update the batched transformer version to use the extract_bert_embedd…
Browse files Browse the repository at this point in the history
…ings utility method
  • Loading branch information
AngledLuffa committed Jan 13, 2024
1 parent 7ee2d76 commit f9c1192
Showing 1 changed file with 5 additions and 13 deletions.
18 changes: 5 additions & 13 deletions stanza/models/lemma_classifier/transformer_baseline/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from transformers import AutoTokenizer, AutoModel
from typing import Mapping, List, Tuple, Any
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, pad_sequence
from stanza.models.common.bert_embedding import extract_bert_embeddings
from stanza.models.lemma_classifier.base_model import LemmaClassifier
from stanza.models.lemma_classifier.constants import ModelType

Expand Down Expand Up @@ -53,20 +54,11 @@ def forward(self, idx_positions: List[int], sentences: List[List[str]]):
Returns the logits of the MLP
"""

# Get the transformer embedding IDs for each token in each sentence
device = next(self.transformer.parameters()).device
tokenized_inputs = self.tokenizer(sentences, padding=True, truncation=True, return_tensors="pt", is_split_into_words=True)
tokenized_inputs = {key: val.to(device) for key, val in tokenized_inputs.items()}

# Forward pass through Transformer
with torch.no_grad():
outputs = self.transformer(**tokenized_inputs)

# Get embeddings for all tokens
last_hidden_state = outputs.last_hidden_state

embeddings = last_hidden_state[torch.arange(last_hidden_state.size(0)), idx_positions]
bert_embeddings = extract_bert_embeddings(self.transformer_name, self.tokenizer, self.transformer, sentences, device,
keep_endpoints=False, num_layers=1, detach=True)
embeddings = [emb[idx] for idx, emb in zip(idx_positions, bert_embeddings)]
embeddings = torch.stack(embeddings, dim=0)[:, :, 0]
# pass to the MLP
output = self.mlp(embeddings)
return output
Expand Down

0 comments on commit f9c1192

Please sign in to comment.