Skip to content

Commit

Permalink
Refactor the loading code from evaluate_models into a separate module…
Browse files Browse the repository at this point in the history
…. This will represent a base model for the two types of LemmaClassifier. Doing so will make it much easier to load as part of a Pipeline
  • Loading branch information
AngledLuffa committed Dec 25, 2023
1 parent 99d8b36 commit f29ec46
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 135 deletions.
80 changes: 80 additions & 0 deletions stanza/models/lemma_classifier/base_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""
Base class for the LemmaClassifier types.
Versions include LSTM and Transformer varieties
"""

import logging

import torch
import torch.nn as nn

from stanza.models.common.foundation_cache import load_pretrain
from stanza.models.lemma_classifier.constants import ModelType
from stanza.models.lemma_classifier.model import LemmaClassifierLSTM
from stanza.models.lemma_classifier.transformer_baseline.model import LemmaClassifierWithTransformer

logger = logging.getLogger('stanza.lemmaclassifier')

class LemmaClassifier(nn.Module):
def __init__(self):
super(LemmaClassifier, self).__init__()

@staticmethod
def load(filename, args=None):
try:
checkpoint = torch.load(filename, lambda storage, loc: storage)
except BaseException:
logger.exception("Cannot load model from %s", filename)
raise

logger.debug("Loading LemmaClassifier model from %s", filename)

model_type = checkpoint['model_type']
if model_type is ModelType.LSTM:
saved_args = checkpoint['args']
# other model args are part of the model and cannot be changed for evaluation or pipeline
# the file paths might be relevant, though
keep_args = ['wordvec_pretrain_file', 'charlm_forward_file', 'charlm_backward_file']
for arg in keep_args:
if args.get(arg, None) is not None:
saved_args[arg] = args[arg]

# TODO: refactor loading the pretrain (also done in the trainer)
pt = load_pretrain(args['wordvec_pretrain_file'])
emb_matrix = pt.emb
embeddings = nn.Embedding.from_pretrained(torch.from_numpy(emb_matrix))
vocab_map = { word.replace('\xa0', ' '): i for i, word in enumerate(pt.vocab) }
vocab_size = emb_matrix.shape[0]
embedding_dim = emb_matrix.shape[1]

if saved_args['use_charlm']:
# Evaluate charlm
model = LemmaClassifierLSTM(vocab_size=vocab_size,
embedding_dim=embedding_dim,
hidden_dim=saved_args['hidden_dim'],
output_dim=saved_args['output_dim'],
vocab_map=vocab_map,
pt_embedding=embeddings,
charlm=True,
charlm_forward_file=saved_args['charlm_forward_file'],
charlm_backward_file=saved_args['charlm_backward_file'])
else:
# Evaluate standard model (bi-LSTM with GloVe embeddings, no charlm)
model = LemmaClassifierLSTM(vocab_size=vocab_size,
embedding_dim=embedding_dim,
hidden_dim=saved_args['hidden_dim'],
output_dim=saved_args['output_dim'],
vocab_map=vocab_map,
pt_embedding=embeddings)
elif model_type is ModelType.TRANSFORMER:
saved_args = checkpoint['args']
output_dim = saved_args['output_dim']
bert_model = saved_args['bert_model']
model = LemmaClassifierWithTransformer(output_dim=output_dim, transformer_name=bert_model)
else:
raise ValueError("Unknown model type %s" % model_type)

model.load_state_dict(checkpoint['params'])
# TODO: make the label_decoder part of the model itself
return model, checkpoint['label_decoder']
74 changes: 10 additions & 64 deletions stanza/models/lemma_classifier/evaluate_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,9 @@

import stanza

from stanza.models.common.foundation_cache import load_pretrain
from stanza.models.common.utils import default_device
from stanza.models.lemma_classifier import utils
from stanza.models.lemma_classifier.constants import *
from stanza.models.lemma_classifier.base_model import LemmaClassifier
from stanza.models.lemma_classifier.model import LemmaClassifierLSTM
from stanza.models.lemma_classifier.transformer_baseline.model import LemmaClassifierWithTransformer

Expand Down Expand Up @@ -127,7 +126,7 @@ def model_predict(model: nn.Module, position_idx: int, words: List[str]) -> int:
return predicted_class


def evaluate_model(model: nn.Module, model_path: str, eval_path: str, verbose: bool = True, is_training = False) -> Tuple[Mapping, Mapping, float, float]:
def evaluate_model(model: nn.Module, label_decoder: Mapping, eval_path: str, verbose: bool = True, is_training: bool = False) -> Tuple[Mapping, Mapping, float, float]:
"""
Helper function for model evaluation
Expand All @@ -147,22 +146,18 @@ def evaluate_model(model: nn.Module, model_path: str, eval_path: str, verbose: b
"""
# load model
device = default_device()

model_state = torch.load(model_path)
model.load_state_dict(model_state['params'])
model.to(device)

if not is_training:
model.eval() # set to eval mode

# load in eval data
label_decoder = model_state['label_decoder']
text_batches, index_batches, label_batches, _, label_decoder = utils.load_dataset(eval_path, label_decoder=label_decoder)

index_batches = torch.tensor(index_batches, device=device)
label_batches = torch.tensor(label_batches, device=device)

logging.info(f"Evaluating model from {model_path} on evaluation file {eval_path}")
logging.info(f"Evaluating on evaluation file {eval_path}")

correct = 0
gold_tags, pred_tags = [label_batches], []
Expand Down Expand Up @@ -207,64 +202,15 @@ def main(args=None):
args = parser.parse_args(args)

logging.info("Running training script with the following args:")
for arg in vars(args):
logging.info(f"{arg}: {getattr(args, arg)}")
args = vars(args)
for arg in args:
logging.info(f"{arg}: {args[arg]}")
logging.info("------------------------------------------------------------")

vocab_size = args.vocab_size
embedding_dim = args.embedding_dim
hidden_dim = args.hidden_dim
output_dim = args.output_dim
wordvec_pretrain_file = args.wordvec_pretrain_file
use_charlm = args.charlm
forward_charlm_file = args.charlm_forward_file
backward_charlm_file = args.charlm_backward_file
save_name = args.save_name
model_type = args.model_type
eval_path = args.eval_file

if model_type.lower() == "lstm":
# TODO: refactor
pt = load_pretrain(wordvec_pretrain_file)
emb_matrix = pt.emb
embeddings = nn.Embedding.from_pretrained(torch.from_numpy(emb_matrix))
vocab_map = { word.replace('\xa0', ' '): i for i, word in enumerate(pt.vocab) }
vocab_size = emb_matrix.shape[0]
embedding_dim = emb_matrix.shape[1]

if use_charlm:
# Evaluate charlm
model = LemmaClassifierLSTM(vocab_size=vocab_size,
embedding_dim=embedding_dim,
hidden_dim=hidden_dim,
output_dim=output_dim,
vocab_map=vocab_map,
pt_embedding=embeddings,
charlm=True,
charlm_forward_file=forward_charlm_file,
charlm_backward_file=backward_charlm_file)
else:
# Evaluate standard model (bi-LSTM with GloVe embeddings, no charlm)
model = LemmaClassifierLSTM(vocab_size=vocab_size,
embedding_dim=embedding_dim,
hidden_dim=hidden_dim,
output_dim=output_dim,
vocab_map=vocab_map,
pt_embedding=embeddings)
elif model_type.lower() == "roberta":
# Evaluate Transformer (BERT or ROBERTA)
model = LemmaClassifierWithTransformer(output_dim=output_dim, transformer_name="roberta-base")
elif model_type.lower() == "bert":
# Evaluate Transformer (BERT or ROBERTA)
model = LemmaClassifierWithTransformer(output_dim=output_dim, transformer_name="bert-base-uncased")
elif model_type.lower() == "transformer":
model = LemmaClassifierWithTransformer(output_dim=output_dim, transformer_name=args.bert_model)
else:
raise ValueError("Unknown model type %s" % model_type)

logging.info(f"Attempting evaluation of model from {save_name} on file {eval_path}")

mcc_results, confusion, acc, weighted_f1 = evaluate_model(model, save_name, eval_path)
logging.info(f"Attempting evaluation of model from {args['save_name']} on file {args['eval_file']}")
model, label_decoder = LemmaClassifier.load(args['save_name'], args)

mcc_results, confusion, acc, weighted_f1 = evaluate_model(model, label_decoder, args['eval_file'])

logging.info(f"MCC Results: {dict(mcc_results)}")
logging.info("______________________________________________")
Expand Down
58 changes: 24 additions & 34 deletions stanza/models/lemma_classifier/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class LemmaClassifierTrainer():
Class to assist with training a LemmaClassifierLSTM
"""

def __init__(self, vocab_size: int, embedding_file: str, embedding_dim: int, hidden_dim: int, output_dim: int = 2, use_charlm: bool = False, **kwargs):
def __init__(self, vocab_size: int, embedding_file: str, embedding_dim: int, hidden_dim: int, output_dim: int = 2, use_charlm: bool = False, eval_file: str = None, **kwargs):
"""
Initializes the LemmaClassifierTrainer class.
Expand All @@ -40,13 +40,13 @@ def __init__(self, vocab_size: int, embedding_file: str, embedding_dim: int, hid
hidden_dim (int): Size of hidden vectors in LSTM layers
output_dim (int, optional): Size of output vector from MLP layer. Defaults to 2.
use_charlm (bool, optional): Whether to use charlm embeddings as well. Defaults to False.
eval_file (str): File used as dev set to evaluate which model gets saved
Kwargs:
forward_charlm_file (str): Path to the forward pass embeddings for the charlm
backward_charlm_file (str): Path to the backward pass embeddings for the charlm
lr (float): Learning rate, defaults to 0.001.
loss_func (str): Which loss function to use (either 'ce' or 'weighted_bce')
eval_file (str): File used as dev set to evaluate which model gets saved
Raises:
FileNotFoundError: If the forward charlm file is not present
Expand Down Expand Up @@ -90,7 +90,7 @@ def __init__(self, vocab_size: int, embedding_file: str, embedding_dim: int, hid

self.optimizer = optim.Adam(self.model.parameters(), lr=kwargs.get("lr", 0.001))

def save_checkpoint(self, save_name: str, state_dict: Mapping, label_decoder: Mapping) -> Mapping:
def save_checkpoint(self, save_name: str, state_dict: Mapping, label_decoder: Mapping, args: Mapping) -> Mapping:
"""
Saves model checkpoint with a current state dict (params) and a label decoder on the dataset.
If the save path doesn't exist, it will create it.
Expand All @@ -102,6 +102,7 @@ def save_checkpoint(self, save_name: str, state_dict: Mapping, label_decoder: Ma
"params": state_dict,
"label_decoder": label_decoder,
"model_type": ModelType.LSTM,
"args": args,
}
torch.save(state_dict, save_name)
return state_dict
Expand All @@ -120,19 +121,7 @@ def configure_weighted_loss(self, label_decoder: Mapping, counts: Mapping):
logging.info(f"Using weights {weights} for weighted loss.")
self.criterion = nn.BCEWithLogitsLoss(weight=weights)

def update_best_checkpoint(self, state_dict: Mapping, best_model: Mapping, best_f1: float, save_name: str, eval_path: str) -> Tuple[Mapping, float]:
"""
Attempts to update the best available version of the model by evaluating the current model's state against the existing
best model on the dev set. The model with a better weighted F1 will be chosen.
"""
_, _, _, f1 = evaluate_model(self.model, save_name, eval_path, is_training=True)
logging.info(f"Weighted f1 for model: {f1}")
if f1 > best_f1:
best_model = state_dict
logging.info(f"New best model: weighted f1 score of {f1}.")
return best_model, max(f1, best_f1)

def train(self, texts_batch: List[List[str]], positions_batch: List[int], labels_batch: List[int], num_epochs: int, save_name: str, **kwargs) -> None:
def train(self, num_epochs: int, save_name: str, args: Mapping, eval_file: str, **kwargs) -> None:

"""
Trains a model on batches of texts, position indices of the target token, and labels (lemma annotation) for the target token.
Expand Down Expand Up @@ -170,7 +159,7 @@ def train(self, texts_batch: List[List[str]], positions_batch: List[int], labels
self.configure_weighted_loss(label_decoder, counts)

# Put the criterion on GPU too
logging.info(f"Criterion on {next(self.model.parameters()).device}")
logging.debug(f"Criterion on {next(self.model.parameters()).device}")
self.criterion = self.criterion.to(next(self.model.parameters()).device)

best_model, best_f1 = None, float("-inf") # Used for saving checkpoints of the model
Expand All @@ -180,10 +169,10 @@ def train(self, texts_batch: List[List[str]], positions_batch: List[int], labels
for texts, position, label in tqdm(zip(texts_batch, positions_batch, labels_batch), total=len(texts_batch)):
if position < 0 or position > len(texts) - 1: # validate position index
raise ValueError(f"Found position {position} in text: {texts}, which is not possible.")

self.optimizer.zero_grad()
output = self.model(position, texts)

# Compute loss, which is different if using CE or BCEWithLogitsLoss
if self.weighted_loss: # BCEWithLogitsLoss requires a vector for target where probability is 1 on the true label class, and 0 on others.
# TODO: three classes?
Expand All @@ -196,20 +185,19 @@ def train(self, texts_batch: List[List[str]], positions_batch: List[int], labels

loss.backward()
self.optimizer.step()

# Evaluate model on dev set to see if it should be saved.
state_dict = self.save_checkpoint(save_name, self.model.state_dict(), label_decoder)
logging.info(f"Saved temp model state dict for epoch [{epoch + 1}/{num_epochs}] to {save_name}")

if kwargs.get("eval_file"):
best_model, best_f1 = self.update_best_checkpoint(state_dict, best_model, best_f1, save_name, kwargs.get("eval_file"))


if eval_file:
_, _, _, f1 = evaluate_model(self.model, label_decoder, eval_file, is_training=True)
logging.info(f"Weighted f1 for model: {f1}")
if f1 > best_f1:
best_f1 = f1
self.save_checkpoint(save_name, self.model.state_dict(), label_decoder, args)
logging.info(f"New best model: weighted f1 score of {f1}.")
else:
self.save_checkpoint(save_name, self.model.state_dict(), label_decoder, args)

logging.info(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item()}")
logging.info("Embedding norm: %s", torch.linalg.norm(self.model.embedding.weight))

# Save the best model from training
self.save_checkpoint(save_name, best_model.get("params"), best_model.get("label_decoder"))
logging.info(f"Saved final model state dict to {save_name} (weighted F1: {best_f1}).")

def build_argparse():
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -249,14 +237,16 @@ def main(args=None):
weighted_loss = args.weighted_loss
eval_file = args.eval_file

args = vars(args)

if os.path.exists(save_name):
raise FileExistsError(f"Save name {save_name} already exists. Training would override existing data. Aborting...")
if not os.path.exists(train_file):
raise FileNotFoundError(f"Training file {train_file} not found. Try again with a valid path.")

logging.info("Running training script with the following args:")
for arg in vars(args):
logging.info(f"{arg}: {getattr(args, arg)}")
for arg in args:
logging.info(f"{arg}: {args[arg]}")
logging.info("------------------------------------------------------------")

trainer = LemmaClassifierTrainer(vocab_size=vocab_size,
Expand All @@ -272,7 +262,7 @@ def main(args=None):
)

trainer.train(
[], [], [], num_epochs=num_epochs, save_name=save_name, train_path=train_file, eval_file=eval_file
num_epochs=num_epochs, save_name=save_name, args=args, train_path=train_file, eval_file=eval_file
)

if __name__ == "__main__":
Expand Down
Loading

0 comments on commit f29ec46

Please sign in to comment.