Skip to content

Commit

Permalink
docs: 🍱 switch out for homemade equivariance figure
Browse files Browse the repository at this point in the history
  • Loading branch information
neptunes5thmoon committed Jul 25, 2024
1 parent 43f4ce1 commit 23a89e2
Show file tree
Hide file tree
Showing 3 changed files with 992 additions and 41 deletions.
63 changes: 22 additions & 41 deletions solution.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# %% tags=[]
# %% tags=["solution", "task"]
# ruff: noqa: F811
# %% [markdown] tags=[]
# # Build Your Own U-Net
Expand Down Expand Up @@ -34,13 +34,13 @@
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms

import unet_tests
from local import (
NucleiDataset,
show_random_dataset_image,
apply_and_show_random_image,
plot_receptive_field,
show_random_dataset_image,
)
import unet_tests

# %% tags=[]
# make sure gpu is available. Please call a TA if this cell fails
Expand Down Expand Up @@ -195,10 +195,7 @@ def check_valid(self, image_size: tuple[int, int]) -> bool:

def forward(self, x):
if not self.check_valid(tuple(x.size()[-2:])):
raise RuntimeError(
"Can not downsample shape %s with factor %s"
% (x.size(), self.downsample_factor)
)
raise RuntimeError("Can not downsample shape %s with factor %s" % (x.size(), self.downsample_factor))

return self.down(x)

Expand Down Expand Up @@ -227,10 +224,7 @@ def check_valid(self, image_size: tuple[int, int]) -> bool:

def forward(self, x):
if not self.check_valid(tuple(x.size()[-2:])):
raise RuntimeError(
"Can not downsample shape %s with factor %s"
% (x.size(), self.downsample_factor)
)
raise RuntimeError("Can not downsample shape %s with factor %s" % (x.size(), self.downsample_factor))

return self.down(x)

Expand Down Expand Up @@ -352,13 +346,9 @@ def __init__(

# SOLUTION 3.1: Initialize your modules and define layers in conv pass
self.conv_pass = torch.nn.Sequential(
torch.nn.Conv2d(
in_channels, out_channels, kernel_size=kernel_size, padding=padding
),
torch.nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding),
torch.nn.ReLU(),
torch.nn.Conv2d(
out_channels, out_channels, kernel_size=kernel_size, padding=padding
),
torch.nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, padding=padding),
torch.nn.ReLU(),
)

Expand Down Expand Up @@ -414,7 +404,7 @@ def forward(self, x):


# %% tags=["task"]
def center_crop(x,y):
def center_crop(x, y):
"""Center-crop x to match spatial dimensions given by y."""

x_target_size = x.size()[:2] + y.size()[2:]
Expand All @@ -425,14 +415,15 @@ def center_crop(x,y):

return x[slices]


class CropAndConcat(torch.nn.Module):
def forward(self, encoder_output, upsample_output):
# TASK 4: Implement the forward function
...


# %% tags=["solution"]
def center_crop(x,y):
def center_crop(x, y):
"""Center-crop x to match spatial dimensions given by y."""

x_target_size = x.size()[:2] + y.size()[2:]
Expand All @@ -443,6 +434,7 @@ def center_crop(x,y):

return x[slices]


class CropAndConcat(torch.nn.Module):
def forward(self, encoder_output, upsample_output):
# SOLUTION 4: Implement the forward function
Expand Down Expand Up @@ -763,9 +755,7 @@ def __init__(
# SOLUTION 6.2A: Initialize list here
for level in range(self.depth):
fmaps_in, fmaps_out = self.compute_fmaps_encoder(level)
self.left_convs.append(
ConvBlock(fmaps_in, fmaps_out, self.kernel_size, self.padding)
)
self.left_convs.append(ConvBlock(fmaps_in, fmaps_out, self.kernel_size, self.padding))

# right convolutional passes
self.right_convs = torch.nn.ModuleList()
Expand All @@ -788,9 +778,7 @@ def __init__(
mode=self.upsample_mode,
)
self.crop_and_concat = CropAndConcat()
self.final_conv = OutputConv(
self.compute_fmaps_decoder(0)[1], self.out_channels, self.final_activation
)
self.final_conv = OutputConv(self.compute_fmaps_decoder(0)[1], self.out_channels, self.final_activation)

def compute_fmaps_encoder(self, level: int) -> tuple[int, int]:
"""Compute the number of input and output feature maps for
Expand Down Expand Up @@ -829,9 +817,7 @@ def compute_fmaps_decoder(self, level: int) -> tuple[int, int]:
"""
# SOLUTION 6.1B: Implement this function
fmaps_out = self.num_fmaps * self.fmap_inc_factor ** (level)
concat_fmaps = self.compute_fmaps_encoder(level)[
1
] # The channels that come from the skip connection
concat_fmaps = self.compute_fmaps_encoder(level)[1] # The channels that come from the skip connection
fmaps_in = concat_fmaps + self.num_fmaps * self.fmap_inc_factor ** (level + 1)

return fmaps_in, fmaps_out
Expand Down Expand Up @@ -970,7 +956,10 @@ def forward(self, x):
# If math isn't your thing hopefully this picture helps to convey the concept, now specifically for translations.

# %% [markdown]
# ![image](static/equivariance.png)
# ![image](static/equivariance.svg)

# %% [markdown]
# <img src="static/equivariance.svg" alt="Invariance and Equivariance" style="width: 200px;"/>

# %% [markdown]
# <div class="alert alert-warning">
Expand Down Expand Up @@ -1086,17 +1075,11 @@ def train(
# log to tensorboard
if tb_logger is not None:
step = epoch * len(loader) + batch_id
tb_logger.add_scalar(
tag="train_loss", scalar_value=loss.item(), global_step=step
)
tb_logger.add_scalar(tag="train_loss", scalar_value=loss.item(), global_step=step)
# check if we log images in this iteration
if step % log_image_interval == 0:
tb_logger.add_images(
tag="input", img_tensor=x.to("cpu"), global_step=step
)
tb_logger.add_images(
tag="target", img_tensor=y.to("cpu"), global_step=step
)
tb_logger.add_images(tag="input", img_tensor=x.to("cpu"), global_step=step)
tb_logger.add_images(tag="target", img_tensor=y.to("cpu"), global_step=step)
tb_logger.add_images(
tag="prediction",
img_tensor=prediction.to("cpu").detach(),
Expand Down Expand Up @@ -1132,9 +1115,7 @@ def train(
logger = SummaryWriter(f"unet_runs/{model_name}")

# %% tags=["solution"]
model = UNet(
depth=4, in_channels=1
) # SOLUTION 8.1: Declare your U-Net here and name it below
model = UNet(depth=4, in_channels=1) # SOLUTION 8.1: Declare your U-Net here and name it below
model_name = "my_fav_unet" # This name will be used in the tensorboard logs
logger = SummaryWriter(f"unet_runs/{model_name}")

Expand Down
Binary file removed static/equivariance.png
Binary file not shown.
970 changes: 970 additions & 0 deletions static/equivariance.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 23a89e2

Please sign in to comment.