diff --git a/solution.py b/solution.py index d902ea0..b407a02 100644 --- a/solution.py +++ b/solution.py @@ -55,6 +55,7 @@ from dlmbl_unet import UNet from tqdm import tqdm import tifffile +import mwatershed as mws from skimage.filters import threshold_otsu @@ -71,7 +72,7 @@ # %% # Create a custom label color map for showing instances np.random.seed(1) -colors = [[0,0,0]] + [list(np.random.choice(range(256), size=3)) for _ in range(254)] +colors = [[0, 0, 0]] + [list(np.random.choice(range(256), size=3)) for _ in range(254)] label_cmap = ListedColormap(colors) # %% [markdown] @@ -87,11 +88,12 @@ #
- As an example, here, you see the SDT (right) of the target mask (middle), below. # %% [markdown] -# ![image](static/04_instance_sdt.png) +# ![image](static/figure2/04_instance_sdt.png) # # %% + def compute_sdt(labels: np.ndarray, scale: int = 5): """Function to compute a signed distance transform.""" dims = len(labels.shape) @@ -132,6 +134,7 @@ def compute_sdt(labels: np.ndarray, scale: int = 5): distances[labels == 0] *= -1 return distances + # %% [markdown] #
# Task 1.1: Explain the `compute_sdt` from the cell above. @@ -166,6 +169,7 @@ def compute_sdt(labels: np.ndarray, scale: int = 5): #
Note that the output of the signed distance transform is not binary, a significant difference from semantic segmentation # %% # Visualize the signed distance transform using the function you wrote above. + root_dir = "tissuenet_data/train" # the directory with all the training samples samples = os.listdir(root_dir) idx = np.random.randint(len(samples) // 3) # take a random sample. @@ -174,7 +178,7 @@ def compute_sdt(labels: np.ndarray, scale: int = 5): os.path.join(root_dir, f"img_{idx}_cyto_masks.tif") ) # get the image sdt = compute_sdt(label) -plot_two(img[1], sdt, label="SDT") +plot_two(img, sdt, label="SDT") # %% [markdown] #
@@ -254,9 +258,10 @@ def __getitem__(self, idx): image = self.transform(image) torch.manual_seed(seed) mask = self.transform(mask) - + # use the compute_sdt function to get the sdt sdt = ... + assert sdt.shape == mask.shape if self.img_transform is not None: image = self.img_transform(image) if self.return_mask is True: @@ -264,6 +269,7 @@ def __getitem__(self, idx): else: return image, sdt.unsqueeze(0) + # %% tags=["solution"] class SDTDataset(Dataset): """A PyTorch dataset to load cell images and nuclei masks.""" @@ -314,6 +320,7 @@ def __getitem__(self, idx): torch.manual_seed(seed) mask = self.transform(mask) sdt = self.create_sdt_target(mask) + assert sdt.shape == mask.shape if self.img_transform is not None: image = self.img_transform(image) if self.return_mask is True: @@ -341,7 +348,7 @@ def create_sdt_target(self, mask): idx = np.random.randint(len(train_data)) # take a random sample img, sdt = train_data[idx] # get the image and the nuclei masks print(img.shape, sdt.shape) -plot_two(img[1], sdt[0], label="SDT") +plot_two(img, sdt[0], label="SDT") # %% [markdown] #
@@ -379,7 +386,7 @@ def create_sdt_target(self, mask): # %% tags=["solution"] unet = UNet( - depth=2, + depth=3, in_channels=2, out_channels=1, final_activation=torch.nn.Tanh(), @@ -418,7 +425,7 @@ def create_sdt_target(self, mask): image = np.squeeze(image.cpu()) sdt = np.squeeze(sdt.cpu().numpy()) pred = np.squeeze(pred.cpu().detach().numpy()) -plot_three(image[1], sdt, pred) +plot_three(image, sdt, pred) # %% [markdown] @@ -449,13 +456,13 @@ def create_sdt_target(self, mask): def find_local_maxima(distance_transform, min_dist_between_points): - # Hint: Use `maximum_filter` to perform a maximum filter convolution on the distance_transform seeds, number_of_seeds = ... return seeds, number_of_seeds + # %% tags=["solution"] from scipy.ndimage import label, maximum_filter @@ -656,14 +663,11 @@ def get_inner_mask(pred, threshold): pred = np.squeeze(pred.cpu().detach().numpy()) # feel free to try different thresholds - thresh = threshold_otsu(pred) + thresh = ... # get boundary mask - inner_mask = get_inner_mask(pred, threshold=thresh) - - pred_labels = watershed_from_boundary_distance( - pred, inner_mask, id_offset=0, min_seed_distance=20 - ) + inner_mask = ... + pred_labels = ... precision, recall, accuracy = evaluate(gt_labels, pred_labels) precision_list.append(precision) recall_list.append(recall) @@ -701,11 +705,14 @@ def get_inner_mask(pred, threshold): pred = np.squeeze(pred.cpu().detach().numpy()) # feel free to try different thresholds - thresh = ... + thresh = threshold_otsu(pred) # get boundary mask - inner_mask = ... - pred_labels = ... + inner_mask = get_inner_mask(pred, threshold=thresh) + + pred_labels = watershed_from_boundary_distance( + pred, inner_mask, id_offset=0, min_seed_distance=20 + ) precision, recall, accuracy = evaluate(gt_labels, pred_labels) precision_list.append(precision) recall_list.append(recall) @@ -715,6 +722,7 @@ def get_inner_mask(pred, threshold): print(f"Mean Recall is {np.mean(recall_list):.3f}") print(f"Mean Accuracy is {np.mean(accuracy_list):.3f}") + # %% [markdown] #
# @@ -728,7 +736,7 @@ def get_inner_mask(pred, threshold): # Here, we show the (affinity in x + affinity in y) in the bottom right image. # %% [markdown] -# ![image](static/05_instance_affinity.png) +# ![image](static/figure3/instance_affinity.png) # %% [markdown] # Similar to the pipeline used for SDTs, we first need to modify the dataset to produce affinities. @@ -741,7 +749,15 @@ def get_inner_mask(pred, threshold): class AffinityDataset(Dataset): """A PyTorch dataset to load cell images and nuclei masks""" - def __init__(self, root_dir, transform=None, img_transform=None, return_mask=False): + def __init__( + self, + root_dir, + transform=None, + img_transform=None, + return_mask=False, + weights: bool = False, + ): + self.weights = weights self.root_dir = root_dir # the directory with all the training samples self.num_samples = len(os.listdir(self.root_dir)) // 3 # list the samples self.return_mask = return_mask @@ -788,13 +804,35 @@ def __getitem__(self, idx): aff_mask = self.create_aff_target(mask) if self.img_transform is not None: image = self.img_transform(image) - if self.return_mask is True: - return image, mask, aff_mask + + if self.weights: + weight = torch.zeros_like(aff_mask) + for channel in range(weight.shape[0]): + weight[channel][aff_mask[channel] == 0] = np.clip( + weight[channel].numel() + / 2 + / (weight[channel].numel() - weight[channel].sum()), + 0.1, + 10.0, + ) + weight[channel][aff_mask[channel] == 1] = np.clip( + weight[channel].numel() / 2 / weight[channel].sum(), 0.1, 10.0 + ) + + if self.return_mask is True: + return image, mask, aff_mask, weight + else: + return image, aff_mask, weight else: - return image, aff_mask + if self.return_mask is True: + return image, mask, aff_mask + else: + return image, aff_mask def create_aff_target(self, mask): - aff_target_array = compute_affinities(np.asarray(mask), [[0, 1], [1, 0]]) + aff_target_array = compute_affinities( + np.asarray(mask), [[0, 1], [1, 0], [0, 5], [5, 0]] + ) aff_target = torch.from_numpy(aff_target_array) return aff_target.float() @@ -804,13 +842,14 @@ def create_aff_target(self, mask): # %% # Initialize the datasets -train_data = AffinityDataset("tissuenet_data/train", v2.RandomCrop(256)) +train_data = AffinityDataset("tissuenet_data/train", v2.RandomCrop(256), weights=True) train_loader = DataLoader( train_data, batch_size=5, shuffle=True, num_workers=NUM_THREADS ) idx = np.random.randint(len(train_data)) # take a random sample -img, affinity = train_data[idx] # get the image and the nuclei masks -plot_two(img[1], affinity[0+2] + affinity[1+2], label="AFFINITY") +img, affinity, weight = train_data[idx] # get the image and the nuclei masks +plot_two(img, affinity, label="AFFINITY") + # %% [markdown] #
@@ -833,11 +872,11 @@ def create_aff_target(self, mask): # %% tags=["solution"] unet = UNet( - depth=2, + depth=4, in_channels=2, - out_channels=2, + out_channels=4, final_activation=torch.nn.Sigmoid(), - num_fmaps=4, + num_fmaps=16, fmap_inc_factor=3, downsample_factor=2, padding="same", @@ -846,11 +885,14 @@ def create_aff_target(self, mask): learning_rate = 1e-4 # choose a loss function -loss = torch.nn.MSELoss() +loss = torch.nn.MSELoss(reduce=False) optimizer = torch.optim.Adam(unet.parameters(), lr=learning_rate) -plot_three(image[1], mask[0] + mask[1], pred[0 + 2] + pred[1 + 2], label="Affinity") +val_data = AffinityDataset("tissuenet_data/test", v2.RandomCrop(256)) +val_loader = DataLoader(val_data, batch_size=1, shuffle=False, num_workers=8) + +# %% for epoch in range(NUM_EPOCHS): train( unet, @@ -871,20 +913,35 @@ def create_aff_target(self, mask): unet.eval() idx = np.random.randint(len(val_data)) # take a random sample -image, mask = val_data[idx] # get the image and the nuclei masks +image, affs = val_data[idx] # get the image and the nuclei masks image = image.to(device) pred = torch.squeeze(unet(torch.unsqueeze(image, dim=0))) - image = image.cpu() -mask = mask.cpu().numpy() +affs = affs.cpu().numpy() pred = pred.cpu().detach().numpy() -plot_three(image[1], mask[0] + mask[1], pred[0] + pred[1], label="Affinity") +bias_short = -0.9 +bias_long = -0.95 + +pred_labels = mws.agglom( + np.array( + [ + pred[0] + bias_short, + pred[1] + bias_short, + pred[2] + bias_long, + pred[3] + bias_long, + ] + ).astype(np.float64), + [[0, 1], [1, 0], [0, 5], [5, 0]], +) + +plot_four(image, affs, pred, pred_labels, label="Affinity") # %% [markdown] # Let's also evaluate the model performance. # %% + val_dataset = AffinityDataset("tissuenet_data/test", return_mask=True) val_loader = DataLoader( val_dataset, batch_size=1, shuffle=False, num_workers=NUM_THREADS @@ -911,17 +968,29 @@ def create_aff_target(self, mask): pred = np.squeeze(pred.cpu().detach().numpy()) - # feel free to try different thresholds - thresh = threshold_otsu(pred) + # # feel free to try different thresholds + # thresh = threshold_otsu(pred) - # get boundary mask - inner_mask = 0.5 * (pred[0] + pred[1]) > thresh + # # get boundary mask + # inner_mask = 0.5 * (pred[0] + pred[1]) > thresh - boundary_distances = distance_transform_edt(inner_mask) + # boundary_distances = distance_transform_edt(inner_mask) - pred_labels = watershed_from_boundary_distance( - boundary_distances, inner_mask, id_offset=0, min_seed_distance=20 + # pred_labels = watershed_from_boundary_distance( + # boundary_distances, inner_mask, id_offset=0, min_seed_distance=20 + # ) + pred_labels = mws.agglom( + np.array( + [ + pred[0] - bias_short, + pred[1] - bias_short, + pred[2] - bias_long, + pred[3] - bias_long, + ] + ).astype(np.float64), + [[0, 1], [1, 0], [0, 5], [5, 0]], ) + precision, recall, accuracy = evaluate(gt_labels, pred_labels) precision_list.append(precision) recall_list.append(recall)