Skip to content

Commit

Permalink
[XLA:GPU][NFC] Beautify code related to Triton fusions
Browse files Browse the repository at this point in the history
This is a little beautification for the earlier "Split GemmRewriterTriton into 4 parts" effort.

PiperOrigin-RevId: 586278961
  • Loading branch information
tdanyluk authored and tensorflower-gardener committed Nov 29, 2023
1 parent dedcf06 commit 7a7a4c1
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 32 deletions.
38 changes: 21 additions & 17 deletions third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ namespace gpu {

namespace {

using triton_fusion::DimOrdersAndReqs;
using triton_fusion::DimOrdersAndReqsOrError;
using triton_fusion::FusionContext;
using triton_fusion::GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible;
using triton_fusion::TransformDirection;

using OldToNewHloMap =
absl::flat_hash_map<const HloInstruction*, HloInstruction*>;

Expand Down Expand Up @@ -170,14 +176,13 @@ void TryToFuseWithInputsRecursively(HloInstruction& root,
continue;
}
num_requeued = 0;
const triton_fusion::DimOrdersAndReqsOrError result =
const DimOrdersAndReqsOrError result =
GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible(
*hlo, triton_fusion::TransformDirection::kOutputToInput,
*hlo, TransformDirection::kOutputToInput,
/*src_operand_index=*/std::nullopt, context.dim_orders().at(hlo),
gpu_version, context.hero_properties());
if (!std::holds_alternative<triton_fusion::DimOrdersAndReqs>(result) ||
!context.CombineDimOrdersAndReqs(
std::get<triton_fusion::DimOrdersAndReqs>(result))) {
if (!std::holds_alternative<DimOrdersAndReqs>(result) ||
!context.CombineDimOrdersAndReqs(std::get<DimOrdersAndReqs>(result))) {
continue;
}
if (hlo->opcode() != HloOpcode::kParameter) {
Expand Down Expand Up @@ -236,12 +241,12 @@ StatusOr<FusionDecision> FuseDot(HloInstruction& dot,
// differently shaped tiles but may go through same HLO graph nodes.
// Direct dot inputs have well defined dimension orders.

auto fuse_inputs = [&](int operand_number, OldToNewHloMap& old_to_new_map)
-> StatusOr<triton_fusion::FusionContext> {
auto fuse_inputs =
[&](int operand_number,
OldToNewHloMap& old_to_new_map) -> StatusOr<FusionContext> {
const int operand_count_before = fusion_inputs.size();
// Direct dot inputs have well defined dimension orders.
auto context =
triton_fusion::FusionContext::FromDotOperand(dot, operand_number);
auto context = FusionContext::FromDotOperand(dot, operand_number);
TryToFuseWithInputsRecursively(*dot.mutable_operand(operand_number),
gpu_version, context, old_to_new_map,
fusion_inputs, builder);
Expand All @@ -255,7 +260,7 @@ StatusOr<FusionDecision> FuseDot(HloInstruction& dot,

// Original instruction -> fused one. Separate for each scope.
OldToNewHloMap lhs_old_to_new_map;
TF_ASSIGN_OR_RETURN(const triton_fusion::FusionContext lhs_context,
TF_ASSIGN_OR_RETURN(const FusionContext lhs_context,
fuse_inputs(0, lhs_old_to_new_map));

OldToNewHloMap rhs_old_to_new_map;
Expand All @@ -272,7 +277,7 @@ StatusOr<FusionDecision> FuseDot(HloInstruction& dot,
// Fusion at dot's output.

// These describe _outputs_ of corresponding HLOs.
auto context = triton_fusion::FusionContext::FromDotOutput(
auto context = FusionContext::FromDotOutput(
dot, /*split_k=*/1, lhs_context.splittable_dimension_major_part_size());
HloInstruction* fusion_output = &dot;
bool output_changed = true;
Expand All @@ -285,15 +290,14 @@ StatusOr<FusionDecision> FuseDot(HloInstruction& dot,
if (!IsDistributiveOverAddition(*user)) {
break;
}
triton_fusion::DimOrdersAndReqsOrError result =
triton_fusion::GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible(
*user, triton_fusion::TransformDirection::kInputToOutput,
DimOrdersAndReqsOrError result =
GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible(
*user, TransformDirection::kInputToOutput,
user->operand_index(fusion_output),
context.dim_orders().at(fusion_output), gpu_version,
context.hero_properties());
if (!std::holds_alternative<triton_fusion::DimOrdersAndReqs>(result) ||
!context.CombineDimOrdersAndReqs(
std::get<triton_fusion::DimOrdersAndReqs>(result))) {
if (!std::holds_alternative<DimOrdersAndReqs>(result) ||
!context.CombineDimOrdersAndReqs(std::get<DimOrdersAndReqs>(result))) {
break;
}
for (HloInstruction* operand : user->operands()) {
Expand Down
41 changes: 26 additions & 15 deletions third_party/xla/xla/service/gpu/triton_fusion_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,17 @@ limitations under the License.
namespace xla {
namespace gpu {

namespace {

using triton_fusion::DimOrdersAndReqs;
using triton_fusion::DimOrdersAndReqsOrError;
using triton_fusion::FusionContext;
using triton_fusion::GetPropagatedDimOrdersAndRequirements;
using triton_fusion::kNoSplitRequirement;
using triton_fusion::TransformDirection;

} // namespace

namespace triton_fusion {

/*static*/ FusionContext FusionContext::FromDotOperand(
Expand Down Expand Up @@ -103,6 +114,7 @@ namespace triton_fusion {
}

namespace {

// Tells how many new parameters does a fusion gain by fusing the operation as
// an input.
int64_t NumAddedParameters(const HloInstruction& hlo) {
Expand All @@ -114,6 +126,7 @@ int64_t NumAddedParameters(const HloInstruction& hlo) {
// All other instructions add all own inputs and remove own single output.
return hlo.operand_count() - 1;
}

} // namespace

bool FusionContext::CombineDimOrdersAndReqs(const DimOrdersAndReqs& update) {
Expand All @@ -127,7 +140,7 @@ bool FusionContext::CombineDimOrdersAndReqs(const DimOrdersAndReqs& update) {
}

RequirementsOrError requirements_or_error =
triton_fusion::CombineRequirements(requirements_, update.requirements);
CombineRequirements(requirements_, update.requirements);
if (std::holds_alternative<FusionDecision>(requirements_or_error)) {
return false;
}
Expand Down Expand Up @@ -197,7 +210,7 @@ StatusOr<TritonFusionAnalysis> TritonFusionAnalysis::Execute(

Status TritonFusionAnalysis::ExecuteForSoftmaxFusion(
const HloInstruction& root) {
auto context = triton_fusion::FusionContext::FromSoftmaxRoot(root);
auto context = FusionContext::FromSoftmaxRoot(root);
// Softmax fusion uses one tiled scope.
TF_RETURN_IF_ERROR(context.PropagateDimensionOrdersToParameters(
root, parameters_[Scope::OUTPUT], iter_specs_[Scope::OUTPUT]));
Expand All @@ -208,11 +221,10 @@ Status TritonFusionAnalysis::ExecuteForSoftmaxFusion(

Status TritonFusionAnalysis::ExecuteForDotFusion(const HloInstruction& dot,
const int split_k) {
int64_t lhs_nc_split_major_part_size = triton_fusion::kNoSplitRequirement;
int64_t lhs_nc_split_major_part_size = kNoSplitRequirement;
for (const Scope scope : {Scope::LHS, Scope::RHS}) {
const int operand_number = static_cast<int>(scope);
auto context = triton_fusion::FusionContext::FromDotOperand(
dot, operand_number, split_k);
auto context = FusionContext::FromDotOperand(dot, operand_number, split_k);
TF_RETURN_IF_ERROR(context.PropagateDimensionOrdersToParameters(
*dot.operand(operand_number), parameters_[scope], iter_specs_[scope]));
if (scope == Scope::LHS) {
Expand All @@ -221,24 +233,21 @@ Status TritonFusionAnalysis::ExecuteForDotFusion(const HloInstruction& dot,
}
}

auto context = triton_fusion::FusionContext::FromDotOutput(
dot, split_k, lhs_nc_split_major_part_size);
auto context =
FusionContext::FromDotOutput(dot, split_k, lhs_nc_split_major_part_size);
const HloInstruction* output = &dot;
// Currently supported is one fusion output and one path from dot to it.
// Propagate dimension order from dot to root.
while (!output->IsRoot()) {
TF_RET_CHECK(output->user_count() == 1);
const HloInstruction* input = output;
output = output->users()[0];
triton_fusion::DimOrdersAndReqsOrError result =
triton_fusion::GetPropagatedDimOrdersAndRequirements(
*output, context.dim_orders().at(input),
triton_fusion::TransformDirection::kInputToOutput,
context.hero_properties());
DimOrdersAndReqsOrError result = GetPropagatedDimOrdersAndRequirements(
*output, context.dim_orders().at(input),
TransformDirection::kInputToOutput, context.hero_properties());
TF_RET_CHECK(std::holds_alternative<DimOrdersAndReqs>(result));
TF_RET_CHECK(
std::holds_alternative<triton_fusion::DimOrdersAndReqs>(result));
TF_RET_CHECK(context.CombineDimOrdersAndReqs(
std::get<triton_fusion::DimOrdersAndReqs>(result)));
context.CombineDimOrdersAndReqs(std::get<DimOrdersAndReqs>(result)));
}
TF_RET_CHECK(
iter_specs_[Scope::OUTPUT]
Expand Down Expand Up @@ -267,6 +276,7 @@ const TensorIterationSpec::DimIterationSpec* TritonFusionAnalysis::IterSpec(
}

namespace {

std::string IterationSpecByInstructionMapToString(
const TritonFusionAnalysis::IterationSpecByInstructionMap& m) {
return absl::StrCat("IterSpec{",
Expand All @@ -288,6 +298,7 @@ std::string ScopeToString(TritonFusionAnalysis::Scope s) {
return "OUTPUT";
}
}

} // namespace

std::string TritonFusionAnalysis::ToString() const {
Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/service/gpu/triton_support.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/stream_executor/device_description.h"
#include "xla/xla_data.pb.h"

namespace xla {
namespace gpu {

Expand Down
2 changes: 2 additions & 0 deletions third_party/xla/xla/service/gpu/triton_tiling_propagation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ TensorIterationSpec DimensionOrder::ToTensorIterationSpec() const {
}

namespace {

// Logical index of a dimension in `shape` labeled with `label` in the
// `dim_order` describing the shape.
std::optional<int> LogicalIndexOfLabeledDimension(
Expand Down Expand Up @@ -261,6 +262,7 @@ RequirementsOrError CombineSoftmaxRequirements(SoftmaxRequirements a,
// SoftmaxRequirements is an empty class for now.
return a;
}

} // namespace

RequirementsOrError CombineRequirements(Requirements a,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/service/instruction_fusion.h"
#include "xla/stream_executor/device_description.h"

namespace xla {
namespace gpu {

Expand Down

0 comments on commit 7a7a4c1

Please sign in to comment.