Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IncSubtensor causes graph break in pytorch backend #1154

Open
Ch0ronomato opened this issue Jan 16, 2025 · 4 comments · May be fixed by #1159
Open

IncSubtensor causes graph break in pytorch backend #1154

Ch0ronomato opened this issue Jan 16, 2025 · 4 comments · May be fixed by #1159
Labels
torch PyTorch backend

Comments

@Ch0ronomato
Copy link
Contributor

Description

import pytensor
import pytensor.tensor as pt
import torch

x = pt.vector('x')
y = x[-1].inc(1)
f = pytensor.function(inputs=[x], outputs=y, mode="PYTORCH")

torch._dynamo.explain(f.vm.jit_fn._fn)(torch.zeros(2)).break_reasons

WIll produce: [GraphCompileReason(reason='data dependent operator: aten._local_scalar_dense.default; to enable, set torch._dynamo.config.capture_scalar_outputs = True', user_stack=[<FrameSummary file [/Users/ch0ronomato/dev/pytensor/pytensor/link/pytorch/dispatch/subtensor.py](https://file+.vscode-resource.vscode-cdn.net/Users/ch0ronomato/dev/pytensor/pytensor/link/pytorch/dispatch/subtensor.py), line 99 in inc_subtensor>], graph_break=True)]

@Ch0ronomato Ch0ronomato added the torch PyTorch backend label Jan 16, 2025
@Ch0ronomato
Copy link
Contributor Author

Ch0ronomato commented Jan 16, 2025

The underline function is

def pytorch_funcified_fgraph(x):
    # IncSubtensor{i}(x, 1, -1)
    tensor_variable = inc_subtensor(x, tensor_constant, scalar_constant)
    return (tensor_variable,)

if you get the constants to be embedded into the code, then the graph break goes away.

from pytensor.link.utils import compile_function_src
lines = inspect.getsourcelines(f.vm.jit_fn._fn)[0]
g = inspect.getclosurevars(f.vm.jit_fn._fn).globals
space = " " * 4
src = "".join([lines[0], *[space + "tensor_constant = torch.tensor(1)\n", space + "scalar_constant = -1\n"], *lines[1:]])
new_fn = compile_function_src(src, f.vm.jit_fn._fn.__name__, {**dict(x for x in g.items() if x[0] == 'inc_subtensor'), **globals()})

print(inspect.getsource(new_fn))
# def pytorch_funcified_fgraph(x):
#    tensor_constant = torch.tensor(1)
#    scalar_constant = -1
#    # IncSubtensor{i}(x, 1, -1)
#    tensor_variable = inc_subtensor(x, tensor_constant, scalar_constant)
#    return (tensor_variable,)torch._dynamo.explain(new_fn)(torch.zeros(2)).break_reasons


torch._dynamo.explain(new_fn)(torch.zeros(2)).break_reasons
# []

@ricardoV94 's idea is to see if we can get the params out from the op / node and bake it into the return function, or use torch check. When I initially tried, that didn't work. Will try again and post results for tracking.

@ricardoV94
Copy link
Member

ricardoV94 commented Jan 16, 2025

That sounds fine, and easier than inlining constants manually in the Op implementation. The same could be done for JAX, and simplify our codebase quite some bit

@ricardoV94
Copy link
Member

ricardoV94 commented Jan 16, 2025

OTOH inlining constants as text is not possible/trivial? Say something like np.full((1000, 1000), np.pi)) or a slice object

@Ch0ronomato
Copy link
Contributor Author

Ugh, yea, inf would be another example. Anything where __repr__ returns almost correct python but not fully correct.

Is there a way we could use a graph rewrite for this? That way instead of there being a scalar_constant it can call one of the dispatch methods? That way we could decorate it with assume_constant and hint to torch that it's okay.

Or we get torch checks working. I opened an issue with torch https://discuss.pytorch.org/t/torch-check-failing-with-torch-compile/215443.

We can also keep this open for a little bit.

@Ch0ronomato Ch0ronomato linked a pull request Jan 21, 2025 that will close this issue
11 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
torch PyTorch backend
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants