Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tt-train] Add RMSNorm module #16991

Draft
wants to merge 20 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions tt-train/sources/ttml/core/tt_tensor_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,4 +329,43 @@ template tt::tt_metal::Tensor from_xtensor<uint32_t, DataType::UINT32>(
const XTensorToMeshVariant<uint32_t>& composer,
Layout layout);

ttnn::Tensor unsqueeze_to_rank(const ttnn::Tensor& t, size_t rank) {
auto logical_shape = t.get_logical_shape();
auto physical_shape = t.get_padded_shape();
auto t_rank = logical_shape.rank();
TT_FATAL(t_rank <= rank, "Cannot unsqueeze to rank {} from rank {}", rank, t_rank);

tt::tt_metal::SmallVector<uint32_t> result_logical_shape(rank);
tt::tt_metal::SmallVector<uint32_t> result_physical_shape(rank);
std::fill(result_logical_shape.begin(), result_logical_shape.end(), 1);
std::fill(result_physical_shape.begin(), result_physical_shape.end(), 1);

auto rank_diff = rank - t_rank;
std::copy(logical_shape.cbegin(), logical_shape.cend(), result_logical_shape.begin() + rank_diff);
std::copy(physical_shape.cbegin(), physical_shape.cend(), result_physical_shape.begin() + rank_diff);
return ttnn::reshape(t, ttnn::Shape{result_logical_shape, result_physical_shape});
}

ttnn::Tensor squeeze_to_rank(const ttnn::Tensor& t, size_t rank) {
auto logical_shape = t.get_logical_shape();
auto physical_shape = t.get_padded_shape();
auto t_rank = logical_shape.rank();
TT_FATAL(t_rank >= rank, "Cannot squeeze to rank {} from rank {}", rank, t_rank);

auto rank_diff = t_rank - rank;
bool leading_ones =
std::all_of(logical_shape.cbegin(), logical_shape.cbegin() + rank_diff, [](size_t dim) { return dim == 1; });
TT_FATAL(leading_ones, "Cannot squeeze shape {} to rank {}", logical_shape, rank);

tt::tt_metal::SmallVector<uint32_t> result_logical_shape(rank);
tt::tt_metal::SmallVector<uint32_t> result_physical_shape(rank);
std::fill(result_logical_shape.begin(), result_logical_shape.end(), 1);
std::fill(result_physical_shape.begin(), result_physical_shape.end(), 1);

std::copy(logical_shape.cbegin() + rank_diff, logical_shape.cend(), result_logical_shape.begin());
std::copy(physical_shape.cbegin() + rank_diff, physical_shape.cend(), result_physical_shape.begin());

return ttnn::reshape(t, ttnn::Shape{result_logical_shape, result_physical_shape});
}

} // namespace ttml::core
6 changes: 6 additions & 0 deletions tt-train/sources/ttml/core/tt_tensor_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,10 @@ tt::tt_metal::Tensor from_xtensor(
const XTensorToMeshVariant<T>& composer,
Layout layout = Layout::TILE);

// Unsqueeze tensor to specified rank by adding leading dimensions of size 1
ttnn::Tensor unsqueeze_to_rank(const ttnn::Tensor& t, size_t rank);

// Squeeze tensor to specified rank by removing leading dimensions of size 1
ttnn::Tensor squeeze_to_rank(const ttnn::Tensor& t, size_t rank);

} // namespace ttml::core
21 changes: 11 additions & 10 deletions tt-train/sources/ttml/core/ttnn_all_includes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,18 @@
#pragma GCC diagnostic ignored "-Wdeprecated-volatile"
#pragma GCC diagnostic ignored "-Wdeprecated-this-capture"

