Skip to content

Commit

Permalink
Update extend_readout_layer function to work with new DynapcnnNetwork
Browse files Browse the repository at this point in the history
  • Loading branch information
bauerfe committed Oct 31, 2024
1 parent b2e9214 commit 2970d66
Showing 1 changed file with 24 additions and 25 deletions.
49 changes: 24 additions & 25 deletions sinabs/backend/dynapcnn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,32 +368,31 @@ def extend_readout_layer(model: "DynapcnnNetwork") -> "DynapcnnNetwork":
"""
model = deepcopy(model)
input_shape = model.input_shape
og_readout_conv_layer = model.sequence[
-1
].conv_layer # extract the conv layer from dynapcnn network
og_weight_data = og_readout_conv_layer.weight.data
og_bias_data = og_readout_conv_layer.bias
og_bias = og_bias_data is not None
# modify the out channels
og_out_channels = og_readout_conv_layer.out_channels
new_out_channels = (og_out_channels - 1) * 4 + 1
og_readout_conv_layer.out_channels = new_out_channels
# build extended weight and replace the old one
ext_weight_shape = (new_out_channels, *og_weight_data.shape[1:])
ext_weight_data = torch.zeros(ext_weight_shape, dtype=og_weight_data.dtype)
for i in range(og_out_channels):
ext_weight_data[i * 4] = og_weight_data[i]
og_readout_conv_layer.weight.data = ext_weight_data
# build extended bias and replace if necessary
if og_bias:
ext_bias_shape = (new_out_channels,)
ext_bias_data = torch.zeros(ext_bias_shape, dtype=og_bias_data.dtype)
for exit_layer in model.exit_layers:
# extract the conv layer from dynapcnn network
og_readout_conv_layer = exit_layer.conv_layer
og_weight_data = og_readout_conv_layer.weight.data
og_bias_data = og_readout_conv_layer.bias
og_bias = og_bias_data is not None
# modify the out channels
og_out_channels = og_readout_conv_layer.out_channels
new_out_channels = (og_out_channels - 1) * 4 + 1
og_readout_conv_layer.out_channels = new_out_channels
# build extended weight and replace the old one
ext_weight_shape = (new_out_channels, *og_weight_data.shape[1:])
ext_weight_data = torch.zeros(ext_weight_shape, dtype=og_weight_data.dtype)
for i in range(og_out_channels):
ext_bias_data[i * 4] = og_bias_data[i]
og_readout_conv_layer.bias.data = ext_bias_data
_ = model(
torch.zeros(size=(1, *input_shape))
) # run a forward pass to initialize the new weights and last IAF
ext_weight_data[i * 4] = og_weight_data[i]
og_readout_conv_layer.weight.data = ext_weight_data
# build extended bias and replace if necessary
if og_bias:
ext_bias_shape = (new_out_channels,)
ext_bias_data = torch.zeros(ext_bias_shape, dtype=og_bias_data.dtype)
for i in range(og_out_channels):
ext_bias_data[i * 4] = og_bias_data[i]
og_readout_conv_layer.bias.data = ext_bias_data
# run a forward pass to initialize the new weights and last IAF
model(torch.zeros(size=(1, *input_shape)))
return model


Expand Down

0 comments on commit 2970d66

Please sign in to comment.