Skip to content

Commit

Permalink
Add options for persistent kernels/warp specialization in MatmulParams (
Browse files Browse the repository at this point in the history
#3792)

This adds the ability to request warp specialization and persistent
kernels in MatmulParams, including in the python frontend. Note that the
functionality of persistent CTA scheduling will be added in a follow-up
PR. For now, we throw an error when persistent kernels are requested.

This PR also parametrizes the three MLP Benchmark tests used in our
recent perf dive for convenience.

---------

Co-authored-by: Ryan Spring <[email protected]>
  • Loading branch information
jacobhinkle and rdspring1 authored Feb 1, 2025
1 parent 93b68e0 commit 99b5f96
Show file tree
Hide file tree
Showing 8 changed files with 311 additions and 17 deletions.
21 changes: 21 additions & 0 deletions csrc/python_frontend/python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,24 @@ void defineHeuristicParamBindings(py::module& nvfuser) {
.MMAMACROPROP(k, uint16_t)
.TOSTRINGTOPLEVEL(MmaMacro);
#undef MMAMACROPROP
py::enum_<MatmulParams::TilingStrategy>(nvfuser, "MatmulTilingStrategy")
.value("one_tile_per_cta", MatmulParams::TilingStrategy::OneTilePerCTA)
.value(
"distribute_tiles_across_sms",
MatmulParams::TilingStrategy::DistributeTilesAcrossSMs)
.value(
"distribute_stages_across_sms",
MatmulParams::TilingStrategy::DistributeStagesAcrossSMs);
py::enum_<MatmulParams::BufferingLoopLevel>(
nvfuser, "MatmulBufferingLoopLevel")
.value("cta_tiles", MatmulParams::BufferingLoopLevel::CTATiles)
.value("warp_tiles", MatmulParams::BufferingLoopLevel::WarpTiles);
py::enum_<MatmulParams::CircularBufferingStrategy>(
nvfuser, "MatmulCircularBufferingStrategy")
.value("pipelined", MatmulParams::CircularBufferingStrategy::Pipelined)
.value(
"warp_specialized",
MatmulParams::CircularBufferingStrategy::WarpSpecialized);

// Base class for scheduler parameters
DEFINECLASS(HeuristicParams)
Expand Down Expand Up @@ -753,6 +771,9 @@ void defineHeuristicParamBindings(py::module& nvfuser) {
.PARAM(MatmulParams, use_smem_epilogue)
.PARAM(MatmulParams, promote_prologue_smem_reuse)
.PARAM(MatmulParams, splitk_factor)
.PARAM(MatmulParams, tiling_strategy)
.PARAM(MatmulParams, buffering_loop_level)
.PARAM(MatmulParams, circular_buffering_strategy)
.PARAM(MatmulParams, cta_order)
.PARAM(MatmulParams, cluster_dims)
.PARAM(MatmulParams, mma_macro);
Expand Down
22 changes: 22 additions & 0 deletions csrc/scheduler/ampere_multi_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,28 @@ AbstractTensor swizzleSharedMemory(TensorView* shared_mem_tv) {

} // namespace

void AmpereMultipleMatmulScheduler::validate() const {
const auto device_prop = at::cuda::getCurrentDeviceProperties();
const int cc = device_prop->major * 10 + device_prop->minor;
NVF_ERROR(
cc >= 75 && cc < 90,
"This matmul scheduler is restricted to Ampere and Turing.");

NVF_CHECK(
params_->tiling_strategy == MatmulParams::TilingStrategy::OneTilePerCTA,
"Ampere matmul scheduler does not support scheduling persistent CTAs");

NVF_CHECK(
params_->buffering_loop_level ==
MatmulParams::BufferingLoopLevel::CTATiles,
"Ampere matmul scheduler only supports cooperatively buffering at the CTA level (no ping-pong)");

NVF_CHECK(
params_->circular_buffering_strategy ==
MatmulParams::CircularBufferingStrategy::Pipelined,
"Ampere matmul scheduler does not support warp specialization");
}

void AmpereMultipleMatmulScheduler::run() {
// Clears memory spaces on intermediate tensors, calls
// cache{After,Before,Fork} on inputs and outputs
Expand Down
8 changes: 3 additions & 5 deletions csrc/scheduler/ampere_multi_matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,14 @@ class AmpereMultipleMatmulScheduler : public MultipleMatmulScheduler {
public:
AmpereMultipleMatmulScheduler(Fusion* fusion, const MatmulParams* params)
: MultipleMatmulScheduler(fusion, params) {
const auto device_prop = at::cuda::getCurrentDeviceProperties();
const int cc = device_prop->major * 10 + device_prop->minor;
NVF_ERROR(
cc >= 75 && cc < 90,
"This matmul scheduler is restricted to Ampere and Turing.");
validate();
}

void run() final;

private:
void validate() const;

void cacheInputsAndOutputs();

// Including current tensor naming convention for reference,
Expand Down
42 changes: 40 additions & 2 deletions csrc/scheduler/hopper_multi_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,34 @@ MatmulDimRole HopperMultipleMatmulScheduler::findMatmulDimRole(IterDomain* id) {
return it->second;
}

void HopperMultipleMatmulScheduler::validate() const {
const auto device_prop = at::cuda::getCurrentDeviceProperties();
const int cc = device_prop->major * 10 + device_prop->minor;
NVF_ERROR(
cc >= 90 && cc < 100, "This matmul scheduler is restricted to Hopper.");

if (params_->tiling_strategy != MatmulParams::TilingStrategy::OneTilePerCTA) {
NVF_CHECK(
params_->splitk_factor == 1,
"Hopper matmul scheduler does not support scheduling persistent split-K kernels");
}

NVF_CHECK(
params_->tiling_strategy !=
MatmulParams::TilingStrategy::DistributeTilesAcrossSMs,
"Hopper matmul scheduler TEMPORARILY does not support persistent scheduling of tiles yet");

NVF_CHECK(
params_->tiling_strategy !=
MatmulParams::TilingStrategy::DistributeStagesAcrossSMs,
"Hopper matmul scheduler does not support distributing stages across SMs a la stream-K");

NVF_CHECK(
params_->buffering_loop_level ==
MatmulParams::BufferingLoopLevel::CTATiles,
"Hopper matmul scheduler only supports cooperatively buffering at the CTA level (no ping-pong)");
}

void HopperMultipleMatmulScheduler::run() {
// Clears memory spaces on intermediate tensors, calls
// cache{After,Before,Fork} on inputs and outputs
Expand Down Expand Up @@ -640,21 +668,31 @@ void HopperMultipleMatmulScheduler::setUpCircularBuffering() {
" but is expected to be positive and not greater than number of stages: ",
params_->circular_buffer_options.smem_circular_buffer_stage);

CircularBufferType cb_type;
switch (params_->circular_buffering_strategy) {
case MatmulParams::CircularBufferingStrategy::Pipelined:
cb_type = (CircularBufferType)Pipelined(false);
break;
case MatmulParams::CircularBufferingStrategy::WarpSpecialized:
cb_type = (CircularBufferType)WarpSpecialized(ParallelType::TIDy);
}
for (TensorView* acw_smem : acw_smems_) {
acw_smem->circularBuffer(
params_->circular_buffer_options.smem_circular_buffer_stage,
/*prefetch_distance=*/
params_->circular_buffer_options.smem_circular_buffer_stage -
params_->circular_buffer_options
.smem_circular_buffer_prefetch_gap);
.smem_circular_buffer_prefetch_gap,
/*type=*/cb_type);
}
for (TensorView* bcw_smem : bcw_smems_) {
bcw_smem->circularBuffer(
params_->circular_buffer_options.smem_circular_buffer_stage,
/*prefetch_distance=*/
params_->circular_buffer_options.smem_circular_buffer_stage -
params_->circular_buffer_options
.smem_circular_buffer_prefetch_gap);
.smem_circular_buffer_prefetch_gap,
/*type=*/cb_type);
}
}

Expand Down
7 changes: 3 additions & 4 deletions csrc/scheduler/hopper_multi_matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,14 @@ class HopperMultipleMatmulScheduler : public MultipleMatmulScheduler {
public:
HopperMultipleMatmulScheduler(Fusion* fusion, const MatmulParams* params)
: MultipleMatmulScheduler(fusion, params) {
const auto device_prop = at::cuda::getCurrentDeviceProperties();
const int cc = device_prop->major * 10 + device_prop->minor;
NVF_ERROR(
cc >= 90 && cc < 100, "This matmul scheduler is restricted to Hopper.");
validate();
}

void run() final;

private:
void validate() const;

void cacheInputsAndOutputs();

// Including current tensor naming convention for reference,
Expand Down
160 changes: 158 additions & 2 deletions csrc/scheduler/matmul_heuristic.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,129 @@ class MatmulParams : public HeuristicParams {
//! Specify the type of MMA op to be used in generated kernel.
MmaMacro mma_macro = MmaMacro::NoMMA;

// [Basic Matmul Configuration]
// We compute matrix products by decomposing the output into tiles. The size
// of these tiles is specified by the M and N dimensions of the CTA tile. The
// K dimension of the CTA tile must equal that of the warp tile, and will be
// discussed later.
//
// Normally, each output tile is processed by a single CTA (for exceptions,
// see notes about split-K and stream-K below). That CTA contains threads
// organized into warps of 32 threads and (on Hopper+) warpgroups consisting
// of 4 warps. The warp tile's M and N dimensions indicate the subtile of the
// CTA tile that each warp or warpgroup is responsible for computing.
//
// The K dimension of the warp tile, which must match that of the CTA tile,
// indicates the K dimension of the operand tiles that are processed in the
// "K loop", a serial loop in the generated kernel used to accumulate
// contributions to the mma result via summation in a register buffer in each
// thread.
//
// The MmaMacro determines the actual PTX instruction used to compute a small
// matrix-matrix product on the device's tensor cores. These macros determine
// an "instruction tile" which can be computed in a single instruction. The
// number of instruction tiles that make up a single warp tile translate to
// loops in the generated kernel inside of the K loop, allowing each thread
// to compute a warp tile result that is larger than the specific
// instruction. Importantly, the warp tile determines the amount of data that
// must be loaded before performing the loop to issue mma instructions, so the
// warp tile provides a lower bound on the size of each loading or circular
// buffering stage.
//
// [Detailed Matmul Configuration]
// One simple way to compute the output tiles is to assign each CTA tile to
// an individual CTA, launching a 2D grid that matches the tiling of the
// output matrix. Each of those CTAs can then compute a single K loop,
// loading one CTA tile at each iteration, in order to accumulate the result
// for a single output tile. This might require multiple waves of CTAs to be
// launched, and each one will need to compute a prologue consisting of some
// indexing expressions. Furthermore, the epilogue computation must complete
// before each SM can launch the next CTA to which it is assigned.
//
// Alternatively, we could launch exactly one CTA per SM on the device. This
// allows us to compute some of the prologue once, then loop over a set of
// output tiles. For each output tile we then compute a K loop and epilogue.
// However, along with warp specialization and other approaches, we can
// sometimes begin loading data for the next tile before the epilogue is
// complete (see below). We call such an approach a "persistent kernel".
//
// Within each iteration of the K loop, two distinct things need to happen.
// First, we need to load data from the operands to the SM in either shared
// memory or registers. Then we need to perform a set of mma instructions to
// compute the contribution of a warp tile to the final result. Waiting for
// the data to load before computing the mma instructions would mean leaving
// the tensor cores idle, hurting performance. Instead, we commonly employ
// circular buffering, wherein at each iteration of the K loop we launch an
// asynchronous load of data for a future iteration. This way each thread
// only needs to launch an asynchronous load, then wait for a previous load
// to complete before computing mma instructions. This is called the
// "pipelined" strategy wherein we leave a number of asynchronous transfers
// in flight at all points of the K loop.
//
// The load instructions inside each K loop iteration can also be avoided by
// moving them to a separate thread. This is done via "warp specialization":
// we launch one additional warp group called the "dma warp group" whose only
// responsibility is to monitor the circular buffer and issue asynchronous
// load instructions. The mma instructions are left to the other warp groups,
// which we call "math warp groups".
//
// [Split-K and Stream-K]
// When the M, N are much smaller than the K dimension, distributing separate
// output tiles across the grid will not fully-occupy all compute resources on
// the GPU. An alternative is to parallelize work along the K dimension and
// then have a single CTA aggregate results for an output tile.
//
// Split-K divides the K dimension by constant factor. For example, when the
// split-k factor is 4, the k dimension is split across 4 CTAs. Each CTA
// accumulates a (CTA-M, CTA-N, K/4) output tile. A grid reductions is then
// performed on the K dimension to get the complete (CTA-M, CTA-N) tile. When
// the split-k factor is 1, it is equivalent to the data-parallel approach.
//
// The Steam-K approach combines the persistent grid strategy, which launches
// a single wave of CTAs, and k dimension parallelization. The core idea is to
// have each SM complete a fixed unit of work per stage, utilizing M, N, and K
// dimension parallelization. Each CTA computes a fixed (CTA-M, CTA-N, CTA-K)
// tile per stage. CTA-K dimension may split across multiple (CTA-M, CTA-N)
// output tiles. Once all partial tiles are completed, a grid sum accumulates
// all partial tiles. The advantage of stream-k over split-k is finding the
// optimal split-k factor to avoid wave quantization is non-trivial.
//
// When (CTA-K == K), then stream-k is equivalent to the persistent
// data-parallel strategy. When K dimension is evenly divided among CTAs (K %
// CTA-K == 0), then stream-k is equivalent to persistent split-k strategy.

//! Specify whether to use a 1-1 mapping from output tile to CTA or to launch
//! one CTA per SM then loop over a subset of output tiles within the kernel
//! (persistent).
enum class TilingStrategy {
OneTilePerCTA, // Map each output tile to a single CTA and launch as many as
// are needed to cover the tile grid. This is also commonly
// referred to as the (data-parallel) strategy.
DistributeTilesAcrossSMs, // Use persistent kernels to compute entire output
// tiles
DistributeStagesAcrossSMs // Use persistent kernels to compute whole and
// partial output tiles (stream-K)
} tiling_strategy = TilingStrategy::OneTilePerCTA;

//! Configure circular buffering loops
enum class BufferingLoopLevel {
CTATiles, // Warp groups cooperatively compute whole CTA tiles in each
// K iteration. If splitk_factor > 1, all math warp groups
// cooperate, but only for a portion of the whole K loop.
// splitk_factor > 1 requires a grid reduction to combine the
// contributions from each portion. Also called split-K.
WarpTiles // All warp tiles in a K loop for each math warp group are
// iterated over then the next math warp group's warp tile is
// processed. Also called ping-pong or alternating stratgy.
} buffering_loop_level = BufferingLoopLevel::CTATiles;

//! Whether to do regular circular buffering (pipelined) or warp
//! specialization using an additional dma warp group
enum class CircularBufferingStrategy {
Pipelined,
WarpSpecialized
} circular_buffering_strategy = CircularBufferingStrategy::Pipelined;

//! Specify CTA rastrization order.
TileRasterizationOrder cta_order = TileRasterizationOrder::RowMajor;

Expand Down Expand Up @@ -240,8 +363,41 @@ class MatmulParams : public HeuristicParams {
<< ((cta_order == TileRasterizationOrder::RowMajor) ? "row-major"
: "column-major")
<< "\n"
<< "Grid swizzle factor: " << grid_swizzle_factor << "\n"
<< cluster_dims.toString() << "\n"
<< "Grid swizzle factor: " << grid_swizzle_factor << "\n";
ss << "Tiling strategy: ";
switch (tiling_strategy) {
case TilingStrategy::OneTilePerCTA:
ss << "OneTilePerCTA";
break;
case TilingStrategy::DistributeTilesAcrossSMs:
ss << "DistributeTilesAcrossSMs";
break;
case TilingStrategy::DistributeStagesAcrossSMs:
ss << "DistributeStagesAcrossSMs";
break;
}
ss << "\n";
ss << "Buffering loop level: ";
switch (buffering_loop_level) {
case BufferingLoopLevel::CTATiles:
ss << "CTATiles";
break;
case BufferingLoopLevel::WarpTiles:
ss << "WarpTiles";
break;
}
ss << "\n";
ss << "Circular buffering strategy: ";
switch (circular_buffering_strategy) {
case CircularBufferingStrategy::Pipelined:
ss << "Pipelined";
break;
case CircularBufferingStrategy::WarpSpecialized:
ss << "WarpSpecialized";
break;
}
ss << "\n";
ss << cluster_dims.toString() << "\n"
<< "Use shared memory epilogue: " << use_smem_epilogue << "\n"
<< "Promote re-use of prologue shared memory: "
<< promote_prologue_smem_reuse << "\n"
Expand Down
Loading

0 comments on commit 99b5f96

Please sign in to comment.