Skip to content

Commit

Permalink
Enable editing LR per layer in Transformer model and update batch pro…
Browse files Browse the repository at this point in the history
…cessing to avoid cutting off last batch
  • Loading branch information
SecroLoL committed Jan 7, 2024
1 parent 8fea749 commit 08f5ce1
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 8 deletions.
1 change: 0 additions & 1 deletion stanza/models/lemma_classifier/evaluate_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ def evaluate_model(model: nn.Module, eval_path: str, verbose: bool = True, is_tr
# load in eval data
text_batches, index_batches, label_batches, _, label_decoder = utils.load_dataset(eval_path, label_decoder=model.label_decoder)


# TODO fix this in the future
text_batches, index_batches, label_batches = text_batches[: -1], index_batches[: -1], label_batches[: -1]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,25 @@ def configure_weighted_loss(self, label_decoder: Mapping, counts: Mapping):
logging.info(f"Using weights {weights} for weighted loss.")
self.criterion = nn.BCEWithLogitsLoss(weight=weights)

def set_layer_learning_rates(self, transformer_lr: float, mlp_lr: float) -> torch.optim:
"""
Sets learning rates for each layer of the model.
Currently, the model has the transformer layer and the MLP layer, so these are tweakable.
Returns (torch.optim): An Adam optimizer with the learning rates adjusted per layer.
"""
transformer_params, mlp_params = [], []
for name, param in self.model.named_parameters():
if 'transformer' in name:
transformer_params.append(param)
elif 'mlp' in name:
mlp_params.append(param)
optimizer = optim.Adam([
{"params": transformer_params, "lr": transformer_lr},
{"params": mlp_params, "lr": mlp_lr}
])
return optimizer

def train(self, num_epochs: int, save_name: str, args: Mapping, eval_file: str, **kwargs):

"""
Expand All @@ -90,17 +109,17 @@ def train(self, num_epochs: int, save_name: str, args: Mapping, eval_file: str,
self.output_dim = len(label_decoder)
logging.info(f"Using label decoder : {label_decoder}")

# TODO: fix this to make it not disregard last batch, and instead pad it or some other idea
text_batches, position_batches, label_batches = text_batches[:-1], position_batches[:-1], label_batches[:-1]
# # TODO: fix this to make it not disregard last batch, and instead pad it or some other idea
# text_batches, position_batches, label_batches = text_batches[:-1], position_batches[:-1], label_batches[:-1]

# Move data to device
label_batches = torch.stack(label_batches).to(device)
position_batches = torch.stack(position_batches).to(device)
# # Move data to device
# label_batches = torch.stack(label_batches).to(device)
# position_batches = torch.stack(position_batches).to(device)

assert len(text_batches) == len(position_batches) == len(label_batches), f"Input batch sizes did not match ({len(text_batches)}, {len(position_batches)}, {len(label_batches)})."

self.model = LemmaClassifierWithTransformer(output_dim=self.output_dim, transformer_name=self.transformer_name, label_decoder=label_decoder)
self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
self.optimizer = self.set_layer_learning_rates(transformer_lr=self.lr/2, mlp_lr=self.lr) # Adam optimizer

self.model.to(device)
self.model.transformer.to(device)
Expand All @@ -118,7 +137,8 @@ def train(self, num_epochs: int, save_name: str, args: Mapping, eval_file: str,
for epoch in range(num_epochs):
# go over entire dataset with each epoch
for sentences, positions, labels in tqdm(zip(text_batches, position_batches, label_batches), total=len(text_batches)):

assert len(sentences) == len(positions) == len(labels), f"Input sentences, positions, and labels are of unequal length ({len(sentences), len(positions), len(labels)})"

self.optimizer.zero_grad()
outputs = self.model(positions, sentences)

Expand Down

0 comments on commit 08f5ce1

Please sign in to comment.