#include <tt-metalium/bfloat16.hpp> // NOLINT
#include <tt-metalium/mesh_device_view.hpp> // NOLINT
#include <cpp/ttnn/operations/copy.hpp> // NOLINT
#include <cpp/ttnn/operations/core/core.hpp> // NOLINT
#include <cpp/ttnn/operations/moreh/moreh_softmax/moreh_softmax.hpp> // NOLINT
#include <cpp/ttnn/operations/moreh/moreh_softmax_backward/moreh_softmax_backward.hpp> // NOLINT
#include <hostdevcommon/common_values.hpp> // NOLINT
#include <tt-metalium/base_types.hpp> // NOLINT
#include <tt-metalium/math.hpp> // NOLINT
#include <tt-metalium/host_api.hpp> // NOLINT
#include <tt-metalium/device_impl.hpp> // NOLINT
#include <tt-metalium/base_types.hpp> // NOLINT
#include <tt-metalium/bfloat16.hpp> // NOLINT
#include <tt-metalium/device_impl.hpp> // NOLINT
#include <tt-metalium/host_api.hpp> // NOLINT
#include <tt-metalium/math.hpp> // NOLINT
#include <tt-metalium/mesh_device_view.hpp> // NOLINT
#include <ttnn/core.hpp> // NOLINT
#include <cpp/ttnn/operations/copy.hpp> // NOLINT
#include <cpp/ttnn/operations/core/core.hpp> // NOLINT
#include <cpp/ttnn/operations/moreh/moreh_softmax/moreh_softmax.hpp> // NOLINT
#include <cpp/ttnn/operations/moreh/moreh_softmax_backward/moreh_softmax_backward.hpp> // NOLINT
#include <ttnn/device.hpp> // NOLINT
#include <ttnn/distributed/api.hpp> // NOLINT
#include <ttnn/distributed/types.hpp> // NOLINT
Expand All @@ -33,6 +33,7 @@
#include <ttnn/operations/data_movement/permute/permute.hpp> // NOLINT
#include <ttnn/operations/data_movement/repeat/repeat.hpp> // NOLINT
#include <ttnn/operations/data_movement/slice/slice.hpp> // NOLINT
#include <ttnn/operations/data_movement/squeeze/squeeze.hpp> // NOLINT
#include <ttnn/operations/data_movement/tilize_with_val_padding/tilize_with_val_padding.hpp> // NOLINT
#include <ttnn/operations/data_movement/transpose/transpose.hpp> // NOLINT
#include <ttnn/operations/data_movement/untilize/untilize.hpp> // NOLINT
Expand Down
28 changes: 28 additions & 0 deletions tt-train/sources/ttml/modules/rms_norm_module.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "rms_norm_module.hpp"

#include "core/tt_tensor_utils.hpp"
#include "ops/rmsnorm_op.hpp"

namespace ttml::modules {

void RMSNormLayer::initialize_tensors(uint32_t features) {
m_gamma =
autograd::create_tensor(core::ones(core::create_shape({1, 1, 1, features}), &autograd::ctx().get_device()));
}

RMSNormLayer::RMSNormLayer(uint32_t features, float epsilon) : m_epsilon(epsilon) {
initialize_tensors(features);

create_name("rmsnorm");
register_tensor(m_gamma, "gamma");
}

autograd::TensorPtr RMSNormLayer::operator()(const autograd::TensorPtr& tensor) {
return ops::rmsnorm(tensor, m_gamma, m_epsilon);
}

} // namespace ttml::modules
27 changes: 27 additions & 0 deletions tt-train/sources/ttml/modules/rms_norm_module.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "autograd/auto_context.hpp"
#include "autograd/graph.hpp"
#include "autograd/module_base.hpp"
#include "autograd/tensor.hpp"
#include "ops/rmsnorm_op.hpp"

namespace ttml::modules {

class RMSNormLayer : public autograd::ModuleBase {
private:
float m_epsilon = 1e-5F;
autograd::TensorPtr m_gamma = nullptr;

public:
void initialize_tensors(uint32_t features);
explicit RMSNormLayer(uint32_t features, float epsilon = 1e-5F);

[[nodiscard]] autograd::TensorPtr operator()(const autograd::TensorPtr& tensor);
};

} // namespace ttml::modules
107 changes: 105 additions & 2 deletions tt-train/sources/ttml/ops/binary_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "binary_ops.hpp"

#include <core/compute_kernel_config.hpp>
#include <core/ttnn_all_includes.hpp>
#include <memory>
#include <ttnn/operations/eltwise/binary/binary.hpp>
Expand Down Expand Up @@ -102,6 +103,42 @@ autograd::TensorPtr operator*(const autograd::TensorPtr& a, const autograd::Tens
auto a_grad = ttnn::multiply(out->get_grad(), b->get_value());
auto b_grad = ttnn::multiply(out->get_grad(), a->get_value());

auto clamp_to_rank = [](const ttnn::Tensor& tensor, size_t rank) {
auto tensor_rank = tensor.shape().rank();
if (tensor_rank == rank) {
return tensor;
} else if (tensor_rank > rank) {
return ttml::core::squeeze_to_rank(tensor, rank);
} else {
return ttml::core::unsqueeze_to_rank(tensor, rank);
}
};

auto logical_suffixes_match = [](const ttnn::Tensor& a, const ttnn::Tensor& b) {
auto a_shape = a.get_logical_shape();
auto b_shape = b.get_logical_shape();

auto suffix_len = std::min(a_shape.size(), b_shape.size());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't it unsigned type? if yes, -suffix_len looks suspicious. Anyway, i would advise to cast it signed in this case (even if size returns signed value for now, because in future it should change)

for (auto i = -1; i >= -suffix_len; i--) {
if (a_shape[i] != b_shape[i]) {
return false;
}
}
return true;
};

if (a->get_value().shape().rank() != a_grad.shape().rank()) {
if (logical_suffixes_match(a->get_value(), a_grad)) {
a_grad = clamp_to_rank(a_grad, a->get_value().shape().rank());
}
}

if (b->get_value().shape().rank() != b_grad.shape().rank()) {
if (logical_suffixes_match(b->get_value(), b_grad)) {
b_grad = clamp_to_rank(b_grad, b->get_value().shape().rank());
}
}

a->add_grad(a_grad);
b->add_grad(b_grad);
};
Expand All @@ -124,6 +161,14 @@ autograd::TensorPtr operator*(const autograd::TensorPtr& a, float b) {
return out;
}

