Skip to content

Commit

Permalink
Reordering cbs in reduce_init_delta
Browse files Browse the repository at this point in the history
  • Loading branch information
atatuzunerTT committed Jan 29, 2025
1 parent 776a92d commit 81a72b3
Show file tree
Hide file tree
Showing 24 changed files with 38 additions and 38 deletions.
4 changes: 2 additions & 2 deletions tests/tt_metal/tt_metal/test_kernels/compute/layernorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ void MAIN {
*/
ACQ();
cb_reserve_back(cb_ex, 1 * onetile);
reduce_init_delta<false>(cb_ex, cb_x, cb_scaler);
reduce_init_delta<false>(cb_x, cb_scaler, cb_ex);
for (uint32_t wt = 0; wt < Wt; wt += blk) {
cb_wait_front(cb_x, wt + blk);
for (uint32_t j = 0; j < blk; j++) {
Expand Down Expand Up @@ -154,7 +154,7 @@ void MAIN {
* TODO(AP): can save space here by reusing CB
*/
cb_reserve_back(cb_ex2, 1);
reduce_init_delta<false>(cb_ex2, cb_xmm2, cb_scaler);
reduce_init_delta<false>(cb_xmm2, cb_scaler, cb_ex2);
ACQ();
cb_wait_front(cb_xmm2, Wt);
// cb_wait_front(cb_xmm, Wt);
Expand Down
2 changes: 1 addition & 1 deletion tests/tt_metal/tt_metal/test_kernels/compute/max_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ inline void reduce_h(
uint32_t out_cb_id) {
cb_wait_front(in_cb_id, in_ntiles_hwc * out_nelems);
cb_reserve_back(out_cb_id, out_ntiles_c * out_nelems);
reduce_init_delta<false, PoolType::MAX, ReduceDim::REDUCE_COL>(out_cb_id, in_cb_id, in_scalar_cb_id);
reduce_init_delta<false, PoolType::MAX, ReduceDim::REDUCE_COL>(in_cb_id, in_scalar_cb_id, out_cb_id);
uint32_t base_tile_id = 0;
for (uint32_t c_i = 0; c_i < in_ntiles_c * out_nelems; ++c_i) {
// add to accumulator all the in_ntiles_hw in a column of tiles
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ inline void reduce_h(
uint32_t out_cb_id) {
cb_wait_front(in_cb_id, in_ntiles_hwc * out_nelems);
cb_reserve_back(out_cb_id, out_ntiles_c * out_nelems);
reduce_init_delta<false, PoolType::MAX, ReduceDim::REDUCE_COL>(out_cb_id, in_cb_id, in_scalar_cb_id);
reduce_init_delta<false, PoolType::MAX, ReduceDim::REDUCE_COL>(in_cb_id, in_scalar_cb_id, out_cb_id);
uint32_t base_tile_id = 0;
for (uint32_t c_i = 0; c_i < in_ntiles_c * out_nelems; ++c_i) {
// add to accumulator all the in_ntiles_hw in a column of tiles
Expand Down
2 changes: 1 addition & 1 deletion tests/tt_metal/tt_metal/test_kernels/compute/reduce_h.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ void MAIN {
#ifndef SHORT_INIT
reduce_init<at_start>(tt::CBIndex::c_0, tt::CBIndex::c_2, tt::CBIndex::c_16);
#else
reduce_init_delta<at_start>(tt::CBIndex::c_16, tt::CBIndex::c_0, tt::CBIndex::c_2);
reduce_init_delta<at_start>(tt::CBIndex::c_0, tt::CBIndex::c_2, tt::CBIndex::c_16);
#endif

cb_wait_front(tt::CBIndex::c_2, 1); // scaler tile from the reader
Expand Down
2 changes: 1 addition & 1 deletion tests/tt_metal/tt_metal/test_kernels/compute/reduce_hw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ void MAIN {
#ifndef SHORT_INIT
reduce_init<at_start>(tt::CBIndex::c_0, tt::CBIndex::c_2, tt::CBIndex::c_16);
#else
reduce_init_delta<at_start>(tt::CBIndex::c_16, tt::CBIndex::c_0, tt::CBIndex::c_2);
reduce_init_delta<at_start>(tt::CBIndex::c_0, tt::CBIndex::c_2, tt::CBIndex::c_16);
#endif

cb_wait_front(tt::CBIndex::c_2, 1); // scaler tile from the reader
Expand Down
2 changes: 1 addition & 1 deletion tests/tt_metal/tt_metal/test_kernels/compute/reduce_w.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ void MAIN {
#ifndef SHORT_INIT
reduce_init<at_start>(tt::CBIndex::c_0, tt::CBIndex::c_2, tt::CBIndex::c_16);
#else
reduce_init_delta<at_start>(tt::CBIndex::c_16, tt::CBIndex::c_0, tt::CBIndex::c_2);
reduce_init_delta<at_start>(tt::CBIndex::c_0, tt::CBIndex::c_2, tt::CBIndex::c_16);
#endif

cb_wait_front(tt::CBIndex::c_2, 1); // scaler tile from the reader
Expand Down
2 changes: 1 addition & 1 deletion tests/tt_metal/tt_metal/test_kernels/compute/rmsnorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ void MAIN {
* compute E[(x)^2]
*/
cb_reserve_back(cb_ex2, 1);
reduce_init_delta<false>(cb_ex2, cb_x2, cb_scaler);
reduce_init_delta<false>(cb_x2, cb_scaler, cb_ex2);
ACQ();
cb_wait_front(cb_x2, Wt);
// cb_wait_front(cb_xmm, Wt);
Expand Down
2 changes: 1 addition & 1 deletion tests/tt_metal/tt_metal/test_kernels/compute/softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ void MAIN {

ACQ();
cb_reserve_back(cb_recipsumexps, onetile);
reduce_init_delta<false>(cb_recipsumexps, cb_exps, cb_bcast_scaler);
reduce_init_delta<false>(cb_exps, cb_bcast_scaler, cb_recipsumexps);
for (uint32_t wt = 0; wt < Wt; wt++) {
cb_wait_front(cb_exps, wt + 1); // must be a cumulative wait for correctness
constexpr uint32_t bcast_scaler0 = 0; // 0th index from bcast_scaler CB
Expand Down
2 changes: 1 addition & 1 deletion tt_metal/include/compute_kernel_api/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ ALWI void reduce_init_short(uint32_t icb, uint32_t icb_scaler, uint32_t ocb) {
}

template <bool at_start, PoolType reduce_type = REDUCE_OP, ReduceDim reduce_dim = REDUCE_DIM>
ALWI void reduce_init_delta(uint32_t ocb, uint32_t icb0, uint32_t icb1) {
ALWI void reduce_init_delta(uint32_t icb0, uint32_t icb1, uint32_t ocb) {
// FIXME: API Update needed in compute kernel?
UNPACK((llk_unpack_AB_reduce_init<reduce_dim>(icb0, icb1)));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ ALWI void reduce_init_delta_with_dt(uint32_t ocb = 16, uint32_t icb0 = 0, uint32
#if defined FP32_DEST_ACC_EN
reconfig_data_format(icb0, icb1);
#endif
reduce_init_delta<at_start, reduce_type, reduce_dim>(ocb, icb0, icb1);
reduce_init_delta<at_start, reduce_type, reduce_dim>(icb0, icb1, ocb);
}

class ArgFetcher {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ void MAIN {
cb_wait_front(cb_xpowadd, onetile);
cb_reserve_back(cb_y, onetile);

reduce_init_delta<false>(cb_y, cb_xpowadd, cb_one);
reduce_init_delta<false>(cb_xpowadd, cb_one, cb_y);
reduce_tile(cb_xpowadd, cb_one, 0, 0, dst0);
reduce_revert_delta(cb_y);
tile_regs_commit();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ void MAIN {
}

cb_wait_front(tt::CBIndex::c_24, onetile);
reduce_init_delta<false>(tt::CBIndex::c_16, tt::CBIndex::c_24, tt::CBIndex::c_2);
reduce_init_delta<false>(tt::CBIndex::c_24, tt::CBIndex::c_2, tt::CBIndex::c_16);
reduce_tile(tt::CBIndex::c_24, tt::CBIndex::c_2, 0, 0, 0);
cb_pop_front(tt::CBIndex::c_24, onetile);
reduce_revert_delta(tt::CBIndex::c_16);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ void MAIN {
#if defined FP32_DEST_ACC_EN
reconfig_data_format(cb_reduce, cb_scaler);
#endif
reduce_init_delta<false>(cb_out0, cb_reduce, cb_scaler);
reduce_init_delta<false>(cb_reduce, cb_scaler, cb_out0);
reduce_tile(cb_reduce, cb_scaler, 0, 0, 0);
reduce_revert_delta(cb_out0);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ void MAIN {
#if defined FP32_DEST_ACC_EN
reconfig_data_format(cb_reduce, cb_scaler);
#endif
reduce_init_delta<false>(cb_out0, cb_reduce, cb_scaler);
reduce_init_delta<false>(cb_reduce, cb_scaler, cb_out0);
reduce_tile((do_mask) ? (cb_intermed0) : (cb_in0), cb_scaler, 0, 0, 0);
reduce_revert_delta(cb_out0);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ void MAIN {
#if defined FP32_DEST_ACC_EN
reconfig_data_format(cb_input, cb_scaler);
#endif
reduce_init_delta<false>(cb_accum_dst, cb_input, cb_scaler);
reduce_init_delta<false>(cb_input, cb_scaler, cb_accum_dst);
reduce_tile(cb_input, cb_scaler, 0, 0, reduce_dst_idx);
reduce_revert_delta(cb_accum_dst);

Expand Down Expand Up @@ -104,7 +104,7 @@ void MAIN {
#if defined FP32_DEST_ACC_EN
reconfig_data_format(cb_input, cb_scaler);
#endif
reduce_init_delta<false>(cb_out, cb_input, cb_scaler);
reduce_init_delta<false>(cb_input, cb_scaler, cb_out);
reduce_tile(cb_input, cb_scaler, 0, 0, reduce_dst_idx);
reduce_revert_delta(cb_out);
tile_regs_commit();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ void MAIN {

// Partial-E[x]
index_h_offset = 0;
reduce_init_delta<false>(cb_ex_partial, cb_x, cb_scaler);
reduce_init_delta<false>(cb_x, cb_scaler, cb_ex_partial);
cb_reserve_back(cb_ex_partial, 1);
tile_regs_acquire();
cb_wait_front(cb_scaler, 1);
Expand All @@ -219,7 +219,7 @@ void MAIN {
reduce_revert_delta(cb_ex_partial);

if constexpr (is_mcast_sender and num_cores_per_mcast_group > 1) {
reduce_init_delta<false>(cb_ex_global, cb_ex_external, cb_scaler_global);
reduce_init_delta<false>(cb_ex_external, cb_scaler_global, cb_ex_global);
cb_reserve_back(cb_ex_global, 1);
cb_reserve_back(cb_ex, 1);
tile_regs_acquire();
Expand Down Expand Up @@ -316,7 +316,7 @@ void MAIN {

// Partial-Var(x)
index_h_offset = 0;
reduce_init_delta<false>(cb_ex_partial, cb_xmm, cb_scaler);
reduce_init_delta<false>(cb_xmm, cb_scaler, cb_ex_partial);
cb_reserve_back(cb_ex_partial, 1);
tile_regs_acquire();
cb_wait_front(cb_xmm, block_hw);
Expand All @@ -337,7 +337,7 @@ void MAIN {
reduce_revert_delta(cb_ex_partial);

if constexpr (is_mcast_sender and num_cores_per_mcast_group > 1) {
reduce_init_delta<false>(cb_ex_global, cb_ex_external, cb_scaler_global);
reduce_init_delta<false>(cb_ex_external, cb_scaler_global, cb_ex_global);
cb_reserve_back(cb_ex_global, 1);
cb_reserve_back(cb_ex, 1);
tile_regs_acquire();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ void MAIN {
*/
ACQ();
cb_reserve_back(cb_ex, onetile);
reduce_init_delta<false>(cb_ex, cb_x, cb_scaler);
reduce_init_delta<false>(cb_x, cb_scaler, cb_ex);
for (uint32_t wt = 0; wt < Wt; wt += blk) {
cb_wait_front(cb_x, wt + blk);
for (uint32_t j = 0; j < blk; j++) {
Expand Down Expand Up @@ -193,7 +193,7 @@ void MAIN {
reconfig_data_format(cb_xmm2, cb_scaler);
}
cb_reserve_back(cb_ex2, 1);
reduce_init_delta<false>(cb_ex2, cb_xmm2, cb_scaler);
reduce_init_delta<false>(cb_xmm2, cb_scaler, cb_ex2);
ACQ();
cb_wait_front(cb_xmm2, Wt);
// cb_wait_front(cb_xmm, Wt);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ void MAIN {
#ifndef RMSNORM
// E[x],
index_h_offset = 0;
reduce_init_delta<false>(cb_ex_partial, cb_in, cb_scaler);
reduce_init_delta<false>(cb_in, cb_scaler, cb_ex_partial);
cb_wait_front(cb_scaler, 1);
cb_reserve_back(cb_ex_partial, block_h);
for (uint32_t i = 0; i < block_h; i++) {
Expand All @@ -163,7 +163,7 @@ void MAIN {

// global reduce, cb_ex <-- cb_ex_external, cb_ex_partial
if constexpr (is_allgather_worker) {
reduce_init_delta<false>(cb_ex, cb_ex_external, cb_scaler_global);
reduce_init_delta<false>(cb_ex_external, cb_scaler_global, cb_ex);
cb_reserve_back(cb_ex, num_tiles_per_allgather_worker);

for (uint32_t i = 0; i < num_tiles_per_allgather_worker; i++) {
Expand Down Expand Up @@ -262,7 +262,7 @@ void MAIN {
cb_wait_front(cb_scaler, 1);
#endif
cb_reserve_back(cb_ex_partial2, block_h);
reduce_init_delta<false>(cb_ex_partial2, cb_xmm2, cb_scaler);
reduce_init_delta<false>(cb_xmm2, cb_scaler, cb_ex_partial2);
index_h_offset = 0;
for (uint32_t i = 0; i < block_h; i++) {
tile_regs_acquire();
Expand All @@ -281,7 +281,7 @@ void MAIN {

// global reduce, cb_ex <-- cb_ex_external, cb_ex_partial
if constexpr (is_allgather_worker) {
reduce_init_delta<false>(cb_ex2, cb_ex_external2, cb_scaler_global);
reduce_init_delta<false>(cb_ex_external2, cb_scaler_global, cb_ex2);
cb_reserve_back(cb_ex2, num_tiles_per_allgather_worker);

for (uint32_t i = 0; i < num_tiles_per_allgather_worker; i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ void MAIN {
#endif

cb_wait_front(cb_scaler_global, 1);
reduce_init_delta<false>(cb_var, cb_stats, cb_scaler_global);
reduce_init_delta<false>(cb_stats, cb_scaler_global, cb_var);
tile_regs_acquire();
// striding over cb_stats, consisting [E(X), E(X^2)] from all the distributed devices in interleaved order
for (uint32_t w = 0; w < stats_tiles * num_distributed_blocks; w++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ void MAIN {
#endif
// E[x],
index_h_offset = 0;
reduce_init_delta<false>(cb_ex_partial2, cb_in0, cb_scaler);
reduce_init_delta<false>(cb_in0, cb_scaler, cb_ex_partial2);

cb_reserve_back(cb_ex_partial2, block_h);
for (uint32_t i = 0; i < block_h; i++) {
Expand Down Expand Up @@ -177,7 +177,7 @@ void MAIN {

cb_reserve_back(cb_ex_partial2, block_h); // RMS E(x2) #Layernorm //E(x) and E(x^2)

reduce_init_delta<false>(cb_ex_partial2, cb_x2, cb_scaler);
reduce_init_delta<false>(cb_x2, cb_scaler, cb_ex_partial2);
index_h_offset = 0;
for (uint32_t i = 0; i < block_h; i++) {
tile_regs_acquire();
Expand All @@ -200,7 +200,7 @@ void MAIN {
cb_wait_front(cb_scaler_global, 1);
reconfig_data_format_srca(cb_x2, cb_ex_external2);
reconfig_data_format_srcb(cb_scaler, cb_scaler_global);
reduce_init_delta<false>(cb_reduction_out, cb_ex_external2, cb_scaler_global);
reduce_init_delta<false>(cb_ex_external2, cb_scaler_global, cb_reduction_out);
cb_reserve_back(cb_reduction_out, num_tiles_per_partial_result * num_tiles_per_allgather_worker);

for (uint32_t i = 0; i < num_tiles_per_allgather_worker; i++) { // loops over height
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ void MAIN {
* cb_stats = [sum(x0**2), sum(x0), sum(x1**2), sum(x1), ...]
* RMSNorm packs mean(x**2) into cb_var. Layernorm just uses cb_stats_reduced.
*/
reduce_init_delta<false>(cb_stats_reduced, cb_stats, cb_reduce);
reduce_init_delta<false>(cb_stats, cb_reduce, cb_stats_reduced);
cb_wait_front(cb_stats, stats_tiles_cols);
cb_reserve_back(cb_stats_reduced, stats_tile_stride);
#ifdef RMSNORM
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ void MAIN {
*/
reconfig_data_format(cb_x2, cb_reduce);
pack_reconfig_data_format(cb_out);
reduce_init_delta<false>(cb_out, cb_x2, cb_reduce);
reduce_init_delta<false>(cb_x2, cb_reduce, cb_out);
cb_wait_front(cb_x2, Wt);
cb_reserve_back(cb_out, onetile);
ACQ();
Expand All @@ -89,7 +89,7 @@ void MAIN {
*/
reconfig_data_format(cb_inp, cb_reduce);
pack_reconfig_data_format(cb_out);
reduce_init_delta<false>(cb_out, cb_inp, cb_reduce);
reduce_init_delta<false>(cb_inp, cb_reduce, cb_out);
cb_reserve_back(cb_out, onetile);
ACQ();
for (uint32_t wtr = 0; wtr < Wt; wtr++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ void calc_numeric_stable(
reconfig_data_format(cb_in, cb_bcast_scaler);
cb_reserve_back(cb_max, 1);
cb_wait_front(cb_bcast_scaler, 1);
reduce_init_delta<false, PoolType::MAX, ReduceDim::REDUCE_ROW>(cb_max, cb_in, cb_bcast_scaler);
reduce_init_delta<false, PoolType::MAX, ReduceDim::REDUCE_ROW>(cb_in, cb_bcast_scaler, cb_max);
for (uint32_t wt = 0; wt < Wt; wt++) {
cb_wait_front(cb_in, wt + 1);
constexpr uint32_t bcast_scaler0 = 0;
Expand Down Expand Up @@ -256,7 +256,7 @@ void MAIN {

ACQ();
cb_reserve_back(cb_recipsumexps, onetile);
reduce_init_delta<false>(cb_recipsumexps, cb_exps, cb_bcast_scaler);
reduce_init_delta<false>(cb_exps, cb_bcast_scaler, cb_recipsumexps);
for (uint32_t wt = 0; wt < Wt; wt++) {
cb_wait_front(cb_exps, wt + 1); // must be a cumulative wait for correctness
constexpr uint32_t bcast_scaler0 = 0; // 0th index from bcast_scaler CB
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ ALWI void calc_numeric_stable(uint32_t cb_in, uint32_t cb_bcast_scaler, uint32_t
reconfig_data_format(cb_in, cb_bcast_scaler);
pack_reconfig_data_format(cb_max);
cb_reserve_back(cb_max, 1);
reduce_init_delta<false, PoolType::MAX, ReduceDim::REDUCE_ROW>(cb_max, cb_in, cb_bcast_scaler);
reduce_init_delta<false, PoolType::MAX, ReduceDim::REDUCE_ROW>(cb_in, cb_bcast_scaler, cb_max);
cb_wait_front(cb_bcast_scaler, 1);
for (uint32_t w = 0; w < block_w; w++) {
constexpr uint32_t bcast_scaler0 = 0;
Expand Down Expand Up @@ -203,7 +203,7 @@ void MAIN {

// sum(exp(x))
ACQ();
reduce_init_delta<false>(cb_recipsumexps, cb_exps, cb_bcast_scaler);
reduce_init_delta<false>(cb_exps, cb_bcast_scaler, cb_recipsumexps);
cb_wait_front(cb_exps, block_w);
cb_wait_front(cb_bcast_scaler, 1);
cb_reserve_back(cb_recipsumexps, 1);
Expand Down

0 comments on commit 81a72b3

Please sign in to comment.