From 865d05dbd53616e88c36f4ee365396e34c50a6dc Mon Sep 17 00:00:00 2001 From: Alex Shan Date: Sat, 13 Jan 2024 17:08:31 -0800 Subject: [PATCH] Add attention layer number of heads to load function in base_model.py --- stanza/models/lemma_classifier/base_model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/stanza/models/lemma_classifier/base_model.py b/stanza/models/lemma_classifier/base_model.py index a1c7e3607b..37098606ef 100644 --- a/stanza/models/lemma_classifier/base_model.py +++ b/stanza/models/lemma_classifier/base_model.py @@ -87,6 +87,7 @@ def load(filename, args=None): label_decoder=checkpoint['label_decoder'], upos_emb_dim=saved_args['upos_emb_dim'], upos_to_id=checkpoint['upos_to_id'], + num_heads=saved_args['num_heads'], charlm=True, charlm_forward_file=saved_args['charlm_forward_file'], charlm_backward_file=saved_args['charlm_backward_file']) @@ -100,7 +101,8 @@ def load(filename, args=None): pt_embedding=word_embeddings, label_decoder=checkpoint['label_decoder'], upos_emb_dim=saved_args['upos_emb_dim'], - upos_to_id=checkpoint['upos_to_id']) + upos_to_id=checkpoint['upos_to_id'], + num_heads=saved_args['num_heads']) elif model_type is ModelType.TRANSFORMER: from stanza.models.lemma_classifier.transformer_baseline.model import LemmaClassifierWithTransformer