Skip to content

Commit

Permalink
#0: revert transpose changes for now
Browse files Browse the repository at this point in the history
  • Loading branch information
sjameelTT committed Dec 12, 2024
1 parent e6d86ea commit 2a2820b
Showing 1 changed file with 37 additions and 13 deletions.
50 changes: 37 additions & 13 deletions ttnn/cpp/ttnn/operations/data_movement/transpose/transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
#include "ttnn/cpp/ttnn/operations/copy.hpp"
#include "ttnn/cpp/ttnn/operations/data_movement/pad/pad.hpp"
#include "ttnn/cpp/ttnn/operations/data_movement/slice/slice.hpp"
#include "ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.hpp"

// FIXME: ARCH_NAME specific include
#include "noc/noc_parameters.h" // DRAM_ALIGNMENT
Expand Down Expand Up @@ -57,14 +56,20 @@ inline Tensor transpose_(
TransposeOpDim transpose_dim,
const MemoryConfig& output_mem_config,
const std::optional<float>& pad_value) {
bool tiled_only = false;
constexpr uint32_t FACE_WIDTH =
tt::constants::FACE_WIDTH; // this is a highly restrictive constraint on the RM transpose_wh kernel, and with
// all the other bugs/limitations we should rewrite it
// use device->get_allocator_alignment when the it reflects the alignment of the buffer and doesn't just default to
// DRAM
auto BUFFER_ALIGNMENT = a.buffer()->buffer_type() == tt::tt_metal::BufferType::DRAM ? DRAM_ALIGNMENT : L1_ALIGNMENT;
uint32_t W = a.get_padded_shape()[-1];
uint32_t H = a.get_padded_shape()[-2];
switch (transpose_dim) {
case TransposeOpDim::HC:
if ((a.get_layout() == Layout::ROW_MAJOR) && ((W * a.element_size()) % BUFFER_ALIGNMENT != 0)) { //
return ttnn::prim::permute(
(const ttnn::Tensor)a, ttnn::SmallVector<uint32_t>({0, 2, 1, 3}), output_mem_config, std::nullopt);
tiled_only = a.get_layout() == Layout::TILE;
if ((!tiled_only) && ((W * a.element_size()) % BUFFER_ALIGNMENT != 0)) { //
tiled_only = true;
}
break;
// bubble dim around to make it possible as these implementations don't have a kernel
Expand All @@ -78,20 +83,39 @@ inline Tensor transpose_(
return ttnn::permute(
(const ttnn::Tensor)a, ttnn::SmallVector<int64_t>({0, 3, 2, 1}), output_mem_config, pad_value);
case TransposeOpDim::CN:
if (a.get_layout() == Layout::ROW_MAJOR) {
return ttnn::prim::permute(
(const ttnn::Tensor)a, ttnn::SmallVector<uint32_t>({1, 0, 2, 3}), output_mem_config, std::nullopt);
}
tiled_only = true; // CN only has a tiled implementation at the moment
break;
case TransposeOpDim::WH:
if (!a.is_sharded() && a.layout() == Layout::ROW_MAJOR) {
return ttnn::prim::permute(
(const ttnn::Tensor)a, ttnn::SmallVector<uint32_t>({0, 1, 3, 2}), output_mem_config, std::nullopt);
case TransposeOpDim::WH: // THIS NEEDS TO BE FIXED
if (((W * a.element_size()) % FACE_WIDTH != 0) || ((H * a.element_size()) % FACE_WIDTH != 0)) {
tiled_only = true;
} else if (a.device()->arch() == tt::ARCH::GRAYSKULL) {
tiled_only = a.shape()[-2] > 256; // hangs right now past this dimension, #13660 will turn it from a
// hang into a PCC issue for GS and improve perf for WH
} else if (
!a.is_sharded() && a.layout() == Layout::ROW_MAJOR &&
!rm_enough_available_space(
a)) { // rm is L1 intensive, if it overflows we can do tiled which allocates much smaller CBs
tiled_only = true;
}
break;
default: break;
}
return operation::run(Transpose{transpose_dim, output_mem_config, pad_value}, {a}).at(0);
if (a.get_layout() == Layout::ROW_MAJOR) {
// the assorted cases where only tiled works right now (HC with stick width constraint, WH with stick width
// constraint, CN).
if (tiled_only) {
// convert to tiled
Tensor b = ttnn::to_layout(a, Layout::TILE, std::nullopt, std::nullopt, (Device*)nullptr);
// run the transpose.
b = operation::run(Transpose{transpose_dim, output_mem_config, pad_value}, {b}).at(0);
// back to original layout
b = ttnn::to_layout(b, a.get_layout(), std::nullopt, std::nullopt, (Device*)nullptr);
return b;
}
return operation::run(Transpose{transpose_dim, output_mem_config, pad_value}, {a}).at(0);
} else {
return operation::run(Transpose{transpose_dim, output_mem_config, pad_value}, {a}).at(0);
}
}

ttnn::Tensor transpose_nd(
Expand Down

0 comments on commit 2a2820b

Please sign in to comment.