From e5113adf2d9a9ec316af4e6918d97b2867d7a5f0 Mon Sep 17 00:00:00 2001 From: Borys Bradel Date: Mon, 20 Jan 2025 23:40:18 +0000 Subject: [PATCH] #12662: pad generic reduce op input --- tests/ttnn/unit_tests/operations/test_max.py | 1 + .../reduction/generic/generic_reductions.cpp | 10 +++++++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/test_max.py b/tests/ttnn/unit_tests/operations/test_max.py index 0a8518778532..411fbd0ab44d 100644 --- a/tests/ttnn/unit_tests/operations/test_max.py +++ b/tests/ttnn/unit_tests/operations/test_max.py @@ -98,6 +98,7 @@ def test_max_global(device, batch_size, h, w): ((32, 32, 32, 64), -4), ((2, 32, 32, 64), -3), ((32, 32, 64), -3), + ((1, 2, 3, 4), -1), ], ) @pytest.mark.parametrize("keepdim", [True, False]) diff --git a/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.cpp b/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.cpp index 2efd56b801ae..b1611b1205ea 100644 --- a/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.cpp @@ -3,6 +3,7 @@ // SPDX-License-Identifier: Apache-2.0 #include "ttnn/operations/reduction/generic/generic_reductions.hpp" +#include "ttnn/operations/data_movement/fill_pad/fill_pad.hpp" #include "ttnn/operations/data_movement/transpose/transpose.hpp" #include "ttnn/operations/eltwise/unary/unary.hpp" #include "ttnn/operations/eltwise/binary/binary_composite.hpp" @@ -241,11 +242,14 @@ Tensor Reduce::invoke( const std::optional& compute_kernel_config, float scalar) { ttnn::SmallVector dim = generate_reduce_dim(input_tensor_arg, dim_arg); + float pad_value = reduce_type == ReduceType::Max + ? -std::numeric_limits::infinity() + : (reduce_type == ReduceType::Min ? std::numeric_limits::infinity() : 0); + Tensor input_tensor = ttnn::fill_implicit_tile_padding(input_tensor_arg, pad_value); if constexpr (reduce_type == ReduceType::Std || reduce_type == ReduceType::Var) { - return std_var_impl(input_tensor_arg, dim, keepdim, memory_config_arg, compute_kernel_config); + return std_var_impl(input_tensor, dim, keepdim, memory_config_arg, compute_kernel_config); } - return reduce_impl( - input_tensor_arg, dim, keepdim, memory_config_arg, compute_kernel_config, scalar, true); + return reduce_impl(input_tensor, dim, keepdim, memory_config_arg, compute_kernel_config, scalar, true); } template class Reduce;