Skip to content

Commit

Permalink
#12662: pad generic reduce op input
Browse files Browse the repository at this point in the history
  • Loading branch information
bbradelTT committed Jan 20, 2025
1 parent 2ef3e06 commit e5113ad
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
1 change: 1 addition & 0 deletions tests/ttnn/unit_tests/operations/test_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def test_max_global(device, batch_size, h, w):
((32, 32, 32, 64), -4),
((2, 32, 32, 64), -3),
((32, 32, 64), -3),
((1, 2, 3, 4), -1),
],
)
@pytest.mark.parametrize("keepdim", [True, False])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// SPDX-License-Identifier: Apache-2.0

#include "ttnn/operations/reduction/generic/generic_reductions.hpp"
#include "ttnn/operations/data_movement/fill_pad/fill_pad.hpp"
#include "ttnn/operations/data_movement/transpose/transpose.hpp"
#include "ttnn/operations/eltwise/unary/unary.hpp"
#include "ttnn/operations/eltwise/binary/binary_composite.hpp"
Expand Down Expand Up @@ -241,11 +242,14 @@ Tensor Reduce<reduce_type>::invoke(
const std::optional<DeviceComputeKernelConfig>& compute_kernel_config,
float scalar) {
ttnn::SmallVector<int> dim = generate_reduce_dim(input_tensor_arg, dim_arg);
float pad_value = reduce_type == ReduceType::Max
? -std::numeric_limits<float>::infinity()
: (reduce_type == ReduceType::Min ? std::numeric_limits<float>::infinity() : 0);
Tensor input_tensor = ttnn::fill_implicit_tile_padding(input_tensor_arg, pad_value);
if constexpr (reduce_type == ReduceType::Std || reduce_type == ReduceType::Var) {
return std_var_impl<reduce_type>(input_tensor_arg, dim, keepdim, memory_config_arg, compute_kernel_config);
return std_var_impl<reduce_type>(input_tensor, dim, keepdim, memory_config_arg, compute_kernel_config);
}
return reduce_impl<reduce_type>(
input_tensor_arg, dim, keepdim, memory_config_arg, compute_kernel_config, scalar, true);
return reduce_impl<reduce_type>(input_tensor, dim, keepdim, memory_config_arg, compute_kernel_config, scalar, true);
}

template class Reduce<ReduceType::Sum>;
Expand Down

0 comments on commit e5113ad

Please sign in to comment.