Skip to content

Commit

Permalink
refactor add edge api and branch selection
Browse files Browse the repository at this point in the history
  • Loading branch information
ClementCaporal committed Jan 14, 2025
1 parent c7ce908 commit 793d0f2
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 31 deletions.
94 changes: 67 additions & 27 deletions src/napari_swc_editor/bindings.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import napari
import numpy as np
from functools import partial

from .swc_io import (
add_edge,
Expand All @@ -14,6 +16,8 @@
structure_id_to_symbol,
symbol_to_structure_id,
update_point_properties,
get_branch_from_node,
get_index_from_treenode_id,
)


Expand Down Expand Up @@ -47,7 +51,7 @@ def add_napari_layers_from_swc_content(
structure_symbol = structure_id_to_symbol(structure)

shape_layer = viewer.add_shapes(
lines, shape_type="line", edge_width=radius
lines, shape_type="line", edge_width=radius,
)

add_kwargs_points = {
Expand Down Expand Up @@ -95,11 +99,50 @@ def bind_layers_with_events(point_layer, shape_layer):
point_layer.events.symbol.connect(event_update_point_properties)

point_layer.bind_key("l")(event_add_edge)
point_layer.bind_key("Shift-l")(event_add_edge_wo_sort)
point_layer.bind_key("u")(event_remove_edge)
point_layer.bind_key("b", select_branch)
point_layer.bind_key("Shift+b", select_upstream_branch)

point_layer.bind_key("Control", linked_point)

point_layer.metadata["shape_layer"] = shape_layer

def select_upstream_branch(layer):
"""Select the upstream branch of the selected point
Parameters
----------
layer : napari.layers.Points
Points layer
treenode_id : int
Index of the selected point
"""

select_branch(layer, upstream=True, downstream=False)

def select_branch(layer, upstream=True, downstream=True):
"""Select the branch of the selected point
Parameters
----------
layer : napari.layers.Points
Points layer
treenode_id : int
Index of the selected point
"""

raw_swc = layer.metadata["raw_swc"]
df = parse_swc_content(raw_swc)

selected_branches = []
for selected in list(layer.selected_data):
indices = get_treenode_id_from_index([selected], df)[0]
branch_indices = get_branch_from_node(indices, df, upstream, downstream)
selected_branches.extend(branch_indices)

selected_branches = np.unique(selected_branches)
indices = get_index_from_treenode_id(selected_branches, df)
layer.selected_data = indices


def linked_point(layer):
Expand All @@ -109,18 +152,23 @@ def linked_point(layer):
layer.metadata["Ctrl_activated"] = True
yield
layer.metadata["Ctrl_activated"] = False


def event_add_points(event):

if event.action == "added":
raw_swc = event.source.metadata["raw_swc"]

df = parse_swc_content(raw_swc)
new_pos = event.source.data[list(event.data_indices)]
new_radius = event.source.size[list(event.data_indices)]
new_structure = event.source.symbol[list(event.data_indices)]

# indices must be deduced because the event.data_indices has only (-1,) see
# https://github.com/napari/napari/issues/7507
indices = [-i-1 for i in reversed(range(0, len(event.source.data)-len(df)))]
new_pos = event.source.data[indices]
new_radius = event.source.size[indices]
new_structure = event.source.symbol[indices]
new_parents = -1


# if shift is activated, the add the new edges from previous selected point
if (
Expand All @@ -145,6 +193,7 @@ def event_add_points(event):
event.source.metadata["swc_data"] = df

if new_parents != -1:

new_lines, new_r = create_line_data_from_swc_data(df)
event.source.metadata["shape_layer"].data = []
event.source.metadata["shape_layer"].add_lines(
Expand Down Expand Up @@ -226,21 +275,7 @@ def event_update_point_properties(event):
event.source.metadata["shape_layer"].add_lines(new_lines, edge_width=new_r)
event.source.metadata["swc_data"] = df


def event_add_edge_wo_sort(layer):
"""Add an edge between two selected points without sorting the indices
Parameters
----------
layer : napari.layers.Points
Points layer
"""
### This function is not used in the current version of the plugin
### Because napari does not support history of the selected points
event_add_edge(layer, sort=False)


def event_add_edge(layer, indices=None, sort=True):
def event_add_edge(layer, treenode_id=None, parent_treenode_id=None, sort=True):
"""Add an edge between two selected points
Parameters
Expand All @@ -255,12 +290,17 @@ def event_add_edge(layer, indices=None, sort=True):
raw_swc = layer.metadata["raw_swc"]
df = parse_swc_content(raw_swc)

if indices is None:
indices = get_treenode_id_from_index(list(layer.selected_data), df)

if sort:
indices = sort_edge_indices(raw_swc, indices, df)
new_swc, new_lines, new_r, df = add_edge(raw_swc, indices, df)
if treenode_id is None:
treenode_id = get_treenode_id_from_index(list(layer.selected_data), df)

if parent_treenode_id is None:
# if no parent is given, the parent will be the following selected points
if sort:
treenode_id = sort_edge_indices(raw_swc, treenode_id, df)
parent_treenode_id = treenode_id[:-1].astype(int)
treenode_id = treenode_id[1:].astype(int)

new_swc, new_lines, new_r, df = add_edge(raw_swc, treenode_id, parent_treenode_id, df)

layer.metadata["raw_swc"] = new_swc
# when updating the shape layer directly, the previous data
Expand Down
89 changes: 85 additions & 4 deletions src/napari_swc_editor/swc_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,16 @@ def parse_swc_content(file_content):
"parent_treenode_id",
],
index_col=0,
# set the type of the columns
dtype={
"treenode_id": int,
"structure_id": int,
"x": float,
"y": float,
"z": float,
"r": float,
"parent_treenode_id": int,
},
)

return df
Expand Down Expand Up @@ -298,6 +308,7 @@ def add_points(
new_points = new_points[
["structure_id", "x", "y", "z", "r", "parent_treenode_id"]
]


if swc_df.size > 0:
previous_max = swc_df.index.max()
Expand Down Expand Up @@ -424,8 +435,32 @@ def get_treenode_id_from_index(iloc, df):

return indices

def get_index_from_treenode_id(indices, df):
"""Get the iloc index from the treenode_id
def add_edge(swc_content, indices, swc_df=None):
Parameters
----------
indices : int or list of int
Treenode_id of the row in the dataframe
df : pd.DataFrame
Dataframe extracted from a swc file. Should have the following columns:
- treenode_id as index
- parent_treenode_id: id of the parent node
Returns
-------
iloc : np.ndarray
Index of the selected treenode_id
"""

if isinstance(indices, int):
indices = [indices]

iloc = df.index.get_indexer(indices)

return iloc

def add_edge(swc_content, treenode_id, parent_treenode_id, swc_df=None):
"""Add an edge between two or more indices in order
Parameters
Expand All @@ -451,13 +486,11 @@ def add_edge(swc_content, indices, swc_df=None):
Dataframe extracted from the swc file
"""

assert len(indices) >= 2, "At least two indices are needed to create edges"

if swc_df is None:
swc_df = parse_swc_content(swc_content)

for i in range(1, len(indices)):
swc_df.loc[indices[i], "parent_treenode_id"] = indices[i - 1]
swc_df.loc[treenode_id, "parent_treenode_id"] = parent_treenode_id

new_lines, new_r = create_line_data_from_swc_data(swc_df)

Expand Down Expand Up @@ -614,3 +647,51 @@ def update_point_properties(swc_content, indices, new_properties, swc_df=None):
new_swc_content = write_swc_content(swc_df, swc_content)

return new_swc_content, new_lines, new_r, swc_df

def get_branch_from_node(node_id, df, upstream=True, downstream=True):
"""Get the branch of a node
Parameters
----------
node_id : int
Index of the node
df : pd.DataFrame
Dataframe extracted from a swc file. Should have the following columns:
- treenode_id as index
- parent_treenode_id: id of the parent node
upstream : bool
If True, get the branch from the selected node to the leaf
Default is True
downstream : bool
If True, get the branch from the selected node to the soma
Default is True
Returns
-------
branch : pd.DataFrame
Branch of the node
"""

to_explore = []
branch = [node_id]
if downstream:
# get the branch from the selected node to the soma
node_id = df.loc[node_id, "parent_treenode_id"]
while node_id != -1:
branch.append(df.loc[node_id, "parent_treenode_id"])

node_id = df.loc[node_id, "parent_treenode_id"]

if upstream:
# get the branch from the selected node to the end
node_id = branch[0]
children = df[df["parent_treenode_id"] == node_id].index.values
to_explore.extend(children)
while len(to_explore) > 0:
node_id = to_explore.pop()
children = df[df["parent_treenode_id"] == node_id].index.values
to_explore.extend(children)
branch.append(node_id)

branch = np.array(branch).astype(int)
return branch

0 comments on commit 793d0f2

Please sign in to comment.