Skip to content

Commit

Permalink
Add an apostrophe augmentation for a couple different alternate or ty…
Browse files Browse the repository at this point in the history
…po apostrophes
  • Loading branch information
AngledLuffa committed Nov 29, 2024
1 parent ddaba93 commit 5cac28d
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 0 deletions.
20 changes: 20 additions & 0 deletions stanza/models/mwt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,30 @@

logger = logging.getLogger('stanza')

# enforce that the MWT splitter knows about a couple different alternate apostrophes
# including covering some potential " typos
# setting the augmentation to a very low value should be enough to teach it
# about the unknown characters without messing up the predictions for other text
APOS = ('"', '’', 'ʼ')

# TODO: can wrap this in a Pytorch DataLoader, such as what was done for POS
class DataLoader:
def __init__(self, doc, batch_size, args, vocab=None, evaluation=False, expand_unk_vocab=False):
self.batch_size = batch_size
self.args = args
self.augment_apos = args.get('augment_apos', 0.0)
self.evaluation = evaluation
self.doc = doc

data = self.load_doc(self.doc, evaluation=self.evaluation)

# handle vocab
if vocab is None:
assert self.evaluation == False # for eval vocab must exist
self.vocab = self.init_vocab(data)
if self.augment_apos > 0 and "'" in self.vocab:
for apos in APOS:
self.vocab.add_unit(apos)
elif expand_unk_vocab:
self.vocab = DeltaVocab(data, vocab)
else:
Expand Down Expand Up @@ -54,9 +65,18 @@ def init_vocab(self, data):
vocab = Vocab(data, self.args['shorthand'])
return vocab

def maybe_augment_apos(self, datum):
if "'" in datum[0] and random.uniform(0,1) < self.augment_apos:
replacement = random.choice(APOS)
datum = (datum[0].replace("'", replacement), datum[1].replace("'", replacement))
return datum


def process(self, data):
processed = []
for d in data:
if not self.evaluation and self.augment_apos > 0:
d = self.maybe_augment_apos(d)
src = list(d[0])
src = [constant.SOS] + src + [constant.EOS]
tgt_in, tgt_out = self.prepare_target(self.vocab, d)
Expand Down
6 changes: 6 additions & 0 deletions stanza/models/mwt/vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,9 @@ def build_vocab(self):

self._id2unit = constant.VOCAB_PREFIX + list(sorted(list(counter.keys()), key=lambda k: counter[k], reverse=True))
self._unit2id = {w:i for i, w in enumerate(self._id2unit)}

def add_unit(self, unit):
if unit in self._unit2id:
return
self._unit2id[unit] = len(self._id2unit)
self._id2unit.append(unit)
1 change: 1 addition & 0 deletions stanza/models/mwt_expander.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def build_argparse():
parser.add_argument('--attn_type', default='soft', choices=['soft', 'mlp', 'linear', 'deep'], help='Attention type')
parser.add_argument('--no_copy', dest='copy', action='store_false', help='Do not use copy mechanism in MWT expansion. By default copy mechanism is used to improve generalization.')

parser.add_argument('--augment_apos', default=0.01, type=float, help='At training time, how much to augment |\'| to |"| |’| |ʼ|')
parser.add_argument('--force_exact_pieces', default=None, action='store_true', help='If possible, make the text of the pieces of the MWT add up to the token itself. (By default, this is determined from the dataset.)')
parser.add_argument('--no_force_exact_pieces', dest='force_exact_pieces', action='store_false', help="Don't make the text of the pieces of the MWT add up to the token itself. (By default, this is determined from the dataset.)")

Expand Down

0 comments on commit 5cac28d

Please sign in to comment.