Skip to content

Commit

Permalink
#15450: Remove default values from circular buffer parameters in LLK …
Browse files Browse the repository at this point in the history
…compute APIs: Transpose and Reduce (#16427)

### Ticket
[Link to Github
Issue](#15450)

### Problem description
Default values for circular buffer arguments in the LLK API can cause
errors. Forgetting to set these arguments explicitly may lead to errors
due to wrong cb usage. This PR is specific to the changes in the
transpose_wh and reduce kernel APIs:
- ./tt_metal/include/compute_kernel_api/transpose_wh.h
- ./tt_metal/include/compute_kernel_api/reduce.h

### What's changed
Default values for the circular buffer parameters have been removed from
functions within these files. The call chains invoking these functions
have been updated to contain explicit arguments for these parameters.

### Checklist
- [x] Post commit CI passes [All post-commit
tests](https://github.com/tenstorrent/tt-metal/actions/runs/12653818878)
- [x] Blackhole Post commit (if applicable) [Blackhole post-commit
test](https://github.com/tenstorrent/tt-metal/actions/runs/12653822690)
- [ ] Model regression CI testing passes (if applicable)
- [ ] Device performance regression CI testing passes (if applicable)
- [ ] **(For models and ops writers)** Full [new
models](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml)
tests passes
- [ ] New/Existing tests provide coverage for changes
  • Loading branch information
atatuzunerTT authored Jan 16, 2025
1 parent 7b6af76 commit f22e28f
Show file tree
Hide file tree
Showing 38 changed files with 88 additions and 88 deletions.
2 changes: 1 addition & 1 deletion tests/tt_metal/tt_metal/test_kernels/compute/cumsum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ void MAIN {
#ifndef ROWWISE
init_sfpu(tt::CBIndex::c_0, tt::CBIndex::c_16);
#else
transpose_wh_init(tt::CBIndex::c_0);
transpose_wh_init(tt::CBIndex::c_0, tt::CBIndex::c_16);
#endif
cumsum_tile_init();

Expand Down
8 changes: 4 additions & 4 deletions tests/tt_metal/tt_metal/test_kernels/compute/layernorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ void MAIN {
*/
ACQ();
cb_reserve_back(cb_ex, 1 * onetile);
reduce_init_delta<false>();
reduce_init_delta<false>(cb_ex, cb_x, cb_scaler);
for (uint32_t wt = 0; wt < Wt; wt += blk) {
cb_wait_front(cb_x, wt + blk);
for (uint32_t j = 0; j < blk; j++) {
Expand All @@ -107,7 +107,7 @@ void MAIN {
// we don't pop cb_x until we compute Ex
}
pack_tile(dst0, cb_ex);
reduce_revert_delta();
reduce_revert_delta(cb_ex);
REL();

cb_push_back(cb_ex, 1);
Expand Down Expand Up @@ -154,7 +154,7 @@ void MAIN {
* TODO(AP): can save space here by reusing CB
*/
cb_reserve_back(cb_ex2, 1);
reduce_init_delta<false>();
reduce_init_delta<false>(cb_ex2, cb_xmm2, cb_scaler);
ACQ();
cb_wait_front(cb_xmm2, Wt);
// cb_wait_front(cb_xmm, Wt);
Expand All @@ -167,7 +167,7 @@ void MAIN {
}
cb_pop_front(cb_xmm2, Wt);
pack_tile(dst0, cb_ex2);
reduce_revert_delta();
reduce_revert_delta(cb_ex2);
REL();

cb_push_back(cb_ex2, 1);
Expand Down
2 changes: 1 addition & 1 deletion tests/tt_metal/tt_metal/test_kernels/compute/max_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ inline void reduce_h(
uint32_t out_cb_id) {
cb_wait_front(in_cb_id, in_ntiles_hwc * out_nelems);
cb_reserve_back(out_cb_id, out_ntiles_c * out_nelems);
reduce_init_delta<false, PoolType::MAX, ReduceDim::REDUCE_COL>(out_cb_id);
reduce_init_delta<false, PoolType::MAX, ReduceDim::REDUCE_COL>(out_cb_id, in_cb_id, in_scalar_cb_id);
uint32_t base_tile_id = 0;
for (uint32_t c_i = 0; c_i < in_ntiles_c * out_nelems; ++c_i) {
// add to accumulator all the in_ntiles_hw in a column of tiles
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ inline void reduce_h(
uint32_t out_cb_id) {
cb_wait_front(in_cb_id, in_ntiles_hwc * out_nelems);
cb_reserve_back(out_cb_id, out_ntiles_c * out_nelems);
reduce_init_delta<false, PoolType::MAX, ReduceDim::REDUCE_COL>(out_cb_id);
reduce_init_delta<false, PoolType::MAX, ReduceDim::REDUCE_COL>(out_cb_id, in_cb_id, in_scalar_cb_id);
uint32_t base_tile_id = 0;
for (uint32_t c_i = 0; c_i < in_ntiles_c * out_nelems; ++c_i) {
// add to accumulator all the in_ntiles_hw in a column of tiles
Expand Down
2 changes: 1 addition & 1 deletion tests/tt_metal/tt_metal/test_kernels/compute/reduce_h.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ void MAIN {
constexpr bool at_start = get_compile_time_arg_val(3);
dummy_init<at_start>(tt::CBIndex::c_0, tt::CBIndex::c_2);
#ifndef SHORT_INIT
reduce_init<at_start>(tt::CBIndex::c_0, tt::CBIndex::c_2);
reduce_init<at_start>(tt::CBIndex::c_0, tt::CBIndex::c_2, tt::CBIndex::c_16);
#else
reduce_init_delta<at_start>(tt::CBIndex::c_16, tt::CBIndex::c_0, tt::CBIndex::c_2);
#endif
Expand Down
2 changes: 1 addition & 1 deletion tests/tt_metal/tt_metal/test_kernels/compute/reduce_hw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ void MAIN {
constexpr bool at_start = get_compile_time_arg_val(3);
dummy_init<at_start>(tt::CBIndex::c_0, tt::CBIndex::c_2);
#ifndef SHORT_INIT
reduce_init<at_start>(tt::CBIndex::c_0, tt::CBIndex::c_2);
reduce_init<at_start>(tt::CBIndex::c_0, tt::CBIndex::c_2, tt::CBIndex::c_16);
#else
reduce_init_delta<at_start>(tt::CBIndex::c_16, tt::CBIndex::c_0, tt::CBIndex::c_2);
#endif
Expand Down
2 changes: 1 addition & 1 deletion tests/tt_metal/tt_metal/test_kernels/compute/reduce_w.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ void MAIN {
constexpr bool at_start = get_compile_time_arg_val(3);
dummy_init<at_start>(tt::CBIndex::c_0, tt::CBIndex::c_2);
#ifndef SHORT_INIT
reduce_init<at_start>(tt::CBIndex::c_0, tt::CBIndex::c_2);
reduce_init<at_start>(tt::CBIndex::c_0, tt::CBIndex::c_2, tt::CBIndex::c_16);
#else
reduce_init_delta<at_start>(tt::CBIndex::c_16, tt::CBIndex::c_0, tt::CBIndex::c_2);
#endif
Expand Down
4 changes: 2 additions & 2 deletions tests/tt_metal/tt_metal/test_kernels/compute/rmsnorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ void MAIN {
* compute E[(x)^2]
*/
cb_reserve_back(cb_ex2, 1);
reduce_init_delta<false>();
reduce_init_delta<false>(cb_ex2, cb_x2, cb_scaler);
ACQ();
cb_wait_front(cb_x2, Wt);
// cb_wait_front(cb_xmm, Wt);
Expand All @@ -123,7 +123,7 @@ void MAIN {
// reduce_tile(cb_xmm, cb_scaler, wt+wtr, scaler0, dst0);
}
cb_pop_front(cb_x2, Wt);
reduce_revert_delta();
reduce_revert_delta(cb_ex2);
pack_tile(dst0, cb_ex2);
REL();

Expand Down
4 changes: 2 additions & 2 deletions tests/tt_metal/tt_metal/test_kernels/compute/softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,13 @@ void MAIN {

ACQ();
cb_reserve_back(cb_recipsumexps, onetile);
reduce_init_delta<false>();
reduce_init_delta<false>(cb_recipsumexps, cb_exps, cb_bcast_scaler);
for (uint32_t wt = 0; wt < Wt; wt++) {
cb_wait_front(cb_exps, wt + 1); // must be a cumulative wait for correctness
constexpr uint32_t bcast_scaler0 = 0; // 0th index from bcast_scaler CB
reduce_tile(cb_exps, cb_bcast_scaler, wt, bcast_scaler0, dst0);
}
reduce_revert_delta();
reduce_revert_delta(cb_recipsumexps);
recip_tile_init();
recip_tile(dst0); // DST[0] = 1/sum(exp(x))
pack_tile(dst0, cb_recipsumexps);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace NAMESPACE {
void MAIN {
uint32_t NHtWt = get_compile_time_arg_val(0);
#ifndef SHORT_INIT
transpose_wh_init(tt::CBIndex::c_0);
transpose_wh_init(tt::CBIndex::c_0, tt::CBIndex::c_16);
#else
unary_op_init_common(tt::CBIndex::c_0);
transpose_wh_init_short(tt::CBIndex::c_0);
Expand Down
10 changes: 5 additions & 5 deletions tt_metal/include/compute_kernel_api/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
namespace ckernel {

template <bool at_start, PoolType reduce_type = REDUCE_OP, ReduceDim reduce_dim = REDUCE_DIM>
ALWI void reduce_init(uint32_t icb, uint32_t icb_scaler, uint32_t ocb = 16) {
ALWI void reduce_init(uint32_t icb, uint32_t icb_scaler, uint32_t ocb) {
UNPACK((llk_unpack_AB_hw_configure_disaggregated<DST_ACCUM_MODE>(icb, icb_scaler)));
UNPACK((llk_unpack_AB_reduce_init<reduce_dim>(icb, icb_scaler)));

Expand All @@ -30,14 +30,14 @@ ALWI void reduce_init(uint32_t icb, uint32_t icb_scaler, uint32_t ocb = 16) {
}

template <PoolType reduce_type = REDUCE_OP, ReduceDim reduce_dim = REDUCE_DIM>
ALWI void reduce_init_short(uint32_t icb, uint32_t icb_scaler, uint32_t ocb = 16) {
ALWI void reduce_init_short(uint32_t icb, uint32_t icb_scaler, uint32_t ocb) {
UNPACK((llk_unpack_AB_reduce_init<reduce_dim>(icb, icb_scaler)));
MATH((llk_math_reduce_init<reduce_type, reduce_dim, MATH_FIDELITY>()));
PACK((llk_pack_reduce_config_v2<reduce_dim, false, false, DST_ACCUM_MODE>(ocb)));
}

template <bool at_start, PoolType reduce_type = REDUCE_OP, ReduceDim reduce_dim = REDUCE_DIM>
ALWI void reduce_init_delta(uint32_t ocb = 16, uint32_t icb0 = 0, uint32_t icb1 = 1) {
ALWI void reduce_init_delta(uint32_t ocb, uint32_t icb0, uint32_t icb1) {
// FIXME: API Update needed in compute kernel?
UNPACK((llk_unpack_AB_reduce_init<reduce_dim>(icb0, icb1)));

Expand All @@ -47,7 +47,7 @@ ALWI void reduce_init_delta(uint32_t ocb = 16, uint32_t icb0 = 0, uint32_t icb1
}

template <PoolType reduce_type = REDUCE_OP, ReduceDim reduce_dim = REDUCE_DIM>
ALWI void reduce_init_delta_no_pack(uint32_t icb0 = 0, uint32_t icb1 = 1) {
ALWI void reduce_init_delta_no_pack(uint32_t icb0, uint32_t icb1) {
// FIXME: API Update needed in compute kernel?
UNPACK((llk_unpack_AB_init<>(icb0, icb1)));

Expand All @@ -60,7 +60,7 @@ ALWI void reduce_init_delta_math() {
}

template <ReduceDim reduce_dim = REDUCE_DIM>
ALWI void reduce_revert_delta(uint32_t ocb = 16) {
ALWI void reduce_revert_delta(uint32_t ocb) {
PACK((llk_pack_reduce_config_v2<reduce_dim, false, true, DST_ACCUM_MODE>(ocb)));
}

Expand Down
2 changes: 1 addition & 1 deletion tt_metal/include/compute_kernel_api/transpose_wh.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace ckernel {
* |----------------|-------------------------------------------------------------|----------|------------------------------------------------|----------|
* | icb | The identifier of the circular buffer (CB) containing input | uint32_t | 0 to 31 | True |
*/
ALWI void transpose_wh_init(uint32_t icb, uint32_t ocb = 16) {
ALWI void transpose_wh_init(uint32_t icb, uint32_t ocb) {
MATH((llk_math_eltwise_unary_datacopy_init<A2D, BroadcastType::NONE, DST_ACCUM_MODE>(true, true, icb)));
MATH((llk_math_pack_sync_init<DST_ACCUM_MODE>()));
MATH((llk_math_hw_configure_disaggregated(icb, icb)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -955,7 +955,7 @@ ALWI void reduce_and_recip_tile_to_cb(
cb_pop_front(icb1, pop1);
}

reduce_revert_delta();
reduce_revert_delta(ocb);

recip_tile_init();
recip_tile(dst0);
Expand Down Expand Up @@ -992,7 +992,7 @@ ALWI void reduce_and_log_tile_to_cb(
cb_pop_front(icb1, pop1);
}

reduce_revert_delta();
reduce_revert_delta(ocb);

log_tile_init();
log_tile(dst0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
namespace NAMESPACE {
void MAIN {
uint32_t NHtWt = get_compile_time_arg_val(0);
transpose_wh_init(tt::CBIndex::c_0);
transpose_wh_init(tt::CBIndex::c_0, tt::CBIndex::c_16);

// transpose a row-major block:
// - assumes the tiles come in in column major order from reader
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ FORCE_INLINE void matmul_with_transpose_and_mask(
// TODO: checking required when the input cb format and intermediate cb format are different.
mm_init(cb_in0, cb_in1, cb_out0);
if (transpose_input || transpose_other) {
transpose_wh_init(cb_in0);
transpose_wh_init(cb_in0, cb_out0);
}

if (need_input_mask_h || need_input_mask_w) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace NAMESPACE {
void MAIN {
uint32_t NHtWt = get_arg_val<uint32_t>(0);

transpose_wh_init(tt::CBIndex::c_0);
transpose_wh_init(tt::CBIndex::c_0, tt::CBIndex::c_16);

// transpose a row-major block:
// - assumes the tiles come in in column major order from reader
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ void MAIN {
constexpr uint32_t cb_id_in = get_compile_time_arg_val(0);
constexpr uint32_t cb_id_out = get_compile_time_arg_val(1);

transpose_wh_init(cb_id_in);
transpose_wh_init(cb_id_in, cb_id_out);

// transpose a row-major block:
// - uses reader_unary_transpose_wh
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ void MAIN {
constexpr uint32_t cb_transpose_in = get_compile_time_arg_val(1);
constexpr uint32_t cb_out = get_compile_time_arg_val(2);

transpose_wh_init(cb_in);
transpose_wh_init(cb_in, cb_transpose_in);
pack_untilize_init(cb_in, cb_transpose_in);

for (uint32_t idx = 0; idx < total_tiles; idx++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ void MAIN {
constexpr uint32_t intermed_cb_id2 = get_compile_time_arg_val(4);
constexpr uint32_t output_cb_id = get_compile_time_arg_val(5);

reduce_init<true>(input_cb_id, scalar_cb_id);
reduce_init<true>(input_cb_id, scalar_cb_id, intermed_cb_id1);
reduce_revert_delta<REDUCE_DIM>(intermed_cb_id1); // Required or else the first tile is wrong

for (uint32_t block_h_id = 0; block_h_id < input_num_blocks_h; block_h_id++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace NAMESPACE {
void MAIN {
uint32_t num_tiles = get_compile_time_arg_val(0);

transpose_wh_init(tt::CBIndex::c_24);
transpose_wh_init(tt::CBIndex::c_24, tt::CBIndex::c_17);

constexpr uint32_t cb_im0 = tt::CBIndex::c_24;
constexpr uint32_t cb_out1 = tt::CBIndex::c_17;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,9 @@ void MAIN {
cb_wait_front(cb_xpowadd, onetile);
cb_reserve_back(cb_y, onetile);

reduce_init_delta<false>();
reduce_init_delta<false>(cb_y, cb_xpowadd, cb_one);
reduce_tile(cb_xpowadd, cb_one, 0, 0, dst0);
reduce_revert_delta();
reduce_revert_delta(cb_y);
tile_regs_commit();

tile_regs_wait();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ void MAIN {
}

cb_wait_front(tt::CBIndex::c_24, onetile);
reduce_init_delta<false>();
reduce_init_delta<false>(tt::CBIndex::c_16, tt::CBIndex::c_24, tt::CBIndex::c_2);
reduce_tile(tt::CBIndex::c_24, tt::CBIndex::c_2, 0, 0, 0);
cb_pop_front(tt::CBIndex::c_24, onetile);
reduce_revert_delta();
reduce_revert_delta(tt::CBIndex::c_16);

if (last_out) {
cb_reserve_back(tt::CBIndex::c_16, onetile);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@ void MAIN {
#if defined FP32_DEST_ACC_EN
reconfig_data_format(cb_reduce, cb_scaler);
#endif
reduce_init_delta<false>();
reduce_init_delta<false>(cb_out0, cb_reduce, cb_scaler);
reduce_tile(cb_reduce, cb_scaler, 0, 0, 0);
reduce_revert_delta();
reduce_revert_delta(cb_out0);

if (do_mask) {
cb_pop_front(cb_intermed0, onetile);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,9 @@ void MAIN {
#if defined FP32_DEST_ACC_EN
reconfig_data_format(cb_reduce, cb_scaler);
#endif
reduce_init_delta<false>();
reduce_init_delta<false>(cb_out0, cb_reduce, cb_scaler);
reduce_tile((do_mask) ? (cb_intermed0) : (cb_in0), cb_scaler, 0, 0, 0);
reduce_revert_delta();
reduce_revert_delta(cb_out0);

if (do_mask) {
cb_pop_front(cb_intermed0, onetile);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ FORCE_INLINE void matmul_with_transpose_and_mask(
// TODO: checking required when the input cb format and intermediate cb format are different.
mm_init(cb_in0, cb_in1, cb_out0);
if (transpose_input || transpose_other) {
transpose_wh_init(cb_in0);
transpose_wh_init(cb_in0, cb_out0);
}

if (need_input_mask_h || need_input_mask_w) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ void MAIN {
#if defined FP32_DEST_ACC_EN
reconfig_data_format(cb_input, cb_scaler);
#endif
reduce_init_delta<false>();
reduce_init_delta<false>(cb_accum_dst, cb_input, cb_scaler);
reduce_tile(cb_input, cb_scaler, 0, 0, reduce_dst_idx);
reduce_revert_delta();
reduce_revert_delta(cb_accum_dst);

cb_pop_front(cb_input, onetile);
}
Expand Down Expand Up @@ -104,9 +104,9 @@ void MAIN {
#if defined FP32_DEST_ACC_EN
reconfig_data_format(cb_input, cb_scaler);
#endif
reduce_init_delta<false>();
reduce_init_delta<false>(cb_out, cb_input, cb_scaler);
reduce_tile(cb_input, cb_scaler, 0, 0, reduce_dst_idx);
reduce_revert_delta();
reduce_revert_delta(cb_out);
tile_regs_commit();

cb_reserve_back(cb_out, onetile);
Expand Down
Loading

0 comments on commit f22e28f

Please sign in to comment.