Skip to content

Commit

Permalink
[WIP] Fix subgraphs
Browse files Browse the repository at this point in the history
  • Loading branch information
cbourjau committed Oct 23, 2023
1 parent 701e8ba commit d18e4e2
Showing 1 changed file with 25 additions and 5 deletions.
30 changes: 25 additions & 5 deletions src/spox/_standard.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Module implementing a base for standard ONNX operators, which use the functionality of ONNX node-level inference."""
import logging
from typing import TYPE_CHECKING, Callable, Dict, Tuple
from typing import TYPE_CHECKING, Callable, Dict, List, Tuple

import numpy
import onnx
Expand Down Expand Up @@ -47,7 +47,7 @@ def min_output(self) -> int:
return self.schema.min_output

def to_singleton_onnx_model(
self, *, dummy_outputs: bool = True, with_dummy_subgraphs: bool = True
self, *, dummy_outputs: bool = True, captured_subgraph_values_are_inputs=True
) -> Tuple[onnx.ModelProto, Scope]:
"""
Build a singleton model consisting of just this StandardNode. Used for type inference.
Expand Down Expand Up @@ -78,8 +78,7 @@ def to_singleton_onnx_model(
# Subgraphs are not fully built for possibly significant performance gains.
# However, this uses a trick so that they type correctly.
# This may throw if we are building ``not with_subgraphs``.
build_subgraph = _make_dummy_subgraph if with_dummy_subgraphs else None
(node_proto,) = self.to_onnx(scope, build_subgraph=build_subgraph)
(node_proto,) = self.to_onnx(scope, build_subgraph=_make_dummy_subgraph)
finally:
self.attrs = self_attrs
# Create a singleton graph for type inference with our node
Expand All @@ -89,6 +88,9 @@ def to_singleton_onnx_model(
for key, var in self.inputs.get_vars().items()
]

if captured_subgraph_values_are_inputs:
input_info += _make_value_info_from_subgraph_captured_variables(self)

# Output types with placeholder empty TypeProto (or actual type if not using dummies)
def out_value_info(curr_key, curr_var):
if dummy_outputs or curr_var.type is None or not curr_var.type._is_concrete:
Expand Down Expand Up @@ -168,7 +170,7 @@ def propagate_values_onnx(self) -> Dict[str, PropValueType]:
if next(iter(self.subgraphs), None) is not None:
# Cannot do propagation with subgraphs implicitly for performance - should be reimplemented
return {}
model, scope = self.to_singleton_onnx_model(with_dummy_subgraphs=False)
model, scope = self.to_singleton_onnx_model()
wrap_feed, run, unwrap_feed = _value_prop.get_backend_calls()
input_feed = {
scope.var[var]: wrap_feed(var._value)
Expand Down Expand Up @@ -253,3 +255,21 @@ def _make_dummy_subgraph(_node: Node, key: str, graph: "Graph") -> onnx.GraphPro
nodes.append(onnx.helper.make_node("Identity", [outer], [out]))

return onnx.helper.make_graph(nodes, f"__dummy_{key}", inputs, outputs)


def _make_value_info_from_subgraph_captured_variables(
node: Node,
) -> List[onnx.ValueInfoProto]:
infos = []
from spox._attributes import AttrGraph

for key, attr in node.attrs.get_fields().items():
if attr is not None:
if isinstance(attr, AttrGraph):
graph = attr.value

for i, arr in enumerate(graph.requested_results.values()):
outer = f"__dummy_outer_output{i}"
infos.append(arr.unwrap_type()._to_onnx_value_info(outer))

return infos

0 comments on commit d18e4e2

Please sign in to comment.