Skip to content

Commit

Permalink
Convert from words to numbers at inference time rather than ahead of …
Browse files Browse the repository at this point in the history
…time. Will be a tiny bit slower when training, but will allow for making edits per batch
  • Loading branch information
AngledLuffa committed Nov 29, 2024
1 parent ee8891c commit 7bd171c
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions stanza/models/mwt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

logger = logging.getLogger('stanza')

# TODO: can wrap this in a Pytorch DataLoader, such as what was done for POS
class DataLoader:
def __init__(self, doc, batch_size, args, vocab=None, evaluation=False, expand_unk_vocab=False):
self.batch_size = batch_size
Expand All @@ -36,7 +37,6 @@ def __init__(self, doc, batch_size, args, vocab=None, evaluation=False, expand_u
data = random.sample(data, keep)
logger.debug("Subsample training set with rate {:g}".format(args['sample_train']))

data = self.preprocess(data)
# shuffle for training
if not self.evaluation:
indices = list(range(len(data)))
Expand All @@ -54,7 +54,7 @@ def init_vocab(self, data):
vocab = Vocab(data, self.args['shorthand'])
return vocab

def preprocess(self, data):
def process(self, data):
processed = []
for d in data:
src = list(d[0])
Expand Down Expand Up @@ -83,6 +83,7 @@ def __getitem__(self, key):
if key < 0 or key >= len(self.data):
raise IndexError
batch = self.data[key]
batch = self.process(batch)
batch_size = len(batch)
batch = list(zip(*batch))
assert len(batch) == 4
Expand Down

0 comments on commit 7bd171c

Please sign in to comment.