Skip to content

Commit

Permalink
Edit batch processing and debug
Browse files Browse the repository at this point in the history
  • Loading branch information
SecroLoL committed Jan 6, 2024
1 parent 464ef4f commit b468519
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def train(self, num_epochs: int, save_name: str, args: Mapping, eval_file: str,
text_batches, position_batches, label_batches = text_batches[:-1], position_batches[:-1], label_batches[:-1]

# Move data to device
labels_batch = torch.stack(labels_batch).to(device)
positions_batch = torch.stack(positions_batch).to(device)
label_batches = torch.stack(label_batches).to(device)
position_batches = torch.stack(position_batches).to(device)

assert len(text_batches) == len(position_batches) == len(label_batches), f"Input batch sizes did not match ({len(text_batches)}, {len(position_batches)}, {len(label_batches)})."

Expand Down
18 changes: 7 additions & 11 deletions stanza/models/lemma_classifier/transformer_baseline/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self, output_dim: int, transformer_name: str, label_decoder: Mappin

# Choose transformer
self.transformer_name = transformer_name
self.tokenizer = AutoTokenizer.from_pretrained(transformer_name)
self.tokenizer = AutoTokenizer.from_pretrained(transformer_name, use_fast=True, add_prefix_space=True)
self.transformer = AutoModel.from_pretrained(transformer_name)
config = self.transformer.config

Expand All @@ -55,21 +55,17 @@ def forward(self, idx_positions: List[int], sentences: List[List[str]]):
"""

# Get the transformer embedding IDs for each token in each sentence
input_ids = [torch.tesnsor(self.tokenizer.convert_tokens_to_ids(sent)) for sent in sentences]
lengths = [len(sent) for sent in input_ids]
input_ids = pad_sequence(input_ids, batch_first=True)


packed_input = pack_padded_sequence(input_ids, lengths, batch_first=True)
device = next(self.transformer.parameters()).device
tokenized_inputs = self.tokenizer(sentences, padding=True, truncation=True, return_tensors="pt", is_split_into_words=True)
tokenized_inputs = {key: val.to(device) for key, val in tokenized_inputs.items()}

# Forward pass through Transformer
outputs = self.transformer(input_ids=packed_input)
outputs = self.transformer(**tokenized_inputs)

# Get embeddings for all tokens
last_hidden_state = outputs.last_hidden_state

unpacked_outputs = pad_packed_sequence(last_hidden_state, batch_first=True)
embeddings = unpacked_outputs[torch.arange(unpacked_outputs.size(0)), idx_positions]

embeddings = last_hidden_state[torch.arange(last_hidden_state.size(0)), idx_positions]
# pass to the MLP
output = self.mlp(embeddings)
return output
Expand Down

0 comments on commit b468519

Please sign in to comment.