From 99cd32ee4fb9eb4cd6693d18cd1a5ad7b6efae50 Mon Sep 17 00:00:00 2001 From: Emilio Castillo Date: Wed, 21 Jun 2023 04:29:38 +0000 Subject: [PATCH 1/9] Initial commit --- pytorch_pfn_extras/__init__.py | 3 + pytorch_pfn_extras/_dynamo/__init__.py | 1 + pytorch_pfn_extras/_dynamo/_compile.py | 236 +++++++++++++++ pytorch_pfn_extras/_dynamo/_optimizer.py | 274 ++++++++++++++++++ pytorch_pfn_extras/_dynamo/_splitter.py | 119 ++++++++ pytorch_pfn_extras/testing.py | 43 +++ stubs/torch/fx/__init__.pyi | 6 + stubs/torch/fx/graph.pyi | 37 +++ stubs/torch/fx/graph_module.pyi | 17 ++ stubs/torch/fx/node.pyi | 11 + stubs/torch/fx/proxy.pyi | 10 + .../dynamo_tests/test_compile.py | 76 +++++ .../training_tests/test_trainer.py | 51 +--- 13 files changed, 842 insertions(+), 42 deletions(-) create mode 100644 pytorch_pfn_extras/_dynamo/__init__.py create mode 100644 pytorch_pfn_extras/_dynamo/_compile.py create mode 100644 pytorch_pfn_extras/_dynamo/_optimizer.py create mode 100644 pytorch_pfn_extras/_dynamo/_splitter.py create mode 100644 pytorch_pfn_extras/testing.py create mode 100644 stubs/torch/fx/__init__.pyi create mode 100644 stubs/torch/fx/graph.pyi create mode 100644 stubs/torch/fx/graph_module.pyi create mode 100644 stubs/torch/fx/node.pyi create mode 100644 stubs/torch/fx/proxy.pyi create mode 100644 tests/pytorch_pfn_extras_tests/dynamo_tests/test_compile.py diff --git a/pytorch_pfn_extras/__init__.py b/pytorch_pfn_extras/__init__.py index 6898f25c4..30379c702 100644 --- a/pytorch_pfn_extras/__init__.py +++ b/pytorch_pfn_extras/__init__.py @@ -26,3 +26,6 @@ from pytorch_pfn_extras._version import __version__ # NOQA from pytorch_pfn_extras.runtime._map import map # NOQA from pytorch_pfn_extras.runtime._to import to # NOQA + +if requires("2.0.0"): + from pytorch_pfn_extras._dynamo import compile # NOQA diff --git a/pytorch_pfn_extras/_dynamo/__init__.py b/pytorch_pfn_extras/_dynamo/__init__.py new file mode 100644 index 000000000..868fb867b --- /dev/null +++ b/pytorch_pfn_extras/_dynamo/__init__.py @@ -0,0 +1 @@ +from pytorch_pfn_extras._dynamo._compile import compile # NOQA diff --git a/pytorch_pfn_extras/_dynamo/_compile.py b/pytorch_pfn_extras/_dynamo/_compile.py new file mode 100644 index 000000000..fbf59233d --- /dev/null +++ b/pytorch_pfn_extras/_dynamo/_compile.py @@ -0,0 +1,236 @@ +from typing import Any, Callable, List, Optional, cast + +import torch +import torch.fx +import torch.fx.GraphModule +import torch.utils._pytree as pytree +from functorch.compile import make_boxed_func +from pytorch_pfn_extras._dynamo import _optimizer, _splitter +from torch._decomp import core_aten_decompositions # type: ignore[attr-defined] +from torch._dynamo.backends.common import aot_autograd +from torch._functorch.partitioners import _is_primal + + +def _dummy_bwd_backend( + gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor] +) -> Any: + # The bwd pass is dummy, so we just return the inputs as they are + def run_graph(*args, **kwargs): # type: ignore[no-untyped-def] + return args[:-1] + + return make_boxed_func(run_graph) + + +def _join_graphs( + module_graph: torch.fx.Graph, optimizer_graph: torch.fx.Graph +) -> torch.fx.Graph: + module_inputs: List[torch.fx.Node] = list( + filter(_is_primal, module_graph.nodes) + ) + module_outputs = pytree.tree_flatten( + [node.args for node in module_graph.nodes if node.op == "output"] + )[0] + grads = {} + # Fuse the two graphs + # 1. Look for the gradients in the outputs + for out in module_outputs: + if out.name.startswith("grad_"): + grads[out.name] = out + + parameters = {} + prefix_len = len("grad_") + for grad_name in grads: + for inp in module_inputs: + if grad_name[prefix_len:] == inp.name: + parameters[inp.name] = inp + + # Look in the optimizer graph for the nodes corresponding to the gradient obtention and + # the parameters (usually inputs) They are the `getattr` function call in the parameter + # In place ops can be ignored at this stage and substituted later in the + # compilation backend since they are returning the result + opt_grad_nodes = set() + opt_param_nodes = set() + opt_to_model = {} + for node in optimizer_graph.nodes: + if node.op == "call_function" and node.target is getattr: + if "grad" in node.args: + opt_grad_nodes.add(node) + # This will allow us to just add the same operations of these nodes to the real graph. + # Note that the updates are done INPLACE so backends for custom devices need to + # be careful + opt_param_nodes.add(node.args[0]) + # Save a correspondence of optimizer graph to model graph + opt_to_model[node] = grads["grad_" + node.args[0].name] + opt_to_model[node.args[0]] = parameters[node.args[0].name] + + # Find insertion points in the graph to add the optimizer required inputs + last_input = None + for in_node in module_graph.nodes: + if _is_primal(in_node): # type: ignore[no-untyped-call] + last_input = in_node + # Find insertion points in the graph to add the optimizer computation + last_node = None + model_output_node = None + for node in module_graph.nodes: + if node.op == "output": + model_output_node = node + break + last_node = node + + assert model_output_node is not None + outputs = pytree.tree_flatten(model_output_node.args)[0] + + # Merge the optimizer and model graphs + for node in optimizer_graph.nodes: + # Skip grad obtainer + if node.op == "call_function" and node.target is getattr: + if "grad" in node.args: + continue + if _is_primal(node): # type: ignore[no-untyped-call] + # Add the optimizer inputs to the module inputs + if node.name not in parameters: + # Look the inserting point + module_graph.inserting_after(last_input) + new_node = module_graph.placeholder(node.name) + opt_to_model[node] = new_node + new_node.meta = node.meta + last_input = new_node + continue + if node.op == "output": + # Combine model and optimizer outputs + outputs.extend(pytree.tree_flatten(node.args)[0]) + continue + + module_graph.inserting_after(last_node) + args = tuple( + opt_to_model[arg] if arg in opt_to_model else arg + for arg in node.args + ) + res = module_graph.create_node( + node.op, node.target, args, node.kwargs, node.name + ) + res.meta = node.meta + opt_to_model[node] = res + last_node = res + + # Remove the original outputs node and add the combined one + module_graph.erase_node(model_output_node) + module_graph.inserting_after(last_node) + module_graph.output(outputs) + return module_graph + + +def _normalize_name(name: str) -> str: + return name.replace("param_out_", "").replace("__dot__", ".") + + +def _compile_module( + module: torch.nn.Module, + optimizer: Optional[torch.optim.Optimizer], + user_backend: Optional[Callable[..., Any]], +) -> Callable[..., Any]: + if not isinstance(module, torch.nn.Module): + raise TypeError("module needs to be a torch.nn.Module instance") + + names = [] + parameters_and_buffers = [] + + def _graph_getter(gm, inputs): # type: ignore[no-untyped-def] + parameters_optimizer = [] + state_optimizer = [] + # TODO(ecastill) call the optimizer compiler here! + if optimizer is not None: + opt_graph, opt_outputs = _optimizer._compile_optimizer( + module, optimizer + ) + # gm.graph is modified in place with the added optimizer steps + _join_graphs(gm.graph, opt_graph) + n_opt_outs = len(opt_outputs) + for node in opt_outputs: + for n, p in module.named_parameters(): + if _normalize_name(node.name) == n: + parameters_optimizer.append(p) + + for n, p in module.named_parameters(): + for p_n in optimizer.state[p]: # type: ignore[index] + state_tensor = optimizer.state[p][p_n] # type: ignore[index] + if state_tensor is not None: + state_optimizer.append(state_tensor) + + # Create the function that deals with the optimizer outputs + # TODO(set this as arg) + supports_inplace = True + gm.recompile() # Sync the module to the graph changes + if user_backend is None: + func = gm + else: + func = user_backend(gm, inputs) + supports_inplace = False + n_params = len(parameters_optimizer) + + def _model_opt_func(*args, **kwargs): # type: ignore[no-untyped-def] + # Need to retrieve the optimizer state and concat it to the + # arguments + outs = func(*(args + tuple(state_optimizer)), **kwargs) + # Iterate the returned parameters and copy them into the + # Model real ones (sync) + if optimizer is not None: + opt_outs = outs[-n_opt_outs:] + if not supports_inplace: + for i in range(n_opt_outs): + if i < n_params: + parameters_optimizer[i].data.copy_(opt_outs[i]) + else: + state_optimizer[i - n_params].data.copy_( + opt_outs[i] + ) + return outs[:n_opt_outs] + return outs + + return make_boxed_func(_model_opt_func) + + # These are the first arguments passed to the functions + # They will be the names of the inputs, replacing the primals + # Extract the parameters name that the graph will use + for n, p in module.named_parameters(): + parameters_and_buffers.append(p) + names.append(n) + + partitioner = _splitter.JointGraph(names) + + aot_backend = aot_autograd( # type: ignore[no-untyped-call] + fw_compiler=_graph_getter, + bw_compiler=_dummy_bwd_backend, + partition_fn=partitioner._no_partition, + decompositions=core_aten_decompositions(), + ) + module_opt = torch.compile(module, backend=aot_backend) # type: ignore[attr-defined] + return cast(Callable[..., Any], module_opt) # type: ignore[redundant-cast] + + +def compile( + module: torch.nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + backend: Optional[Callable[..., Any]] = None, +) -> Callable[..., Any]: + """Compiles a module and an optimizer in a single graph using the provided backend. + + .. note:: + The backend object needs to be a callable accepting a ``torch.fx.GraphModule`` + and a list of ``torch.Tensor`` and return a ``Callable`` as specified by + https://pytorch.org/docs/2.0/dynamo/custom-backends.html#custom-backends + + Args: + module: + torch.nn.Module to be compiled + optimizer: + Optimizer object associated to the module. It will be traced and its + operations included in the module graph. Some dry run operations + may be performed to fully initialize the optimizer status. + backend (optional): + Object to process the graph and compile it for custom devices, will + use PyTorch dynamo by default if not specified. + """ + + module_opt = _compile_module(module, optimizer, backend) + return module_opt diff --git a/pytorch_pfn_extras/_dynamo/_optimizer.py b/pytorch_pfn_extras/_dynamo/_optimizer.py new file mode 100644 index 000000000..683059ae8 --- /dev/null +++ b/pytorch_pfn_extras/_dynamo/_optimizer.py @@ -0,0 +1,274 @@ +import contextlib +import types +from typing import Any, Dict, Generator, List, Tuple + +import torch +import torch.fx +import torch.fx.GraphModule + + +# patch the torch.optim.SGD._init_group function to avoid the +# symbolically traced variables cannot be used as inputs to control flow error +# by replacing this function in SGD optimizer instances +def _sgd_init_group( # type: ignore[no-untyped-def] + self, group, params_with_grad, d_p_list, momentum_buffer_list +): + has_sparse_grad = False + + for p in group["params"]: + if p.grad is not None: + params_with_grad.append(p) + d_p_list.append(p.grad) + # if p.grad.is_sparse: + # has_sparse_grad = True + has_sparse_grad = p.grad.is_sparse + + state = self.state[p] + if "momentum_buffer" not in state: + momentum_buffer_list.append(None) + else: + momentum_buffer_list.append(state["momentum_buffer"]) + + return has_sparse_grad + + +def _get_parameters_names( + model: torch.nn.Module, +) -> Dict[torch.Tensor, str]: + param_names: Dict[torch.Tensor, str] = {} + for n, p in model.named_parameters(): + param_names[p] = n + return param_names + + +@contextlib.contextmanager +def _initialize_optimizer( + optimizer: torch.optim.Optimizer, module: torch.nn.Module +) -> Generator[Dict[torch.Tensor, str], None, None]: + if isinstance(optimizer, torch.optim.SGD): + optimizer._init_group = types.MethodType(_sgd_init_group, optimizer) # type: ignore[attr-defined] + + # Replace the optimizer parameters with zero tensors so that the step functions + # will initialize the state but doesn"t modify the module real weights + # param_to_dummy = {} # Keeps references so that `state` tensors can be reset + param_groups: List[List[torch.Tensor]] = [] + param_to_dummy: Dict[torch.Tensor, torch.Tensor] = {} + with torch.fx.experimental.proxy_tensor.maybe_disable_fake_tensor_mode(): # type: ignore[attr-defined,no-untyped-call] + names = _get_parameters_names(module) + + for p_group in optimizer.param_groups: + param_groups.append([]) + for i, param in enumerate(p_group["params"]): + dummy = torch.zeros_like(param) + dummy.grad = torch.zeros_like(param) + # param_to_dummy[dummy] = param + param_groups[-1].append(param) + names[dummy] = names[param] + param_to_dummy[param] = dummy + p_group["params"][i] = dummy + + # This call will initialize the `.state` values so fx can trace its ops + optimizer.step() + + yield names + + with torch.fx.experimental.proxy_tensor.maybe_disable_fake_tensor_mode(): # type: ignore[attr-defined,no-untyped-call] + # Reset the optimizer original parameters + for i, p_group in enumerate(optimizer.param_groups): + for j, _ in enumerate(p_group["params"]): + # param = param_to_dummy[dummy] + param = param_groups[i][j] + p_group["params"][j] = param + dummy = param_to_dummy[param] + optimizer.state[param] = optimizer.state[dummy] # type: ignore[index] + + +def _create_meta(tensor: torch.Tensor) -> Dict[str, Any]: + return { + "val": None, + "tensor_meta": torch.fx.passes.shape_prop.TensorMetadata( # type: ignore[attr-defined, arg-type] + tensor.shape, + tensor.dtype, + tensor.requires_grad, + tensor.stride(), # type: ignore[arg-type] + torch.preserve_format, + tensor.data.is_quantized, + {}, + ), + } + + +def _get_shape_inference_inputs_and_metadata( + optimizer: torch.optim.Optimizer, +) -> Tuple[Dict[torch.Tensor, Dict[str, Any]], List[torch.Tensor]]: + params_meta = {} + inputs = [] + + with torch.fx.experimental.proxy_tensor.maybe_disable_fake_tensor_mode(): # type: ignore[attr-defined,no-untyped-call] + for p_group in optimizer.param_groups: + for i in range(len(p_group["params"])): + param_tensor = p_group["params"][i] + params_meta[param_tensor] = _create_meta(param_tensor) + inputs.append(param_tensor) + + for p_group in optimizer.param_groups: + for i in range(len(p_group["params"])): + param_tensor = p_group["params"][i] + for p_n in optimizer.state[param_tensor]: + state_tensor = optimizer.state[param_tensor][p_n] + if state_tensor is not None: + params_meta[state_tensor] = _create_meta(state_tensor) + inputs.append(state_tensor) + + return params_meta, inputs + + +def _create_placeholders_for_parameters_and_state( + optimizer: torch.optim.Optimizer, + names: Dict[torch.Tensor, str], + opt_graph: torch.fx.Graph, + params_meta: Dict[torch.Tensor, Dict[str, Any]], + tracer: torch.fx.proxy.GraphAppendingTracer, +) -> Tuple[List[torch.fx.Node], List[torch.fx.Node]]: + placeholders = [] + state = [] + + params_to_proxy = {} + for p_group in optimizer.param_groups: + for i in range(len(p_group["params"])): + # Find param in list + # May need to replace `.` with `@` + param_tensor = p_group["params"][i] + p_name = names[param_tensor].replace(".", "__dot__") + placeholders.append(opt_graph.placeholder(p_name)) + placeholders[-1].meta = params_meta[param_tensor] + proxy = torch.fx.Proxy(placeholders[i], tracer) + optimizer.state[proxy] = optimizer.state[param_tensor].copy() # type: ignore[index] + params_to_proxy[param_tensor] = proxy + + for p_group in optimizer.param_groups: + for i in range(len(p_group["params"])): + # Find param in list + # May need to replace `.` with `@` + param_tensor = p_group["params"][i] + p_name = names[param_tensor].replace(".", "__dot__") + proxy = params_to_proxy[param_tensor] + for p in optimizer.state[proxy]: # type: ignore[index] + state_tensor = optimizer.state[param_tensor][p] + if state_tensor is not None: + state.append(opt_graph.placeholder(f"state_{p}_{p_name}")) + optimizer.state[proxy][p] = torch.fx.Proxy( # type: ignore[index] + state[-1], tracer + ) + state[-1].meta = params_meta[state_tensor] + p_group["params"][i] = proxy + + return placeholders, state + + +def _is_inplace(node: torch.fx.Node, arg: torch.fx.Node) -> bool: + return ( + node.op == "call_method" # type: ignore[return-value] + and node.args[0] == arg + and node.target[-1] == "_" # type: ignore[index] + ) + + +def _get_last_inplace_update( + opt_graph: torch.fx.Graph, +) -> Dict[torch.fx.Node, torch.fx.Node]: + last_inplace = {} + for node in opt_graph.nodes: + last_node = node + for o_node in opt_graph.nodes: + # If its an inplace modifying op, then its likely to be the update + if _is_inplace(o_node, last_node): + last_node = o_node + last_inplace[node] = o_node + + return last_inplace + + +def _adjust_inplace_ops( + opt_graph: torch.fx.Graph, last_inplace: Dict[torch.fx.Node, torch.fx.Node] +) -> None: + for node in opt_graph.nodes: + args = list(node.args) + modified = False + for i, a in enumerate(args): + # find the node that lastly modified the arg inplace before the current node + if a in last_inplace: + last = a + for p_node in opt_graph.nodes: + if p_node == node: + break + # Identify the previous update for the current node arg + if _is_inplace(p_node, last): + last = p_node + args[i] = last + modified = True + + if modified: + node.args = tuple(args) + + +def _compile_optimizer( + module: torch.nn.Module, optimizer: torch.optim.Optimizer +) -> Tuple[torch.fx.Graph, List[torch.fx.Node]]: + # Do all the optimizer crap here + if not isinstance(module, torch.nn.Module): + raise RuntimeError( + "Optimizer needs module to be instance of torch.nn.Module" + ) + + opt_graph = torch.fx.Graph() + tracer = torch.fx.proxy.GraphAppendingTracer(opt_graph) + + # Gets all the optimizer registered inputs so we can run shape inference + with _initialize_optimizer(optimizer, module) as param_names: + params_meta, inputs = _get_shape_inference_inputs_and_metadata( + optimizer + ) + ( + param_placeholders, + state_placeholders, + ) = _create_placeholders_for_parameters_and_state( + optimizer, param_names, opt_graph, params_meta, tracer + ) + + # Trace the computation + optimizer.step() + # Look for the parameters and return their last known value + outputs = [] + + last_inplace = _get_last_inplace_update(opt_graph) + + # The last inplace update will be the optimizer outputs + def _get_last_update( + node_set: List[torch.fx.Node], out_prefix: str + ) -> None: + for node in node_set: + last_inplace[node].name = f"{out_prefix}{node.name}" + outputs.append(last_inplace[node]) + + # Add the last node updating a parameter to the graph outputs + _get_last_update(param_placeholders, "param_out_") + # Add the last node updating the state (e.g. momentum) to the graph outputs + _get_last_update(state_placeholders, "param_out_") + + # Make the nodes that have inplace ops as arguments to use the last inplace + # update right before the node itself. in some devices, inplace updates + # may not be real in-place ops and pytorch graph uses inplace ops nodes + # without caring about the order + # b = a.inplace_op() + # c = b.inplace_op() + # d = a + x # Here we use a, but since the value is updated it is equivalent + # # to use c + _adjust_inplace_ops(opt_graph, last_inplace) + opt_graph.output(outputs) + + with torch.fx.experimental.proxy_tensor.maybe_disable_fake_tensor_mode(): # type: ignore[attr-defined,no-untyped-call] + opt_module = torch.fx.GraphModule(module, opt_graph) + torch.fx.passes.shape_prop.ShapeProp(opt_module).propagate(*inputs) # type: ignore[attr-defined, no-untyped-call] + + return opt_graph, outputs diff --git a/pytorch_pfn_extras/_dynamo/_splitter.py b/pytorch_pfn_extras/_dynamo/_splitter.py new file mode 100644 index 000000000..c4a97b5b5 --- /dev/null +++ b/pytorch_pfn_extras/_dynamo/_splitter.py @@ -0,0 +1,119 @@ +from typing import Any, List, Optional, Tuple + +import torch +import torch.fx +import torch.fx.GraphModule +import torch.utils._pytree as pytree +from torch._functorch.partitioners import _is_primal, _is_tangent + + +class JointGraph: + def __init__(self, parameter_names: Optional[List[str]] = None): + if parameter_names is not None: + self._parameter_names = [ + n.replace(".", "__dot__") for n in parameter_names + ] + else: + self._parameter_names = [] + + def _no_partition( + self, + joint_module: torch.fx.GraphModule, + _joint_inputs: Any, + *, + num_fwd_outputs: int, + ) -> Tuple[torch.fx.GraphModule, torch.fx.GraphModule]: + primal_inputs: List[torch.fx.Node] = list( + filter(_is_primal, joint_module.graph.nodes) + ) + tangent_inputs: List[torch.fx.Node] = list( + filter(_is_tangent, joint_module.graph.nodes) + ) + outputs = pytree.tree_flatten( + [ + node.args + for node in joint_module.graph.nodes + if node.op == "output" + ] + )[0] + combined_graph = torch.fx.Graph() + env = {} + for i, node in enumerate(primal_inputs): + new_node = combined_graph.placeholder( + self._parameter_names[i] + if i < len(self._parameter_names) + else f"input_{i - len(self._parameter_names)}" + ) + new_node.meta = node.meta + env[node] = new_node + # The tangents will be transformed to constant ops + # Depending on the module this is different :( + # Maybe we can retrieve shape? + for node in tangent_inputs: + new_node = combined_graph.call_function( + torch.ops.aten.ones, args=(node.meta.get("tensor_meta").shape,) # type: ignore[union-attr] + ) + new_node.meta = node.meta + env[node] = new_node + + assert len(tangent_inputs) == num_fwd_outputs + + for node in joint_module.graph.nodes: + if node in primal_inputs or node in tangent_inputs: + continue + if node.op != "output": + env[node] = combined_graph.node_copy(node, lambda x: env[x]) + + outs = set() + combined_outputs = [] + # Some outputs are repeated, just return them only once + # We use `env[node]` to use newly created nodes for tangent if we are + # going to return them + for node in outputs: + if node is None: + # This is the case where the corresponding input doesn"t need a grad + continue # type: ignore[unreachable] + if env[node] not in outs: + combined_outputs.append(env[node]) + outs.add(env[node]) + + for i, node in enumerate(combined_outputs[:num_fwd_outputs]): + node.name = f"fwd_out_{i}" + for i, node in enumerate(combined_outputs[num_fwd_outputs:]): + if node is not None: + node.name = "grad_" + ( + self._parameter_names[i] + if i < len(self._parameter_names) + else f"input_{i - len(self._parameter_names)}" + ) + + combined_graph.output(combined_outputs) + fwd_module = torch.fx.GraphModule(joint_module, combined_graph) + # Now we create a graph for backward that is just the identities of the original inputs + # Since they are now known + bwd_graph = torch.fx.Graph() + bwd_outputs = outputs[num_fwd_outputs:] + out_nodes = [] + env = {} + + for node in bwd_outputs: + # Rename it + if node is None: + out_nodes.append(node) # type: ignore[unreachable] + continue + if node not in env: + new_node = bwd_graph.placeholder(node.name) + if node not in tangent_inputs: + env[node] = new_node + new_node.meta = node.meta + out_nodes.append(new_node) + + for node in tangent_inputs: + if node not in env: + new_node = bwd_graph.placeholder(node.name) + new_node.meta = node.meta + env[node] = new_node + + bwd_graph.output(out_nodes) + bwd_module = torch.fx.GraphModule(joint_module, bwd_graph) + return (fwd_module, bwd_module) diff --git a/pytorch_pfn_extras/testing.py b/pytorch_pfn_extras/testing.py new file mode 100644 index 000000000..73bca63dd --- /dev/null +++ b/pytorch_pfn_extras/testing.py @@ -0,0 +1,43 @@ +from typing import Any, Dict, List, Tuple, Union + +import torch + + +def _compare_states( + s1: Union[Dict[Any, Any], List[Any], Tuple[Any]], + s2: Union[Dict[Any, Any], List[Any], Tuple[Any]], + strict: bool = False, +) -> bool: + def allclose(a: torch.Tensor, b: torch.Tensor) -> bool: + if strict: + return bool((a == b).all()) + else: + return torch.allclose(a, b) + + if isinstance(s1, dict): + keys = list(s1.keys()) + assert isinstance(s2, dict) + if set(keys) != set(s2.keys()): + return False + elif isinstance(s1, (list, tuple)): + keys = list(range(len(s1))) + if len(s1) != len(s2): + return False + + all_equal = True + for k in keys: + if isinstance(s1[k], dict): + if not isinstance(s2[k], dict): + return False + all_equal = all_equal and _compare_states(s1[k], s2[k]) + elif isinstance(s1[k], (list, tuple)): + if not isinstance(s2[k], (list, tuple)): + return False + all_equal = all_equal and _compare_states(s1[k], s2[k]) + elif isinstance(s1[k], torch.Tensor): + all_equal = all_equal and allclose(s1[k], s2[k]) + else: + all_equal = all_equal and s1[k] == s2[k] + if not all_equal: + return all_equal + return all_equal diff --git a/stubs/torch/fx/__init__.pyi b/stubs/torch/fx/__init__.pyi new file mode 100644 index 000000000..3d82b2809 --- /dev/null +++ b/stubs/torch/fx/__init__.pyi @@ -0,0 +1,6 @@ +# flake8: noqa +from .graph import Graph as Graph +from .graph_module import GraphModule +from .node import Node as Node +from .proxy import GraphAppendingTracer as GraphAppendingTracer +from .proxy import Proxy as Proxy diff --git a/stubs/torch/fx/graph.pyi b/stubs/torch/fx/graph.pyi new file mode 100644 index 000000000..99e45703b --- /dev/null +++ b/stubs/torch/fx/graph.pyi @@ -0,0 +1,37 @@ +# flake8: noqa +from typing import Any, Callable, Dict, List, Optional, Tuple + +from .node import Node + +class Graph: + nodes: List[Node] + + def call_function( + self, + the_function: Callable[..., Any], + args: Optional[Any] = None, + kwargs: Optional[Any] = None, + type_expr: Optional[Any] = None, + ) -> Node: ... + def output(self, result: Any, type_expr: Optional[Any] = None) -> Node: ... + def placeholder( + self, + name: str, + type_expr: Optional[Any] = None, + default_value: Optional[Any] = None, + ) -> Node: ... + def node_copy( + self, node: Node, arg_transform: Callable[[Node], "Argument"] + ) -> Node: ... + def inserting_after(self, n: Optional[Node] = None): ... + def create_node( + self, + op: str, + target: Any, + args: Optional[Tuple[Any, ...]] = None, + kwargs: Optional[Dict[str, Any]] = None, + name: Optional[str] = None, + type_expr: Optional[Any] = None, + ) -> Node: ... + def erase_node(self, to_erase: Node) -> None: ... + ... diff --git a/stubs/torch/fx/graph_module.pyi b/stubs/torch/fx/graph_module.pyi new file mode 100644 index 000000000..276ec7d25 --- /dev/null +++ b/stubs/torch/fx/graph_module.pyi @@ -0,0 +1,17 @@ +# flake8: noqa +from typing import Any, Dict + +import torch + +from .graph import Graph + +class GraphModule: + graph: Graph + + def __init__( + self, + root: Union[torch.nn.Module, Dict[str, Any]], + graph: Graph, + class_name: str = "GraphModule", + ): ... + ... diff --git a/stubs/torch/fx/node.pyi b/stubs/torch/fx/node.pyi new file mode 100644 index 000000000..88af2a432 --- /dev/null +++ b/stubs/torch/fx/node.pyi @@ -0,0 +1,11 @@ +# flake8: noqa +from typing import Any, Callable, Dict, List, Optional, Tuple + +class Node: + op: str + name: str + target: Any + meta: Any + args: Tuple[Any, ...] + kwargs: Dict[str, Any] + ... diff --git a/stubs/torch/fx/proxy.pyi b/stubs/torch/fx/proxy.pyi new file mode 100644 index 000000000..0c0dc1dea --- /dev/null +++ b/stubs/torch/fx/proxy.pyi @@ -0,0 +1,10 @@ +# flake8: noqa +from typing import Any + +from .graph import Graph, Node + +class GraphAppendingTracer: + def __init__(self, graph: Graph): ... + +class Proxy: + def __init__(self, node: Node, tracer: GraphAppendingTracer): ... diff --git a/tests/pytorch_pfn_extras_tests/dynamo_tests/test_compile.py b/tests/pytorch_pfn_extras_tests/dynamo_tests/test_compile.py new file mode 100644 index 000000000..cb2ff02f6 --- /dev/null +++ b/tests/pytorch_pfn_extras_tests/dynamo_tests/test_compile.py @@ -0,0 +1,76 @@ +import sys + +import pytest +import pytorch_pfn_extras as ppe +import torch +from pytorch_pfn_extras import testing + + +class _DummyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 10) + + def forward(self, x): + return self.linear(x).sum() * torch.tensor(5.0) + + +@pytest.mark.skipif( + not ppe.requires("2.0.0") or sys.platform == "win32", + reason="torch.compile interface its only added in PyTorch>2.0 and linux", +) +def test_compile_with_optimizer(): + torch._dynamo.reset() + x = torch.randn(10, requires_grad=True) + torch_module = _DummyModule() + module_initial_state = torch_module.state_dict() + compiled_module = _DummyModule() + compiled_module.load_state_dict(module_initial_state) + + opt = torch.optim.SGD(torch_module.parameters(), lr=0.5, momentum=0.01) + y = torch_module(x) + y.backward() + opt.step() + + opt = torch.optim.SGD(compiled_module.parameters(), lr=0.5, momentum=0.01) + joint_module = ppe.compile(compiled_module, opt) + # This executes forward+backward+optimizer step + compiled_y = joint_module(x) + assert torch.allclose(y, compiled_y) + assert testing._compare_states( + torch_module.state_dict(), compiled_module.state_dict() + ) + # Run one more step and check that the weights now difer + compiled_y = joint_module(x) + assert not testing._compare_states( + torch_module.state_dict(), compiled_module.state_dict() + ) + + +@pytest.mark.skipif( + not ppe.requires("2.0.0") or sys.platform == "win32", + reason="torch.compile interface its only added in PyTorch>2.0 and linux", +) +def test_compile_without_optimizer(): + torch._dynamo.reset() + x = torch.randn(10, requires_grad=True) + torch_module = _DummyModule() + module_initial_state = torch_module.state_dict() + compiled_module = _DummyModule() + compiled_module.load_state_dict(module_initial_state) + + y = torch_module(x) + y.backward() + + joint_module = ppe.compile(compiled_module, None) + compiled_y = joint_module(x) + # Call backward so the dummy graph is executed and the gradients are set + # To all the tensors + compiled_y.backward() + assert torch.allclose(y, compiled_y) + assert torch.allclose( + torch_module.linear.weight.grad, compiled_module.linear.weight.grad + ) + assert torch.allclose( + torch_module.linear.bias.grad, compiled_module.linear.bias.grad + ) diff --git a/tests/pytorch_pfn_extras_tests/training_tests/test_trainer.py b/tests/pytorch_pfn_extras_tests/training_tests/test_trainer.py index cfcf4825b..5115690db 100644 --- a/tests/pytorch_pfn_extras_tests/training_tests/test_trainer.py +++ b/tests/pytorch_pfn_extras_tests/training_tests/test_trainer.py @@ -5,7 +5,7 @@ import pytest import pytorch_pfn_extras as ppe import torch -from pytorch_pfn_extras import engine, training +from pytorch_pfn_extras import engine, testing, training from pytorch_pfn_extras.training import triggers from torch import nn from torch.nn import functional as F @@ -360,41 +360,6 @@ def get_result_from_training_loop(): assert torch.equal(a, e) -def _compare_states(s1, s2, strict=False): - def allclose(a, b): - if strict: - return (a == b).all() - else: - return torch.allclose(a, b) - - if isinstance(s1, dict): - keys = s1.keys() - if set(keys) != set(s2.keys()): - return False - elif isinstance(s1, (list, tuple)): - keys = range(len(s1)) - if len(s1) != len(s2): - return False - - all_equal = True - for k in keys: - if isinstance(s1[k], dict): - if not isinstance(s2[k], dict): - return False - all_equal = all_equal and _compare_states(s1[k], s2[k]) - elif isinstance(s1[k], (list, tuple)): - if not isinstance(s2[k], (list, tuple)): - return False - all_equal = all_equal and _compare_states(s1[k], s2[k]) - elif isinstance(s1[k], torch.Tensor): - all_equal = all_equal and allclose(s1[k], s2[k]) - else: - all_equal = all_equal and s1[k] == s2[k] - if not all_equal: - return all_equal - return all_equal - - class TestTrainerState: def _get_trainer( self, @@ -442,11 +407,11 @@ def test_trainer_state(self, path): torch.manual_seed(0) new_trainer = self._get_trainer(10, path) new_trainer.run(data) - assert not _compare_states(state, new_trainer.state_dict()) + assert not testing._compare_states(state, new_trainer.state_dict()) new_trainer = self._get_trainer(20, path) new_trainer.load_state_dict(trainer.state_dict()) new_trainer.run(data) - assert _compare_states(state, new_trainer.state_dict()) + assert testing._compare_states(state, new_trainer.state_dict()) def test_trainer_autoload(self, path): trainer = self._get_trainer(20, path) @@ -471,7 +436,9 @@ def test_trainer_autoload(self, path): # This forces engine initialization new_trainer._setup_manager(len(data)) assert new_trainer.epoch == 20 - assert _compare_states(trainer.state_dict(), new_trainer.state_dict()) + assert testing._compare_states( + trainer.state_dict(), new_trainer.state_dict() + ) def test_trainer_autoload_training_results_consistency(self, path): snapshot_epoch = 10 @@ -504,7 +471,7 @@ def test_trainer_autoload_training_results_consistency(self, path): print(trainer.state_dict().keys()) trainer_state_dict = trainer.state_dict() new_trainer_state_dict = new_trainer.state_dict() - assert _compare_states( + assert testing._compare_states( trainer_state_dict["models"], new_trainer_state_dict["models"], strict=True, @@ -544,7 +511,7 @@ def test_trainer_autoload_training_results_consistency_with_gpu(self, path): print(trainer.state_dict().keys()) trainer_state_dict = trainer.state_dict() new_trainer_state_dict = new_trainer.state_dict() - assert _compare_states( + assert testing._compare_states( trainer_state_dict["models"], new_trainer_state_dict["models"], strict=True, @@ -610,7 +577,7 @@ def test_trainer_autoload_training_results_consistency_with_gradscaler( trainer_state_dict = trainer.state_dict() new_trainer_state_dict = new_trainer.state_dict() - assert _compare_states( + assert testing._compare_states( trainer_state_dict["models"], new_trainer_state_dict["models"], strict=True, From bf572b424c123877ca88a7183bba21c3c2245653 Mon Sep 17 00:00:00 2001 From: Emilio Castillo Date: Tue, 27 Jun 2023 08:09:56 +0000 Subject: [PATCH 2/9] cleanup --- stubs/torch/fx/__init__.pyi | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stubs/torch/fx/__init__.pyi b/stubs/torch/fx/__init__.pyi index 3d82b2809..c2e500fc1 100644 --- a/stubs/torch/fx/__init__.pyi +++ b/stubs/torch/fx/__init__.pyi @@ -1,6 +1,6 @@ # flake8: noqa from .graph import Graph as Graph -from .graph_module import GraphModule +from .graph_module import GraphModule as GraphModule from .node import Node as Node from .proxy import GraphAppendingTracer as GraphAppendingTracer from .proxy import Proxy as Proxy From c89858b23d59a15d66a904feba8bab50cba57277 Mon Sep 17 00:00:00 2001 From: Emilio Castillo Date: Tue, 27 Jun 2023 08:12:23 +0000 Subject: [PATCH 3/9] cleanup --- pytorch_pfn_extras/_dynamo/_compile.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_pfn_extras/_dynamo/_compile.py b/pytorch_pfn_extras/_dynamo/_compile.py index fbf59233d..26f180302 100644 --- a/pytorch_pfn_extras/_dynamo/_compile.py +++ b/pytorch_pfn_extras/_dynamo/_compile.py @@ -2,7 +2,6 @@ import torch import torch.fx -import torch.fx.GraphModule import torch.utils._pytree as pytree from functorch.compile import make_boxed_func from pytorch_pfn_extras._dynamo import _optimizer, _splitter From c51afb7c09f452a5204695a21d9890b49cbbd240 Mon Sep 17 00:00:00 2001 From: Emilio Castillo Date: Tue, 27 Jun 2023 08:15:41 +0000 Subject: [PATCH 4/9] cleanup --- pytorch_pfn_extras/_dynamo/_optimizer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_pfn_extras/_dynamo/_optimizer.py b/pytorch_pfn_extras/_dynamo/_optimizer.py index 683059ae8..da435fbdb 100644 --- a/pytorch_pfn_extras/_dynamo/_optimizer.py +++ b/pytorch_pfn_extras/_dynamo/_optimizer.py @@ -4,7 +4,6 @@ import torch import torch.fx -import torch.fx.GraphModule # patch the torch.optim.SGD._init_group function to avoid the From d69e212d048315754f938dedc10c2571c1a14ae4 Mon Sep 17 00:00:00 2001 From: Emilio Castillo Date: Tue, 27 Jun 2023 08:16:57 +0000 Subject: [PATCH 5/9] cleanup --- pytorch_pfn_extras/_dynamo/_splitter.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_pfn_extras/_dynamo/_splitter.py b/pytorch_pfn_extras/_dynamo/_splitter.py index c4a97b5b5..a4f548d03 100644 --- a/pytorch_pfn_extras/_dynamo/_splitter.py +++ b/pytorch_pfn_extras/_dynamo/_splitter.py @@ -2,7 +2,6 @@ import torch import torch.fx -import torch.fx.GraphModule import torch.utils._pytree as pytree from torch._functorch.partitioners import _is_primal, _is_tangent From ccb99dbb9444c4843ae59f98343c8907cf6fd79d Mon Sep 17 00:00:00 2001 From: Emilio Castillo Date: Wed, 12 Jul 2023 02:29:51 +0000 Subject: [PATCH 6/9] Review fixes I --- pytorch_pfn_extras/_dynamo/_compile.py | 2 +- pytorch_pfn_extras/_dynamo/_optimizer.py | 37 +++++++++++++++--------- pytorch_pfn_extras/_dynamo/_splitter.py | 15 ++++++++++ 3 files changed, 39 insertions(+), 15 deletions(-) diff --git a/pytorch_pfn_extras/_dynamo/_compile.py b/pytorch_pfn_extras/_dynamo/_compile.py index 26f180302..50e095d66 100644 --- a/pytorch_pfn_extras/_dynamo/_compile.py +++ b/pytorch_pfn_extras/_dynamo/_compile.py @@ -15,7 +15,7 @@ def _dummy_bwd_backend( ) -> Any: # The bwd pass is dummy, so we just return the inputs as they are def run_graph(*args, **kwargs): # type: ignore[no-untyped-def] - return args[:-1] + return gm(*args, **kwargs) return make_boxed_func(run_graph) diff --git a/pytorch_pfn_extras/_dynamo/_optimizer.py b/pytorch_pfn_extras/_dynamo/_optimizer.py index da435fbdb..a1114ee30 100644 --- a/pytorch_pfn_extras/_dynamo/_optimizer.py +++ b/pytorch_pfn_extras/_dynamo/_optimizer.py @@ -60,7 +60,6 @@ def _initialize_optimizer( for i, param in enumerate(p_group["params"]): dummy = torch.zeros_like(param) dummy.grad = torch.zeros_like(param) - # param_to_dummy[dummy] = param param_groups[-1].append(param) names[dummy] = names[param] param_to_dummy[param] = dummy @@ -75,7 +74,6 @@ def _initialize_optimizer( # Reset the optimizer original parameters for i, p_group in enumerate(optimizer.param_groups): for j, _ in enumerate(p_group["params"]): - # param = param_to_dummy[dummy] param = param_groups[i][j] p_group["params"][j] = param dummy = param_to_dummy[param] @@ -105,16 +103,16 @@ def _get_shape_inference_inputs_and_metadata( with torch.fx.experimental.proxy_tensor.maybe_disable_fake_tensor_mode(): # type: ignore[attr-defined,no-untyped-call] for p_group in optimizer.param_groups: - for i in range(len(p_group["params"])): - param_tensor = p_group["params"][i] + for param in p_group["params"]: + param_tensor = param params_meta[param_tensor] = _create_meta(param_tensor) inputs.append(param_tensor) for p_group in optimizer.param_groups: - for i in range(len(p_group["params"])): - param_tensor = p_group["params"][i] - for p_n in optimizer.state[param_tensor]: - state_tensor = optimizer.state[param_tensor][p_n] + for param in p_group["params"]: + param_tensor = param + optimizer_state: Dict[Any, Any] = optimizer.state[param_tensor] + for state_tensor in optimizer_state.values(): if state_tensor is not None: params_meta[state_tensor] = _create_meta(state_tensor) inputs.append(state_tensor) @@ -134,10 +132,14 @@ def _create_placeholders_for_parameters_and_state( params_to_proxy = {} for p_group in optimizer.param_groups: - for i in range(len(p_group["params"])): + for i, param in enumerate(p_group["params"]): # Find param in list - # May need to replace `.` with `@` - param_tensor = p_group["params"][i] + param_tensor = param + # Dynamo uses the parameters names to create a python function + # if special symbols in the parameter names are not replaced + # the definition will be ill-formed and look like: + # def forward(self, linear.weight, ...) + # causing a syntax error p_name = names[param_tensor].replace(".", "__dot__") placeholders.append(opt_graph.placeholder(p_name)) placeholders[-1].meta = params_meta[param_tensor] @@ -146,10 +148,10 @@ def _create_placeholders_for_parameters_and_state( params_to_proxy[param_tensor] = proxy for p_group in optimizer.param_groups: - for i in range(len(p_group["params"])): + for i, param in enumerate(p_group["params"]): # Find param in list # May need to replace `.` with `@` - param_tensor = p_group["params"][i] + param_tensor = param p_name = names[param_tensor].replace(".", "__dot__") proxy = params_to_proxy[param_tensor] for p in optimizer.state[proxy]: # type: ignore[index] @@ -166,6 +168,8 @@ def _create_placeholders_for_parameters_and_state( def _is_inplace(node: torch.fx.Node, arg: torch.fx.Node) -> bool: + # There is no easy way to detect inplace ops in torch, but they are + # defined as tensor methods with a "_" suffix. ("add_", "mul_") return ( node.op == "call_method" # type: ignore[return-value] and node.args[0] == arg @@ -191,6 +195,11 @@ def _get_last_inplace_update( def _adjust_inplace_ops( opt_graph: torch.fx.Graph, last_inplace: Dict[torch.fx.Node, torch.fx.Node] ) -> None: + # This is to avoid cases such as: + # b = a.add_() # a is modified in place and returns the value + # # a=b with dynamo, but it is possible to be a!=b in other backends + # c = torch.exp(a) # if in-place is supported this is correct, but if not we want to be torch.exp(b) + # This behavior is seen in momentum update for SGD for node in opt_graph.nodes: args = list(node.args) modified = False @@ -262,7 +271,7 @@ def _get_last_update( # b = a.inplace_op() # c = b.inplace_op() # d = a + x # Here we use a, but since the value is updated it is equivalent - # # to use c + # # to use c, and c should be used to ensure correctness. _adjust_inplace_ops(opt_graph, last_inplace) opt_graph.output(outputs) diff --git a/pytorch_pfn_extras/_dynamo/_splitter.py b/pytorch_pfn_extras/_dynamo/_splitter.py index a4f548d03..8b4b80c6f 100644 --- a/pytorch_pfn_extras/_dynamo/_splitter.py +++ b/pytorch_pfn_extras/_dynamo/_splitter.py @@ -22,6 +22,21 @@ def _no_partition( *, num_fwd_outputs: int, ) -> Tuple[torch.fx.GraphModule, torch.fx.GraphModule]: + """The calculation graph, traced in an end-to-end manner, + of the forward-backward computation is divided into the forward graph and + the backward graph. The forward graph includes the whole calculation process + up until the computation of gradients. The backward graph is split as + an identity function concerning the gradients. + + Args: + joint_module: The end-to-end calculation graph. + _joint_inputs: Example inputs. + num_fwd_outputs: Number of forward outputs. + Returns: + The two returned GraphModule objects must adhere to the following interface: + - forward_graph_module: ((*primal_inputs) -> any_outputs) + - backward_graph_module: ((subset_of_forward_outputs, *tangent_inputs) -> parameter_grads) + """ primal_inputs: List[torch.fx.Node] = list( filter(_is_primal, joint_module.graph.nodes) ) From ee966201d1404f8b0f397c2867a7aa01b48b5629 Mon Sep 17 00:00:00 2001 From: Emilio Castillo Date: Wed, 12 Jul 2023 02:42:27 +0000 Subject: [PATCH 7/9] Review fixes II --- pytorch_pfn_extras/_dynamo/_compile.py | 6 +++- .../dynamo_tests/test_compile.py | 36 +++++++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/pytorch_pfn_extras/_dynamo/_compile.py b/pytorch_pfn_extras/_dynamo/_compile.py index 50e095d66..a0767aa77 100644 --- a/pytorch_pfn_extras/_dynamo/_compile.py +++ b/pytorch_pfn_extras/_dynamo/_compile.py @@ -203,7 +203,7 @@ def _model_opt_func(*args, **kwargs): # type: ignore[no-untyped-def] partition_fn=partitioner._no_partition, decompositions=core_aten_decompositions(), ) - module_opt = torch.compile(module, backend=aot_backend) # type: ignore[attr-defined] + module_opt = torch.compile(module, fullgraph=True, backend=aot_backend) # type: ignore[attr-defined] return cast(Callable[..., Any], module_opt) # type: ignore[redundant-cast] @@ -219,6 +219,10 @@ def compile( and a list of ``torch.Tensor`` and return a ``Callable`` as specified by https://pytorch.org/docs/2.0/dynamo/custom-backends.html#custom-backends + .. note:: + Modules that are split in multiple graphs are not supported. ``torch.compiled`` + is called with the ``fullgraph=True`` argument. + Args: module: torch.nn.Module to be compiled diff --git a/tests/pytorch_pfn_extras_tests/dynamo_tests/test_compile.py b/tests/pytorch_pfn_extras_tests/dynamo_tests/test_compile.py index cb2ff02f6..aea5624fe 100644 --- a/tests/pytorch_pfn_extras_tests/dynamo_tests/test_compile.py +++ b/tests/pytorch_pfn_extras_tests/dynamo_tests/test_compile.py @@ -15,6 +15,18 @@ def forward(self, x): return self.linear(x).sum() * torch.tensor(5.0) +class _DummyModuleSplit(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 10) + + def forward(self, x): + if x.sum() < 0: + return self.linear(x).sum() * torch.tensor(5.0) + else: + return self.linear(x).sum() * torch.tensor(1.0) + + @pytest.mark.skipif( not ppe.requires("2.0.0") or sys.platform == "win32", reason="torch.compile interface its only added in PyTorch>2.0 and linux", @@ -74,3 +86,27 @@ def test_compile_without_optimizer(): assert torch.allclose( torch_module.linear.bias.grad, compiled_module.linear.bias.grad ) + + +@pytest.mark.skipif( + not ppe.requires("2.0.0") or sys.platform == "win32", + reason="torch.compile interface its only added in PyTorch>2.0 and linux", +) +def test_compile_with_optimizer_and_split_graph(): + torch._dynamo.reset() + x = torch.randn(10, requires_grad=True) + torch_module = _DummyModuleSplit() + module_initial_state = torch_module.state_dict() + compiled_module = _DummyModuleSplit() + compiled_module.load_state_dict(module_initial_state) + + opt = torch.optim.SGD(torch_module.parameters(), lr=0.5, momentum=0.01) + y = torch_module(x) + y.backward() + opt.step() + + opt = torch.optim.SGD(compiled_module.parameters(), lr=0.5, momentum=0.01) + joint_module = ppe.compile(compiled_module, opt) + # This executes forward+backward+optimizer step + with pytest.raises(torch._dynamo.exc.Unsupported): + joint_module(x) From 395287ea56b0ca54762c14d015da30f2d8076936 Mon Sep 17 00:00:00 2001 From: Emilio Castillo Date: Wed, 12 Jul 2023 02:56:37 +0000 Subject: [PATCH 8/9] use root module --- pytorch_pfn_extras/_dynamo/_optimizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_pfn_extras/_dynamo/_optimizer.py b/pytorch_pfn_extras/_dynamo/_optimizer.py index a1114ee30..e8cc8240b 100644 --- a/pytorch_pfn_extras/_dynamo/_optimizer.py +++ b/pytorch_pfn_extras/_dynamo/_optimizer.py @@ -276,7 +276,7 @@ def _get_last_update( opt_graph.output(outputs) with torch.fx.experimental.proxy_tensor.maybe_disable_fake_tensor_mode(): # type: ignore[attr-defined,no-untyped-call] - opt_module = torch.fx.GraphModule(module, opt_graph) + opt_module = torch.fx.GraphModule(torch.nn.Module(), opt_graph) torch.fx.passes.shape_prop.ShapeProp(opt_module).propagate(*inputs) # type: ignore[attr-defined, no-untyped-call] return opt_graph, outputs From f687c729b39efbe3258c21007267b504ed742ac9 Mon Sep 17 00:00:00 2001 From: Emilio Castillo Date: Wed, 12 Jul 2023 02:59:07 +0000 Subject: [PATCH 9/9] typing --- pytorch_pfn_extras/_dynamo/_compile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_pfn_extras/_dynamo/_compile.py b/pytorch_pfn_extras/_dynamo/_compile.py index a0767aa77..c7a9f2daa 100644 --- a/pytorch_pfn_extras/_dynamo/_compile.py +++ b/pytorch_pfn_extras/_dynamo/_compile.py @@ -15,7 +15,7 @@ def _dummy_bwd_backend( ) -> Any: # The bwd pass is dummy, so we just return the inputs as they are def run_graph(*args, **kwargs): # type: ignore[no-untyped-def] - return gm(*args, **kwargs) + return gm(*args, **kwargs) # type: ignore[operator] return make_boxed_func(run_graph)