-
Notifications
You must be signed in to change notification settings - Fork 116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
IncSubtensor causes graph break in pytorch backend #1154
Comments
The underline function is def pytorch_funcified_fgraph(x):
# IncSubtensor{i}(x, 1, -1)
tensor_variable = inc_subtensor(x, tensor_constant, scalar_constant)
return (tensor_variable,) if you get the constants to be embedded into the code, then the graph break goes away. from pytensor.link.utils import compile_function_src
lines = inspect.getsourcelines(f.vm.jit_fn._fn)[0]
g = inspect.getclosurevars(f.vm.jit_fn._fn).globals
space = " " * 4
src = "".join([lines[0], *[space + "tensor_constant = torch.tensor(1)\n", space + "scalar_constant = -1\n"], *lines[1:]])
new_fn = compile_function_src(src, f.vm.jit_fn._fn.__name__, {**dict(x for x in g.items() if x[0] == 'inc_subtensor'), **globals()})
print(inspect.getsource(new_fn))
# def pytorch_funcified_fgraph(x):
# tensor_constant = torch.tensor(1)
# scalar_constant = -1
# # IncSubtensor{i}(x, 1, -1)
# tensor_variable = inc_subtensor(x, tensor_constant, scalar_constant)
# return (tensor_variable,)torch._dynamo.explain(new_fn)(torch.zeros(2)).break_reasons
torch._dynamo.explain(new_fn)(torch.zeros(2)).break_reasons
# [] @ricardoV94 's idea is to see if we can get the params out from the op / node and bake it into the return function, or use torch check. When I initially tried, that didn't work. Will try again and post results for tracking. |
That sounds fine, and easier than inlining constants manually in the Op implementation. The same could be done for JAX, and simplify our codebase quite some bit |
OTOH inlining constants as text is not possible/trivial? Say something like |
Ugh, yea, inf would be another example. Anything where Is there a way we could use a graph rewrite for this? That way instead of there being a Or we get torch checks working. I opened an issue with torch https://discuss.pytorch.org/t/torch-check-failing-with-torch-compile/215443. We can also keep this open for a little bit. |
Description
WIll produce:
[GraphCompileReason(reason='data dependent operator: aten._local_scalar_dense.default; to enable, set torch._dynamo.config.capture_scalar_outputs = True', user_stack=[<FrameSummary file [/Users/ch0ronomato/dev/pytensor/pytensor/link/pytorch/dispatch/subtensor.py](https://file+.vscode-resource.vscode-cdn.net/Users/ch0ronomato/dev/pytensor/pytensor/link/pytorch/dispatch/subtensor.py), line 99 in inc_subtensor>], graph_break=True)]
The text was updated successfully, but these errors were encountered: