Skip to content

Commit

Permalink
Merge branch 'dev' into wl_coref_chains
Browse files Browse the repository at this point in the history
  • Loading branch information
AngledLuffa committed Nov 20, 2023
2 parents 6277d92 + 701d2a5 commit 1ccf0a4
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 6 deletions.
15 changes: 15 additions & 0 deletions stanza/models/common/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""
A couple more specific FileNotFoundError exceptions
The idea being, the caller can catch it and report a more useful error resolution
"""

import errno

class ForwardCharlmNotFoundError(FileNotFoundError):
def __init__(self, msg, filename):
super().__init__(errno.ENOENT, msg, filename)

class BackwardCharlmNotFoundError(FileNotFoundError):
def __init__(self, msg, filename):
super().__init__(errno.ENOENT, msg, filename)
5 changes: 3 additions & 2 deletions stanza/models/ner/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, pack_sequence, pad_sequence, PackedSequence

from stanza.models.common.data import map_to_ids, get_long_tensor
from stanza.models.common.exceptions import ForwardCharlmNotFoundError, BackwardCharlmNotFoundError
from stanza.models.common.packed_lstm import PackedLSTM
from stanza.models.common.dropout import WordDropout, LockedDropout
from stanza.models.common.char_model import CharacterModel, CharacterLanguageModel
Expand Down Expand Up @@ -95,9 +96,9 @@ def add_unsaved_module(name, module):
if self.args['char'] and self.args['char_emb_dim'] > 0:
if self.args['charlm']:
if args['charlm_forward_file'] is None or not os.path.exists(args['charlm_forward_file']):
raise FileNotFoundError('Could not find forward character model: {} Please specify with --charlm_forward_file'.format(args['charlm_forward_file']))
raise ForwardCharlmNotFoundError('Could not find forward character model: {} Please specify with --charlm_forward_file'.format(args['charlm_forward_file']), args['charlm_forward_file'])
if args['charlm_backward_file'] is None or not os.path.exists(args['charlm_backward_file']):
raise FileNotFoundError('Could not find backward character model: {} Please specify with --charlm_backward_file'.format(args['charlm_backward_file']))
raise BackwardCharlmNotFoundError('Could not find backward character model: {} Please specify with --charlm_backward_file'.format(args['charlm_backward_file']), args['charlm_backward_file'])
add_unsaved_module('charmodel_forward', CharacterLanguageModel.load(args['charlm_forward_file'], finetune=False))
add_unsaved_module('charmodel_backward', CharacterLanguageModel.load(args['charlm_backward_file'], finetune=False))
input_size += self.charmodel_forward.hidden_dim() + self.charmodel_backward.hidden_dim()
Expand Down
8 changes: 7 additions & 1 deletion stanza/pipeline/ner_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging

from stanza.models.common import doc
from stanza.models.common.exceptions import ForwardCharlmNotFoundError, BackwardCharlmNotFoundError
from stanza.models.common.utils import unsort
from stanza.models.ner.data import DataLoader
from stanza.models.ner.trainer import Trainer
Expand Down Expand Up @@ -68,7 +69,12 @@ def _set_up_model(self, config, pipeline, device):
if predict_tagset is not None:
args['predict_tagset'] = predict_tagset

trainer = Trainer(args=args, model_file=model_path, pretrain=pretrain, device=device, foundation_cache=pipeline.foundation_cache)
try:
trainer = Trainer(args=args, model_file=model_path, pretrain=pretrain, device=device, foundation_cache=pipeline.foundation_cache)
except ForwardCharlmNotFoundError as e:
raise ForwardCharlmNotFoundError("Could not find the forward charlm %s. Please specify the correct path with ner_forward_charlm_path" % e.filename, e.filename) from None
except BackwardCharlmNotFoundError as e:
raise BackwardCharlmNotFoundError("Could not find the backward charlm %s. Please specify the correct path with ner_backward_charlm_path" % e.filename, e.filename) from None
self.trainers.append(trainer)

self._trainer = self.trainers[0]
Expand Down
10 changes: 7 additions & 3 deletions stanza/resources/default_packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,11 +768,15 @@

def known_nicknames():
"""
Return a set of all the transformer nicknames
Return a list of all the transformer nicknames
We return a list so that we can sort them in decreasing key length
"""
nicknames = set(value for key, value in TRANSFORMER_NICKNAMES.items())
nicknames = list(value for key, value in TRANSFORMER_NICKNAMES.items())

# previously unspecific transformers get "transformer" as the nickname
nicknames.add("transformer")
nicknames.append("transformer")

nicknames = sorted(nicknames, key=lambda x: -len(x))

return nicknames

0 comments on commit 1ccf0a4

Please sign in to comment.