Skip to content

Commit

Permalink
Add batch processing for Transformer model and Trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
SecroLoL committed Jan 6, 2024
1 parent 492c650 commit 464ef4f
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,18 @@ def train(self, num_epochs: int, save_name: str, args: Mapping, eval_file: str,
device = default_device()

if kwargs.get("train_path"):
texts_batch, positions_batch, labels_batch, counts, label_decoder = utils.load_dataset(kwargs.get("train_path"), get_counts=self.weighted_loss)
text_batches, position_batches, label_batches, counts, label_decoder = utils.load_dataset(kwargs.get("train_path"), get_counts=self.weighted_loss)
self.output_dim = len(label_decoder)
logging.info(f"Using label decoder : {label_decoder}")

# TODO: fix this to make it not disregard last batch, and instead pad it or some other idea
text_batches, position_batches, label_batches = text_batches[:-1], position_batches[:-1], label_batches[:-1]

# Move data to device
labels_batch = torch.tensor(labels_batch, device=device)
positions_batch = torch.tensor(positions_batch, device=device)
labels_batch = torch.stack(labels_batch).to(device)
positions_batch = torch.stack(positions_batch).to(device)

assert len(texts_batch) == len(positions_batch) == len(labels_batch), f"Input batch sizes did not match ({len(texts_batch)}, {len(positions_batch)}, {len(labels_batch)})."
assert len(text_batches) == len(position_batches) == len(label_batches), f"Input batch sizes did not match ({len(text_batches)}, {len(position_batches)}, {len(label_batches)})."

self.model = LemmaClassifierWithTransformer(output_dim=self.output_dim, transformer_name=self.transformer_name, label_decoder=label_decoder)
self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
Expand All @@ -114,21 +117,18 @@ def train(self, num_epochs: int, save_name: str, args: Mapping, eval_file: str,
best_model, best_f1 = None, float("-inf")
for epoch in range(num_epochs):
# go over entire dataset with each epoch
for texts, position, label in tqdm(zip(texts_batch, positions_batch, labels_batch), total=len(texts_batch)):
if position < 0 or position > len(texts) - 1: # validate position index
raise ValueError(f"Found position {position} in text: {texts}, which is not possible.")

for sentences, positions, labels in tqdm(zip(text_batches, position_batches, label_batches), total=len(text_batches)):

self.optimizer.zero_grad()
output = self.model(position, texts)
outputs = self.model(positions, sentences)

# Compute loss, which is different if using CE or BCEWithLogitsLoss
if self.weighted_loss: # BCEWithLogitsLoss requires a vector for target where probability is 1 on the true label class, and 0 on others.
target_vec = [1, 0] if label == 0 else [0, 1]
target = torch.tensor(target_vec, dtype=torch.float32, device=device)
targets = torch.stack([torch.tensor([1, 0]) if label == 0 else torch.tensor([0, 1]) for label in labels]).to(dtype=torch.float32).to(device)
else: # CELoss accepts target as just raw label
target = torch.tensor(label, dtype=torch.long, device=device)
targets = labels

loss = self.criterion(output, target)
loss = self.criterion(outputs, targets)

loss.backward()
self.optimizer.step()
Expand Down
25 changes: 12 additions & 13 deletions stanza/models/lemma_classifier/transformer_baseline/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from transformers import AutoTokenizer, AutoModel
from typing import Mapping, List, Tuple, Any

from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, pad_sequence
from stanza.models.lemma_classifier.base_model import LemmaClassifier
from stanza.models.lemma_classifier.constants import ModelType

Expand Down Expand Up @@ -44,7 +44,7 @@ def __init__(self, output_dim: int, transformer_name: str, label_decoder: Mappin
)
self.label_decoder = label_decoder

def forward(self, pos_index: int, text: List[str]):
def forward(self, idx_positions: List[int], sentences: List[List[str]]):
"""
Args:
Expand All @@ -54,25 +54,24 @@ def forward(self, pos_index: int, text: List[str]):
Returns the logits of the MLP
"""

# Get the transformer embeddings
input_ids = self.tokenizer.convert_tokens_to_ids(text)
# Get the transformer embedding IDs for each token in each sentence
input_ids = [torch.tesnsor(self.tokenizer.convert_tokens_to_ids(sent)) for sent in sentences]
lengths = [len(sent) for sent in input_ids]
input_ids = pad_sequence(input_ids, batch_first=True)


# Convert tokens to IDs and put them into a tensor
input_ids_tensor = torch.tensor([input_ids], device=next(self.parameters()).device) # move data to device as well
packed_input = pack_padded_sequence(input_ids, lengths, batch_first=True)
# Forward pass through Transformer
with torch.no_grad():
outputs = self.transformer(input_ids_tensor)
outputs = self.transformer(input_ids=packed_input)

# Get embeddings for all tokens
last_hidden_state = outputs.last_hidden_state
token_embeddings = last_hidden_state[0]

pos_index = torch.tensor(pos_index, device=next(self.parameters()).device)
# Get target embedding
target_pos_embedding = token_embeddings[pos_index]
unpacked_outputs = pad_packed_sequence(last_hidden_state, batch_first=True)
embeddings = unpacked_outputs[torch.arange(unpacked_outputs.size(0)), idx_positions]

# pass to the MLP
output = self.mlp(target_pos_embedding)
output = self.mlp(embeddings)
return output

def model_type(self):
Expand Down

0 comments on commit 464ef4f

Please sign in to comment.