Skip to content

Commit

Permalink
style: 🎨 run black
Browse files Browse the repository at this point in the history
  • Loading branch information
neptunes5thmoon committed Aug 14, 2024
1 parent 15cbfc7 commit 031ea98
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 58 deletions.
92 changes: 40 additions & 52 deletions solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def __init__(
in_channels: int,
out_channels: int,
kernel_size: int,
padding: str = "same"
padding: str = "same",
):
"""A convolution block for a U-Net. Contains two convolutions, each followed by a ReLU.
Expand Down Expand Up @@ -1179,12 +1179,12 @@ def train(
#
# Congratulations! You trained your first UNet that you implemented all by yourself!
#
# We will keep using this U-Net throughout the rest of the exercises. Whenever you see an import like `import dlmbl-unet` or
# We will keep using this U-Net throughout the rest of the exercises. Whenever you see an import like `import dlmbl-unet` or
# `from dlmbl-unet import UNet` it will be importing from [this repository](https://github.com/dlmbl/dlmbl-unet) which contains the solution to this notebook as a package (including the bonus exercises so don't peak just yet if you wanna solve the bonus too).
# </div>

# %% [markdown] tags=[]
# ## Bonus 1: 3D UNet
# ## Bonus 1: 3D UNet
# The UNet you implemented so far only works for 2D images, but in microscopy we often have 3D data that also needs to be processed as such, i.e. for some tasks it is important that the network's receptive field is 3D. So in this bonus exercise we will change our implementation to make the number of dimensions configurable.
#

Expand All @@ -1194,6 +1194,7 @@ def train(
# To make the same class usable for 2D and 3D data we will add an argument `ndim` to each building block.
# </div>


# %% tags=["task"]
class Downsample(torch.nn.Module):
def __init__(self, downsample_factor: int, ndim: int = 2):
Expand All @@ -1208,7 +1209,6 @@ def __init__(self, downsample_factor: int, ndim: int = 2):
# TASK 10A: Initialize the maxpool module
# Define what the downop should be based on `ndim`.
self.down = ... # YOUR CODE HERE


def check_valid(self, image_size: tuple[int, ...]) -> bool:
"""Check if the downsample factor evenly divides each image dimension.
Expand Down Expand Up @@ -1250,7 +1250,7 @@ def __init__(
NxN square kernel.
padding (str): The type of convolution padding to use. Either "same" or "valid".
Defaults to "same".
ndim (int): Number of dimensions for the convolution operation. Use 2 for 2D
ndim (int): Number of dimensions for the convolution operation. Use 2 for 2D
convolutions and 3 for 3D convolutions. Defaults to 2.
"""
super().__init__()
Expand All @@ -1261,7 +1261,7 @@ def __init__(
msg = "Only allowing odd kernel sizes."
raise ValueError(msg)

# TASK 10C: Initialize your modules and define layers.
# TASK 10C: Initialize your modules and define layers.
# Use the convolution module matching `ndim`.
# YOUR CODE HERE

Expand Down Expand Up @@ -1298,10 +1298,10 @@ def __init__(
convolutions and 3 for 3D convolutions. Defaults to 2.
"""
super().__init__()
if ndim not in (2,3):
if ndim not in (2, 3):
msg = f"Invalid number of dimensions: {ndim=}. Options are 2 or 3."
raise ValueError(msg)

# TASK 10E: Define the convolution submodule.
# Use the convolution module matching `ndim`.
# YOUR CODE HERE
Expand All @@ -1328,12 +1328,12 @@ def __init__(
kernel_size: int = 3,
padding: str = "same",
upsample_mode: str = "nearest",
ndim: int =2,
ndim: int = 2,
):
"""A U-Net for 2D or 3D input that expects tensors shaped like:
``(batch, channels, height, width)`` or ``(batch, channels, depth, height, width)``,
``(batch, channels, height, width)`` or ``(batch, channels, depth, height, width)``,
respectively.
Args:
depth:
The number of levels in the U-Net. 2 is the smallest that really
Expand Down Expand Up @@ -1385,17 +1385,17 @@ def __init__(
# left convolutional passes
self.left_convs = torch.nn.ModuleList()
# TASK 10G: Initialize list here
# After you implemented the conv pass you can copy this from TASK 6.2A,
# After you implemented the conv pass you can copy this from TASK 6.2A,
# but make sure to pass the ndim argument

# right convolutional passes
self.right_convs = torch.nn.ModuleList()
# TASK 10H: Initialize list here
# After you implemented the conv pass you can copy this from TASK 6.2B,
# After you implemented the conv pass you can copy this from TASK 6.2B,
# but make sure to pass the ndim argument

# TASK 10I: Initialize other modules here
# Same here, copy over from TASK 6.3, but make sure to add the ndim argument
# Same here, copy over from TASK 6.3, but make sure to add the ndim argument
# as needed.

def compute_fmaps_encoder(self, level: int) -> tuple[int, int]:
Expand Down Expand Up @@ -1467,12 +1467,8 @@ def __init__(self, downsample_factor: int, ndim: int = 2):
self.downsample_factor = downsample_factor
# SOLUTION 10A: Initialize the maxpool module
# Define what the downop should be based on `ndim`.
downops = {
2: torch.nn.MaxPool2d,
3: torch.nn.MaxPool3d
}
downops = {2: torch.nn.MaxPool2d, 3: torch.nn.MaxPool3d}
self.down = downops[ndim](downsample_factor)


def check_valid(self, image_size: tuple[int, ...]) -> bool:
"""Check if the downsample factor evenly divides each image dimension.
Expand Down Expand Up @@ -1518,7 +1514,7 @@ def __init__(
NxN square kernel.
padding (str): The type of convolution padding to use. Either "same" or "valid".
Defaults to "same".
ndim (int): Number of dimensions for the convolution operation. Use 2 for 2D
ndim (int): Number of dimensions for the convolution operation. Use 2 for 2D
convolutions and 3 for 3D convolutions. Defaults to 2.
"""
super().__init__()
Expand All @@ -1529,13 +1525,10 @@ def __init__(
msg = "Only allowing odd kernel sizes."
raise ValueError(msg)

# SOLUTION 10C: Initialize your modules and define layers.
# SOLUTION 10C: Initialize your modules and define layers.
# Use the convolution module matching `ndim`.
# YOUR CODE HERE
convops = {
2: torch.nn.Conv2d,
3: torch.nn.Conv3d
}
convops = {2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
self.conv_pass = torch.nn.Sequential(
convops[ndim](
in_channels, out_channels, kernel_size=kernel_size, padding=padding
Expand Down Expand Up @@ -1580,21 +1573,15 @@ def __init__(
convolutions and 3 for 3D convolutions. Defaults to 2.
"""
super().__init__()
if ndim not in (2,3):
if ndim not in (2, 3):
msg = f"Invalid number of dimensions: {ndim=}. Options are 2 or 3."
raise ValueError(msg)
# SOLUTION 10E: Define the convolution submodule.
# Use the convolution module matching `ndim`.
convops = {
2: torch.nn.Conv2d,
3: torch.nn.Conv3d
}
self.final_conv = convops[ndim](
in_channels, out_channels, 1, padding=0
)

convops = {2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
self.final_conv = convops[ndim](in_channels, out_channels, 1, padding=0)

self.activation = activation


def forward(self, x):
# SOLUTION 10F: Implement the forward function
Expand All @@ -1618,12 +1605,12 @@ def __init__(
kernel_size: int = 3,
padding: str = "same",
upsample_mode: str = "nearest",
ndim: int =2,
ndim: int = 2,
):
"""A U-Net for 2D or 3D input that expects tensors shaped like:
``(batch, channels, height, width)`` or ``(batch, channels, depth, height, width)``,
``(batch, channels, height, width)`` or ``(batch, channels, depth, height, width)``,
respectively.
Args:
depth:
The number of levels in the U-Net. 2 is the smallest that really
Expand Down Expand Up @@ -1675,32 +1662,30 @@ def __init__(
# left convolutional passes
self.left_convs = torch.nn.ModuleList()
# SOLUTION 10G: Initialize list here
# After you implemented the conv pass you can copy this from TASK 6.2A,
# After you implemented the conv pass you can copy this from TASK 6.2A,
# but make sure to pass the ndim argument
for level in range(self.depth):
fmaps_in, fmaps_out = self.compute_fmaps_encoder(level)
self.left_convs.append(
ConvBlock(fmaps_in, fmaps_out, self.kernel_size, self.padding, ndim=ndim)
ConvBlock(
fmaps_in, fmaps_out, self.kernel_size, self.padding, ndim=ndim
)
)
# right convolutional passes
self.right_convs = torch.nn.ModuleList()
# SOLUTION 10H: Initialize list here
# After you implemented the conv pass you can copy this from TASK 6.2B,
# After you implemented the conv pass you can copy this from TASK 6.2B,
# but make sure to pass the ndim argument
for level in range(self.depth - 1):
fmaps_in, fmaps_out = self.compute_fmaps_decoder(level)
self.right_convs.append(
ConvBlock(
fmaps_in,
fmaps_out,
self.kernel_size,
self.padding,
ndim=ndim
fmaps_in, fmaps_out, self.kernel_size, self.padding, ndim=ndim
)
)

# SOLUTION 10I: Initialize other modules here
# Same here, copy over from TASK 6.3, but make sure to add the ndim argument
# Same here, copy over from TASK 6.3, but make sure to add the ndim argument
# as needed.
self.downsample = Downsample(self.downsample_factor, ndim=ndim)
self.upsample = torch.nn.Upsample(
Expand All @@ -1709,7 +1694,10 @@ def __init__(
)
self.crop_and_concat = CropAndConcat()
self.final_conv = OutputConv(
self.compute_fmaps_decoder(0)[1], self.out_channels, self.final_activation, ndim=ndim
self.compute_fmaps_decoder(0)[1],
self.out_channels,
self.final_activation,
ndim=ndim,
)

def compute_fmaps_encoder(self, level: int) -> tuple[int, int]:
Expand Down Expand Up @@ -1785,14 +1773,14 @@ def forward(self, x):
concat = self.crop_and_concat(convolution_outputs[i], upsampled)
conv_output = self.right_convs[i](concat)
layer_input = conv_output

# SOLUTION 10O: Apply the final convolution and return the output
# Copy from TASK 6.4D
return self.final_conv(layer_input)


# %% [markdown] tags=[]
# Run the 3d test for your implementation of the UNet.
# Run the 3d test for your implementation of the UNet.

# %% tags=[]
unet_tests.TestUNet(UNet).run3d()
15 changes: 9 additions & 6 deletions unet_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def run(self):
print("TESTS PASSED")



class TestConvBlock:
def __init__(self, conv_module):
self.conv_module = conv_module
Expand Down Expand Up @@ -158,13 +157,13 @@ def test_shape_valid_3d(self) -> None:
downsample_factor=3,
kernel_size=5,
padding="valid",
ndim=3
ndim=3,
)
msg = "The output shape of your UNet is incorrect for valid padding in 3D."
assert unetvalid(torch.ones((2,2,140,140,140))).shape == torch.Size(
assert unetvalid(torch.ones((2, 2, 140, 140, 140))).shape == torch.Size(
(2, 1, 4, 4, 4)
), msg

def test_shape_same(self) -> None:
unetsame = self.unetmodule(
depth=4,
Expand All @@ -180,6 +179,7 @@ def test_shape_same(self) -> None:
assert unetsame(torch.ones((2, 2, 243, 243))).shape == torch.Size(
(2, 7, 243, 243)
), msg

def test_shape_same_3d(self) -> None:
unetsame = self.unetmodule(
depth=3,
Expand All @@ -190,10 +190,13 @@ def test_shape_same_3d(self) -> None:
downsample_factor=3,
kernel_size=5,
padding="same",
ndim=3
ndim=3,
)
msg = "The output shape of your Unet is incorrect for same padding in 3D."
assert unetsame(torch.ones((2,2,27,27,27))).shape == torch.Size((2,1,27,27,27)), msg
assert unetsame(torch.ones((2, 2, 27, 27, 27))).shape == torch.Size(
(2, 1, 27, 27, 27)
), msg

def run(self):
self.test_fmaps()
self.test_shape_valid()
Expand Down

0 comments on commit 031ea98

Please sign in to comment.