From a92792872ce51ff8c9ea022711fb8b254dcbb8b4 Mon Sep 17 00:00:00 2001 From: Yifei Teng Date: Wed, 5 Mar 2025 12:24:35 -0800 Subject: [PATCH] Use placeholder tensor in scan (#8785) --- test/run_tests.sh | 1 + test/scan/test_scan.py | 33 +++++++++ test/test_placeholder.py | 73 +++++++++++++++++++ torch_xla/core/xla_builder.py | 33 +++++++++ torch_xla/csrc/init_python_bindings.cpp | 13 ++++ .../csrc/runtime/ifrt_computation_client.h | 8 +- .../csrc/runtime/pjrt_computation_client.h | 10 ++- torch_xla/experimental/scan.py | 4 +- 8 files changed, 168 insertions(+), 7 deletions(-) create mode 100644 test/test_placeholder.py diff --git a/test/run_tests.sh b/test/run_tests.sh index 46b729338b78..999837a33f7a 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -199,6 +199,7 @@ function run_xla_op_tests2 { run_test "$CDIR/scan/test_scan_layers.py" run_test "$CDIR/test_gru.py" run_test "$CDIR/test_as_stride_use_slice.py" + run_test "$CDIR/test_placeholder.py" run_xla_hlo_debug run_test "$CDIR/scan/test_scan_debug.py" run_test "$CDIR/test_autocast.py" run_test "$CDIR/eager/test_eager.py" diff --git a/test/scan/test_scan.py b/test/scan/test_scan.py index d0bb6b08e82e..7d4fecb587c5 100644 --- a/test/scan/test_scan.py +++ b/test/scan/test_scan.py @@ -613,6 +613,39 @@ def compute_outputs_and_gradients(carry, x): self.compare_pytree(grad_init, expected_grads['init']) self.compare_pytree(grad_x, expected_grads['x']) + def test_scan_tracing_does_not_allocate_device_memory(self): + """ + When scan is tracing the function to obtain an HLO, it should not allocate + device memory. + """ + + def fn1(carry, x): + carry = torch.sin(carry) + x = torch.sin(x) + return carry, x + + def fn2(carry, x): + """ + Test cases where input/outputs are aliased. + """ + return carry, x + + for fn in [fn1, fn2]: + init = torch.tensor([0.0, 0.0], requires_grad=True, device=self.device) + xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], + requires_grad=True, + device=self.device) + torch_xla.sync(wait=True) + met.clear_all() + self.assertFalse(met.metric_data("TransferToDeviceTime")) + # Use `scan` to lower `fn` into HLO and run it. Doing so should not + # transfer anything from host to device since `init` and `xs` are + # already on the device. + # In practice, `carry` and `x` will be placeholder tensors in `fn`. + _ = scan(fn, init, xs) + torch_xla.sync(wait=True) + self.assertFalse(met.metric_data("TransferToDeviceTime")) + if __name__ == '__main__': test = unittest.main() diff --git a/test/test_placeholder.py b/test/test_placeholder.py new file mode 100644 index 000000000000..d5506bfacd55 --- /dev/null +++ b/test/test_placeholder.py @@ -0,0 +1,73 @@ +from absl.testing import absltest +import torch +import torch_xla +from torch_xla.core.xla_builder import create_placeholder_tensor +import torch_xla.debug.metrics as met +import re + + +class TestPlaceholder(absltest.TestCase): + + def setUp(self): + super().setUp() + torch_xla._XLAC._xla_set_enable_alias_with_buffer_donor_config(True) + + def test_create_placeholder(self): + for shape, dtype in zip( + ((1, 2), (2, 3, 4), (3, 4, 5, 6)), + (torch.float32, torch.bfloat16, torch.int8), + ): + p = create_placeholder_tensor(shape, dtype) + assert isinstance(p, torch.Tensor) + assert p.device == torch_xla.device() + self.assertEqual(p.dtype, dtype) + self.assertEqual(p.shape, shape) + self.assertTrue(torch_xla._XLAC._is_placecholder(p)) + + def test_read_value_crashes(self): + p = create_placeholder_tensor((1,), torch.bfloat16) + with self.assertRaises(RuntimeError): + p.cpu() + + def test_trace_graph(self): + met.clear_all() + self.assertFalse(met.metric_data("TransferToDeviceTime")) + + p1 = create_placeholder_tensor((2, 3), torch.bfloat16) + a = torch.sin(p1) + p2 = create_placeholder_tensor((3, 4), torch.bfloat16) + # We use p1 once and p2 twice. But the graph should still only have two parameters. + b = (a @ p2) @ p2.T + ir: str = torch_xla._XLAC._get_xla_tensors_text([b]) + self.assertEqual(ir.count("xla::device_data()"), 2) + self.assertEqual(ir.count("bf16[3,4]{1,0} xla::device_data()"), 1) + self.assertEqual(ir.count("bf16[2,3]{1,0} xla::device_data()"), 1) + hlo: str = torch_xla._XLAC._get_xla_tensors_hlo([b]) + regex = r'\(p.*: bf16\[3,4\], p.*: bf16\[2,3\]\) -> \(bf16\[2,3\]\)' + assert re.search(regex, hlo) is not None + + # There should be no buffers transferred to the device during tracing + self.assertFalse(met.metric_data("TransferToDeviceTime")) + + def test_placeholder_handle_unique(self): + p1 = create_placeholder_tensor((1,), torch.bfloat16) + p2 = create_placeholder_tensor((1,), torch.bfloat16) + h1, h2 = torch_xla._XLAC._get_tensors_handle([p1, p2]) + self.assertNotEqual(h1, h2) + + def test_cannot_get_handle_from_deleted_pjrt_buffer(self): + xla_device = torch_xla.device() + t0 = torch.randn(4, 2, 2).to(xla_device) + t1 = torch.randn(4, 2, 2).to(xla_device) + self.assertTrue(torch_xla._XLAC._set_buffer_donation(t0, True)) + self.assertTrue(torch_xla._XLAC._get_buffer_donation(t0)) + _ = t0 + t1 + torch_xla.sync(wait=True) + + self.assertTrue(torch_xla._XLAC._is_placecholder(t0)) + with self.assertRaises(RuntimeError, msg='is deleted'): + torch_xla._XLAC._get_tensors_handle([t0]) + + +if __name__ == "__main__": + absltest.main() diff --git a/torch_xla/core/xla_builder.py b/torch_xla/core/xla_builder.py index f5fdac2b1265..19b97d6b8236 100644 --- a/torch_xla/core/xla_builder.py +++ b/torch_xla/core/xla_builder.py @@ -40,6 +40,24 @@ class Type: Type.PRED: torch.bool, } +_PT_XLA_TYPE_MAP = { + torch.float32: Type.F32, + torch.float64: Type.F64, + torch.bfloat16: Type.BF16, + torch.float16: Type.F16, + torch.uint8: Type.U8, + torch.int8: Type.S8, + torch.uint16: Type.U16, + torch.int16: Type.S16, + torch.uint32: Type.U32, + torch.int32: Type.S32, + torch.uint64: Type.U64, + torch.int64: Type.S64, + torch.complex64: Type.C64, + torch.complex128: Type.C128, + torch.bool: Type.PRED, +} + class Shape(object): """Wraps a core XLA shape object to provide a more friendly API.""" @@ -751,6 +769,10 @@ def map(cls, ops, computation, dimensions, static_operands=(), builder=None): def to_torch_type(cls, dtype): return _XLA_PT_TYPE_MAP[dtype] if dtype else torch.float32 + @classmethod + def from_torch_type(cls, dtype): + return _PT_XLA_TYPE_MAP[dtype] + def create_builder(name): return torch_xla._XLAC._xla_op_create_builder(name) @@ -846,3 +868,14 @@ def fn_flattened_inputs(*flattened): if isinstance(result, list) and len(result) == 1: return result[0] return result + + +def create_placeholder_tensor(shape, dtype): + """ + Creates a placeholder tensor that does not hold any device buffer. + This is primarily useful for staging out the HLO of a user computation. + Accessing the value of the tensor will panic. + """ + dtype = Op.from_torch_type(dtype) + shape = mkshape(dtype, shape) + return torch_xla._XLAC._xla_create_placeholder_tensor(shape.shape) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index ea9e49f3f1a0..919b9415a01f 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1866,6 +1866,7 @@ void InitXlaModuleBindings(py::module m) { }); m.def("_xla_optimization_barrier_", [](std::vector& inputs) { OptimizationBarrier_(inputs); }); + // TODO(https://github.com/pytorch/xla/issues/8713): torch.einsum is getting // decomposed when inside a custom op. This C++ op is an escape hatch to call // XLA einsum without going through torch.einsum. We should remove this @@ -1876,6 +1877,18 @@ void InitXlaModuleBindings(py::module m) { XLATensorPtr output = tensor_methods::einsum(equation, xla_tensors); return bridge::AtenFromXlaTensor(output); }); + + // Creates a placeholder tensor that does not hold any device buffer. + // This is primarily useful for staging out the HLO of a user computation. + // Accessing the value of the tensor will panic. + m.def("_xla_create_placeholder_tensor", [](py::object py_shape) { + xla::Shape shape = op_builder::PyShapeToShape(py_shape); + auto xla_tensor = XLATensor::Create( + torch_xla::runtime::GetComputationClient()->CreateDataPlaceholder( + bridge::GetCurrentDevice().toString(), std::move(shape))); + return bridge::AtenFromXlaTensor(xla_tensor); + }); + m.def("_xla_set_default_device", [](const std::string& device) { return SetCurrentThreadDevice(device); }); diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h index c83a705abbbd..73b8e21c9f06 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.h +++ b/torch_xla/csrc/runtime/ifrt_computation_client.h @@ -203,9 +203,11 @@ class IfrtComputationClient : public ComputationClient { sharding_(sharding) {} Handle GetHandle() override { - XLA_CHECK(HasValue()) - << "buffer with shape " << shape().ToString() << " on device " - << device() << (buffer == nullptr ? " is null" : " is deleted"); + // If the data is a placeholder, use the address of this object as the + // handle. + if (buffer == nullptr) { + return reinterpret_cast(this); + } return reinterpret_cast(buffer.get()); }; void Assign(const torch::lazy::BackendData& data) override; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index 6530ce768b4b..9791f32381b6 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -191,9 +191,15 @@ class PjRtComputationClient : public ComputationClient { buffer(buffer) {} Handle GetHandle() override { - XLA_CHECK(HasValue()) + // If the data is a placeholder, use the address of this object as the + // handle. + if (buffer == nullptr) { + return reinterpret_cast(this); + } + + XLA_CHECK(!buffer->IsDeleted()) << "buffer with shape " << shape().ToString() << " on device " - << device() << (buffer == nullptr ? " is null" : " is deleted"); + << device() << " is deleted"; return reinterpret_cast(buffer.get()); }; void Assign(const torch::lazy::BackendData& data) override; diff --git a/torch_xla/experimental/scan.py b/torch_xla/experimental/scan.py index a456f1f6db77..7f8967c9388c 100644 --- a/torch_xla/experimental/scan.py +++ b/torch_xla/experimental/scan.py @@ -558,8 +558,8 @@ def fn(carry, x): # Abstractly trace and lower `fn`. # Later we will include `fn_computation` within the while loop body. def make_fake_tensor(v: torch.Tensor) -> torch.Tensor: - return torch.empty( - v.size(), dtype=v.dtype).to(device).requires_grad_(v.requires_grad) + t = xb.create_placeholder_tensor(v.shape, v.dtype) + return t.requires_grad_(v.requires_grad) device = torch_xla.device() fake_carry = tree_map(make_fake_tensor, init)