Skip to content

Commit

Permalink
Comments after code review
Browse files Browse the repository at this point in the history
Signed-off-by: neNasko1 <[email protected]>
  • Loading branch information
neNasko1 committed Jan 3, 2025
1 parent 9aee9fd commit 393ffdf
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 9 deletions.
6 changes: 5 additions & 1 deletion src/spox/opset/ai/onnx/v17.py
Original file line number Diff line number Diff line change
Expand Up @@ -1785,15 +1785,19 @@ def infer_output_types(self, input_prop_values: PropDict) -> dict[str, Type]:

body = self.attrs.body.value

# We skip the iteration_num and condition as they are correctly inferred
initial_types = [v.type for v in list(body.requested_arguments)[2:]]
# We skip the returned condition as it is correctly inferred
carried_types = [v.type for v in list(body.requested_results.values())[1:]]

shape_unchanged_between_iterations = all(
i_typ == c_typ for i_typ, c_typ in zip(initial_types, carried_types)
)

for name, _, c_typ in zip(output_names, initial_types, carried_types):
output_types[name] = c_typ if is_constant_shape else c_typ._with_shape(None)
output_types[name] = (
c_typ if shape_unchanged_between_iterations else c_typ._with_shape(None)
)

return output_types

Expand Down
72 changes: 66 additions & 6 deletions tests/type_inference/test_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,21 @@
import spox.opset.ai.onnx.v17 as op17
import spox.opset.ai.onnx.v19 as op19
import spox.opset.ai.onnx.v21 as op21
from spox import Tensor, argument
from spox import Optional, Sequence, Tensor, argument


@pytest.mark.parametrize("op", [op17, op19, op21])
def test_loop_inference(op):
x, y, zs = op.loop(
v_initial=[argument(Tensor(float, (None,))), argument(Tensor(int, ("N", 2)))],
v_initial=[
argument(Tensor(np.float64, (None,))),
argument(Tensor(np.int64, ("N", 2))),
],
body=lambda i, c, a, b: [op.const(True), a, op.add(i, b), i],
)
assert x.type == Tensor(float, (None,))
assert y.type == Tensor(int, ("N", 2))
assert zs.type == Tensor(int, (None, 1))
assert x.type == Tensor(np.float64, (None,))
assert y.type == Tensor(np.int64, ("N", 2))
assert zs.type == Tensor(np.int64, (None, 1))


@pytest.mark.parametrize("op", [op17, op19, op21])
Expand All @@ -33,4 +36,61 @@ def test_loop_concat(op):
)[0]

# type can change, so we cannot infer anything
assert result.type == Tensor(int, None)
assert result.type == Tensor(np.int64, None)


@pytest.mark.parametrize("op", [op17, op19, op21])
def test_loop_sequence(op):
num_iters = op.const(1)
v = op.sequence_empty(dtype=np.int64)

result = op.loop(
num_iters,
v_initial=[v],
body=lambda i, c, x: (op.const(True), op.sequence_insert(x, op.const([1]))),
)[0]

assert result.type == Sequence(Tensor(np.int64, None))


@pytest.mark.parametrize("op", [op17, op19, op21])
def test_loop_optional(op):
num_iters = op.const(1)
v = op.optional(type=Tensor(np.int64, (1, 2)))

result = op.loop(
num_iters,
v_initial=[v],
body=lambda i, c, x: (
op.const(True),
op.if_(
op.optional_has_element(x),
then_branch=lambda: [op.optional(type=Tensor(np.int64, (1, 2)))],
else_branch=lambda: [op.optional(op.const([[1, 1]]))],
)[0],
),
)[0]

assert result.type == Optional(Tensor(np.int64, (1, 2)))


@pytest.mark.parametrize("op", [op17, op19, op21])
def test_loop_optional_no_shape(op):
num_iters = op.const(1)
v = op.optional(type=Tensor(np.int64, (1, 2)))

result = op.loop(
num_iters,
v_initial=[v],
body=lambda i, c, x: (
op.const(True),
op.if_(
op.optional_has_element(x),
then_branch=lambda: [op.optional(type=Tensor(np.int64, (1, 2)))],
else_branch=lambda: [op.optional(op.const([[1]]))],
)[0],
),
)[0]

# shape can change, we cannot infer type
assert result.type == Optional(Tensor(np.int64, None))
8 changes: 6 additions & 2 deletions tools/templates/type_inference/loop16-fix.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@ output_names = list(self.outputs.get_var_infos())

body = self.attrs.body.value

# We skip the iteration_num and condition as they are correctly inferred
initial_types = [v.type for v in list(body.requested_arguments)[2:]]
# We skip the returned condition as it is correctly inferred
carried_types = [v.type for v in list(body.requested_results.values())[1:]]

is_constant_shape = all(i_typ == c_typ for i_typ, c_typ in zip(initial_types, carried_types))
shape_unchanged_between_iterations = all(
i_typ == c_typ for i_typ, c_typ in zip(initial_types, carried_types)
)

for name, _, c_typ in zip(output_names, initial_types, carried_types):
output_types[name] = c_typ if is_constant_shape else c_typ._with_shape(None)
output_types[name] = c_typ if shape_unchanged_between_iterations else c_typ._with_shape(None)

return output_types

0 comments on commit 393ffdf

Please sign in to comment.