diff --git a/stanza/models/lemma_classifier/transformer_baseline/baseline_trainer.py b/stanza/models/lemma_classifier/transformer_baseline/baseline_trainer.py index 8b64776582..2ed354593e 100644 --- a/stanza/models/lemma_classifier/transformer_baseline/baseline_trainer.py +++ b/stanza/models/lemma_classifier/transformer_baseline/baseline_trainer.py @@ -94,8 +94,8 @@ def train(self, num_epochs: int, save_name: str, args: Mapping, eval_file: str, text_batches, position_batches, label_batches = text_batches[:-1], position_batches[:-1], label_batches[:-1] # Move data to device - labels_batch = torch.stack(labels_batch).to(device) - positions_batch = torch.stack(positions_batch).to(device) + label_batches = torch.stack(label_batches).to(device) + position_batches = torch.stack(position_batches).to(device) assert len(text_batches) == len(position_batches) == len(label_batches), f"Input batch sizes did not match ({len(text_batches)}, {len(position_batches)}, {len(label_batches)})." diff --git a/stanza/models/lemma_classifier/transformer_baseline/model.py b/stanza/models/lemma_classifier/transformer_baseline/model.py index a9295efbc1..c4adff7b8e 100644 --- a/stanza/models/lemma_classifier/transformer_baseline/model.py +++ b/stanza/models/lemma_classifier/transformer_baseline/model.py @@ -30,7 +30,7 @@ def __init__(self, output_dim: int, transformer_name: str, label_decoder: Mappin # Choose transformer self.transformer_name = transformer_name - self.tokenizer = AutoTokenizer.from_pretrained(transformer_name) + self.tokenizer = AutoTokenizer.from_pretrained(transformer_name, use_fast=True, add_prefix_space=True) self.transformer = AutoModel.from_pretrained(transformer_name) config = self.transformer.config @@ -55,21 +55,17 @@ def forward(self, idx_positions: List[int], sentences: List[List[str]]): """ # Get the transformer embedding IDs for each token in each sentence - input_ids = [torch.tesnsor(self.tokenizer.convert_tokens_to_ids(sent)) for sent in sentences] - lengths = [len(sent) for sent in input_ids] - input_ids = pad_sequence(input_ids, batch_first=True) - - - packed_input = pack_padded_sequence(input_ids, lengths, batch_first=True) + device = next(self.transformer.parameters()).device + tokenized_inputs = self.tokenizer(sentences, padding=True, truncation=True, return_tensors="pt", is_split_into_words=True) + tokenized_inputs = {key: val.to(device) for key, val in tokenized_inputs.items()} + # Forward pass through Transformer - outputs = self.transformer(input_ids=packed_input) + outputs = self.transformer(**tokenized_inputs) # Get embeddings for all tokens last_hidden_state = outputs.last_hidden_state - unpacked_outputs = pad_packed_sequence(last_hidden_state, batch_first=True) - embeddings = unpacked_outputs[torch.arange(unpacked_outputs.size(0)), idx_positions] - + embeddings = last_hidden_state[torch.arange(last_hidden_state.size(0)), idx_positions] # pass to the MLP output = self.mlp(embeddings) return output