Skip to content

Commit

Permalink
Further bugfixes and improved readability of dynapcnn network repr.
Browse files Browse the repository at this point in the history
  • Loading branch information
bauerfe committed Nov 8, 2024
1 parent 876962f commit 9460edc
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 29 deletions.
20 changes: 8 additions & 12 deletions sinabs/backend/dynapcnn/dynapcnn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(
super().__init__()

self.in_shape = in_shape
self._pool = pool
self.pool = pool
self._discretize = discretize
self._rescale_weights = rescale_weights

Expand Down Expand Up @@ -135,20 +135,16 @@ def __init__(
if self._discretize:
conv, spk = discretize_conv_spike_(conv, spk, to_int=False)

self._conv = conv
self._spk = spk
self.conv = conv
self.spk = spk

@property
def conv_layer(self):
return self._conv
return self.conv

@property
def spk_layer(self):
return self._spk

@property
def pool(self):
return self._pool
return self.spk

@property
def discretize(self):
Expand All @@ -175,7 +171,7 @@ def forward(self, x) -> List[torch.Tensor]:
x = self.conv_layer(x)
x = self.spk_layer(x)

for pool in self._pool:
for pool in self.pool:

if pool == 1:
# no pooling is applied.
Expand Down Expand Up @@ -215,7 +211,7 @@ def get_output_shape(self) -> List[Tuple[int, int, int]]:
neuron_shape = self.get_neuron_shape()
# this is the actual output shape, including pooling
output_shape = []
for pool in self._pool:
for pool in self.pool:
output_shape.append(
neuron_shape[0],
neuron_shape[1] // pool,
Expand All @@ -227,7 +223,7 @@ def summary(self) -> dict:
"""Returns a summary of the convolution's/pooling's kernel sizes and the output shape of the spiking layer."""

return {
"pool": (self._pool),
"pool": (self.pool),
"kernel": list(self.conv_layer.weight.data.shape),
"neuron": self._get_conv_output_shape(), # neuron layer output has the same shape as the convolution layer ouput.
}
Expand Down
37 changes: 33 additions & 4 deletions sinabs/backend/dynapcnn/dynapcnn_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,13 @@ def exit_layers(self):
self.dynapcnn_layers[i] for i in self._dynapcnn_module.get_exit_layers()
]

@property
def is_deployed_on_dynapcnn_device(self):
return (
hasattr(self, "device")
and parse_device_id(self.device)[0] in ChipFactory.supported_devices
)

@property
def layer_destination_map(self):
return self._dynapcnn_module.destination_map
Expand Down Expand Up @@ -212,10 +219,7 @@ def forward(
structure as if `return_complete` is `True`, but only with entries
where the destination is marked as final.
"""
if (
hasattr(self, "device")
and parse_device_id(self.device)[0] in ChipFactory.supported_devices
):
if self.is_deployed_on_dynapcnn_device:
return self.hw_forward(x)
else:
# Forward pass through software DynapcnnLayer instance
Expand Down Expand Up @@ -654,12 +658,37 @@ def _to_device(self, device: torch.device) -> None:

def __str__(self):
pretty_print = ""
if self.dvs_layer is not None:
pretty_print += (
"-------------------------- [ DVSLayer ] --------------------------\n"
)
pretty_print += f"{self.dvs_layer}\n\n"
for idx, layer_data in self.dynapcnn_layers.items():
pretty_print += f"----------------------- [ DynapcnnLayer {idx} ] -----------------------\n"
if self.is_deployed_on_dynapcnn_device:
pretty_print += f"Core {self.layer2core_map[idx]}\n"
pretty_print += f"{layer_data}\n\n"

return pretty_print

def __repr__(self):
if self.is_deployed_on_dynapcnn_device:
layer_info = "\n\n".join(
f"{idx} - core: {self.layer2core_map[idx]}\n{pformat(layer)}"
for idx, layer in self.dynapcnn_layers.items()
)
device_info = f" deployed on {self.device},"
else:
layer_info = "\n\n".join(
f"Index: {idx}\n{pformat(layer)}"
for idx, layer in self.dynapcnn_layers.items()
)
device_info = f" on {self.device}," if hasattr(self, "device") else ""
return (
f"DynapCNN Network{device_info} containing:\nDVS Layer: {pformat(self.dvs_layer)}"
"\n\nDynapCNN Layers:\n\n" + layer_info
)


class DynapcnnCompatibleNetwork(DynapcnnNetwork):
"""Deprecated class, use DynapcnnNetwork instead."""
Expand Down
7 changes: 6 additions & 1 deletion sinabs/backend/dynapcnn/dynapcnnnetwork_module.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# author : Willian Soares Girao
# contact : [email protected]

from collections import defaultdict
from pprint import pformat
from typing import Dict, List, Optional, Set, Union
from warnings import warn

Expand Down Expand Up @@ -372,3 +372,8 @@ def remap(key):
remap(node): [remap(src) for src in sources]
for node, sources in self._node_source_map.items()
}

def __repr__(self):
return f"DVS Layer: {pformat(self.dvs_layer)}\n\nDynapCNN Layers:\n" + pformat(
self.dynapcnn_layers
)
30 changes: 25 additions & 5 deletions sinabs/backend/dynapcnn/nir_graph_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ def __init__(
# retrieves what the I/O shape for each node's module is.
self._nodes_io_shapes = self._get_nodes_io_shapes(dummy_input)

# Verify that graph is compatible
self.verify_graph_integrity()

####################################################### Publich Methods #######################################################

@property
Expand Down Expand Up @@ -180,11 +183,9 @@ def get_dynapcnn_network_module(
- The DynapcnnNetworkModule based on graph representation of this `GraphExtractor`
"""
# Make sure all nodes are supported
# Make sure all nodes are supported and there are no isolated nodes.
self.verify_node_types()

# Verify that graph is compatible
self.verify_graph_integrity()
self.verify_no_isolated_nodes()

# create a dict holding the data necessary to instantiate a `DynapcnnLayer`.
self.dcnnl_map, self.dvs_layer_info = collect_dynapcnn_layer_info(
Expand Down Expand Up @@ -281,7 +282,6 @@ def verify_graph_integrity(self):
Check that:
- Only nodes of specific classes have multiple sources or targets.
- There are no disconnected nodes except for `DVSLayer` instances.
Raises
------
Expand Down Expand Up @@ -379,6 +379,26 @@ def verify_node_types(self):
f"{pformat(SupportedNodeTypes)}."
)

def verify_no_isolated_nodes(self):
"""Verify that there are no disconnected nodes except for `DVSLayer` instances.
Raises
------
- InvalidGraphStructure when disconnected nodes are detected
"""
for node, module in self.indx_2_module_map.items():
# Make sure there are no individual, unconnected nodes
edges_with_node = {e for e in self.edges if node in e}
if not edges_with_node and not isinstance(module, DVSLayer):
raise InvalidGraphStructure(
f"There is an isolated module of type {type(module)}. Only "
"`DVSLayer` instances can be completely disconnected from "
"any other module. Other than that, layers for DynapCNN "
"consist of groups of weight layers (`Linear` or `Conv2d`), "
"spiking layers (`IAF` or `IAFSqueeze`), and optioanlly "
"pooling layers (`SumPool2d`, `AvgPool2d`)."
)

####################################################### Pivate Methods #######################################################

def _handle_dvs_input(
Expand Down
4 changes: 1 addition & 3 deletions tests/test_dynapcnn/test_doorbell.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,8 @@ def test_same_result():


def test_auto_config():
# - Should give an error with the normal layer ordering
dynapcnn_net = DynapcnnNetwork(snn, input_shape=input_shape, discretize=True)
with pytest.raises(ValueError):
dynapcnn_net.make_config(chip_layers_ordering=[0, 1, 2, 3, 4])
dynapcnn_net.make_config(chip_layers_ordering=[0, 1, 2, 3, 4])
dynapcnn_net.make_config(layer2core_map="auto")


Expand Down
8 changes: 4 additions & 4 deletions tests/test_dynapcnnnetwork/test_failcases.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,14 @@ def test_missing_spiking_layer():
in_shape = (2, 28, 28)
snn = nn.Sequential(
nn.Conv2d(2, 8, kernel_size=3, stride=1, bias=False),
sl.IAF(),
sl.IAFSqueeze(batch_size=1),
sl.SumPool2d(2),
nn.AvgPool2d(2),
nn.Conv2d(8, 16, kernel_size=3, stride=1, bias=False),
sl.IAF(),
sl.IAFSqueeze(batch_size=1),
nn.Dropout2d(),
nn.Conv2d(16, 2, kernel_size=3, stride=1, bias=False),
sl.IAF(),
sl.IAFSqueeze(batch_size=1),
nn.Flatten(),
nn.Linear(8, 5),
)
Expand All @@ -66,7 +66,7 @@ def test_missing_spiking_layer():
def test_incorrect_model_start():
in_shape = (2, 28, 28)
snn = nn.Sequential(
sl.IAF(),
sl.IAFSqueeze(batch_size=1),
sl.SumPool2d(2),
nn.AvgPool2d(2),
)
Expand Down

0 comments on commit 9460edc

Please sign in to comment.