Skip to content

Commit

Permalink
Use replayExprWithNewInput to replay permutes. (#1861)
Browse files Browse the repository at this point in the history
  • Loading branch information
wujingyue authored Mar 1, 2024
1 parent bf7be62 commit ee63c8f
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 17 deletions.
30 changes: 14 additions & 16 deletions csrc/preseg_passes/move_split_cat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <ir/internal_base_nodes.h>
#include <ir/utils.h>
#include <ops/alias.h>
#include <transform_replay.h>

namespace nvfuser::preseg_passes {

Expand Down Expand Up @@ -262,25 +263,22 @@ void CancelSplitCat::run() {
continue;
}

TensorView* merged_out = split_in;
Val* merged_out = split_in;
for (auto i = use_def_chain.rbegin(), end = use_def_chain.rend(); i != end;
i++) {
Expr* to_replay = *i;
// TODO(wujingyue): instead of an op-type dispatch, try a more general
// approach suggested by @jacobhinkle:
// https://github.com/NVIDIA/Fuser/pull/1782#discussion_r1496123087.
if (to_replay->isA<LoadStoreOp>()) {
auto* set_out = to_replay->output(0)->as<TensorView>();
std::vector<int64_t> permutation = *ir_utils::computePermutation(
set_out->getRootDomain(), set_out->getMaybeRFactorDomain());
merged_out = permute(merged_out, permutation);
continue;
}
NVF_ERROR(false, "Replay is not implemented for this Expr: ", to_replay);
Expr* merged = replayExprWithNewInput(*i, merged_out);
NVF_ERROR(
merged->outputs().size() == 1,
"Currently, we merge only unary ops, so it would be a programming "
"mistake when the number of outputs is ",
merged->outputs().size());
merged_out = merged->output(0);
}

ir_utils::replaceValInAllExprInputsAndFusionOutputs(
cat->output(0), merged_out);
// `cat->output(0)` may be a fusion output with allocation domain.
// Therefore, instead of replacing the output, we create a Set to preserve
// the output allocation domain.
IrBuilder::create<LoadStoreOp>(
LoadStoreOpType::Set, cat->output(0), merged_out);
}
}

Expand Down
46 changes: 46 additions & 0 deletions test/test_move_split_cat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

namespace nvfuser {

using testing::Contains;

using MoveSplitCatTest = NVFuserTest;

TEST_F(MoveSplitCatTest, Cancellable_Adjacent) {
Expand Down Expand Up @@ -116,6 +118,50 @@ TEST_F(MoveSplitCatTest, Cancellable_PermuteInBetween) {
EXPECT_TRUE(out_tensors[0].is_alias_of(in_tensor));
}

MATCHER(IsPermute, "") {
if (auto* set = dynamic_cast<LoadStoreOp*>(arg)) {
if (auto* set_out = dynamic_cast<TensorView*>(set->out())) {
return set_out->hasRFactor();
}
}
return false;
}

TEST_F(MoveSplitCatTest, Cancellable_IncompatibleAllocationOrder) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

TensorView* in = makeContigConcreteTensor({2, 3, 5});
TensorView* s0 = slice(in, {0, 0, 0}, {2, 3, 2});
TensorView* s1 = slice(in, {0, 0, 2}, {2, 3, 5});
s0 = permute(s0, {1, 0, 2});
s1 = permute(s1, {1, 0, 2});
TensorView* out = cat({s0, s1}, /*dim=*/-1);
out->setAllocationDomain({out->axis(2), out->axis(0), out->axis(1)}, true);

fusion->addInput(in);
fusion->addOutput(out);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor in_tensor = at::randn({2, 3, 5}, options);

FusionExecutorCache fec(std::move(fusion));
auto out_tensors = fec.runFusionWithInputs({in_tensor});
testValidate(fec.fusion(), out_tensors, {in_tensor}, __LINE__, __FILE__);

// Check the two permutes are merged to one.
FusionKernelRuntime* runtime = fec.getMostRecentKernelRuntime();
ASSERT_EQ(runtime->executors().size(), 1)
<< "After merging, the whole fusion can be scheduled unsegmented.";
const FusionExecutor& executor = runtime->executors().front();
kir::Kernel* kernel = executor.kernel();
EXPECT_THAT(kernel->exprs(), Contains(IsPermute()).Times(1));

// Due to the incompatible output allocation order, the output can't be an
// alias.
EXPECT_FALSE(out_tensors[0].is_alias_of(in_tensor));
}

TEST_F(MoveSplitCatTest, Cancellable_MultiplePermutesInBetween) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());
Expand Down
2 changes: 1 addition & 1 deletion third_party/googletest
Submodule googletest updated 241 files

0 comments on commit ee63c8f

Please sign in to comment.