Skip to content

Commit

Permalink
Add attention layer number of heads to load function in base_model.py
Browse files Browse the repository at this point in the history
  • Loading branch information
SecroLoL committed Jan 14, 2024
1 parent 527b002 commit 865d05d
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion stanza/models/lemma_classifier/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def load(filename, args=None):
label_decoder=checkpoint['label_decoder'],
upos_emb_dim=saved_args['upos_emb_dim'],
upos_to_id=checkpoint['upos_to_id'],
num_heads=saved_args['num_heads'],
charlm=True,
charlm_forward_file=saved_args['charlm_forward_file'],
charlm_backward_file=saved_args['charlm_backward_file'])
Expand All @@ -100,7 +101,8 @@ def load(filename, args=None):
pt_embedding=word_embeddings,
label_decoder=checkpoint['label_decoder'],
upos_emb_dim=saved_args['upos_emb_dim'],
upos_to_id=checkpoint['upos_to_id'])
upos_to_id=checkpoint['upos_to_id'],
num_heads=saved_args['num_heads'])
elif model_type is ModelType.TRANSFORMER:
from stanza.models.lemma_classifier.transformer_baseline.model import LemmaClassifierWithTransformer

Expand Down

0 comments on commit 865d05d

Please sign in to comment.