Skip to content

Commit

Permalink
extend circular buffer tests to test 1d TMA and fix index for 1dtma (#…
Browse files Browse the repository at this point in the history
…3859)

**Two changes in this PR:**
(1) circular buffer tests are extended to test both
`LoadStoreOpType::CpAsyncBulkTensorTile` and
`LoadStoreOpType::CpAsyncBulk`
(2) use IdModel indexing for 1D TMA, avoid offset bug when using warp
specilization with prefetch
  • Loading branch information
liqiangxl authored Feb 15, 2025
1 parent 11cece1 commit 204d795
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 28 deletions.
6 changes: 4 additions & 2 deletions csrc/device_lower/analysis/device_version.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,12 @@ void MinimumDeviceVersion::handle(LoadStoreOp* ls_op) {
if (ls_op->opType() == LoadStoreOpType::CpAsync) {
ensureVersion(
{8, 0}, "LoadStoreOpType::CpAsync requires Ampere (8.0) or newer");
} else if (ls_op->opType() == LoadStoreOpType::CpAsyncBulkTensorTile) {
} else if (
ls_op->opType() == LoadStoreOpType::CpAsyncBulkTensorTile ||
ls_op->opType() == LoadStoreOpType::CpAsyncBulk) {
ensureVersion(
{9, 0},
"LoadStoreOpType::CpAsyncBulkTensorTile requires Hopper (9.0) or newer");
"LoadStoreOpType::CpAsyncBulk{TensorTile} requires Hopper (9.0) or newer");
}
}

