Skip to content

Commit

Permalink
Fix issue with taking out the last batch: do not load index and label…
Browse files Browse the repository at this point in the history
… batches as tensors. Leave batching to forward pass
  • Loading branch information
SecroLoL committed Jan 7, 2024
1 parent 08f5ce1 commit a02ea0b
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 18 deletions.
6 changes: 0 additions & 6 deletions stanza/models/lemma_classifier/evaluate_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,6 @@ def evaluate_model(model: nn.Module, eval_path: str, verbose: bool = True, is_tr

# load in eval data
text_batches, index_batches, label_batches, _, label_decoder = utils.load_dataset(eval_path, label_decoder=model.label_decoder)

# TODO fix this in the future
text_batches, index_batches, label_batches = text_batches[: -1], index_batches[: -1], label_batches[: -1]

index_batches = torch.stack(index_batches).to(device)
label_batches = torch.stack(label_batches).to(device)

logging.info(f"Evaluating on evaluation file {eval_path}")

Expand Down
4 changes: 0 additions & 4 deletions stanza/models/lemma_classifier/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,6 @@ def train(self, num_epochs: int, save_name: str, args: Mapping, eval_file: str,
logging.info(f"Loaded dataset successfully from {train_path}")
logging.info(f"Using label decoder: {label_decoder} Output dimension: {self.output_dim}")

text_batches, idx_batches, label_batches = text_batches[:-1], idx_batches[:-1], label_batches[:-1] # TODO come up with a fix for this

idx_batches, label_batches = torch.stack(idx_batches).to(device), torch.stack(label_batches).to(device)

self.model = LemmaClassifierLSTM(self.vocab_size, self.embedding_dim, self.hidden_dim, self.output_dim, self.vocab_map, self.embeddings, label_decoder,
charlm=self.use_charlm, charlm_forward_file=self.forward_charlm_file, charlm_backward_file=self.backward_charlm_file)
self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,18 +108,12 @@ def train(self, num_epochs: int, save_name: str, args: Mapping, eval_file: str,
text_batches, position_batches, label_batches, counts, label_decoder = utils.load_dataset(kwargs.get("train_path"), get_counts=self.weighted_loss)
self.output_dim = len(label_decoder)
logging.info(f"Using label decoder : {label_decoder}")

# # TODO: fix this to make it not disregard last batch, and instead pad it or some other idea
# text_batches, position_batches, label_batches = text_batches[:-1], position_batches[:-1], label_batches[:-1]

# # Move data to device
# label_batches = torch.stack(label_batches).to(device)
# position_batches = torch.stack(position_batches).to(device)

assert len(text_batches) == len(position_batches) == len(label_batches), f"Input batch sizes did not match ({len(text_batches)}, {len(position_batches)}, {len(label_batches)})."

self.model = LemmaClassifierWithTransformer(output_dim=self.output_dim, transformer_name=self.transformer_name, label_decoder=label_decoder)
self.optimizer = self.set_layer_learning_rates(transformer_lr=self.lr/2, mlp_lr=self.lr) # Adam optimizer
# self.optimizer = self.set_layer_learning_rates(transformer_lr=self.lr/2, mlp_lr=self.lr) # Adam optimizer
self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)

self.model.to(device)
self.model.transformer.to(device)
Expand Down

0 comments on commit a02ea0b

Please sign in to comment.