diff --git a/test/scan/test_scan.py b/test/scan/test_scan.py index f344386c852..6223d77ebf4 100644 --- a/test/scan/test_scan.py +++ b/test/scan/test_scan.py @@ -9,6 +9,7 @@ from torch.utils._pytree import tree_map, tree_flatten, tree_iter, tree_leaves, PyTree import torch_xla +import torch_xla.debug.metrics as met from torch_xla.experimental.scan import scan, value_and_grad_partitioned, tree_flatten_none parent_folder = os.path.dirname(os.path.dirname(__file__)) @@ -178,6 +179,29 @@ def fn(carry, x): device=self.device) self.run_test(fn, init, xs) + def test_scan_create_tensors_no_transfers_from_device(self): + """Test that scanning over a function that internally creates tensors + will not transfer those tensor to host, which can be potentially expensive. + """ + + def fn(carry, x): + a = torch.tensor([1.0, 2.0], device=self.device) + b = torch.tensor([3.0, 4.0], device=self.device) + return carry + a, x + b + + init = torch.tensor([0.0, 0.0], requires_grad=True, device=self.device) + xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], + requires_grad=True, + device=self.device) + met.clear_all() + self.assertFalse(met.metric_data("TransferFromDeviceTime")) + # Use `scan` to lower `fn` into HLO and run it. Doing so should not + # transfer anything from the device to host. `carry` and `ys` still + # reference data on the device. + carry, ys = scan(fn, init, xs) + torch_xla.sync(wait=True) + self.assertFalse(met.metric_data("TransferFromDeviceTime")) + def test_scan_internal_in_place_mutation(self): """ Test internal in-place mutations inside the `fn` to be scanned over. diff --git a/torch_xla/experimental/scan.py b/torch_xla/experimental/scan.py index 7e872edc42d..f181d5ad764 100644 --- a/torch_xla/experimental/scan.py +++ b/torch_xla/experimental/scan.py @@ -521,7 +521,8 @@ def make_fake_tensor(v: torch.Tensor) -> torch.Tensor: builder.add_param(val) # Detect hoisted variables. - hoisted_vars: Dict[int, torch.Tensor] = fn_ctx.parameter_id_tensor_mapping() + hoisted_vars: Dict[ + int, torch.Tensor] = fn_ctx.device_parameter_id_tensor_mapping() for v in itertools.chain(fake_carry, fake_x): param_id = fn_ctx.tensor_parameter_id(v) if param_id != -1: