diff --git a/sinabs/backend/dynapcnn/dynapcnn_layer.py b/sinabs/backend/dynapcnn/dynapcnn_layer.py index d2412640..ea573701 100644 --- a/sinabs/backend/dynapcnn/dynapcnn_layer.py +++ b/sinabs/backend/dynapcnn/dynapcnn_layer.py @@ -96,7 +96,7 @@ def __init__( super().__init__() self.in_shape = in_shape - self._pool = pool + self.pool = pool self._discretize = discretize self._rescale_weights = rescale_weights @@ -135,20 +135,16 @@ def __init__( if self._discretize: conv, spk = discretize_conv_spike_(conv, spk, to_int=False) - self._conv = conv - self._spk = spk + self.conv = conv + self.spk = spk @property def conv_layer(self): - return self._conv + return self.conv @property def spk_layer(self): - return self._spk - - @property - def pool(self): - return self._pool + return self.spk @property def discretize(self): @@ -175,7 +171,7 @@ def forward(self, x) -> List[torch.Tensor]: x = self.conv_layer(x) x = self.spk_layer(x) - for pool in self._pool: + for pool in self.pool: if pool == 1: # no pooling is applied. @@ -215,7 +211,7 @@ def get_output_shape(self) -> List[Tuple[int, int, int]]: neuron_shape = self.get_neuron_shape() # this is the actual output shape, including pooling output_shape = [] - for pool in self._pool: + for pool in self.pool: output_shape.append( neuron_shape[0], neuron_shape[1] // pool, @@ -227,7 +223,7 @@ def summary(self) -> dict: """Returns a summary of the convolution's/pooling's kernel sizes and the output shape of the spiking layer.""" return { - "pool": (self._pool), + "pool": (self.pool), "kernel": list(self.conv_layer.weight.data.shape), "neuron": self._get_conv_output_shape(), # neuron layer output has the same shape as the convolution layer ouput. } diff --git a/sinabs/backend/dynapcnn/dynapcnn_network.py b/sinabs/backend/dynapcnn/dynapcnn_network.py index 32932fca..d848e39f 100644 --- a/sinabs/backend/dynapcnn/dynapcnn_network.py +++ b/sinabs/backend/dynapcnn/dynapcnn_network.py @@ -132,6 +132,13 @@ def exit_layers(self): self.dynapcnn_layers[i] for i in self._dynapcnn_module.get_exit_layers() ] + @property + def is_deployed_on_dynapcnn_device(self): + return ( + hasattr(self, "device") + and parse_device_id(self.device)[0] in ChipFactory.supported_devices + ) + @property def layer_destination_map(self): return self._dynapcnn_module.destination_map @@ -212,10 +219,7 @@ def forward( structure as if `return_complete` is `True`, but only with entries where the destination is marked as final. """ - if ( - hasattr(self, "device") - and parse_device_id(self.device)[0] in ChipFactory.supported_devices - ): + if self.is_deployed_on_dynapcnn_device: return self.hw_forward(x) else: # Forward pass through software DynapcnnLayer instance @@ -654,12 +658,37 @@ def _to_device(self, device: torch.device) -> None: def __str__(self): pretty_print = "" + if self.dvs_layer is not None: + pretty_print += ( + "-------------------------- [ DVSLayer ] --------------------------\n" + ) + pretty_print += f"{self.dvs_layer}\n\n" for idx, layer_data in self.dynapcnn_layers.items(): pretty_print += f"----------------------- [ DynapcnnLayer {idx} ] -----------------------\n" + if self.is_deployed_on_dynapcnn_device: + pretty_print += f"Core {self.layer2core_map[idx]}\n" pretty_print += f"{layer_data}\n\n" return pretty_print + def __repr__(self): + if self.is_deployed_on_dynapcnn_device: + layer_info = "\n\n".join( + f"{idx} - core: {self.layer2core_map[idx]}\n{pformat(layer)}" + for idx, layer in self.dynapcnn_layers.items() + ) + device_info = f" deployed on {self.device}," + else: + layer_info = "\n\n".join( + f"Index: {idx}\n{pformat(layer)}" + for idx, layer in self.dynapcnn_layers.items() + ) + device_info = f" on {self.device}," if hasattr(self, "device") else "" + return ( + f"DynapCNN Network{device_info} containing:\nDVS Layer: {pformat(self.dvs_layer)}" + "\n\nDynapCNN Layers:\n\n" + layer_info + ) + class DynapcnnCompatibleNetwork(DynapcnnNetwork): """Deprecated class, use DynapcnnNetwork instead.""" diff --git a/sinabs/backend/dynapcnn/dynapcnnnetwork_module.py b/sinabs/backend/dynapcnn/dynapcnnnetwork_module.py index 16a728c8..bd03d5d7 100644 --- a/sinabs/backend/dynapcnn/dynapcnnnetwork_module.py +++ b/sinabs/backend/dynapcnn/dynapcnnnetwork_module.py @@ -1,7 +1,7 @@ # author : Willian Soares Girao # contact : wsoaresgirao@gmail.com -from collections import defaultdict +from pprint import pformat from typing import Dict, List, Optional, Set, Union from warnings import warn @@ -372,3 +372,8 @@ def remap(key): remap(node): [remap(src) for src in sources] for node, sources in self._node_source_map.items() } + + def __repr__(self): + return f"DVS Layer: {pformat(self.dvs_layer)}\n\nDynapCNN Layers:\n" + pformat( + self.dynapcnn_layers + ) diff --git a/sinabs/backend/dynapcnn/nir_graph_extractor.py b/sinabs/backend/dynapcnn/nir_graph_extractor.py index 94579241..51d12c17 100644 --- a/sinabs/backend/dynapcnn/nir_graph_extractor.py +++ b/sinabs/backend/dynapcnn/nir_graph_extractor.py @@ -119,6 +119,9 @@ def __init__( # retrieves what the I/O shape for each node's module is. self._nodes_io_shapes = self._get_nodes_io_shapes(dummy_input) + # Verify that graph is compatible + self.verify_graph_integrity() + ####################################################### Publich Methods ####################################################### @property @@ -180,11 +183,9 @@ def get_dynapcnn_network_module( - The DynapcnnNetworkModule based on graph representation of this `GraphExtractor` """ - # Make sure all nodes are supported + # Make sure all nodes are supported and there are no isolated nodes. self.verify_node_types() - - # Verify that graph is compatible - self.verify_graph_integrity() + self.verify_no_isolated_nodes() # create a dict holding the data necessary to instantiate a `DynapcnnLayer`. self.dcnnl_map, self.dvs_layer_info = collect_dynapcnn_layer_info( @@ -281,7 +282,6 @@ def verify_graph_integrity(self): Check that: - Only nodes of specific classes have multiple sources or targets. - - There are no disconnected nodes except for `DVSLayer` instances. Raises ------ @@ -379,6 +379,26 @@ def verify_node_types(self): f"{pformat(SupportedNodeTypes)}." ) + def verify_no_isolated_nodes(self): + """Verify that there are no disconnected nodes except for `DVSLayer` instances. + + Raises + ------ + - InvalidGraphStructure when disconnected nodes are detected + """ + for node, module in self.indx_2_module_map.items(): + # Make sure there are no individual, unconnected nodes + edges_with_node = {e for e in self.edges if node in e} + if not edges_with_node and not isinstance(module, DVSLayer): + raise InvalidGraphStructure( + f"There is an isolated module of type {type(module)}. Only " + "`DVSLayer` instances can be completely disconnected from " + "any other module. Other than that, layers for DynapCNN " + "consist of groups of weight layers (`Linear` or `Conv2d`), " + "spiking layers (`IAF` or `IAFSqueeze`), and optioanlly " + "pooling layers (`SumPool2d`, `AvgPool2d`)." + ) + ####################################################### Pivate Methods ####################################################### def _handle_dvs_input( diff --git a/tests/test_dynapcnn/test_doorbell.py b/tests/test_dynapcnn/test_doorbell.py index a62f96ea..4c78ac80 100644 --- a/tests/test_dynapcnn/test_doorbell.py +++ b/tests/test_dynapcnn/test_doorbell.py @@ -75,10 +75,8 @@ def test_same_result(): def test_auto_config(): - # - Should give an error with the normal layer ordering dynapcnn_net = DynapcnnNetwork(snn, input_shape=input_shape, discretize=True) - with pytest.raises(ValueError): - dynapcnn_net.make_config(chip_layers_ordering=[0, 1, 2, 3, 4]) + dynapcnn_net.make_config(chip_layers_ordering=[0, 1, 2, 3, 4]) dynapcnn_net.make_config(layer2core_map="auto") diff --git a/tests/test_dynapcnnnetwork/test_failcases.py b/tests/test_dynapcnnnetwork/test_failcases.py index 125b1dff..846a76ad 100644 --- a/tests/test_dynapcnnnetwork/test_failcases.py +++ b/tests/test_dynapcnnnetwork/test_failcases.py @@ -47,14 +47,14 @@ def test_missing_spiking_layer(): in_shape = (2, 28, 28) snn = nn.Sequential( nn.Conv2d(2, 8, kernel_size=3, stride=1, bias=False), - sl.IAF(), + sl.IAFSqueeze(batch_size=1), sl.SumPool2d(2), nn.AvgPool2d(2), nn.Conv2d(8, 16, kernel_size=3, stride=1, bias=False), - sl.IAF(), + sl.IAFSqueeze(batch_size=1), nn.Dropout2d(), nn.Conv2d(16, 2, kernel_size=3, stride=1, bias=False), - sl.IAF(), + sl.IAFSqueeze(batch_size=1), nn.Flatten(), nn.Linear(8, 5), ) @@ -66,7 +66,7 @@ def test_missing_spiking_layer(): def test_incorrect_model_start(): in_shape = (2, 28, 28) snn = nn.Sequential( - sl.IAF(), + sl.IAFSqueeze(batch_size=1), sl.SumPool2d(2), nn.AvgPool2d(2), )