diff --git a/stanza/models/mwt/data.py b/stanza/models/mwt/data.py index dba58ca3a7..1c67eea94a 100644 --- a/stanza/models/mwt/data.py +++ b/stanza/models/mwt/data.py @@ -3,6 +3,7 @@ import os from collections import Counter import logging + import torch import stanza.models.common.seq2seq_constant as constant @@ -13,11 +14,20 @@ logger = logging.getLogger('stanza') +# enforce that the MWT splitter knows about a couple different alternate apostrophes +# including covering some potential " typos +# setting the augmentation to a very low value should be enough to teach it +# about the unknown characters without messing up the predictions for other text +# +# 0x22, 0x27, 0x02BC, 0x02CA, 0x055A, 0x07F4, 0x2019, 0xFF07 +APOS = ('"', "'", 'ʼ', 'ˊ', '՚', 'ߴ', '’', ''') + # TODO: can wrap this in a Pytorch DataLoader, such as what was done for POS class DataLoader: def __init__(self, doc, batch_size, args, vocab=None, evaluation=False, expand_unk_vocab=False): self.batch_size = batch_size self.args = args + self.augment_apos = args.get('augment_apos', 0.0) self.evaluation = evaluation self.doc = doc @@ -25,7 +35,11 @@ def __init__(self, doc, batch_size, args, vocab=None, evaluation=False, expand_u # handle vocab if vocab is None: + assert self.evaluation == False # for eval vocab must exist self.vocab = self.init_vocab(data) + if self.augment_apos > 0 and any(x in self.vocab for x in APOS): + for apos in APOS: + self.vocab.add_unit(apos) elif expand_unk_vocab: self.vocab = DeltaVocab(data, vocab) else: @@ -54,9 +68,21 @@ def init_vocab(self, data): vocab = Vocab(data, self.args['shorthand']) return vocab + def maybe_augment_apos(self, datum): + for original in APOS: + if original in datum[0]: + if random.uniform(0,1) < self.augment_apos: + replacement = random.choice(APOS) + datum = (datum[0].replace(original, replacement), datum[1].replace(original, replacement)) + break + return datum + + def process(self, data): processed = [] for d in data: + if not self.evaluation and self.augment_apos > 0: + d = self.maybe_augment_apos(d) src = list(d[0]) src = [constant.SOS] + src + [constant.EOS] tgt_in, tgt_out = self.prepare_target(self.vocab, d) diff --git a/stanza/models/mwt/vocab.py b/stanza/models/mwt/vocab.py index 776a41c3af..0c861e7a49 100644 --- a/stanza/models/mwt/vocab.py +++ b/stanza/models/mwt/vocab.py @@ -11,3 +11,9 @@ def build_vocab(self): self._id2unit = constant.VOCAB_PREFIX + list(sorted(list(counter.keys()), key=lambda k: counter[k], reverse=True)) self._unit2id = {w:i for i, w in enumerate(self._id2unit)} + + def add_unit(self, unit): + if unit in self._unit2id: + return + self._unit2id[unit] = len(self._id2unit) + self._id2unit.append(unit) diff --git a/stanza/models/mwt_expander.py b/stanza/models/mwt_expander.py index 8f68cfca66..87eadbb8a0 100644 --- a/stanza/models/mwt_expander.py +++ b/stanza/models/mwt_expander.py @@ -61,6 +61,7 @@ def build_argparse(): parser.add_argument('--attn_type', default='soft', choices=['soft', 'mlp', 'linear', 'deep'], help='Attention type') parser.add_argument('--no_copy', dest='copy', action='store_false', help='Do not use copy mechanism in MWT expansion. By default copy mechanism is used to improve generalization.') + parser.add_argument('--augment_apos', default=0.01, type=float, help='At training time, how much to augment |\'| to |"| |’| |ʼ|') parser.add_argument('--force_exact_pieces', default=None, action='store_true', help='If possible, make the text of the pieces of the MWT add up to the token itself. (By default, this is determined from the dataset.)') parser.add_argument('--no_force_exact_pieces', dest='force_exact_pieces', action='store_false', help="Don't make the text of the pieces of the MWT add up to the token itself. (By default, this is determined from the dataset.)")