diff --git a/torchnlp/encoders/text/__init__.py b/torchnlp/encoders/text/__init__.py index 645314c..f220350 100755 --- a/torchnlp/encoders/text/__init__.py +++ b/torchnlp/encoders/text/__init__.py @@ -15,7 +15,7 @@ from torchnlp.encoders.text.spacy_encoder import SpacyEncoder from torchnlp.encoders.text.static_tokenizer_encoder import StaticTokenizerEncoder from torchnlp.encoders.text.subword_encoder import SubwordEncoder -from torchnlp.encoders.text.text_encoder import BatchedSequences +from torchnlp.encoders.text.text_encoder import SequenceBatch from torchnlp.encoders.text.text_encoder import pad_tensor from torchnlp.encoders.text.text_encoder import stack_and_pad_tensors from torchnlp.encoders.text.text_encoder import TextEncoder @@ -28,5 +28,5 @@ 'DEFAULT_RESERVED_TOKENS', 'DEFAULT_SOS_INDEX', 'DEFAULT_SOS_TOKEN', 'DEFAULT_UNKNOWN_INDEX', 'DEFAULT_UNKNOWN_TOKEN', 'DelimiterEncoder', 'MosesEncoder', 'pad_tensor', 'stack_and_pad_tensors', 'TextEncoder', 'SpacyEncoder', 'StaticTokenizerEncoder', - 'SubwordEncoder', 'TreebankEncoder', 'WhitespaceEncoder', 'BatchedSequences' + 'SubwordEncoder', 'TreebankEncoder', 'WhitespaceEncoder', 'SequenceBatch' ] diff --git a/torchnlp/encoders/text/text_encoder.py b/torchnlp/encoders/text/text_encoder.py index f06ec78..44fa6c6 100644 --- a/torchnlp/encoders/text/text_encoder.py +++ b/torchnlp/encoders/text/text_encoder.py @@ -1,4 +1,4 @@ -from collections import namedtuple +import typing import torch @@ -25,7 +25,9 @@ def pad_tensor(tensor, length, padding_index=DEFAULT_PADDING_INDEX): return torch.cat((tensor, padding), dim=0) -BatchedSequences = namedtuple('BatchedSequences', ['tensor', 'lengths']) +class SequenceBatch(typing.NamedTuple): + tensor: torch.Tensor + lengths: torch.Tensor def stack_and_pad_tensors(batch, padding_index=DEFAULT_PADDING_INDEX, dim=0): @@ -37,8 +39,7 @@ def stack_and_pad_tensors(batch, padding_index=DEFAULT_PADDING_INDEX, dim=0): dim (int, optional): Dimension on to which to concatenate the batch of tensors. Returns - BatchedSequences(torch.Tensor, torch.Tensor): Padded tensors and original lengths of - tensors. + SequenceBatch: Padded tensors and original lengths of tensors. """ lengths = [tensor.shape[0] for tensor in batch] max_len = max(lengths) @@ -48,7 +49,7 @@ def stack_and_pad_tensors(batch, padding_index=DEFAULT_PADDING_INDEX, dim=0): for _ in range(dim): lengths = lengths.unsqueeze(0) - return BatchedSequences(padded, lengths) + return SequenceBatch(padded, lengths) class TextEncoder(Encoder): diff --git a/torchnlp/utils.py b/torchnlp/utils.py index 2b6c266..7b62137 100644 --- a/torchnlp/utils.py +++ b/torchnlp/utils.py @@ -7,10 +7,9 @@ logger = logging.getLogger(__name__) -def _get_tensors(object_, seen=set()): +def _get_tensors(object_, seen=None): if torch.is_tensor(object_): return [object_] - elif isinstance(object_, (str, float, int)) or id(object_) in seen: return [] @@ -42,7 +41,7 @@ def get_tensors(object_): Returns: (list of torch.tensor): List of tensors that are associated with ``object_``. """ - return _get_tensors(object_) + return _get_tensors(object_, set()) def sampler_to_iterator(dataset, sampler):