Skip to content

Commit

Permalink
Linking tensor.reshape to ttnn.reshape (#16377)
Browse files Browse the repository at this point in the history
### Ticket
Link to Github Issue
#13745

### Problem description
tensor.reshape is used as a view reshape; the behaviour should be
similar to ttnn.reshape

### What's changed
Linking tensor.reshape to ttnn.reshape and adding tensor.reshape as an
experimental operation named view. Same PR as
#15669, but it was reverted

### Checklist
- [x] Post commit CI passes
https://github.com/tenstorrent/tt-metal/actions/runs/12635237980
- [x] T3K unit tests
https://github.com/tenstorrent/tt-metal/actions/runs/12635247860
- [x] Nightly model and ttnn tests
https://github.com/tenstorrent/tt-metal/actions/runs/12635313084
- [x] Single card demo tests
https://github.com/tenstorrent/tt-metal/actions/runs/12653046344
- [ ] Blackhole Post commit (if applicable)
- [ ] Model regression CI testing passes (if applicable)
- [ ] Device performance regression CI testing passes (if applicable)
- [ ] **(For models and ops writers)** Full [new
models](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml)
tests passes
- [ ] New/Existing tests provide coverage for changes
  • Loading branch information
nardoTT authored Jan 16, 2025
1 parent c4198c1 commit b97973e
Show file tree
Hide file tree
Showing 25 changed files with 399 additions and 176 deletions.
7 changes: 5 additions & 2 deletions tests/tt_eager/ops/test_tilize_zero_padding_channels_last.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
#include "ttnn/operations/data_movement/tilize_with_val_padding/tilize_with_val_padding.hpp"
#include <tt-metalium/host_api.hpp>

#include "ttnn/operations/functions.hpp"
#include "ttnn/cpp/ttnn/operations/experimental/reshape/view.hpp"

using namespace tt;
using namespace tt_metal;
using namespace constants;
Expand All @@ -37,8 +40,8 @@ int main(int argc, char** argv) {
////////////////////////////////////////////////////////////////////////////
ttnn::SimpleShape shape{1, 32, 61, 32};
// Allocates a DRAM buffer on device populated with values specified by initialize
Tensor a = ttnn::arange(/*start=*/0, /*stop=*/shape.volume(), /*step=*/1, DataType::BFLOAT16)
.reshape(shape)
Tensor a = ttnn::experimental::view(
ttnn::arange(/*start=*/0, /*stop=*/shape.volume(), /*step=*/1, DataType::BFLOAT16), shape)
.to(device);
Tensor b = ttnn::tilize_with_zero_padding(a);
Tensor c = b.cpu();
Expand Down
7 changes: 5 additions & 2 deletions tests/tt_eager/tensors/test_async_tensor_apis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#include "ttnn/operations/eltwise/binary/binary.hpp"
#include "ttnn/operations/eltwise/unary/unary.hpp"

#include "ttnn/cpp/ttnn/operations/experimental/reshape/view.hpp"

namespace tt::tt_metal {
namespace {

Expand Down Expand Up @@ -58,7 +60,7 @@ TEST_F(DispatchFixture, TestTensorOwnershipSanity) {
},
host_tensor.get_storage());
// Send tensor to device, read it back and copy it to empty tensor initialized by main thread
Tensor reshaped_tensor = host_tensor.reshape(ttnn::SimpleShape{1, 1, 32, 128});
Tensor reshaped_tensor = ttnn::experimental::view(host_tensor, ttnn::SimpleShape{1, 1, 32, 128});
auto device_tensor = reshaped_tensor.to(Layout::TILE).to(device);
auto thread_local_tensor = device_tensor.cpu().to(Layout::ROW_MAJOR);
readback_tensor.set_storage(thread_local_tensor.get_storage());
Expand Down Expand Up @@ -285,7 +287,8 @@ TEST_F(DispatchFixture, TestTensorAsyncDataMovement) {
},
host_tensor.get_storage());

Tensor reshaped_tensor = host_tensor.reshape(ttnn::SimpleShape{1, 1, 32, tensor_stop / 32});
Tensor reshaped_tensor =
ttnn::experimental::view(host_tensor, ttnn::SimpleShape{1, 1, 32, tensor_stop / 32});
auto device_tensor = reshaped_tensor.to(Layout::TILE).to(device);
auto thread_local_tensor = device_tensor.cpu().to(Layout::ROW_MAJOR);
log_info(LogTest, "Worker populating empty host readback_tensor");
Expand Down
5 changes: 3 additions & 2 deletions tests/tt_eager/tensors/test_copy_and_move.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "ttnn/tensor/tensor_impl.hpp"
#include <tt-metalium/host_api.hpp>
#include "ttnn/operations/functions.hpp"
#include "ttnn/cpp/ttnn/operations/experimental/reshape/view.hpp"

using namespace tt;
using namespace tt_metal;
Expand All @@ -37,8 +38,8 @@ bool test_tensor_copy_semantics(IDevice* device) {
pass &= dev_a_data == dev_a_copy_data;

// host tensor updated with host tensor copy assignment
Tensor host_c = ttnn::arange(/*start=*/0, /*stop=*/single_tile_shape.volume(), /*step=*/1)
.reshape(single_tile_shape)
Tensor host_c = ttnn::experimental::view(
ttnn::arange(/*start=*/0, /*stop=*/single_tile_shape.volume(), /*step=*/1), single_tile_shape)
.to(Layout::TILE);
Tensor host_c_copy = ttnn::random::random(single_tile_shape).to(Layout::TILE);
host_c_copy = host_c;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

#include <tt-metalium/mesh_device.hpp>
#include <tt-metalium/mesh_device_view.hpp>
#include "ttnn/cpp/ttnn/operations/experimental/reshape/view.hpp"

#include <tt-metalium/tile.hpp>

Expand Down Expand Up @@ -1659,10 +1660,12 @@ bool RunMultiInputReaderTestPropagateFullTensorIn(
TwoInputReaderKernelWriteMode test_writeback_mode) {
auto logical_shape = tensor_shape.logical_shape();
auto num_elems = std::reduce(logical_shape.cbegin(), logical_shape.cend(), 1, std::multiplies<uint32_t>());
Tensor input_tensor0 = ttnn::arange(0, num_elems, 1, DataType::UINT32).reshape(tensor_shape).to(layout);
Tensor input_tensor1 = ttnn::arange(num_elems, 2 * num_elems, 1, DataType::UINT32).reshape(tensor_shape).to(layout);
Tensor output_tensor0 = ttnn::ones(tensor_shape, DataType::UINT32, layout).reshape(tensor_shape);
Tensor output_tensor1 = ttnn::ones(tensor_shape, DataType::UINT32, layout).reshape(tensor_shape);
Tensor input_tensor0 =
ttnn::experimental::view(ttnn::arange(0, num_elems, 1, DataType::UINT32), tensor_shape).to(layout);
Tensor input_tensor1 =
ttnn::experimental::view(ttnn::arange(num_elems, 2 * num_elems, 1, DataType::UINT32), tensor_shape).to(layout);
Tensor output_tensor0 = ttnn::experimental::view(ttnn::ones(tensor_shape, DataType::UINT32, layout), tensor_shape);
Tensor output_tensor1 = ttnn::experimental::view(ttnn::ones(tensor_shape, DataType::UINT32, layout), tensor_shape);
input_tensor0.set_tensor_spec(TensorSpec(
logical_shape, TensorLayout(DataType::UINT32, PageConfig(layout, tt_metal::Tile()), in0_memory_config)));
input_tensor1.set_tensor_spec(TensorSpec(
Expand Down Expand Up @@ -1956,10 +1959,14 @@ TEST(WorkerCclCommandProcessingKernelFabricUnicastMode, MultiInputReader_SingleP
MemoryConfig const out1_memory_config = MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM);

auto num_elems = std::reduce(logical_shape.cbegin(), logical_shape.cend(), 1, std::multiplies<uint32_t>());
Tensor input_tensor0 = ttnn::arange(0, num_elems, 1, DataType::UINT32).reshape(tensor_shape).to(layout);
Tensor input_tensor1 = ttnn::arange(num_elems, 2 * num_elems, 1, DataType::UINT32).reshape(tensor_shape).to(layout);
Tensor output_tensor0 = ttnn::ones(tensor_shape.value, DataType::UINT32, layout).reshape(tensor_shape);
Tensor output_tensor1 = ttnn::ones(tensor_shape.value, DataType::UINT32, layout).reshape(tensor_shape);
Tensor input_tensor0 =
ttnn::experimental::view(ttnn::arange(0, num_elems, 1, DataType::UINT32), tensor_shape).to(layout);
Tensor input_tensor1 =
ttnn::experimental::view(ttnn::arange(num_elems, 2 * num_elems, 1, DataType::UINT32), tensor_shape).to(layout);
Tensor output_tensor0 =
ttnn::experimental::view(ttnn::ones(tensor_shape.value, DataType::UINT32, layout), tensor_shape);
Tensor output_tensor1 =
ttnn::experimental::view(ttnn::ones(tensor_shape.value, DataType::UINT32, layout), tensor_shape);

input_tensor0.set_tensor_spec(TensorSpec(
logical_shape, TensorLayout(DataType::UINT32, PageConfig(layout, tt_metal::Tile()), in0_memory_config)));
Expand Down Expand Up @@ -2027,10 +2034,14 @@ TEST(WorkerCclCommandProcessingKernelFabricUnicastMode, MultiInputReader_SingleP
MemoryConfig const out1_memory_config = MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM);

auto num_elems = std::reduce(logical_shape.cbegin(), logical_shape.cend(), 1, std::multiplies<uint32_t>());
Tensor input_tensor0 = ttnn::arange(0, num_elems, 1, DataType::UINT32).reshape(tensor_shape).to(layout);
Tensor input_tensor1 = ttnn::arange(num_elems, 2 * num_elems, 1, DataType::UINT32).reshape(tensor_shape).to(layout);
Tensor output_tensor0 = ttnn::ones(tensor_shape.value, DataType::UINT32, layout).reshape(tensor_shape);
Tensor output_tensor1 = ttnn::ones(tensor_shape.value, DataType::UINT32, layout).reshape(tensor_shape);
Tensor input_tensor0 =
ttnn::experimental::view(ttnn::arange(0, num_elems, 1, DataType::UINT32), tensor_shape).to(layout);
Tensor input_tensor1 =
ttnn::experimental::view(ttnn::arange(num_elems, 2 * num_elems, 1, DataType::UINT32), tensor_shape).to(layout);
Tensor output_tensor0 =
ttnn::experimental::view(ttnn::ones(tensor_shape.value, DataType::UINT32, layout), tensor_shape);
Tensor output_tensor1 =
ttnn::experimental::view(ttnn::ones(tensor_shape.value, DataType::UINT32, layout), tensor_shape);

input_tensor0.set_tensor_spec(TensorSpec(
logical_shape, TensorLayout(DataType::UINT32, PageConfig(layout, tt_metal::Tile()), in0_memory_config)));
Expand Down Expand Up @@ -2102,10 +2113,14 @@ void RunFabricMcastFullTensorPropagateTest(
MemoryConfig const out1_memory_config = MemoryConfig(TensorMemoryLayout::INTERLEAVED, BufferType::DRAM);

auto num_elems = std::reduce(logical_shape.cbegin(), logical_shape.cend(), 1, std::multiplies<uint32_t>());
Tensor input_tensor1 = ttnn::arange(num_elems, 2 * num_elems, 1, DataType::UINT32).reshape(tensor_shape).to(layout);
Tensor input_tensor0 = ttnn::arange(0, num_elems, 1, DataType::UINT32).reshape(tensor_shape).to(layout);
Tensor output_tensor1 = ttnn::ones(tensor_shape.value, DataType::UINT32, layout).reshape(tensor_shape);
Tensor output_tensor0 = ttnn::ones(tensor_shape.value, DataType::UINT32, layout).reshape(tensor_shape);
Tensor input_tensor1 =
ttnn::experimental::view(ttnn::arange(num_elems, 2 * num_elems, 1, DataType::UINT32), tensor_shape).to(layout);
Tensor input_tensor0 =
ttnn::experimental::view(ttnn::arange(0, num_elems, 1, DataType::UINT32), tensor_shape).to(layout);
Tensor output_tensor1 =
ttnn::experimental::view(ttnn::ones(tensor_shape.value, DataType::UINT32, layout), tensor_shape);
Tensor output_tensor0 =
ttnn::experimental::view(ttnn::ones(tensor_shape.value, DataType::UINT32, layout), tensor_shape);
input_tensor0.set_tensor_spec(TensorSpec(
logical_shape, TensorLayout(DataType::UINT32, PageConfig(layout, tt_metal::Tile()), in0_memory_config)));
input_tensor1.set_tensor_spec(TensorSpec(
Expand Down Expand Up @@ -2328,9 +2343,11 @@ bool RunPipelinedWorkersTest(
host_tensors.reserve(num_tensors);
device_tensors.reserve(num_tensors);
auto num_elems = std::reduce(logical_shape.cbegin(), logical_shape.cend(), 1, std::multiplies<uint32_t>());
host_tensors.push_back(ttnn::arange(0, num_elems, 1, DataType::UINT32).reshape(tensor_shape).to(layout));
host_tensors.push_back(
ttnn::experimental::view(ttnn::arange(0, num_elems, 1, DataType::UINT32), tensor_shape).to(layout));
for (size_t i = 1; i < num_tensors; ++i) {
host_tensors.push_back(ttnn::ones(tensor_shape.value, DataType::UINT32, layout).reshape(tensor_shape));
host_tensors.push_back(
ttnn::experimental::view(ttnn::ones(tensor_shape.value, DataType::UINT32, layout), tensor_shape));
}
TT_FATAL(mem_configs.size() == num_tensors, "Must have a memory config for each tensor");
for (size_t i = 0; i < num_tensors; i++) {
Expand Down Expand Up @@ -2844,7 +2861,7 @@ TEST(CclAsyncOp, ReduceScatterSmall_PersistentFabric) {
for (size_t i = 0; i < num_devices; i++) {
// host_input_tensors.push_back(ttnn::numpy::random::uniform(bfloat16(-1.0f), bfloat16(1.0f) ,
// {logical_shape[0],logical_shape[1],logical_shape[2],logical_shape[3]}, layout).to(devices[i]));
auto t = ttnn::arange(0, num_elems, 1, DataType::BFLOAT16).reshape(input_shape).to(layout);
auto t = ttnn::experimental::view(ttnn::arange(0, num_elems, 1, DataType::BFLOAT16), input_shape).to(layout);
t.set_tensor_spec(TensorSpec(
logical_shape, TensorLayout(DataType::BFLOAT16, PageConfig(layout, tt_metal::Tile()), in_memory_config)));

Expand Down Expand Up @@ -2955,7 +2972,7 @@ void run_all_gather_with_persistent_fabric(const size_t dim, const size_t num_li
size_t page_size = tile_size(DataFormat::Float16);
std::vector<Tensor> device_input_tensors;
for (size_t i = 0; i < num_devices; i++) {
auto t = ttnn::arange(0, num_elems, 1).reshape(input_shape).to(layout);
auto t = ttnn::experimental::view(ttnn::arange(0, num_elems, 1), input_shape).to(layout);
t.set_tensor_spec(TensorSpec(
logical_shape, TensorLayout(DataType::BFLOAT16, PageConfig(layout, tt_metal::Tile()), in_memory_config)));

Expand Down
4 changes: 2 additions & 2 deletions tests/ttnn/unit_tests/test_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_reshape_sharded_rm(device, n, c, h, w):
torch_input_tensor, layout=ttnn.ROW_MAJOR_LAYOUT, device=device, memory_config=sharded_mem_config
)

tt_output_tensor = tt_input_tensor.reshape(n, c, h * 2, w // 2)
tt_output_tensor = ttnn.experimental.view(tt_input_tensor, n, c, h * 2, w // 2)

sharded_mem_config = ttnn.create_sharded_memory_config(
tt_output_tensor.shape,
Expand Down Expand Up @@ -473,7 +473,7 @@ def test_reshape_zero_element(input_shape, output_shape, layout, ttnn_reshape, u
if ttnn_reshape:
tt_output_tensor = ttnn.reshape(tt_input_tensor, output_shape)
else:
tt_output_tensor = tt_input_tensor.reshape(output_shape)
tt_output_tensor = ttnn.experimental.view(tt_input_tensor, output_shape)
tt_output_tensor = ttnn.from_device(tt_output_tensor)
tt_output_tensor = ttnn.to_torch(tt_output_tensor)
assert tt_output_tensor.shape == torch.Size(output_shape)
33 changes: 33 additions & 0 deletions tests/ttnn/unit_tests/test_reshape_transpose.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch
import ttnn
from tests.ttnn.utils_for_testing import assert_with_pcc


def test_transpose_with_reshape(device):
# Create input tensor
torch_input = torch.rand((1, 1, 2048, 512), dtype=torch.bfloat16)

# TT operations
tt_input = ttnn.from_torch(
torch_input,
dtype=ttnn.DataType.BFLOAT16,
layout=ttnn.ROW_MAJOR_LAYOUT,
device=device,
memory_config=ttnn.L1_MEMORY_CONFIG,
)
tt_input = tt_input.reshape(1, 2048, 4, 128)
tt_output = ttnn.transpose(tt_input, 1, 2)

# Convert back to PyTorch for comparison
tt_result = ttnn.to_torch(tt_output)

# PyTorch reference operations
torch_ref = torch_input.view(1, 2048, 4, 128)
torch_ref = torch_ref.transpose(1, 2)

# Compare results
assert_with_pcc(torch_ref, tt_result, 0.9999)
3 changes: 3 additions & 0 deletions ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,8 @@ set(TTNN_OP_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/dropout/device/dropout_program_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/dropout/dropout.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/dropout/dropout_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/reshape/view.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/reshape/view_pybind.cpp
)

set(TTNN_SUBLIBRARIES
Expand Down Expand Up @@ -615,6 +617,7 @@ set(TTNN_SUBLIBRARIES
ttnn/operations/experimental/ssm
ttnn/operations/experimental/transformer
ttnn/operations/experimental/dropout
ttnn/operations/experimental/reshape
ttnn/operations/full_like
ttnn/operations/full
ttnn/operations/index_fill
Expand Down
9 changes: 6 additions & 3 deletions ttnn/cpp/pybind11/pytensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
#include "ttnn/tensor/tensor_ops.hpp"
#include "tools/profiler/op_profiler.hpp"

#include "ttnn/common/constants.hpp"
#include "ttnn/operations/core/core.hpp"

using namespace tt::tt_metal;

namespace py = pybind11;
Expand Down Expand Up @@ -1657,7 +1660,7 @@ void pytensor_module(py::module& m_tensor) {
.def(
"reshape",
[](Tensor& self, int N, int C, int H, int W) {
return self.reshape(infer_dims_for_reshape(self, ttnn::SmallVector<int>{N, C, H, W}));
return ttnn::reshape(self, infer_dims_for_reshape(self, ttnn::SmallVector<int>{N, C, H, W}));
},
R"doc(
Reshapes TT tensor
Expand All @@ -1668,7 +1671,7 @@ void pytensor_module(py::module& m_tensor) {
)doc")
.def(
"reshape",
[](Tensor& self, const ttnn::Shape& shape) -> Tensor { return self.reshape(shape); },
[](Tensor& self, const ttnn::Shape& shape) -> Tensor { return ttnn::reshape(self, shape); },
R"doc(
Reshapes TT tensor
Expand All @@ -1679,7 +1682,7 @@ void pytensor_module(py::module& m_tensor) {
.def(
"reshape",
[](Tensor& self, const ttnn::SmallVector<int32_t>& shape) -> Tensor {
return self.reshape(infer_dims_for_reshape(self, shape));
return ttnn::reshape(self, infer_dims_for_reshape(self, shape));
},
R"doc(
Reshapes TT tensor
Expand Down
5 changes: 3 additions & 2 deletions ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "ttnn/operations/data_movement/untilize_with_unpadding/untilize_with_unpadding.hpp"
#include "ttnn/operations/data_movement/reshape_view/reshape.hpp"
#include <tt-metalium/constants.hpp>
#include "cpp/ttnn/operations/experimental/reshape/view.hpp"
#include "ttnn/operations/core/core.hpp"
#include "ttnn/types.hpp"

Expand Down Expand Up @@ -105,7 +106,7 @@ Tensor to_layout_impl(
SmallVector<uint32_t> new_padded_shape(2, 1);
new_padded_shape[1] = tensor.get_padded_shape()[-1];
new_padded_shape[0] = tensor.get_padded_shape()[-2];
tensor = tensor.reshape(tensor.get_logical_shape(), SimpleShape(new_padded_shape));
tensor = ttnn::experimental::view(tensor, tensor.get_logical_shape(), SimpleShape(new_padded_shape));
}
}

Expand Down Expand Up @@ -202,7 +203,7 @@ Tensor to_layout_impl(
tensor =
tensor.pad(ttnn::SimpleShape(padded_output_shape), ttnn::SimpleShape(std::move(padded_input_start)), 0);
tensor = device ? tensor.to(layout, device) : tensor.to(layout);
return tensor.reshape(output_shape, padded_output_shape);
return ttnn::experimental::view(tensor, output_shape, padded_output_shape);
} else {
TT_THROW("ttnn::to_layout: Unsupported output layout: {}!", layout);
}
Expand Down
8 changes: 6 additions & 2 deletions ttnn/cpp/ttnn/operations/data_movement/fold/fold.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#include "cpp/ttnn/operations/data_movement/pad/pad.hpp"
#include <tt-metalium/constants.hpp>

#include "cpp/ttnn/operations/experimental/reshape/view.hpp"

#include "fold.hpp"

namespace ttnn::operations::data_movement {
Expand Down Expand Up @@ -216,7 +218,8 @@ std::vector<Tensor> fold_with_transpose_sharded_(
// reshape
n = tt_output_tensor.shape()[0], w = tt_output_tensor.shape()[1], c = tt_output_tensor.shape()[2],
h = tt_output_tensor.shape()[3];
tt_output_tensor = tt_output_tensor.reshape(ttnn::SimpleShape{n, (w / stride_w), (c * stride_w), h});
tt_output_tensor =
ttnn::experimental::view(tt_output_tensor, ttnn::SimpleShape{n, (w / stride_w), (c * stride_w), h});

tt::log_debug("reshape_hc_output: {}", tt_output_tensor.shape());

Expand All @@ -229,7 +232,8 @@ std::vector<Tensor> fold_with_transpose_sharded_(
// reshape
n = tt_output_tensor.shape()[0], w = tt_output_tensor.shape()[1], h = tt_output_tensor.shape()[2],
c = tt_output_tensor.shape()[3];
tt_output_tensor = tt_output_tensor.reshape(ttnn::SimpleShape{n, w, (h / stride_h), (c * stride_h)});
tt_output_tensor =
ttnn::experimental::view(tt_output_tensor, ttnn::SimpleShape{n, w, (h / stride_h), (c * stride_h)});

tt::log_debug("reshape_hw_output: {}", tt_output_tensor.shape());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ void ReshapeDeviceOperation::validate(const std::vector<Tensor>& input_tensors)

TT_FATAL(
input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED,
"Reshape does not currently support sharding");
"Use ttnn::reshape for reshaping sharded inputs");
TT_FATAL(
this->output_mem_config.memory_layout == TensorMemoryLayout::INTERLEAVED,
"Reshape does not currently support sharding");
"Reshape does not currently support sharding. Use ttnn::reshape for reshaping sharded inputs");

if (input_tensor_a.get_layout() == Layout::TILE) {
TT_FATAL(input_tensor_a.volume() % TILE_HW == 0, "Error");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#include "ttnn/tensor/tensor_utils.hpp"
#include "device/reshape_op.hpp"

#include "cpp/ttnn/operations/experimental/reshape/view.hpp"

namespace ttnn::operations::data_movement {

namespace detail {
Expand Down Expand Up @@ -57,7 +59,7 @@ ttnn::Tensor ReshapeOperation::invoke(
padded_output_shape[3] == input_tensor.get_padded_shape()[3])) {
// Don't need to do a check here to see the H and W both divisible by 32
// since handled within the tensor reshape method
return input_tensor.reshape(output_shape);
return ttnn::experimental::view(input_tensor, output_shape);
}
if (input_tensor.get_padded_shape() == padded_output_shape) {
return ttnn::operations::experimental::auto_format::AutoFormat::move_tensor_to_mem_config(
Expand Down
Loading

0 comments on commit b97973e

Please sign in to comment.