diff --git a/orttraining/orttraining/python/training/torchdynamo/__init__.py b/orttraining/orttraining/python/training/torchdynamo/__init__.py index 1f87244430..862c45ce31 100644 --- a/orttraining/orttraining/python/training/torchdynamo/__init__.py +++ b/orttraining/orttraining/python/training/torchdynamo/__init__.py @@ -2,14 +2,3 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- - -from typing import Set - -# set of custom ops supported in DORT. -# It can contain, for example, `aten::custom_add`. -custom_symbols: Set[str] = set() - - -# register custom ops in DORT -def register_custom_op_in_dort(custom_op_name: str): - custom_symbols.add(custom_op_name) diff --git a/orttraining/orttraining/python/training/torchdynamo/ort_backend.py b/orttraining/orttraining/python/training/torchdynamo/ort_backend.py index 377871d26d..541f159d5e 100644 --- a/orttraining/orttraining/python/training/torchdynamo/ort_backend.py +++ b/orttraining/orttraining/python/training/torchdynamo/ort_backend.py @@ -5,7 +5,7 @@ import dataclasses import logging -from typing import Any, Callable, Dict, Mapping, Set, Tuple, Union +from typing import Any, Dict, Mapping, Tuple, Union import numpy as np import onnx @@ -16,19 +16,47 @@ import torch.fx import torch.jit import torch.onnx + +# TODO(wschin,justinchuby): Since the internal APIs are not stable, please +# contact us if you hit errors. +import torch.onnx._internal +import torch.onnx._internal.diagnostics +import torch.onnx._internal.exporter +import torch.onnx._internal.fx.decomposition_table +import torch.onnx._internal.fx.passes import torch.onnx._onnx_supported_ops -from torch._decomp import decomposition_table from torch._subclasses.fake_tensor import FakeTensor from torch.fx.passes.fake_tensor_prop import FakeTensorProp from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner from torch.fx.passes.operator_support import OperatorSupport from torch.fx.passes.tools_common import CALLABLE_NODE_OPS -from torch.onnx._globals import GLOBALS as ONNX_GLOBALS import onnxruntime # type: ignore from onnxruntime.capi import _pybind_state as ORTC -from . import custom_symbols +# DEFAULT_ONNX_EXPORTER_OPTIONS contains shared information between exporter and DORT. +# For example, they should use the same decomposition table to maintain the same set +# operators when +# 1. capturing FX graph in torch.compile +# 2. call exporter's API to convert `torch.fx.GraphModule` to ONNX model. +DEFAULT_ONNX_EXPORTER_OPTIONS = torch.onnx._internal.exporter.ResolvedExportOptions( + torch.onnx._internal.exporter.ExportOptions() +) + +# TODO(wechi): This line must generate result identical to the call of +# _create_onnx_supports_op_overload_table(...) inside +# create_onnx_friendly_decomposition_table(...) in +# torch/onnx/_internal/fx/decomposition_table.py. +_SUPPORT_DICT = torch.onnx._internal.fx.decomposition_table._create_onnx_supports_op_overload_table( + DEFAULT_ONNX_EXPORTER_OPTIONS.onnx_registry +) # type: ignore + +_EXTRA_SUPPORT_DICT: Dict[str, Any] = { + "getattr": None, + "_operator.getitem": None, +} + +DORT_DECOMPOSITION_TABLE = DEFAULT_ONNX_EXPORTER_OPTIONS.decomposition_table _NP_DTYPE = { torch.float16: np.float16, @@ -68,7 +96,7 @@ def _get_ort_device_type(device_type: str): return ORTC.OrtDevice.cpu() # type: ignore # ort pytorch device is mapped to NPU OrtDevice type if device_type == "ort": - return ORTC.OrtDevice.npu() + return ORTC.OrtDevice.npu() # type: ignore raise ValueError("Unsupported device type: " + device_type) @@ -78,130 +106,6 @@ def _get_ort_device_type(device_type: str): # logger.setLevel(logging.INFO) -def _get_onnx_supported_table() -> Set[str]: - # TODO(wechi): this entire function should be replaced by a formal a exporter API. - - onnx_supported_ops: Set[str] = set() - for aten_op_name, schema in torch.onnx._onnx_supported_ops.all_symbolics_schemas().items(): - # TODO(wechi): aten_op_name could be prim::add in addition to aten::add. - # We should build another dictionary for storing support table for prim ops. - # Currently, we only consider aten ops as before. - if aten_op_name not in custom_symbols and not aten_op_name.startswith("aten::"): - logger.info( - "Skip %s in support table because it's not in aten domain or supported custom ops %s", - aten_op_name, - custom_symbols, - ) - continue - short_op_name = aten_op_name.split("::")[1] - if aten_op_name.startswith("aten::") and not hasattr(torch.ops.aten, short_op_name): # type: ignore - # Some aten ops are not in torch.ops.aten. Those are excluded until we - # figure out why. - logger.info("Skip %s in support table because it's not found in torch.ops.aten.", aten_op_name) - continue - # aten_op_name is aten symbol's name; e.g., "sum" for aten::sum. - # opsets_string is the ONNX opsets that can express info[0]; e.g., "15 16 17" - # indicates that opset 15, opset 16, and opset 17 can all express aten_op_name. - if ONNX_GLOBALS.export_onnx_opset_version in schema.opsets: - logger.info("Add %s to support table.", aten_op_name) - onnx_supported_ops.add(aten_op_name) - return onnx_supported_ops - - -def _get_support_dictionaries_and_decomposition_tables() -> ( - Tuple[ - Dict[torch._ops.OpOverload, Any], - Dict[str, Any], - Dict[torch._ops.OpOverload, Callable], - Dict[torch._ops.OpOverload, Callable], - ] -): - # The keys of this dictionary are OpOverload's which can be - # exported by ONNX exporter. Type of key is torch._ops.OpOverload. - # For example, if torch.ops.aten.add.default is a key in support_dict, - # all torch.fx.Node's with torch.ops.aten.add.default as target will - # be selected by CapabilityBasedPartitioner and sent to ORT for - # computation. - # We choose torch._ops.OpOverload as the key because - # 1. torch._ops.OpOverload uniquely identifies an op. We don't want - # to use OpOverloadPacket because it contains overloads of the same op. - # This allows us to select supported ops at the finest grain. - # 2. torch._ops.OpOverload is what we get from torch.fx.Node.target. Getting - # qualified name using _get_qualified_name is not needed. - support_dictionary: Dict[torch._ops.OpOverload, Any] = {} - for aten_op_name in _get_onnx_supported_table(): - if aten_op_name.startswith("aten::"): - short_op_name = aten_op_name.split("aten::")[1] - op_overload_packet = getattr(torch.ops.aten, short_op_name) # type: ignore - # Due to the lack of overload name in exporting function's name, assume - # each exporting function (e.g., torch.onnx.symbolic_opset9.add) support - # all overloads (e.g., in torch.ops.aten.add). - # Thus, we register all torch._ops.OpOverload's for the same exporting function. - # Please manually exclude torch._ops.OpOverload if exporter fails. - for overload in op_overload_packet.overloads(): - op_overload = getattr(op_overload_packet, overload) - support_dictionary[op_overload] = None - - elif aten_op_name in custom_symbols: - op_namespace = aten_op_name.split("::")[0] - short_op_name = aten_op_name.split("::")[1] - # Get the custom ops from: torch.ops.custom_namespace - custom_op_namespace = getattr(torch.ops, op_namespace) - op_overload_packet = getattr(custom_op_namespace, short_op_name) # type: ignore - for overload in op_overload_packet.overloads(): - op_overload = getattr(op_overload_packet, overload) - support_dictionary[op_overload] = None - - # No decomposition table. OpOverload in this table shouldn't be found - # in aten2aten_decomposition_table. - # The symbols in this set will be replaced by torch.ops.aten.to.dtype in replace_to_copy_with_to because - # only aten.to has ONNX exporter. - # If the replacement fails, ONNX exporter will fail because only aten.to has ONNX exporter. - # TODO(wechi): For a long-term solution, we need to ensure every op used in op decomposision has - # an exporter. - no_decomposition_table: Set[torch._ops.OpOverload] = { - torch.ops.aten._to_copy.default, # type: ignore - torch.ops.aten._to_copy.out, # type: ignore - } - - # decomposition_table currently contains both aten2aten and aten2prim decompositions - # This is a hack to separate them, as ONNX only recognizes aten symbols. - aten2aten_decomposition_table: Dict[torch._ops.OpOverload, Callable] = {} - aten2prim_decomposition_table: Dict[torch._ops.OpOverload, Callable] = {} - - for op_overload, decomp_fn in decomposition_table.items(): - if op_overload in support_dictionary: - # ONNX can express this op, no need to decompose. - continue - if "torch._refs" in decomp_fn.__module__: - aten2prim_decomposition_table[op_overload] = decomp_fn - else: - if op_overload in no_decomposition_table: - continue - # Assume ONNX can express ops after decomposition. - # If no, exporter will fail and the user need to - # remove this decomposition rule. - aten2aten_decomposition_table[op_overload] = decomp_fn - - # Some torch.fx.Node's are converted to ONNX-compatible ops - # by torch.jit.script. They don't have direct ONNX exporting - # functions but still runnable in ORT. - extra_support_dictionary: Dict[str, Any] = { - "getattr": None, - "_operator.getitem": None, - } - - return support_dictionary, extra_support_dictionary, aten2aten_decomposition_table, aten2prim_decomposition_table - - -( - _SUPPORT_DICT, - _EXTRA_SUPPORT_DICT, - ATEN2ATEN_DECOMP, - ATEN2PRIM_DECOMP, -) = _get_support_dictionaries_and_decomposition_tables() - - class OrtOperatorSupport(OperatorSupport): """ Operator support for ONNXRuntime backend. It has two-level of support decision. @@ -234,31 +138,6 @@ def is_node_supported(self, submodules: Mapping[str, torch.nn.Module], node: tor return False -def _jit_graph_to_onnx_model(graph, operator_export_type): - r""" - This function exports torch::jit::Graph object - to serialized ONNX ModelProto. - It only keeps the essential parts for IR graph conversions. - It also does not interact with actual PyTorch modules nor - PyTorch tensor inputs. - """ - graph = torch.onnx.utils._optimize_graph(graph, operator_export_type, params_dict={}) - proto, _, _, _ = graph._export_onnx( # type: ignore - {}, - ONNX_GLOBALS.export_onnx_opset_version, - {}, - False, - operator_export_type, - False, - False, - {}, - True, - "", - {}, - ) - return proto - - def _move_placeholder_to_front(graph_module: torch.fx.GraphModule) -> None: """ In torch.fx.Graph, placehoder is a special assignment node. If it's not @@ -316,42 +195,6 @@ def _replace_to_copy_with_to(fx_module: torch.fx.GraphModule) -> None: fx_module.recompile() -def _fx_to_torchscript( - fx_module: torch.fx.GraphModule, -) -> torch.jit.ScriptModule: - """Convert torch.fx.Graph to torch.jit.ScriptModule.""" - - for node in fx_module.graph.nodes: - new_kwargs = {} - for k, v in node.kwargs.items(): - if isinstance(v, torch.device): - v = v.type # noqa: PLW2901 - new_kwargs[k] = v - node.kwargs = new_kwargs - for node in fx_module.graph.nodes: - if isinstance(node.target, torch._ops.OpOverload): - node.target = node.target.overloadpacket - fx_module.graph.lint() - fx_module.recompile() - return torch.jit.script(fx_module) # type: ignore - - -def _decorate_script_module(script_module: torch.jit.ScriptModule, expected_inputs, expected_outputs): - for i, input_value in enumerate(script_module.graph.inputs()): # type: ignore - if input_value.debugName() == "self": - script_module.graph.eraseInput(i) # type: ignore - break - for input_value, expected_input in zip(script_module.graph.inputs(), expected_inputs): # type: ignore - input_value.setType(torch._C.TensorType.create_from_tensor(expected_input)) - for output_value, expected_output in zip(script_module.graph.outputs(), expected_outputs): # type: ignore - output_value.setType(torch._C.TensorType.create_from_tensor(expected_output)) - - -def _create_onnx_proto(script_module): - onnx_proto = _jit_graph_to_onnx_model(script_module.graph, torch.onnx.OperatorExportTypes.ONNX) - return onnx_proto - - def _create_onnx_model(onnx_proto): return onnx.ModelProto.FromString(onnx_proto) @@ -554,17 +397,25 @@ def _ort_acclerated_call(self, graph_module: torch.fx.GraphModule, *args, **kwar # rethrow FakeTensorProb failure because it is not yet currently handled. raise self._ort_execution_info.example_outputs[graph_module] = prim_outputs - # Compile the torch.fx.GraphModule into a torch.jit.ScriptModule. - script_module = _fx_to_torchscript(graph_module) - # Post-processing step to add expected input and output type information - # to the graph in torch.jit.ScriptModule. Expected inputs is "args" and "kwargs" - # while expected outputs is "prim_outputs". - if isinstance(prim_outputs, tuple): - _decorate_script_module(script_module, args, prim_outputs) - else: - _decorate_script_module(script_module, args, (prim_outputs,)) - # Generate ONNX ModelProto from torch._C.Graph. - onnx_proto = _create_onnx_proto(script_module) + + from torch.onnx._internal.fx import fx_onnx_interpreter + + # Create the object to iterate through the nodes in graph one-by-one + # and calls the corresponding ONNX exporter for each node. + fx_interpreter = fx_onnx_interpreter.FxOnnxInterpreter( + diagnostic_context=DEFAULT_ONNX_EXPORTER_OPTIONS.diagnostic_context + ) + # Start the per-node exporting process. It's conceptually a for loop + # scanning through the nodes in the graph. + exported = fx_interpreter.run( + fx_graph_module=graph_module, + onnxfunction_dispatcher=DEFAULT_ONNX_EXPORTER_OPTIONS.onnxfunction_dispatcher, + op_level_debug=DEFAULT_ONNX_EXPORTER_OPTIONS.op_level_debug, + ) + # Convert the exported result to ONNX ModelProto. + onnx_proto = exported.to_model_proto( + opset_version=DEFAULT_ONNX_EXPORTER_OPTIONS.opset_version + ).SerializeToString() # Initialize a ORT session to execute this ONNX model. # TorchDynamo assumes all inputs/outputs are on the same device, diff --git a/orttraining/orttraining/python/training/torchdynamo/register_backend.py b/orttraining/orttraining/python/training/torchdynamo/register_backend.py index ae9a1522a3..1aa2692e70 100644 --- a/orttraining/orttraining/python/training/torchdynamo/register_backend.py +++ b/orttraining/orttraining/python/training/torchdynamo/register_backend.py @@ -6,7 +6,7 @@ from functorch.compile import min_cut_rematerialization_partition from torch._dynamo.backends.common import aot_autograd -from .ort_backend import ATEN2ATEN_DECOMP, OrtBackend +from .ort_backend import DORT_DECOMPOSITION_TABLE, OrtBackend # This should be the underlying compiler for ALL graphs if # the user uses ORT to accelerate PyTorch via Dynamo. @@ -28,8 +28,11 @@ # compiled_model = torch._dynamo.optimize(aot_ort)(model) # result = compiled_model(torch.rand(2, 2, dtype=torch.float) # result.sum().backward() + aot_ort = aot_autograd( - fw_compiler=DEFAULT_BACKEND, partition_fn=min_cut_rematerialization_partition, decompositions=ATEN2ATEN_DECOMP + fw_compiler=DEFAULT_BACKEND, + partition_fn=min_cut_rematerialization_partition, + decompositions=DORT_DECOMPOSITION_TABLE, ) # Declare ORT as a compiler in Dynamo for inference (i.e., when .backward is NOT called). diff --git a/orttraining/orttraining/test/python/orttraining_test_dort.py b/orttraining/orttraining/test/python/orttraining_test_dort.py index ae6d1ac3c4..ed2450ae17 100644 --- a/orttraining/orttraining/test/python/orttraining_test_dort.py +++ b/orttraining/orttraining/test/python/orttraining_test_dort.py @@ -24,11 +24,11 @@ def test_elementwise_model(self): def run_elementwise_model(): # A function to test DORT. def elementwise_model(tensor_x: torch.Tensor): - tensor_w = tensor_x.relu() + tensor_w = tensor_x.sigmoid() tensor_y = tensor_w * tensor_w + 1.5 tensor_z = tensor_y + tensor_x tensor_p = tensor_z * tensor_x - tensor_q = tensor_p.relu() + tensor_q = tensor_p.sigmoid() return tensor_q @torch._dynamo.optimize(aot_ort) diff --git a/orttraining/orttraining/test/python/orttraining_test_dort_custom_ops.py b/orttraining/orttraining/test/python/orttraining_test_dort_custom_ops.py index 79554b0440..d8c8d395fa 100644 --- a/orttraining/orttraining/test/python/orttraining_test_dort_custom_ops.py +++ b/orttraining/orttraining/test/python/orttraining_test_dort_custom_ops.py @@ -5,18 +5,64 @@ import sys import unittest +import onnxscript import torch +import torch._dynamo from functorch.compile import min_cut_rematerialization_partition from torch._dynamo.backends.common import aot_autograd -from torch.onnx import register_custom_op_symbolic - -import onnxruntime as onnxrt -from onnxruntime.training.torchdynamo import register_custom_op_in_dort -from onnxruntime.training.torchdynamo.ort_backend import ATEN2ATEN_DECOMP, OrtBackend - - -def onnx_custom_add(g, x, y): - return g.op("test.customop::CustomOpOne", x, y, outputs=1) +from torch.library import Library + +import onnxruntime +from onnxruntime.training.torchdynamo.ort_backend import ( + _SUPPORT_DICT, + DEFAULT_ONNX_EXPORTER_OPTIONS, + DORT_DECOMPOSITION_TABLE, + OrtBackend, +) + +# Dummy operator set to map aten::mul.Tensor to test.customop::CustomOpOne +# in ONNX model executed by DORT. +# Print the output of to_model_proto in ort_backend.py for the generated +# ONNX model. +custom_opset = onnxscript.values.Opset(domain="test.customop", version=1) + + +# Exporter for torch.ops.aten.mul.Tensor. +@onnxscript.script(custom_opset) +def custom_exporter_for_aten_add_Tensor(x, y): + # This function represents an ONNX function. Register below + # set this function as the FX-to-ONNX exporter of "aten::mul.Tensor". + return custom_opset.CustomOpOne(x, y) + + +# Register custom_exporter_for_aten_add_Tensor as "aten::mul.Tensor"'s +# exporter. +# Use custom_exporter_for_aten_add_Tensor.to_function_proto() to investigate +# function representing "aten::mul.Tensor". +DEFAULT_ONNX_EXPORTER_OPTIONS.onnxfunction_dispatcher.onnx_registry.register( + "aten::mul.Tensor", + DEFAULT_ONNX_EXPORTER_OPTIONS.opset_version, + custom_exporter_for_aten_add_Tensor, + True, +) + + +# Exporter for torch.ops.foo.bar.default. +@onnxscript.script(custom_opset) +def custom_exporter_for_foo_bar_default(x): + # This function represents an ONNX function. Register below + # set this function as the FX-to-ONNX exporter of "aten::mul.Tensor". + return custom_opset.CustomOpOne(x, x) + + +# Ask exporter to map "torch.ops.foo.bar" to +# custom_exporter_for_foo_bar_default. +DEFAULT_ONNX_EXPORTER_OPTIONS.onnxfunction_dispatcher.onnx_registry.register( + "foo::bar", + DEFAULT_ONNX_EXPORTER_OPTIONS.opset_version, + custom_exporter_for_foo_bar_default, + True, +) class TestTorchDynamoOrtCustomOp(unittest.TestCase): @@ -26,15 +72,19 @@ def setUp(self): # Make computation deterministic. torch.manual_seed(42) - def test_DORT_custom_ops(self): - torch._dynamo.reset() + @staticmethod + def search_for_custom_op_library_path(): + """Searches for the path of the custom op library file. - # register custom op in onnx - register_custom_op_symbolic("aten::mul", onnx_custom_add, opset_version=14) + The returned path may change depending on the platform of the CI. - # register custom op in dort - register_custom_op_in_dort("test.customop::CustomOpOne") + Returns: + str: The path of the custom op library file. + Raises: + FileNotFoundError: If the custom op library file is not found + in the expected location. + """ if sys.platform.startswith("win"): shared_library = "custom_op_library.dll" if not os.path.exists(shared_library): @@ -50,27 +100,83 @@ def test_DORT_custom_ops(self): if not os.path.exists(shared_library): raise FileNotFoundError(f"Unable to find '{shared_library}'") - session_options = onnxrt.SessionOptions() - session_options.register_custom_ops_library(shared_library) + return shared_library + + @staticmethod + def create_onnxruntime_session_options(): + """Creates an ONNXRuntime session options object. + + The returned option object is configured to enable custom + operator's implementation visible in ONNXRuntime. + + Returns: + onnxruntime.SessionOptions: An ONNXRuntime session options object. + """ + custom_op_library_path = TestTorchDynamoOrtCustomOp.search_for_custom_op_library_path() + session_options = onnxruntime.SessionOptions() + session_options.register_custom_ops_library(custom_op_library_path) + return session_options + + def test_DORT_custom_ops(self): + torch._dynamo.reset() + + session_options = TestTorchDynamoOrtCustomOp.create_onnxruntime_session_options() ort_backend = OrtBackend(ep="CPUExecutionProvider", session_options=session_options) aot_ort = aot_autograd( - fw_compiler=ort_backend, partition_fn=min_cut_rematerialization_partition, decompositions=ATEN2ATEN_DECOMP + fw_compiler=ort_backend, + partition_fn=min_cut_rematerialization_partition, + decompositions=DORT_DECOMPOSITION_TABLE, ) - def custom_add(tensor_x: torch.Tensor, tensor_y: torch.Tensor): + def one_mul(tensor_x: torch.Tensor, tensor_y: torch.Tensor): return torch.mul(tensor_x, tensor_y) - opt_add = torch._dynamo.optimize(aot_ort)(custom_add) + opt_mul = torch._dynamo.optimize(aot_ort)(one_mul) tensor_x = torch.ones((64, 64), dtype=torch.float32) tensor_y = torch.ones((64, 64), dtype=torch.float32) for _ in range(5): result_ref = torch.add(tensor_x, tensor_y) - result_ort = opt_add(tensor_x, tensor_y) + result_ort = opt_mul(tensor_x, tensor_y) torch.testing.assert_close(result_ref, result_ort) + def test_dort_with_custom_torch_op_library(self): + torch._dynamo.reset() + + foo_lib = Library("foo", "DEF") + bar_name = foo_lib.define("bar(Tensor self) -> Tensor") + + def bar_impl(self: torch.Tensor) -> torch.Tensor: + # foo::bar.default will be mapped to test.customop::CustomOpOne. + # In ORT, test.customop::CustomOpOne is simply an Add for testing. + return torch.add(self, self) + + foo_lib.impl(bar_name, bar_impl, "CompositeExplicitAutograd") + + # TODO(wechi): Redesign API to expose this better. + _SUPPORT_DICT.add(torch.ops.foo.bar.default) + + session_options = TestTorchDynamoOrtCustomOp.create_onnxruntime_session_options() + ort_backend = OrtBackend(ep="CPUExecutionProvider", session_options=session_options) + aot_ort = aot_autograd( + fw_compiler=ort_backend, + partition_fn=min_cut_rematerialization_partition, + decompositions=DORT_DECOMPOSITION_TABLE, + ) + + def one_foo(tensor_x: torch.Tensor): + return torch.ops.foo.bar(tensor_x) + + opt_foo = torch._dynamo.optimize(aot_ort)(one_foo) + + for _ in range(5): + x = torch.randn(3, 2, device="cpu") + expected = torch.ops.foo.bar(x) + actual = opt_foo(x) + torch.testing.assert_close(expected, actual) + if __name__ == "__main__": unittest.main() diff --git a/tools/ci_build/github/azure-pipelines/linux-cpu-eager-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-cpu-eager-pipeline.yml index e72b65a1ca..7b8d5ade55 100644 --- a/tools/ci_build/github/azure-pipelines/linux-cpu-eager-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-cpu-eager-pipeline.yml @@ -69,7 +69,8 @@ jobs: bash -c " export PYTHONPATH=/build/Release && \ /opt/python/cp39-cp39/bin/python3.9 -m pip install /build/Release/dist/*.whl && \ - /opt/python/cp39-cp39/bin/python3.9 /onnxruntime_src/orttraining/orttraining/test/python/orttraining_test_dort.py" + /opt/python/cp39-cp39/bin/python3.9 /onnxruntime_src/orttraining/orttraining/test/python/orttraining_test_dort.py && \ + cd /build/Release && /opt/python/cp39-cp39/bin/python3.9 /onnxruntime_src/orttraining/orttraining/test/python/orttraining_test_dort_custom_ops.py" workingDirectory: $(Build.SourcesDirectory) condition: succeededOrFailed() diff --git a/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_lort.sh b/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_lort.sh index e9f6f054b1..88110954f2 100755 --- a/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_lort.sh +++ b/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_lort.sh @@ -31,6 +31,16 @@ fi export ONNX_ML=1 export CMAKE_ARGS="-DONNX_GEN_PB_TYPE_STUBS=OFF -DONNX_WERROR=OFF" +/opt/python/cp39-cp39/bin/python3.9 -m pip install transformers + +cd /usr/local/ +echo "Cloning ONNX Script" +git clone --recursive https://github.com/microsoft/onnxscript.git +cd onnxscript +/opt/python/cp39-cp39/bin/python3.9 -m pip install -r requirements-dev.txt +/opt/python/cp39-cp39/bin/python3.9 setup.py install +cd ~ && /opt/python/cp39-cp39/bin/python3.9 -c "import onnxscript; print(f'Installed ONNX Script: {onnxscript.__version__}')" + cd /usr/local echo "Cloning Pytorch" git clone --recursive https://github.com/pytorch/pytorch.git @@ -42,17 +52,5 @@ echo "Building and installing Pytorch" VERBOSE=1 BUILD_LAZY_TS_BACKEND=1 /opt/python/cp39-cp39/bin/python3.9 setup.py install cd ~ && /opt/python/cp39-cp39/bin/python3.9 -c "import torch; print(f'Installed Pytorch: {torch.__version__}')" -cd /usr/local/ -echo "Cloning TorchDynamo" -git clone --recursive https://github.com/pytorch/torchdynamo.git -cd torchdynamo -echo "Installing TorchDynamo requirements" -/opt/python/cp39-cp39/bin/python3.9 -m pip install transformers -/opt/python/cp39-cp39/bin/python3.9 -m pip install -r requirements.txt -echo "Installing TorchDynamo" -/opt/python/cp39-cp39/bin/python3.9 setup.py install -cd ~ && /opt/python/cp39-cp39/bin/python3.9 -c "import torch; print(f'Installed Pytorch: {torch.__version__}')" -cd ~ && /opt/python/cp39-cp39/bin/python3.9 -c "import torchdynamo; print(f'Installed TorchDynamo: {torchdynamo.__path__}')" - cd / rm -rf /tmp/src