Skip to content

Commit

Permalink
Use placeholder tensor in scan (pytorch#8785)
Browse files Browse the repository at this point in the history
  • Loading branch information
tengyifei authored Mar 5, 2025
1 parent 4540d81 commit a927928
Show file tree
Hide file tree
Showing 8 changed files with 168 additions and 7 deletions.
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ function run_xla_op_tests2 {
run_test "$CDIR/scan/test_scan_layers.py"
run_test "$CDIR/test_gru.py"
run_test "$CDIR/test_as_stride_use_slice.py"
run_test "$CDIR/test_placeholder.py"
run_xla_hlo_debug run_test "$CDIR/scan/test_scan_debug.py"
run_test "$CDIR/test_autocast.py"
run_test "$CDIR/eager/test_eager.py"
Expand Down
33 changes: 33 additions & 0 deletions test/scan/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,39 @@ def compute_outputs_and_gradients(carry, x):
self.compare_pytree(grad_init, expected_grads['init'])
self.compare_pytree(grad_x, expected_grads['x'])

def test_scan_tracing_does_not_allocate_device_memory(self):
"""
When scan is tracing the function to obtain an HLO, it should not allocate
device memory.
"""

def fn1(carry, x):
carry = torch.sin(carry)
x = torch.sin(x)
return carry, x

def fn2(carry, x):
"""
Test cases where input/outputs are aliased.
"""
return carry, x

for fn in [fn1, fn2]:
init = torch.tensor([0.0, 0.0], requires_grad=True, device=self.device)
xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
requires_grad=True,
device=self.device)
torch_xla.sync(wait=True)
met.clear_all()
self.assertFalse(met.metric_data("TransferToDeviceTime"))
# Use `scan` to lower `fn` into HLO and run it. Doing so should not
# transfer anything from host to device since `init` and `xs` are
# already on the device.
# In practice, `carry` and `x` will be placeholder tensors in `fn`.
_ = scan(fn, init, xs)
torch_xla.sync(wait=True)
self.assertFalse(met.metric_data("TransferToDeviceTime"))


if __name__ == '__main__':
test = unittest.main()
Expand Down
73 changes: 73 additions & 0 deletions test/test_placeholder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from absl.testing import absltest
import torch
import torch_xla
from torch_xla.core.xla_builder import create_placeholder_tensor
import torch_xla.debug.metrics as met
import re


class TestPlaceholder(absltest.TestCase):

def setUp(self):
super().setUp()
torch_xla._XLAC._xla_set_enable_alias_with_buffer_donor_config(True)

def test_create_placeholder(self):
for shape, dtype in zip(
((1, 2), (2, 3, 4), (3, 4, 5, 6)),
(torch.float32, torch.bfloat16, torch.int8),
):
p = create_placeholder_tensor(shape, dtype)
assert isinstance(p, torch.Tensor)
assert p.device == torch_xla.device()
self.assertEqual(p.dtype, dtype)
self.assertEqual(p.shape, shape)
self.assertTrue(torch_xla._XLAC._is_placecholder(p))

def test_read_value_crashes(self):
p = create_placeholder_tensor((1,), torch.bfloat16)
with self.assertRaises(RuntimeError):
p.cpu()

def test_trace_graph(self):
met.clear_all()
self.assertFalse(met.metric_data("TransferToDeviceTime"))

p1 = create_placeholder_tensor((2, 3), torch.bfloat16)
a = torch.sin(p1)
p2 = create_placeholder_tensor((3, 4), torch.bfloat16)
# We use p1 once and p2 twice. But the graph should still only have two parameters.
b = (a @ p2) @ p2.T
ir: str = torch_xla._XLAC._get_xla_tensors_text([b])
self.assertEqual(ir.count("xla::device_data()"), 2)
self.assertEqual(ir.count("bf16[3,4]{1,0} xla::device_data()"), 1)
self.assertEqual(ir.count("bf16[2,3]{1,0} xla::device_data()"), 1)
hlo: str = torch_xla._XLAC._get_xla_tensors_hlo([b])
regex = r'\(p.*: bf16\[3,4\], p.*: bf16\[2,3\]\) -> \(bf16\[2,3\]\)'
assert re.search(regex, hlo) is not None

# There should be no buffers transferred to the device during tracing
self.assertFalse(met.metric_data("TransferToDeviceTime"))

def test_placeholder_handle_unique(self):
p1 = create_placeholder_tensor((1,), torch.bfloat16)
p2 = create_placeholder_tensor((1,), torch.bfloat16)
h1, h2 = torch_xla._XLAC._get_tensors_handle([p1, p2])
self.assertNotEqual(h1, h2)

def test_cannot_get_handle_from_deleted_pjrt_buffer(self):
xla_device = torch_xla.device()
t0 = torch.randn(4, 2, 2).to(xla_device)
t1 = torch.randn(4, 2, 2).to(xla_device)
self.assertTrue(torch_xla._XLAC._set_buffer_donation(t0, True))
self.assertTrue(torch_xla._XLAC._get_buffer_donation(t0))
_ = t0 + t1
torch_xla.sync(wait=True)

self.assertTrue(torch_xla._XLAC._is_placecholder(t0))
with self.assertRaises(RuntimeError, msg='is deleted'):
torch_xla._XLAC._get_tensors_handle([t0])


if __name__ == "__main__":
absltest.main()
33 changes: 33 additions & 0 deletions torch_xla/core/xla_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,24 @@ class Type:
Type.PRED: torch.bool,
}

