diff --git a/tt-train/sources/ttml/core/tt_tensor_utils.cpp b/tt-train/sources/ttml/core/tt_tensor_utils.cpp index 9f808e151f1..b271c39a007 100644 --- a/tt-train/sources/ttml/core/tt_tensor_utils.cpp +++ b/tt-train/sources/ttml/core/tt_tensor_utils.cpp @@ -326,4 +326,43 @@ template tt::tt_metal::Tensor from_xtensor( const XTensorToMeshVariant& composer, Layout layout); +ttnn::Tensor unsqueeze_to_rank(const ttnn::Tensor& t, size_t rank) { + auto logical_shape = t.get_logical_shape(); + auto physical_shape = t.get_padded_shape(); + auto t_rank = logical_shape.rank(); + TT_FATAL(t_rank <= rank, "Cannot unsqueeze to rank {} from rank {}", rank, t_rank); + + tt::tt_metal::SmallVector result_logical_shape(rank); + tt::tt_metal::SmallVector result_physical_shape(rank); + std::fill(result_logical_shape.begin(), result_logical_shape.end(), 1); + std::fill(result_physical_shape.begin(), result_physical_shape.end(), 1); + + auto rank_diff = rank - t_rank; + std::copy(logical_shape.cbegin(), logical_shape.cend(), result_logical_shape.begin() + rank_diff); + std::copy(physical_shape.cbegin(), physical_shape.cend(), result_physical_shape.begin() + rank_diff); + return ttnn::reshape(t, ttnn::Shape{result_logical_shape}, ttnn::Shape{result_physical_shape}); +} + +ttnn::Tensor squeeze_to_rank(const ttnn::Tensor& t, size_t rank) { + auto logical_shape = t.get_logical_shape(); + auto physical_shape = t.get_padded_shape(); + auto t_rank = logical_shape.rank(); + TT_FATAL(t_rank >= rank, "Cannot squeeze to rank {} from rank {}", rank, t_rank); + + auto rank_diff = t_rank - rank; + bool leading_ones = + std::all_of(logical_shape.cbegin(), logical_shape.cbegin() + rank_diff, [](size_t dim) { return dim == 1; }); + TT_FATAL(leading_ones, "Cannot squeeze shape {} to rank {}", logical_shape, rank); + + tt::tt_metal::SmallVector result_logical_shape(rank); + tt::tt_metal::SmallVector result_physical_shape(rank); + std::fill(result_logical_shape.begin(), result_logical_shape.end(), 1); + std::fill(result_physical_shape.begin(), result_physical_shape.end(), 1); + + std::copy(logical_shape.cbegin() + rank_diff, logical_shape.cend(), result_logical_shape.begin()); + std::copy(physical_shape.cbegin() + rank_diff, physical_shape.cend(), result_physical_shape.begin()); + + return ttnn::reshape(t, ttnn::Shape{result_logical_shape}, ttnn::Shape{result_physical_shape}); +} + } // namespace ttml::core diff --git a/tt-train/sources/ttml/core/tt_tensor_utils.hpp b/tt-train/sources/ttml/core/tt_tensor_utils.hpp index f3a2900e080..9815c7a3c2e 100644 --- a/tt-train/sources/ttml/core/tt_tensor_utils.hpp +++ b/tt-train/sources/ttml/core/tt_tensor_utils.hpp @@ -83,4 +83,10 @@ tt::tt_metal::Tensor from_xtensor( const XTensorToMeshVariant& composer, Layout layout = Layout::TILE); +// Unsqueeze tensor to specified rank by adding leading dimensions of size 1 +ttnn::Tensor unsqueeze_to_rank(const ttnn::Tensor& t, size_t rank); + +// Squeeze tensor to specified rank by removing leading dimensions of size 1 +ttnn::Tensor squeeze_to_rank(const ttnn::Tensor& t, size_t rank); + } // namespace ttml::core diff --git a/tt-train/sources/ttml/core/ttnn_all_includes.hpp b/tt-train/sources/ttml/core/ttnn_all_includes.hpp index 0dc4a096ea8..e02c988a8c2 100644 --- a/tt-train/sources/ttml/core/ttnn_all_includes.hpp +++ b/tt-train/sources/ttml/core/ttnn_all_includes.hpp @@ -33,11 +33,14 @@ #include // NOLINT #include // NOLINT #include // NOLINT +#include // NOLINT #include // NOLINT #include // NOLINT #include // NOLINT #include // NOLINT +#include // NOLINT #include // NOLINT +#include // NOLINT #include // NOLINT #include // NOLINT #include // NOLINT diff --git a/tt-train/sources/ttml/modules/rms_norm_module.cpp b/tt-train/sources/ttml/modules/rms_norm_module.cpp new file mode 100644 index 00000000000..04f82a28c28 --- /dev/null +++ b/tt-train/sources/ttml/modules/rms_norm_module.cpp @@ -0,0 +1,28 @@ +// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "rms_norm_module.hpp" + +#include "core/tt_tensor_utils.hpp" +#include "ops/rmsnorm_op.hpp" + +namespace ttml::modules { + +void RMSNormLayer::initialize_tensors(uint32_t features) { + m_gamma = + autograd::create_tensor(core::ones(core::create_shape({1, 1, 1, features}), &autograd::ctx().get_device())); +} + +RMSNormLayer::RMSNormLayer(uint32_t features, float epsilon) : m_epsilon(epsilon) { + initialize_tensors(features); + + create_name("rmsnorm"); + register_tensor(m_gamma, "gamma"); +} + +autograd::TensorPtr RMSNormLayer::operator()(const autograd::TensorPtr& tensor) { + return ops::rmsnorm(tensor, m_gamma, m_epsilon); +} + +} // namespace ttml::modules diff --git a/tt-train/sources/ttml/modules/rms_norm_module.hpp b/tt-train/sources/ttml/modules/rms_norm_module.hpp new file mode 100644 index 00000000000..721b3658c07 --- /dev/null +++ b/tt-train/sources/ttml/modules/rms_norm_module.hpp @@ -0,0 +1,27 @@ +// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "autograd/auto_context.hpp" +#include "autograd/graph.hpp" +#include "autograd/module_base.hpp" +#include "autograd/tensor.hpp" +#include "ops/rmsnorm_op.hpp" + +namespace ttml::modules { + +class RMSNormLayer : public autograd::ModuleBase { +private: + float m_epsilon = 1e-5F; + autograd::TensorPtr m_gamma = nullptr; + +public: + void initialize_tensors(uint32_t features); + explicit RMSNormLayer(uint32_t features, float epsilon = 1e-5F); + + [[nodiscard]] autograd::TensorPtr operator()(const autograd::TensorPtr& tensor); +}; + +} // namespace ttml::modules diff --git a/tt-train/sources/ttml/ops/binary_ops.cpp b/tt-train/sources/ttml/ops/binary_ops.cpp index 511f3e8a0a5..90787111a9e 100644 --- a/tt-train/sources/ttml/ops/binary_ops.cpp +++ b/tt-train/sources/ttml/ops/binary_ops.cpp @@ -4,6 +4,7 @@ #include "binary_ops.hpp" +#include #include #include #include @@ -86,7 +87,7 @@ autograd::TensorPtr operator-(const autograd::TensorPtr& a, const autograd::Tens b->add_grad(ttnn::neg(out->get_grad())); }; auto links = autograd::get_links(a, b); - + out->set_node(autograd::ctx().add_backward_node(std::move(grad), links)); return out; @@ -102,6 +103,42 @@ autograd::TensorPtr operator*(const autograd::TensorPtr& a, const autograd::Tens auto a_grad = ttnn::multiply(out->get_grad(), b->get_value()); auto b_grad = ttnn::multiply(out->get_grad(), a->get_value()); + auto clamp_to_rank = [](const ttnn::Tensor& tensor, size_t rank) { + auto tensor_rank = tensor.logical_shape().rank(); + if (tensor_rank == rank) { + return tensor; + } else if (tensor_rank > rank) { + return ttml::core::squeeze_to_rank(tensor, rank); + } else { + return ttml::core::unsqueeze_to_rank(tensor, rank); + } + }; + + auto logical_suffixes_match = [](const ttnn::Tensor& a, const ttnn::Tensor& b) { + auto a_shape = a.get_logical_shape(); + auto b_shape = b.get_logical_shape(); + + auto suffix_len = std::min(a_shape.size(), b_shape.size()); + for (auto i = -1; i >= -suffix_len; i--) { + if (a_shape[i] != b_shape[i]) { + return false; + } + } + return true; + }; + + if (a->get_value().logical_shape().rank() != a_grad.logical_shape().rank()) { + if (logical_suffixes_match(a->get_value(), a_grad)) { + a_grad = clamp_to_rank(a_grad, a->get_value().logical_shape().rank()); + } + } + + if (b->get_value().logical_shape().rank() != b_grad.logical_shape().rank()) { + if (logical_suffixes_match(b->get_value(), b_grad)) { + b_grad = clamp_to_rank(b_grad, b->get_value().logical_shape().rank()); + } + } + a->add_grad(a_grad); b->add_grad(b_grad); }; @@ -124,6 +161,14 @@ autograd::TensorPtr operator*(const autograd::TensorPtr& a, float b) { return out; } +autograd::TensorPtr operator*(float a, const autograd::TensorPtr& b) { + return b * a; +} + +autograd::TensorPtr operator/(const autograd::TensorPtr& a, float b) { + return a * (1.F / b); +} + autograd::TensorPtr operator/(const autograd::TensorPtr& a, const autograd::TensorPtr& b) { auto out = autograd::create_tensor(); @@ -155,12 +200,70 @@ autograd::TensorPtr mul(const autograd::TensorPtr& a, const autograd::TensorPtr& return a * b; } +autograd::TensorPtr mul(const autograd::TensorPtr& a, float b) { + return a * b; +} + +autograd::TensorPtr mul(float a, const autograd::TensorPtr& b) { + return b * a; +} + autograd::TensorPtr div(const autograd::TensorPtr& a, const autograd::TensorPtr& b) { return a / b; } -autograd::TensorPtr mul(const autograd::TensorPtr& a, float b) { - return a * b; +autograd::TensorPtr div(const autograd::TensorPtr& a, float b) { + return a / b; +} + +tt::tt_metal::Tensor ttnn_matmul( + const tt::tt_metal::Tensor& a, const tt::tt_metal::Tensor& b, bool transpose_a, bool transpose_b) { + return ttnn::matmul( + a, + b, + transpose_a, + transpose_b, + /* memory_config */ std::nullopt, + /* dtype */ std::nullopt, + /* program_config */ std::nullopt, + /* activation */ std::nullopt, + /* compute_kernel_config */ core::ComputeKernelConfig::matmul(), + /* core_grid */ std::nullopt, // NOTE: I believe matmul will use the + // core grid for the device it ends up + // running on, but should confirm. + /* output_tile */ std::nullopt); +} + +autograd::TensorPtr matmul( + const autograd::TensorPtr& a, const autograd::TensorPtr& b, bool transpose_a, bool transpose_b) { + auto out = autograd::create_tensor(); + out->set_value(ttnn_matmul(a->get_value(), b->get_value(), transpose_a, transpose_b)); + + autograd::GradFunction grad = [a, b, out, transpose_a, transpose_b]() { + // For loss function L and matmul C = AB: + // dL/dA = dL/dC * B^T + // dL/dB = A^T * dL/dC + + // where L is the loss function + auto grad_a = ttnn_matmul( + out->get_grad(), + b->get_value(), + /* transpose_a */ transpose_a, + /* transpose_b */ !transpose_b); + auto grad_b = ttnn_matmul( + a->get_value(), + out->get_grad(), + /* transpose_a */ !transpose_a, + /* transpose_b */ transpose_b); + + a->add_grad(grad_a); + b->add_grad(grad_b); + }; + + auto links = autograd::get_links(a, b); + out->set_node(autograd::ctx().add_backward_node(std::move(grad), links)); + + return out; } } // namespace ttml::ops diff --git a/tt-train/sources/ttml/ops/binary_ops.hpp b/tt-train/sources/ttml/ops/binary_ops.hpp index 862e318f1a2..2a4def45b30 100644 --- a/tt-train/sources/ttml/ops/binary_ops.hpp +++ b/tt-train/sources/ttml/ops/binary_ops.hpp @@ -12,14 +12,21 @@ autograd::TensorPtr operator+(const autograd::TensorPtr& a, const autograd::Auto autograd::TensorPtr operator+(const autograd::TensorPtr& a, const autograd::TensorPtr& b); autograd::TensorPtr operator*(const autograd::TensorPtr& a, const autograd::TensorPtr& b); autograd::TensorPtr operator*(const autograd::TensorPtr& a, float b); +autograd::TensorPtr operator*(float a, const autograd::TensorPtr& b); autograd::TensorPtr operator-(const autograd::TensorPtr& a, const autograd::TensorPtr& b); autograd::TensorPtr operator/(const autograd::TensorPtr& a, const autograd::TensorPtr& b); +autograd::TensorPtr operator/(const autograd::TensorPtr& a, float b); autograd::TensorPtr add(const autograd::TensorPtr& a, const autograd::AutocastTensor& b); autograd::TensorPtr add(const autograd::TensorPtr& a, const autograd::TensorPtr& b); autograd::TensorPtr sub(const autograd::TensorPtr& a, const autograd::TensorPtr& b); autograd::TensorPtr mul(const autograd::TensorPtr& a, const autograd::TensorPtr& b); autograd::TensorPtr mul(const autograd::TensorPtr& a, float b); +autograd::TensorPtr mul(float a, const autograd::TensorPtr& b); autograd::TensorPtr div(const autograd::TensorPtr& a, const autograd::TensorPtr& b); +autograd::TensorPtr div(const autograd::TensorPtr& a, float b); + +autograd::TensorPtr matmul( + const autograd::TensorPtr& a, const autograd::TensorPtr& b, bool transpose_a, bool transpose_b); } // namespace ttml::ops diff --git a/tt-train/sources/ttml/ops/rmsnorm_op.cpp b/tt-train/sources/ttml/ops/rmsnorm_op.cpp new file mode 100644 index 00000000000..b83a288ca27 --- /dev/null +++ b/tt-train/sources/ttml/ops/rmsnorm_op.cpp @@ -0,0 +1,66 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include +#include +#include +#include + +#include "autograd/auto_context.hpp" +#include "autograd/graph.hpp" +#include "autograd/graph_utils.hpp" +#include "autograd/tensor.hpp" +#include "core/compute_kernel_config.hpp" +#include "core/tt_tensor_utils.hpp" +#include "layernorm_op.hpp" +#include "ops/binary_ops.hpp" +#include "ops/unary_ops.hpp" + +namespace ttml::ops { + +autograd::TensorPtr rmsnorm(const autograd::TensorPtr& tensor, const autograd::TensorPtr& gamma, float epsilon) { + auto device = &autograd::ctx().get_device(); + ttnn::Tensor squares = ttnn::square(tensor->get_value()); + std::array eps_shape = {1, 1, 1, 1}; + ttnn::Tensor eps_tensor = core::from_vector({epsilon}, core::create_shape(eps_shape), device); + ttnn::Tensor mean_of_squares = ttnn::mean(squares); + ttnn::Tensor mean_of_squares_plus_epsilon = ttnn::experimental::add(mean_of_squares, eps_tensor); + ttnn::Tensor rms_eps = ttnn::sqrt(mean_of_squares_plus_epsilon); + ttnn::Tensor gamma_times_activations = ttnn::experimental::mul(gamma->get_value(), tensor->get_value()); + ttnn::Tensor out_tensor = ttnn::experimental::div(gamma_times_activations, rms_eps); + + auto out = autograd::create_tensor(out_tensor); + out->set_value(out_tensor); + + autograd::GradFunction grad = [tensor, gamma, out, eps_tensor]() { + auto a = tensor->get_value(); + auto g = gamma->get_value(); + auto dout = out->get_grad(); + // let tensor = {a_i | i = 0, 1, ..., n} + // and gamma = {g_i | i = 0, 1, ..., n} + + // backward grads in terms of dout: + // dL/da_i = dL/dout * eps * gamma_i / (eps + a_i^2)^(3/2) + // dL/dg_i = dL/dout * a_i / sqrt(eps + a_i^2) + + auto dtensor_divisor = ttnn::pow(ttnn::experimental::add(eps_tensor, ttnn::square(a)), 3.0F / 2.0F); + auto dtensor_dividend = ttnn::experimental::mul(ttnn::experimental::mul(dout, g), eps_tensor); + auto dtensor = ttnn::experimental::div(dtensor_dividend, dtensor_divisor); + + auto dgamma_dividend = ttnn::experimental::mul(dout, a); + auto dgamma_divisor = ttnn::sqrt(ttnn::experimental::add(eps_tensor, ttnn::square(a))); // using a^2 + eps for scalar add in ttnn. + auto dgamma = ttnn::experimental::div(dgamma_dividend, dgamma_divisor); + + tensor->add_grad(dtensor); + gamma->add_grad(dgamma); + }; + + auto links = autograd::get_links(tensor); + out->set_node(autograd::ctx().add_backward_node(std::move(grad), links)); + + return out; +} + +} // namespace ttml::ops diff --git a/tt-train/sources/ttml/ops/rmsnorm_op.hpp b/tt-train/sources/ttml/ops/rmsnorm_op.hpp new file mode 100644 index 00000000000..34499b75b4b --- /dev/null +++ b/tt-train/sources/ttml/ops/rmsnorm_op.hpp @@ -0,0 +1,12 @@ +// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once +#include "autograd/tensor.hpp" + +namespace ttml::ops { + +autograd::TensorPtr rmsnorm(const autograd::TensorPtr& tensor, const autograd::TensorPtr& gamma, float epsilon); + +} // namespace ttml::ops diff --git a/tt-train/sources/ttml/ops/unary_ops.cpp b/tt-train/sources/ttml/ops/unary_ops.cpp index e2e76fb881c..22a0f5100a6 100644 --- a/tt-train/sources/ttml/ops/unary_ops.cpp +++ b/tt-train/sources/ttml/ops/unary_ops.cpp @@ -140,4 +140,36 @@ autograd::TensorPtr broadcast_batch(const autograd::TensorPtr& tensor, uint32_t return out; } +autograd::TensorPtr sqrt(const autograd::TensorPtr& tensor) { + auto out = autograd::create_tensor(); + auto sqrt_tensor = ttnn::sqrt(tensor->get_value()); + out->set_value(sqrt_tensor); + autograd::GradFunction grad = [&tensor, &out, &sqrt_tensor]() { + // dL/dx = dL/d(sqrt(x)) * 1/(2*sqrt(x)) + auto grad = ttnn::divide(out->get_grad(), ttnn::multiply(sqrt_tensor, 2.F)); + tensor->add_grad(grad); + }; + auto links = autograd::get_links(tensor); + out->set_node(autograd::ctx().add_backward_node(std::move(grad), links)); + return out; +} + +autograd::TensorPtr sum(const autograd::TensorPtr& tensor) { + auto out = autograd::create_tensor(); + out->set_value(ttml::ttnn_fixed::sum_moreh(tensor->get_value())); + + autograd::GradFunction grad = [tensor, out]() { + // Distribute the gradient to each element in the original tensor + auto in_shape = tensor->get_value().get_logical_shape(); + auto grad_shape = out->get_grad().get_logical_shape(); + + auto unsqueezed_grad = ttml::core::unsqueeze_to_rank(out->get_grad(), in_shape.rank()); + auto grad_broadcast = ttnn::repeat(unsqueezed_grad, in_shape); + tensor->add_grad(grad_broadcast); + }; + + auto links = autograd::get_links(tensor); + out->set_node(autograd::ctx().add_backward_node(std::move(grad), links)); + return out; +} } // namespace ttml::ops diff --git a/tt-train/sources/ttml/ops/unary_ops.hpp b/tt-train/sources/ttml/ops/unary_ops.hpp index 669ee04233b..33e964ffe3e 100644 --- a/tt-train/sources/ttml/ops/unary_ops.hpp +++ b/tt-train/sources/ttml/ops/unary_ops.hpp @@ -15,4 +15,6 @@ autograd::TensorPtr sum(const autograd::TensorPtr& tensor); autograd::TensorPtr broadcast_batch(const autograd::TensorPtr& tensor, uint32_t new_batch_dim); autograd::TensorPtr log_softmax(const autograd::TensorPtr& tensor, int dim); autograd::TensorPtr log_softmax_moreh(const autograd::TensorPtr& tensor, int dim); +autograd::TensorPtr sqrt(const autograd::TensorPtr& tensor); + } // namespace ttml::ops diff --git a/tt-train/sources/ttml/ttnn_fixed/trivial_ttnn_ops.cpp b/tt-train/sources/ttml/ttnn_fixed/trivial_ttnn_ops.cpp index ad818f6040f..f28363f9bae 100644 --- a/tt-train/sources/ttml/ttnn_fixed/trivial_ttnn_ops.cpp +++ b/tt-train/sources/ttml/ttnn_fixed/trivial_ttnn_ops.cpp @@ -73,6 +73,14 @@ tt::tt_metal::Tensor sum_moreh(const tt::tt_metal::Tensor& t, int dim, bool keep /* device_compute_kernel_config */ core::ComputeKernelConfig::precise()); return res; } + +// Overload supporting generic sum over multiple dimensions +tt::tt_metal::Tensor sum_moreh(const tt::tt_metal::Tensor& t, std::optional> dims, bool keep_dim) { + tt::tt_metal::Tensor res = + ttnn::moreh_sum(t, dims, keep_dim, std::nullopt, std::nullopt, core::ComputeKernelConfig::precise()); + return res; +} + tt::tt_metal::Tensor sum_ttnn(const tt::tt_metal::Tensor& t, int dim, bool keep_dim) { return ttnn::sum(t, dim, keep_dim, std::nullopt, core::ComputeKernelConfig::precise()); } diff --git a/tt-train/sources/ttml/ttnn_fixed/trivial_ttnn_ops.hpp b/tt-train/sources/ttml/ttnn_fixed/trivial_ttnn_ops.hpp index c8a62d981bc..a736b304d3a 100644 --- a/tt-train/sources/ttml/ttnn_fixed/trivial_ttnn_ops.hpp +++ b/tt-train/sources/ttml/ttnn_fixed/trivial_ttnn_ops.hpp @@ -19,5 +19,7 @@ tt::tt_metal::Tensor mean_moreh(const tt::tt_metal::Tensor& t, int dim, bool kee tt::tt_metal::Tensor mean_ttnn(const tt::tt_metal::Tensor& t, int dim, bool keep_dim); tt::tt_metal::Tensor sum_moreh(const tt::tt_metal::Tensor& t, int dim, bool keep_dim); +tt::tt_metal::Tensor sum_moreh( + const tt::tt_metal::Tensor& t, std::optional> dims = std::nullopt, bool keep_dim = false); tt::tt_metal::Tensor sum_ttnn(const tt::tt_metal::Tensor& t, int dim, bool keep_dim); } // namespace ttml::ttnn_fixed diff --git a/tt-train/tests/ops/binary_ops_test.cpp b/tt-train/tests/ops/binary_ops_test.cpp new file mode 100644 index 00000000000..0b8dd2bdf01 --- /dev/null +++ b/tt-train/tests/ops/binary_ops_test.cpp @@ -0,0 +1,69 @@ +// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ops/binary_ops.hpp" + +#include + +#include + +#include "autograd/auto_context.hpp" +#include "autograd/tensor.hpp" +#include "core/tt_tensor_utils.hpp" +#include "core/xtensor_utils.hpp" + +class BinaryOpsTest : public ::testing::Test { +protected: + void SetUp() override { + ttml::autograd::ctx().open_device(); + } + + void TearDown() override { + ttml::autograd::ctx().close_device(); + } +}; + +TEST_F(BinaryOpsTest, TensorAdd_Broadcasted) { + xt::xarray a = {{1.F, 2.F, 3.F, 4.F, 1.F, 2.F, 3.F, 4.F}}; + xt::xarray b = xt::xarray::from_shape({1, 1, 1, 1}); + b(0, 0, 0, 0) = 1.F; + + auto a_tensor = ttml::autograd::create_tensor(ttml::core::from_xtensor(a, &ttml::autograd::ctx().get_device())); + auto b_tensor = ttml::autograd::create_tensor(ttml::core::from_xtensor(b, &ttml::autograd::ctx().get_device())); + + auto result = ttml::ops::add(a_tensor, b_tensor); + auto result_xarray = ttml::core::to_xtensor(result->get_value()); + + auto expected_result = xt::xarray{2.F, 3.F, 4.F, 5.F, 2.F, 3.F, 4.F, 5.F}; + + EXPECT_TRUE(xt::allclose(result_xarray, expected_result)); +} + +TEST_F(BinaryOpsTest, TensorMul_Eltwise) { + xt::xarray a = {{1.F, 2.F, 3.F, 4.F, 1.F, 2.F, 3.F, 4.F}}; + xt::xarray b = {{1.F, 2.F, 3.F, 4.F, 1.F, 2.F, 3.F, 4.F}}; + + auto a_tensor = ttml::autograd::create_tensor(ttml::core::from_xtensor(a, &ttml::autograd::ctx().get_device())); + auto b_tensor = ttml::autograd::create_tensor(ttml::core::from_xtensor(b, &ttml::autograd::ctx().get_device())); + + auto result = ttml::ops::mul(a_tensor, b_tensor); + auto result_xarray = ttml::core::to_xtensor(result->get_value()); + + auto expected_result = xt::xarray{{1.F, 4.F, 9.F, 16.F, 1.F, 4.F, 9.F, 16.F}}; + + EXPECT_TRUE(xt::allclose(result_xarray, expected_result)); +} + +TEST_F(BinaryOpsTest, TensorDivByFloat) { + xt::xarray a = {{1.F, 2.F, 3.F, 4.F, 1.F, 2.F, 3.F, 4.F}}; + + auto a_tensor = ttml::autograd::create_tensor(ttml::core::from_xtensor(a, &ttml::autograd::ctx().get_device())); + float b = 2.F; + auto result = ttml::ops::div(a_tensor, b); + auto result_xarray = ttml::core::to_xtensor(result->get_value()); + + auto expected_result = xt::xarray{{0.5F, 1.F, 1.5F, 2.F, 0.5F, 1.F, 1.5F, 2.F}}; + + EXPECT_TRUE(xt::allclose(result_xarray, expected_result)); +} diff --git a/tt-train/tests/ops/rmsnorm_op_test.cpp b/tt-train/tests/ops/rmsnorm_op_test.cpp new file mode 100644 index 00000000000..d7bd34c12b1 --- /dev/null +++ b/tt-train/tests/ops/rmsnorm_op_test.cpp @@ -0,0 +1,116 @@ +// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ops/rmsnorm_op.hpp" + +#include + +#include +#include + +#include "autograd/auto_context.hpp" +#include "autograd/tensor.hpp" +#include "core/tt_tensor_utils.hpp" +#include "ops/losses.hpp" +#include "ops/unary_ops.hpp" + +class RMSNormOpTest : public ::testing::Test { +protected: + void SetUp() override { + ttml::autograd::ctx().open_device(); + } + + void TearDown() override { + ttml::autograd::ctx().close_device(); + } +}; + +// Forward and backward tests are given by comparing with results from PyTorch: +// For test tensor `x` of shape [N,C,H,W] we set x.requires_grad = True +// and compute the RMSNorm as `x_norm_sum = torch.nn.functional.rms_norm(x).sum()` +// and compute its gradient with respect to `x` as `x_grad = torch.autograd.grad(x_norm_sum, x)[0]` +// We then compare the results of the RMSNorm and its gradient with the results of the RMSNorm and its gradient +// computed by the RMSNorm op in TTML. + +TEST_F(RMSNormOpTest, RMSNormOp_Forward) { + using namespace ttml; + + xt::xarray example_xtensor = { + {{{0.0037F, 1.0103F, -0.0769F, -1.0242F, 0.7413F}, {1.5342F, -1.4141F, 0.9436F, -0.3354F, 0.5814F}}, + {{2.1460F, 0.7238F, -0.2614F, -0.0608F, 1.3787F}, {0.2094F, -1.3087F, -1.8958F, 0.6596F, -1.3471F}}}, + {{{1.2607F, 1.7451F, -1.6049F, -0.0411F, -0.9609F}, {0.1918F, -1.2580F, -0.5534F, -0.3971F, -0.6368F}}, + {{0.2271F, 0.0791F, 0.8026F, 0.4299F, 0.8505F}, {1.5362F, 0.9735F, 0.4186F, -1.4561F, 1.3001F}}}}; + + auto example_tensor = autograd::create_tensor(core::from_xtensor(example_xtensor, &autograd::ctx().get_device())); + + uint32_t N = 2, C = 2, H = 2, W = 5; + + uint32_t size = N * C * H * W; + + auto gamma = autograd::create_tensor(core::ones(core::create_shape({1, 1, 1, W}), &autograd::ctx().get_device())); + + auto result = ops::rmsnorm(example_tensor, gamma, 0.0F); + + // Compare result with torch + auto result_xtensor = core::to_xtensor(result->get_value()); + + auto expected_result = xt::xarray( + {{{{{0.0051F, 1.3943F, -0.1061F, -1.4135F, 1.0230F}, {1.4376F, -1.3251F, 0.8842F, -0.3143F, 0.5448F}}, + + {{1.8006F, 0.6073F, -0.2194F, -0.0510F, 1.1568F}, {0.1698F, -1.0614F, -1.5377F, 0.5350F, -1.0926F}}}, + + {{{0.9884F, 1.3681F, -1.2582F, -0.0322F, -0.7533F}, {0.2719F, -1.7830F, -0.7845F, -0.5629F, -0.9025F}}, + + {{0.4003F, 0.1393F, 1.4143F, 0.7576F, 1.4987F}, {1.2719F, 0.8061F, 0.3466F, -1.2056F, 1.0765F}}}}}); + + std::cout << "result_xtensor: " << result_xtensor << "\n"; + std::cout << "expected_result: " << expected_result << "\n"; + + EXPECT_TRUE(xt::allclose(result_xtensor, expected_result, 1e-4F)); + + // Take grad of sum of result with respect to example_tensor + auto sum_result = ttml::ops::sum(result); + sum_result->backward(); + auto example_tensor_grad = core::to_xtensor(example_tensor->get_grad()); + + auto expected_example_tensor_grad = xt::xarray( + {{{{{1.3788, 1.0326, 1.4065, 1.7323, 1.1251}, {0.6064, 1.2418, 0.7337, 1.0093, 0.8117}}, + + {{-0.1564, 0.5033, 0.9603, 0.8673, 0.1995}, {0.8934, 0.2968, 0.0660, 1.0703, 0.2817}}}}, + + {{{{0.7355, 0.7169, 0.8457, 0.7855, 0.8209}, {1.7072, -0.4837, 0.5811, 0.8173, 0.4551}}, + + {{1.1684, 1.5554, -0.3364, 0.6381, -0.4617}, {0.3445, 0.5216, 0.6962, 1.2863, 0.4188}}}}}); + EXPECT_TRUE(xt::allclose(example_tensor_grad, expected_example_tensor_grad, 1e-4F)); +} + +TEST_F(RMSNormOpTest, RMSNormOp_Forward_Small) { + using namespace ttml; + + xt::xarray example_xtensor = {{1.F, 2.F, 3.F, 4.F, 1.F, 2.F, 3.F, 4.F}}; + auto example_tensor = autograd::create_tensor(core::from_xtensor(example_xtensor, &autograd::ctx().get_device())); + + uint32_t H = 1, W = 8; + + uint32_t size = H * W; + + auto gamma = autograd::create_tensor(core::ones(core::create_shape({1, 1, 1, W}), &autograd::ctx().get_device())); + auto result = ops::rmsnorm(example_tensor, gamma, 0.0F); + + // Compare result with torch + auto result_xtensor = core::to_xtensor(result->get_value()); + + xt::xarray expected_result = {{0.3651F, 0.7303F, 1.0954F, 1.4606F, 0.3651F, 0.7303F, 1.0954F, 1.4606F}}; + std::cout << "result_xtensor: " << result_xtensor << "\n"; + std::cout << "expected_result: " << expected_result << "\n"; + + EXPECT_TRUE(xt::allclose(result_xtensor, expected_result, 1e-2F)); + + auto sum_result = ttml::ops::sum(result); + sum_result->backward(); + auto example_tensor_grad = core::to_xtensor(example_tensor->get_grad()); + auto expected_example_tensor_grad = xt::xarray( + {{2.4343e-01F, 1.2172e-01F, 2.9802e-08F, -1.2172e-01F, 2.4343e-01F, 1.2172e-01F, 2.9802e-08F, -1.2172e-01F}}); + EXPECT_TRUE(xt::allclose(example_tensor_grad, expected_example_tensor_grad, 3e-2F)); +} diff --git a/tt-train/tests/ops/unary_ops_test.cpp b/tt-train/tests/ops/unary_ops_test.cpp index 90c2afeac0d..3208a9a12a6 100644 --- a/tt-train/tests/ops/unary_ops_test.cpp +++ b/tt-train/tests/ops/unary_ops_test.cpp @@ -11,6 +11,7 @@ #include "autograd/auto_context.hpp" #include "autograd/tensor.hpp" #include "core/tt_tensor_utils.hpp" +#include "core/xtensor_utils.hpp" class UnaryOpsTest : public ::testing::Test { protected: @@ -45,6 +46,43 @@ TEST_F(UnaryOpsTest, GlobalMean) { } } +TEST_F(UnaryOpsTest, Sum) { + xt::xarray test_vector = {{1.F, 2.F, 3.F, 4.F}, {1.F, 2.F, 3.F, 4.F}}; + auto test_tensor_ptr = + ttml::autograd::create_tensor(ttml::core::from_xtensor(test_vector, &ttml::autograd::ctx().get_device())); + + auto result = ttml::ops::sum(test_tensor_ptr); + auto result_vector = ttml::core::to_xtensor(result->get_value()); + + ASSERT_TRUE(xt::allclose(result_vector, xt::sum(test_vector), 1e-5F)); + + result->backward(); + auto test_tensor_grad = ttml::core::to_xtensor(test_tensor_ptr->get_grad()); + + ASSERT_TRUE(xt::allclose(xt::ones_like(test_vector), test_tensor_grad, 1e-5F)); +} + +TEST_F(UnaryOpsTest, Sqrt) { + xt::xarray test_vector = {{1.F, 2.F, 3.F, 4.F}, {1.F, 2.F, 3.F, 4.F}}; + auto test_tensor_ptr = + ttml::autograd::create_tensor(ttml::core::from_xtensor(test_vector, &ttml::autograd::ctx().get_device())); + + auto result = ttml::ops::sqrt(test_tensor_ptr); + auto result_vector = ttml::core::to_xtensor(result->get_value()); + + std::cout << "result_vector: " << result_vector << std::endl; + std::cout << "test_vector: " << test_vector << std::endl; + std::cout << "xt::sqrt(test_vector): " << xt::sqrt(test_vector) << std::endl; + + ASSERT_TRUE(xt::allclose(result_vector, xt::sqrt(test_vector), 1e-2F)); + + // FIXME(jaykru-tt): add grad test for sqrt + // result->backward(); + // auto test_tensor_grad = ttml::core::to_xtensor(test_tensor_ptr->get_grad()); + + // ASSERT_TRUE(xt::allclose(xt::ones_like(test_vector), test_tensor_grad)); +} + TEST_F(UnaryOpsTest, LogSoftmax) { auto* device = &ttml::autograd::ctx().get_device(); std::vector test_data = {-0.1F, -0.2F, -0.3F, -0.4F, 0.F, -0.2F, -0.3F, -0.4F};