Skip to content

Commit

Permalink
#16812: Reordering cbs in reduce_init_delta (#16981)
Browse files Browse the repository at this point in the history
### Ticket
[Link to Github
Issue](#16812)

### Problem description
In all compute kernel APIs, the ordering of the circular buffer
parameters are input cbs first, then output cbs. However, in the
reduce_init_delta function in the reduce.h kernel, the output cb
parameter comes before those of input cbs. This may cause confusion (and
it has in some files) as the expected ordering is inputs first and
outputs later, following all other kernel APIs.

### What's changed
The parameters were reordered such that the input cbs come before the
output cb. Call chains were updated accordingly.

### Checklist
- [x] [Post commit CI
passes](https://github.com/tenstorrent/tt-metal/actions/runs/12917899606)
- [x] [Blackhole Post
commit](https://github.com/tenstorrent/tt-metal/actions/runs/12917901314)
(if applicable)
- [ ] Model regression CI testing passes (if applicable)
- [ ] Device performance regression CI testing passes (if applicable)
- [ ] **(For models and ops writers)** Full [new
models](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml)
tests passes
- [ ] New/Existing tests provide coverage for changes
  • Loading branch information
atatuzunerTT authored Jan 29, 2025
1 parent dff295d commit ec23429
Show file tree
Hide file tree
Showing 24 changed files with 38 additions and 38 deletions.
4 changes: 2 additions & 2 deletions tests/tt_metal/tt_metal/test_kernels/compute/layernorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ void MAIN {
*/
ACQ();
cb_reserve_back(cb_ex, 1 * onetile);
reduce_init_delta<false>(cb_ex, cb_x, cb_scaler);
reduce_init_delta<false>(cb_x, cb_scaler, cb_ex);
for (uint32_t wt = 0; wt < Wt; wt += blk) {
cb_wait_front(cb_x, wt + blk);
for (uint32_t j = 0; j < blk; j++) {
Expand Down Expand Up @@ -154,7 +154,7 @@ void MAIN {
* TODO(AP): can save space here by reusing CB
*/
cb_reserve_back(cb_ex2, 1);
reduce_init_delta<false>(cb_ex2, cb_xmm2, cb_scaler);
reduce_init_delta<false>(cb_xmm2, cb_scaler, cb_ex2);
ACQ();
cb_wait_front(cb_xmm2, Wt);
// cb_wait_front(cb_xmm, Wt);
Expand Down
2 changes: 1 addition & 1 deletion tests/tt_metal/tt_metal/test_kernels/compute/max_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ inline void reduce_h(
uint32_t out_cb_id) {
cb_wait_front(in_cb_id, in_ntiles_hwc * out_nelems);
cb_reserve_back(out_cb_id, out_ntiles_c * out_nelems);
reduce_init_delta<false, PoolType::MAX, ReduceDim::REDUCE_COL>(out_cb_id, in_cb_id, in_scalar_cb_id);
reduce_init_delta<false, PoolType::MAX, ReduceDim::REDUCE_COL>(in_cb_id, in_scalar_cb_id, out_cb_id);
uint32_t base_tile_id = 0;
for (uint32_t c_i = 0; c_i < in_ntiles_c * out_nelems; ++c_i) {
// add to accumulator all the in_ntiles_hw in a column of tiles
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ inline void reduce_h(
uint32_t out_cb_id) {
cb_wait_front(in_cb_id, in_ntiles_hwc * out_nelems);
cb_reserve_back(out_cb_id, out_ntiles_c * out_nelems);
reduce_init_delta<false, PoolType::MAX, ReduceDim::REDUCE_COL>(out_cb_id, in_cb_id, in_scalar_cb_id);
reduce_init_delta<false, PoolType::MAX, ReduceDim::REDUCE_COL>(in_cb_id, in_scalar_cb_id, out_cb_id);
uint32_t base_tile_id = 0;
for (uint32_t c_i = 0; c_i < in_ntiles_c * out_nelems; ++c_i) {
// add to accumulator all the in_ntiles_hw in a column of tiles
Expand Down
2 changes: 1 addition & 1 deletion tests/tt_metal/tt_metal/test_kernels/compute/reduce_h.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ void MAIN {
#ifndef SHORT_INIT
reduce_init<at_start>(tt::CBIndex::c_0, tt::CBIndex::c_2, tt::CBIndex::c_16);
#else
reduce_init_delta<at_start>(tt::CBIndex::c_16, tt::CBIndex::c_0, tt::CBIndex::c_2);
reduce_init_delta<at_start>(tt::CBIndex::c_0, tt::CBIndex::c_2, tt::CBIndex::c_16);
#endif

cb_wait_front(tt::CBIndex::c_2, 1); // scaler tile from the reader
Expand Down
2 changes: 1 addition & 1 deletion tests/tt_metal/tt_metal/test_kernels/compute/reduce_hw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ void MAIN {
#ifndef SHORT_INIT
reduce_init<at_start>(tt::CBIndex::c_0, tt::CBIndex::c_2, tt::CBIndex::c_16);
#else
reduce_init_delta<at_start>(tt::CBIndex::c_16, tt::CBIndex::c_0, tt::CBIndex::c_2);
reduce_init_delta<at_start>(tt::CBIndex::c_0, tt::CBIndex::c_2, tt::CBIndex::c_16);
#endif

cb_wait_front(tt::CBIndex::c_2, 1); // scaler tile from the reader
Expand Down
2 changes: 1 addition & 1 deletion tests/tt_metal/tt_metal/test_kernels/compute/reduce_w.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ void MAIN {
#ifndef SHORT_INIT
reduce_init<at_start>(tt::CBIndex::c_0, tt::CBIndex::c_2, tt::CBIndex::c_16);
#else
reduce_init_delta<at_start>(tt::CBIndex::c_16, tt::CBIndex::c_0, tt::CBIndex::c_2);
reduce_init_delta<at_start>(tt::CBIndex::c_0, tt::CBIndex::c_2, tt::CBIndex::c_16);
#endif

cb_wait_front(tt::CBIndex::c_2, 1); // scaler tile from the reader
Expand Down
2 changes: 1 addition & 1 deletion tests/tt_metal/tt_metal/test_kernels/compute/rmsnorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ void MAIN {
* compute E[(x)^2]
*/
cb_reserve_back(cb_ex2, 1);
reduce_init_delta<false>(cb_ex2, cb_x2, cb_scaler);
reduce_init_delta<false>(cb_x2, cb_scaler, cb_ex2);
ACQ();
cb_wait_front(cb_x2, Wt);
// cb_wait_front(cb_xmm, Wt);
Expand Down
2 changes: 1 addition & 1 deletion tests/tt_metal/tt_metal/test_kernels/compute/softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ void MAIN {

ACQ();
cb_reserve_back(cb_recipsumexps, onetile);
reduce_init_delta<false>(cb_recipsumexps, cb_exps, cb_bcast_scaler);
reduce_init_delta<false>(cb_exps, cb_bcast_scaler, cb_recipsumexps);
for (uint32_t wt = 0; wt < Wt; wt++) {
cb_wait_front(cb_exps, wt + 1); // must be a cumulative wait for correctness
constexpr uint32_t bcast_scaler0 = 0; // 0th index from bcast_scaler CB
Expand Down
2 changes: 1 addition & 1 deletion tt_metal/include/compute_kernel_api/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ ALWI void reduce_init_short(uint32_t icb, uint32_t icb_scaler, uint32_t ocb) {
}

template <bool at_start, PoolType reduce_type = REDUCE_OP, ReduceDim reduce_dim = REDUCE_DIM>
ALWI void reduce_init_delta(uint32_t ocb, uint32_t icb0, uint32_t icb1) {
ALWI void reduce_init_delta(uint32_t icb0, uint32_t icb1, uint32_t ocb) {
// FIXME: API Update needed in compute kernel?
UNPACK((llk_unpack_AB_reduce_init<reduce_dim>(icb0, icb1)));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ ALWI void reduce_init_delta_with_dt(uint32_t ocb = 16, uint32_t icb0 = 0, uint32
#if defined FP32_DEST_ACC_EN
reconfig_data_format(icb0, icb1);
#endif
reduce_init_delta<at_start, reduce_type, reduce_dim>(ocb, icb0, icb1);
reduce_init_delta<at_start, reduce_type, reduce_dim>(icb0, icb1, ocb);
}

class ArgFetcher {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ void MAIN {
cb_wait_front(cb_xpowadd, onetile);
cb_reserve_back(cb_y, onetile);

reduce_init_delta<false>(cb_y, cb_xpowadd, cb_one);
reduce_init_delta<false>(cb_xpowadd, cb_one, cb_y);
reduce_tile(cb_xpowadd, cb_one, 0, 0, dst0);
reduce_revert_delta(cb_y);
tile_regs_commit();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ void MAIN {
}

cb_wait_front(tt::CBIndex::c_24, onetile);
reduce_init_delta<false>(tt::CBIndex::c_16, tt::CBIndex::c_24, tt::CBIndex::c_2);
reduce_init_delta<false>(tt::CBIndex::c_24, tt::CBIndex::c_2, tt::CBIndex::c_16);
reduce_tile(tt::CBIndex::c_24, tt::CBIndex::c_2, 0, 0, 0);
cb_pop_front(tt::CBIndex::c_24, onetile);
reduce_revert_delta(tt::CBIndex::c_16);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ void MAIN {
#if defined FP32_DEST_ACC_EN
reconfig_data_format(cb_reduce, cb_scaler);
#endif
reduce_init_delta<false>(cb_out0, cb_reduce, cb_scaler);
reduce_init_delta<false>(cb_reduce, cb_scaler, cb_out0);
reduce_tile(cb_reduce, cb_scaler, 0, 0, 0);
reduce_revert_delta(cb_out0);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ void MAIN {
#if defined FP32_DEST_ACC_EN
reconfig_data_format(cb_reduce, cb_scaler);
#endif
reduce_init_delta<false>(cb_out0, cb_reduce, cb_scaler);
reduce_init_delta<false>(cb_reduce, cb_scaler, cb_out0);
reduce_tile((do_mask) ? (cb_intermed0) : (cb_in0), cb_scaler, 0, 0, 0);
reduce_revert_delta(cb_out0);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ void MAIN {
#if defined FP32_DEST_ACC_EN
reconfig_data_format(cb_input, cb_scaler);
#endif
reduce_init_delta<false>(cb_accum_dst, cb_input, cb_scaler);
reduce_init_delta<false>(cb_input, cb_scaler, cb_accum_dst);
reduce_tile(cb_input, cb_scaler, 0, 0, reduce_dst_idx);
reduce_revert_delta(cb_accum_dst);

Expand Down Expand Up @@ -104,7 +104,7 @@ void MAIN {
#if defined FP32_DEST_ACC_EN
reconfig_data_format(cb_input, cb_scaler);
#endif
reduce_init_delta<false>(cb_out, cb_input, cb_scaler);
reduce_init_delta<false>(cb_input, cb_scaler, cb_out);
reduce_tile(cb_input, cb_scaler, 0, 0, reduce_dst_idx);
reduce_revert_delta(cb_out);
tile_regs_commit();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ void MAIN {

// Partial-E[x]
index_h_offset = 0;
reduce_init_delta<false>(cb_ex_partial, cb_x, cb_scaler);
reduce_init_delta<false>(cb_x, cb_scaler, cb_ex_partial);
cb_reserve_back(cb_ex_partial, 1);
tile_regs_acquire();
cb_wait_front(cb_scaler, 1);
Expand All @@ -219,7 +219,7 @@ void MAIN {
reduce_revert_delta(cb_ex_partial);

if constexpr (is_mcast_sender and num_cores_per_mcast_group > 1) {
reduce_init_delta<false>(cb_ex_global, cb_ex_external, cb_scaler_global);
reduce_init_delta<false>(cb_ex_external, cb_scaler_global, cb_ex_global);
cb_reserve_back(cb_ex_global, 1);
cb_reserve_back(cb_ex, 1);
tile_regs_acquire();
Expand Down Expand Up @@ -316,7 +316,7 @@ void MAIN {

// Partial-Var(x)
index_h_offset = 0;
reduce_init_delta<false>(cb_ex_partial, cb_xmm, cb_scaler);
reduce_init_delta<false>(cb_xmm, cb_scaler, cb_ex_partial);
cb_reserve_back(cb_ex_partial, 1);
tile_regs_acquire();
cb_wait_front(cb_xmm, block_hw);
Expand All @@ -337,7 +337,7 @@ void MAIN {
reduce_revert_delta(cb_ex_partial);

if constexpr (is_mcast_sender and num_cores_per_mcast_group > 1) {
reduce_init_delta<false>(cb_ex_global, cb_ex_external, cb_scaler_global);
reduce_init_delta<false>(cb_ex_external, cb_scaler_global, cb_ex_global);
cb_reserve_back(cb_ex_global, 1);
cb_reserve_back(cb_ex, 1);
tile_regs_acquire();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ void MAIN {
*/
ACQ();
cb_reserve_back(cb_ex, onetile);
reduce_init_delta<false>(cb_ex, cb_x, cb_scaler);
reduce_init_delta<false>(cb_x, cb_scaler, cb_ex);
for (uint32_t wt = 0; wt < Wt; wt += blk) {
cb_wait_front(cb_x, wt + blk);
for (uint32_t j = 0; j < blk; j++) {
Expand Down Expand Up @@ -193,7 +193,7 @@ void MAIN {
reconfig_data_format(cb_xmm2, cb_scaler);
}
cb_reserve_back(cb_ex2, 1);
reduce_init_delta<false>(cb_ex2, cb_xmm2, cb_scaler);
reduce_init_delta<false>(cb_xmm2, cb_scaler, cb_ex2);
ACQ();
cb_wait_front(cb_xmm2, Wt);
// cb_wait_front(cb_xmm, Wt);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ void MAIN {
#ifndef RMSNORM
// E[x],
index_h_offset = 0;
reduce_init_delta<false>(cb_ex_partial, cb_in, cb_scaler);
reduce_init_delta<false>(cb_in, cb_scaler, cb_ex_partial);
cb_wait_front(cb_scaler, 1);
cb_reserve_back(cb_ex_partial, block_h);
for (uint32_t i = 0; i < block_h; i++) {
Expand All @@ -163,7 +163,7 @@ void MAIN {

// global reduce, cb_ex <-- cb_ex_external, cb_ex_partial
if constexpr (is_allgather_worker) {
reduce_init_delta<false>(cb_ex, cb_ex_external, cb_scaler_global);
reduce_init_delta<false>(cb_ex_external, cb_scaler_global, cb_ex);
cb_reserve_back(cb_ex, num_tiles_per_allgather_worker);

for (uint32_t i = 0; i < num_tiles_per_allgather_worker; i++) {
Expand Down Expand Up @@ -262,7 +262,7 @@ void MAIN {
cb_wait_front(cb_scaler, 1);
#endif
cb_reserve_back(cb_ex_partial2, block_h);
reduce_init_delta<false>(cb_ex_partial2, cb_xmm2, cb_scaler);
reduce_init_delta<false>(cb_xmm2, cb_scaler, cb_ex_partial2);
index_h_offset = 0;
for (uint32_t i = 0; i < block_h; i++) {
tile_regs_acquire();
Expand All @@ -281,7 +281,7 @@ void MAIN {

// global reduce, cb_ex <-- cb_ex_external, cb_ex_partial
if constexpr (is_allgather_worker) {
reduce_init_delta<false>(cb_ex2, cb_ex_external2, cb_scaler_global);
reduce_init_delta<false>(cb_ex_external2, cb_scaler_global, cb_ex2);
cb_reserve_back(cb_ex2, num_tiles_per_allgather_worker);

for (uint32_t i = 0; i < num_tiles_per_allgather_worker; i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ void MAIN {
#endif

cb_wait_front(cb_scaler_global, 1);
reduce_init_delta<false>(cb_var, cb_stats, cb_scaler_global);
reduce_init_delta<false>(cb_stats, cb_scaler_global, cb_var);
tile_regs_acquire();
// striding over cb_stats, consisting [E(X), E(X^2)] from all the distributed devices in interleaved order
for (uint32_t w = 0; w < stats_tiles * num_distributed_blocks; w++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ void MAIN {
#endif
// E[x],
index_h_offset = 0;
reduce_init_delta<false>(cb_ex_partial2, cb_in0, cb_scaler);
reduce_init_delta<false>(cb_in0, cb_scaler, cb_ex_partial2);

cb_reserve_back(cb_ex_partial2, block_h);
for (uint32_t i = 0; i < block_h; i++) {
Expand Down Expand Up @@ -177,7 +177,7 @@ void MAIN {

cb_reserve_back(cb_ex_partial2, block_h); // RMS E(x2) #Layernorm //E(x) and E(x^2)

reduce_init_delta<false>(cb_ex_partial2, cb_x2, cb_scaler);
reduce_init_delta<false>(cb_x2, cb_scaler, cb_ex_partial2);
index_h_offset = 0;
for (uint32_t i = 0; i < block_h; i++) {
tile_regs_acquire();
Expand All @@ -200,7 +200,7 @@ void MAIN {
cb_wait_front(cb_scaler_global, 1);
reconfig_data_format_srca(cb_x2, cb_ex_external2);
reconfig_data_format_srcb(cb_scaler, cb_scaler_global);
reduce_init_delta<false>(cb_reduction_out, cb_ex_external2, cb_scaler_global);
reduce_init_delta<false>(cb_ex_external2, cb_scaler_global, cb_reduction_out);
cb_reserve_back(cb_reduction_out, num_tiles_per_partial_result * num_tiles_per_allgather_worker);

for (uint32_t i = 0; i < num_tiles_per_allgather_worker; i++) { // loops over height
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ void MAIN {
* cb_stats = [sum(x0**2), sum(x0), sum(x1**2), sum(x1), ...]
* RMSNorm packs mean(x**2) into cb_var. Layernorm just uses cb_stats_reduced.
*/
reduce_init_delta<false>(cb_stats_reduced, cb_stats, cb_reduce);
reduce_init_delta<false>(cb_stats, cb_reduce, cb_stats_reduced);
cb_wait_front(cb_stats, stats_tiles_cols);
cb_reserve_back(cb_stats_reduced, stats_tile_stride);
#ifdef RMSNORM
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ void MAIN {
*/
reconfig_data_format(cb_x2, cb_reduce);
pack_reconfig_data_format(cb_out);
reduce_init_delta<false>(cb_out, cb_x2, cb_reduce);
reduce_init_delta<false>(cb_x2, cb_reduce, cb_out);
cb_wait_front(cb_x2, Wt);
cb_reserve_back(cb_out, onetile);
ACQ();
Expand All @@ -89,7 +89,7 @@ void MAIN {
*/
reconfig_data_format(cb_inp, cb_reduce);
pack_reconfig_data_format(cb_out);
reduce_init_delta<false>(cb_out, cb_inp, cb_reduce);
reduce_init_delta<false>(cb_inp, cb_reduce, cb_out);
cb_reserve_back(cb_out, onetile);
ACQ();
for (uint32_t wtr = 0; wtr < Wt; wtr++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ void calc_numeric_stable(
reconfig_data_format(cb_in, cb_bcast_scaler);
cb_reserve_back(cb_max, 1);
cb_wait_front(cb_bcast_scaler, 1);
reduce_init_delta<false, PoolType::MAX, ReduceDim::REDUCE_ROW>(cb_max, cb_in, cb_bcast_scaler);
reduce_init_delta<false, PoolType::MAX, ReduceDim::REDUCE_ROW>(cb_in, cb_bcast_scaler, cb_max);
for (uint32_t wt = 0; wt < Wt; wt++) {
cb_wait_front(cb_in, wt + 1);
constexpr uint32_t bcast_scaler0 = 0;
Expand Down Expand Up @@ -256,7 +256,7 @@ void MAIN {

ACQ();
cb_reserve_back(cb_recipsumexps, onetile);
reduce_init_delta<false>(cb_recipsumexps, cb_exps, cb_bcast_scaler);
reduce_init_delta<false>(cb_exps, cb_bcast_scaler, cb_recipsumexps);
for (uint32_t wt = 0; wt < Wt; wt++) {
cb_wait_front(cb_exps, wt + 1); // must be a cumulative wait for correctness
constexpr uint32_t bcast_scaler0 = 0; // 0th index from bcast_scaler CB
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ ALWI void calc_numeric_stable(uint32_t cb_in, uint32_t cb_bcast_scaler, uint32_t
reconfig_data_format(cb_in, cb_bcast_scaler);
pack_reconfig_data_format(cb_max);
cb_reserve_back(cb_max, 1);
reduce_init_delta<false, PoolType::MAX, ReduceDim::REDUCE_ROW>(cb_max, cb_in, cb_bcast_scaler);
reduce_init_delta<false, PoolType::MAX, ReduceDim::REDUCE_ROW>(cb_in, cb_bcast_scaler, cb_max);
cb_wait_front(cb_bcast_scaler, 1);
for (uint32_t w = 0; w < block_w; w++) {
constexpr uint32_t bcast_scaler0 = 0;
Expand Down Expand Up @@ -203,7 +203,7 @@ void MAIN {

// sum(exp(x))
ACQ();
reduce_init_delta<false>(cb_recipsumexps, cb_exps, cb_bcast_scaler);
reduce_init_delta<false>(cb_exps, cb_bcast_scaler, cb_recipsumexps);
cb_wait_front(cb_exps, block_w);
cb_wait_front(cb_bcast_scaler, 1);
cb_reserve_back(cb_recipsumexps, 1);
Expand Down

0 comments on commit ec23429

Please sign in to comment.