Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ppe.compile Enable forward only and custom decompositions #740

Merged
merged 9 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 35 additions & 7 deletions pytorch_pfn_extras/_dynamo/_compile.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, List, Optional, cast
from typing import Any, Callable, Dict, List, Optional, cast

import torch
import torch.fx
Expand Down Expand Up @@ -127,12 +127,14 @@ def _compile_module(
module: torch.nn.Module,
optimizer: Optional[torch.optim.Optimizer],
user_backend: Optional[Callable[..., Any]],
generate_backward: bool,
decompositions: Optional[Dict[Any, Callable]],
) -> Callable[..., Any]:
if not isinstance(module, torch.nn.Module):
raise TypeError("module needs to be a torch.nn.Module instance")

names = []
parameters_and_buffers = []
parameters_and_buffers: List[torch.Tensor] = []

def _graph_getter(gm, inputs): # type: ignore[no-untyped-def]
parameters_optimizer = []
Expand All @@ -150,7 +152,7 @@ def _graph_getter(gm, inputs): # type: ignore[no-untyped-def]
if _normalize_name(node.name) == n:
parameters_optimizer.append(p)

for n, p in module.named_parameters():
for _, p in module.named_parameters():
for p_n in optimizer.state[p]: # type: ignore[index]
state_tensor = optimizer.state[p][p_n] # type: ignore[index]
if state_tensor is not None:
Expand Down Expand Up @@ -195,13 +197,22 @@ def _model_opt_func(*args, **kwargs): # type: ignore[no-untyped-def]
parameters_and_buffers.append(p)
names.append(n)

partitioner = _splitter.JointGraph(names)
for n, b in module.named_buffers():
parameters_and_buffers.append(b)
names.append(n)

Comment on lines +200 to +202
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# This may be to simplistic ..., would be better to set a `mode`?
partitioner: _splitter._Splitter
if generate_backward:
partitioner = _splitter.JointGraph(names)
else:
partitioner = _splitter.ForwardOnly(names)

aot_backend = aot_autograd( # type: ignore[no-untyped-call]
fw_compiler=_graph_getter,
bw_compiler=_dummy_bwd_backend,
partition_fn=partitioner._no_partition,
decompositions=core_aten_decompositions(),
partition_fn=partitioner.partition,
decompositions=decompositions,
)
module_opt = torch.compile(module, fullgraph=True, backend=aot_backend) # type: ignore[attr-defined]
return cast(Callable[..., Any], module_opt) # type: ignore[redundant-cast]
Expand All @@ -211,6 +222,9 @@ def compile(
module: torch.nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
backend: Optional[Callable[..., Any]] = None,
*,
generate_backward: bool = True,
decompositions: Optional[Dict[Any, Callable]] = None,
) -> Callable[..., Any]:
"""Compiles a module and an optimizer in a single graph using the provided backend.

Expand All @@ -233,7 +247,21 @@ def compile(
backend (optional):
Object to process the graph and compile it for custom devices, will
use PyTorch dynamo by default if not specified.
generate_backward:
Add the backward pass to the graph. Default is ``True``.
decompositions (optional):
Custom mapping for decompose a torch op into simple ops. Default is
``None`` and resorts to `torch._decomp.core_aten_decompositions()`
"""

module_opt = _compile_module(module, optimizer, backend)
if decompositions is None:
decompositions = core_aten_decompositions()

module_opt = _compile_module(
module,
optimizer,
backend,
generate_backward,
decompositions,
)
return module_opt
93 changes: 89 additions & 4 deletions pytorch_pfn_extras/_dynamo/_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,37 @@
import torch
import torch.fx
import torch.utils._pytree as pytree
from torch._functorch.partitioners import _is_primal, _is_tangent
from torch._functorch.partitioners import (
_is_primal,
_is_tangent,
default_partition,
)


class JointGraph:
def __init__(self, parameter_names: Optional[List[str]] = None):
class _Splitter:
def partition(
self,
joint_module: torch.fx.GraphModule,
_joint_inputs: Any,
*,
num_fwd_outputs: int,
) -> Tuple[torch.fx.GraphModule, torch.fx.GraphModule]:
raise NotImplementedError("Splitters must override partition")


