diff --git a/stanza/models/lemma_classifier/transformer_baseline/baseline_trainer.py b/stanza/models/lemma_classifier/transformer_baseline/baseline_trainer.py index 2d4ca293c4..8b64776582 100644 --- a/stanza/models/lemma_classifier/transformer_baseline/baseline_trainer.py +++ b/stanza/models/lemma_classifier/transformer_baseline/baseline_trainer.py @@ -86,15 +86,18 @@ def train(self, num_epochs: int, save_name: str, args: Mapping, eval_file: str, device = default_device() if kwargs.get("train_path"): - texts_batch, positions_batch, labels_batch, counts, label_decoder = utils.load_dataset(kwargs.get("train_path"), get_counts=self.weighted_loss) + text_batches, position_batches, label_batches, counts, label_decoder = utils.load_dataset(kwargs.get("train_path"), get_counts=self.weighted_loss) self.output_dim = len(label_decoder) logging.info(f"Using label decoder : {label_decoder}") + # TODO: fix this to make it not disregard last batch, and instead pad it or some other idea + text_batches, position_batches, label_batches = text_batches[:-1], position_batches[:-1], label_batches[:-1] + # Move data to device - labels_batch = torch.tensor(labels_batch, device=device) - positions_batch = torch.tensor(positions_batch, device=device) + labels_batch = torch.stack(labels_batch).to(device) + positions_batch = torch.stack(positions_batch).to(device) - assert len(texts_batch) == len(positions_batch) == len(labels_batch), f"Input batch sizes did not match ({len(texts_batch)}, {len(positions_batch)}, {len(labels_batch)})." + assert len(text_batches) == len(position_batches) == len(label_batches), f"Input batch sizes did not match ({len(text_batches)}, {len(position_batches)}, {len(label_batches)})." self.model = LemmaClassifierWithTransformer(output_dim=self.output_dim, transformer_name=self.transformer_name, label_decoder=label_decoder) self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr) @@ -114,21 +117,18 @@ def train(self, num_epochs: int, save_name: str, args: Mapping, eval_file: str, best_model, best_f1 = None, float("-inf") for epoch in range(num_epochs): # go over entire dataset with each epoch - for texts, position, label in tqdm(zip(texts_batch, positions_batch, labels_batch), total=len(texts_batch)): - if position < 0 or position > len(texts) - 1: # validate position index - raise ValueError(f"Found position {position} in text: {texts}, which is not possible.") - + for sentences, positions, labels in tqdm(zip(text_batches, position_batches, label_batches), total=len(text_batches)): + self.optimizer.zero_grad() - output = self.model(position, texts) + outputs = self.model(positions, sentences) # Compute loss, which is different if using CE or BCEWithLogitsLoss if self.weighted_loss: # BCEWithLogitsLoss requires a vector for target where probability is 1 on the true label class, and 0 on others. - target_vec = [1, 0] if label == 0 else [0, 1] - target = torch.tensor(target_vec, dtype=torch.float32, device=device) + targets = torch.stack([torch.tensor([1, 0]) if label == 0 else torch.tensor([0, 1]) for label in labels]).to(dtype=torch.float32).to(device) else: # CELoss accepts target as just raw label - target = torch.tensor(label, dtype=torch.long, device=device) + targets = labels - loss = self.criterion(output, target) + loss = self.criterion(outputs, targets) loss.backward() self.optimizer.step() diff --git a/stanza/models/lemma_classifier/transformer_baseline/model.py b/stanza/models/lemma_classifier/transformer_baseline/model.py index 61cdf93781..a9295efbc1 100644 --- a/stanza/models/lemma_classifier/transformer_baseline/model.py +++ b/stanza/models/lemma_classifier/transformer_baseline/model.py @@ -6,7 +6,7 @@ from transformers import AutoTokenizer, AutoModel from typing import Mapping, List, Tuple, Any - +from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, pad_sequence from stanza.models.lemma_classifier.base_model import LemmaClassifier from stanza.models.lemma_classifier.constants import ModelType @@ -44,7 +44,7 @@ def __init__(self, output_dim: int, transformer_name: str, label_decoder: Mappin ) self.label_decoder = label_decoder - def forward(self, pos_index: int, text: List[str]): + def forward(self, idx_positions: List[int], sentences: List[List[str]]): """ Args: @@ -54,25 +54,24 @@ def forward(self, pos_index: int, text: List[str]): Returns the logits of the MLP """ - # Get the transformer embeddings - input_ids = self.tokenizer.convert_tokens_to_ids(text) + # Get the transformer embedding IDs for each token in each sentence + input_ids = [torch.tesnsor(self.tokenizer.convert_tokens_to_ids(sent)) for sent in sentences] + lengths = [len(sent) for sent in input_ids] + input_ids = pad_sequence(input_ids, batch_first=True) + - # Convert tokens to IDs and put them into a tensor - input_ids_tensor = torch.tensor([input_ids], device=next(self.parameters()).device) # move data to device as well + packed_input = pack_padded_sequence(input_ids, lengths, batch_first=True) # Forward pass through Transformer - with torch.no_grad(): - outputs = self.transformer(input_ids_tensor) + outputs = self.transformer(input_ids=packed_input) # Get embeddings for all tokens last_hidden_state = outputs.last_hidden_state - token_embeddings = last_hidden_state[0] - pos_index = torch.tensor(pos_index, device=next(self.parameters()).device) - # Get target embedding - target_pos_embedding = token_embeddings[pos_index] + unpacked_outputs = pad_packed_sequence(last_hidden_state, batch_first=True) + embeddings = unpacked_outputs[torch.arange(unpacked_outputs.size(0)), idx_positions] # pass to the MLP - output = self.mlp(target_pos_embedding) + output = self.mlp(embeddings) return output def model_type(self):