Skip to content

Commit

Permalink
Use mbarrier for WAR for circular buffering (#3446)
Browse files Browse the repository at this point in the history
Currently, in the TMA circular buffering pass, we generate code like
below
```C++
// main loop
for i:
  if (elect_sync) {
    arrive-load;
    TMA;
  }
  wait-load;
  compute;
  __syncthreads(); // for avoiding WAR harzard
}
```

This PR adds an option to change the generated code into:
```C++
// main loop
for i:
  if (elect_sync) {
    wait-compute;
    arrive-load;
    TMA;
  }
  wait-load;
  compute;
  arrive-compute;
}
```

That is, the plain old `__syncthreads()` is replaced with an arrive-wait
of a mbarrier. With this change, each circular buffer stage will be
using two mbarriers, one for signaling that the corresponding data has
been loaded to smem, and ready to read (RAW harzard), and another for
signaling that the corresponding data has been fully read, there is no
more read in the future, so feel free to reuse the space to hold new
data (WAR harzard).

In theory, I am expecting better performance when enabling this feature,
because `__syncthreads()` is a hard sync that requires all warp groups
to reach this point, while the arrive wait barrier is a much softer sync
that requires all warp groups has passed a point before. But
unfortunately, the perf is worse for my matmul kernel. This is why I
made this feature a default-off one.

The main purpose for this PR is not to support a slower way to wait. The
reason that I does this work is, this way of having two mbarriers, one
for RAW another for WAR, is very close to the warp specialization code
we want to generate, and the work in this PR can be largely reused by
warp specialization. So, the main purpose for this PR is to serve as an
incremental step towards warp specialization, and providing a second
option for people to try on is only a side benefit.

Besides, the above code is a very good illustration why we should go for
warp specialization. On Hopper, both TMA and MMA are async, however,
putting load and compute into the same warp makes it impossible for us
to truely pipeline both TMA and MMA. From the above code, we can easily
see that, for each MMA, it not only need to wait for the data to be
ready (RAW), but also need to wait for the buffer for the next load to
be freed. Why? Just because it is a few lines above in the code. The MMA
in this iteration has no real dependency on the buffer being freed,
because the MMA in this iteration will not touch it, but there is a fake
dependency just because it is earlier in code. What if the the data for
this iteration's MMA is ready, but the buffer for this iteration's load
is not freed yet? Can I just start doing MMA while waiting for that
buffer to be freed? No I can not, because we are doing load and compute
in the same warp, we need both to be ready to move on with this
iteration. With this observation, it is very natural that we should
separate the loading and computation into different warps.
  • Loading branch information
zasdfgbnm authored Nov 22, 2024
1 parent b5e5182 commit 769a4d2
Show file tree
Hide file tree
Showing 7 changed files with 577 additions and 97 deletions.
2 changes: 1 addition & 1 deletion .clang-tidy
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ cppcoreguidelines-*,
-cppcoreguidelines-pro-type-vararg,
-cppcoreguidelines-special-member-functions,
-cppcoreguidelines-non-private-member-variables-in-classes,
-cppcoreguidelines-avoid-goto,
-facebook-hte-RelativeInclude,
hicpp-exception-baseclass,
hicpp-avoid-goto,
misc-unused-alias-decls,
misc-unused-using-decls,
modernize-*,
Expand Down
123 changes: 95 additions & 28 deletions csrc/device_lower/pass/allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ namespace nvfuser {

namespace {

enum class CircularBufferWaitType { ReadAfterWrite, WriteAfterRead };

// This function creates kir::Loop with range based on stage depth. It is
// used for mbarrier initialization and invalidation.
ForLoop* createStageDepthForLoop(ForLoop* circular_buffer_loop) {
Expand All @@ -43,25 +45,52 @@ ForLoop* createStageDepthForLoop(ForLoop* circular_buffer_loop) {
// }
Expr* initializeMbarrier(
ForLoop* circular_buffer_loop,
TensorView* all_mbarriers) {
TensorView* all_mbarriers,
CircularBufferWaitType wait_type) {
NVF_ERROR(circular_buffer_loop != nullptr);
ForLoop* loop = createStageDepthForLoop(circular_buffer_loop);

int64_t stage_depth =
GpuLower::current()
->circularBufferInfo()
.getCircularBufferOptionsFor(circular_buffer_loop->iter_domain())
.stage;

// We use mbarrier[0:stage_depth] for RAW, and
// mbarrier[stage_depth:2*stage_depth] for WAR.
Val* mbarrier_index = wait_type == CircularBufferWaitType::ReadAfterWrite
? loop->index()
: SimplifyingIrBuilder::addExpr(loop->index(), stage_depth);

// Get mbarrier for this circular buffer stage.
kir::TensorIndex* stage_mbarrier =
IrBuilder::create<kir::TensorIndex>(all_mbarriers, loop->index());
IrBuilder::create<kir::TensorIndex>(all_mbarriers, mbarrier_index);

auto circular_buffered_tvs =
GpuLower::current()->circularBufferInfo().getCircularBufferTvs(
circular_buffer_loop);
int64_t num_of_tvs_loaded_by_tma = std::count_if(
circular_buffered_tvs.begin(),
circular_buffered_tvs.end(),
[](const TensorView* tv) {
return ir_utils::isCpAsyncBulkLoad(tv->definition());
});
Val* num_of_arrives =
IrBuilder::create<Val>(num_of_tvs_loaded_by_tma, DataType::UInt32);

Val* num_of_arrives = nullptr;
if (wait_type == CircularBufferWaitType::ReadAfterWrite) {
// The mbarrier of RAW is used to wait for the completion of the TMA
// load of the circular buffer tensor. The number of arrives is the
// number of TMA issued for the circular buffer tensor.
int64_t num_of_tvs_loaded_by_tma = std::count_if(
circular_buffered_tvs.begin(),
circular_buffered_tvs.end(),
[](const TensorView* tv) {
return ir_utils::isCpAsyncBulkLoad(tv->definition());
});
num_of_arrives =
IrBuilder::create<Val>(num_of_tvs_loaded_by_tma, DataType::UInt32);
} else {
// The mbarrier of WAR is used to wait for the completion of the reading
// of the circular buffer tensor. The number of arrives is the number of
// threads in the CTA.
num_of_arrives = SimplifyingIrBuilder::maybeCastExpr(
DataType::UInt32,
GpuLower::current()->parallelDimensionMap().getNumThreadsEachBlock());
}

// Initialize mbarrier for each circular buffer stage. Use the thread
// count from the MBarrierInit created in the allocation pass. The wait
Expand All @@ -87,13 +116,26 @@ Expr* initializeMbarrier(
// }
Expr* invalidateMbarrier(
ForLoop* circular_buffer_loop,
TensorView* all_mbarriers) {
TensorView* all_mbarriers,
CircularBufferWaitType wait_type) {
NVF_ERROR(circular_buffer_loop != nullptr);
ForLoop* loop = createStageDepthForLoop(circular_buffer_loop);

int64_t stage_depth =
GpuLower::current()
->circularBufferInfo()
.getCircularBufferOptionsFor(circular_buffer_loop->iter_domain())
.stage;

// We use mbarrier[0:stage_depth] for RAW, and
// mbarrier[stage_depth:2*stage_depth] for WAR.
Val* mbarrier_index = wait_type == CircularBufferWaitType::ReadAfterWrite
? loop->index()
: SimplifyingIrBuilder::addExpr(loop->index(), stage_depth);

// Get mbarrier for this circular buffer stage.
kir::TensorIndex* stage_mbarrier =
IrBuilder::create<kir::TensorIndex>(all_mbarriers, loop->index());
IrBuilder::create<kir::TensorIndex>(all_mbarriers, mbarrier_index);

// Invalidate the mbarrier for each circular buffer stage.
kir::MBarrierInvalidate* mbarrier_inval =
Expand Down Expand Up @@ -640,27 +682,33 @@ class AllocationInserter : public kir::ExprMutator {
// then allocate an array of mbarier objects. mbarrier::init and
// mbarrier::inval will be updated in circular buffering pass, but we
// add them here to handle shared memory correctly in alias memory pass.
int64_t circular_buffer_depth =
GpuLower::current()
->circularBufferInfo()
.getCircularBufferOptionsFor(fl->iter_domain())
.stage;

TensorView* mbarrier =
TensorViewBuilder()
.shape(std::vector<int64_t>{circular_buffer_depth})
.dtype(DataType::UInt)
.contiguity(true)
.build();
const auto& opt =
GpuLower::current()->circularBufferInfo().getCircularBufferOptionsFor(
fl->iter_domain());

// We use mbarrier[0:stage] for RAW, that is, to wait for the completion
// of the TMA load of the circular buffer tensor, and
// mbarrier[stage:2*stage] for WAR, that is, to wait for the completion of
// the reading of the circular buffer tensor.
int64_t num_mbarriers =
opt.usesMBarrierForWAR() ? opt.stage * 2 : opt.stage;

TensorView* mbarrier = TensorViewBuilder()
.shape(std::vector<int64_t>{num_mbarriers})
.dtype(DataType::UInt)
.contiguity(true)
.build();
mbarrier->setMemoryType(MemoryType::Shared);

kir::Allocate* mbarrier_alloc =
IrBuilder::create<kir::Allocate>(mbarrier, MemoryType::Shared);

// Initialize and invalidate mbarriers that are used to notify that
// the load of the circular buffer is complete.
auto mbarrier_init_filled = initializeMbarrier(fl, mbarrier);
auto mbarrier_inval_filled = invalidateMbarrier(fl, mbarrier);
auto mbarrier_init_raw = initializeMbarrier(
fl, mbarrier, CircularBufferWaitType::ReadAfterWrite);
auto mbarrier_inval_raw = invalidateMbarrier(
fl, mbarrier, CircularBufferWaitType::ReadAfterWrite);

// Block sync is necessary to finish mbarrier initialization.
kir::BlockSync* sync = IrBuilder::create<kir::BlockSync>(false);
Expand All @@ -670,6 +718,11 @@ class AllocationInserter : public kir::ExprMutator {
//
// __shared__ mbarrier[num_stages];
// for (circular_buffer_stage) {
// // initialize mbarrier for RAW
// init(mbarrier[stage]);
// }
// for (circular_buffer_stage) {
// // initialize mbarrier for WAR
// init(mbarrier[stage]);
// }
// block_sync();
Expand All @@ -679,13 +732,27 @@ class AllocationInserter : public kir::ExprMutator {
// }
//
// for (circular_buffer_stage) {
// // invalidate mbarrier for WAR
// inval(mbarrier[stage]);
// }
// for (circular_buffer_stage) {
// // invalidate mbarrier for RAW
// inval(mbarrier[stage]);
// }
//
Scope* current_scope = scope_.empty() ? nullptr : scope_.back();
registerInsertBefore(fl, mbarrier_alloc, current_scope);
registerInsertBefore(fl, mbarrier_init_filled, current_scope);
registerInsertAfter(fl, mbarrier_inval_filled, current_scope);
registerInsertBefore(fl, mbarrier_init_raw, current_scope);
registerInsertAfter(fl, mbarrier_inval_raw, current_scope);

if (opt.usesMBarrierForWAR()) {
auto mbarrier_init_war = initializeMbarrier(
fl, mbarrier, CircularBufferWaitType::WriteAfterRead);
auto mbarrier_inval_war = invalidateMbarrier(
fl, mbarrier, CircularBufferWaitType::WriteAfterRead);
registerInsertBefore(fl, mbarrier_init_war, current_scope);
registerInsertAfter(fl, mbarrier_inval_war, current_scope);
}
registerInsertBefore(fl, sync, current_scope);

for (auto tv : circular_buffer_tvs) {
Expand Down
Loading

0 comments on commit 769a4d2

Please sign in to comment.