@@ -254,9 +258,10 @@ def __getitem__(self, idx):
image = self.transform(image)
torch.manual_seed(seed)
mask = self.transform(mask)
-
+
# use the compute_sdt function to get the sdt
sdt = ...
+ assert sdt.shape == mask.shape
if self.img_transform is not None:
image = self.img_transform(image)
if self.return_mask is True:
@@ -264,6 +269,7 @@ def __getitem__(self, idx):
else:
return image, sdt.unsqueeze(0)
+
# %% tags=["solution"]
class SDTDataset(Dataset):
"""A PyTorch dataset to load cell images and nuclei masks."""
@@ -314,6 +320,7 @@ def __getitem__(self, idx):
torch.manual_seed(seed)
mask = self.transform(mask)
sdt = self.create_sdt_target(mask)
+ assert sdt.shape == mask.shape
if self.img_transform is not None:
image = self.img_transform(image)
if self.return_mask is True:
@@ -341,7 +348,7 @@ def create_sdt_target(self, mask):
idx = np.random.randint(len(train_data)) # take a random sample
img, sdt = train_data[idx] # get the image and the nuclei masks
print(img.shape, sdt.shape)
-plot_two(img[1], sdt[0], label="SDT")
+plot_two(img, sdt[0], label="SDT")
# %% [markdown]
#
@@ -379,7 +386,7 @@ def create_sdt_target(self, mask):
# %% tags=["solution"]
unet = UNet(
- depth=2,
+ depth=3,
in_channels=2,
out_channels=1,
final_activation=torch.nn.Tanh(),
@@ -418,7 +425,7 @@ def create_sdt_target(self, mask):
image = np.squeeze(image.cpu())
sdt = np.squeeze(sdt.cpu().numpy())
pred = np.squeeze(pred.cpu().detach().numpy())
-plot_three(image[1], sdt, pred)
+plot_three(image, sdt, pred)
# %% [markdown]
@@ -449,13 +456,13 @@ def create_sdt_target(self, mask):
def find_local_maxima(distance_transform, min_dist_between_points):
-
# Hint: Use `maximum_filter` to perform a maximum filter convolution on the distance_transform
seeds, number_of_seeds = ...
return seeds, number_of_seeds
+
# %% tags=["solution"]
from scipy.ndimage import label, maximum_filter
@@ -656,14 +663,11 @@ def get_inner_mask(pred, threshold):
pred = np.squeeze(pred.cpu().detach().numpy())
# feel free to try different thresholds
- thresh = threshold_otsu(pred)
+ thresh = ...
# get boundary mask
- inner_mask = get_inner_mask(pred, threshold=thresh)
-
- pred_labels = watershed_from_boundary_distance(
- pred, inner_mask, id_offset=0, min_seed_distance=20
- )
+ inner_mask = ...
+ pred_labels = ...
precision, recall, accuracy = evaluate(gt_labels, pred_labels)
precision_list.append(precision)
recall_list.append(recall)
@@ -701,11 +705,14 @@ def get_inner_mask(pred, threshold):
pred = np.squeeze(pred.cpu().detach().numpy())
# feel free to try different thresholds
- thresh = ...
+ thresh = threshold_otsu(pred)
# get boundary mask
- inner_mask = ...
- pred_labels = ...
+ inner_mask = get_inner_mask(pred, threshold=thresh)
+
+ pred_labels = watershed_from_boundary_distance(
+ pred, inner_mask, id_offset=0, min_seed_distance=20
+ )
precision, recall, accuracy = evaluate(gt_labels, pred_labels)
precision_list.append(precision)
recall_list.append(recall)
@@ -715,6 +722,7 @@ def get_inner_mask(pred, threshold):
print(f"Mean Recall is {np.mean(recall_list):.3f}")
print(f"Mean Accuracy is {np.mean(accuracy_list):.3f}")
+
# %% [markdown]
#
#
@@ -728,7 +736,7 @@ def get_inner_mask(pred, threshold):
# Here, we show the (affinity in x + affinity in y) in the bottom right image.
# %% [markdown]
-# ![image](static/05_instance_affinity.png)
+# ![image](static/figure3/instance_affinity.png)
# %% [markdown]
# Similar to the pipeline used for SDTs, we first need to modify the dataset to produce affinities.
@@ -741,7 +749,15 @@ def get_inner_mask(pred, threshold):
class AffinityDataset(Dataset):
"""A PyTorch dataset to load cell images and nuclei masks"""
- def __init__(self, root_dir, transform=None, img_transform=None, return_mask=False):
+ def __init__(
+ self,
+ root_dir,
+ transform=None,
+ img_transform=None,
+ return_mask=False,
+ weights: bool = False,
+ ):
+ self.weights = weights
self.root_dir = root_dir # the directory with all the training samples
self.num_samples = len(os.listdir(self.root_dir)) // 3 # list the samples
self.return_mask = return_mask
@@ -788,13 +804,35 @@ def __getitem__(self, idx):
aff_mask = self.create_aff_target(mask)
if self.img_transform is not None:
image = self.img_transform(image)
- if self.return_mask is True:
- return image, mask, aff_mask
+
+ if self.weights:
+ weight = torch.zeros_like(aff_mask)
+ for channel in range(weight.shape[0]):
+ weight[channel][aff_mask[channel] == 0] = np.clip(
+ weight[channel].numel()
+ / 2
+ / (weight[channel].numel() - weight[channel].sum()),
+ 0.1,
+ 10.0,
+ )
+ weight[channel][aff_mask[channel] == 1] = np.clip(
+ weight[channel].numel() / 2 / weight[channel].sum(), 0.1, 10.0
+ )
+
+ if self.return_mask is True:
+ return image, mask, aff_mask, weight
+ else:
+ return image, aff_mask, weight
else:
- return image, aff_mask
+ if self.return_mask is True:
+ return image, mask, aff_mask
+ else:
+ return image, aff_mask
def create_aff_target(self, mask):
- aff_target_array = compute_affinities(np.asarray(mask), [[0, 1], [1, 0]])
+ aff_target_array = compute_affinities(
+ np.asarray(mask), [[0, 1], [1, 0], [0, 5], [5, 0]]
+ )
aff_target = torch.from_numpy(aff_target_array)
return aff_target.float()
@@ -804,13 +842,14 @@ def create_aff_target(self, mask):
# %%
# Initialize the datasets
-train_data = AffinityDataset("tissuenet_data/train", v2.RandomCrop(256))
+train_data = AffinityDataset("tissuenet_data/train", v2.RandomCrop(256), weights=True)
train_loader = DataLoader(
train_data, batch_size=5, shuffle=True, num_workers=NUM_THREADS
)
idx = np.random.randint(len(train_data)) # take a random sample
-img, affinity = train_data[idx] # get the image and the nuclei masks
-plot_two(img[1], affinity[0+2] + affinity[1+2], label="AFFINITY")
+img, affinity, weight = train_data[idx] # get the image and the nuclei masks
+plot_two(img, affinity, label="AFFINITY")
+
# %% [markdown]
#
@@ -833,11 +872,11 @@ def create_aff_target(self, mask):
# %% tags=["solution"]
unet = UNet(
- depth=2,
+ depth=4,
in_channels=2,
- out_channels=2,
+ out_channels=4,
final_activation=torch.nn.Sigmoid(),
- num_fmaps=4,
+ num_fmaps=16,
fmap_inc_factor=3,
downsample_factor=2,
padding="same",
@@ -846,11 +885,14 @@ def create_aff_target(self, mask):
learning_rate = 1e-4
# choose a loss function
-loss = torch.nn.MSELoss()
+loss = torch.nn.MSELoss(reduce=False)
optimizer = torch.optim.Adam(unet.parameters(), lr=learning_rate)
-plot_three(image[1], mask[0] + mask[1], pred[0 + 2] + pred[1 + 2], label="Affinity")
+val_data = AffinityDataset("tissuenet_data/test", v2.RandomCrop(256))
+val_loader = DataLoader(val_data, batch_size=1, shuffle=False, num_workers=8)
+
+# %%
for epoch in range(NUM_EPOCHS):
train(
unet,
@@ -871,20 +913,35 @@ def create_aff_target(self, mask):
unet.eval()
idx = np.random.randint(len(val_data)) # take a random sample
-image, mask = val_data[idx] # get the image and the nuclei masks
+image, affs = val_data[idx] # get the image and the nuclei masks
image = image.to(device)
pred = torch.squeeze(unet(torch.unsqueeze(image, dim=0)))
-
image = image.cpu()
-mask = mask.cpu().numpy()
+affs = affs.cpu().numpy()
pred = pred.cpu().detach().numpy()
-plot_three(image[1], mask[0] + mask[1], pred[0] + pred[1], label="Affinity")
+bias_short = -0.9
+bias_long = -0.95
+
+pred_labels = mws.agglom(
+ np.array(
+ [
+ pred[0] + bias_short,
+ pred[1] + bias_short,
+ pred[2] + bias_long,
+ pred[3] + bias_long,
+ ]
+ ).astype(np.float64),
+ [[0, 1], [1, 0], [0, 5], [5, 0]],
+)
+
+plot_four(image, affs, pred, pred_labels, label="Affinity")
# %% [markdown]
# Let's also evaluate the model performance.
# %%
+
val_dataset = AffinityDataset("tissuenet_data/test", return_mask=True)
val_loader = DataLoader(
val_dataset, batch_size=1, shuffle=False, num_workers=NUM_THREADS
@@ -911,17 +968,29 @@ def create_aff_target(self, mask):
pred = np.squeeze(pred.cpu().detach().numpy())
- # feel free to try different thresholds
- thresh = threshold_otsu(pred)
+ # # feel free to try different thresholds
+ # thresh = threshold_otsu(pred)
- # get boundary mask
- inner_mask = 0.5 * (pred[0] + pred[1]) > thresh
+ # # get boundary mask
+ # inner_mask = 0.5 * (pred[0] + pred[1]) > thresh
- boundary_distances = distance_transform_edt(inner_mask)
+ # boundary_distances = distance_transform_edt(inner_mask)
- pred_labels = watershed_from_boundary_distance(
- boundary_distances, inner_mask, id_offset=0, min_seed_distance=20
+ # pred_labels = watershed_from_boundary_distance(
+ # boundary_distances, inner_mask, id_offset=0, min_seed_distance=20
+ # )
+ pred_labels = mws.agglom(
+ np.array(
+ [
+ pred[0] - bias_short,
+ pred[1] - bias_short,
+ pred[2] - bias_long,
+ pred[3] - bias_long,
+ ]
+ ).astype(np.float64),
+ [[0, 1], [1, 0], [0, 5], [5, 0]],
)
+
precision, recall, accuracy = evaluate(gt_labels, pred_labels)
precision_list.append(precision)
recall_list.append(recall)