diff --git a/torch_xla/csrc/ops/ops.cpp b/torch_xla/csrc/ops/ops.cpp index 39c4bf54321..8f682ce5188 100644 --- a/torch_xla/csrc/ops/ops.cpp +++ b/torch_xla/csrc/ops/ops.cpp @@ -692,7 +692,21 @@ torch::lazy::NodePtr Gelu(const torch::lazy::Value& input) { auto lower_fn = [](const XlaNode& node, LoweringContext* loctx) -> XlaOpVector { xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0)); - return node.ReturnOp(BuildGelu(xla_input), loctx); + + // Building composite computation. + const std::string name = std::string(GetCompositeNamespace()) + "gelu"; + xla::XlaBuilder builder(name); + xla::XlaOp arg = xla::Parameter( + &builder, 0, ShapeHelper::ShapeOfXlaOp(xla_input), "arg"); + xla::XlaOp ret = BuildGelu(arg); + xla::XlaComputation computation = ConsumeValue(builder.Build(ret)); + + // Building call to computation. + std::vector inputs{xla_input}; + xla::XlaOp output = + xla::CompositeCall(loctx->builder(), computation, inputs, name); + + return node.ReturnOp(output, loctx); }; return GenericOp(torch::lazy::OpKind(at::aten::gelu), {input}, GetXlaShape(input), std::move(lower_fn)); @@ -704,7 +718,25 @@ torch::lazy::NodePtr GeluBackward(const torch::lazy::Value& grad_output, LoweringContext* loctx) -> XlaOpVector { xla::XlaOp xla_grad_output = loctx->GetOutputOp(node.operand(0)); xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(1)); - return node.ReturnOp(BuildGeluBackward(xla_grad_output, xla_input), loctx); + + // Building composite computation. + const std::string name = + std::string(GetCompositeNamespace()) + "gelu_backward"; + xla::XlaBuilder builder(name); + xla::XlaOp arg_grad_output = + xla::Parameter(&builder, 0, ShapeHelper::ShapeOfXlaOp(xla_grad_output), + "arg_grad_output"); + xla::XlaOp arg_input = xla::Parameter( + &builder, 1, ShapeHelper::ShapeOfXlaOp(xla_input), "arg_input"); + xla::XlaOp ret = BuildGeluBackward(arg_grad_output, arg_input); + xla::XlaComputation computation = ConsumeValue(builder.Build(ret)); + + // Building call to computation. + std::vector inputs{xla_grad_output, xla_input}; + xla::XlaOp output = + xla::CompositeCall(loctx->builder(), computation, inputs, name); + + return node.ReturnOp(output, loctx); }; return GenericOp(torch::lazy::OpKind(at::aten::gelu_backward), {grad_output, input}, GetXlaShape(input), diff --git a/torch_xla/csrc/ops/softmax.cpp b/torch_xla/csrc/ops/softmax.cpp index 1ca0f2db0ad..87d0f55f876 100644 --- a/torch_xla/csrc/ops/softmax.cpp +++ b/torch_xla/csrc/ops/softmax.cpp @@ -4,6 +4,7 @@ #include "torch_xla/csrc/convert_ops.h" #include "torch_xla/csrc/lowering_context.h" +#include "torch_xla/csrc/shape_helper.h" #include "torch_xla/csrc/softmax_builder.h" #include "torch_xla/csrc/tensor_util.h" #include "torch_xla/csrc/torch_util.h" @@ -44,7 +45,22 @@ torch::lazy::NodePtr Softmax::Clone(torch::lazy::OpList operands) const { XlaOpVector Softmax::Lower(LoweringContext* loctx) const { xla::XlaOp input = loctx->GetOutputOp(operand(0)); - return ReturnOp(LowerSoftmax(input, dim_, dtype_), loctx); + + // Build computation. + const std::string name = std::string(GetCompositeNamespace()) + "softmax"; + const std::string attr = "{dim = " + std::to_string(dim_) + " : i64}"; + xla::XlaBuilder builder(name); + xla::XlaOp arg = + xla::Parameter(&builder, 0, ShapeHelper::ShapeOfXlaOp(input), "arg"); + xla::XlaOp ret = LowerSoftmax(arg, dim_, dtype_); + xla::XlaComputation computation = ConsumeValue(builder.Build(ret)); + + // Build call to computation. + std::vector inputs{input}; + xla::XlaOp output = + xla::CompositeCall(loctx->builder(), computation, inputs, name, attr); + + return ReturnOp(output, loctx); } std::string Softmax::ToString() const { diff --git a/torch_xla/csrc/ops/softmax_backward.cpp b/torch_xla/csrc/ops/softmax_backward.cpp index c34d22e957a..788b8818a83 100644 --- a/torch_xla/csrc/ops/softmax_backward.cpp +++ b/torch_xla/csrc/ops/softmax_backward.cpp @@ -3,7 +3,9 @@ #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/runtime/debug_macros.h" +#include "torch_xla/csrc/shape_helper.h" #include "torch_xla/csrc/softmax_builder.h" +#include "torch_xla/csrc/torch_util.h" namespace torch_xla { @@ -23,8 +25,25 @@ torch::lazy::NodePtr SoftmaxBackward::Clone( XlaOpVector SoftmaxBackward::Lower(LoweringContext* loctx) const { xla::XlaOp grad_output = loctx->GetOutputOp(operand(0)); xla::XlaOp output = loctx->GetOutputOp(operand(1)); + + // Build computation. + const std::string name = + std::string(GetCompositeNamespace()) + "softmax_backward"; + const std::string attr = "{dim = " + std::to_string(dim_) + " : i64}"; + xla::XlaBuilder builder(name); + xla::XlaOp arg_grad_output = xla::Parameter( + &builder, 0, ShapeHelper::ShapeOfXlaOp(grad_output), "arg_grad_output"); + xla::XlaOp arg_output = xla::Parameter( + &builder, 1, ShapeHelper::ShapeOfXlaOp(grad_output), "arg_output"); + xla::XlaOp ret = BuildSoftmaxGrad(/*grad_output=*/arg_grad_output, + /*output=*/arg_output, dim_); + xla::XlaComputation computation = ConsumeValue(builder.Build(ret)); + + // Build call to computation. + std::vector inputs{grad_output, output}; xla::XlaOp grad_input = - BuildSoftmaxGrad(/*grad_output=*/grad_output, /*output=*/output, dim_); + xla::CompositeCall(loctx->builder(), computation, inputs, name, attr); + return ReturnOp(grad_input, loctx); } diff --git a/torch_xla/csrc/torch_util.cpp b/torch_xla/csrc/torch_util.cpp index 1d5e3616643..a8a7a56442f 100644 --- a/torch_xla/csrc/torch_util.cpp +++ b/torch_xla/csrc/torch_util.cpp @@ -73,6 +73,11 @@ at::Tensor MaybeWrapTensorToFunctional(const at::Tensor& tensor) { return at::functionalization::impl::to_functional_tensor(tensor); } +absl::string_view GetCompositeNamespace() { + static const char* kCompositePrefix = "ptxla."; + return absl::string_view(kCompositePrefix); +} + } // namespace torch_xla namespace torch { diff --git a/torch_xla/csrc/torch_util.h b/torch_xla/csrc/torch_util.h index 82b73cadfd5..598cc7048ae 100644 --- a/torch_xla/csrc/torch_util.h +++ b/torch_xla/csrc/torch_util.h @@ -73,6 +73,9 @@ inline bool IsDefined(const std::optional& tensor) { return tensor.has_value() && tensor.value().defined(); } +// The namespace to generate composite op. +absl::string_view GetCompositeNamespace(); + } // namespace torch_xla namespace torch {