class JointGraph(_Splitter):
def __init__(
self,
parameter_names: Optional[List[str]] = None,
):
if parameter_names is not None:
self._parameter_names = [
n.replace(".", "__dot__") for n in parameter_names
]
else:
self._parameter_names = []

def _no_partition(
def partition(
self,
joint_module: torch.fx.GraphModule,
_joint_inputs: Any,
Expand Down Expand Up @@ -131,3 +149,70 @@ def _no_partition(
bwd_graph.output(out_nodes)
bwd_module = torch.fx.GraphModule(joint_module, bwd_graph)
return (fwd_module, bwd_module)


class ForwardOnly(_Splitter):
def __init__(
self,
parameter_names: Optional[List[str]] = None,
):
if parameter_names is not None:
self._parameter_names = [
n.replace(".", "__dot__") for n in parameter_names
]
else:
self._parameter_names = []

def partition(
self,
joint_module: torch.fx.GraphModule,
_joint_inputs: Any,
*,
num_fwd_outputs: int,
) -> Tuple[torch.fx.GraphModule, torch.fx.GraphModule]:
fwd_module, _ = default_partition(
joint_module,
_joint_inputs,
num_fwd_outputs=num_fwd_outputs,
)
fwd_graph = fwd_module.graph
# Change the input names in the fwd graph
primal_inputs: List[torch.fx.Node] = list(
filter(_is_primal, fwd_graph.nodes)
)
for i, node in enumerate(primal_inputs):
node.name = (
self._parameter_names[i]
if i < len(self._parameter_names)
else f"input_{i - len(self._parameter_names)}"
)

# The joint graph has the forward and backward outputs together as output values
# by accessing the output node of the graph (there is only one output node)
# the node has a list of all the variables the graph returns in the node.args
# https://pytorch.org/docs/stable/fx.html#a-quick-primer-on-graphs
output_node = [node for node in fwd_graph.nodes if node.op == "output"][
0
]
outputs = pytree.tree_flatten(output_node.args)[0]
fwd_graph.erase_node(output_node)
# Select only the return values from the forward pass
fwd_graph.output(outputs[:num_fwd_outputs])
fwd_module = torch.fx.GraphModule(joint_module, fwd_graph)

# We now create a dummy graph that returns the outputs of the backward pass
# Notice that the graph needs to return as many values as the inputs of the
# forward pass. The outputs of the joint graph returns additional values
# besides the gradients.
bwd_graph = torch.fx.Graph()
# Needs to create one gradient per each input element
bwd_outs = []
for i_node in primal_inputs:
bwd_outs.append(
bwd_graph.call_function(
torch.zeros, (i_node.meta.get("tensor_meta").shape,)
)
)
bwd_graph.output(tuple(bwd_outs))
bwd_module = torch.fx.GraphModule(joint_module, bwd_graph)
return (fwd_module, bwd_module)
34 changes: 34 additions & 0 deletions tests/pytorch_pfn_extras_tests/dynamo_tests/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,37 @@ def test_compile_with_optimizer_and_split_graph():
# This executes forward+backward+optimizer step
with pytest.raises(torch._dynamo.exc.Unsupported):
joint_module(x)


@pytest.mark.skipif(
not ppe.requires("2.0.0") or sys.platform == "win32",
reason="torch.compile interface its only added in PyTorch>2.0 and linux",
)
def test_compile_forward_only():
torch._dynamo.reset()
x = torch.randn(10, requires_grad=True)
torch_module = _DummyModule()
module_initial_state = torch_module.state_dict()
compiled_module = _DummyModule()
compiled_module.load_state_dict(module_initial_state)

y = torch_module(x)

n_outs = 0

# Verify that the graph has deleted the uneeded outputs (grads)
def test_backend(gm, inputs):
nonlocal n_outs
n_outs = len(
[node for node in gm.graph.nodes if node.op == "output"][0].args
)
return gm

# This executes forward step only
fwd_module = ppe.compile(
compiled_module, None, backend=test_backend, generate_backward=False
)
compiled_y = fwd_module(x)
assert n_outs == 1

assert torch.allclose(y, compiled_y)
Loading