Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
log out partial fx graph when guard on data dependent during non stir…
…ct tracing (pytorch#146298) As discussed with @avikchaudhuri and @bdhirsh last week, this can be quite useful when debugging. The following code produces a data dependent error ``` import torch from torch import nn # UserError: Could not guard on data-dependent expression Eq(507 - u0, 0) (unhinted: Eq(507 - u0, 0)). (Size-like symbols: u0) class Repro(nn.Module): def __init__(self): super().__init__() def forward(self, cache, update, pos): _, _, max_seq_len, _ = cache.shape _, _, seqlen, _ = update.shape pos_item = pos[0].item() # u0 torch._check(pos_item + seqlen <= max_seq_len) # u0 + 502 <= 507 torch._check(pos_item >= 0) before = cache.narrow(2, 0, pos_item) # FAIL # Laith: why can't we make unbacked expressions size-like? after = cache.narrow(2, (pos_item + seqlen), (max_seq_len - pos_item - seqlen)) # PASS end = torch.tensor(max_seq_len - pos_item - seqlen).item() after = cache.narrow(2, (pos_item + seqlen), end) return torch.cat([before, update, after], dim=2) repro = Repro() bsz = 1 n_heads = 4 max_seq_len = 512 head_dim = 64 seqlen = 5 pos_item = 1 cache = torch.zeros(bsz, n_heads, max_seq_len, head_dim) update = torch.ones(bsz, n_heads, seqlen, head_dim) pos = torch.tensor([pos_item]) example_inputs = (cache, update, pos) torch.export.export(repro, example_inputs, strict=False) ``` This is what it now prints out ``` class GraphModule(torch.nn.Module): def forward(self, arg0_1: "f32[1, 4, 512, 64][131072, 32768, 64, 1]cpu", arg1_1: "f32[1, 4, 5, 64][1280, 320, 64, 1]cpu", arg2_1: "i64[1][1]cpu"): # File: /data/users/bobren/a/pytorch/r1.py:14 in forward, code: pos_item = pos[0].item() # u0 select: "i64[][]cpu" = torch.ops.aten.select.int(arg2_1, 0, 0); arg2_1 = None item: "Sym(u0)" = torch.ops.aten.item.default(select); select = None # File: /data/users/bobren/a/pytorch/r1.py:15 in forward, code: torch._check(pos_item + seqlen <= max_seq_len) # u0 + 502 <= 507 add: "Sym(u0 + 5)" = item + 5 le: "Sym(u0 + 5 <= 512)" = add <= 512; add = le = None # File: /data/users/bobren/a/pytorch/r1.py:16 in forward, code: torch._check(pos_item >= 0) ge: "Sym(u0 >= 0)" = item >= 0; ge = None # File: /data/users/bobren/a/pytorch/r1.py:17 in forward, code: before = cache.narrow(2, 0, pos_item) narrow: "f32[1, 4, u0, 64][131072, 32768, 64, 1]cpu" = torch.ops.aten.narrow.default(arg0_1, 2, 0, item); narrow = None # File: /data/users/bobren/a/pytorch/r1.py:21 in forward, code: after = cache.narrow(2, (pos_item + seqlen), (max_seq_len - pos_item - seqlen)) add_1: "Sym(u0 + 5)" = item + 5 sub: "Sym(512 - u0)" = 512 - item; item = None sub_1: "Sym(507 - u0)" = sub - 5; sub = None narrow_1 = torch.ops.aten.narrow.default(arg0_1, 2, add_1, sub_1); arg0_1 = add_1 = sub_1 = narrow_1 = None Traceback (most recent call last): File "/data/users/bobren/a/pytorch/r1.py", line 45, in <module> torch.export.export(repro, example_inputs, strict=False) File "/data/users/bobren/a/pytorch/torch/export/__init__.py", line 368, in export return _export( File "/data/users/bobren/a/pytorch/torch/export/_trace.py", line 1044, in wrapper raise e File "/data/users/bobren/a/pytorch/torch/export/_trace.py", line 1017, in wrapper ep = fn(*args, **kwargs) File "/data/users/bobren/a/pytorch/torch/export/exported_program.py", line 117, in wrapper return fn(*args, **kwargs) File "/data/users/bobren/a/pytorch/torch/export/_trace.py", line 2079, in _export return _export_for_training( File "/data/users/bobren/a/pytorch/torch/export/_trace.py", line 1044, in wrapper raise e File "/data/users/bobren/a/pytorch/torch/export/_trace.py", line 1017, in wrapper ep = fn(*args, **kwargs) File "/data/users/bobren/a/pytorch/torch/export/exported_program.py", line 117, in wrapper return fn(*args, **kwargs) File "/data/users/bobren/a/pytorch/torch/export/_trace.py", line 1944, in _export_for_training export_artifact = export_func( # type: ignore[operator] File "/data/users/bobren/a/pytorch/torch/export/_trace.py", line 1879, in _non_strict_export aten_export_artifact = _to_aten_func( # type: ignore[operator] File "/data/users/bobren/a/pytorch/torch/export/_trace.py", line 1665, in _export_to_aten_ir_make_fx gm, graph_signature = transform(_make_fx_helper)( File "/data/users/bobren/a/pytorch/torch/export/_trace.py", line 1809, in _aot_export_non_strict gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags) File "/data/users/bobren/a/pytorch/torch/export/_trace.py", line 1585, in _make_fx_helper gm = make_fx( File "/data/users/bobren/a/pytorch/torch/fx/experimental/proxy_tensor.py", line 2194, in wrapped return make_fx_tracer.trace(f, *args) File "/data/users/bobren/a/pytorch/torch/fx/experimental/proxy_tensor.py", line 2132, in trace return self._trace_inner(f, *args) File "/data/users/bobren/a/pytorch/torch/fx/experimental/proxy_tensor.py", line 2103, in _trace_inner t = dispatch_trace( File "/data/users/bobren/a/pytorch/torch/_compile.py", line 51, in inner return disable_fn(*args, **kwargs) File "/data/users/bobren/a/pytorch/torch/_dynamo/eval_frame.py", line 749, in _fn return fn(*args, **kwargs) File "/data/users/bobren/a/pytorch/torch/fx/experimental/proxy_tensor.py", line 1136, in dispatch_trace graph = tracer.trace(root, concrete_args) # type: ignore[arg-type] File "/data/users/bobren/a/pytorch/torch/fx/experimental/proxy_tensor.py", line 1692, in trace res = super().trace(root, concrete_args) File "/data/users/bobren/a/pytorch/torch/fx/_symbolic_trace.py", line 834, in trace (self.create_arg(fn(*args)),), File "/data/users/bobren/a/pytorch/torch/fx/experimental/proxy_tensor.py", line 1191, in wrapped out = f(*tensors) # type:ignore[call-arg] File "<string>", line 1, in <lambda> File "/data/users/bobren/a/pytorch/torch/export/_trace.py", line 1488, in wrapped_fn return tuple(flat_fn(*args)) File "/data/users/bobren/a/pytorch/torch/_functorch/_aot_autograd/utils.py", line 184, in flat_fn tree_out = fn(*args, **kwargs) File "/data/users/bobren/a/pytorch/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 879, in functional_call out = mod(*args[params_len:], **kwargs) File "/data/users/bobren/a/pytorch/torch/fx/_symbolic_trace.py", line 811, in module_call_wrapper return self.call_module(mod, forward, args, kwargs) File "/data/users/bobren/a/pytorch/torch/fx/experimental/proxy_tensor.py", line 1762, in call_module return Tracer.call_module(self, m, forward, args, kwargs) File "/data/users/bobren/a/pytorch/torch/fx/_symbolic_trace.py", line 529, in call_module ret_val = forward(*args, **kwargs) File "/data/users/bobren/a/pytorch/torch/fx/_symbolic_trace.py", line 804, in forward return _orig_module_call(mod, *args, **kwargs) File "/data/users/bobren/a/pytorch/torch/nn/modules/module.py", line 1749, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/data/users/bobren/a/pytorch/torch/nn/modules/module.py", line 1760, in _call_impl return forward_call(*args, **kwargs) File "/data/users/bobren/a/pytorch/torch/export/_trace.py", line 1793, in forward tree_out = mod(*args, **kwargs) File "/data/users/bobren/a/pytorch/torch/fx/_symbolic_trace.py", line 811, in module_call_wrapper return self.call_module(mod, forward, args, kwargs) File "/data/users/bobren/a/pytorch/torch/fx/experimental/proxy_tensor.py", line 1762, in call_module return Tracer.call_module(self, m, forward, args, kwargs) File "/data/users/bobren/a/pytorch/torch/fx/_symbolic_trace.py", line 529, in call_module ret_val = forward(*args, **kwargs) File "/data/users/bobren/a/pytorch/torch/fx/_symbolic_trace.py", line 804, in forward return _orig_module_call(mod, *args, **kwargs) File "/data/users/bobren/a/pytorch/torch/nn/modules/module.py", line 1749, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/data/users/bobren/a/pytorch/torch/nn/modules/module.py", line 1760, in _call_impl return forward_call(*args, **kwargs) File "/data/users/bobren/a/pytorch/r1.py", line 21, in forward after = cache.narrow(2, (pos_item + seqlen), (max_seq_len - pos_item - seqlen)) File "/data/users/bobren/a/pytorch/torch/fx/experimental/proxy_tensor.py", line 1239, in __torch_function__ return func(*args, **kwargs) File "/data/users/bobren/a/pytorch/torch/fx/experimental/proxy_tensor.py", line 1286, in __torch_function__ return func(*args, **kwargs) File "/data/users/bobren/a/pytorch/torch/_export/non_strict_utils.py", line 654, in __torch_function__ return func(*args, **kwargs) File "/data/users/bobren/a/pytorch/torch/_ops.py", line 866, in handler return torch._library.utils.handle_dispatch_mode( File "/data/users/bobren/a/pytorch/torch/_library/utils.py", line 296, in handle_dispatch_mode return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs) File "/data/users/bobren/a/pytorch/torch/utils/_stats.py", line 27, in wrapper return fn(*args, **kwargs) File "/data/users/bobren/a/pytorch/torch/fx/experimental/proxy_tensor.py", line 1341, in __torch_dispatch__ return proxy_call(self, func, self.pre_dispatch, args, kwargs) File "/data/users/bobren/a/pytorch/torch/fx/experimental/proxy_tensor.py", line 910, in proxy_call out = func(*args, **kwargs) File "/data/users/bobren/a/pytorch/torch/_ops.py", line 749, in __call__ return self._op(*args, **kwargs) File "/data/users/bobren/a/pytorch/torch/utils/_stats.py", line 27, in wrapper return fn(*args, **kwargs) File "/data/users/bobren/a/pytorch/torch/_subclasses/fake_tensor.py", line 1267, in __torch_dispatch__ return self.dispatch(func, types, args, kwargs) File "/data/users/bobren/a/pytorch/torch/_subclasses/fake_tensor.py", line 1808, in dispatch return self._cached_dispatch_impl(func, types, args, kwargs) File "/data/users/bobren/a/pytorch/torch/_subclasses/fake_tensor.py", line 1369, in _cached_dispatch_impl output = self._dispatch_impl(func, types, args, kwargs) File "/data/users/bobren/a/pytorch/torch/_subclasses/fake_tensor.py", line 2282, in _dispatch_impl decomposition_table[func](*args, **kwargs) File "/data/users/bobren/a/pytorch/torch/_decomp/decompositions.py", line 759, in slice_forward return self.as_strided(sizes, strides, storage_offset) File "/data/users/bobren/a/pytorch/torch/utils/_stats.py", line 27, in wrapper return fn(*args, **kwargs) File "/data/users/bobren/a/pytorch/torch/_subclasses/fake_tensor.py", line 1267, in __torch_dispatch__ return self.dispatch(func, types, args, kwargs) File "/data/users/bobren/a/pytorch/torch/_subclasses/fake_tensor.py", line 1808, in dispatch return self._cached_dispatch_impl(func, types, args, kwargs) File "/data/users/bobren/a/pytorch/torch/_subclasses/fake_tensor.py", line 1370, in _cached_dispatch_impl entry = self._make_cache_entry(state, key, func, args, kwargs, output) File "/data/users/bobren/a/pytorch/torch/_subclasses/fake_tensor.py", line 1640, in _make_cache_entry output_info = self._get_output_info_for_cache_entry( File "/data/users/bobren/a/pytorch/torch/_subclasses/fake_tensor.py", line 1583, in _get_output_info_for_cache_entry synth_output = self._output_from_cache_entry( File "/data/users/bobren/a/pytorch/torch/_subclasses/fake_tensor.py", line 1738, in _output_from_cache_entry return self._get_output_tensor_from_cache_entry( File "/data/users/bobren/a/pytorch/torch/_subclasses/fake_tensor.py", line 1709, in _get_output_tensor_from_cache_entry empty.set_(storage, storage_offset, shape, stride) File "/data/users/bobren/a/pytorch/torch/fx/experimental/sym_node.py", line 564, in guard_size_oblivious r = self.shape_env.evaluate_expr( File "/data/users/bobren/a/pytorch/torch/fx/experimental/recording.py", line 263, in wrapper return retlog(fn(*args, **kwargs)) File "/data/users/bobren/a/pytorch/torch/fx/experimental/symbolic_shapes.py", line 6468, in evaluate_expr return self._evaluate_expr( File "/data/users/bobren/a/pytorch/torch/fx/experimental/symbolic_shapes.py", line 6658, in _evaluate_expr raise self._make_data_dependent_error( torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Ne(507 - u0, 1) (unhinted: Ne(507 - u0, 1)). (Size-like symbols: u0) ``` Pull Request resolved: pytorch#146298 Approved by: https://github.com/bdhirsh
- Loading branch information