diff --git a/stanza/models/lemma_classifier/base_model.py b/stanza/models/lemma_classifier/base_model.py index fb15ce02c3..4e715eae06 100644 --- a/stanza/models/lemma_classifier/base_model.py +++ b/stanza/models/lemma_classifier/base_model.py @@ -6,19 +6,45 @@ import logging +from abc import ABC, abstractmethod + +import os + import torch import torch.nn as nn from stanza.models.common.foundation_cache import load_pretrain from stanza.models.lemma_classifier.constants import ModelType -from stanza.models.lemma_classifier.model import LemmaClassifierLSTM -from stanza.models.lemma_classifier.transformer_baseline.model import LemmaClassifierWithTransformer logger = logging.getLogger('stanza.lemmaclassifier') -class LemmaClassifier(nn.Module): - def __init__(self): - super(LemmaClassifier, self).__init__() +class LemmaClassifier(ABC, nn.Module): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def save(self, save_name, args): + """ + Save the model to the given path, possibly with some args + + TODO: keep all the relevant args in the model + """ + save_dir = os.path.split(save_name)[0] + if save_dir: + os.makedirs(save_dir, exist_ok=True) + state_dict = { + "params": self.state_dict(), + "label_decoder": self.label_decoder, + "model_type": self.model_type(), + "args": args, + } + torch.save(state_dict, save_name) + return state_dict + + @abstractmethod + def model_type(self): + """ + return a ModelType + """ @staticmethod def load(filename, args=None): @@ -32,6 +58,13 @@ def load(filename, args=None): model_type = checkpoint['model_type'] if model_type is ModelType.LSTM: + # TODO: if anyone can suggest a way to avoid this circular import + # (or better yet, avoid the load method knowing about subclasses) + # please do so + # maybe the subclassing is not necessary and we just put + # save & load in the trainer + from stanza.models.lemma_classifier.model import LemmaClassifierLSTM + saved_args = checkpoint['args'] # other model args are part of the model and cannot be changed for evaluation or pipeline # the file paths might be relevant, though @@ -70,6 +103,8 @@ def load(filename, args=None): pt_embedding=embeddings, label_decoder=checkpoint['label_decoder']) elif model_type is ModelType.TRANSFORMER: + from stanza.models.lemma_classifier.transformer_baseline.model import LemmaClassifierWithTransformer + output_dim = len(checkpoint['label_decoder']) saved_args = checkpoint['args'] bert_model = saved_args['bert_model'] diff --git a/stanza/models/lemma_classifier/model.py b/stanza/models/lemma_classifier/model.py index 76c239d8c7..9b1fb3c766 100644 --- a/stanza/models/lemma_classifier/model.py +++ b/stanza/models/lemma_classifier/model.py @@ -6,9 +6,10 @@ from stanza.models.common.vocab import UNK_ID from stanza.models.lemma_classifier import utils -from stanza.models.lemma_classifier.constants import * +from stanza.models.lemma_classifier.base_model import LemmaClassifier +from stanza.models.lemma_classifier.constants import ModelType -class LemmaClassifierLSTM(nn.Module): +class LemmaClassifierLSTM(LemmaClassifier): """ Model architecture: Extracts word embeddings over the sentence, passes embeddings into a bi-LSTM to get a sentence encoding. @@ -106,3 +107,6 @@ def forward(self, pos_index: int, words: List[str]): # MLP forward pass output = self.mlp(lstm_out) return output + + def model_type(self): + return ModelType.LSTM diff --git a/stanza/models/lemma_classifier/train_model.py b/stanza/models/lemma_classifier/train_model.py index 3569fbabd8..2b313479bd 100644 --- a/stanza/models/lemma_classifier/train_model.py +++ b/stanza/models/lemma_classifier/train_model.py @@ -85,23 +85,6 @@ def __init__(self, embedding_file: str, hidden_dim: int, use_charlm: bool = Fals else: raise ValueError("Must enter a valid loss function (e.g. 'ce' or 'weighted_bce')") - def save_checkpoint(self, save_name: str, model: LemmaClassifierLSTM, args: Mapping) -> Mapping: - """ - Saves model checkpoint with a current state dict (params) and a label decoder on the dataset. - If the save path doesn't exist, it will create it. - """ - save_dir = os.path.split(save_name)[0] - if save_dir: - os.makedirs(save_dir, exist_ok=True) - state_dict = { - "params": model.state_dict(), - "label_decoder": model.label_decoder, - "model_type": ModelType.LSTM, - "args": args, - } - torch.save(state_dict, save_name) - return state_dict - def configure_weighted_loss(self, label_decoder: Mapping, counts: Mapping): """ If applicable, this function will update the loss function of the LemmaClassifierLSTM model to become BCEWithLogitsLoss. @@ -186,14 +169,14 @@ def train(self, num_epochs: int, save_name: str, args: Mapping, eval_file: str, self.optimizer.step() if eval_file: - _, _, _, f1 = evaluate_model(self.model, label_decoder, eval_file, is_training=True) + _, _, _, f1 = evaluate_model(self.model, eval_file, is_training=True) logging.info(f"Weighted f1 for model: {f1}") if f1 > best_f1: best_f1 = f1 - self.save_checkpoint(save_name, self.model, args) + self.model.save(save_name, args) logging.info(f"New best model: weighted f1 score of {f1}.") else: - self.save_checkpoint(save_name, self.model, args) + self.model.save(save_name, args) logging.info(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item()}") diff --git a/stanza/models/lemma_classifier/transformer_baseline/baseline_trainer.py b/stanza/models/lemma_classifier/transformer_baseline/baseline_trainer.py index 8a7ccb7593..2d4ca293c4 100644 --- a/stanza/models/lemma_classifier/transformer_baseline/baseline_trainer.py +++ b/stanza/models/lemma_classifier/transformer_baseline/baseline_trainer.py @@ -66,23 +66,6 @@ def configure_weighted_loss(self, label_decoder: Mapping, counts: Mapping): logging.info(f"Using weights {weights} for weighted loss.") self.criterion = nn.BCEWithLogitsLoss(weight=weights) - def save_checkpoint(self, save_name: str, model: LemmaClassifierWithTransformer, args: Mapping) -> Mapping: - """ - Saves model checkpoint with a current state dict (params) and a label decoder on the dataset. - If the save path doesn't exist, it will create it. - """ - save_dir = os.path.split(save_name)[0] - if save_dir: - os.makedirs(save_dir, exist_ok=True) - state_dict = { - "params": model.state_dict(), - "label_decoder": model.label_decoder, - "model_type": ModelType.TRANSFORMER, - "args": args, - } - torch.save(state_dict, save_name) - return state_dict - def train(self, num_epochs: int, save_name: str, args: Mapping, eval_file: str, **kwargs): """ @@ -153,14 +136,14 @@ def train(self, num_epochs: int, save_name: str, args: Mapping, eval_file: str, logging.info(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item()}") if eval_file: # Evaluate model on dev set to see if it should be saved. - _, _, _, f1 = evaluate_model(self.model, label_decoder, eval_file, is_training=True) + _, _, _, f1 = evaluate_model(self.model, eval_file, is_training=True) logging.info(f"Weighted f1 for model: {f1}") if f1 > best_f1: best_f1 = f1 - self.save_checkpoint(save_name, self.model, args) + self.model.save(save_name, args) logging.info(f"New best model: weighted f1 score of {f1}.") else: - self.save_checkpoint(save_name, self.model, args) + self.model.save(save_name, args) def main(args=None): diff --git a/stanza/models/lemma_classifier/transformer_baseline/model.py b/stanza/models/lemma_classifier/transformer_baseline/model.py index 7ce8797022..61cdf93781 100644 --- a/stanza/models/lemma_classifier/transformer_baseline/model.py +++ b/stanza/models/lemma_classifier/transformer_baseline/model.py @@ -7,11 +7,13 @@ from transformers import AutoTokenizer, AutoModel from typing import Mapping, List, Tuple, Any -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +from stanza.models.lemma_classifier.base_model import LemmaClassifier +from stanza.models.lemma_classifier.constants import ModelType +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') -class LemmaClassifierWithTransformer(nn.Module): +class LemmaClassifierWithTransformer(LemmaClassifier): def __init__(self, output_dim: int, transformer_name: str, label_decoder: Mapping): """ Model architecture: @@ -72,3 +74,6 @@ def forward(self, pos_index: int, text: List[str]): # pass to the MLP output = self.mlp(target_pos_embedding) return output + + def model_type(self): + return ModelType.TRANSFORMER