Skip to content

Commit

Permalink
Turn the magic ordering of the return value for the DataLoader into a…
Browse files Browse the repository at this point in the history
… namedtuple. Actually not necessary to edit things downstream - they should all still work
  • Loading branch information
AngledLuffa committed Oct 23, 2023
1 parent db73192 commit 71253bc
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion stanza/models/pos/data.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import namedtuple
import random
import logging
import torch
Expand All @@ -11,6 +12,8 @@

logger = logging.getLogger('stanza')

DataBatch = namedtuple("DataBatch", "words words_mask wordchars wordchars_mask upos xpos ufeats pretrained orig_idx word_orig_idx lens word_lens text")

class DataLoader:
def __init__(self, doc, batch_size, args, pretrain, vocab=None, evaluation=False, sort_during_eval=False, bert_tokenizer=None):
self.batch_size = batch_size
Expand Down Expand Up @@ -125,7 +128,7 @@ def __getitem__(self, key):
pretrained = get_long_tensor(batch[5], batch_size)
text = batch[6]
sentlens = [len(x) for x in batch[0]]
return words, words_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, orig_idx, word_orig_idx, sentlens, word_lens, text
return DataBatch(words, words_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, orig_idx, word_orig_idx, sentlens, word_lens, text)

def __iter__(self):
for i in range(self.__len__()):
Expand Down

0 comments on commit 71253bc

Please sign in to comment.