Skip to content

Commit

Permalink
Build the correct simple_tag vocab when loading a model missing the s…
Browse files Browse the repository at this point in the history
…imple_tag
  • Loading branch information
AngledLuffa committed Nov 21, 2023
1 parent 7b2d5c3 commit a0aa5f5
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions stanza/models/ner/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from stanza.models.common.foundation_cache import NoTransformerFoundationCache
from stanza.models.common.trainer import Trainer as BaseTrainer
from stanza.models.common.vocab import VOCAB_PREFIX, VOCAB_PREFIX_SIZE
from stanza.models.common.vocab import VOCAB_PREFIX, VOCAB_PREFIX_SIZE, CompositeVocab
from stanza.models.common import utils, loss
from stanza.models.ner.model import NERTagger
from stanza.models.ner.vocab import MultiVocab
Expand Down Expand Up @@ -184,8 +184,16 @@ def load(self, filename, pretrain=None, args=None, foundation_cache=None):
# but the tensors after the base CLF won't have an effect on the inference
# so they can just be ignored anyway
if 'simple_tag' not in self.vocab:
# TODO: build the correct vocab
self.vocab['simple_tag'] = self.vocab['tag']
fake_tag_data = []
fake_len = len(self.vocab['tag'])
for fake_idx in range(fake_len):
items = self.vocab['tag'].items(fake_idx)
for tag in items:
fake_multi = ['O'] * fake_len
fake_multi[fake_idx] = tag.split("-")[-1]
fake_tag_data.append([[fake_multi]])
simple_tagvocab = CompositeVocab(fake_tag_data, self.args['shorthand'], idx=0, sep=None)
self.vocab['simple_tag'] = simple_tagvocab

emb_matrix=None
if pretrain is not None:
Expand Down

0 comments on commit a0aa5f5

Please sign in to comment.