From dff295db562d6bc693d0e85c1ac0a8bfc717d778 Mon Sep 17 00:00:00 2001 From: Evan Smal Date: Tue, 28 Jan 2025 22:27:51 +0000 Subject: [PATCH] Update memory config when using `view` op with height sharded tensors This fixes a problem with ResNet50 (#17247) where the output memory config was not being updated to match the new output width. --- .../operations/experimental/reshape/view.cpp | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/ttnn/cpp/ttnn/operations/experimental/reshape/view.cpp b/ttnn/cpp/ttnn/operations/experimental/reshape/view.cpp index cc857ae0e15..e0666c09bda 100644 --- a/ttnn/cpp/ttnn/operations/experimental/reshape/view.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/reshape/view.cpp @@ -13,18 +13,32 @@ namespace ttnn::operations::experimental::reshape { +static MemoryConfig infer_output_memory_config( + const MemoryConfig& input_memory_config, const ttnn::SimpleShape& output_logical_shape) { + if (input_memory_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) { + auto shard_spec = input_memory_config.shard_spec.value(); + shard_spec.shape[1] = output_logical_shape[-1]; // update output shard to match new shard width + return MemoryConfig{input_memory_config.memory_layout, input_memory_config.buffer_type, shard_spec}; + } else { + return input_memory_config; + } +} + Tensor tensor_reshape( const Tensor& input_tensor, const ttnn::SimpleShape& new_logical_shape, const ttnn::SimpleShape& new_padded_shape) { ZoneScoped; GraphTracker::instance().track_function_start("Tensor::reshape", input_tensor, new_logical_shape, new_padded_shape); + + const auto output_memory_config = infer_output_memory_config(input_tensor.memory_config(), new_logical_shape); auto new_spec = ttnn::TensorSpec( new_logical_shape, TensorLayout::fromPaddedShape( input_tensor.get_dtype(), input_tensor.get_tensor_spec().page_config(), - input_tensor.memory_config(), + output_memory_config, new_logical_shape, new_padded_shape)); + auto output = std::visit( [&input_tensor, &new_spec, &new_logical_shape, &new_padded_shape](auto&& storage) -> Tensor { using T = std::decay_t;