Skip to content

Commit

Permalink
#16720: make generic reduce pass pad value to transpose
Browse files Browse the repository at this point in the history
  • Loading branch information
bbradelTT committed Jan 23, 2025
1 parent deaa872 commit 79c3a50
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@ def test_min_max_for_dim_hw(device, use_program_cache, shape_dim, kind, layout):
if kind == "max":
value = x.max()
elif kind == "min":
if N * C % 32 != 0:
pytest.skip("global min with Tensor dimension N*C not multiple of 32 is not supported at this time.")
value = x.min()
elif kind == "mean":
value = x.mean()
Expand Down
2 changes: 1 addition & 1 deletion tests/ttnn/unit_tests/operations/test_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,5 +129,5 @@ def test_max_dim(device, input_shape_and_dim, keepdim):

output_tensor = ttnn.to_torch(output_tensor)

pcc = 0.999 if is_grayskull() else 0.9999
pcc = 0.9999
assert_with_pcc(torch_output_tensor, output_tensor, pcc=pcc)
30 changes: 16 additions & 14 deletions ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ ttnn::SmallVector<int> generate_reduce_dim(
return dim;
}

float get_pad_value(ReduceType reduce_type) {
return reduce_type == ReduceType::Max
? -std::numeric_limits<float>::infinity()
: (reduce_type == ReduceType::Min ? std::numeric_limits<float>::infinity() : 0);
}

template <ReduceType reduce_type>
static Tensor reduce_impl(
const Tensor& input_tensor_arg,
Expand All @@ -58,8 +64,7 @@ static Tensor reduce_impl(
const std::optional<MemoryConfig>& memory_config_arg,
const std::optional<DeviceComputeKernelConfig>& compute_kernel_config,
float scalar,
bool reshape,
bool fill = true) {
bool reshape) {
using ttnn::operations::experimental::auto_format::AutoFormat;
auto input_shape = input_tensor_arg.get_logical_shape();
auto rank = input_shape.size();
Expand All @@ -80,6 +85,7 @@ static Tensor reduce_impl(
auto input_tensor = ttnn::unsqueeze_to_4D(input_tensor_arg);

Tensor output_tensor;
float pad_value = get_pad_value(reduce_type);
bool single_reduce_op = (dim.size() == 1 && (dim[0] == rank - 1 || dim[0] == rank - 2)) ||
(dim.size() == 2 && dim[1] == rank - 1 && dim[0] == rank - 2);
if (!single_reduce_op) {
Expand All @@ -93,7 +99,7 @@ static Tensor reduce_impl(
int adjusted_dim = offset + i_dim;
int reduce_dim = adjusted_dim;
if (transpose) {
output_tensor = ttnn::transpose(output_tensor, adjusted_dim, 2, memory_config);
output_tensor = ttnn::transpose(output_tensor, adjusted_dim, -2, memory_config, pad_value);
reduce_dim = 2;
}
if (use_reduce_type) {
Expand All @@ -116,7 +122,7 @@ static Tensor reduce_impl(
/*reshape=*/false);
}
if (transpose) {
output_tensor = ttnn::transpose(output_tensor, adjusted_dim, -2, memory_config);
output_tensor = ttnn::transpose(output_tensor, adjusted_dim, -2, memory_config, pad_value);
}
}
}
Expand Down Expand Up @@ -150,11 +156,6 @@ static Tensor reduce_impl(
for (int axis : dim) {
reduced_volume *= input_shape[axis];
}
float pad_value = reduce_type == ReduceType::Max
? -std::numeric_limits<float>::infinity()
: (reduce_type == ReduceType::Min ? std::numeric_limits<float>::infinity() : 0);
bool is_tiled = input_tensor.get_layout() == TILE_LAYOUT;
input_tensor = fill && is_tiled ? ttnn::fill_implicit_tile_padding(input_tensor, pad_value) : input_tensor;

if constexpr (reduce_type == ReduceType::Sum) {
output_tensor = tt::tt_metal::reduce(
Expand Down Expand Up @@ -247,11 +248,13 @@ Tensor Reduce<reduce_type>::invoke(
const std::optional<DeviceComputeKernelConfig>& compute_kernel_config,
float scalar) {
ttnn::SmallVector<int> dim = generate_reduce_dim(input_tensor_arg, dim_arg);
float pad_value = get_pad_value(reduce_type);
bool is_tiled = input_tensor_arg.get_layout() == TILE_LAYOUT;
auto input_tensor = is_tiled ? ttnn::fill_implicit_tile_padding(input_tensor_arg, pad_value) : input_tensor_arg;
if constexpr (reduce_type == ReduceType::Std || reduce_type == ReduceType::Var) {
return std_var_impl<reduce_type>(input_tensor_arg, dim, keepdim, memory_config_arg, compute_kernel_config);
return std_var_impl<reduce_type>(input_tensor, dim, keepdim, memory_config_arg, compute_kernel_config);
}
return reduce_impl<reduce_type>(
input_tensor_arg, dim, keepdim, memory_config_arg, compute_kernel_config, scalar, true);
return reduce_impl<reduce_type>(input_tensor, dim, keepdim, memory_config_arg, compute_kernel_config, scalar, true);
}

Tensor pool_sum(
Expand All @@ -267,8 +270,7 @@ Tensor pool_sum(
memory_config_arg,
compute_kernel_config,
scalar,
/*reshape=*/true,
/*fill=*/false);
/*reshape=*/true);
}

template class Reduce<ReduceType::Sum>;
Expand Down

0 comments on commit 79c3a50

Please sign in to comment.