From bfcc7d85a6974cb880867002696f427d2df00b35 Mon Sep 17 00:00:00 2001 From: Tian Jin Date: Fri, 26 Oct 2018 01:42:52 -0400 Subject: [PATCH 1/4] fix fill --- onnx_tf/handlers/frontend/fill.py | 12 ++++++++++++ test/frontend/test_node.py | 2 ++ 2 files changed, 14 insertions(+) diff --git a/onnx_tf/handlers/frontend/fill.py b/onnx_tf/handlers/frontend/fill.py index 8345916d0..e75bc982c 100644 --- a/onnx_tf/handlers/frontend/fill.py +++ b/onnx_tf/handlers/frontend/fill.py @@ -1,3 +1,5 @@ +import number + import numpy as np from onnx_tf.common import exception @@ -12,6 +14,9 @@ class Fill(FrontendHandler): @classmethod def args_check(cls, node, **kwargs): + output_shape = node.attr["_output_shapes"][0] + for dim_size in output_shape: + assert isinstance(dim_size, numbers.Number) if node.inputs[1] not in kwargs["consts"]: exception.CONST_NOT_FOUND_EXCEPT(node.inputs[1], node.op_type) @@ -20,3 +25,10 @@ def version_1(cls, node, **kwargs): value = float(np.asscalar(kwargs["consts"][node.inputs[1]])) return cls.make_node_from_tf_node( node, [node.inputs[0]], input_as_shape=1, value=value) + + @classmethod + def version_9(cls, node, **kwargs): + value = float(np.asscalar(kwargs["consts"][node.inputs[1]])) + outputs = cls.get_outputs_names(node) + return cls.make_node( + "ConstantLike", node, outputs, shape=output_shape, value=value) diff --git a/test/frontend/test_node.py b/test/frontend/test_node.py index 06434b883..080486ad5 100644 --- a/test/frontend/test_node.py +++ b/test/frontend/test_node.py @@ -95,6 +95,7 @@ def do_test_expected(self): return do_test_expected + # yapf: disable # organized as a tuple of the format: # (test_name, tensorflow_op, output_node_name, LIST of inputs, MAP of attributes) @@ -153,6 +154,7 @@ def do_test_expected(self): ("test_concat", tf.concat, "concat", [[get_rnd([1, 10]),get_rnd([10, 10]),get_rnd([20, 10])], 0], {}), ("test_bias_add_nchw", tf.nn.bias_add, "BiasAdd", [get_rnd([10, 32, 10, 10]),get_rnd([32])], {"data_format":"NCHW"}), ("test_bias_add_nhwc", tf.nn.bias_add, "BiasAdd", [get_rnd([10, 10, 10, 32]),get_rnd([32])], {"data_format":"NHWC"}), +("test_fill", tf.fill, "Fill", [[3, 10], 5], {}), ] if not legacy_opset_pre_ver(6): From d900b929a72d97a204cf294a3a4539941d989c82 Mon Sep 17 00:00:00 2001 From: Tian Jin Date: Fri, 26 Oct 2018 01:43:59 -0400 Subject: [PATCH 2/4] revert unintended style change --- test/frontend/test_node.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/frontend/test_node.py b/test/frontend/test_node.py index 080486ad5..459a2ecbe 100644 --- a/test/frontend/test_node.py +++ b/test/frontend/test_node.py @@ -95,7 +95,6 @@ def do_test_expected(self): return do_test_expected - # yapf: disable # organized as a tuple of the format: # (test_name, tensorflow_op, output_node_name, LIST of inputs, MAP of attributes) From a812dd169a8ac94cb0535610c18e6db3576eb265 Mon Sep 17 00:00:00 2001 From: Tian Jin Date: Fri, 26 Oct 2018 01:47:04 -0400 Subject: [PATCH 3/4] import error, bug fix --- onnx_tf/handlers/frontend/fill.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnx_tf/handlers/frontend/fill.py b/onnx_tf/handlers/frontend/fill.py index e75bc982c..535de2e19 100644 --- a/onnx_tf/handlers/frontend/fill.py +++ b/onnx_tf/handlers/frontend/fill.py @@ -1,4 +1,4 @@ -import number +import numbers import numpy as np From 7663c926a68b70d7fd62ee9f6d4a10d35fe3f8cb Mon Sep 17 00:00:00 2001 From: Tian Jin Date: Fri, 26 Oct 2018 02:00:36 -0400 Subject: [PATCH 4/4] fixes --- onnx_tf/handlers/frontend/fill.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/onnx_tf/handlers/frontend/fill.py b/onnx_tf/handlers/frontend/fill.py index 535de2e19..0c144a1a7 100644 --- a/onnx_tf/handlers/frontend/fill.py +++ b/onnx_tf/handlers/frontend/fill.py @@ -28,7 +28,8 @@ def version_1(cls, node, **kwargs): @classmethod def version_9(cls, node, **kwargs): + output_shape = node.attr["_output_shapes"][0] value = float(np.asscalar(kwargs["consts"][node.inputs[1]])) outputs = cls.get_outputs_names(node) - return cls.make_node( - "ConstantLike", node, outputs, shape=output_shape, value=value) + return cls.make_node_from_tf_node( + node, [], outputs, op_type="ConstantLike", shape=output_shape, value=value)