Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tt-train] Add RMSNorm module #16991

Draft
wants to merge 20 commits into
base: main
Choose a base branch
from
Draft

[tt-train] Add RMSNorm module #16991

wants to merge 20 commits into from

Conversation

jaykru-tt
Copy link
Contributor

@jaykru-tt jaykru-tt commented Jan 22, 2025

Problem description

We need RMSNorm to train Llama 3 and some other exciting open source models.

What's changed

  • Added sqrt and matmul ops with backward to support RMS
  • Added RMS op (defined as composite of existing ops and the new ops mentioned above)
  • Added RMS module

Checklist

  • Post commit CI passes
  • Blackhole Post commit (if applicable)
  • Model regression CI testing passes (if applicable)
  • Device performance regression CI testing passes (if applicable)
  • (For models and ops writers) Full new models tests passes
  • New/Existing tests provide coverage for changes

@jaykru-tt jaykru-tt changed the title Jkruer/rmsnorm [tt-train] Add RMSNorm module Jan 22, 2025

class RMSNormLayer : public autograd::ModuleBase {
private:
float m_epsilon;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please don't forgeet default initialization.


public:
void initialize_tensors(uint32_t features);
explicit RMSNormLayer(uint32_t features, std::optional<float> epsilon = std::nullopt);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we always need optional. Overall I am not a fun of active using of the std::optional if not really needed.

/* program_config */ std::nullopt,
/* activation */ std::nullopt,
/* compute_kernel_config */ core::ComputeKernelConfig::matmul(),
/* core_grid */ std::nullopt, // NOTE: I believe matmul will use the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let better reuse our default grid parameters for now. Also is it a copypaste of the same function in the linear? if yes you probably can put it somewhere else to avoid copypasting.

auto grad_a = ttnn_matmul(
out->get_grad(),
b->get_value(),
/* transpose_a */ false,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you sure you don't need to change this params depending on transpose_a, transpose_b

/* transpose_a */ false,
/* transpose_b */ true);
auto eps_tensor =
autograd::create_tensor(core::from_xtensor(xt::xarray<float>{epsilon}, &autograd::ctx().get_device()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better to implement our op which takes tensor and scalar then create an even small tensor every step.

/* output_tile */ std::nullopt);
}

autograd::TensorPtr matmul(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need matmul here? :)

auto a_shape = a.get_logical_shape();
auto b_shape = b.get_logical_shape();

auto suffix_len = std::min(a_shape.size(), b_shape.size());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't it unsigned type? if yes, -suffix_len looks suspicious. Anyway, i would advise to cast it signed in this case (even if size returns signed value for now, because in future it should change)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants