From 2eb94b5735a7d63cbddcfa713d848d7a888053e4 Mon Sep 17 00:00:00 2001 From: niboshi Date: Tue, 31 Oct 2023 05:38:22 +0000 Subject: [PATCH 1/3] Add float8 types --- pytorch_pfn_extras/onnx/pfto_exporter/export.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/pytorch_pfn_extras/onnx/pfto_exporter/export.py b/pytorch_pfn_extras/onnx/pfto_exporter/export.py index 0b0a8300..fd5e46ce 100644 --- a/pytorch_pfn_extras/onnx/pfto_exporter/export.py +++ b/pytorch_pfn_extras/onnx/pfto_exporter/export.py @@ -168,6 +168,14 @@ def _type_to_proto(t: torch._C.TensorType) -> onnx.TypeProto: if t.scalarType() is None: ret.tensor_type.elem_type = onnx.TensorProto.DataType.UNDEFINED # type: ignore[attr-defined] + elif t.scalarType() == "Float8_e4m3fn": + ret.tensor_type.elem_type = int( # type: ignore + sym_hel._C_onnx.TensorProtoDataType.FLOAT8E4M3FN + ) + elif t.scalarType() == "Float8_e5m2": + ret.tensor_type.elem_type = int( # type: ignore + sym_hel._C_onnx.TensorProtoDataType.FLOAT8E5M2 + ) else: ret.tensor_type.elem_type = int( # type: ignore sym_hel.cast_pytorch_to_onnx[t.scalarType()] # type: ignore[index] @@ -221,6 +229,8 @@ def onnx_node_doc_string(onnx_node: torch._C.Node, torch_node: torch._C.Node) -> torch.float16: onnx.TensorProto.DataType.FLOAT16, # type: ignore[attr-defined] torch.complex64: onnx.TensorProto.DataType.COMPLEX64, # type: ignore[attr-defined] torch.complex128: onnx.TensorProto.DataType.COMPLEX128, # type: ignore[attr-defined] + torch.torch.float8_e4m3fn: onnx.TensorProto.DataType.FLOAT8E4M3FN, # type: ignore[attr-defined] + torch.torch.float8_e5m2: onnx.TensorProto.DataType.FLOAT8E5M2, # type: ignore[attr-defined] } From 239ae63d85d3bbd76fe380cddaf07cf6f14d6b48 Mon Sep 17 00:00:00 2001 From: niboshi Date: Tue, 31 Oct 2023 06:15:26 +0000 Subject: [PATCH 2/3] Use those defined in onnx package --- pytorch_pfn_extras/onnx/pfto_exporter/export.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_pfn_extras/onnx/pfto_exporter/export.py b/pytorch_pfn_extras/onnx/pfto_exporter/export.py index fd5e46ce..8db422d1 100644 --- a/pytorch_pfn_extras/onnx/pfto_exporter/export.py +++ b/pytorch_pfn_extras/onnx/pfto_exporter/export.py @@ -170,11 +170,11 @@ def _type_to_proto(t: torch._C.TensorType) -> onnx.TypeProto: ret.tensor_type.elem_type = onnx.TensorProto.DataType.UNDEFINED # type: ignore[attr-defined] elif t.scalarType() == "Float8_e4m3fn": ret.tensor_type.elem_type = int( # type: ignore - sym_hel._C_onnx.TensorProtoDataType.FLOAT8E4M3FN + onnx.TensorProto.DataType.FLOAT8E4M3FN ) elif t.scalarType() == "Float8_e5m2": ret.tensor_type.elem_type = int( # type: ignore - sym_hel._C_onnx.TensorProtoDataType.FLOAT8E5M2 + onnx.TensorProto.DataType.FLOAT8E5M2 ) else: ret.tensor_type.elem_type = int( # type: ignore From 98c54754b0e1a5b78c2d237e91874f09f47056c2 Mon Sep 17 00:00:00 2001 From: niboshi Date: Tue, 31 Oct 2023 06:48:15 +0000 Subject: [PATCH 3/3] Check torch and onnx versions for float8 types --- .../onnx/pfto_exporter/export.py | 38 ++++++++++--------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/pytorch_pfn_extras/onnx/pfto_exporter/export.py b/pytorch_pfn_extras/onnx/pfto_exporter/export.py index 8db422d1..808921ef 100644 --- a/pytorch_pfn_extras/onnx/pfto_exporter/export.py +++ b/pytorch_pfn_extras/onnx/pfto_exporter/export.py @@ -148,6 +148,20 @@ def _tensor_to_proto(t: torch.Tensor, name: Optional[ONNXValueID] = None) -> onn return onnx.numpy_helper.from_array(t.detach().cpu().numpy(), name) +def _scalar_type_to_elem_type(scalar_type: Optional[str]) -> int: + if scalar_type is None: + return cast(int, onnx.TensorProto.DataType.UNDEFINED) # type: ignore[attr-defined] + typ = sym_hel.cast_pytorch_to_onnx.get(scalar_type) + if typ is not None: + return cast(int, typ.value) + if pytorch_pfn_extras.requires("14.0", "onnx"): + if scalar_type == "Float8_e4m3fn": + return cast(int, onnx.TensorProto.DataType.FLOAT8E4M3FN) # type: ignore[attr-defined] + if scalar_type == "Float8_e5m2": + return cast(int, onnx.TensorProto.DataType.FLOAT8E5M2) # type: ignore[attr-defined] + raise ValueError("Unsupported scalar type: {scalar_type}") + + def _type_to_proto(t: torch._C.TensorType) -> onnx.TypeProto: if t.kind() == "NoneType": return onnx.TypeProto() @@ -166,21 +180,7 @@ def _type_to_proto(t: torch._C.TensorType) -> onnx.TypeProto: assert t.kind() == "TensorType", f"Not Tensor type(actual: {t.kind()}): {t}" - if t.scalarType() is None: - ret.tensor_type.elem_type = onnx.TensorProto.DataType.UNDEFINED # type: ignore[attr-defined] - elif t.scalarType() == "Float8_e4m3fn": - ret.tensor_type.elem_type = int( # type: ignore - onnx.TensorProto.DataType.FLOAT8E4M3FN - ) - elif t.scalarType() == "Float8_e5m2": - ret.tensor_type.elem_type = int( # type: ignore - onnx.TensorProto.DataType.FLOAT8E5M2 - ) - else: - ret.tensor_type.elem_type = int( # type: ignore - sym_hel.cast_pytorch_to_onnx[t.scalarType()] # type: ignore[index] - ) - + ret.tensor_type.elem_type = _scalar_type_to_elem_type(t.scalarType()) ret.tensor_type.shape.CopyFrom(onnx.TensorShapeProto()) if t.sizes() is not None: for s in t.sizes(): # type: ignore @@ -229,8 +229,12 @@ def onnx_node_doc_string(onnx_node: torch._C.Node, torch_node: torch._C.Node) -> torch.float16: onnx.TensorProto.DataType.FLOAT16, # type: ignore[attr-defined] torch.complex64: onnx.TensorProto.DataType.COMPLEX64, # type: ignore[attr-defined] torch.complex128: onnx.TensorProto.DataType.COMPLEX128, # type: ignore[attr-defined] - torch.torch.float8_e4m3fn: onnx.TensorProto.DataType.FLOAT8E4M3FN, # type: ignore[attr-defined] - torch.torch.float8_e5m2: onnx.TensorProto.DataType.FLOAT8E5M2, # type: ignore[attr-defined] + **( + { + torch.float8_e4m3fn: onnx.TensorProto.DataType.FLOAT8E4M3FN, # type: ignore[attr-defined] + torch.float8_e5m2: onnx.TensorProto.DataType.FLOAT8E5M2, # type: ignore[attr-defined] + } if pytorch_pfn_extras.requires("2.1") and pytorch_pfn_extras.requires("14.0", "onnx") else {} + ), }