diff --git a/torchnlp/nn/attention.py b/torchnlp/nn/attention.py index 286e7b8..769bf1c 100755 --- a/torchnlp/nn/attention.py +++ b/torchnlp/nn/attention.py @@ -61,9 +61,9 @@ def forward(self, query, context): query_len = context.size(1) if self.attention_type == "general": - query = query.view(batch_size * output_len, dimensions) + query = query.reshape(batch_size * output_len, dimensions) query = self.linear_in(query) - query = query.view(batch_size, output_len, dimensions) + query = query.reshape(batch_size, output_len, dimensions) # TODO: Include mask on PADDING_INDEX?