Skip to content

Commit

Permalink
Use those defined in onnx package
Browse files Browse the repository at this point in the history
  • Loading branch information
niboshi committed Oct 31, 2023
1 parent 2eb94b5 commit 239ae63
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions pytorch_pfn_extras/onnx/pfto_exporter/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,11 @@ def _type_to_proto(t: torch._C.TensorType) -> onnx.TypeProto:
ret.tensor_type.elem_type = onnx.TensorProto.DataType.UNDEFINED # type: ignore[attr-defined]
elif t.scalarType() == "Float8_e4m3fn":
ret.tensor_type.elem_type = int( # type: ignore
sym_hel._C_onnx.TensorProtoDataType.FLOAT8E4M3FN
onnx.TensorProto.DataType.FLOAT8E4M3FN
)
elif t.scalarType() == "Float8_e5m2":
ret.tensor_type.elem_type = int( # type: ignore
sym_hel._C_onnx.TensorProtoDataType.FLOAT8E5M2
onnx.TensorProto.DataType.FLOAT8E5M2
)
else:
ret.tensor_type.elem_type = int( # type: ignore
Expand Down

0 comments on commit 239ae63

Please sign in to comment.