autograd::TensorPtr operator*(float a, const autograd::TensorPtr& b) {
return b * a;
}

autograd::TensorPtr operator/(const autograd::TensorPtr& a, float b) {
return a * (1.F / b);
}

autograd::TensorPtr operator/(const autograd::TensorPtr& a, const autograd::TensorPtr& b) {
auto out = autograd::create_tensor();

Expand Down Expand Up @@ -155,12 +200,70 @@ autograd::TensorPtr mul(const autograd::TensorPtr& a, const autograd::TensorPtr&
return a * b;
}

autograd::TensorPtr mul(const autograd::TensorPtr& a, float b) {
return a * b;
}

autograd::TensorPtr mul(float a, const autograd::TensorPtr& b) {
return b * a;
}

autograd::TensorPtr div(const autograd::TensorPtr& a, const autograd::TensorPtr& b) {
return a / b;
}

autograd::TensorPtr mul(const autograd::TensorPtr& a, float b) {
return a * b;
autograd::TensorPtr div(const autograd::TensorPtr& a, float b) {
return a / b;
}

tt::tt_metal::Tensor ttnn_matmul(
const tt::tt_metal::Tensor& a, const tt::tt_metal::Tensor& b, bool transpose_a, bool transpose_b) {
return ttnn::matmul(
a,
b,
transpose_a,
transpose_b,
/* memory_config */ std::nullopt,
/* dtype */ std::nullopt,
/* program_config */ std::nullopt,
/* activation */ std::nullopt,
/* compute_kernel_config */ core::ComputeKernelConfig::matmul(),
/* core_grid */ std::nullopt, // NOTE: I believe matmul will use the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let better reuse our default grid parameters for now. Also is it a copypaste of the same function in the linear? if yes you probably can put it somewhere else to avoid copypasting.

// core grid for the device it ends up
// running on, but should confirm.
/* output_tile */ std::nullopt);
}

autograd::TensorPtr matmul(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need matmul here? :)

const autograd::TensorPtr& a, const autograd::TensorPtr& b, bool transpose_a, bool transpose_b) {
auto out = autograd::create_tensor();
out->set_value(ttnn_matmul(a->get_value(), b->get_value(), transpose_a, transpose_b));

autograd::GradFunction grad = [a, b, out, transpose_a, transpose_b]() {
// For loss function L and matmul C = AB:
// dL/dA = dL/dC * B^T
// dL/dB = A^T * dL/dC

// where L is the loss function
auto grad_a = ttnn_matmul(
out->get_grad(),
b->get_value(),
/* transpose_a */ transpose_a,
/* transpose_b */ !transpose_b);
auto grad_b = ttnn_matmul(
a->get_value(),
out->get_grad(),
/* transpose_a */ !transpose_a,
/* transpose_b */ transpose_b);

a->add_grad(grad_a);
b->add_grad(grad_b);
};

auto links = autograd::get_links(a, b);
out->set_node(autograd::ctx().add_backward_node(std::move(grad), links));

return out;
}

} // namespace ttml::ops
7 changes: 7 additions & 0 deletions tt-train/sources/ttml/ops/binary_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,21 @@ autograd::TensorPtr operator+(const autograd::TensorPtr& a, const autograd::Auto
autograd::TensorPtr operator+(const autograd::TensorPtr& a, const autograd::TensorPtr& b);
autograd::TensorPtr operator*(const autograd::TensorPtr& a, const autograd::TensorPtr& b);
autograd::TensorPtr operator*(const autograd::TensorPtr& a, float b);
autograd::TensorPtr operator*(float a, const autograd::TensorPtr& b);
autograd::TensorPtr operator-(const autograd::TensorPtr& a, const autograd::TensorPtr& b);
autograd::TensorPtr operator/(const autograd::TensorPtr& a, const autograd::TensorPtr& b);
autograd::TensorPtr operator/(const autograd::TensorPtr& a, float b);

autograd::TensorPtr add(const autograd::TensorPtr& a, const autograd::AutocastTensor& b);
autograd::TensorPtr add(const autograd::TensorPtr& a, const autograd::TensorPtr& b);
autograd::TensorPtr sub(const autograd::TensorPtr& a, const autograd::TensorPtr& b);
autograd::TensorPtr mul(const autograd::TensorPtr& a, const autograd::TensorPtr& b);
autograd::TensorPtr mul(const autograd::TensorPtr& a, float b);
autograd::TensorPtr mul(float a, const autograd::TensorPtr& b);
autograd::TensorPtr div(const autograd::TensorPtr& a, const autograd::TensorPtr& b);
autograd::TensorPtr div(const autograd::TensorPtr& a, float b);

autograd::TensorPtr matmul(
const autograd::TensorPtr& a, const autograd::TensorPtr& b, bool transpose_a, bool transpose_b);

} // namespace ttml::ops
37 changes: 37 additions & 0 deletions tt-train/sources/ttml/ops/rmsnorm_op.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include <core/ttnn_all_includes.hpp>
#include <cstddef>
#include <cstdint>
#include <initializer_list>
#include <optional>

