Skip to content

Commit

Permalink
Add an apostrophe augmentation for a few different alternate or typo …
Browse files Browse the repository at this point in the history
…apostrophes

All appearances of apostrophe are augmented occasionally to other forms

The overall effect of this mostly averages out to neutral, although
oddly the Marathi treebank is significantly improved.
These treebanks are for UD 2.15; unlisted treebanks were unchanged

                                   orig   aug  diff
UD_Ancient_Hebrew-PTNK            98.28 98.08  0.20
UD_Arabic-PADT                    99.44 99.47 -0.03
UD_Coptic-Scriptorium             93.56 93.29  0.27
UD_Galician-TreeGal               99.45 99.36  0.09
UD_Georgian-GLC                   99.92 99.91  0.01
UD_Hebrew-HTB                     98.12 98.17 -0.05
UD_Hebrew-IAHLTknesset            98.72 98.75 -0.03
UD_Hebrew-IAHLTwiki               98.69 98.73 -0.04
UD_Italian-Old                    99.79 99.8  -0.01
UD_Italian-ParlaMint              99.93 99.9   0.03
UD_Italian-TWITTIRO               99.57 99.65 -0.08
UD_Latin-UDante                   99.7  99.75 -0.05
UD_Ligurian-GLT                   99.39 99.61 -0.22
UD_Maghrebi_Arabic_French-Arabizi 97.98 98.19 -0.21
UD_Marathi-UFAL                   95.18 96.23 -1.05
UD_Scottish_Gaelic-ARCOSG         99.89 99.87  0.02
UD_Spanish-GSD                   100.0  99.99  0.01
UD_Tamil-TTB                      95.43 95.44 -0.01
UD_Turkish-BOUN                   99.87 99.85  0.02
UD_Turkish-IMST                   99.65 99.54  0.11
UD_Welsh-CCG                      99.86 99.83  0.03
UD_Wolof-WTB                      99.83 99.84 -0.01
  total                                       -1.00
  • Loading branch information
AngledLuffa committed Nov 30, 2024
1 parent 7bd171c commit dad0fb8
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 0 deletions.
26 changes: 26 additions & 0 deletions stanza/models/mwt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
from collections import Counter
import logging

import torch

import stanza.models.common.seq2seq_constant as constant
Expand All @@ -13,19 +14,32 @@

logger = logging.getLogger('stanza')

# enforce that the MWT splitter knows about a couple different alternate apostrophes
# including covering some potential " typos
# setting the augmentation to a very low value should be enough to teach it
# about the unknown characters without messing up the predictions for other text
#
# 0x22, 0x27, 0x02BC, 0x02CA, 0x055A, 0x07F4, 0x2019, 0xFF07
APOS = ('"', "'", 'ʼ', 'ˊ', '՚', 'ߴ', '’', ''')

# TODO: can wrap this in a Pytorch DataLoader, such as what was done for POS
class DataLoader:
def __init__(self, doc, batch_size, args, vocab=None, evaluation=False, expand_unk_vocab=False):
self.batch_size = batch_size
self.args = args
self.augment_apos = args.get('augment_apos', 0.0)
self.evaluation = evaluation
self.doc = doc

data = self.load_doc(self.doc, evaluation=self.evaluation)

# handle vocab
if vocab is None:
assert self.evaluation == False # for eval vocab must exist
self.vocab = self.init_vocab(data)
if self.augment_apos > 0 and any(x in self.vocab for x in APOS):
for apos in APOS:
self.vocab.add_unit(apos)
elif expand_unk_vocab:
self.vocab = DeltaVocab(data, vocab)
else:
Expand Down Expand Up @@ -54,9 +68,21 @@ def init_vocab(self, data):
vocab = Vocab(data, self.args['shorthand'])
return vocab

def maybe_augment_apos(self, datum):
for original in APOS:
if original in datum[0]:
if random.uniform(0,1) < self.augment_apos:
replacement = random.choice(APOS)
datum = (datum[0].replace(original, replacement), datum[1].replace(original, replacement))
break
return datum


def process(self, data):
processed = []
for d in data:
if not self.evaluation and self.augment_apos > 0:
d = self.maybe_augment_apos(d)
src = list(d[0])
src = [constant.SOS] + src + [constant.EOS]
tgt_in, tgt_out = self.prepare_target(self.vocab, d)
Expand Down
6 changes: 6 additions & 0 deletions stanza/models/mwt/vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,9 @@ def build_vocab(self):

self._id2unit = constant.VOCAB_PREFIX + list(sorted(list(counter.keys()), key=lambda k: counter[k], reverse=True))
self._unit2id = {w:i for i, w in enumerate(self._id2unit)}

def add_unit(self, unit):
if unit in self._unit2id:
return
self._unit2id[unit] = len(self._id2unit)
self._id2unit.append(unit)
1 change: 1 addition & 0 deletions stanza/models/mwt_expander.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def build_argparse():
parser.add_argument('--attn_type', default='soft', choices=['soft', 'mlp', 'linear', 'deep'], help='Attention type')
parser.add_argument('--no_copy', dest='copy', action='store_false', help='Do not use copy mechanism in MWT expansion. By default copy mechanism is used to improve generalization.')

parser.add_argument('--augment_apos', default=0.01, type=float, help='At training time, how much to augment |\'| to |"| |’| |ʼ|')
parser.add_argument('--force_exact_pieces', default=None, action='store_true', help='If possible, make the text of the pieces of the MWT add up to the token itself. (By default, this is determined from the dataset.)')
parser.add_argument('--no_force_exact_pieces', dest='force_exact_pieces', action='store_false', help="Don't make the text of the pieces of the MWT add up to the token itself. (By default, this is determined from the dataset.)")

Expand Down

0 comments on commit dad0fb8

Please sign in to comment.