Skip to content

Commit

Permalink
#16646: Revert reciprocal to auto launch
Browse files Browse the repository at this point in the history
  • Loading branch information
Patrick Roberts authored and patrickroberts committed Jan 16, 2025
1 parent b97973e commit 7b6af76
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 39 deletions.
9 changes: 4 additions & 5 deletions ttnn/cpp/ttnn/decorators.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,10 @@ auto map_execute_on_worker_thread_return_to_launch_op_return(const T&& value) {
} else if constexpr (is_homogenous_tuple<T, Tensor>()) {
Tensors output_tensors;
output_tensors.reserve(std::tuple_size_v<T>);
std::apply(
[&output_tensors](auto&&... args) {
(output_tensors.emplace_back(std::forward<decltype(args)>(args)), ...);
},
value);
[&]<std::size_t... Is>(std::index_sequence<Is...>) {
using std::get;
(output_tensors.emplace_back(std::forward<decltype(get<Is>(value))>(get<Is>(value))), ...);
}(std::make_index_sequence<std::tuple_size_v<T>>{});
return output_tensors;
} else {
static_assert(
Expand Down
31 changes: 4 additions & 27 deletions ttnn/cpp/ttnn/device_operation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -436,33 +436,10 @@ typename device_operation_t::tensor_return_value_t launch_on_multi_device(
std::vector<tensor_return_value_t> outputs;
outputs.reserve(num_shards);

bool launch_shards_in_parallel = false;
if (launch_shards_in_parallel) {
std::vector<std::future<tensor_return_value_t>> shard_futures;
shard_futures.reserve(num_shards);

// Launch each shard
for (auto shard_index = 0; shard_index < num_shards; shard_index++) {
shard_futures.emplace_back(
std::async(
std::launch::async,
[cq_id, operation_attributes, tensor_args, shard_index, storage]() mutable {
auto device = storage.get_buffer_for_device_id(shard_index)->device();
auto shard_tensor_args = get_shard_tensor_args<device_operation_t>(shard_index, device, tensor_args);
return launch_on_single_device<device_operation_t>(cq_id, operation_attributes, shard_tensor_args);
}));
}

// Combine shards into a multi-device storage
for (auto& shard_future : shard_futures) {
outputs.push_back(shard_future.get());
}
} else {
for (auto shard_index = 0; shard_index < num_shards; shard_index++) {
auto device = storage.get_buffer_for_device_id(shard_index)->device();
auto shard_tensor_args = get_shard_tensor_args<device_operation_t>(shard_index, device, tensor_args);
outputs.push_back(launch_on_single_device<device_operation_t>(cq_id, operation_attributes, shard_tensor_args));
}
for (auto shard_index = 0; shard_index < num_shards; shard_index++) {
auto device = storage.get_buffer_for_device_id(shard_index)->device();
auto shard_tensor_args = get_shard_tensor_args<device_operation_t>(shard_index, device, tensor_args);
outputs.push_back(launch_on_single_device<device_operation_t>(cq_id, operation_attributes, shard_tensor_args));
}

return make_tensor_return_value_from_shards(storage, outputs);
Expand Down
13 changes: 13 additions & 0 deletions ttnn/cpp/ttnn/operations/eltwise/complex/complex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
namespace ttnn {
namespace operations::complex {

ComplexTensor::ComplexTensor(const std::tuple<const Tensor&, const Tensor&>& real_imag) :
m_real_imag{std::get<0>(real_imag), std::get<1>(real_imag)} {}

const Tensor& ComplexTensor::operator[](uint32_t index) const { return m_real_imag[index]; }

const Tensor& ComplexTensor::real() const { return m_real_imag[0]; }
Expand All @@ -18,6 +21,16 @@ void ComplexTensor::deallocate() {
m_real_imag[1].deallocate();
}

template <>
const Tensor& get<0>(const ComplexTensor& complex) {
return complex.real();
}

template <>
const Tensor& get<1>(const ComplexTensor& complex) {
return complex.imag();
}

ComplexTensor CreateComplexTensor::invoke(const Tensor& real, const Tensor& imag) {
TT_ASSERT(real.padded_shape() == imag.padded_shape(), "Tensor shapes of real and imag should be identical");
return ComplexTensor({real, imag});
Expand Down
15 changes: 15 additions & 0 deletions ttnn/cpp/ttnn/operations/eltwise/complex/complex.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,17 @@ namespace operations::complex {
struct ComplexTensor {
std::array<Tensor, 2> m_real_imag;

ComplexTensor(const std::tuple<const Tensor&, const Tensor&>& real_imag);

const Tensor& operator[](uint32_t index) const;
const Tensor& real() const;
const Tensor& imag() const;
void deallocate();
};

template <std::size_t I>
const Tensor& get(const ComplexTensor&);

struct CreateComplexTensor {
static ComplexTensor invoke(const Tensor& input_tensor_a_arg, const Tensor& input_tensor_b_arg);
};
Expand All @@ -35,3 +40,13 @@ constexpr auto complex_tensor =
ttnn::register_operation<"ttnn::complex_tensor", operations::complex::CreateComplexTensor>();

} // namespace ttnn

template <>
struct std::tuple_size<ttnn::operations::complex::ComplexTensor> {
static constexpr std::size_t value = 2;
};

template <std::size_t I>
struct std::tuple_element<I, ttnn::operations::complex::ComplexTensor> {
using type = ttnn::Tensor;
};
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace detail {

void bind_complex_tensor_type(py::module& m) {
py::class_<ComplexTensor>(m, "ComplexTensor")
.def(py::init<std::array<Tensor, 2>>())
.def(py::init<std::tuple<const Tensor&, const Tensor&>>())
.def_property_readonly("real", &ComplexTensor::real)
.def_property_readonly("imag", &ComplexTensor::imag)
.def("deallocate", &ComplexTensor::deallocate)
Expand Down
7 changes: 1 addition & 6 deletions ttnn/cpp/ttnn/operations/eltwise/unary/unary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,11 +284,6 @@ struct AsymmetricBinop {
"ttnn::" #operation_name, \
ttnn::operations::unary::ExecuteUnary<ttnn::operations::unary::UnaryOpType::operation_type>>();

#define REGISTER_UNARY_OPERATION_OVERLOAD(operation_name, operation_type) \
constexpr auto operation_name = ttnn::register_operation< \
"ttnn::" #operation_name, \
ttnn::operations::unary::ExecuteUnary<ttnn::operations::unary::UnaryOpType::operation_type>>();

#define REGISTER_UNARY_OPERATION_WITH_FAST_AND_APPROXIMATE_MODE(operation_name, operation_type) \
constexpr auto operation_name = ttnn::register_operation_with_auto_launch_op< \
"ttnn::" #operation_name, \
Expand Down Expand Up @@ -332,7 +327,7 @@ REGISTER_UNARY_OPERATION(logical_not, LOGICAL_NOT_UNARY);
REGISTER_UNARY_OPERATION(ltz, LTZ);
REGISTER_UNARY_OPERATION(neg, NEG);
REGISTER_UNARY_OPERATION(nez, NEZ);
REGISTER_UNARY_OPERATION_OVERLOAD(reciprocal, RECIP);
REGISTER_UNARY_OPERATION(reciprocal, RECIP);
REGISTER_UNARY_OPERATION(relu, RELU);
REGISTER_UNARY_OPERATION(relu6, RELU6);
REGISTER_UNARY_OPERATION(sigmoid, SIGMOID);
Expand Down

0 comments on commit 7b6af76

Please sign in to comment.