#include "autograd/auto_context.hpp"
#include "autograd/graph.hpp"
#include "autograd/graph_utils.hpp"
#include "autograd/tensor.hpp"
#include "core/compute_kernel_config.hpp"
#include "core/tt_tensor_utils.hpp"
#include "layernorm_op.hpp"
#include "ops/binary_ops.hpp"
#include "ops/unary_ops.hpp"

namespace ttml::ops {

autograd::TensorPtr rmsnorm(const autograd::TensorPtr& tensor, const autograd::TensorPtr& gamma, float epsilon) {
auto squares = tensor * tensor;
std::array<uint32_t, 4> eps_shape = {1, 1, 1, 1};
auto eps_tensor = autograd::create_tensor(
core::from_vector({epsilon}, core::create_shape(eps_shape), &autograd::ctx().get_device()));
auto mean_of_squares = ttml::ops::mean(squares);
auto mean_of_squares_plus_epsilon = mean_of_squares + eps_tensor;
auto rms_eps = ttml::ops::sqrt(mean_of_squares_plus_epsilon);
auto gamma_times_activations = gamma * tensor;
float rms_eps_value = core::to_xtensor(rms_eps->get_value())[0];
auto out = gamma_times_activations / rms_eps_value;
return out;
}

} // namespace ttml::ops
12 changes: 12 additions & 0 deletions tt-train/sources/ttml/ops/rmsnorm_op.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#pragma once
#include "autograd/tensor.hpp"

namespace ttml::ops {

autograd::TensorPtr rmsnorm(const autograd::TensorPtr& tensor, const autograd::TensorPtr& gamma, float epsilon);

} // namespace ttml::ops
32 changes: 32 additions & 0 deletions tt-train/sources/ttml/ops/unary_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,4 +140,36 @@ autograd::TensorPtr broadcast_batch(const autograd::TensorPtr& tensor, uint32_t
return out;
}

autograd::TensorPtr sqrt(const autograd::TensorPtr& tensor) {
auto out = autograd::create_tensor();
auto sqrt_tensor = ttnn::sqrt(tensor->get_value());
out->set_value(sqrt_tensor);
autograd::GradFunction grad = [&tensor, &out, &sqrt_tensor]() {
// dL/dx = dL/d(sqrt(x)) * 1/(2*sqrt(x))
auto grad = ttnn::divide(out->get_grad(), ttnn::multiply(sqrt_tensor, 2.F));
tensor->add_grad(grad);
};
auto links = autograd::get_links(tensor);
out->set_node(autograd::ctx().add_backward_node(std::move(grad), links));
return out;
}

autograd::TensorPtr sum(const autograd::TensorPtr& tensor) {
auto out = autograd::create_tensor();
out->set_value(ttml::ttnn_fixed::sum_moreh(tensor->get_value()));

autograd::GradFunction grad = [tensor, out]() {
// Distribute the gradient to each element in the original tensor
auto in_shape = tensor->get_value().get_logical_shape();
auto grad_shape = out->get_grad().get_logical_shape();

auto unsqueezed_grad = ttml::core::unsqueeze_to_rank(out->get_grad(), in_shape.rank());
auto grad_broadcast = ttnn::repeat(unsqueezed_grad, in_shape);
tensor->add_grad(grad_broadcast);
};

auto links = autograd::get_links(tensor);
out->set_node(autograd::ctx().add_backward_node(std::move(grad), links));
return out;
}
} // namespace ttml::ops
Loading
Loading