-
Notifications
You must be signed in to change notification settings - Fork 97
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[tt-train] Add RMSNorm module #16991
base: main
Are you sure you want to change the base?
Changes from all commits
13307ed
9ecfc43
1f9179d
e29bafd
7d38670
ee3da62
8eb9654
57c1dfe
b86e42e
9c7d8ac
5d0b6e2
a2c82a8
afc754f
03a5e73
3110486
0748e52
14f97b4
53ca8b5
753f4bc
2b6df8f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "rms_norm_module.hpp" | ||
|
||
#include "core/tt_tensor_utils.hpp" | ||
#include "ops/rmsnorm_op.hpp" | ||
|
||
namespace ttml::modules { | ||
|
||
void RMSNormLayer::initialize_tensors(uint32_t features) { | ||
m_gamma = | ||
autograd::create_tensor(core::ones(core::create_shape({1, 1, 1, features}), &autograd::ctx().get_device())); | ||
} | ||
|
||
RMSNormLayer::RMSNormLayer(uint32_t features, float epsilon) : m_epsilon(epsilon) { | ||
initialize_tensors(features); | ||
|
||
create_name("rmsnorm"); | ||
register_tensor(m_gamma, "gamma"); | ||
} | ||
|
||
autograd::TensorPtr RMSNormLayer::operator()(const autograd::TensorPtr& tensor) { | ||
return ops::rmsnorm(tensor, m_gamma, m_epsilon); | ||
} | ||
|
||
} // namespace ttml::modules |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
|
||
#include "autograd/auto_context.hpp" | ||
#include "autograd/graph.hpp" | ||
#include "autograd/module_base.hpp" | ||
#include "autograd/tensor.hpp" | ||
#include "ops/rmsnorm_op.hpp" | ||
|
||
namespace ttml::modules { | ||
|
||
class RMSNormLayer : public autograd::ModuleBase { | ||
private: | ||
float m_epsilon = 1e-5F; | ||
autograd::TensorPtr m_gamma = nullptr; | ||
|
||
public: | ||
void initialize_tensors(uint32_t features); | ||
explicit RMSNormLayer(uint32_t features, float epsilon = 1e-5F); | ||
|
||
[[nodiscard]] autograd::TensorPtr operator()(const autograd::TensorPtr& tensor); | ||
}; | ||
|
||
} // namespace ttml::modules |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
|
||
#include "binary_ops.hpp" | ||
|
||
#include <core/compute_kernel_config.hpp> | ||
#include <core/ttnn_all_includes.hpp> | ||
#include <memory> | ||
#include <ttnn/operations/eltwise/binary/binary.hpp> | ||
|
@@ -102,6 +103,42 @@ autograd::TensorPtr operator*(const autograd::TensorPtr& a, const autograd::Tens | |
auto a_grad = ttnn::multiply(out->get_grad(), b->get_value()); | ||
auto b_grad = ttnn::multiply(out->get_grad(), a->get_value()); | ||
|
||
auto clamp_to_rank = [](const ttnn::Tensor& tensor, size_t rank) { | ||
auto tensor_rank = tensor.shape().rank(); | ||
if (tensor_rank == rank) { | ||
return tensor; | ||
} else if (tensor_rank > rank) { | ||
return ttml::core::squeeze_to_rank(tensor, rank); | ||
} else { | ||
return ttml::core::unsqueeze_to_rank(tensor, rank); | ||
} | ||
}; | ||
|
||
auto logical_suffixes_match = [](const ttnn::Tensor& a, const ttnn::Tensor& b) { | ||
auto a_shape = a.get_logical_shape(); | ||
auto b_shape = b.get_logical_shape(); | ||
|
||
auto suffix_len = std::min(a_shape.size(), b_shape.size()); | ||
for (auto i = -1; i >= -suffix_len; i--) { | ||
if (a_shape[i] != b_shape[i]) { | ||
return false; | ||
} | ||
} | ||
return true; | ||
}; | ||
|
||
if (a->get_value().shape().rank() != a_grad.shape().rank()) { | ||
if (logical_suffixes_match(a->get_value(), a_grad)) { | ||
a_grad = clamp_to_rank(a_grad, a->get_value().shape().rank()); | ||
} | ||
} | ||
|
||
if (b->get_value().shape().rank() != b_grad.shape().rank()) { | ||
if (logical_suffixes_match(b->get_value(), b_grad)) { | ||
b_grad = clamp_to_rank(b_grad, b->get_value().shape().rank()); | ||
} | ||
} | ||
|
||
a->add_grad(a_grad); | ||
b->add_grad(b_grad); | ||
}; | ||
|
@@ -124,6 +161,14 @@ autograd::TensorPtr operator*(const autograd::TensorPtr& a, float b) { | |
return out; | ||
} | ||
|
||
autograd::TensorPtr operator*(float a, const autograd::TensorPtr& b) { | ||
return b * a; | ||
} | ||
|
||
autograd::TensorPtr operator/(const autograd::TensorPtr& a, float b) { | ||
return a * (1.F / b); | ||
} | ||
|
||
autograd::TensorPtr operator/(const autograd::TensorPtr& a, const autograd::TensorPtr& b) { | ||
auto out = autograd::create_tensor(); | ||
|
||
|
@@ -155,12 +200,70 @@ autograd::TensorPtr mul(const autograd::TensorPtr& a, const autograd::TensorPtr& | |
return a * b; | ||
} | ||
|
||
autograd::TensorPtr mul(const autograd::TensorPtr& a, float b) { | ||
return a * b; | ||
} | ||
|
||
autograd::TensorPtr mul(float a, const autograd::TensorPtr& b) { | ||
return b * a; | ||
} | ||
|
||
autograd::TensorPtr div(const autograd::TensorPtr& a, const autograd::TensorPtr& b) { | ||
return a / b; | ||
} | ||
|
||
autograd::TensorPtr mul(const autograd::TensorPtr& a, float b) { | ||
return a * b; | ||
autograd::TensorPtr div(const autograd::TensorPtr& a, float b) { | ||
return a / b; | ||
} | ||
|
||
tt::tt_metal::Tensor ttnn_matmul( | ||
const tt::tt_metal::Tensor& a, const tt::tt_metal::Tensor& b, bool transpose_a, bool transpose_b) { | ||
return ttnn::matmul( | ||
a, | ||
b, | ||
transpose_a, | ||
transpose_b, | ||
/* memory_config */ std::nullopt, | ||
/* dtype */ std::nullopt, | ||
/* program_config */ std::nullopt, | ||
/* activation */ std::nullopt, | ||
/* compute_kernel_config */ core::ComputeKernelConfig::matmul(), | ||
/* core_grid */ std::nullopt, // NOTE: I believe matmul will use the | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let better reuse our default grid parameters for now. Also is it a copypaste of the same function in the linear? if yes you probably can put it somewhere else to avoid copypasting. |
||
// core grid for the device it ends up | ||
// running on, but should confirm. | ||
/* output_tile */ std::nullopt); | ||
} | ||
|
||
autograd::TensorPtr matmul( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we need matmul here? :) |
||
const autograd::TensorPtr& a, const autograd::TensorPtr& b, bool transpose_a, bool transpose_b) { | ||
auto out = autograd::create_tensor(); | ||
out->set_value(ttnn_matmul(a->get_value(), b->get_value(), transpose_a, transpose_b)); | ||
|
||
autograd::GradFunction grad = [a, b, out, transpose_a, transpose_b]() { | ||
// For loss function L and matmul C = AB: | ||
// dL/dA = dL/dC * B^T | ||
// dL/dB = A^T * dL/dC | ||
|
||
// where L is the loss function | ||
auto grad_a = ttnn_matmul( | ||
out->get_grad(), | ||
b->get_value(), | ||
/* transpose_a */ transpose_a, | ||
/* transpose_b */ !transpose_b); | ||
auto grad_b = ttnn_matmul( | ||
a->get_value(), | ||
out->get_grad(), | ||
/* transpose_a */ !transpose_a, | ||
/* transpose_b */ transpose_b); | ||
|
||
a->add_grad(grad_a); | ||
b->add_grad(grad_b); | ||
}; | ||
|
||
auto links = autograd::get_links(a, b); | ||
out->set_node(autograd::ctx().add_backward_node(std::move(grad), links)); | ||
|
||
return out; | ||
} | ||
|
||
} // namespace ttml::ops |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include <core/ttnn_all_includes.hpp> | ||
#include <cstddef> | ||
#include <cstdint> | ||
#include <initializer_list> | ||
#include <optional> | ||
|
||
#include "autograd/auto_context.hpp" | ||
#include "autograd/graph.hpp" | ||
#include "autograd/graph_utils.hpp" | ||
#include "autograd/tensor.hpp" | ||
#include "core/compute_kernel_config.hpp" | ||
#include "core/tt_tensor_utils.hpp" | ||
#include "layernorm_op.hpp" | ||
#include "ops/binary_ops.hpp" | ||
#include "ops/unary_ops.hpp" | ||
|
||
namespace ttml::ops { | ||
|
||
autograd::TensorPtr rmsnorm(const autograd::TensorPtr& tensor, const autograd::TensorPtr& gamma, float epsilon) { | ||
auto squares = tensor * tensor; | ||
std::array<uint32_t, 4> eps_shape = {1, 1, 1, 1}; | ||
auto eps_tensor = autograd::create_tensor( | ||
core::from_vector({epsilon}, core::create_shape(eps_shape), &autograd::ctx().get_device())); | ||
auto mean_of_squares = ttml::ops::mean(squares); | ||
auto mean_of_squares_plus_epsilon = mean_of_squares + eps_tensor; | ||
auto rms_eps = ttml::ops::sqrt(mean_of_squares_plus_epsilon); | ||
auto gamma_times_activations = gamma * tensor; | ||
float rms_eps_value = core::to_xtensor(rms_eps->get_value())[0]; | ||
auto out = gamma_times_activations / rms_eps_value; | ||
return out; | ||
} | ||
|
||
} // namespace ttml::ops |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
#include "autograd/tensor.hpp" | ||
|
||
namespace ttml::ops { | ||
|
||
autograd::TensorPtr rmsnorm(const autograd::TensorPtr& tensor, const autograd::TensorPtr& gamma, float epsilon); | ||
|
||
} // namespace ttml::ops |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isn't it unsigned type? if yes,
-suffix_len
looks suspicious. Anyway, i would advise to cast it signed in this case (even if size returns signed value for now, because in future it should change)