diff --git a/stanza/models/mwt/data.py b/stanza/models/mwt/data.py index dba58ca3a7..b8b0b2f1e9 100644 --- a/stanza/models/mwt/data.py +++ b/stanza/models/mwt/data.py @@ -13,11 +13,18 @@ logger = logging.getLogger('stanza') +# enforce that the MWT splitter knows about a couple different alternate apostrophes +# including covering some potential " typos +# setting the augmentation to a very low value should be enough to teach it +# about the unknown characters without messing up the predictions for other text +APOS = ('"', '’', 'ʼ') + # TODO: can wrap this in a Pytorch DataLoader, such as what was done for POS class DataLoader: def __init__(self, doc, batch_size, args, vocab=None, evaluation=False, expand_unk_vocab=False): self.batch_size = batch_size self.args = args + self.augment_apos = args.get('augment_apos', 0.0) self.evaluation = evaluation self.doc = doc @@ -25,7 +32,11 @@ def __init__(self, doc, batch_size, args, vocab=None, evaluation=False, expand_u # handle vocab if vocab is None: + assert self.evaluation == False # for eval vocab must exist self.vocab = self.init_vocab(data) + if self.augment_apos > 0 and "'" in self.vocab: + for apos in APOS: + self.vocab.add_unit(apos) elif expand_unk_vocab: self.vocab = DeltaVocab(data, vocab) else: @@ -54,9 +65,18 @@ def init_vocab(self, data): vocab = Vocab(data, self.args['shorthand']) return vocab + def maybe_augment_apos(self, datum): + if "'" in datum[0] and random.uniform(0,1) < self.augment_apos: + replacement = random.choice(APOS) + datum = (datum[0].replace("'", replacement), datum[1].replace("'", replacement)) + return datum + + def process(self, data): processed = [] for d in data: + if not self.evaluation and self.augment_apos > 0: + d = self.maybe_augment_apos(d) src = list(d[0]) src = [constant.SOS] + src + [constant.EOS] tgt_in, tgt_out = self.prepare_target(self.vocab, d) diff --git a/stanza/models/mwt/vocab.py b/stanza/models/mwt/vocab.py index 776a41c3af..0c861e7a49 100644 --- a/stanza/models/mwt/vocab.py +++ b/stanza/models/mwt/vocab.py @@ -11,3 +11,9 @@ def build_vocab(self): self._id2unit = constant.VOCAB_PREFIX + list(sorted(list(counter.keys()), key=lambda k: counter[k], reverse=True)) self._unit2id = {w:i for i, w in enumerate(self._id2unit)} + + def add_unit(self, unit): + if unit in self._unit2id: + return + self._unit2id[unit] = len(self._id2unit) + self._id2unit.append(unit) diff --git a/stanza/models/mwt_expander.py b/stanza/models/mwt_expander.py index 8f68cfca66..87eadbb8a0 100644 --- a/stanza/models/mwt_expander.py +++ b/stanza/models/mwt_expander.py @@ -61,6 +61,7 @@ def build_argparse(): parser.add_argument('--attn_type', default='soft', choices=['soft', 'mlp', 'linear', 'deep'], help='Attention type') parser.add_argument('--no_copy', dest='copy', action='store_false', help='Do not use copy mechanism in MWT expansion. By default copy mechanism is used to improve generalization.') + parser.add_argument('--augment_apos', default=0.01, type=float, help='At training time, how much to augment |\'| to |"| |’| |ʼ|') parser.add_argument('--force_exact_pieces', default=None, action='store_true', help='If possible, make the text of the pieces of the MWT add up to the token itself. (By default, this is determined from the dataset.)') parser.add_argument('--no_force_exact_pieces', dest='force_exact_pieces', action='store_false', help="Don't make the text of the pieces of the MWT add up to the token itself. (By default, this is determined from the dataset.)")