Skip to content

Commit

Permalink
Put the save() code in one place as well, although I'm unsatisfied wi…
Browse files Browse the repository at this point in the history
…th the need for hidden circular imports... perhaps a better solution would be a trainer class which knows about the different LemmaClassifier classes
  • Loading branch information
AngledLuffa committed Dec 31, 2023
1 parent 3b7df18 commit 1aa8dd0
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 49 deletions.
45 changes: 40 additions & 5 deletions stanza/models/lemma_classifier/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,45 @@

import logging

from abc import ABC, abstractmethod

import os

import torch
import torch.nn as nn

from stanza.models.common.foundation_cache import load_pretrain
from stanza.models.lemma_classifier.constants import ModelType
from stanza.models.lemma_classifier.model import LemmaClassifierLSTM
from stanza.models.lemma_classifier.transformer_baseline.model import LemmaClassifierWithTransformer

logger = logging.getLogger('stanza.lemmaclassifier')

class LemmaClassifier(nn.Module):
def __init__(self):
super(LemmaClassifier, self).__init__()
class LemmaClassifier(ABC, nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def save(self, save_name, args):
"""
Save the model to the given path, possibly with some args
TODO: keep all the relevant args in the model
"""
save_dir = os.path.split(save_name)[0]
if save_dir:
os.makedirs(save_dir, exist_ok=True)
state_dict = {
"params": self.state_dict(),
"label_decoder": self.label_decoder,
"model_type": self.model_type(),
"args": args,
}
torch.save(state_dict, save_name)
return state_dict

@abstractmethod
def model_type(self):
"""
return a ModelType
"""

@staticmethod
def load(filename, args=None):
Expand All @@ -32,6 +58,13 @@ def load(filename, args=None):

model_type = checkpoint['model_type']
if model_type is ModelType.LSTM:
# TODO: if anyone can suggest a way to avoid this circular import
# (or better yet, avoid the load method knowing about subclasses)
# please do so
# maybe the subclassing is not necessary and we just put
# save & load in the trainer
from stanza.models.lemma_classifier.model import LemmaClassifierLSTM

saved_args = checkpoint['args']
# other model args are part of the model and cannot be changed for evaluation or pipeline
# the file paths might be relevant, though
Expand Down Expand Up @@ -70,6 +103,8 @@ def load(filename, args=None):
pt_embedding=embeddings,
label_decoder=checkpoint['label_decoder'])
elif model_type is ModelType.TRANSFORMER:
from stanza.models.lemma_classifier.transformer_baseline.model import LemmaClassifierWithTransformer

output_dim = len(checkpoint['label_decoder'])
saved_args = checkpoint['args']
bert_model = saved_args['bert_model']
Expand Down
8 changes: 6 additions & 2 deletions stanza/models/lemma_classifier/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@

from stanza.models.common.vocab import UNK_ID
from stanza.models.lemma_classifier import utils
from stanza.models.lemma_classifier.constants import *
from stanza.models.lemma_classifier.base_model import LemmaClassifier
from stanza.models.lemma_classifier.constants import ModelType

