diff --git a/stanza/models/lemma_classifier/base_model.py b/stanza/models/lemma_classifier/base_model.py index a1c7e3607b..37098606ef 100644 --- a/stanza/models/lemma_classifier/base_model.py +++ b/stanza/models/lemma_classifier/base_model.py @@ -87,6 +87,7 @@ def load(filename, args=None): label_decoder=checkpoint['label_decoder'], upos_emb_dim=saved_args['upos_emb_dim'], upos_to_id=checkpoint['upos_to_id'], + num_heads=saved_args['num_heads'], charlm=True, charlm_forward_file=saved_args['charlm_forward_file'], charlm_backward_file=saved_args['charlm_backward_file']) @@ -100,7 +101,8 @@ def load(filename, args=None): pt_embedding=word_embeddings, label_decoder=checkpoint['label_decoder'], upos_emb_dim=saved_args['upos_emb_dim'], - upos_to_id=checkpoint['upos_to_id']) + upos_to_id=checkpoint['upos_to_id'], + num_heads=saved_args['num_heads']) elif model_type is ModelType.TRANSFORMER: from stanza.models.lemma_classifier.transformer_baseline.model import LemmaClassifierWithTransformer