Skip to content

Commit

Permalink
Fix slice negative starts (#275)
Browse files Browse the repository at this point in the history
Fix slice negative starts, which is documented in ONNX
https://github.com/onnx/onnx/blob/master/docs/Operators.md#Slice
Also add unit test for slice.
  • Loading branch information
chinhuang007 authored and tjingrant committed Oct 6, 2018
1 parent 6013e3a commit e26f4d1
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 6 deletions.
2 changes: 2 additions & 0 deletions onnx_tf/handlers/backend/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ def version_1(cls, node, **kwargs):
axes = node.attrs.get("axes", list(range(slice_len)))

for i in range(slice_len):
starts[i] = full_sizes[
axes[i]] + starts[i] if starts[i] < 0 else starts[i]
ends[i] = full_sizes[axes[i]] + ends[i] if ends[i] < 0 else ends[i]
if full_sizes[axes[i]] is not None:
ends[i] = np.min([full_sizes[axes[i]], ends[i]])
Expand Down
24 changes: 18 additions & 6 deletions test/backend/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,8 @@ def test_dot(self):
def test_dynamic_slice(self):
if defs.onnx_opset_version() < 9:
raise unittest.SkipTest(
"ONNX version {} doesn't support DynamicSlice."
.format(defs.onnx_opset_version()))
"ONNX version {} doesn't support DynamicSlice.".format(
defs.onnx_opset_version()))
axes = np.array([0, 1], dtype=np.long)
starts = np.array([1, 0], dtype=np.long)
ends = np.array([2, 3], dtype=np.long)
Expand Down Expand Up @@ -737,12 +737,24 @@ def test_size(self):
np.testing.assert_almost_equal(output["Y"], np.size(x))

def test_slice(self):
# TODO: API update or fix onnx version
return
node_def = helper.make_node("Slice", ["X", "Y", "Z", "W"], ["S"])
# test case 1 with normal inputs
axes = [0, 1, 2]
starts = [0, 0, 0]
ends = [2, 2, 2]
node_def = helper.make_node(
"Slice", ["X"], ["S"], axes=axes, starts=starts, ends=ends)
x = self._get_rnd([1000]).reshape([10, 10, 10])
output = run_node(node_def, [x, [0, 1, 2], [0, 0, 0], [2, 2, 2]])
output = run_node(node_def, [x])
np.testing.assert_almost_equal(output["S"], x[0:2, 0:2, 0:2])
# test case 2 with negative, out-of-bound and default inputs
axes = [0, 2]
starts = [0, -7]
ends = [-8, 20]
node_def = helper.make_node(
"Slice", ["X"], ["S"], axes=axes, starts=starts, ends=ends)
x = self._get_rnd([1000]).reshape([10, 10, 10])
output = run_node(node_def, [x])
np.testing.assert_almost_equal(output["S"], x[0:-8, :, -7:20])

def test_softplus(self):
node_def = helper.make_node("Softplus", ["X"], ["Y"])
Expand Down

0 comments on commit e26f4d1

Please sign in to comment.