_PT_XLA_TYPE_MAP = {
torch.float32: Type.F32,
torch.float64: Type.F64,
torch.bfloat16: Type.BF16,
torch.float16: Type.F16,
torch.uint8: Type.U8,
torch.int8: Type.S8,
torch.uint16: Type.U16,
torch.int16: Type.S16,
torch.uint32: Type.U32,
torch.int32: Type.S32,
torch.uint64: Type.U64,
torch.int64: Type.S64,
torch.complex64: Type.C64,
torch.complex128: Type.C128,
torch.bool: Type.PRED,
}


class Shape(object):
"""Wraps a core XLA shape object to provide a more friendly API."""
Expand Down Expand Up @@ -751,6 +769,10 @@ def map(cls, ops, computation, dimensions, static_operands=(), builder=None):
def to_torch_type(cls, dtype):
return _XLA_PT_TYPE_MAP[dtype] if dtype else torch.float32

@classmethod
def from_torch_type(cls, dtype):
return _PT_XLA_TYPE_MAP[dtype]


def create_builder(name):
return torch_xla._XLAC._xla_op_create_builder(name)
Expand Down Expand Up @@ -846,3 +868,14 @@ def fn_flattened_inputs(*flattened):
if isinstance(result, list) and len(result) == 1:
return result[0]
return result


def create_placeholder_tensor(shape, dtype):
"""
Creates a placeholder tensor that does not hold any device buffer.
This is primarily useful for staging out the HLO of a user computation.
Accessing the value of the tensor will panic.
"""
dtype = Op.from_torch_type(dtype)
shape = mkshape(dtype, shape)
return torch_xla._XLAC._xla_create_placeholder_tensor(shape.shape)
13 changes: 13 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1866,6 +1866,7 @@ void InitXlaModuleBindings(py::module m) {
});
m.def("_xla_optimization_barrier_",
[](std::vector<at::Tensor>& inputs) { OptimizationBarrier_(inputs); });

// TODO(https://github.com/pytorch/xla/issues/8713): torch.einsum is getting
// decomposed when inside a custom op. This C++ op is an escape hatch to call
// XLA einsum without going through torch.einsum. We should remove this
Expand All @@ -1876,6 +1877,18 @@ void InitXlaModuleBindings(py::module m) {
XLATensorPtr output = tensor_methods::einsum(equation, xla_tensors);
return bridge::AtenFromXlaTensor(output);
});

// Creates a placeholder tensor that does not hold any device buffer.
// This is primarily useful for staging out the HLO of a user computation.
// Accessing the value of the tensor will panic.
m.def("_xla_create_placeholder_tensor", [](py::object py_shape) {
xla::Shape shape = op_builder::PyShapeToShape(py_shape);
auto xla_tensor = XLATensor::Create(
torch_xla::runtime::GetComputationClient()->CreateDataPlaceholder(
bridge::GetCurrentDevice().toString(), std::move(shape)));
return bridge::AtenFromXlaTensor(xla_tensor);
});

m.def("_xla_set_default_device", [](const std::string& device) {
return SetCurrentThreadDevice(device);
});
Expand Down
8 changes: 5 additions & 3 deletions torch_xla/csrc/runtime/ifrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,11 @@ class IfrtComputationClient : public ComputationClient {
sharding_(sharding) {}

Handle GetHandle() override {
XLA_CHECK(HasValue())
<< "buffer with shape " << shape().ToString() << " on device "
<< device() << (buffer == nullptr ? " is null" : " is deleted");
// If the data is a placeholder, use the address of this object as the
// handle.
if (buffer == nullptr) {
return reinterpret_cast<std::uintptr_t>(this);
}
return reinterpret_cast<std::uintptr_t>(buffer.get());
};
void Assign(const torch::lazy::BackendData& data) override;
Expand Down
10 changes: 8 additions & 2 deletions torch_xla/csrc/runtime/pjrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,15 @@ class PjRtComputationClient : public ComputationClient {
buffer(buffer) {}

Handle GetHandle() override {
XLA_CHECK(HasValue())
// If the data is a placeholder, use the address of this object as the
// handle.
if (buffer == nullptr) {
return reinterpret_cast<std::uintptr_t>(this);
}

XLA_CHECK(!buffer->IsDeleted())
<< "buffer with shape " << shape().ToString() << " on device "
<< device() << (buffer == nullptr ? " is null" : " is deleted");
<< device() << " is deleted";
return reinterpret_cast<std::uintptr_t>(buffer.get());
};
void Assign(const torch::lazy::BackendData& data) override;
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/experimental/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,8 +558,8 @@ def fn(carry, x):
# Abstractly trace and lower `fn`.
# Later we will include `fn_computation` within the while loop body.
def make_fake_tensor(v: torch.Tensor) -> torch.Tensor:
return torch.empty(
v.size(), dtype=v.dtype).to(device).requires_grad_(v.requires_grad)
t = xb.create_placeholder_tensor(v.shape, v.dtype)
return t.requires_grad_(v.requires_grad)

device = torch_xla.device()
fake_carry = tree_map(make_fake_tensor, init)
Expand Down

0 comments on commit a927928

Please sign in to comment.