Skip to content
This repository has been archived by the owner on Jul 4, 2023. It is now read-only.

Commit

Permalink
Add typing to stack_and_pad_tensors and fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
PetrochukM committed Sep 4, 2020
1 parent 93828fb commit db9d39d
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 10 deletions.
4 changes: 2 additions & 2 deletions torchnlp/encoders/text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torchnlp.encoders.text.spacy_encoder import SpacyEncoder
from torchnlp.encoders.text.static_tokenizer_encoder import StaticTokenizerEncoder
from torchnlp.encoders.text.subword_encoder import SubwordEncoder
from torchnlp.encoders.text.text_encoder import BatchedSequences
from torchnlp.encoders.text.text_encoder import SequenceBatch
from torchnlp.encoders.text.text_encoder import pad_tensor
from torchnlp.encoders.text.text_encoder import stack_and_pad_tensors
from torchnlp.encoders.text.text_encoder import TextEncoder
Expand All @@ -28,5 +28,5 @@
'DEFAULT_RESERVED_TOKENS', 'DEFAULT_SOS_INDEX', 'DEFAULT_SOS_TOKEN', 'DEFAULT_UNKNOWN_INDEX',
'DEFAULT_UNKNOWN_TOKEN', 'DelimiterEncoder', 'MosesEncoder', 'pad_tensor',
'stack_and_pad_tensors', 'TextEncoder', 'SpacyEncoder', 'StaticTokenizerEncoder',
'SubwordEncoder', 'TreebankEncoder', 'WhitespaceEncoder', 'BatchedSequences'
'SubwordEncoder', 'TreebankEncoder', 'WhitespaceEncoder', 'SequenceBatch'
]
11 changes: 6 additions & 5 deletions torchnlp/encoders/text/text_encoder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections import namedtuple
import typing

import torch

Expand All @@ -25,7 +25,9 @@ def pad_tensor(tensor, length, padding_index=DEFAULT_PADDING_INDEX):
return torch.cat((tensor, padding), dim=0)


BatchedSequences = namedtuple('BatchedSequences', ['tensor', 'lengths'])
class SequenceBatch(typing.NamedTuple):
tensor: torch.Tensor
lengths: torch.Tensor


def stack_and_pad_tensors(batch, padding_index=DEFAULT_PADDING_INDEX, dim=0):
Expand All @@ -37,8 +39,7 @@ def stack_and_pad_tensors(batch, padding_index=DEFAULT_PADDING_INDEX, dim=0):
dim (int, optional): Dimension on to which to concatenate the batch of tensors.
Returns
BatchedSequences(torch.Tensor, torch.Tensor): Padded tensors and original lengths of
tensors.
SequenceBatch: Padded tensors and original lengths of tensors.
"""
lengths = [tensor.shape[0] for tensor in batch]
max_len = max(lengths)
Expand All @@ -48,7 +49,7 @@ def stack_and_pad_tensors(batch, padding_index=DEFAULT_PADDING_INDEX, dim=0):
for _ in range(dim):
lengths = lengths.unsqueeze(0)

return BatchedSequences(padded, lengths)
return SequenceBatch(padded, lengths)


class TextEncoder(Encoder):
Expand Down
5 changes: 2 additions & 3 deletions torchnlp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@
logger = logging.getLogger(__name__)


def _get_tensors(object_, seen=set()):
def _get_tensors(object_, seen=None):
if torch.is_tensor(object_):
return [object_]

elif isinstance(object_, (str, float, int)) or id(object_) in seen:
return []

Expand Down Expand Up @@ -42,7 +41,7 @@ def get_tensors(object_):
Returns:
(list of torch.tensor): List of tensors that are associated with ``object_``.
"""
return _get_tensors(object_)
return _get_tensors(object_, set())


def sampler_to_iterator(dataset, sampler):
Expand Down

0 comments on commit db9d39d

Please sign in to comment.