diff --git a/onnx_tf/handlers/backend/slice.py b/onnx_tf/handlers/backend/slice.py index 85263b1c2..1f70ae7db 100644 --- a/onnx_tf/handlers/backend/slice.py +++ b/onnx_tf/handlers/backend/slice.py @@ -24,6 +24,8 @@ def version_1(cls, node, **kwargs): axes = node.attrs.get("axes", list(range(slice_len))) for i in range(slice_len): + starts[i] = full_sizes[ + axes[i]] + starts[i] if starts[i] < 0 else starts[i] ends[i] = full_sizes[axes[i]] + ends[i] if ends[i] < 0 else ends[i] if full_sizes[axes[i]] is not None: ends[i] = np.min([full_sizes[axes[i]], ends[i]]) diff --git a/test/backend/test_node.py b/test/backend/test_node.py index e556bef82..1047d15f9 100644 --- a/test/backend/test_node.py +++ b/test/backend/test_node.py @@ -307,8 +307,8 @@ def test_dot(self): def test_dynamic_slice(self): if defs.onnx_opset_version() < 9: raise unittest.SkipTest( - "ONNX version {} doesn't support DynamicSlice." - .format(defs.onnx_opset_version())) + "ONNX version {} doesn't support DynamicSlice.".format( + defs.onnx_opset_version())) axes = np.array([0, 1], dtype=np.long) starts = np.array([1, 0], dtype=np.long) ends = np.array([2, 3], dtype=np.long) @@ -737,12 +737,24 @@ def test_size(self): np.testing.assert_almost_equal(output["Y"], np.size(x)) def test_slice(self): - # TODO: API update or fix onnx version - return - node_def = helper.make_node("Slice", ["X", "Y", "Z", "W"], ["S"]) + # test case 1 with normal inputs + axes = [0, 1, 2] + starts = [0, 0, 0] + ends = [2, 2, 2] + node_def = helper.make_node( + "Slice", ["X"], ["S"], axes=axes, starts=starts, ends=ends) x = self._get_rnd([1000]).reshape([10, 10, 10]) - output = run_node(node_def, [x, [0, 1, 2], [0, 0, 0], [2, 2, 2]]) + output = run_node(node_def, [x]) np.testing.assert_almost_equal(output["S"], x[0:2, 0:2, 0:2]) + # test case 2 with negative, out-of-bound and default inputs + axes = [0, 2] + starts = [0, -7] + ends = [-8, 20] + node_def = helper.make_node( + "Slice", ["X"], ["S"], axes=axes, starts=starts, ends=ends) + x = self._get_rnd([1000]).reshape([10, 10, 10]) + output = run_node(node_def, [x]) + np.testing.assert_almost_equal(output["S"], x[0:-8, :, -7:20]) def test_softplus(self): node_def = helper.make_node("Softplus", ["X"], ["Y"])