diff --git a/src/napari_swc_editor/bindings.py b/src/napari_swc_editor/bindings.py index 554b504..7df0fba 100644 --- a/src/napari_swc_editor/bindings.py +++ b/src/napari_swc_editor/bindings.py @@ -1,4 +1,6 @@ import napari +import numpy as np +from functools import partial from .swc_io import ( add_edge, @@ -14,6 +16,8 @@ structure_id_to_symbol, symbol_to_structure_id, update_point_properties, + get_branch_from_node, + get_index_from_treenode_id, ) @@ -47,7 +51,7 @@ def add_napari_layers_from_swc_content( structure_symbol = structure_id_to_symbol(structure) shape_layer = viewer.add_shapes( - lines, shape_type="line", edge_width=radius + lines, shape_type="line", edge_width=radius, ) add_kwargs_points = { @@ -95,11 +99,50 @@ def bind_layers_with_events(point_layer, shape_layer): point_layer.events.symbol.connect(event_update_point_properties) point_layer.bind_key("l")(event_add_edge) - point_layer.bind_key("Shift-l")(event_add_edge_wo_sort) point_layer.bind_key("u")(event_remove_edge) + point_layer.bind_key("b", select_branch) + point_layer.bind_key("Shift+b", select_upstream_branch) + point_layer.bind_key("Control", linked_point) point_layer.metadata["shape_layer"] = shape_layer + +def select_upstream_branch(layer): + """Select the upstream branch of the selected point + + Parameters + ---------- + layer : napari.layers.Points + Points layer + treenode_id : int + Index of the selected point + """ + + select_branch(layer, upstream=True, downstream=False) + +def select_branch(layer, upstream=True, downstream=True): + """Select the branch of the selected point + + Parameters + ---------- + layer : napari.layers.Points + Points layer + treenode_id : int + Index of the selected point + """ + + raw_swc = layer.metadata["raw_swc"] + df = parse_swc_content(raw_swc) + + selected_branches = [] + for selected in list(layer.selected_data): + indices = get_treenode_id_from_index([selected], df)[0] + branch_indices = get_branch_from_node(indices, df, upstream, downstream) + selected_branches.extend(branch_indices) + + selected_branches = np.unique(selected_branches) + indices = get_index_from_treenode_id(selected_branches, df) + layer.selected_data = indices def linked_point(layer): @@ -109,7 +152,7 @@ def linked_point(layer): layer.metadata["Ctrl_activated"] = True yield layer.metadata["Ctrl_activated"] = False - + def event_add_points(event): @@ -117,10 +160,15 @@ def event_add_points(event): raw_swc = event.source.metadata["raw_swc"] df = parse_swc_content(raw_swc) - new_pos = event.source.data[list(event.data_indices)] - new_radius = event.source.size[list(event.data_indices)] - new_structure = event.source.symbol[list(event.data_indices)] + + # indices must be deduced because the event.data_indices has only (-1,) see + # https://github.com/napari/napari/issues/7507 + indices = [-i-1 for i in reversed(range(0, len(event.source.data)-len(df)))] + new_pos = event.source.data[indices] + new_radius = event.source.size[indices] + new_structure = event.source.symbol[indices] new_parents = -1 + # if shift is activated, the add the new edges from previous selected point if ( @@ -145,6 +193,7 @@ def event_add_points(event): event.source.metadata["swc_data"] = df if new_parents != -1: + new_lines, new_r = create_line_data_from_swc_data(df) event.source.metadata["shape_layer"].data = [] event.source.metadata["shape_layer"].add_lines( @@ -226,21 +275,7 @@ def event_update_point_properties(event): event.source.metadata["shape_layer"].add_lines(new_lines, edge_width=new_r) event.source.metadata["swc_data"] = df - -def event_add_edge_wo_sort(layer): - """Add an edge between two selected points without sorting the indices - - Parameters - ---------- - layer : napari.layers.Points - Points layer - """ - ### This function is not used in the current version of the plugin - ### Because napari does not support history of the selected points - event_add_edge(layer, sort=False) - - -def event_add_edge(layer, indices=None, sort=True): +def event_add_edge(layer, treenode_id=None, parent_treenode_id=None, sort=True): """Add an edge between two selected points Parameters @@ -255,12 +290,17 @@ def event_add_edge(layer, indices=None, sort=True): raw_swc = layer.metadata["raw_swc"] df = parse_swc_content(raw_swc) - if indices is None: - indices = get_treenode_id_from_index(list(layer.selected_data), df) - - if sort: - indices = sort_edge_indices(raw_swc, indices, df) - new_swc, new_lines, new_r, df = add_edge(raw_swc, indices, df) + if treenode_id is None: + treenode_id = get_treenode_id_from_index(list(layer.selected_data), df) + + if parent_treenode_id is None: + # if no parent is given, the parent will be the following selected points + if sort: + treenode_id = sort_edge_indices(raw_swc, treenode_id, df) + parent_treenode_id = treenode_id[:-1].astype(int) + treenode_id = treenode_id[1:].astype(int) + + new_swc, new_lines, new_r, df = add_edge(raw_swc, treenode_id, parent_treenode_id, df) layer.metadata["raw_swc"] = new_swc # when updating the shape layer directly, the previous data diff --git a/src/napari_swc_editor/swc_io.py b/src/napari_swc_editor/swc_io.py index e3d837c..196211e 100644 --- a/src/napari_swc_editor/swc_io.py +++ b/src/napari_swc_editor/swc_io.py @@ -50,6 +50,16 @@ def parse_swc_content(file_content): "parent_treenode_id", ], index_col=0, + # set the type of the columns + dtype={ + "treenode_id": int, + "structure_id": int, + "x": float, + "y": float, + "z": float, + "r": float, + "parent_treenode_id": int, + }, ) return df @@ -298,6 +308,7 @@ def add_points( new_points = new_points[ ["structure_id", "x", "y", "z", "r", "parent_treenode_id"] ] + if swc_df.size > 0: previous_max = swc_df.index.max() @@ -424,8 +435,32 @@ def get_treenode_id_from_index(iloc, df): return indices +def get_index_from_treenode_id(indices, df): + """Get the iloc index from the treenode_id -def add_edge(swc_content, indices, swc_df=None): + Parameters + ---------- + indices : int or list of int + Treenode_id of the row in the dataframe + df : pd.DataFrame + Dataframe extracted from a swc file. Should have the following columns: + - treenode_id as index + - parent_treenode_id: id of the parent node + + Returns + ------- + iloc : np.ndarray + Index of the selected treenode_id + """ + + if isinstance(indices, int): + indices = [indices] + + iloc = df.index.get_indexer(indices) + + return iloc + +def add_edge(swc_content, treenode_id, parent_treenode_id, swc_df=None): """Add an edge between two or more indices in order Parameters @@ -451,13 +486,11 @@ def add_edge(swc_content, indices, swc_df=None): Dataframe extracted from the swc file """ - assert len(indices) >= 2, "At least two indices are needed to create edges" if swc_df is None: swc_df = parse_swc_content(swc_content) - for i in range(1, len(indices)): - swc_df.loc[indices[i], "parent_treenode_id"] = indices[i - 1] + swc_df.loc[treenode_id, "parent_treenode_id"] = parent_treenode_id new_lines, new_r = create_line_data_from_swc_data(swc_df) @@ -614,3 +647,51 @@ def update_point_properties(swc_content, indices, new_properties, swc_df=None): new_swc_content = write_swc_content(swc_df, swc_content) return new_swc_content, new_lines, new_r, swc_df + +def get_branch_from_node(node_id, df, upstream=True, downstream=True): + """Get the branch of a node + + Parameters + ---------- + node_id : int + Index of the node + df : pd.DataFrame + Dataframe extracted from a swc file. Should have the following columns: + - treenode_id as index + - parent_treenode_id: id of the parent node + upstream : bool + If True, get the branch from the selected node to the leaf + Default is True + downstream : bool + If True, get the branch from the selected node to the soma + Default is True + + Returns + ------- + branch : pd.DataFrame + Branch of the node + """ + + to_explore = [] + branch = [node_id] + if downstream: + # get the branch from the selected node to the soma + node_id = df.loc[node_id, "parent_treenode_id"] + while node_id != -1: + branch.append(df.loc[node_id, "parent_treenode_id"]) + + node_id = df.loc[node_id, "parent_treenode_id"] + + if upstream: + # get the branch from the selected node to the end + node_id = branch[0] + children = df[df["parent_treenode_id"] == node_id].index.values + to_explore.extend(children) + while len(to_explore) > 0: + node_id = to_explore.pop() + children = df[df["parent_treenode_id"] == node_id].index.values + to_explore.extend(children) + branch.append(node_id) + + branch = np.array(branch).astype(int) + return branch