Skip to content

Commit

Permalink
Update Task 1 with documentation links and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
cmalinmayor committed Jul 30, 2024
1 parent 2260d96 commit f9c05fe
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 32 deletions.
1 change: 1 addition & 0 deletions setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pip install plotly
pip install matplotlib
pip install ipywidgets
pip install nbformat
pip install pandas

# Make environment discoverable by Jupyter
pip install ipykernel
Expand Down
81 changes: 49 additions & 32 deletions solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
# %%
# %load_ext autoreload
# %autoreload 2
# TODO: remove

# %%
import time
Expand All @@ -84,6 +85,7 @@
from motile_toolbox.visualization import to_napari_tracks_layer
from napari.layers import Tracks
from csv import DictReader
import pandas as pd

from tqdm.auto import tqdm

Expand All @@ -95,31 +97,60 @@
# %% [markdown]
# For this exercise we will be working with a fluorescence microscopy time-lapse of breast cancer cells with stained nuclei (SiR-DNA). It is similar to the dataset at https://zenodo.org/record/4034976#.YwZRCJPP1qt. The raw data, pre-computed segmentations, and detection probabilities are saved in a zarr, and the ground truth tracks are saved in a csv. The segmentation was generated with a pre-trained StartDist model, so there may be some segmentation errors which can affect the tracking process. The detection probabilities also come from StarDist, and are downsampled in x and y by 2 compared to the detections and raw data.

# %% [markdown]
# Here we load the raw image data, segmentation, and probabilities from the zarr, and view them in napari.

# %%
data_path = "data/breast_cancer_fluo.zarr"
data_root = zarr.open(data_path, 'r')
image_data = data_root["raw"][:]
segmentation = data_root["seg_relabeled"][:]
probabilities = data_root["probs"][:]

# %% [markdown]
# Let's use [napari](https://napari.org/tutorials/fundamentals/getting_started.html) to visualize the data. Napari is a wonderful viewer for imaging data that you can interact with in python, even directly out of jupyter notebooks. If you've never used napari, you might want to take a few minutes to go through [this tutorial](https://napari.org/stable/tutorials/fundamentals/viewer.html).

# %%
viewer = napari.Viewer()
viewer.add_image(image_data, name="raw")
viewer.add_labels(segmentation, name="seg")
viewer.add_image(probabilities, name="probs", scale=(1, 2, 2))


# %% [markdown]
# ## Task 1: Read in the ground truth graph
#
# In addition to the image data and segmentations, we also have a ground truth tracking solution.
# The ground truth tracks are stored in a CSV with five columns: id, time, x, y, and parent_id.
#
# Each row in the CSV represents a detection at location (time, x, y) with the given id.
# If the parent_id is not -1, it represents the id of the parent in the previous time frame.
# If the parent_id is not -1, it represents the id of the parent detection in the previous time frame.
# For cell tracking, tracks can usually be stored in this format, because there is no merging.
# With merging, a more complicated data struture would be needed.
#
# Note that there are no ground truth segmentations - each detection is just a point representing the center of a cell.
#

# %% [markdown]
#
# <div class="alert alert-block alert-info"><h3>Task 1: Read in the ground truth graph</h3>
#
# For this task, you will read in the csv and store the tracks as a `networkx` DiGraph.
# Each node in the graph will represent a detection, and should use the given id, and have attributes `time` and `pos` to represent time and position (a list of [x, y]).
# Each edge in the graph will go from a parent to a child.
# For this task, you will read in the csv and store the tracks as a `networkx` DiGraph. Take a look at the documentation for the DiGraph <a href=https://networkx.org/documentation/stable/reference/classes/digraph.html>here</a> to learn how to create a graph, add nodes and edges with attributes, and access those nodes and edges.
#
# Here are the requirements for the graph:
# <ol>
# <li>Each row in the CSV becomes a node in the graph</li>
# <li>The node id is an integer specified by the "id" column in the csv</li>
# <li>Each node has an integer "time" attribute specified by the "time" column in the csv</li>
# <li>Each node has a list[float] "pos" attribute containing the ["x", "y"] values from the csv</li>
# <li>If the parent_id is not -1, then there is an edge in the graph from "parent_id" to "id"</li>
# </ol>
#
# You can read the CSV using basic python file io, csv.DictReader, pandas, or any other tool you are comfortable with. If not using pandas, remember to cast your read in values from strings to integers or floats.
# </div>
#

# %%
# %% tags=["task"]
def read_gt_tracks():
gt_tracks = nx.DiGraph()
### YOUR CODE HERE ###
Expand All @@ -130,9 +161,9 @@ def read_gt_tracks():

# %% tags=["solution"]
def read_gt_tracks():
gt_tracks = nx.DiGraph()
with open("data/breast_cancer_fluo_gt_tracks.csv") as f:
reader = DictReader(f)
gt_tracks = nx.DiGraph()
for row in reader:
_id = int(row["id"])
attrs = {
Expand All @@ -147,33 +178,19 @@ def read_gt_tracks():

gt_tracks = read_gt_tracks()

# %% [markdown]
# Let's use [napari](https://napari.org/tutorials/fundamentals/getting_started.html) to visualize the data. Napari is a wonderful viewer for imaging data that you can interact with in python, even directly out of jupyter notebooks. If you've never used napari, you might want to take a few minutes to go through [this tutorial](https://napari.org/stable/tutorials/fundamentals/viewer.html).

# %% [markdown]
# <div class="alert alert-block alert-danger"><h3>Napari in a jupyter notebook:</h3>
#
# - To have napari working in a jupyter notebook, you need to use up-to-date versions of napari, pyqt and pyqt5, as is the case in the conda environments provided together with this exercise.
# - When you are coding and debugging, close the napari viewer with `viewer.close()` to avoid problems with the two event loops of napari and jupyter.
# - **If a cell is not executed (empty square brackets on the left of a cell) despite you running it, running it a second time right after will usually work.**
# </div>

# %%
viewer = napari.viewer.current_viewer()
if viewer:
viewer.close()
viewer = napari.Viewer()
viewer.add_image(image_data, name="raw")
viewer.add_labels(segmentation, name="seg")
viewer.add_image(probabilities, name="probs", scale=(1, 2, 2))
tracks_layer = to_napari_tracks_layer(gt_tracks, frame_key="time", location_key="pos", name="gt_tracks")
viewer.add_layer(tracks_layer)

# %%
# viewer = napari.viewer.current_viewer()
# if viewer:
# viewer.close()

# run this cell to test your implementation
assert gt_tracks.number_of_nodes() == 5490, f"Found {gt_tracks.number_of_nodes()} nodes, expected 5490"
assert gt_tracks.number_of_edges() == 5120, f"Found {gt_tracks.number_of_edges()} edges, expected 5120"
for node, data in gt_tracks.nodes(data=True):
assert type(node) == int, f"Node id {node} has type {type(node)}, expected 'int'"
assert "time" in data, f"'time' attribute missing for node {node}"
assert type(data["time"]) == int, f"'time' attribute has type {type(data['time'])}, expected 'int'"
assert "pos" in data, f"'pos' attribute missing for node {node}"
assert type(data["pos"]) == list, f"'pos' attribute has type {type(data['pos'])}, expected 'list'"
assert len(data["pos"]) == 2, f"'pos' attribute has length {len(data['pos'])}, expected 2"
assert type(data["pos"][0]) == float, f"'pos' attribute element 0 has type {type(data['pos'][0])}, expected 'float'"
print("Your graph passed all the tests!")

# %% [markdown]
# ## Build a candidate graph from the detections
Expand Down

0 comments on commit f9c05fe

Please sign in to comment.