Skip to content

Commit

Permalink
Fix TorchGraph handlers
Browse files Browse the repository at this point in the history
  • Loading branch information
bauerfe committed Oct 31, 2024
1 parent c23acf8 commit b24edb8
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions sinabs/backend/dynapcnn/nir_graph_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,12 +342,12 @@ def _need_dvs_node(self, model: nn.Module, dvs_input: bool) -> bool:

return not has_dvs_layer and dvs_input

def _get_name_2_indx_map(self, nir_graph: nirtorch.graph.TorchGraph) -> Dict[str, int]:
def _get_name_2_indx_map(self, nir_graph: TorchGraph) -> Dict[str, int]:
"""Assign unique index to each node and return mapper from name to index.
Parameters
----------
- nir_graph (nirtorch.graph.TorchGraph): a NIR graph representation of `spiking_model`.
- nir_graph (TorchGraph): a NIR graph representation of `spiking_model`.
Returns
----------
Expand All @@ -362,14 +362,14 @@ def _get_name_2_indx_map(self, nir_graph: nirtorch.graph.TorchGraph) -> Dict[str
}

def _get_edges_from_nir(
self, nir_graph: nirtorch.graph.TorchGraph, name_2_indx_map: Dict[str, int]
self, nir_graph: TorchGraph, name_2_indx_map: Dict[str, int]
) -> Set[Edge]:
"""Standardize the representation of `nirtorch.graph.TorchGraph` into a list of edges,
"""Standardize the representation of TorchGraph` into a list of edges,
representing nodes by their indices.
Parameters
----------
- nir_graph (nirtorch.graph.TorchGraph): a NIR graph representation of `spiking_model`.
- nir_graph (TorchGraph): a NIR graph representation of `spiking_model`.
- name_2_indx_map (dict): Map from node names to unique indices.
Returns
Expand Down

0 comments on commit b24edb8

Please sign in to comment.