Expand Down
9 changes: 5 additions & 4 deletions csrc/index_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2144,13 +2144,15 @@ kir::TensorIndex* Index::getProducerIndex(
Val* index = nullptr;
bool is_producer_tma_op = producer->definition() != nullptr &&
producer->definition()->isA<LoadStoreOp>() &&
producer->definition()->as<LoadStoreOp>()->opType() ==
LoadStoreOpType::CpAsyncBulkTensorTile;
ir_utils::isCpAsyncBulkLoad(producer->definition());
bool is_consumer_tma_op = consumer->definition() != nullptr &&
consumer->definition()->isA<LoadStoreOp>() &&
ir_utils::isCpAsyncBulkLoad(consumer->definition());

if (!ir_utils::hasRootToLoopLinearTransformations(producer) ||
(consumer->definition()->isA<MmaOp>() &&
isHopper(consumer->definition()->as<MmaOp>()->macro())) ||
is_producer_tma_op ||
is_producer_tma_op || is_consumer_tma_op ||
GpuLower::current()->idModelOptions().producerIndex()) {
index = GpuLower::current()->tensorIndexer().getLinearIndex(
producer, consumer->definition(), loops);
Expand Down Expand Up @@ -2678,7 +2680,6 @@ std::pair<Val*, Val*> Index::getCpAsyncBulkGmemIndex(

// 1D TMA without tensor map
if (ldst->opType() == LoadStoreOpType::CpAsyncBulk) {
NVF_ERROR(dim == 1L, "1D TMA but got more than one indices.")
if (is_load) {
std::stringstream ss;
ss << "Hopper::CpAsyncBulkG2SIndex";
Expand Down
110 changes: 88 additions & 22 deletions tests/cpp/test_circular_buffering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -977,8 +977,13 @@ INSTANTIATE_TEST_SUITE_P(
StagesAndPrefetches(),
nonTMAName);

using TmaCircularBufferingParams =
std::tuple<int64_t, int64_t, int64_t, int64_t, CircularBufferType>;
using TmaCircularBufferingParams = std::tuple<
int64_t,
int64_t,
int64_t,
int64_t,
CircularBufferType,
LoadStoreOpType>;

class TmaCircularBufferingTest
: public NVFuserFixtureParamTest<TmaCircularBufferingParams> {
Expand All @@ -988,13 +993,15 @@ class TmaCircularBufferingTest
int64_t tensor_outer_dim = 1;
int64_t tensor_inner_dim = 1;
CircularBufferType circular_buffer_type;
LoadStoreOpType tma_load_type;

void SetUp() override {
number_of_stages = std::get<0>(GetParam());
prefetch_distance = std::get<1>(GetParam());
tensor_outer_dim = std::get<2>(GetParam());
tensor_inner_dim = std::get<3>(GetParam());
circular_buffer_type = std::get<4>(GetParam());
tma_load_type = std::get<5>(GetParam());

// NOTE: Multiple of 16 required for inner dimension
NVF_ERROR(tensor_inner_dim % 16 == 0);
Expand All @@ -1007,6 +1014,14 @@ class TmaCircularBufferingTest
.num_registers.has_value();
}

// https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk
// the memory range [srcMem, srcMem + size - 1] must not overflow the source
// memory space. Otherwise, the behavior is undefined.
bool tma1dSrcAddressOverflow(int64_t bulk_inner_dim) {
return tensor_inner_dim % bulk_inner_dim != 0 &&
tma_load_type == LoadStoreOpType::CpAsyncBulk;
}

template <typename data_type>
void compare(int64_t tensor_dim, at::Tensor result, at::Tensor reference) {
at::Tensor reference_cpu_data = reference.cpu();
Expand Down Expand Up @@ -1158,14 +1173,17 @@ TEST_P(TmaCircularBufferingTest, SingleDim) {
TensorView* tv1 = exp(tv0);
fusion->addOutput(tv1);

TensorView* tv2 = tv0->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile);
TensorView* tv2 = tv0->cacheAfter(tma_load_type);
tv2->setMemoryType(MemoryType::Shared);

TensorView* reference = tv1;

// Constants
constexpr size_t bulk_inner_dim = 32;

if (tma1dSrcAddressOverflow(bulk_inner_dim)) {
GTEST_SKIP() << "cp.async.bulk doesn't allow src address overflow!";
return;
}
// [M] -> [M/bid, bid]
reference->split(-1, bulk_inner_dim);

Expand Down Expand Up @@ -1212,15 +1230,18 @@ TEST_P(TmaCircularBufferingTest, SingleDimUnroll) {
TensorView* tv1 = exp(tv0);
fusion->addOutput(tv1);

TensorView* tv2 = tv0->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile);
TensorView* tv2 = tv0->cacheAfter(tma_load_type);
tv2->setMemoryType(MemoryType::Shared);

TensorView* reference = tv1;

// Constants
constexpr size_t unroll_dim = 4;
constexpr size_t bulk_inner_dim = 32;

if (tma1dSrcAddressOverflow(bulk_inner_dim)) {
GTEST_SKIP() << "cp.async.bulk doesn't allow src address overflow!";
return;
}
// [M] -> [M/bid, bid]
reference->split(-1, bulk_inner_dim);
// [M/bid, bid] -> [M/bid/unroll, unroll, bid]
Expand Down Expand Up @@ -1277,15 +1298,18 @@ TEST_P(TmaCircularBufferingTest, SingleDimUnswitch) {
TensorView* tv1 = exp(tv0);
fusion->addOutput(tv1);

TensorView* tv2 = tv0->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile);
TensorView* tv2 = tv0->cacheAfter(tma_load_type);
tv2->setMemoryType(MemoryType::Shared);

TensorView* reference = tv1;

// Constants
constexpr size_t unroll_dim = 4;
constexpr size_t bulk_inner_dim = 32;

if (tma1dSrcAddressOverflow(bulk_inner_dim)) {
GTEST_SKIP() << "cp.async.bulk doesn't allow src address overflow!";
return;
}
// [M] -> [M/bid, bid]
reference->split(-1, bulk_inner_dim);
// [M/bid, bid] -> [M/bid/unroll, unroll, bid]
Expand Down Expand Up @@ -1333,6 +1357,11 @@ TEST_P(TmaCircularBufferingTest, MultiDim) {
return;
}

if (tma_load_type == LoadStoreOpType::CpAsyncBulk) {
GTEST_SKIP() << "LoadStoreOpType::CpAsyncBulk only supports 1D TMA";
return;
}

std::unique_ptr<Fusion> fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

Expand Down Expand Up @@ -1412,7 +1441,7 @@ TEST_P(TmaCircularBufferingTest, Pointwise) {
fusion->addOutput(tv2);

// Use TMA to load TV0 into shared memory
TensorView* tv3 = tv0->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile);
TensorView* tv3 = tv0->cacheAfter(tma_load_type);
tv3->setMemoryType(MemoryType::Shared);

TensorView* tv4 = tv1->cacheAfter();
Expand All @@ -1422,7 +1451,10 @@ TEST_P(TmaCircularBufferingTest, Pointwise) {

// Constants
constexpr int64_t bulk_inner_dim = 32;

if (tma1dSrcAddressOverflow(bulk_inner_dim)) {
GTEST_SKIP() << "cp.async.bulk doesn't allow src address overflow!";
return;
}
// [M, N] -> [M, N/bid, bid]
reference->split(-1, bulk_inner_dim);

Expand Down Expand Up @@ -1488,7 +1520,7 @@ TEST_P(TmaCircularBufferingTest, PointwiseCpAsync) {
fusion->addOutput(tv2);

// Use TMA to load TV0 into shared memory
TensorView* tv3 = tv0->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile);
TensorView* tv3 = tv0->cacheAfter(tma_load_type);
tv3->setMemoryType(MemoryType::Shared);

// Load TV1 into shared memory
Expand All @@ -1499,7 +1531,10 @@ TEST_P(TmaCircularBufferingTest, PointwiseCpAsync) {

// Constants
constexpr int64_t bulk_inner_dim = 32;

if (tma1dSrcAddressOverflow(bulk_inner_dim)) {
GTEST_SKIP() << "cp.async.bulk doesn't allow src address overflow!";
return;
}
// [M, N] -> [M, N/bid, bid]
reference->split(-1, bulk_inner_dim);

Expand Down Expand Up @@ -1555,14 +1590,19 @@ TEST_P(TmaCircularBufferingTest, InnerReduction) {
TensorView* tv1 = sum(tv0, {-1});
fusion->addOutput(tv1);

TensorView* tv2 = tv0->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile);
TensorView* tv2 = tv0->cacheAfter(tma_load_type);
tv2->setMemoryType(MemoryType::Shared);

TensorView* reference = tv1;

constexpr int64_t examples_per_cta = 4;
constexpr int64_t bulk_inner_dim = 256;

if (tma1dSrcAddressOverflow(bulk_inner_dim)) {
GTEST_SKIP() << "cp.async.bulk doesn't allow src address overflow!";
return;
}

// [M, N] -> [M/epc, epc, N]
reference->split(0, examples_per_cta);
// [M/epc, epc, N] -> [M/epc, epc, N/bid, bid]
Expand Down Expand Up @@ -1620,12 +1660,16 @@ TEST_P(TmaCircularBufferingTest, OuterReduction) {
TensorView* tv1 = sum(tv0, {0});
fusion->addOutput(tv1);

TensorView* tv2 = tv0->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile);
TensorView* tv2 = tv0->cacheAfter(tma_load_type);
tv2->setMemoryType(MemoryType::Shared);

TensorView* reference = tv1;

constexpr int64_t tile_size = 256;
if (tma1dSrcAddressOverflow(tile_size)) {
GTEST_SKIP() << "cp.async.bulk doesn't allow src address overflow!";
return;
}

// [M, N] -> [M, N/bid, bid]
reference->split(1, tile_size);
Expand Down Expand Up @@ -1698,8 +1742,7 @@ TEST_P(TmaCircularBufferingTest, Persistent) {
fusion->addOutput(x_norm);

// Load input from global to shared memory
TensorView* x_cache_smem =
x->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile);
TensorView* x_cache_smem = x->cacheAfter(tma_load_type);
x_cache_smem->setMemoryType(MemoryType::Shared);

// Load input from shared memory to registers
Expand All @@ -1718,7 +1761,11 @@ TEST_P(TmaCircularBufferingTest, Persistent) {
constexpr int64_t vectorize = 4;
int64_t elem_per_compute_thread = tensor_inner_dim / width / vectorize;
constexpr int64_t examples_per_cta = 4;

constexpr int64_t tile_size = 256;
if (tma1dSrcAddressOverflow(tile_size)) {
GTEST_SKIP() << "cp.async.bulk doesn't allow src address overflow!";
return;
}
// Since multi-dim CpAsyncBulk has a size limit of 256 per dimension,
// we require multiple TMA operations to load the entire example in shared
// memory for pointwise kernel.
Expand All @@ -1727,7 +1774,7 @@ TEST_P(TmaCircularBufferingTest, Persistent) {
// logical domain: [I1, I2]
x_cache_smem->split(0, examples_per_cta);
// split: [I0 / 4, 4, I2]
x_cache_smem->split(-1, 256);
x_cache_smem->split(-1, tile_size);
// split: [I0/4, 4, I2/256, 256]

// Schedule reference_tv
Expand Down Expand Up @@ -1805,6 +1852,11 @@ TEST_P(TmaCircularBufferingTest, Matmul) {
return;
}

if (tma_load_type == LoadStoreOpType::CpAsyncBulk) {
GTEST_SKIP() << "LoadStoreOpType::CpAsyncBulk only supports 1D TMA";
return;
}

std::unique_ptr<Fusion> fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

Expand Down Expand Up @@ -1930,6 +1982,11 @@ TEST_P(TmaCircularBufferingTest, MatmulWithBroadcastedInput) {
return;
}

if (tma_load_type == LoadStoreOpType::CpAsyncBulk) {
GTEST_SKIP() << "LoadStoreOpType::CpAsyncBulk only supports 1D TMA";
return;
}

std::unique_ptr<Fusion> fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

Expand Down Expand Up @@ -2055,14 +2112,23 @@ auto tmaCircularBufferingParams() {
WarpSpecialized(ParallelType::TIDy),
WarpSpecialized(ParallelType::TIDx, std::make_pair(40, 240)),
WarpSpecialized(ParallelType::TIDy, std::make_pair(40, 240))};
const std::vector<LoadStoreOpType> tma_types{
LoadStoreOpType::CpAsyncBulk, LoadStoreOpType::CpAsyncBulkTensorTile};
std::vector<TmaCircularBufferingParams> values;
for (int64_t i : {2, 4}) {
for (int64_t j : c10::irange(-i, i)) {
for (int64_t m : {128, 500, 1024}) {
for (int64_t n : {128, 1024}) {
values.emplace_back(
i, j, m, n, all_types[lcg_parkmiller % all_types.size()]);
lcg_parkmiller = (uint64_t)lcg_parkmiller * 48271 % 0x7fffffff;
for (auto tma_load_type : tma_types) {
values.emplace_back(
i,
j,
m,
n,
all_types[lcg_parkmiller % all_types.size()],
tma_load_type);
lcg_parkmiller = (uint64_t)lcg_parkmiller * 48271 % 0x7fffffff;
}
}
}
}
Expand All @@ -2084,7 +2150,7 @@ std::string tmaName(
<< prefetch_distance_str << "_M_"
<< std::to_string(std::get<2>(info.param)) << "_N_"
<< std::to_string(std::get<3>(info.param)) << "_"
<< std::get<4>(info.param);
<< std::get<4>(info.param) << "_" << std::get<5>(info.param);
return ss.str();
}

Expand Down

0 comments on commit 204d795

Please sign in to comment.