Skip to content

Commit

Permalink
Update failing unit tests in test_large_net
Browse files Browse the repository at this point in the history
  • Loading branch information
bauerfe committed Oct 31, 2024
1 parent 2970d66 commit 3d4793f
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 7 deletions.
4 changes: 4 additions & 0 deletions sinabs/backend/dynapcnn/dynapcnn_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ def dynapcnn_layers(self):
@property
def dynapcnn_module(self):
return self._dynapcnn_module

@property
def exit_layers(self):
return [self.dynapcnn_layers[i] for i in self._dynapcnn_module.get_exit_layers()]

@property
def layer_destination_map(self):
Expand Down
29 changes: 22 additions & 7 deletions tests/test_dynapcnn/test_large_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,9 @@ def test_same_result():
assert torch.equal(dynapcnn_out.squeeze(), snn_out.squeeze())


# TODO: Define new test with actual network that is too large. Probably have it as fail case in test_dynapcnnnetwork
def test_too_large():
with pytest.raises(ValueError):
# - Should give an error with the normal layer ordering
dynapcnn_net.make_config(chip_layers_ordering=range(9))
pass


def test_auto_config():
Expand All @@ -100,10 +99,25 @@ def test_auto_config():


def test_was_copied():
# - Make sure that layers of different models are distinct objects
for lyr_snn, lyr_dynapcnn in zip(snn.spiking_model, dynapcnn_net.sequence):
assert lyr_snn is not lyr_dynapcnn
from nirtorch.utils import sanitize_name

# - Make sure that layers of different models are distinct objects
snn_layers = {sanitize_name(name): lyr for name, lyr in snn.named_modules()}
idx_2_name_map = {
idx: sanitize_name(name) for name, idx in dynapcnn_net.name_2_indx_map.items()
}
for idx, lyr_info in dynapcnn_net._graph_extractor.dcnnl_map.items():
conv_lyr_dynapcnn = dynapcnn_net.dynapcnn_layers[idx].conv_layer
conv_node_idx = lyr_info["conv"]["node_id"]
conv_name = idx_2_name_map[conv_node_idx]
conv_lyr_snn = snn_layers[conv_name]
assert conv_lyr_dynapcnn is not conv_lyr_snn

spk_lyr_dynapcnn = dynapcnn_net.dynapcnn_layers[idx].spk_layer
spk_node_idx = lyr_info["neuron"]["node_id"]
spk_name = idx_2_name_map[spk_node_idx]
spk_lyr_snn = snn_layers[spk_name]
assert spk_lyr_dynapcnn is not spk_lyr_snn

def test_make_config():
dynapcnn_net = DynapcnnNetwork(
Expand Down Expand Up @@ -162,6 +176,7 @@ def test_extended_readout_layer(out_channels: int):
)
extended_net = extend_readout_layer(dynapcnn_net)

converted_channels = extended_net.sequence[-1].conv_layer.out_channels
assert len(exit_layers := extended_net.exit_layers) == 1
converted_channels = exit_layers[0].conv_layer.out_channels

assert (out_channels - 1) * 4 + 1 == converted_channels

0 comments on commit 3d4793f

Please sign in to comment.