Skip to content

Commit

Permalink
log out partial fx graph when guard on data dependent during non stir…
Browse files Browse the repository at this point in the history
…ct tracing (pytorch#146298)

As discussed with @avikchaudhuri and @bdhirsh last week, this can be quite useful when debugging.

The following code produces a data dependent error

```
import torch
from torch import nn

# UserError: Could not guard on data-dependent expression Eq(507 - u0, 0) (unhinted: Eq(507 - u0, 0)).  (Size-like symbols: u0)
class Repro(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, cache, update, pos):
        _, _, max_seq_len, _ = cache.shape
        _, _, seqlen, _ = update.shape

        pos_item = pos[0].item() # u0
        torch._check(pos_item + seqlen <= max_seq_len) # u0 + 502 <= 507
        torch._check(pos_item >= 0)
        before = cache.narrow(2, 0, pos_item)

        # FAIL
        # Laith: why can't we make unbacked expressions size-like?
        after = cache.narrow(2, (pos_item + seqlen), (max_seq_len - pos_item - seqlen))

        # PASS
        end = torch.tensor(max_seq_len - pos_item - seqlen).item()
        after = cache.narrow(2, (pos_item + seqlen), end)

        return torch.cat([before, update, after], dim=2)

repro = Repro()

bsz = 1
n_heads = 4
max_seq_len = 512
head_dim = 64
seqlen = 5
pos_item = 1

cache = torch.zeros(bsz, n_heads, max_seq_len, head_dim)
update = torch.ones(bsz, n_heads, seqlen, head_dim)
pos = torch.tensor([pos_item])
example_inputs = (cache, update, pos)

torch.export.export(repro, example_inputs, strict=False)
```

This is what it now prints out

```
class GraphModule(torch.nn.Module):
    def forward(self, arg0_1: "f32[1, 4, 512, 64][131072, 32768, 64, 1]cpu", arg1_1: "f32[1, 4, 5, 64][1280, 320, 64, 1]cpu", arg2_1: "i64[1][1]cpu"):
         # File: /data/users/bobren/a/pytorch/r1.py:14 in forward, code: pos_item = pos[0].item() # u0
        select: "i64[][]cpu" = torch.ops.aten.select.int(arg2_1, 0, 0);  arg2_1 = None
        item: "Sym(u0)" = torch.ops.aten.item.default(select);  select = None

         # File: /data/users/bobren/a/pytorch/r1.py:15 in forward, code: torch._check(pos_item + seqlen <= max_seq_len) # u0 + 502 <= 507
        add: "Sym(u0 + 5)" = item + 5
        le: "Sym(u0 + 5 <= 512)" = add <= 512;  add = le = None

         # File: /data/users/bobren/a/pytorch/r1.py:16 in forward, code: torch._check(pos_item >= 0)
        ge: "Sym(u0 >= 0)" = item >= 0;  ge = None

         # File: /data/users/bobren/a/pytorch/r1.py:17 in forward, code: before = cache.narrow(2, 0, pos_item)
        narrow: "f32[1, 4, u0, 64][131072, 32768, 64, 1]cpu" = torch.ops.aten.narrow.default(arg0_1, 2, 0, item);  narrow = None

         # File: /data/users/bobren/a/pytorch/r1.py:21 in forward, code: after = cache.narrow(2, (pos_item + seqlen), (max_seq_len - pos_item - seqlen))
        add_1: "Sym(u0 + 5)" = item + 5
        sub: "Sym(512 - u0)" = 512 - item;  item = None
        sub_1: "Sym(507 - u0)" = sub - 5;  sub = None
        narrow_1 = torch.ops.aten.narrow.default(arg0_1, 2, add_1, sub_1);  arg0_1 = add_1 = sub_1 = narrow_1 = None

Traceback (most recent call last):
  File "/data/users/bobren/a/pytorch/r1.py", line 45, in <module>
    torch.export.export(repro, example_inputs, strict=False)
  File "/data/users/bobren/a/pytorch/torch/export/__init__.py", line 368, in export
    return _export(
  File "/data/users/bobren/a/pytorch/torch/export/_trace.py", line 1044, in wrapper
    raise e
  File "/data/users/bobren/a/pytorch/torch/export/_trace.py", line 1017, in wrapper
    ep = fn(*args, **kwargs)
  File "/data/users/bobren/a/pytorch/torch/export/exported_program.py", line 117, in wrapper
    return fn(*args, **kwargs)
  File "/data/users/bobren/a/pytorch/torch/export/_trace.py", line 2079, in _export
    return _export_for_training(
  File "/data/users/bobren/a/pytorch/torch/export/_trace.py", line 1044, in wrapper
    raise e
  File "/data/users/bobren/a/pytorch/torch/export/_trace.py", line 1017, in wrapper
    ep = fn(*args, **kwargs)
  File "/data/users/bobren/a/pytorch/torch/export/exported_program.py", line 117, in wrapper
    return fn(*args, **kwargs)
  File "/data/users/bobren/a/pytorch/torch/export/_trace.py", line 1944, in _export_for_training
    export_artifact = export_func(  # type: ignore[operator]
  File "/data/users/bobren/a/pytorch/torch/export/_trace.py", line 1879, in _non_strict_export
    aten_export_artifact = _to_aten_func(  # type: ignore[operator]
  File "/data/users/bobren/a/pytorch/torch/export/_trace.py", line 1665, in _export_to_aten_ir_make_fx
    gm, graph_signature = transform(_make_fx_helper)(
  File "/data/users/bobren/a/pytorch/torch/export/_trace.py", line 1809, in _aot_export_non_strict
    gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
  File "/data/users/bobren/a/pytorch/torch/export/_trace.py", line 1585, in _make_fx_helper
    gm = make_fx(
  File "/data/users/bobren/a/pytorch/torch/fx/experimental/proxy_tensor.py", line 2194, in wrapped
    return make_fx_tracer.trace(f, *args)
  File "/data/users/bobren/a/pytorch/torch/fx/experimental/proxy_tensor.py", line 2132, in trace
    return self._trace_inner(f, *args)
  File "/data/users/bobren/a/pytorch/torch/fx/experimental/proxy_tensor.py", line 2103, in _trace_inner
    t = dispatch_trace(
  File "/data/users/bobren/a/pytorch/torch/_compile.py", line 51, in inner
    return disable_fn(*args, **kwargs)
  File "/data/users/bobren/a/pytorch/torch/_dynamo/eval_frame.py", line 749, in _fn
    return fn(*args, **kwargs)
  File "/data/users/bobren/a/pytorch/torch/fx/experimental/proxy_tensor.py", line 1136, in dispatch_trace
    graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]
  File "/data/users/bobren/a/pytorch/torch/fx/experimental/proxy_tensor.py", line 1692, in trace
    res = super().trace(root, concrete_args)
  File "/data/users/bobren/a/pytorch/torch/fx/_symbolic_trace.py", line 834, in trace
    (self.create_arg(fn(*args)),),
  File "/data/users/bobren/a/pytorch/torch/fx/experimental/proxy_tensor.py", line 1191, in wrapped
    out = f(*tensors)  # type:ignore[call-arg]
  File "<string>", line 1, in <lambda>
  File "/data/users/bobren/a/pytorch/torch/export/_trace.py", line 1488, in wrapped_fn
    return tuple(flat_fn(*args))
  File "/data/users/bobren/a/pytorch/torch/_functorch/_aot_autograd/utils.py", line 184, in flat_fn
    tree_out = fn(*args, **kwargs)
  File "/data/users/bobren/a/pytorch/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 879, in functional_call
    out = mod(*args[params_len:], **kwargs)
  File "/data/users/bobren/a/pytorch/torch/fx/_symbolic_trace.py", line 811, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
  File "/data/users/bobren/a/pytorch/torch/fx/experimental/proxy_tensor.py", line 1762, in call_module
    return Tracer.call_module(self, m, forward, args, kwargs)
  File "/data/users/bobren/a/pytorch/torch/fx/_symbolic_trace.py", line 529, in call_module
    ret_val = forward(*args, **kwargs)
  File "/data/users/bobren/a/pytorch/torch/fx/_symbolic_trace.py", line 804, in forward
    return _orig_module_call(mod, *args, **kwargs)
  File "/data/users/bobren/a/pytorch/torch/nn/modules/module.py", line 1749, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/users/bobren/a/pytorch/torch/nn/modules/module.py", line 1760, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/users/bobren/a/pytorch/torch/export/_trace.py", line 1793, in forward
    tree_out = mod(*args, **kwargs)
  File "/data/users/bobren/a/pytorch/torch/fx/_symbolic_trace.py", line 811, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
  File "/data/users/bobren/a/pytorch/torch/fx/experimental/proxy_tensor.py", line 1762, in call_module
    return Tracer.call_module(self, m, forward, args, kwargs)
  File "/data/users/bobren/a/pytorch/torch/fx/_symbolic_trace.py", line 529, in call_module
    ret_val = forward(*args, **kwargs)
  File "/data/users/bobren/a/pytorch/torch/fx/_symbolic_trace.py", line 804, in forward
    return _orig_module_call(mod, *args, **kwargs)
  File "/data/users/bobren/a/pytorch/torch/nn/modules/module.py", line 1749, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/users/bobren/a/pytorch/torch/nn/modules/module.py", line 1760, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/users/bobren/a/pytorch/r1.py", line 21, in forward
    after = cache.narrow(2, (pos_item + seqlen), (max_seq_len - pos_item - seqlen))
  File "/data/users/bobren/a/pytorch/torch/fx/experimental/proxy_tensor.py", line 1239, in __torch_function__
    return func(*args, **kwargs)
  File "/data/users/bobren/a/pytorch/torch/fx/experimental/proxy_tensor.py", line 1286, in __torch_function__
    return func(*args, **kwargs)
  File "/data/users/bobren/a/pytorch/torch/_export/non_strict_utils.py", line 654, in __torch_function__
    return func(*args, **kwargs)
  File "/data/users/bobren/a/pytorch/torch/_ops.py", line 866, in handler
    return torch._library.utils.handle_dispatch_mode(
  File "/data/users/bobren/a/pytorch/torch/_library/utils.py", line 296, in handle_dispatch_mode
    return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs)
  File "/data/users/bobren/a/pytorch/torch/utils/_stats.py", line 27, in wrapper
    return fn(*args, **kwargs)
  File "/data/users/bobren/a/pytorch/torch/fx/experimental/proxy_tensor.py", line 1341, in __torch_dispatch__
    return proxy_call(self, func, self.pre_dispatch, args, kwargs)
  File "/data/users/bobren/a/pytorch/torch/fx/experimental/proxy_tensor.py", line 910, in proxy_call
    out = func(*args, **kwargs)
  File "/data/users/bobren/a/pytorch/torch/_ops.py", line 749, in __call__
    return self._op(*args, **kwargs)
  File "/data/users/bobren/a/pytorch/torch/utils/_stats.py", line 27, in wrapper
    return fn(*args, **kwargs)
  File "/data/users/bobren/a/pytorch/torch/_subclasses/fake_tensor.py", line 1267, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/data/users/bobren/a/pytorch/torch/_subclasses/fake_tensor.py", line 1808, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
  File "/data/users/bobren/a/pytorch/torch/_subclasses/fake_tensor.py", line 1369, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
  File "/data/users/bobren/a/pytorch/torch/_subclasses/fake_tensor.py", line 2282, in _dispatch_impl
    decomposition_table[func](*args, **kwargs)
  File "/data/users/bobren/a/pytorch/torch/_decomp/decompositions.py", line 759, in slice_forward
    return self.as_strided(sizes, strides, storage_offset)
  File "/data/users/bobren/a/pytorch/torch/utils/_stats.py", line 27, in wrapper
    return fn(*args, **kwargs)
  File "/data/users/bobren/a/pytorch/torch/_subclasses/fake_tensor.py", line 1267, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/data/users/bobren/a/pytorch/torch/_subclasses/fake_tensor.py", line 1808, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
  File "/data/users/bobren/a/pytorch/torch/_subclasses/fake_tensor.py", line 1370, in _cached_dispatch_impl
    entry = self._make_cache_entry(state, key, func, args, kwargs, output)
  File "/data/users/bobren/a/pytorch/torch/_subclasses/fake_tensor.py", line 1640, in _make_cache_entry
    output_info = self._get_output_info_for_cache_entry(
  File "/data/users/bobren/a/pytorch/torch/_subclasses/fake_tensor.py", line 1583, in _get_output_info_for_cache_entry
    synth_output = self._output_from_cache_entry(
  File "/data/users/bobren/a/pytorch/torch/_subclasses/fake_tensor.py", line 1738, in _output_from_cache_entry
    return self._get_output_tensor_from_cache_entry(
  File "/data/users/bobren/a/pytorch/torch/_subclasses/fake_tensor.py", line 1709, in _get_output_tensor_from_cache_entry
    empty.set_(storage, storage_offset, shape, stride)
  File "/data/users/bobren/a/pytorch/torch/fx/experimental/sym_node.py", line 564, in guard_size_oblivious
    r = self.shape_env.evaluate_expr(
  File "/data/users/bobren/a/pytorch/torch/fx/experimental/recording.py", line 263, in wrapper
    return retlog(fn(*args, **kwargs))
  File "/data/users/bobren/a/pytorch/torch/fx/experimental/symbolic_shapes.py", line 6468, in evaluate_expr
    return self._evaluate_expr(
  File "/data/users/bobren/a/pytorch/torch/fx/experimental/symbolic_shapes.py", line 6658, in _evaluate_expr
    raise self._make_data_dependent_error(
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Ne(507 - u0, 1) (unhinted: Ne(507 - u0, 1)).  (Size-like symbols: u0)
```

Pull Request resolved: pytorch#146298
Approved by: https://github.com/bdhirsh
  • Loading branch information
bobrenjc93 authored and pytorchmergebot committed Feb 3, 2025
1 parent 0da07a6 commit d69c181
Showing 1 changed file with 20 additions and 1 deletion.
21 changes: 20 additions & 1 deletion torch/fx/_symbolic_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import inspect
import math
import os
import sys
import warnings
from itertools import chain
from types import CodeType, FunctionType, ModuleType
Expand Down Expand Up @@ -818,7 +819,10 @@ def forward(*args, **kwargs):
deduplicate=False,
)
patcher.patch_method(
torch.nn.Module, "__call__", module_call_wrapper, deduplicate=False
torch.nn.Module,
"__call__",
module_call_wrapper,
deduplicate=False,
)
_patch_wrapped_functions(patcher)
_autowrap_check(patcher, fn_globals, self._autowrap_function_ids)
Expand All @@ -835,6 +839,21 @@ def forward(*args, **kwargs):
)

self.submodule_paths = None
except RuntimeError as e:
if (
isinstance(e.args[0], str)
and "Could not guard on data-dependent" in e.args[0]
):
print(
"\n"
+ self.graph.python_code(
root_module="self",
verbose=True,
).src,
file=sys.stderr,
)

raise
finally:
_is_fx_tracing_flag = old_is_fx_tracing_flag
return self.graph
Expand Down

0 comments on commit d69c181

Please sign in to comment.