From 492c6503b001b778db42abb2bad72f59377ba19c Mon Sep 17 00:00:00 2001 From: SecroLoL Date: Thu, 4 Jan 2024 23:38:12 -0800 Subject: [PATCH] Edit to make batch processing work and get rid of debug statements --- .../lemma_classifier/evaluate_models.py | 20 +++++++++---------- stanza/models/lemma_classifier/model.py | 16 +++++++-------- stanza/models/lemma_classifier/train_model.py | 9 ++++----- stanza/models/lemma_classifier/utils.py | 2 +- 4 files changed, 21 insertions(+), 26 deletions(-) diff --git a/stanza/models/lemma_classifier/evaluate_models.py b/stanza/models/lemma_classifier/evaluate_models.py index 70038b7443..c2fca1809a 100644 --- a/stanza/models/lemma_classifier/evaluate_models.py +++ b/stanza/models/lemma_classifier/evaluate_models.py @@ -107,7 +107,7 @@ def evaluate_sequences(gold_tag_sequences: List[List[Any]], pred_tag_sequences: return multi_class_result, confusion, weighted_f1 -def model_predict(model: nn.Module, position_indices: torch.tensor[int], sentences: List[List[str]]) -> torch.tensor[int]: +def model_predict(model: nn.Module, position_indices: torch.Tensor, sentences: List[List[str]]) -> torch.Tensor: """ A LemmaClassifierLSTM or LemmaClassifierWithTransformer is used to predict on a single text example, given the position index of the target token. @@ -121,9 +121,7 @@ def model_predict(model: nn.Module, position_indices: torch.tensor[int], sentenc """ with torch.no_grad(): logits = model(position_indices, sentences) # should be size (batch_size, output_size) - logging.info(f"Logits shape: {logits.shape} (should be size (batch_size, output_size))") predicted_class = torch.argmax(logits, dim=1) # should be size (batch_size, 1) - logging.info(f"Predicted class shape: {predicted_class.shape}, (should be size (batch_size, 1))") return predicted_class @@ -155,9 +153,13 @@ def evaluate_model(model: nn.Module, eval_path: str, verbose: bool = True, is_tr # load in eval data text_batches, index_batches, label_batches, _, label_decoder = utils.load_dataset(eval_path, label_decoder=model.label_decoder) + + + # TODO fix this in the future + text_batches, index_batches, label_batches = text_batches[: -1], index_batches[: -1], label_batches[: -1] - index_batches = torch.tensor(index_batches, device=device) - label_batches = torch.tensor(label_batches, device=device) + index_batches = torch.stack(index_batches).to(device) + label_batches = torch.stack(label_batches).to(device) logging.info(f"Evaluating on evaluation file {eval_path}") @@ -168,21 +170,17 @@ def evaluate_model(model: nn.Module, eval_path: str, verbose: bool = True, is_tr for sentences, pos_indices, labels in tqdm(zip(text_batches, index_batches, label_batches), "Evaluating examples from data file", total=len(text_batches)): pred = model_predict(model, pos_indices, sentences) # Pred should be size (batch_size, ) correct_preds = pred == labels - logging.info(f"Correct preds shape: {correct_preds.shape} (should be size (batch_size, 1))") correct += torch.sum(correct_preds) total += len(correct_preds) - pred_tags += pred.tolist() + pred_tags += [pred.tolist()] logging.info("Finished evaluating on dataset. Computing scores...") accuracy = correct / total - logging.info(f"Gold Tags: {gold_tags}") - logging.info(f"Pred Tags: {pred_tags}") - mc_results, confusion, weighted_f1 = evaluate_sequences(gold_tags, pred_tags, verbose=verbose) # add brackets around batches of gold and pred tags because each batch is an element within the sequences in this helper if verbose: - logging.info(f"Accuracy: {accuracy} ({correct}/{len(label_batches)})") + logging.info(f"Accuracy: {accuracy} ({correct}/{total})") return mc_results, confusion, accuracy, weighted_f1 diff --git a/stanza/models/lemma_classifier/model.py b/stanza/models/lemma_classifier/model.py index e96b4dc03e..cb828ba6f3 100644 --- a/stanza/models/lemma_classifier/model.py +++ b/stanza/models/lemma_classifier/model.py @@ -95,22 +95,23 @@ def forward(self, pos_indices: List[int], sentences: List[List[str]]): sentence_token_ids = [self.vocab_map.get(word.lower(), UNK_ID) for word in words] sentence_token_ids = torch.tensor(sentence_token_ids, device=next(self.parameters()).device) token_ids.append(sentence_token_ids) - - embedded = self.embedding(torch.tensor(token_ids)) + + token_ids = pad_sequence(token_ids, batch_first=True) + embedded = self.embedding(token_ids) if self.use_charlm: char_reps_forward = self.charmodel_forward.build_char_representation(sentences) # takes [[str]] char_reps_backward = self.charmodel_backward.build_char_representation(sentences) - embedded = torch.cat((embedded, char_reps_forward, char_reps_backward), 1) + char_reps_forward = pad_sequence(char_reps_forward, batch_first=True) + char_reps_backward = pad_sequence(char_reps_backward, batch_first=True) + + embedded = torch.cat((embedded, char_reps_forward, char_reps_backward), 2) - print(f"Embedding shape: {embedded.shape}. Should be size (batch_size, T, input_size)") # Should be size (batch_size, T, input_size) padded_sequences = pad_sequence(embedded, batch_first=True) lengths = torch.tensor([len(seq) for seq in embedded]) packed_sequences = pack_padded_sequence(padded_sequences, lengths, batch_first=True) - - print(f"Packed Sequences shape: {packed_sequences.shape}. Should be size (batch_size, input_size)") # should be size (batch_size, input_size) lstm_out, (hidden, _) = self.lstm(packed_sequences) @@ -118,11 +119,8 @@ def forward(self, pos_indices: List[int], sentences: List[List[str]]): unpacked_lstm_outputs, _ = pad_packed_sequence(lstm_out, batch_first=True) lstm_out = unpacked_lstm_outputs[torch.arange(unpacked_lstm_outputs.size(0)), pos_indices] - print(f"LSTM OUT Shape: {lstm_out.shape}. Should be size (batch_size, input_size)") # Should be size (batch_size, input_size) - # MLP forward pass output = self.mlp(lstm_out) - print(f"Output shape: {output.shape}. Should be size (batch_size, output_size)") # should be size (batch_size, output_size) return output def model_type(self): diff --git a/stanza/models/lemma_classifier/train_model.py b/stanza/models/lemma_classifier/train_model.py index a3f243fbe9..5fad39ef42 100644 --- a/stanza/models/lemma_classifier/train_model.py +++ b/stanza/models/lemma_classifier/train_model.py @@ -124,8 +124,9 @@ def train(self, num_epochs: int, save_name: str, args: Mapping, eval_file: str, logging.info(f"Loaded dataset successfully from {train_path}") logging.info(f"Using label decoder: {label_decoder} Output dimension: {self.output_dim}") - idx_batches, label_batches = torch.tensor(idx_batches, device=device), torch.tensor(label_batches, device=device) - logging.info(f"idx batches size: {idx_batches.shape}. label_batches shape {label_batches.shape}") + text_batches, idx_batches, label_batches = text_batches[:-1], idx_batches[:-1], label_batches[:-1] # TODO come up with a fix for this + + idx_batches, label_batches = torch.stack(idx_batches).to(device), torch.stack(label_batches).to(device) self.model = LemmaClassifierLSTM(self.vocab_size, self.embedding_dim, self.hidden_dim, self.output_dim, self.vocab_map, self.embeddings, label_decoder, charlm=self.use_charlm, charlm_forward_file=self.forward_charlm_file, charlm_backward_file=self.backward_charlm_file) @@ -157,14 +158,12 @@ def train(self, num_epochs: int, save_name: str, args: Mapping, eval_file: str, # Compute loss, which is different if using CE or BCEWithLogitsLoss if self.weighted_loss: # BCEWithLogitsLoss requires a vector for target where probability is 1 on the true label class, and 0 on others. # TODO: three classes? - targets = torch.tensor([torch.tensor([1, 0] if label == 0 else [0, 1]) for label in labels], dtype=torch.float32, device=device) + targets = torch.stack([torch.tensor([1, 0]) if label == 0 else torch.tensor([0, 1]) for label in labels]).to(dtype=torch.float32).to(device) # should be shape size (batch_size, 2) else: # CELoss accepts target as just raw label targets = labels - logging.info(f"targets shape {targets.shape}. Should be shape (batch_size, ) or (batch_size, output_dim)") # should be shape (batch_size, ) or (batch_size, 2) - loss = self.criterion(output, targets) loss.backward() diff --git a/stanza/models/lemma_classifier/utils.py b/stanza/models/lemma_classifier/utils.py index 1471eb0d0e..ccba31119c 100644 --- a/stanza/models/lemma_classifier/utils.py +++ b/stanza/models/lemma_classifier/utils.py @@ -12,7 +12,7 @@ def load_doc_from_conll_file(path: str): return stanza.utils.conll.CoNLL.conll2doc(path) -def load_dataset(data_path: str, batch_size=DEFAULT_BATCH_SIZE, get_counts: bool = False, label_decoder: dict = None) -> Tuple[List[List[str]], List[torch.Tensor[int]], List[torch.Tensor[int]], Mapping[int, int], Mapping[str, int]]: +def load_dataset(data_path: str, batch_size=DEFAULT_BATCH_SIZE, get_counts: bool = False, label_decoder: dict = None) -> Tuple[List[List[str]], List[torch.Tensor], List[torch.Tensor], Mapping[int, int], Mapping[str, int]]: """ Loads a data file into data batches for tokenized text sentences, token indices, and true labels for each sentence.