class LemmaClassifierLSTM(nn.Module):
class LemmaClassifierLSTM(LemmaClassifier):
"""
Model architecture:
Extracts word embeddings over the sentence, passes embeddings into a bi-LSTM to get a sentence encoding.
Expand Down Expand Up @@ -106,3 +107,6 @@ def forward(self, pos_index: int, words: List[str]):
# MLP forward pass
output = self.mlp(lstm_out)
return output

def model_type(self):
return ModelType.LSTM
23 changes: 3 additions & 20 deletions stanza/models/lemma_classifier/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,23 +85,6 @@ def __init__(self, embedding_file: str, hidden_dim: int, use_charlm: bool = Fals
else:
raise ValueError("Must enter a valid loss function (e.g. 'ce' or 'weighted_bce')")

def save_checkpoint(self, save_name: str, model: LemmaClassifierLSTM, args: Mapping) -> Mapping:
"""
Saves model checkpoint with a current state dict (params) and a label decoder on the dataset.
If the save path doesn't exist, it will create it.
"""
save_dir = os.path.split(save_name)[0]
if save_dir:
os.makedirs(save_dir, exist_ok=True)
state_dict = {
"params": model.state_dict(),
"label_decoder": model.label_decoder,
"model_type": ModelType.LSTM,
"args": args,
}
torch.save(state_dict, save_name)
return state_dict

def configure_weighted_loss(self, label_decoder: Mapping, counts: Mapping):
"""
If applicable, this function will update the loss function of the LemmaClassifierLSTM model to become BCEWithLogitsLoss.
Expand Down Expand Up @@ -186,14 +169,14 @@ def train(self, num_epochs: int, save_name: str, args: Mapping, eval_file: str,
self.optimizer.step()

if eval_file:
_, _, _, f1 = evaluate_model(self.model, label_decoder, eval_file, is_training=True)
_, _, _, f1 = evaluate_model(self.model, eval_file, is_training=True)
logging.info(f"Weighted f1 for model: {f1}")
if f1 > best_f1:
best_f1 = f1
self.save_checkpoint(save_name, self.model, args)
self.model.save(save_name, args)
logging.info(f"New best model: weighted f1 score of {f1}.")
else:
self.save_checkpoint(save_name, self.model, args)
self.model.save(save_name, args)

logging.info(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item()}")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,23 +66,6 @@ def configure_weighted_loss(self, label_decoder: Mapping, counts: Mapping):
logging.info(f"Using weights {weights} for weighted loss.")
self.criterion = nn.BCEWithLogitsLoss(weight=weights)

def save_checkpoint(self, save_name: str, model: LemmaClassifierWithTransformer, args: Mapping) -> Mapping:
"""
Saves model checkpoint with a current state dict (params) and a label decoder on the dataset.
If the save path doesn't exist, it will create it.
"""
save_dir = os.path.split(save_name)[0]
if save_dir:
os.makedirs(save_dir, exist_ok=True)
state_dict = {
"params": model.state_dict(),
"label_decoder": model.label_decoder,
"model_type": ModelType.TRANSFORMER,
"args": args,
}
torch.save(state_dict, save_name)
return state_dict

def train(self, num_epochs: int, save_name: str, args: Mapping, eval_file: str, **kwargs):

"""
Expand Down Expand Up @@ -153,14 +136,14 @@ def train(self, num_epochs: int, save_name: str, args: Mapping, eval_file: str,
logging.info(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item()}")
if eval_file:
# Evaluate model on dev set to see if it should be saved.
_, _, _, f1 = evaluate_model(self.model, label_decoder, eval_file, is_training=True)
_, _, _, f1 = evaluate_model(self.model, eval_file, is_training=True)
logging.info(f"Weighted f1 for model: {f1}")
if f1 > best_f1:
best_f1 = f1
self.save_checkpoint(save_name, self.model, args)
self.model.save(save_name, args)
logging.info(f"New best model: weighted f1 score of {f1}.")
else:
self.save_checkpoint(save_name, self.model, args)
self.model.save(save_name, args)


def main(args=None):
Expand Down
9 changes: 7 additions & 2 deletions stanza/models/lemma_classifier/transformer_baseline/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
from transformers import AutoTokenizer, AutoModel
from typing import Mapping, List, Tuple, Any

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
from stanza.models.lemma_classifier.base_model import LemmaClassifier
from stanza.models.lemma_classifier.constants import ModelType

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

class LemmaClassifierWithTransformer(nn.Module):

class LemmaClassifierWithTransformer(LemmaClassifier):
def __init__(self, output_dim: int, transformer_name: str, label_decoder: Mapping):
"""
Model architecture:
Expand Down Expand Up @@ -72,3 +74,6 @@ def forward(self, pos_index: int, text: List[str]):
# pass to the MLP
output = self.mlp(target_pos_embedding)
return output

def model_type(self):
return ModelType.TRANSFORMER

0 comments on commit 1aa8dd0

Please sign in to comment.