Skip to content

Commit

Permalink
Add a test of loading. Use it to fix loading bug
Browse files Browse the repository at this point in the history
  • Loading branch information
AngledLuffa committed Jan 13, 2024
1 parent c346ff6 commit 6e12731
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
8 changes: 5 additions & 3 deletions stanza/models/lemma_classifier/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,11 @@ def load(filename, args=None):
# the file paths might be relevant, though
keep_args = ['wordvec_pretrain_file', 'charlm_forward_file', 'charlm_backward_file']
for arg in keep_args:
if args.get(arg, None) is not None:
if args is not None and args.get(arg, None) is not None:
saved_args[arg] = args[arg]

# TODO: refactor loading the pretrain (also done in the trainer)
pt = load_pretrain(args['wordvec_pretrain_file'])
pt = load_pretrain(saved_args['wordvec_pretrain_file'])
emb_matrix = pt.emb
word_embeddings = nn.Embedding.from_pretrained(torch.from_numpy(emb_matrix))
vocab_map = { word.replace('\xa0', ' '): i for i, word in enumerate(pt.vocab) }
Expand Down Expand Up @@ -98,7 +98,9 @@ def load(filename, args=None):
output_dim=len(checkpoint['label_decoder']),
vocab_map=vocab_map,
pt_embedding=word_embeddings,
label_decoder=checkpoint['label_decoder'])
label_decoder=checkpoint['label_decoder'],
upos_emb_dim=saved_args['upos_emb_dim'],
upos_to_id=checkpoint['upos_to_id'])
elif model_type is ModelType.TRANSFORMER:
from stanza.models.lemma_classifier.transformer_baseline.model import LemmaClassifierWithTransformer

Expand Down
3 changes: 3 additions & 0 deletions stanza/tests/lemma_classifier/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]

from stanza.models.lemma_classifier import train_model
from stanza.models.lemma_classifier.base_model import LemmaClassifier
from stanza.models.lemma_classifier.evaluate_models import evaluate_model
from stanza.models.lemma_classifier.transformer_baseline import baseline_trainer

Expand All @@ -28,6 +29,8 @@ def test_train_lstm(tmp_path, pretrain_file):
'--train_file', train_file,
'--eval_file', eval_file]
trainer = train_model.main(train_args)

model = LemmaClassifier.load(save_name, None)
evaluate_model(trainer.model, eval_file)

def test_train_transformer(tmp_path, pretrain_file):
Expand Down

0 comments on commit 6e12731

Please sign in to comment.