Skip to content

Commit

Permalink
[DORT] Use new FX-to-ONNX exporter (#16450)
Browse files Browse the repository at this point in the history
The ONNX exporter in DORT have been moved to PyTorch as a formal
feature. We therefore switch to consume the exporter from PyTorch
instead of maintaining two duplicates.
  • Loading branch information
wschin authored Jul 4, 2023
1 parent d540c7d commit a0a5f57
Show file tree
Hide file tree
Showing 7 changed files with 198 additions and 250 deletions.
11 changes: 0 additions & 11 deletions orttraining/orttraining/python/training/torchdynamo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,3 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------

from typing import Set

# set of custom ops supported in DORT.
# It can contain, for example, `aten::custom_add`.
custom_symbols: Set[str] = set()


# register custom ops in DORT
def register_custom_op_in_dort(custom_op_name: str):
custom_symbols.add(custom_op_name)
253 changes: 52 additions & 201 deletions orttraining/orttraining/python/training/torchdynamo/ort_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import dataclasses
import logging
from typing import Any, Callable, Dict, Mapping, Set, Tuple, Union
from typing import Any, Dict, Mapping, Tuple, Union

import numpy as np
import onnx
Expand All @@ -16,19 +16,47 @@
import torch.fx
import torch.jit
import torch.onnx

# TODO(wschin,justinchuby): Since the internal APIs are not stable, please
# contact us if you hit errors.
import torch.onnx._internal
import torch.onnx._internal.diagnostics
import torch.onnx._internal.exporter
import torch.onnx._internal.fx.decomposition_table
import torch.onnx._internal.fx.passes
import torch.onnx._onnx_supported_ops
from torch._decomp import decomposition_table
from torch._subclasses.fake_tensor import FakeTensor
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch.fx.passes.operator_support import OperatorSupport
from torch.fx.passes.tools_common import CALLABLE_NODE_OPS
from torch.onnx._globals import GLOBALS as ONNX_GLOBALS

import onnxruntime # type: ignore
from onnxruntime.capi import _pybind_state as ORTC

from . import custom_symbols
# DEFAULT_ONNX_EXPORTER_OPTIONS contains shared information between exporter and DORT.
# For example, they should use the same decomposition table to maintain the same set
# operators when
# 1. capturing FX graph in torch.compile
# 2. call exporter's API to convert `torch.fx.GraphModule` to ONNX model.
DEFAULT_ONNX_EXPORTER_OPTIONS = torch.onnx._internal.exporter.ResolvedExportOptions(
torch.onnx._internal.exporter.ExportOptions()
)

# TODO(wechi): This line must generate result identical to the call of
# _create_onnx_supports_op_overload_table(...) inside
# create_onnx_friendly_decomposition_table(...) in
# torch/onnx/_internal/fx/decomposition_table.py.
_SUPPORT_DICT = torch.onnx._internal.fx.decomposition_table._create_onnx_supports_op_overload_table(
DEFAULT_ONNX_EXPORTER_OPTIONS.onnx_registry
) # type: ignore

_EXTRA_SUPPORT_DICT: Dict[str, Any] = {
"getattr": None,
"_operator.getitem": None,
}

DORT_DECOMPOSITION_TABLE = DEFAULT_ONNX_EXPORTER_OPTIONS.decomposition_table

_NP_DTYPE = {
torch.float16: np.float16,
Expand Down Expand Up @@ -68,7 +96,7 @@ def _get_ort_device_type(device_type: str):
return ORTC.OrtDevice.cpu() # type: ignore
# ort pytorch device is mapped to NPU OrtDevice type
if device_type == "ort":
return ORTC.OrtDevice.npu()
return ORTC.OrtDevice.npu() # type: ignore
raise ValueError("Unsupported device type: " + device_type)


Expand All @@ -78,130 +106,6 @@ def _get_ort_device_type(device_type: str):
# logger.setLevel(logging.INFO)


def _get_onnx_supported_table() -> Set[str]:
# TODO(wechi): this entire function should be replaced by a formal a exporter API.

onnx_supported_ops: Set[str] = set()
for aten_op_name, schema in torch.onnx._onnx_supported_ops.all_symbolics_schemas().items():
# TODO(wechi): aten_op_name could be prim::add in addition to aten::add.
# We should build another dictionary for storing support table for prim ops.
# Currently, we only consider aten ops as before.
if aten_op_name not in custom_symbols and not aten_op_name.startswith("aten::"):
logger.info(
"Skip %s in support table because it's not in aten domain or supported custom ops %s",
aten_op_name,
custom_symbols,
)
continue
short_op_name = aten_op_name.split("::")[1]
if aten_op_name.startswith("aten::") and not hasattr(torch.ops.aten, short_op_name): # type: ignore
# Some aten ops are not in torch.ops.aten. Those are excluded until we
# figure out why.
logger.info("Skip %s in support table because it's not found in torch.ops.aten.", aten_op_name)
continue
# aten_op_name is aten symbol's name; e.g., "sum" for aten::sum.
# opsets_string is the ONNX opsets that can express info[0]; e.g., "15 16 17"
# indicates that opset 15, opset 16, and opset 17 can all express aten_op_name.
if ONNX_GLOBALS.export_onnx_opset_version in schema.opsets:
logger.info("Add %s to support table.", aten_op_name)
onnx_supported_ops.add(aten_op_name)
return onnx_supported_ops


def _get_support_dictionaries_and_decomposition_tables() -> (
Tuple[
Dict[torch._ops.OpOverload, Any],
Dict[str, Any],
Dict[torch._ops.OpOverload, Callable],
Dict[torch._ops.OpOverload, Callable],
]
):
# The keys of this dictionary are OpOverload's which can be
# exported by ONNX exporter. Type of key is torch._ops.OpOverload.
# For example, if torch.ops.aten.add.default is a key in support_dict,
# all torch.fx.Node's with torch.ops.aten.add.default as target will
# be selected by CapabilityBasedPartitioner and sent to ORT for
# computation.
# We choose torch._ops.OpOverload as the key because
# 1. torch._ops.OpOverload uniquely identifies an op. We don't want
# to use OpOverloadPacket because it contains overloads of the same op.
# This allows us to select supported ops at the finest grain.
# 2. torch._ops.OpOverload is what we get from torch.fx.Node.target. Getting
# qualified name using _get_qualified_name is not needed.
support_dictionary: Dict[torch._ops.OpOverload, Any] = {}
for aten_op_name in _get_onnx_supported_table():
if aten_op_name.startswith("aten::"):
short_op_name = aten_op_name.split("aten::")[1]
op_overload_packet = getattr(torch.ops.aten, short_op_name) # type: ignore
# Due to the lack of overload name in exporting function's name, assume
# each exporting function (e.g., torch.onnx.symbolic_opset9.add) support
# all overloads (e.g., in torch.ops.aten.add).
# Thus, we register all torch._ops.OpOverload's for the same exporting function.
# Please manually exclude torch._ops.OpOverload if exporter fails.
for overload in op_overload_packet.overloads():
op_overload = getattr(op_overload_packet, overload)
support_dictionary[op_overload] = None

elif aten_op_name in custom_symbols:
op_namespace = aten_op_name.split("::")[0]
short_op_name = aten_op_name.split("::")[1]
# Get the custom ops from: torch.ops.custom_namespace
custom_op_namespace = getattr(torch.ops, op_namespace)
op_overload_packet = getattr(custom_op_namespace, short_op_name) # type: ignore
for overload in op_overload_packet.overloads():
op_overload = getattr(op_overload_packet, overload)
support_dictionary[op_overload] = None

# No decomposition table. OpOverload in this table shouldn't be found
# in aten2aten_decomposition_table.
# The symbols in this set will be replaced by torch.ops.aten.to.dtype in replace_to_copy_with_to because
# only aten.to has ONNX exporter.
# If the replacement fails, ONNX exporter will fail because only aten.to has ONNX exporter.
# TODO(wechi): For a long-term solution, we need to ensure every op used in op decomposision has
# an exporter.
no_decomposition_table: Set[torch._ops.OpOverload] = {
torch.ops.aten._to_copy.default, # type: ignore
torch.ops.aten._to_copy.out, # type: ignore
}

# decomposition_table currently contains both aten2aten and aten2prim decompositions
# This is a hack to separate them, as ONNX only recognizes aten symbols.
aten2aten_decomposition_table: Dict[torch._ops.OpOverload, Callable] = {}
aten2prim_decomposition_table: Dict[torch._ops.OpOverload, Callable] = {}

for op_overload, decomp_fn in decomposition_table.items():
if op_overload in support_dictionary:
# ONNX can express this op, no need to decompose.
continue
if "torch._refs" in decomp_fn.__module__:
aten2prim_decomposition_table[op_overload] = decomp_fn
else:
if op_overload in no_decomposition_table:
continue
# Assume ONNX can express ops after decomposition.
# If no, exporter will fail and the user need to
# remove this decomposition rule.
aten2aten_decomposition_table[op_overload] = decomp_fn

# Some torch.fx.Node's are converted to ONNX-compatible ops
# by torch.jit.script. They don't have direct ONNX exporting
# functions but still runnable in ORT.
extra_support_dictionary: Dict[str, Any] = {
"getattr": None,
"_operator.getitem": None,
}

return support_dictionary, extra_support_dictionary, aten2aten_decomposition_table, aten2prim_decomposition_table


(
_SUPPORT_DICT,
_EXTRA_SUPPORT_DICT,
ATEN2ATEN_DECOMP,
ATEN2PRIM_DECOMP,
) = _get_support_dictionaries_and_decomposition_tables()


class OrtOperatorSupport(OperatorSupport):
"""
Operator support for ONNXRuntime backend. It has two-level of support decision.
Expand Down Expand Up @@ -234,31 +138,6 @@ def is_node_supported(self, submodules: Mapping[str, torch.nn.Module], node: tor
return False


def _jit_graph_to_onnx_model(graph, operator_export_type):
r"""
This function exports torch::jit::Graph object
to serialized ONNX ModelProto.
It only keeps the essential parts for IR graph conversions.
It also does not interact with actual PyTorch modules nor
PyTorch tensor inputs.
"""
graph = torch.onnx.utils._optimize_graph(graph, operator_export_type, params_dict={})
proto, _, _, _ = graph._export_onnx( # type: ignore
{},
ONNX_GLOBALS.export_onnx_opset_version,
{},
False,
operator_export_type,
False,
False,
{},
True,
"",
{},
)
return proto


def _move_placeholder_to_front(graph_module: torch.fx.GraphModule) -> None:
"""
In torch.fx.Graph, placehoder is a special assignment node. If it's not
Expand Down Expand Up @@ -316,42 +195,6 @@ def _replace_to_copy_with_to(fx_module: torch.fx.GraphModule) -> None:
fx_module.recompile()


def _fx_to_torchscript(
fx_module: torch.fx.GraphModule,
) -> torch.jit.ScriptModule:
"""Convert torch.fx.Graph to torch.jit.ScriptModule."""

for node in fx_module.graph.nodes:
new_kwargs = {}
for k, v in node.kwargs.items():
if isinstance(v, torch.device):
v = v.type # noqa: PLW2901
new_kwargs[k] = v
node.kwargs = new_kwargs
for node in fx_module.graph.nodes:
if isinstance(node.target, torch._ops.OpOverload):
node.target = node.target.overloadpacket
fx_module.graph.lint()
fx_module.recompile()
return torch.jit.script(fx_module) # type: ignore


def _decorate_script_module(script_module: torch.jit.ScriptModule, expected_inputs, expected_outputs):
for i, input_value in enumerate(script_module.graph.inputs()): # type: ignore
if input_value.debugName() == "self":
script_module.graph.eraseInput(i) # type: ignore
break
for input_value, expected_input in zip(script_module.graph.inputs(), expected_inputs): # type: ignore
input_value.setType(torch._C.TensorType.create_from_tensor(expected_input))
for output_value, expected_output in zip(script_module.graph.outputs(), expected_outputs): # type: ignore
output_value.setType(torch._C.TensorType.create_from_tensor(expected_output))


def _create_onnx_proto(script_module):
onnx_proto = _jit_graph_to_onnx_model(script_module.graph, torch.onnx.OperatorExportTypes.ONNX)
return onnx_proto


def _create_onnx_model(onnx_proto):
return onnx.ModelProto.FromString(onnx_proto)

Expand Down Expand Up @@ -554,17 +397,25 @@ def _ort_acclerated_call(self, graph_module: torch.fx.GraphModule, *args, **kwar
# rethrow FakeTensorProb failure because it is not yet currently handled.
raise
self._ort_execution_info.example_outputs[graph_module] = prim_outputs
# Compile the torch.fx.GraphModule into a torch.jit.ScriptModule.
script_module = _fx_to_torchscript(graph_module)
# Post-processing step to add expected input and output type information
# to the graph in torch.jit.ScriptModule. Expected inputs is "args" and "kwargs"
# while expected outputs is "prim_outputs".
if isinstance(prim_outputs, tuple):
_decorate_script_module(script_module, args, prim_outputs)
else:
_decorate_script_module(script_module, args, (prim_outputs,))
# Generate ONNX ModelProto from torch._C.Graph.
onnx_proto = _create_onnx_proto(script_module)

from torch.onnx._internal.fx import fx_onnx_interpreter

# Create the object to iterate through the nodes in graph one-by-one
# and calls the corresponding ONNX exporter for each node.
fx_interpreter = fx_onnx_interpreter.FxOnnxInterpreter(
diagnostic_context=DEFAULT_ONNX_EXPORTER_OPTIONS.diagnostic_context
)
# Start the per-node exporting process. It's conceptually a for loop
# scanning through the nodes in the graph.
exported = fx_interpreter.run(
fx_graph_module=graph_module,
onnxfunction_dispatcher=DEFAULT_ONNX_EXPORTER_OPTIONS.onnxfunction_dispatcher,
op_level_debug=DEFAULT_ONNX_EXPORTER_OPTIONS.op_level_debug,
)
# Convert the exported result to ONNX ModelProto.
onnx_proto = exported.to_model_proto(
opset_version=DEFAULT_ONNX_EXPORTER_OPTIONS.opset_version
).SerializeToString()

# Initialize a ORT session to execute this ONNX model.
# TorchDynamo assumes all inputs/outputs are on the same device,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from functorch.compile import min_cut_rematerialization_partition
from torch._dynamo.backends.common import aot_autograd

from .ort_backend import ATEN2ATEN_DECOMP, OrtBackend
from .ort_backend import DORT_DECOMPOSITION_TABLE, OrtBackend

# This should be the underlying compiler for ALL graphs if
# the user uses ORT to accelerate PyTorch via Dynamo.
Expand All @@ -28,8 +28,11 @@
# compiled_model = torch._dynamo.optimize(aot_ort)(model)
# result = compiled_model(torch.rand(2, 2, dtype=torch.float)
# result.sum().backward()

aot_ort = aot_autograd(
fw_compiler=DEFAULT_BACKEND, partition_fn=min_cut_rematerialization_partition, decompositions=ATEN2ATEN_DECOMP
fw_compiler=DEFAULT_BACKEND,
partition_fn=min_cut_rematerialization_partition,
decompositions=DORT_DECOMPOSITION_TABLE,
)

# Declare ORT as a compiler in Dynamo for inference (i.e., when .backward is NOT called).
Expand Down
4 changes: 2 additions & 2 deletions orttraining/orttraining/test/python/orttraining_test_dort.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ def test_elementwise_model(self):
def run_elementwise_model():
# A function to test DORT.
def elementwise_model(tensor_x: torch.Tensor):
tensor_w = tensor_x.relu()
tensor_w = tensor_x.sigmoid()
tensor_y = tensor_w * tensor_w + 1.5
tensor_z = tensor_y + tensor_x
tensor_p = tensor_z * tensor_x
tensor_q = tensor_p.relu()
tensor_q = tensor_p.sigmoid()
return tensor_q

@torch._dynamo.optimize(aot_ort)
Expand Down
Loading

0 comments on commit a0a5f57

Please sign in to comment.