From 7b6af76929ccfa1581eebd66175a53c7ec2ef604 Mon Sep 17 00:00:00 2001 From: Patrick Roberts Date: Thu, 16 Jan 2025 18:32:24 +0000 Subject: [PATCH] #16646: Revert reciprocal to auto launch --- ttnn/cpp/ttnn/decorators.hpp | 9 +++--- ttnn/cpp/ttnn/device_operation.hpp | 31 +++---------------- .../operations/eltwise/complex/complex.cpp | 13 ++++++++ .../operations/eltwise/complex/complex.hpp | 15 +++++++++ .../eltwise/complex/complex_pybind.hpp | 2 +- .../ttnn/operations/eltwise/unary/unary.hpp | 7 +---- 6 files changed, 38 insertions(+), 39 deletions(-) diff --git a/ttnn/cpp/ttnn/decorators.hpp b/ttnn/cpp/ttnn/decorators.hpp index 9ae2c249fc1..4eac54c7443 100644 --- a/ttnn/cpp/ttnn/decorators.hpp +++ b/ttnn/cpp/ttnn/decorators.hpp @@ -131,11 +131,10 @@ auto map_execute_on_worker_thread_return_to_launch_op_return(const T&& value) { } else if constexpr (is_homogenous_tuple()) { Tensors output_tensors; output_tensors.reserve(std::tuple_size_v); - std::apply( - [&output_tensors](auto&&... args) { - (output_tensors.emplace_back(std::forward(args)), ...); - }, - value); + [&](std::index_sequence) { + using std::get; + (output_tensors.emplace_back(std::forward(value))>(get(value))), ...); + }(std::make_index_sequence>{}); return output_tensors; } else { static_assert( diff --git a/ttnn/cpp/ttnn/device_operation.hpp b/ttnn/cpp/ttnn/device_operation.hpp index f18135658e3..bb63b9bb77f 100644 --- a/ttnn/cpp/ttnn/device_operation.hpp +++ b/ttnn/cpp/ttnn/device_operation.hpp @@ -436,33 +436,10 @@ typename device_operation_t::tensor_return_value_t launch_on_multi_device( std::vector outputs; outputs.reserve(num_shards); - bool launch_shards_in_parallel = false; - if (launch_shards_in_parallel) { - std::vector> shard_futures; - shard_futures.reserve(num_shards); - - // Launch each shard - for (auto shard_index = 0; shard_index < num_shards; shard_index++) { - shard_futures.emplace_back( - std::async( - std::launch::async, - [cq_id, operation_attributes, tensor_args, shard_index, storage]() mutable { - auto device = storage.get_buffer_for_device_id(shard_index)->device(); - auto shard_tensor_args = get_shard_tensor_args(shard_index, device, tensor_args); - return launch_on_single_device(cq_id, operation_attributes, shard_tensor_args); - })); - } - - // Combine shards into a multi-device storage - for (auto& shard_future : shard_futures) { - outputs.push_back(shard_future.get()); - } - } else { - for (auto shard_index = 0; shard_index < num_shards; shard_index++) { - auto device = storage.get_buffer_for_device_id(shard_index)->device(); - auto shard_tensor_args = get_shard_tensor_args(shard_index, device, tensor_args); - outputs.push_back(launch_on_single_device(cq_id, operation_attributes, shard_tensor_args)); - } + for (auto shard_index = 0; shard_index < num_shards; shard_index++) { + auto device = storage.get_buffer_for_device_id(shard_index)->device(); + auto shard_tensor_args = get_shard_tensor_args(shard_index, device, tensor_args); + outputs.push_back(launch_on_single_device(cq_id, operation_attributes, shard_tensor_args)); } return make_tensor_return_value_from_shards(storage, outputs); diff --git a/ttnn/cpp/ttnn/operations/eltwise/complex/complex.cpp b/ttnn/cpp/ttnn/operations/eltwise/complex/complex.cpp index 7c1ddaf1362..95384f07b48 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/complex/complex.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/complex/complex.cpp @@ -7,6 +7,9 @@ namespace ttnn { namespace operations::complex { +ComplexTensor::ComplexTensor(const std::tuple& real_imag) : + m_real_imag{std::get<0>(real_imag), std::get<1>(real_imag)} {} + const Tensor& ComplexTensor::operator[](uint32_t index) const { return m_real_imag[index]; } const Tensor& ComplexTensor::real() const { return m_real_imag[0]; } @@ -18,6 +21,16 @@ void ComplexTensor::deallocate() { m_real_imag[1].deallocate(); } +template <> +const Tensor& get<0>(const ComplexTensor& complex) { + return complex.real(); +} + +template <> +const Tensor& get<1>(const ComplexTensor& complex) { + return complex.imag(); +} + ComplexTensor CreateComplexTensor::invoke(const Tensor& real, const Tensor& imag) { TT_ASSERT(real.padded_shape() == imag.padded_shape(), "Tensor shapes of real and imag should be identical"); return ComplexTensor({real, imag}); diff --git a/ttnn/cpp/ttnn/operations/eltwise/complex/complex.hpp b/ttnn/cpp/ttnn/operations/eltwise/complex/complex.hpp index 839f18baca4..934c36c6b1c 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/complex/complex.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/complex/complex.hpp @@ -17,12 +17,17 @@ namespace operations::complex { struct ComplexTensor { std::array m_real_imag; + ComplexTensor(const std::tuple& real_imag); + const Tensor& operator[](uint32_t index) const; const Tensor& real() const; const Tensor& imag() const; void deallocate(); }; +template +const Tensor& get(const ComplexTensor&); + struct CreateComplexTensor { static ComplexTensor invoke(const Tensor& input_tensor_a_arg, const Tensor& input_tensor_b_arg); }; @@ -35,3 +40,13 @@ constexpr auto complex_tensor = ttnn::register_operation<"ttnn::complex_tensor", operations::complex::CreateComplexTensor>(); } // namespace ttnn + +template <> +struct std::tuple_size { + static constexpr std::size_t value = 2; +}; + +template +struct std::tuple_element { + using type = ttnn::Tensor; +}; diff --git a/ttnn/cpp/ttnn/operations/eltwise/complex/complex_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/complex/complex_pybind.hpp index 82afc7fe614..9d251da3901 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/complex/complex_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/complex/complex_pybind.hpp @@ -18,7 +18,7 @@ namespace detail { void bind_complex_tensor_type(py::module& m) { py::class_(m, "ComplexTensor") - .def(py::init>()) + .def(py::init>()) .def_property_readonly("real", &ComplexTensor::real) .def_property_readonly("imag", &ComplexTensor::imag) .def("deallocate", &ComplexTensor::deallocate) diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/unary.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/unary.hpp index 678af6f6f4e..5bffdd7e54c 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/unary.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/unary.hpp @@ -284,11 +284,6 @@ struct AsymmetricBinop { "ttnn::" #operation_name, \ ttnn::operations::unary::ExecuteUnary>(); -#define REGISTER_UNARY_OPERATION_OVERLOAD(operation_name, operation_type) \ - constexpr auto operation_name = ttnn::register_operation< \ - "ttnn::" #operation_name, \ - ttnn::operations::unary::ExecuteUnary>(); - #define REGISTER_UNARY_OPERATION_WITH_FAST_AND_APPROXIMATE_MODE(operation_name, operation_type) \ constexpr auto operation_name = ttnn::register_operation_with_auto_launch_op< \ "ttnn::" #operation_name, \ @@ -332,7 +327,7 @@ REGISTER_UNARY_OPERATION(logical_not, LOGICAL_NOT_UNARY); REGISTER_UNARY_OPERATION(ltz, LTZ); REGISTER_UNARY_OPERATION(neg, NEG); REGISTER_UNARY_OPERATION(nez, NEZ); -REGISTER_UNARY_OPERATION_OVERLOAD(reciprocal, RECIP); +REGISTER_UNARY_OPERATION(reciprocal, RECIP); REGISTER_UNARY_OPERATION(relu, RELU); REGISTER_UNARY_OPERATION(relu6, RELU6); REGISTER_UNARY_OPERATION(sigmoid, SIGMOID);