Skip to content

Commit

Permalink
Cherrypick #8524 into r2.6 (#8572)
Browse files Browse the repository at this point in the history
  • Loading branch information
tengyifei authored Jan 15, 2025
1 parent c2e6fd3 commit c7902e3
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
24 changes: 24 additions & 0 deletions test/scan/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torch.utils._pytree import tree_map, tree_flatten, tree_iter, tree_leaves, PyTree

import torch_xla
import torch_xla.debug.metrics as met
from torch_xla.experimental.scan import scan, value_and_grad_partitioned, tree_flatten_none

parent_folder = os.path.dirname(os.path.dirname(__file__))
Expand Down Expand Up @@ -178,6 +179,29 @@ def fn(carry, x):
device=self.device)
self.run_test(fn, init, xs)

def test_scan_create_tensors_no_transfers_from_device(self):
"""Test that scanning over a function that internally creates tensors
will not transfer those tensor to host, which can be potentially expensive.
"""

def fn(carry, x):
a = torch.tensor([1.0, 2.0], device=self.device)
b = torch.tensor([3.0, 4.0], device=self.device)
return carry + a, x + b

init = torch.tensor([0.0, 0.0], requires_grad=True, device=self.device)
xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
requires_grad=True,
device=self.device)
met.clear_all()
self.assertFalse(met.metric_data("TransferFromDeviceTime"))
# Use `scan` to lower `fn` into HLO and run it. Doing so should not
# transfer anything from the device to host. `carry` and `ys` still
# reference data on the device.
carry, ys = scan(fn, init, xs)
torch_xla.sync(wait=True)
self.assertFalse(met.metric_data("TransferFromDeviceTime"))

def test_scan_internal_in_place_mutation(self):
"""
Test internal in-place mutations inside the `fn` to be scanned over.
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/experimental/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,8 @@ def make_fake_tensor(v: torch.Tensor) -> torch.Tensor:
builder.add_param(val)

# Detect hoisted variables.
hoisted_vars: Dict[int, torch.Tensor] = fn_ctx.parameter_id_tensor_mapping()
hoisted_vars: Dict[
int, torch.Tensor] = fn_ctx.device_parameter_id_tensor_mapping()
for v in itertools.chain(fake_carry, fake_x):
param_id = fn_ctx.tensor_parameter_id(v)
if param_id != -1:
Expand Down

0 comments on commit c7902e3

Please sign in to comment.