From 8f0b500323708fc02f2b25b4c2f4f5d9059a556c Mon Sep 17 00:00:00 2001 From: AnniekStok Date: Thu, 7 Nov 2024 17:17:08 +0100 Subject: [PATCH 01/25] first implementation of custom, manually assigned groups of nodes, stored as lists while interacting and on node attributes for saving and switching between runs --- .../application_menus/menu_widget.py | 1 + .../views/tree_view/tree_view_mode_widget.py | 10 +- .../data_views/views/tree_view/tree_widget.py | 89 +++++- .../data_views/views_coordinator/__init__.py | 1 + .../views_coordinator/collection_widget.py | 287 ++++++++++++++++++ .../views_coordinator/collections.py | 48 +++ .../views_coordinator/tracks_viewer.py | 8 +- 7 files changed, 429 insertions(+), 15 deletions(-) create mode 100644 src/motile_plugin/data_views/views_coordinator/collection_widget.py create mode 100644 src/motile_plugin/data_views/views_coordinator/collections.py diff --git a/src/motile_plugin/application_menus/menu_widget.py b/src/motile_plugin/application_menus/menu_widget.py index c814dfb..de59bf4 100644 --- a/src/motile_plugin/application_menus/menu_widget.py +++ b/src/motile_plugin/application_menus/menu_widget.py @@ -22,6 +22,7 @@ def __init__(self, viewer: napari.Viewer): tabwidget.addTab(motile_widget, "Track with Motile") tabwidget.addTab(editing_widget, "Edit Tracks") tabwidget.addTab(tracks_viewer.tracks_list, "Results List") + tabwidget.addTab(tracks_viewer.collection_widget, "Collections") layout = QVBoxLayout() layout.addWidget(tabwidget) diff --git a/src/motile_plugin/data_views/views/tree_view/tree_view_mode_widget.py b/src/motile_plugin/data_views/views/tree_view/tree_view_mode_widget.py index 552b77d..f62f8f6 100644 --- a/src/motile_plugin/data_views/views/tree_view/tree_view_mode_widget.py +++ b/src/motile_plugin/data_views/views/tree_view/tree_view_mode_widget.py @@ -22,15 +22,18 @@ def __init__(self): display_box = QGroupBox("Display [Q]") display_layout = QHBoxLayout() button_group = QButtonGroup() - self.show_all_radio = QRadioButton("All cells") + self.show_all_radio = QRadioButton("All") self.show_all_radio.setChecked(True) self.show_all_radio.clicked.connect(lambda: self._set_mode("all")) - self.show_lineage_radio = QRadioButton("Current lineage(s)") + self.show_lineage_radio = QRadioButton("Selected lineage(s)") self.show_lineage_radio.clicked.connect(lambda: self._set_mode("lineage")) + self.show_group_radio = QRadioButton("Group") + self.show_group_radio.clicked.connect(lambda: self._set_mode("group")) button_group.addButton(self.show_all_radio) button_group.addButton(self.show_lineage_radio) display_layout.addWidget(self.show_all_radio) display_layout.addWidget(self.show_lineage_radio) + display_layout.addWidget(self.show_group_radio) display_box.setLayout(display_layout) display_box.setMaximumWidth(250) display_box.setMaximumHeight(60) @@ -44,6 +47,9 @@ def _toggle_display_mode(self, event=None) -> None: """Toggle display mode""" if self.mode == "lineage": + self._set_mode("group") + self.show_all_group.setChecked(True) + elif self.mode == "group": self._set_mode("all") self.show_all_radio.setChecked(True) else: diff --git a/src/motile_plugin/data_views/views/tree_view/tree_widget.py b/src/motile_plugin/data_views/views/tree_view/tree_widget.py index fc3a590..a829902 100644 --- a/src/motile_plugin/data_views/views/tree_view/tree_widget.py +++ b/src/motile_plugin/data_views/views/tree_view/tree_widget.py @@ -162,7 +162,7 @@ def set_view( if feature == "tree": self.getAxis("bottom").setStyle(showValues=False) self.setLabel("bottom", text="") - else: # should this actually ever happen? + else: self.getAxis("bottom").setStyle(showValues=True) self.setLabel("bottom", text="Object size in calibrated units") self.autoRange() @@ -311,6 +311,7 @@ def set_selection(self, selected_nodes: list[Any], feature: str) -> None: t_values = [] for node_id in selected_nodes: node_df = self.track_df.loc[self.track_df["node_id"] == node_id] + x_axis_value = None if not node_df.empty: x_axis_value = node_df[axis_label].values[0] t = node_df["t"].values[0] @@ -324,14 +325,17 @@ def set_selection(self, selected_nodes: list[Any], feature: str) -> None: outlines[index] = pg.mkPen(color="c", width=2) # Center point if a single node is selected, center range if multiple nodes are selected - if len(selected_nodes) == 1: + if len(selected_nodes) == 1 and x_axis_value is not None: self._center_view(x_axis_value, t) else: - min_x = np.min(x_values) - max_x = np.max(x_values) - min_t = np.min(t_values) - max_t = np.max(t_values) - self._center_range(min_x, max_x, min_t, max_t) + if ( + len(x_values) > 0 + ): # only center if the selected objects are actually in the view (might not be the case if you are in group mode) + min_x = np.min(x_values) + max_x = np.max(x_values) + min_t = np.min(t_values) + max_t = np.max(t_values) + self._center_range(min_x, max_x, min_t, max_t) self.g.scatter.setPen(outlines) self.g.scatter.setSize(size) @@ -405,6 +409,7 @@ def __init__(self, viewer: napari.Viewer): super().__init__() self.track_df = pd.DataFrame() # all tracks self.lineage_df = pd.DataFrame() # the currently viewed subset of lineages + self.group_df = pd.DataFrame() # the currently viewed group self.graph = None self.mode = "all" # options: "all", "lineage" self.feature = "tree" # options: "tree", "area" @@ -413,6 +418,9 @@ def __init__(self, viewer: napari.Viewer): self.tracks_viewer = TracksViewer.get_instance(viewer) self.selected_nodes = self.tracks_viewer.selected_nodes self.selected_nodes.list_updated.connect(self._update_selected) + self.tracks_viewer.collection_widget.group_changed.connect( + self._update_selected + ) self.tracks_viewer.tracks_updated.connect(self._update_track_data) # Construct the tree view pyqtgraph widget @@ -538,7 +546,7 @@ def keyReleaseEvent(self, ev): def _update_selected(self): """Called whenever the selection list is updated. Only re-computes the full graph information when the new selection is not in the - lineage df (and in lineage mode) + lineage or group df (and in lineage/group mode) """ if self.mode == "lineage" and any( @@ -552,6 +560,14 @@ def _update_selected(self): self.feature, self.selected_nodes, ) + elif self.mode == "group": + self._update_group_df() + self.tree_widget.update( + self.group_df, + self.view_direction, + self.feature, + self.selected_nodes, + ) else: self.tree_widget.set_selection(self.selected_nodes, self.feature) @@ -608,7 +624,15 @@ def _update_track_data(self, reset_view: bool | None = None) -> None: self.selected_nodes, reset_view=reset_view, ) - + elif self.mode == "group": + self._update_group_df() + self.tree_widget.update( + self.group_df, + self.view_direction, + self.feature, + self.selected_nodes, + reset_view=reset_view, + ) else: self.tree_widget.update( self.track_df, @@ -625,8 +649,8 @@ def _set_mode(self, mode: str) -> None: Args: mode (str): The mode to set the view to. Options are "all" or "lineage" """ - if mode not in ["all", "lineage"]: - raise ValueError(f"Mode must be 'all' or 'lineage', got {mode}") + if mode not in ["all", "lineage", "group"]: + raise ValueError(f"Mode must be 'all', 'lineage' or 'group', got {mode}") self.mode = mode if mode == "all": @@ -635,6 +659,13 @@ def _set_mode(self, mode: str) -> None: else: self.view_direction = "horizontal" df = self.track_df + elif mode == "group": + if self.feature == "tree": + self.view_direction = "vertical" + else: + self.view_direction = "horizontal" + self._update_group_df() + df = self.group_df elif mode == "lineage": self.view_direction = "horizontal" self._update_lineage_df() @@ -655,7 +686,7 @@ def _set_feature(self, feature: str) -> None: raise ValueError(f"Feature must be 'tree' or 'area', got {feature}") self.feature = feature - if feature == "tree" and self.mode == "all": + if feature == "tree" and (self.mode == "all" or self.mode == "group"): self.view_direction = "vertical" else: self.view_direction = "horizontal" @@ -665,6 +696,8 @@ def _set_feature(self, feature: str) -> None: df = self.track_df if self.mode == "lineage": df = self.lineage_df + if self.mode == "group": + df = self.group_df self.navigation_widget.feature = self.feature self.tree_widget.update( @@ -697,3 +730,35 @@ def _update_lineage_df(self) -> None: self.lineage_df["x_axis_pos"].rank(method="dense").astype(int) - 1 ) self.navigation_widget.lineage_df = self.lineage_df + + def _update_group_df(self) -> None: + """Subset dataframe to include only nodes belonging to the current group/collection""" + + visible = [] + for ( + node_id + ) in self.tracks_viewer.collection_widget.selected_collection.collection: + if node_id in self.track_df["node_id"].tolist(): + visible += extract_lineage_tree(self.graph, node_id) + else: + self.tracks_viewer.collection_widget.selected_collection.collection._list.remove( + node_id + ) + self.group_df = self.track_df[ + self.track_df["node_id"].isin(visible) + ].reset_index() + self.group_df["x_axis_pos"] = ( + self.group_df["x_axis_pos"].rank(method="dense").astype(int) - 1 + ) + + if not self.group_df.empty: + # change the opacity for the nodes that are part of the lineage but strictly not of the group + self.group_df["color"] = self.group_df.apply( + lambda row: [*row["color"][:3], 62.0] + if row["node_id"] + not in self.tracks_viewer.collection_widget.selected_collection.collection._list + else row["color"], + axis=1, + ) + self.group_df["color"] = self.group_df["color"].apply(np.array) + self.navigation_widget.group_df = self.group_df diff --git a/src/motile_plugin/data_views/views_coordinator/__init__.py b/src/motile_plugin/data_views/views_coordinator/__init__.py index e69de29..ff8b08b 100644 --- a/src/motile_plugin/data_views/views_coordinator/__init__.py +++ b/src/motile_plugin/data_views/views_coordinator/__init__.py @@ -0,0 +1 @@ +from .collections import Collection # noqa diff --git a/src/motile_plugin/data_views/views_coordinator/collection_widget.py b/src/motile_plugin/data_views/views_coordinator/collection_widget.py new file mode 100644 index 0000000..5f1cd3b --- /dev/null +++ b/src/motile_plugin/data_views/views_coordinator/collection_widget.py @@ -0,0 +1,287 @@ +from __future__ import annotations + +from functools import partial +from typing import TYPE_CHECKING + +from motile_toolbox.candidate_graph.graph_attributes import NodeAttr +from napari._qt.qt_resources import QColoredSVGIcon +from qtpy.QtCore import Signal +from qtpy.QtWidgets import ( + QGroupBox, + QHBoxLayout, + QLabel, + QLineEdit, + QListWidget, + QListWidgetItem, + QPushButton, + QVBoxLayout, + QWidget, +) + +from motile_plugin.data_views.views.tree_view.tree_widget_utils import ( + extract_lineage_tree, +) + +from . import Collection + +if TYPE_CHECKING: + from motile_plugin.data_views.views_coordinator.tracks_viewer import TracksViewer + + +class CollectionButton(QWidget): + def __init__(self, name: str): + super().__init__() + self.name = QLabel(name) + self.name.setFixedHeight(20) + self.collection = Collection() + delete_icon = QColoredSVGIcon.from_resources("delete").colored("white") + self.delete = QPushButton(icon=delete_icon) + self.delete.setFixedSize(20, 20) + layout = QHBoxLayout() + layout.setSpacing(1) + layout.addWidget(self.name) + layout.addWidget(self.delete) + self.setLayout(layout) + + def sizeHint(self): + hint = super().sizeHint() + hint.setHeight(30) + return hint + + +class CollectionWidget(QGroupBox): + """Widget for holding in-memory Collections (groups). Emits a signal whenever + a collection is selected in the list, to update the viewing properties + """ + + group_changed = Signal() + + def __init__(self, tracks_viewer: TracksViewer): + super().__init__(title="Collections") + + self.tracks_viewer = tracks_viewer + + self.collection_list = QListWidget() + self.collection_list.setSelectionMode(1) # single selection + self.collection_list.itemSelectionChanged.connect(self._selection_changed) + self.selected_collection = None + + # edit layout + edit_widget = QGroupBox("Edit") + edit_layout = QVBoxLayout() + + add_layout = QHBoxLayout() + add_node = QPushButton("Add node(s)") + add_node.clicked.connect(self.add_node) + add_track = QPushButton("Add track(s)") + add_track.clicked.connect(self.add_track) + add_lineage = QPushButton("Add lineage(s)") + add_lineage.clicked.connect(self.add_lineage) + add_layout.addWidget(add_node) + add_layout.addWidget(add_track) + add_layout.addWidget(add_lineage) + + remove_layout = QHBoxLayout() + remove_node = QPushButton("Remove node(s)") + remove_node.clicked.connect(self.remove_node) + remove_track = QPushButton("Remove track(s)") + remove_track.clicked.connect(self.remove_track) + remove_lineage = QPushButton("Remove lineage(s)") + remove_lineage.clicked.connect(self.remove_lineage) + remove_layout.addWidget(remove_node) + remove_layout.addWidget(remove_track) + remove_layout.addWidget(remove_lineage) + + edit_layout.addLayout(add_layout) + edit_layout.addLayout(remove_layout) + edit_widget.setLayout(edit_layout) + + # adding a new group + new_group_layout = QHBoxLayout() + new_group_layout.addWidget(QLabel("New group:")) + self.group_name = QLineEdit("new group") + new_group_layout.addWidget(self.group_name) + new_group_button = QPushButton("Create") + new_group_button.clicked.connect(self.new_group) + new_group_layout.addWidget(new_group_button) + + # combine widgets + layout = QVBoxLayout() + layout.addWidget(self.collection_list) + layout.addWidget(edit_widget) + layout.addLayout(new_group_layout) + self.setLayout(layout) + + def retrieve_existing_groups(self): + # first clear the entire list + self.collection_list.clear() + + # check for existing groups in the node attributes + group_dict = {} + for node, data in self.tracks_viewer.tracks.graph.nodes(data=True): + groups = data.get("group") + if groups: # Only add if 'group' attribute is present and not None + for group in groups: + if group not in group_dict: + group_dict[group] = [] + self.add_group(group, select=False) + group_dict[group].append(node) + + # populate the lists based on the nodes that were assigned to the different groups + for i in range(self.collection_list.count()): + self.collection_list.setCurrentRow(i) + self.selected_collection.collection.add( + group_dict[self.selected_collection.name.text()] + ) + + def _selection_changed(self): + selected = self.collection_list.selectedItems() + if selected: + self.selected_collection = self.collection_list.itemWidget(selected[0]) + self.group_changed.emit() + + def add_node(self): + if self.selected_collection is not None: + self.selected_collection.collection.add(self.tracks_viewer.selected_nodes) + for node_id in self.tracks_viewer.selected_nodes: + if "group" not in self.tracks_viewer.tracks.graph.nodes[node_id]: + self.tracks_viewer.tracks.graph.nodes[node_id]["group"] = [] + if ( + self.selected_collection.name.text() + not in self.tracks_viewer.tracks.graph.nodes[node_id]["group"] + ): + self.tracks_viewer.tracks.graph.nodes[node_id]["group"].append( + self.selected_collection.name.text() + ) + self.group_changed.emit() + + def add_track(self): + if self.selected_collection is not None: + for node_id in self.tracks_viewer.selected_nodes: + track_id = self.tracks_viewer.tracks._get_node_attr( + node_id, NodeAttr.TRACK_ID.value + ) + track = list( + { + node + for node, data in self.tracks_viewer.tracks.graph.nodes( + data=True + ) + if data.get("track_id") == track_id + } + ) + self.selected_collection.collection.add(track) + for node_id in track: + if "group" not in self.tracks_viewer.tracks.graph.nodes[node_id]: + self.tracks_viewer.tracks.graph.nodes[node_id]["group"] = [] + if ( + self.selected_collection.name.text() + not in self.tracks_viewer.tracks.graph.nodes[node_id]["group"] + ): + self.tracks_viewer.tracks.graph.nodes[node_id]["group"].append( + self.selected_collection.name.text() + ) + self.group_changed.emit() + + def add_lineage(self): + if self.selected_collection is not None: + for node_id in self.tracks_viewer.selected_nodes: + lineage = extract_lineage_tree(self.tracks_viewer.tracks.graph, node_id) + self.selected_collection.collection.add(lineage) + for node_id in lineage: + if "group" not in self.tracks_viewer.tracks.graph.nodes[node_id]: + self.tracks_viewer.tracks.graph.nodes[node_id]["group"] = [] + if ( + self.selected_collection.name.text() + not in self.tracks_viewer.tracks.graph.nodes[node_id]["group"] + ): + self.tracks_viewer.tracks.graph.nodes[node_id]["group"].append( + self.selected_collection.name.text() + ) + self.group_changed.emit() + + def remove_node(self): + if self.selected_collection is not None: + self.selected_collection.collection.remove( + self.tracks_viewer.selected_nodes + ) + for node_id in self.tracks_viewer.selected_nodes: + self.tracks_viewer.tracks.graph.nodes[node_id]["group"].remove( + self.selected_collection.name.text() + ) + self.group_changed.emit() + + def remove_track(self): + if self.selected_collection is not None: + for node_id in self.tracks_viewer.selected_nodes: + track_id = self.tracks_viewer.tracks._get_node_attr( + node_id, NodeAttr.TRACK_ID.value + ) + track = list( + { + node + for node, data in self.tracks_viewer.tracks.graph.nodes( + data=True + ) + if data.get("track_id") == track_id + } + ) + self.selected_collection.collection.remove(track) + for node_id in track: + self.tracks_viewer.tracks.graph.nodes[node_id]["group"].remove( + self.selected_collection.name.text() + ) + self.group_changed.emit() + + def remove_lineage(self): + if self.selected_collection is not None: + for node_id in self.tracks_viewer.selected_nodes: + lineage = extract_lineage_tree(self.tracks_viewer.tracks.graph, node_id) + self.selected_collection.collection.remove(lineage) + for node_id in lineage: + self.tracks_viewer.tracks.graph.nodes[node_id]["group"].remove( + self.selected_collection.name.text() + ) + self.group_changed.emit() + + def add_group(self, name: str, select=True): + """Create a new custom group""" + + names = [ + self.collection_list.itemWidget(self.collection_list.item(i)).name.text() + for i in range(self.collection_list.count()) + ] + while name in names: + name = name + "_1" + item = QListWidgetItem(self.collection_list) + group_row = CollectionButton(name) + self.collection_list.setItemWidget(item, group_row) + item.setSizeHint(group_row.minimumSizeHint()) + self.collection_list.addItem(item) + group_row.delete.clicked.connect(partial(self.remove_group, item)) + if select: + self.collection_list.setCurrentRow(len(self.collection_list) - 1) + + def remove_group(self, item: QListWidgetItem): + """Remove a collection object from the list. You must pass the list item that + represents the collection, not the collection object itself. + + Args: + item (QListWidgetItem): The list item to remove. This list item + contains the CollectionButton that represents a set of tracks. + """ + row = self.collection_list.indexFromItem(item).row() + group_name = self.collection_list.itemWidget(item).name.text() + self.collection_list.takeItem(row) + + # also delete the group from the node attributes + for _, data in self.tracks_viewer.tracks.graph.nodes(data=True): + groups = data.get("group") + + if groups and group_name in groups: + groups.remove(group_name) # Remove the group from the list + + def new_group(self): + """Create a new group""" + + self.add_group(name=self.group_name.text(), select=True) diff --git a/src/motile_plugin/data_views/views_coordinator/collections.py b/src/motile_plugin/data_views/views_coordinator/collections.py new file mode 100644 index 0000000..8e2e831 --- /dev/null +++ b/src/motile_plugin/data_views/views_coordinator/collections.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +from psygnal import Signal +from PyQt5.QtCore import QObject + + +class Collection(QObject): + """A collection of nodes that sends a signal on every update. + Stores a list of node ids only.""" + + list_updated = Signal() + + def __init__(self): + super().__init__() + self._list = [] + + def add(self, items: list, append: bool | None = True): + """Add nodes from a list and emit a single signal""" + + if append: + for item in items: + if item in self._list: + continue + else: + self._list.append(item) + + else: + self._list = items + + self.list_updated.emit() + + def remove(self, items: list): + """Remove nodes from a list and emit a single signal""" + + self._list = [item for item in self._list if item not in items] + + self.list_updated.emit() + + def reset(self): + """Empty list and emit update signal""" + self._list = [] + self.list_updated.emit() + + def __getitem__(self, index): + return self._list[index] + + def __len__(self): + return len(self._list) diff --git a/src/motile_plugin/data_views/views_coordinator/tracks_viewer.py b/src/motile_plugin/data_views/views_coordinator/tracks_viewer.py index 6a3bada..bcaca38 100644 --- a/src/motile_plugin/data_views/views_coordinator/tracks_viewer.py +++ b/src/motile_plugin/data_views/views_coordinator/tracks_viewer.py @@ -15,6 +15,7 @@ ) from motile_plugin.utils.relabel_segmentation import relabel_segmentation +from .collection_widget import CollectionWidget from .node_selection_list import NodeSelectionList from .tracks_list import TracksList @@ -66,6 +67,8 @@ def __init__( self.tracks_list = TracksList() self.tracks_list.view_tracks.connect(self.update_tracks) + self.collection_widget = CollectionWidget(self) + self.set_keybinds() def set_keybinds(self): @@ -89,7 +92,7 @@ def _refresh(self, node: str | None = None, refresh_view: bool = False) -> None: if len(self.selected_nodes) > 0 and any( not self.tracks.graph.has_node(node) for node in self.selected_nodes ): - self.selected_nodes.reset() + self.selected_nodes._list = [] self.tracking_layers._refresh() @@ -126,6 +129,9 @@ def update_tracks(self, tracks: Tracks, name: str) -> None: if isinstance(layer, (napari.layers.Labels | napari.layers.Points)): layer.visible = False + # retrieve existing groups + self.collection_widget.retrieve_existing_groups() + self.set_display_mode("all") self.tracking_layers.set_tracks(tracks, name) self.selected_nodes.reset() From 26070827a20a05494282656e2f23b8f25922dac5 Mon Sep 17 00:00:00 2001 From: AnniekStok Date: Fri, 8 Nov 2024 16:43:32 +0100 Subject: [PATCH 02/25] implement an experimental sync button to adjust the tree widget content based on the field of view in the napari viewer --- .../views/tree_view/tree_view_mode_widget.py | 2 +- .../views/tree_view/tree_view_sync_widget.py | 44 ++++ .../data_views/views/tree_view/tree_widget.py | 222 +++++++++++++++--- .../views_coordinator/tracks_viewer.py | 21 ++ 4 files changed, 252 insertions(+), 37 deletions(-) create mode 100644 src/motile_plugin/data_views/views/tree_view/tree_view_sync_widget.py diff --git a/src/motile_plugin/data_views/views/tree_view/tree_view_mode_widget.py b/src/motile_plugin/data_views/views/tree_view/tree_view_mode_widget.py index f62f8f6..8777785 100644 --- a/src/motile_plugin/data_views/views/tree_view/tree_view_mode_widget.py +++ b/src/motile_plugin/data_views/views/tree_view/tree_view_mode_widget.py @@ -35,7 +35,7 @@ def __init__(self): display_layout.addWidget(self.show_lineage_radio) display_layout.addWidget(self.show_group_radio) display_box.setLayout(display_layout) - display_box.setMaximumWidth(250) + display_box.setMaximumWidth(300) display_box.setMaximumHeight(60) layout = QVBoxLayout() diff --git a/src/motile_plugin/data_views/views/tree_view/tree_view_sync_widget.py b/src/motile_plugin/data_views/views/tree_view/tree_view_sync_widget.py new file mode 100644 index 0000000..f67d0e4 --- /dev/null +++ b/src/motile_plugin/data_views/views/tree_view/tree_view_sync_widget.py @@ -0,0 +1,44 @@ +from qtpy.QtWidgets import ( + QGroupBox, + QHBoxLayout, + QPushButton, + QVBoxLayout, + QWidget, +) + + +class SyncWidget(QWidget): + """Widget to switch between viewing all nodes versus nodes of one or more lineages in the tree widget""" + + def __init__(self): + super().__init__() + + sync_box = QGroupBox("Sync/Desync views") + sync_layout = QHBoxLayout() + self.sync_button = SyncButton() + sync_layout.addWidget(self.sync_button) + sync_box.setLayout(sync_layout) + layout = QVBoxLayout() + layout.addWidget(sync_box) + sync_box.setMaximumWidth(150) + sync_box.setMaximumHeight(60) + + self.setLayout(layout) + + +class SyncButton(QPushButton): + def __init__(self): + super().__init__() + + self.setCheckable(True) + self.setText("🔗") # Initial icon as Unicode and text + self.clicked.connect(self.toggle_state) + self.setFixedHeight(25) + + def toggle_state(self): + """Set text and icon depending on toggle state""" + + if self.isChecked(): + self.setText("❌") # Replace with your chosen broken link symbol + else: + self.setText("🔗 ") diff --git a/src/motile_plugin/data_views/views/tree_view/tree_widget.py b/src/motile_plugin/data_views/views/tree_view/tree_widget.py index a829902..0e7f3bc 100644 --- a/src/motile_plugin/data_views/views/tree_view/tree_widget.py +++ b/src/motile_plugin/data_views/views/tree_view/tree_widget.py @@ -6,9 +6,10 @@ import numpy as np import pandas as pd import pyqtgraph as pg +from napari.utils.events.event import Event from psygnal import Signal from pyqtgraph.Qt import QtCore -from qtpy.QtCore import Qt +from qtpy.QtCore import Qt, QTimer from qtpy.QtGui import QColor, QKeyEvent, QMouseEvent from qtpy.QtWidgets import ( QHBoxLayout, @@ -22,6 +23,7 @@ from .navigation_widget import NavigationWidget from .tree_view_feature_widget import TreeViewFeatureWidget from .tree_view_mode_widget import TreeViewModeWidget +from .tree_view_sync_widget import SyncWidget from .tree_widget_utils import ( extract_lineage_tree, extract_sorted_tracks, @@ -407,14 +409,28 @@ class TreeWidget(QWidget): def __init__(self, viewer: napari.Viewer): super().__init__() + self.viewer = viewer self.track_df = pd.DataFrame() # all tracks self.lineage_df = pd.DataFrame() # the currently viewed subset of lineages self.group_df = pd.DataFrame() # the currently viewed group + self.sync_df = ( + None # the objects synced with the field of view in napari viewer + ) self.graph = None self.mode = "all" # options: "all", "lineage" self.feature = "tree" # options: "tree", "area" self.view_direction = "vertical" # options: "horizontal", "vertical" + # Set debounce timer for syncing views + self.debounce_timer = QTimer() + self.debounce_timer.setSingleShot( + True + ) # Ensure the timer only fires once per activation + self.debounce_timer.timeout.connect( + self.sync_views + ) # Connect the timer to the sync function + + # Connect to tracks_viewer self.tracks_viewer = TracksViewer.get_instance(viewer) self.selected_nodes = self.tracks_viewer.selected_nodes self.selected_nodes.list_updated.connect(self._update_selected) @@ -447,32 +463,160 @@ def __init__(self, viewer: napari.Viewer): self.feature, ) + # Add widget for syncing views + self.sync = False + self.sync_widget = SyncWidget() + self.sync_widget.sync_button.setEnabled(False) + self.sync_widget.sync_button.clicked.connect(self.set_sync) + # Construct a toolbar and set main layout panel_layout = QHBoxLayout() panel_layout.addWidget(self.mode_widget) panel_layout.addWidget(self.feature_widget) panel_layout.addWidget(self.navigation_widget) + panel_layout.addWidget(self.sync_widget) panel_layout.setSpacing(0) panel_layout.setContentsMargins(0, 0, 0, 0) panel = QWidget() panel.setLayout(panel_layout) - panel.setMaximumWidth(820) + panel.setMaximumWidth(1000) panel.setMaximumHeight(78) # Make a collapsible for TreeView widgets - collapsable_widget = QCollapsible("Show/Hide Tree View Controls") - collapsable_widget.layout().setContentsMargins(0, 0, 0, 0) - collapsable_widget.layout().setSpacing(0) - collapsable_widget.addWidget(panel) - collapsable_widget.collapse(animate=False) + collapsible_widget = QCollapsible("Show/Hide Tree View Controls") + collapsible_widget.layout().setContentsMargins(0, 0, 0, 0) + collapsible_widget.layout().setSpacing(0) + collapsible_widget.addWidget(panel) + collapsible_widget.collapse(animate=False) - layout.addWidget(collapsable_widget) + layout.addWidget(collapsible_widget) layout.addWidget(self.tree_widget) layout.setSpacing(0) self.setLayout(layout) self._update_track_data(reset_view=True) + def set_sync(self) -> None: + """Lock/unlock syncing between napari views and tree view""" + + self.sync = not self.sync + + if self.sync: + self.viewer.camera.events.center.connect( + self.on_camera_center_change + ) # connect debounce timer start to the camera event + self._set_mode("all") + self.tracks_viewer.display_mode_updated.connect(self._set_mode) + else: + self.viewer.camera.events.center.disconnect(self.on_camera_center_change) + self.tracks_viewer.display_mode_updated.disconnect(self._set_mode) + if self.mode == "all": + self.tree_widget.update( + self.track_df, + self.view_direction, + self.feature, + self.selected_nodes, + reset_view=False, + ) + elif self.mode == "lineage": + self.tree_widget.update( + self.lineage_df, + self.view_direction, + self.feature, + self.selected_nodes, + reset_view=False, + ) + elif self.mode == "group": + self.tree_widget.update( + self.group_df, + self.view_direction, + self.feature, + self.selected_nodes, + reset_view=False, + ) + + def on_camera_center_change(self, event: Event | None = None): + """Start or restart the debounce timer with a delay (200 ms)""" + + self.debounce_timer.start(200) + + def sync_views(self, force_update: bool | None = None) -> None: + """Sync the data in the tree plot with the data in the field of view of the napari viewer""" + + if self.sync_df is None: + prev_visible = self.track_df["node_id"].tolist() + else: + prev_visible = self.sync_df["node_id"].tolist() + + corner_coordinates = ( + self.tracks_viewer.tracking_layers.points_layer.corner_pixels + ) + dims_displayed = self.viewer.dims.displayed + + # self.viewer.dims.displayed_order + x_dim = dims_displayed[-1] + y_dim = dims_displayed[-2] + + # find corner pixels for the displayed axes + _min_x = corner_coordinates[0][x_dim] + _max_x = corner_coordinates[1][x_dim] + _min_y = corner_coordinates[0][y_dim] + _max_y = corner_coordinates[1][y_dim] + + if self.mode == "all": + visible_nodes = self.track_df[ + (self.track_df["x"] >= _min_x) + & (self.track_df["x"] <= _max_x) + & (self.track_df["y"] >= _min_y) + & (self.track_df["y"] <= _max_y) + ]["node_id"].tolist() + elif self.mode == "lineage": + visible_nodes = self.lineage_df[ + (self.lineage_df["x"] >= _min_x) + & (self.lineage_df["x"] <= _max_x) + & (self.lineage_df["y"] >= _min_y) + & (self.lineage_df["y"] <= _max_y) + ]["node_id"].tolist() + elif self.mode == "group": + visible_nodes = self.group_df[ + (self.group_df["x"] >= _min_x) + & (self.group_df["x"] <= _max_x) + & (self.group_df["y"] >= _min_y) + & (self.group_df["y"] <= _max_y) + ]["node_id"].tolist() + + visible = [] + for node_id in visible_nodes: + visible += extract_lineage_tree(self.graph, node_id) + + if ( + set(visible) != set(prev_visible) or force_update + ): # only call update function if the list of visible nodes has changed + if self.mode == "all": + self.sync_df = self.track_df[ + self.track_df["node_id"].isin(visible) + ].reset_index() + elif self.mode == "lineage": + self.sync_df = self.lineage_df[ + self.lineage_df["node_id"].isin(visible) + ].reset_index() + elif self.mode == "group": + self.sync_df = self.group_df[ + self.group_df["node_id"].isin(visible) + ].reset_index() + + self.sync_df["x_axis_pos"] = ( + self.sync_df["x_axis_pos"].rank(method="dense").astype(int) - 1 + ) + + self.tree_widget.update( + self.sync_df, + self.view_direction, + self.feature, + self.selected_nodes, + reset_view=True, + ) + def keyPressEvent(self, event: QKeyEvent) -> None: """Handle key press events.""" key_map = { @@ -579,6 +723,7 @@ def _update_track_data(self, reset_view: bool | None = None) -> None: if self.tracks_viewer.tracks is None: self.track_df = pd.DataFrame() self.graph = None + self.sync_widget.sync_button.setEnabled(False) else: if reset_view: self.track_df = extract_sorted_tracks( @@ -591,6 +736,7 @@ def _update_track_data(self, reset_view: bool | None = None) -> None: self.track_df, ) self.graph = self.tracks_viewer.tracks.graph + self.sync_widget.sync_button.setEnabled(True) # check whether we have area measurements and therefore should activate the area # button @@ -614,28 +760,21 @@ def _update_track_data(self, reset_view: bool | None = None) -> None: self.navigation_widget.track_df = self.track_df self.navigation_widget.lineage_df = self.lineage_df - # check which view to set + # get the dataframe of currently displayed data, then update plot with or without filtering by the field of view (if sync is on) if self.mode == "lineage": self._update_lineage_df() - self.tree_widget.update( - self.lineage_df, - self.view_direction, - self.feature, - self.selected_nodes, - reset_view=reset_view, - ) + df = self.lineage_df elif self.mode == "group": self._update_group_df() - self.tree_widget.update( - self.group_df, - self.view_direction, - self.feature, - self.selected_nodes, - reset_view=reset_view, - ) + df = self.group_df + else: + df = self.track_df + + if self.sync: + self.sync_views(force_update=True) else: self.tree_widget.update( - self.track_df, + df, self.view_direction, self.feature, self.selected_nodes, @@ -671,9 +810,19 @@ def _set_mode(self, mode: str) -> None: self._update_lineage_df() df = self.lineage_df self.navigation_widget.view_direction = self.view_direction - self.tree_widget.update( - df, self.view_direction, self.feature, self.selected_nodes - ) + + if self.sync: + self.tracks_viewer.display_mode_updated.disconnect( + self._set_mode + ) # disconnect first to prevent infinite loop + self.tracks_viewer.set_display_mode(mode) + self.tracks_viewer.display_mode_updated.connect(self._set_mode) + self.sync_views() # update the tree plot with the nodes in the dataframe that are also in the napari view + + else: # update the tree plot with all nodes in the dataframe + self.tree_widget.update( + df, self.view_direction, self.feature, self.selected_nodes + ) def _set_feature(self, feature: str) -> None: """Set the feature mode to 'tree' or 'area'. For this the view is always @@ -735,15 +884,16 @@ def _update_group_df(self) -> None: """Subset dataframe to include only nodes belonging to the current group/collection""" visible = [] - for ( - node_id - ) in self.tracks_viewer.collection_widget.selected_collection.collection: - if node_id in self.track_df["node_id"].tolist(): - visible += extract_lineage_tree(self.graph, node_id) - else: - self.tracks_viewer.collection_widget.selected_collection.collection._list.remove( - node_id - ) + if self.tracks_viewer.collection_widget.selected_collection is not None: + for ( + node_id + ) in self.tracks_viewer.collection_widget.selected_collection.collection: + if node_id in self.track_df["node_id"].tolist(): + visible += extract_lineage_tree(self.graph, node_id) + else: + self.tracks_viewer.collection_widget.selected_collection.collection._list.remove( + node_id + ) self.group_df = self.track_df[ self.track_df["node_id"].isin(visible) ].reset_index() diff --git a/src/motile_plugin/data_views/views_coordinator/tracks_viewer.py b/src/motile_plugin/data_views/views_coordinator/tracks_viewer.py index bcaca38..63001fd 100644 --- a/src/motile_plugin/data_views/views_coordinator/tracks_viewer.py +++ b/src/motile_plugin/data_views/views_coordinator/tracks_viewer.py @@ -29,6 +29,7 @@ class TracksViewer: """ tracks_updated = Signal(Optional[bool]) + display_mode_updated = Signal(str) @classmethod def get_instance(cls, viewer=None): @@ -141,6 +142,8 @@ def toggle_display_mode(self, event=None) -> None: """Toggle the display mode between available options""" if self.mode == "lineage": + self.set_display_mode("group") + elif self.mode == "group": self.set_display_mode("all") else: self.set_display_mode("lineage") @@ -152,6 +155,9 @@ def set_display_mode(self, mode: str) -> None: if mode == "lineage": self.mode = "lineage" self.viewer.text_overlay.text = "Toggle Display [Q]\n Lineage" + elif mode == "group": + self.mode = "group" + self.viewer.text_overlay.text = "Toggle Display [Q]\n Group" else: self.mode = "all" self.viewer.text_overlay.text = "Toggle Display [Q]\n All" @@ -159,6 +165,7 @@ def set_display_mode(self, mode: str) -> None: self.viewer.text_overlay.visible = True visible = self.filter_visible_nodes() self.tracking_layers.update_visible(visible) + self.display_mode_updated.emit(mode) def filter_visible_nodes(self) -> list[int]: """Construct a list of track_ids that should be displayed""" @@ -187,6 +194,20 @@ def filter_visible_nodes(self) -> list[int]: for node in self.visible } ) + elif self.mode == "group": + self.group_visible = [] + if self.collection_widget.selected_collection is not None: + for node_id in self.collection_widget.selected_collection.collection: + self.group_visible += extract_lineage_tree( + self.tracks.graph, node_id + ) + + return list( + { + self.tracks.graph.nodes[node][NodeAttr.TRACK_ID.value] + for node in self.group_visible + } + ) else: return "all" From 8495740ab3f99fe7a66c139bb1cf8a5b1ccab80c Mon Sep 17 00:00:00 2001 From: AnniekStok Date: Fri, 8 Nov 2024 16:58:11 +0100 Subject: [PATCH 03/25] clean up --- .../views_coordinator/collection_widget.py | 20 ++++++++++++++++++- .../views_coordinator/collections.py | 17 ++++------------ 2 files changed, 23 insertions(+), 14 deletions(-) diff --git a/src/motile_plugin/data_views/views_coordinator/collection_widget.py b/src/motile_plugin/data_views/views_coordinator/collection_widget.py index 5f1cd3b..02b4c5c 100644 --- a/src/motile_plugin/data_views/views_coordinator/collection_widget.py +++ b/src/motile_plugin/data_views/views_coordinator/collection_widget.py @@ -29,6 +29,8 @@ class CollectionButton(QWidget): + """Widget holding a name and delete icon for listing in the QListWidget. Also contains an initially empty instance of a Collection to which nodes can be assigned""" + def __init__(self, name: str): super().__init__() self.name = QLabel(name) @@ -113,6 +115,8 @@ def __init__(self, tracks_viewer: TracksViewer): self.setLayout(layout) def retrieve_existing_groups(self): + """Create collections based on the node attributes. Nodes assigned to a group should have that group in their 'group' attribute""" + # first clear the entire list self.collection_list.clear() @@ -135,12 +139,16 @@ def retrieve_existing_groups(self): ) def _selection_changed(self): + """Update the currently selected collection and send update signal""" + selected = self.collection_list.selectedItems() if selected: self.selected_collection = self.collection_list.itemWidget(selected[0]) self.group_changed.emit() def add_node(self): + """Add individual nodes to the selected collection and send update signal""" + if self.selected_collection is not None: self.selected_collection.collection.add(self.tracks_viewer.selected_nodes) for node_id in self.tracks_viewer.selected_nodes: @@ -156,6 +164,8 @@ def add_node(self): self.group_changed.emit() def add_track(self): + """Add tracks by track_ids to the selected collection and send update signal""" + if self.selected_collection is not None: for node_id in self.tracks_viewer.selected_nodes: track_id = self.tracks_viewer.tracks._get_node_attr( @@ -184,6 +194,8 @@ def add_track(self): self.group_changed.emit() def add_lineage(self): + """Add lineages to the selected collection and send update signal""" + if self.selected_collection is not None: for node_id in self.tracks_viewer.selected_nodes: lineage = extract_lineage_tree(self.tracks_viewer.tracks.graph, node_id) @@ -201,6 +213,8 @@ def add_lineage(self): self.group_changed.emit() def remove_node(self): + """Remove individual nodes from the selected collection and send update signal""" + if self.selected_collection is not None: self.selected_collection.collection.remove( self.tracks_viewer.selected_nodes @@ -212,6 +226,8 @@ def remove_node(self): self.group_changed.emit() def remove_track(self): + """Remove tracks by track id from the selected collection and send update signal""" + if self.selected_collection is not None: for node_id in self.tracks_viewer.selected_nodes: track_id = self.tracks_viewer.tracks._get_node_attr( @@ -234,6 +250,8 @@ def remove_track(self): self.group_changed.emit() def remove_lineage(self): + """Remove lineages from the selected collection and send update signal""" + if self.selected_collection is not None: for node_id in self.tracks_viewer.selected_nodes: lineage = extract_lineage_tree(self.tracks_viewer.tracks.graph, node_id) @@ -268,7 +286,7 @@ def remove_group(self, item: QListWidgetItem): Args: item (QListWidgetItem): The list item to remove. This list item - contains the CollectionButton that represents a set of tracks. + contains the CollectionButton that represents a set of node_ids. """ row = self.collection_list.indexFromItem(item).row() group_name = self.collection_list.itemWidget(item).name.text() diff --git a/src/motile_plugin/data_views/views_coordinator/collections.py b/src/motile_plugin/data_views/views_coordinator/collections.py index 8e2e831..b5c46c0 100644 --- a/src/motile_plugin/data_views/views_coordinator/collections.py +++ b/src/motile_plugin/data_views/views_coordinator/collections.py @@ -1,21 +1,17 @@ from __future__ import annotations -from psygnal import Signal from PyQt5.QtCore import QObject class Collection(QObject): - """A collection of nodes that sends a signal on every update. - Stores a list of node ids only.""" - - list_updated = Signal() + """A collection of node ids belonging to a group""" def __init__(self): super().__init__() self._list = [] def add(self, items: list, append: bool | None = True): - """Add nodes from a list and emit a single signal""" + """Add nodes from a list""" if append: for item in items: @@ -27,19 +23,14 @@ def add(self, items: list, append: bool | None = True): else: self._list = items - self.list_updated.emit() - def remove(self, items: list): - """Remove nodes from a list and emit a single signal""" + """Remove nodes from a list""" self._list = [item for item in self._list if item not in items] - self.list_updated.emit() - def reset(self): - """Empty list and emit update signal""" + """Empty list""" self._list = [] - self.list_updated.emit() def __getitem__(self, index): return self._list[index] From 5d7049b59d59722738bcdc590590141147774b4d Mon Sep 17 00:00:00 2001 From: AnniekStok Date: Fri, 8 Nov 2024 17:26:07 +0100 Subject: [PATCH 04/25] fix sync update issue, and make sure the mode radio button is set to all when starting the sync mode --- src/motile_plugin/data_views/views/tree_view/tree_widget.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/motile_plugin/data_views/views/tree_view/tree_widget.py b/src/motile_plugin/data_views/views/tree_view/tree_widget.py index 0e7f3bc..d6e5c87 100644 --- a/src/motile_plugin/data_views/views/tree_view/tree_widget.py +++ b/src/motile_plugin/data_views/views/tree_view/tree_widget.py @@ -505,11 +505,15 @@ def set_sync(self) -> None: self.viewer.camera.events.center.connect( self.on_camera_center_change ) # connect debounce timer start to the camera event - self._set_mode("all") self.tracks_viewer.display_mode_updated.connect(self._set_mode) + self._set_mode("all") + self.mode_widget.show_all_radio.setChecked(True) + else: self.viewer.camera.events.center.disconnect(self.on_camera_center_change) self.tracks_viewer.display_mode_updated.disconnect(self._set_mode) + self.sync_df = None + if self.mode == "all": self.tree_widget.update( self.track_df, From ada38ed81a3b48ac11947c112f09b04af59d194e Mon Sep 17 00:00:00 2001 From: AnniekStok Date: Tue, 12 Nov 2024 14:45:47 +0100 Subject: [PATCH 05/25] implement a filter widget that highlights nodes fulfilling selection criteria and make it possible to assign filtered nodes to a new group --- .../application_menus/menu_widget.py | 1 + .../data_views/views/layers/track_points.py | 19 +- .../views/layers/tracks_layer_group.py | 2 +- .../data_views/views/tree_view/tree_widget.py | 56 ++- .../views_coordinator/collection_widget.py | 170 +++++-- .../views_coordinator/filter_widget.py | 433 ++++++++++++++++++ .../views_coordinator/tracks_viewer.py | 14 + 7 files changed, 642 insertions(+), 53 deletions(-) create mode 100644 src/motile_plugin/data_views/views_coordinator/filter_widget.py diff --git a/src/motile_plugin/application_menus/menu_widget.py b/src/motile_plugin/application_menus/menu_widget.py index de59bf4..58da275 100644 --- a/src/motile_plugin/application_menus/menu_widget.py +++ b/src/motile_plugin/application_menus/menu_widget.py @@ -23,6 +23,7 @@ def __init__(self, viewer: napari.Viewer): tabwidget.addTab(editing_widget, "Edit Tracks") tabwidget.addTab(tracks_viewer.tracks_list, "Results List") tabwidget.addTab(tracks_viewer.collection_widget, "Collections") + tabwidget.addTab(tracks_viewer.filter_widget, "Filters") layout = QVBoxLayout() layout.addWidget(tabwidget) diff --git a/src/motile_plugin/data_views/views/layers/track_points.py b/src/motile_plugin/data_views/views/layers/track_points.py index 39e9100..2aa5279 100644 --- a/src/motile_plugin/data_views/views/layers/track_points.py +++ b/src/motile_plugin/data_views/views/layers/track_points.py @@ -187,11 +187,11 @@ def get_symbols(self, tracks: Tracks, symbolmap: dict[NodeType, str]) -> list[st symbols = [symbolmap[statemap[degree]] for _, degree in tracks.graph.out_degree] return symbols - def update_point_outline(self, visible: list[int] | str) -> None: + def update_point_visibility(self, visible: list[int] | str) -> None: """Update the outline color of the selected points and visibility according to display mode Args: - visible (list[int] | str): A list of track ids, or "all" + visible (list[int] | str | None): A list of track ids, "all" """ # filter out the non-selected tracks if in lineage mode if visible == "all": @@ -203,9 +203,23 @@ def update_point_outline(self, visible: list[int] | str) -> None: self.shown[:] = False self.shown[indices] = True + self.update_point_outline() + + def update_point_outline(self) -> None: # set border color for selected item self.border_color = [1, 1, 1, 1] self.size = 5 + + for node in self.tracks_viewer.filtered_nodes: + index = self.node_index_dict[node] + self.border_color[index] = ( + self.tracks_viewer.filter_color[0], + self.tracks_viewer.filter_color[1], + self.tracks_viewer.filter_color[2], + 1, + ) + self.size[index] = 5 + for node in self.tracks_viewer.selected_nodes: index = self.node_index_dict[node] self.border_color[index] = ( @@ -215,4 +229,5 @@ def update_point_outline(self, visible: list[int] | str) -> None: 1, ) self.size[index] = 7 + self.refresh() diff --git a/src/motile_plugin/data_views/views/layers/tracks_layer_group.py b/src/motile_plugin/data_views/views/layers/tracks_layer_group.py index 188ecee..87d6866 100644 --- a/src/motile_plugin/data_views/views/layers/tracks_layer_group.py +++ b/src/motile_plugin/data_views/views/layers/tracks_layer_group.py @@ -99,7 +99,7 @@ def update_visible(self, visible: list[int]): if self.seg_layer is not None: self.seg_layer.update_label_colormap(visible) if self.points_layer is not None: - self.points_layer.update_point_outline(visible) + self.points_layer.update_point_visibility(visible) if self.tracks_layer is not None: self.tracks_layer.update_track_visibility(visible) diff --git a/src/motile_plugin/data_views/views/tree_view/tree_widget.py b/src/motile_plugin/data_views/views/tree_view/tree_widget.py index d6e5c87..1110a7c 100644 --- a/src/motile_plugin/data_views/views/tree_view/tree_widget.py +++ b/src/motile_plugin/data_views/views/tree_view/tree_widget.py @@ -125,6 +125,8 @@ def update( view_direction: str, feature: str, selected_nodes: list[Any], + filtered_nodes: set[Any], + color: tuple[float], reset_view: bool | None = False, ): """Update the entire view, including the data, view direction, and @@ -139,7 +141,7 @@ def update( self.set_data(track_df, feature) self._update_viewed_data(view_direction) # this can be expensive self.set_view(view_direction, feature, reset_view) - self.set_selection(selected_nodes, feature) + self.set_selection(selected_nodes, filtered_nodes, color, feature) def set_view( self, view_direction: str, feature: str, reset_view: bool | None = False @@ -285,7 +287,13 @@ def _create_pyqtgraph_content(self, track_df: pd.DataFrame, feature: str) -> Non [pg.mkPen(QColor(150, 150, 150)) for i in range(len(self._pos))] ) - def set_selection(self, selected_nodes: list[Any], feature: str) -> None: + def set_selection( + self, + selected_nodes: list[Any], + filtered_nodes: set[Any], + color: tuple, + feature: str, + ) -> None: """Set the provided list of nodes to be selected. Increases the size and highlights the outline with blue. Also centers the view if the first selected node is not visible in the current canvas. @@ -308,6 +316,14 @@ def set_selection(self, selected_nodes: list[Any], feature: str) -> None: "area" if feature == "area" else "x_axis_pos" ) # check what is currently being shown, to know how to scale the view + if len(filtered_nodes) > 0: + color = [c * 255 for c in color] + for node_id in filtered_nodes: + node_df = self.track_df.loc[self.track_df["node_id"] == node_id] + if not node_df.empty: + index = self.node_ids.index(node_id) + outlines[index] = pg.mkPen(color=color, width=2) + if len(selected_nodes) > 0: x_values = [] t_values = [] @@ -438,6 +454,7 @@ def __init__(self, viewer: napari.Viewer): self._update_selected ) self.tracks_viewer.tracks_updated.connect(self._update_track_data) + self.tracks_viewer.filter_widget.apply_filter.connect(self._update_selected) # Construct the tree view pyqtgraph widget layout = QVBoxLayout() @@ -520,6 +537,8 @@ def set_sync(self) -> None: self.view_direction, self.feature, self.selected_nodes, + self.tracks_viewer.filtered_nodes, + self.tracks_viewer.filter_color, reset_view=False, ) elif self.mode == "lineage": @@ -528,6 +547,8 @@ def set_sync(self) -> None: self.view_direction, self.feature, self.selected_nodes, + self.tracks_viewer.filtered_nodes, + self.tracks_viewer.filter_color, reset_view=False, ) elif self.mode == "group": @@ -536,6 +557,8 @@ def set_sync(self) -> None: self.view_direction, self.feature, self.selected_nodes, + self.tracks_viewer.filtered_nodes, + self.tracks_viewer.filter_color, reset_view=False, ) @@ -618,6 +641,8 @@ def sync_views(self, force_update: bool | None = None) -> None: self.view_direction, self.feature, self.selected_nodes, + self.tracks_viewer.filtered_nodes, + self.tracks_viewer.filter_color, reset_view=True, ) @@ -707,6 +732,8 @@ def _update_selected(self): self.view_direction, self.feature, self.selected_nodes, + self.tracks_viewer.filtered_nodes, + self.tracks_viewer.filter_color, ) elif self.mode == "group": self._update_group_df() @@ -715,9 +742,16 @@ def _update_selected(self): self.view_direction, self.feature, self.selected_nodes, + self.tracks_viewer.filtered_nodes, + self.tracks_viewer.filter_color, ) else: - self.tree_widget.set_selection(self.selected_nodes, self.feature) + self.tree_widget.set_selection( + self.selected_nodes, + self.tracks_viewer.filtered_nodes, + self.tracks_viewer.filter_color, + self.feature, + ) def _update_track_data(self, reset_view: bool | None = None) -> None: """Called when the TracksViewer emits the tracks_updated signal, indicating @@ -782,6 +816,8 @@ def _update_track_data(self, reset_view: bool | None = None) -> None: self.view_direction, self.feature, self.selected_nodes, + self.tracks_viewer.filtered_nodes, + self.tracks_viewer.filter_color, reset_view=reset_view, ) @@ -825,7 +861,12 @@ def _set_mode(self, mode: str) -> None: else: # update the tree plot with all nodes in the dataframe self.tree_widget.update( - df, self.view_direction, self.feature, self.selected_nodes + df, + self.view_direction, + self.feature, + self.selected_nodes, + self.tracks_viewer.filtered_nodes, + self.tracks_viewer.filter_color, ) def _set_feature(self, feature: str) -> None: @@ -854,7 +895,12 @@ def _set_feature(self, feature: str) -> None: self.navigation_widget.feature = self.feature self.tree_widget.update( - df, self.view_direction, self.feature, self.selected_nodes + df, + self.view_direction, + self.feature, + self.selected_nodes, + self.tracks_viewer.filtered_nodes, + self.tracks_viewer.filter_color, ) def _update_lineage_df(self) -> None: diff --git a/src/motile_plugin/data_views/views_coordinator/collection_widget.py b/src/motile_plugin/data_views/views_coordinator/collection_widget.py index 02b4c5c..f238fac 100644 --- a/src/motile_plugin/data_views/views_coordinator/collection_widget.py +++ b/src/motile_plugin/data_views/views_coordinator/collection_widget.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import partial -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from motile_toolbox.candidate_graph.graph_attributes import NodeAttr from napari._qt.qt_resources import QColoredSVGIcon @@ -62,59 +62,123 @@ def __init__(self, tracks_viewer: TracksViewer): super().__init__(title="Collections") self.tracks_viewer = tracks_viewer + self.tracks_viewer.selected_nodes.list_updated.connect(self._update_buttons) + + self.group_changed.connect(self._update_buttons) self.collection_list = QListWidget() self.collection_list.setSelectionMode(1) # single selection self.collection_list.itemSelectionChanged.connect(self._selection_changed) self.selected_collection = None + # Select nodes in group + selection_layout = QHBoxLayout() + self.select_btn = QPushButton("Select nodes in group") + self.select_btn.clicked.connect(self._select_nodes) + self.deselect_btn = QPushButton("Deselect") + self.deselect_btn.clicked.connect(self.tracks_viewer.selected_nodes.reset) + selection_layout.addWidget(self.select_btn) + selection_layout.addWidget(self.deselect_btn) + # edit layout edit_widget = QGroupBox("Edit") edit_layout = QVBoxLayout() add_layout = QHBoxLayout() - add_node = QPushButton("Add node(s)") - add_node.clicked.connect(self.add_node) - add_track = QPushButton("Add track(s)") - add_track.clicked.connect(self.add_track) - add_lineage = QPushButton("Add lineage(s)") - add_lineage.clicked.connect(self.add_lineage) - add_layout.addWidget(add_node) - add_layout.addWidget(add_track) - add_layout.addWidget(add_lineage) + self.add_nodes_btn = QPushButton("Add node(s)") + self.add_nodes_btn.clicked.connect(lambda: self.add_nodes(None)) + self.add_track_btn = QPushButton("Add track(s)") + self.add_track_btn.clicked.connect(self._add_track) + self.add_lineage_btn = QPushButton("Add lineage(s)") + self.add_lineage_btn.clicked.connect(self._add_lineage) + add_layout.addWidget(self.add_nodes_btn) + add_layout.addWidget(self.add_track_btn) + add_layout.addWidget(self.add_lineage_btn) remove_layout = QHBoxLayout() - remove_node = QPushButton("Remove node(s)") - remove_node.clicked.connect(self.remove_node) - remove_track = QPushButton("Remove track(s)") - remove_track.clicked.connect(self.remove_track) - remove_lineage = QPushButton("Remove lineage(s)") - remove_lineage.clicked.connect(self.remove_lineage) - remove_layout.addWidget(remove_node) - remove_layout.addWidget(remove_track) - remove_layout.addWidget(remove_lineage) + self.remove_node_btn = QPushButton("Remove node(s)") + self.remove_node_btn.clicked.connect(self._remove_node) + self.remove_track_btn = QPushButton("Remove track(s)") + self.remove_track_btn.clicked.connect(self._remove_track) + self.remove_lineage_btn = QPushButton("Remove lineage(s)") + self.remove_lineage_btn.clicked.connect(self._remove_lineage) + remove_layout.addWidget(self.remove_node_btn) + remove_layout.addWidget(self.remove_track_btn) + remove_layout.addWidget(self.remove_lineage_btn) edit_layout.addLayout(add_layout) edit_layout.addLayout(remove_layout) edit_widget.setLayout(edit_layout) # adding a new group + new_group_box = QGroupBox("New Group") new_group_layout = QHBoxLayout() - new_group_layout.addWidget(QLabel("New group:")) self.group_name = QLineEdit("new group") new_group_layout.addWidget(self.group_name) - new_group_button = QPushButton("Create") - new_group_button.clicked.connect(self.new_group) - new_group_layout.addWidget(new_group_button) + self.new_group_button = QPushButton("Create") + self.new_group_button.clicked.connect( + lambda: self.add_group(name=None, select=True) + ) + new_group_layout.addWidget(self.new_group_button) + new_group_box.setLayout(new_group_layout) # combine widgets layout = QVBoxLayout() layout.addWidget(self.collection_list) + layout.addLayout(selection_layout) layout.addWidget(edit_widget) - layout.addLayout(new_group_layout) + layout.addWidget(new_group_box) self.setLayout(layout) - def retrieve_existing_groups(self): + self._update_buttons() + + def _update_buttons(self) -> None: + """Enable or disable selection and edit buttons depending on whether a group is selected, nodes are selected, and whether the group contains any nodes""" + + selected = self.collection_list.selectedItems() + if selected and len(self.tracks_viewer.selected_nodes) > 0: + self.add_nodes_btn.setEnabled(True) + self.add_track_btn.setEnabled(True) + self.add_lineage_btn.setEnabled(True) + self.remove_node_btn.setEnabled(True) + self.remove_track_btn.setEnabled(True) + self.remove_lineage_btn.setEnabled(True) + else: + self.add_nodes_btn.setEnabled(False) + self.add_track_btn.setEnabled(False) + self.add_lineage_btn.setEnabled(False) + self.remove_node_btn.setEnabled(False) + self.remove_track_btn.setEnabled(False) + self.remove_lineage_btn.setEnabled(False) + self.select_btn.setEnabled(False) + + if selected: + if len(self.selected_collection.collection) > 0: + self.select_btn.setEnabled(True) + else: + self.select_btn.setEnabled(False) + + if len(self.tracks_viewer.selected_nodes) > 0: + self.deselect_btn.setEnabled(True) + else: + self.deselect_btn.setEnabled(False) + + if self.tracks_viewer.tracks is not None: + self.new_group_button.setEnabled(True) + else: + self.new_group_button.setEnabled(False) + + def _select_nodes(self) -> None: + """Select all nodes in the collection""" + + selected = self.collection_list.selectedItems() + if selected: + self.selected_collection = self.collection_list.itemWidget(selected[0]) + self.tracks_viewer.selected_nodes.add_list( + self.selected_collection.collection._list, append=False + ) + + def retrieve_existing_groups(self) -> None: """Create collections based on the node attributes. Nodes assigned to a group should have that group in their 'group' attribute""" # first clear the entire list @@ -128,7 +192,7 @@ def retrieve_existing_groups(self): for group in groups: if group not in group_dict: group_dict[group] = [] - self.add_group(group, select=False) + self.add_group(name=group, select=False) group_dict[group].append(node) # populate the lists based on the nodes that were assigned to the different groups @@ -138,7 +202,9 @@ def retrieve_existing_groups(self): group_dict[self.selected_collection.name.text()] ) - def _selection_changed(self): + self._update_buttons() + + def _selection_changed(self) -> None: """Update the currently selected collection and send update signal""" selected = self.collection_list.selectedItems() @@ -146,12 +212,22 @@ def _selection_changed(self): self.selected_collection = self.collection_list.itemWidget(selected[0]) self.group_changed.emit() - def add_node(self): - """Add individual nodes to the selected collection and send update signal""" + self._update_buttons() + + def add_nodes(self, nodes: list[Any] | None = None) -> None: + """Add individual nodes to the selected collection and send update signal + + Args: + nodes (list, optional): A list of nodes to add to this group. If not provided, the nodes are taken from the current selection in tracks_viewer.selected_nodes. + """ if self.selected_collection is not None: - self.selected_collection.collection.add(self.tracks_viewer.selected_nodes) - for node_id in self.tracks_viewer.selected_nodes: + if nodes is None: + nodes = ( + self.tracks_viewer.selected_nodes._list + ) # take the nodes that are currently selected + self.selected_collection.collection.add(nodes) + for node_id in nodes: if "group" not in self.tracks_viewer.tracks.graph.nodes[node_id]: self.tracks_viewer.tracks.graph.nodes[node_id]["group"] = [] if ( @@ -161,9 +237,10 @@ def add_node(self): self.tracks_viewer.tracks.graph.nodes[node_id]["group"].append( self.selected_collection.name.text() ) + self.group_changed.emit() - def add_track(self): + def _add_track(self) -> None: """Add tracks by track_ids to the selected collection and send update signal""" if self.selected_collection is not None: @@ -193,7 +270,7 @@ def add_track(self): ) self.group_changed.emit() - def add_lineage(self): + def _add_lineage(self) -> None: """Add lineages to the selected collection and send update signal""" if self.selected_collection is not None: @@ -212,7 +289,7 @@ def add_lineage(self): ) self.group_changed.emit() - def remove_node(self): + def _remove_node(self) -> None: """Remove individual nodes from the selected collection and send update signal""" if self.selected_collection is not None: @@ -225,7 +302,7 @@ def remove_node(self): ) self.group_changed.emit() - def remove_track(self): + def _remove_track(self) -> None: """Remove tracks by track id from the selected collection and send update signal""" if self.selected_collection is not None: @@ -249,7 +326,7 @@ def remove_track(self): ) self.group_changed.emit() - def remove_lineage(self): + def _remove_lineage(self) -> None: """Remove lineages from the selected collection and send update signal""" if self.selected_collection is not None: @@ -262,8 +339,16 @@ def remove_lineage(self): ) self.group_changed.emit() - def add_group(self, name: str, select=True): - """Create a new custom group""" + def add_group(self, name: str | None = None, select: bool = True) -> None: + """Create a new custom group + + Args: + name (str, optional): the name to give to this group. If not provided, the name in the self.group_name QLineEdit widget is used. + select (bool, optional): whether or not to make this group the selected item in the QListWidget. Defaults to True + """ + + if name is None: + name = self.group_name.text() names = [ self.collection_list.itemWidget(self.collection_list.item(i)).name.text() @@ -276,11 +361,11 @@ def add_group(self, name: str, select=True): self.collection_list.setItemWidget(item, group_row) item.setSizeHint(group_row.minimumSizeHint()) self.collection_list.addItem(item) - group_row.delete.clicked.connect(partial(self.remove_group, item)) + group_row.delete.clicked.connect(partial(self._remove_group, item)) if select: self.collection_list.setCurrentRow(len(self.collection_list) - 1) - def remove_group(self, item: QListWidgetItem): + def _remove_group(self, item: QListWidgetItem) -> None: """Remove a collection object from the list. You must pass the list item that represents the collection, not the collection object itself. @@ -298,8 +383,3 @@ def remove_group(self, item: QListWidgetItem): if groups and group_name in groups: groups.remove(group_name) # Remove the group from the list - - def new_group(self): - """Create a new group""" - - self.add_group(name=self.group_name.text(), select=True) diff --git a/src/motile_plugin/data_views/views_coordinator/filter_widget.py b/src/motile_plugin/data_views/views_coordinator/filter_widget.py new file mode 100644 index 0000000..6d31549 --- /dev/null +++ b/src/motile_plugin/data_views/views_coordinator/filter_widget.py @@ -0,0 +1,433 @@ +from __future__ import annotations + +import operator +from functools import partial +from typing import TYPE_CHECKING + +import networkx as nx +import numpy as np +from napari._qt.qt_resources import QColoredSVGIcon +from qtpy.QtCore import Signal +from qtpy.QtGui import QColor +from qtpy.QtWidgets import ( + QAbstractItemView, + QComboBox, + QDoubleSpinBox, + QGroupBox, + QHBoxLayout, + QInputDialog, + QLabel, + QLineEdit, + QListWidget, + QListWidgetItem, + QPushButton, + QSpinBox, + QVBoxLayout, + QWidget, +) + +if TYPE_CHECKING: + from motile_plugin.data_views.views_coordinator.tracks_viewer import TracksViewer + + +class Rule(QWidget): + """Widget holding the rule""" + + def __init__(self, graph: nx.diGraph): + super().__init__() + + self.graph = graph + + self.items = set() + for _, attrs in self.graph.nodes(data=True): + self.items.update(attrs.keys()) + if "pos" in self.items: + self.items.remove("pos") + + self.signs = ["<", "<=", ">", ">=", "=", "\u2260"] + + self.logic = ["AND", "OR", "NOT", "XOR"] + + self.item_dropdown = QComboBox() + self.item_dropdown.addItems(self.items) + self.item_dropdown.currentIndexChanged.connect(self._set_sign_value_widget) + + self.sign_dropdown = QComboBox() + self.sign_dropdown.addItems(self.signs) + + self.logic_dropdown = QComboBox() + self.logic_dropdown.addItems(self.logic) + + # Placeholder for the dynamic value widget + self.value_widget = QWidget() + self.value_layout = QVBoxLayout() + self.value_widget.setLayout(self.value_layout) + + delete_icon = QColoredSVGIcon.from_resources("delete").colored("white") + self.delete = QPushButton(icon=delete_icon) + self.delete.setFixedSize(20, 20) + + layout = QHBoxLayout() + layout.addWidget(self.delete) + layout.addWidget(self.item_dropdown) + layout.addWidget(self.sign_dropdown) + layout.addWidget(self.logic_dropdown) + layout.addWidget(self.value_widget) + + self.setLayout(layout) + + # Initialize value widget + self._set_sign_value_widget() + + def _set_sign_value_widget(self): + # Remove the existing value widget from the layout + self.layout().removeWidget(self.value_widget) + self.value_widget.deleteLater() + + # Get the selected item (attribute name) + selected_item = self.item_dropdown.currentText() + + # Determine the attribute type by checking the nodes + attr_type = None + for _, attrs in self.graph.nodes(data=True): + if selected_item in attrs: + attr_type = type(attrs[selected_item]) + break + + # Set up the value widget based on the attribute type + if attr_type == str: + # Use a dropdown for string attributes + unique_values = { + attrs[selected_item] + for _, attrs in self.graph.nodes(data=True) + if selected_item in attrs and isinstance(attrs[selected_item], str) + } + self.value_widget = QComboBox() + self.value_widget.addItems(unique_values) + + # update the signs, cannot use >. >=, <=, < for string comparison + self.sign_dropdown.clear() + signs = ["=", "\u2260"] + self.sign_dropdown.addItems(signs) + + elif attr_type in (int, float, np.float64): + # Use a spin box for numeric attributes (int or float) + self.value_widget = QSpinBox() if attr_type == int else QDoubleSpinBox() + self.value_widget.setMinimum(0) + self.value_widget.setMaximum(100000000) + + # all signs should be allowed + self.sign_dropdown.clear() + self.sign_dropdown.addItems(self.signs) + + elif attr_type in (list, tuple): + self.value_widget = QLineEdit("Type your value here") + # update the signs, cannot use >. >=, <=, < for string comparison + self.sign_dropdown.clear() + signs = ["=", "\u2260"] + self.sign_dropdown.addItems(signs) + + else: + # Fallback if attribute type is not a string of number + self.value_widget = QLabel("No valid attribute type") + + # Add the new value widget to the layout + self.layout().addWidget(self.value_widget) + + def _update_items(self, graph): + # To update the items list when the graph changes + self.items = set() + for _, attrs in graph.nodes(data=True): + self.items.update(attrs.keys()) + self.item_dropdown.clear() + self.item_dropdown.addItems(self.items) + + +class Filter(QWidget): + """Widget holding a name and delete icon for listing in the QListWidget. Also contains an initially empty instance of a Collection to which nodes can be assigned""" + + filter_updated = Signal() + + def __init__(self, tracks_viewer, item: QListWidgetItem): + super().__init__() + + self.tracks_viewer = tracks_viewer + self.item = item + + # rule list widget + self.rule_list = QListWidget() + self.rule_list.setSelectionMode(QAbstractItemView.NoSelection) # no selection + self.setStyleSheet(""" + QListWidget::item:selected { + background-color: #262931; + } + """) + + self.color = QComboBox() + self.color.addItems(["red", "green", "blue", "magenta", "yellow", "orange"]) + self.color.currentIndexChanged.connect(self.update_filter) + + self.item.setBackground(QColor("#262931")) + add_button = QPushButton("Add rule") + add_button.clicked.connect(self.add_rule) + update_button = QPushButton("Apply") + update_button.clicked.connect(self.update_filter) + delete_icon = QColoredSVGIcon.from_resources("delete").colored("white") + self.delete = QPushButton(icon=delete_icon) + self.delete.setFixedSize(40, 40) + self.rule_list = QListWidget() + layout = QHBoxLayout() + layout.setSpacing(1) + + self.rule_list.setFixedHeight(200) + self.setFixedHeight(220) + + settings_layout = QVBoxLayout() + settings_layout.addWidget(self.color) + settings_layout.addWidget(add_button) + settings_layout.addWidget(update_button) + settings_layout.addWidget(self.delete) + + layout.addWidget(self.rule_list) + layout.addLayout(settings_layout) + self.setLayout(layout) + + def sizeHint(self): + hint = super().sizeHint() + hint.setHeight(30) + return hint + + def update_filter(self): + if not self.item.isSelected(): + self.item.setSelected(True) + self.filter_updated.emit() + + def add_rule(self): + """Create a new custom group""" + + item = QListWidgetItem(self.rule_list) + group_row = Rule(self.tracks_viewer.tracks.graph) + self.rule_list.setItemWidget(item, group_row) + item.setSizeHint(group_row.minimumSizeHint()) + item.setBackground(QColor("#262931")) + self.rule_list.addItem(item) + group_row.delete.clicked.connect(partial(self.remove_rule, item)) + + def remove_rule(self, item: QListWidgetItem): + """Remove a collection object from the list. You must pass the list item that + represents the collection, not the collection object itself. + + Args: + item (QListWidgetItem): The list item to remove. This list item + contains the CollectionButton that represents a set of node_ids. + """ + row = self.rule_list.indexFromItem(item).row() + self.rule_list.takeItem(row) + + +class FilterWidget(QGroupBox): + """Widget for construction filters to 'soft' select nodes that meet certain criteria. A new group can be created based on the filtered nodes. Sends a signal when a new filter is selected or when the user clicks the 'apply' button to update the filter criteria.""" + + apply_filter = Signal() + + def __init__(self, tracks_viewer: TracksViewer): + super().__init__(title="Filters") + + self.tracks_viewer = tracks_viewer + + # filter list widget + self.filter_list = QListWidget() + self.filter_list.setSelectionMode(1) # single selection + self.filter_list.itemSelectionChanged.connect(self._selection_changed) + + # edit widget + edit_widget = QWidget() + edit_layout = QHBoxLayout() + self.add_filter_btn = QPushButton("Add new filter") + self.add_filter_btn.clicked.connect(self._add_filter) + self.clear_selection_btn = QPushButton("Deactivate filter") + self.clear_selection_btn.clicked.connect(self.filter_list.clearSelection) + self.create_group_btn = QPushButton("Create group") + self.create_group_btn.clicked.connect(self._create_group) + edit_layout.addWidget(self.clear_selection_btn) + edit_layout.addWidget(self.create_group_btn) + edit_layout.addWidget(self.add_filter_btn) + edit_widget.setLayout(edit_layout) + + # combine widgets + layout = QVBoxLayout() + layout.addWidget(self.filter_list) + layout.addWidget(edit_widget) + self.setLayout(layout) + + self.update_buttons() + + def update_buttons(self) -> None: + """Activate or deactivates the buttons depending on whether a filter is selected or can be created""" + + if self.tracks_viewer.tracks is not None: + self.add_filter_btn.setEnabled(True) + else: + self.add_filter_btn.setEnabled(False) + + selected = self.filter_list.selectedItems() + if selected: + self.clear_selection_btn.setEnabled(True) + else: + self.clear_selection_btn.setEnabled(False) + + if len(self.tracks_viewer.filtered_nodes) > 0 and selected: + self.create_group_btn.setEnabled(True) + else: + self.create_group_btn.setEnabled(False) + + def _selection_changed(self) -> None: + """Check whether a filter is selected, and if so call the function to apply it""" + + selected = self.filter_list.selectedItems() + if selected: + self.selected_filter = self.filter_list.itemWidget(selected[0]) + + color = self.selected_filter.color.currentText() + color = QColor(color) + color.setAlpha(150) + + self.setStyleSheet(f""" + QListWidget::item:selected {{ + background-color: {color.name()}; + }} + """) + + rgb_color = color.getRgb()[:3] + rgb_color = [c / 255 for c in rgb_color] + self.tracks_viewer.filter_color = rgb_color + self._filter_nodes() + else: + self.tracks_viewer.filtered_nodes = {} + self.apply_filter.emit() + + self.update_buttons() + + def _filter_nodes(self) -> None: + """Assign a filtered set of nodes to the tracks_viewer based on the criteria of the selected filter""" + + if self.selected_filter.rule_list.count() == 0: + result_set = set() + + else: + OP_MAP = { + "<": operator.lt, + "<=": operator.le, + ">": operator.gt, + ">=": operator.ge, + "=": operator.eq, + "\u2260": operator.ne, + } + + result_set = set(self.tracks_viewer.tracks.graph.nodes) + + for i in range(self.selected_filter.rule_list.count()): + item = self.selected_filter.rule_list.item(i) # Get the QListWidgetItem + rule = self.selected_filter.rule_list.itemWidget( + item + ) # Get the Rule widget + + # Extract information from each rule + item_value = rule.item_dropdown.currentText() + sign_value = rule.sign_dropdown.currentText() + logic_value = rule.logic_dropdown.currentText() + + if isinstance(rule.value_widget, QSpinBox | QDoubleSpinBox): + value = rule.value_widget.value() + elif isinstance(rule.value_widget, QLineEdit): + value = rule.value_widget.text() + else: + value = rule.value_widget.currentText() + + # Define a set for nodes that satisfy this rule + current_set = set() + + # Get the correct operator function + compare_op = OP_MAP.get(sign_value) + + # Iterate over nodes and apply the rule condition + for node, attrs in self.tracks_viewer.tracks.graph.nodes(data=True): + node_attr_value = attrs.get(item_value) + + try: + if ( + type(node_attr_value) in (tuple, list) + or node_attr_value is None + ): # not all nodes may have a value for given attribute, e.g. not all nodes may belong to a group, and therefore do not have a 'group' attribute + if node_attr_value is None: + node_attr_value = [] + node_attr_value = [str(v) for v in node_attr_value] + if sign_value == "=": + condition = ( + value in node_attr_value + ) # we consider the condition to be true if the requested value is present in the list (even if other values are also present in the list). + else: + condition = value not in node_attr_value + elif isinstance(value, str): + condition = compare_op(str(node_attr_value), value) + else: + node_attr_value = float(node_attr_value) + value = float(value) + condition = compare_op(node_attr_value, value) + + # If the condition is true, add the node to the current set + if condition: + current_set.add(node) + except ( + ValueError, + TypeError, + ): # If there's a type mismatch or conversion issue, skip this node + continue + + # Apply logic to chain the result sets + if logic_value == "AND": + result_set &= current_set + elif logic_value == "OR": + result_set |= current_set + elif logic_value == "NOT": + result_set -= current_set + elif logic_value == "XOR": + result_set ^= current_set + + self.tracks_viewer.filtered_nodes = result_set + self.apply_filter.emit() + + def _create_group(self) -> None: + """Add a new group collection based on the current filter""" + + name, ok_pressed = QInputDialog.getText( + None, "Enter group name", "Please enter a group name:", text="New Group" + ) + if ok_pressed and name: + self.tracks_viewer.collection_widget.add_group(name, select=True) + self.tracks_viewer.collection_widget.add_nodes( + self.tracks_viewer.filtered_nodes + ) + + def _add_filter(self) -> None: + """Create a new empty filter""" + + item = QListWidgetItem(self.filter_list) + filter_row = Filter(self.tracks_viewer, item) + filter_row.filter_updated.connect(self._selection_changed) + self.filter_list.setItemWidget(item, filter_row) + item.setSizeHint(filter_row.minimumSizeHint()) + self.filter_list.addItem(item) + filter_row.delete.clicked.connect(partial(self._remove_filter, item)) + self.filter_list.setCurrentRow(len(self.filter_list) - 1) + + def _remove_filter(self, item: QListWidgetItem) -> None: + """Remove a filter from the list. You must pass the list item that + represents the filter, not the filter object itself. + + Args: + item (QListWidgetItem): The list item to remove. This list item + contains the CollectionButton that represents a set of node_ids. + """ + row = self.filter_list.indexFromItem(item).row() + self.filter_list.takeItem(row) diff --git a/src/motile_plugin/data_views/views_coordinator/tracks_viewer.py b/src/motile_plugin/data_views/views_coordinator/tracks_viewer.py index 63001fd..aea74c9 100644 --- a/src/motile_plugin/data_views/views_coordinator/tracks_viewer.py +++ b/src/motile_plugin/data_views/views_coordinator/tracks_viewer.py @@ -16,6 +16,7 @@ from motile_plugin.utils.relabel_segmentation import relabel_segmentation from .collection_widget import CollectionWidget +from .filter_widget import FilterWidget from .node_selection_list import NodeSelectionList from .tracks_list import TracksList @@ -65,13 +66,22 @@ def __init__( self.selected_nodes = NodeSelectionList() self.selected_nodes.list_updated.connect(self.update_selection) + self.filtered_nodes = {} + self.filter_color = (1, 1, 1, 0) + self.tracks_list = TracksList() self.tracks_list.view_tracks.connect(self.update_tracks) self.collection_widget = CollectionWidget(self) + self.filter_widget = FilterWidget(self) + self.filter_widget.apply_filter.connect(self.apply_filter) + self.set_keybinds() + def apply_filter(self): + self.tracking_layers.points_layer.update_point_outline() + def set_keybinds(self): # TODO: separate and document keybinds (and maybe allow user to choose) self.viewer.bind_key("q")(self.toggle_display_mode) @@ -133,6 +143,10 @@ def update_tracks(self, tracks: Tracks, name: str) -> None: # retrieve existing groups self.collection_widget.retrieve_existing_groups() + # clear filters and update buttons + self.filter_widget.filter_list.clearSelection() + self.filter_widget.update_buttons() + self.set_display_mode("all") self.tracking_layers.set_tracks(tracks, name) self.selected_nodes.reset() From d28238cc699ecbea4a9be557e0a00bf17ca27e9f Mon Sep 17 00:00:00 2001 From: AnniekStok Date: Tue, 12 Nov 2024 15:34:58 +0100 Subject: [PATCH 06/25] clean up --- .../views_coordinator/filter_widget.py | 70 +++++++++---------- 1 file changed, 35 insertions(+), 35 deletions(-) diff --git a/src/motile_plugin/data_views/views_coordinator/filter_widget.py b/src/motile_plugin/data_views/views_coordinator/filter_widget.py index 6d31549..040ab4b 100644 --- a/src/motile_plugin/data_views/views_coordinator/filter_widget.py +++ b/src/motile_plugin/data_views/views_coordinator/filter_widget.py @@ -31,30 +31,33 @@ class Rule(QWidget): - """Widget holding the rule""" + """Widget for constructing a condition to filter by""" def __init__(self, graph: nx.diGraph): super().__init__() self.graph = graph + # assign node attributes as items to filter by self.items = set() for _, attrs in self.graph.nodes(data=True): self.items.update(attrs.keys()) - if "pos" in self.items: + if ( + "pos" in self.items + ): # not using position information right now, consider splitting in x, y, (z) coordinates for filtering? self.items.remove("pos") - self.signs = ["<", "<=", ">", ">=", "=", "\u2260"] - - self.logic = ["AND", "OR", "NOT", "XOR"] - self.item_dropdown = QComboBox() self.item_dropdown.addItems(self.items) self.item_dropdown.currentIndexChanged.connect(self._set_sign_value_widget) + # create a dropdown with signs for comparisons + self.signs = ["<", "<=", ">", ">=", "=", "\u2260"] self.sign_dropdown = QComboBox() self.sign_dropdown.addItems(self.signs) + # create a dropdown with different logical operators for combining multiple conditions + self.logic = ["AND", "OR", "NOT", "XOR"] self.logic_dropdown = QComboBox() self.logic_dropdown.addItems(self.logic) @@ -62,24 +65,27 @@ def __init__(self, graph: nx.diGraph): self.value_widget = QWidget() self.value_layout = QVBoxLayout() self.value_widget.setLayout(self.value_layout) + self._set_sign_value_widget() # Initialize value widget + # Create a delete button for removing the rule delete_icon = QColoredSVGIcon.from_resources("delete").colored("white") self.delete = QPushButton(icon=delete_icon) self.delete.setFixedSize(20, 20) + # Combine widgets and assign layout layout = QHBoxLayout() layout.addWidget(self.delete) layout.addWidget(self.item_dropdown) layout.addWidget(self.sign_dropdown) layout.addWidget(self.logic_dropdown) layout.addWidget(self.value_widget) - self.setLayout(layout) - # Initialize value widget - self._set_sign_value_widget() + def _set_sign_value_widget(self) -> None: + """Replaces self.value_widget with a new widget of the appropriate type: combobox for string types, spinboxes for numerical values, QLineEdit for values in tuples or lists. + Also assigns the correct signs to self.sign_dropdown, as >, >=, <=, and < cannot be used for string comparisons. + """ - def _set_sign_value_widget(self): # Remove the existing value widget from the layout self.layout().removeWidget(self.value_widget) self.value_widget.deleteLater() @@ -105,7 +111,7 @@ def _set_sign_value_widget(self): self.value_widget = QComboBox() self.value_widget.addItems(unique_values) - # update the signs, cannot use >. >=, <=, < for string comparison + # update the signs, cannot use >, >=, <=, < for string comparison self.sign_dropdown.clear() signs = ["=", "\u2260"] self.sign_dropdown.addItems(signs) @@ -134,17 +140,9 @@ def _set_sign_value_widget(self): # Add the new value widget to the layout self.layout().addWidget(self.value_widget) - def _update_items(self, graph): - # To update the items list when the graph changes - self.items = set() - for _, attrs in graph.nodes(data=True): - self.items.update(attrs.keys()) - self.item_dropdown.clear() - self.item_dropdown.addItems(self.items) - class Filter(QWidget): - """Widget holding a name and delete icon for listing in the QListWidget. Also contains an initially empty instance of a Collection to which nodes can be assigned""" + """Filter widget containing a single filter, composed of multiple Rule widgets formulating conditions to filter by""" filter_updated = Signal() @@ -157,21 +155,25 @@ def __init__(self, tracks_viewer, item: QListWidgetItem): # rule list widget self.rule_list = QListWidget() self.rule_list.setSelectionMode(QAbstractItemView.NoSelection) # no selection + self.item.setBackground(QColor("#262931")) self.setStyleSheet(""" QListWidget::item:selected { background-color: #262931; } """) + self.rule_list.setFixedHeight(200) + self.setFixedHeight(220) + # available colors self.color = QComboBox() self.color.addItems(["red", "green", "blue", "magenta", "yellow", "orange"]) - self.color.currentIndexChanged.connect(self.update_filter) + self.color.currentIndexChanged.connect(self._update_filter) - self.item.setBackground(QColor("#262931")) + # adding, removing, updating rules add_button = QPushButton("Add rule") - add_button.clicked.connect(self.add_rule) + add_button.clicked.connect(self._add_rule) update_button = QPushButton("Apply") - update_button.clicked.connect(self.update_filter) + update_button.clicked.connect(self._update_filter) delete_icon = QColoredSVGIcon.from_resources("delete").colored("white") self.delete = QPushButton(icon=delete_icon) self.delete.setFixedSize(40, 40) @@ -179,9 +181,7 @@ def __init__(self, tracks_viewer, item: QListWidgetItem): layout = QHBoxLayout() layout.setSpacing(1) - self.rule_list.setFixedHeight(200) - self.setFixedHeight(220) - + # combine settings widgets settings_layout = QVBoxLayout() settings_layout.addWidget(self.color) settings_layout.addWidget(add_button) @@ -192,17 +192,17 @@ def __init__(self, tracks_viewer, item: QListWidgetItem): layout.addLayout(settings_layout) self.setLayout(layout) - def sizeHint(self): + def sizeHint(self) -> None: hint = super().sizeHint() hint.setHeight(30) return hint - def update_filter(self): + def _update_filter(self) -> None: if not self.item.isSelected(): self.item.setSelected(True) self.filter_updated.emit() - def add_rule(self): + def _add_rule(self) -> None: """Create a new custom group""" item = QListWidgetItem(self.rule_list) @@ -211,15 +211,15 @@ def add_rule(self): item.setSizeHint(group_row.minimumSizeHint()) item.setBackground(QColor("#262931")) self.rule_list.addItem(item) - group_row.delete.clicked.connect(partial(self.remove_rule, item)) + group_row.delete.clicked.connect(partial(self._remove_rule, item)) - def remove_rule(self, item: QListWidgetItem): - """Remove a collection object from the list. You must pass the list item that - represents the collection, not the collection object itself. + def _remove_rule(self, item: QListWidgetItem) -> None: + """Remove a rule from the list. You must pass the list item that + represents the rule, not the rule object itself. Args: item (QListWidgetItem): The list item to remove. This list item - contains the CollectionButton that represents a set of node_ids. + contains the Rule that represents a set of node_ids. """ row = self.rule_list.indexFromItem(item).row() self.rule_list.takeItem(row) From 610f87ad9ec48cdac6b3f32c4c428d0ed290e756 Mon Sep 17 00:00:00 2001 From: AnniekStok Date: Tue, 12 Nov 2024 15:56:08 +0100 Subject: [PATCH 07/25] add node count to groups, fix error triggered by removing nodes from collection that were not part of it --- .../views_coordinator/collection_widget.py | 38 ++++++++++++++----- 1 file changed, 29 insertions(+), 9 deletions(-) diff --git a/src/motile_plugin/data_views/views_coordinator/collection_widget.py b/src/motile_plugin/data_views/views_coordinator/collection_widget.py index f238fac..11f3a8d 100644 --- a/src/motile_plugin/data_views/views_coordinator/collection_widget.py +++ b/src/motile_plugin/data_views/views_coordinator/collection_widget.py @@ -37,11 +37,13 @@ def __init__(self, name: str): self.name.setFixedHeight(20) self.collection = Collection() delete_icon = QColoredSVGIcon.from_resources("delete").colored("white") + self.node_count = QLabel(f"{len(self.collection)} nodes") self.delete = QPushButton(icon=delete_icon) self.delete.setFixedSize(20, 20) layout = QHBoxLayout() layout.setSpacing(1) layout.addWidget(self.name) + layout.addWidget(self.node_count) layout.addWidget(self.delete) self.setLayout(layout) @@ -50,6 +52,9 @@ def sizeHint(self): hint.setHeight(30) return hint + def update_node_count(self): + self.node_count.setText(f"{len(self.collection)} nodes") + class CollectionWidget(QGroupBox): """Widget for holding in-memory Collections (groups). Emits a signal whenever @@ -62,9 +67,11 @@ def __init__(self, tracks_viewer: TracksViewer): super().__init__(title="Collections") self.tracks_viewer = tracks_viewer - self.tracks_viewer.selected_nodes.list_updated.connect(self._update_buttons) + self.tracks_viewer.selected_nodes.list_updated.connect( + self._update_buttons_and_node_count + ) - self.group_changed.connect(self._update_buttons) + self.group_changed.connect(self._update_buttons_and_node_count) self.collection_list = QListWidget() self.collection_list.setSelectionMode(1) # single selection @@ -130,9 +137,9 @@ def __init__(self, tracks_viewer: TracksViewer): layout.addWidget(new_group_box) self.setLayout(layout) - self._update_buttons() + self._update_buttons_and_node_count() - def _update_buttons(self) -> None: + def _update_buttons_and_node_count(self) -> None: """Enable or disable selection and edit buttons depending on whether a group is selected, nodes are selected, and whether the group contains any nodes""" selected = self.collection_list.selectedItems() @@ -153,6 +160,10 @@ def _update_buttons(self) -> None: self.select_btn.setEnabled(False) if selected: + self.collection_list.itemWidget( + selected[0] + ).update_node_count() # update the node count + if len(self.selected_collection.collection) > 0: self.select_btn.setEnabled(True) else: @@ -202,7 +213,7 @@ def retrieve_existing_groups(self) -> None: group_dict[self.selected_collection.name.text()] ) - self._update_buttons() + self._update_buttons_and_node_count() def _selection_changed(self) -> None: """Update the currently selected collection and send update signal""" @@ -212,7 +223,7 @@ def _selection_changed(self) -> None: self.selected_collection = self.collection_list.itemWidget(selected[0]) self.group_changed.emit() - self._update_buttons() + self._update_buttons_and_node_count() def add_nodes(self, nodes: list[Any] | None = None) -> None: """Add individual nodes to the selected collection and send update signal @@ -293,13 +304,22 @@ def _remove_node(self) -> None: """Remove individual nodes from the selected collection and send update signal""" if self.selected_collection is not None: + # remove from the collection self.selected_collection.collection.remove( self.tracks_viewer.selected_nodes ) + + # remove from the node attribute for node_id in self.tracks_viewer.selected_nodes: - self.tracks_viewer.tracks.graph.nodes[node_id]["group"].remove( - self.selected_collection.name.text() - ) + node_attrs = self.tracks_viewer.tracks.graph.nodes[node_id] + group_list = node_attrs.get("group") + if ( + isinstance(group_list, list) + and self.selected_collection.name.text() in group_list + ): + # Remove the value if it exists in the list + group_list.remove(self.selected_collection.name.text()) + self.group_changed.emit() def _remove_track(self) -> None: From ff8cdd7e4a5020cf9a87869ffc028829a3b9e6b0 Mon Sep 17 00:00:00 2001 From: AnniekStok Date: Tue, 12 Nov 2024 17:17:56 +0100 Subject: [PATCH 08/25] fix bug, can only set value widget after adding it to the layout --- .../data_views/views_coordinator/filter_widget.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/motile_plugin/data_views/views_coordinator/filter_widget.py b/src/motile_plugin/data_views/views_coordinator/filter_widget.py index 040ab4b..8a8f0ba 100644 --- a/src/motile_plugin/data_views/views_coordinator/filter_widget.py +++ b/src/motile_plugin/data_views/views_coordinator/filter_widget.py @@ -65,7 +65,6 @@ def __init__(self, graph: nx.diGraph): self.value_widget = QWidget() self.value_layout = QVBoxLayout() self.value_widget.setLayout(self.value_layout) - self._set_sign_value_widget() # Initialize value widget # Create a delete button for removing the rule delete_icon = QColoredSVGIcon.from_resources("delete").colored("white") @@ -81,6 +80,9 @@ def __init__(self, graph: nx.diGraph): layout.addWidget(self.value_widget) self.setLayout(layout) + # Initialize value widget + self._set_sign_value_widget() + def _set_sign_value_widget(self) -> None: """Replaces self.value_widget with a new widget of the appropriate type: combobox for string types, spinboxes for numerical values, QLineEdit for values in tuples or lists. Also assigns the correct signs to self.sign_dropdown, as >, >=, <=, and < cannot be used for string comparisons. From 74f6d13866588dd9ed97c01c1d18f3495029afe1 Mon Sep 17 00:00:00 2001 From: AnniekStok Date: Wed, 13 Nov 2024 09:20:46 +0100 Subject: [PATCH 09/25] fix layout in filter widget --- .../data_views/views_coordinator/filter_widget.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/motile_plugin/data_views/views_coordinator/filter_widget.py b/src/motile_plugin/data_views/views_coordinator/filter_widget.py index 8a8f0ba..f00b5d9 100644 --- a/src/motile_plugin/data_views/views_coordinator/filter_widget.py +++ b/src/motile_plugin/data_views/views_coordinator/filter_widget.py @@ -163,8 +163,6 @@ def __init__(self, tracks_viewer, item: QListWidgetItem): background-color: #262931; } """) - self.rule_list.setFixedHeight(200) - self.setFixedHeight(220) # available colors self.color = QComboBox() @@ -194,6 +192,10 @@ def __init__(self, tracks_viewer, item: QListWidgetItem): layout.addLayout(settings_layout) self.setLayout(layout) + # Set fixed size (keep this at the end of the init!) + self.rule_list.setFixedHeight(200) + self.setFixedHeight(220) + def sizeHint(self) -> None: hint = super().sizeHint() hint.setHeight(30) From b6561a7d709ea34d5e9dc2916b5a9e495443de94 Mon Sep 17 00:00:00 2001 From: AnniekStok Date: Wed, 13 Nov 2024 09:37:38 +0100 Subject: [PATCH 10/25] enforce updating the node count for all groups when loading back in --- .../data_views/views_coordinator/collection_widget.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/motile_plugin/data_views/views_coordinator/collection_widget.py b/src/motile_plugin/data_views/views_coordinator/collection_widget.py index 11f3a8d..6f41194 100644 --- a/src/motile_plugin/data_views/views_coordinator/collection_widget.py +++ b/src/motile_plugin/data_views/views_coordinator/collection_widget.py @@ -212,6 +212,7 @@ def retrieve_existing_groups(self) -> None: self.selected_collection.collection.add( group_dict[self.selected_collection.name.text()] ) + self.selected_collection.update_node_count() # enforce updating the node count for all elements self._update_buttons_and_node_count() From a06210a9f00cf9842e931fc036dc10012b8eee98 Mon Sep 17 00:00:00 2001 From: AnniekStok Date: Mon, 2 Dec 2024 16:58:55 +0100 Subject: [PATCH 11/25] wip: use histograms for selecting a range of data when filtering. Use track_df (now on tracks_viewer) instead of graph for filtering --- .../data_views/views/tree_view/tree_widget.py | 50 +- .../views_coordinator/filter_widget.py | 448 ++++++++++++------ 2 files changed, 315 insertions(+), 183 deletions(-) diff --git a/src/motile_plugin/data_views/views/tree_view/tree_widget.py b/src/motile_plugin/data_views/views/tree_view/tree_widget.py index 1110a7c..d5d3899 100644 --- a/src/motile_plugin/data_views/views/tree_view/tree_widget.py +++ b/src/motile_plugin/data_views/views/tree_view/tree_widget.py @@ -426,7 +426,6 @@ class TreeWidget(QWidget): def __init__(self, viewer: napari.Viewer): super().__init__() self.viewer = viewer - self.track_df = pd.DataFrame() # all tracks self.lineage_df = pd.DataFrame() # the currently viewed subset of lineages self.group_df = pd.DataFrame() # the currently viewed group self.sync_df = ( @@ -448,6 +447,7 @@ def __init__(self, viewer: napari.Viewer): # Connect to tracks_viewer self.tracks_viewer = TracksViewer.get_instance(viewer) + self.tracks_viewer.track_df = pd.DataFrame() # all tracks self.selected_nodes = self.tracks_viewer.selected_nodes self.selected_nodes.list_updated.connect(self._update_selected) self.tracks_viewer.collection_widget.group_changed.connect( @@ -473,7 +473,7 @@ def __init__(self, viewer: napari.Viewer): # Add navigation widget self.navigation_widget = NavigationWidget( - self.track_df, + self.tracks_viewer.track_df, self.lineage_df, self.view_direction, self.selected_nodes, @@ -533,7 +533,7 @@ def set_sync(self) -> None: if self.mode == "all": self.tree_widget.update( - self.track_df, + self.tracks_viewer.track_df, self.view_direction, self.feature, self.selected_nodes, @@ -571,7 +571,7 @@ def sync_views(self, force_update: bool | None = None) -> None: """Sync the data in the tree plot with the data in the field of view of the napari viewer""" if self.sync_df is None: - prev_visible = self.track_df["node_id"].tolist() + prev_visible = self.tracks_viewer.track_df["node_id"].tolist() else: prev_visible = self.sync_df["node_id"].tolist() @@ -591,11 +591,11 @@ def sync_views(self, force_update: bool | None = None) -> None: _max_y = corner_coordinates[1][y_dim] if self.mode == "all": - visible_nodes = self.track_df[ - (self.track_df["x"] >= _min_x) - & (self.track_df["x"] <= _max_x) - & (self.track_df["y"] >= _min_y) - & (self.track_df["y"] <= _max_y) + visible_nodes = self.tracks_viewer.track_df[ + (self.tracks_viewer.track_df["x"] >= _min_x) + & (self.tracks_viewer.track_df["x"] <= _max_x) + & (self.tracks_viewer.track_df["y"] >= _min_y) + & (self.tracks_viewer.track_df["y"] <= _max_y) ]["node_id"].tolist() elif self.mode == "lineage": visible_nodes = self.lineage_df[ @@ -620,8 +620,8 @@ def sync_views(self, force_update: bool | None = None) -> None: set(visible) != set(prev_visible) or force_update ): # only call update function if the list of visible nodes has changed if self.mode == "all": - self.sync_df = self.track_df[ - self.track_df["node_id"].isin(visible) + self.sync_df = self.tracks_viewer.track_df[ + self.tracks_viewer.track_df["node_id"].isin(visible) ].reset_index() elif self.mode == "lineage": self.sync_df = self.lineage_df[ @@ -759,26 +759,26 @@ def _update_track_data(self, reset_view: bool | None = None) -> None: """ if self.tracks_viewer.tracks is None: - self.track_df = pd.DataFrame() + self.tracks_viewer.track_df = pd.DataFrame() self.graph = None self.sync_widget.sync_button.setEnabled(False) else: if reset_view: - self.track_df = extract_sorted_tracks( + self.tracks_viewer.track_df = extract_sorted_tracks( self.tracks_viewer.tracks, self.tracks_viewer.colormap ) else: - self.track_df = extract_sorted_tracks( + self.tracks_viewer.track_df = extract_sorted_tracks( self.tracks_viewer.tracks, self.tracks_viewer.colormap, - self.track_df, + self.tracks_viewer.track_df, ) self.graph = self.tracks_viewer.tracks.graph self.sync_widget.sync_button.setEnabled(True) # check whether we have area measurements and therefore should activate the area # button - if "area" not in self.track_df.columns: + if "area" not in self.tracks_viewer.track_df.columns: if self.feature_widget.feature == "area": self.feature_widget._toggle_feature_mode() self.feature_widget.show_area_radio.setEnabled(False) @@ -795,7 +795,7 @@ def _update_track_data(self, reset_view: bool | None = None) -> None: self.feature_widget.show_tree_radio.setChecked(True) # also update the navigation widget - self.navigation_widget.track_df = self.track_df + self.navigation_widget.track_df = self.tracks_viewer.track_df self.navigation_widget.lineage_df = self.lineage_df # get the dataframe of currently displayed data, then update plot with or without filtering by the field of view (if sync is on) @@ -806,7 +806,7 @@ def _update_track_data(self, reset_view: bool | None = None) -> None: self._update_group_df() df = self.group_df else: - df = self.track_df + df = self.tracks_viewer.track_df if self.sync: self.sync_views(force_update=True) @@ -837,7 +837,7 @@ def _set_mode(self, mode: str) -> None: self.view_direction = "vertical" else: self.view_direction = "horizontal" - df = self.track_df + df = self.tracks_viewer.track_df elif mode == "group": if self.feature == "tree": self.view_direction = "vertical" @@ -887,7 +887,7 @@ def _set_feature(self, feature: str) -> None: self.navigation_widget.view_direction = self.view_direction if self.mode == "all": - df = self.track_df + df = self.tracks_viewer.track_df if self.mode == "lineage": df = self.lineage_df if self.mode == "group": @@ -922,8 +922,8 @@ def _update_lineage_df(self) -> None: visible = [] for node_id in self.selected_nodes: visible += extract_lineage_tree(self.graph, node_id) - self.lineage_df = self.track_df[ - self.track_df["node_id"].isin(visible) + self.lineage_df = self.tracks_viewer.track_df[ + self.tracks_viewer.track_df["node_id"].isin(visible) ].reset_index() self.lineage_df["x_axis_pos"] = ( self.lineage_df["x_axis_pos"].rank(method="dense").astype(int) - 1 @@ -938,14 +938,14 @@ def _update_group_df(self) -> None: for ( node_id ) in self.tracks_viewer.collection_widget.selected_collection.collection: - if node_id in self.track_df["node_id"].tolist(): + if node_id in self.tracks_viewer.track_df["node_id"].tolist(): visible += extract_lineage_tree(self.graph, node_id) else: self.tracks_viewer.collection_widget.selected_collection.collection._list.remove( node_id ) - self.group_df = self.track_df[ - self.track_df["node_id"].isin(visible) + self.group_df = self.tracks_viewer.track_df[ + self.tracks_viewer.track_df["node_id"].isin(visible) ].reset_index() self.group_df["x_axis_pos"] = ( self.group_df["x_axis_pos"].rank(method="dense").astype(int) - 1 diff --git a/src/motile_plugin/data_views/views_coordinator/filter_widget.py b/src/motile_plugin/data_views/views_coordinator/filter_widget.py index f00b5d9..1e75d27 100644 --- a/src/motile_plugin/data_views/views_coordinator/filter_widget.py +++ b/src/motile_plugin/data_views/views_coordinator/filter_widget.py @@ -1,26 +1,26 @@ from __future__ import annotations -import operator from functools import partial from typing import TYPE_CHECKING -import networkx as nx import numpy as np +import pandas as pd from napari._qt.qt_resources import QColoredSVGIcon from qtpy.QtCore import Signal from qtpy.QtGui import QColor from qtpy.QtWidgets import ( QAbstractItemView, + QCheckBox, QComboBox, QDoubleSpinBox, QGroupBox, QHBoxLayout, QInputDialog, QLabel, - QLineEdit, QListWidget, QListWidgetItem, QPushButton, + QScrollArea, QSpinBox, QVBoxLayout, QWidget, @@ -29,42 +29,224 @@ if TYPE_CHECKING: from motile_plugin.data_views.views_coordinator.tracks_viewer import TracksViewer +import pyqtgraph as pg +from qtpy.QtCore import Qt + + +class HistogramRangeSliderWidget(QWidget): + update_rule = Signal() + + def __init__(self, df: pd.DataFrame, fill_color: QColor): + super().__init__() + + self.data = df.dropna() + self.fill_color = fill_color + self.data_min = self.data.min() + self.data_max = self.data.max() + self.dtype = self.data.dtypes + + # Create a PlotWidget for the histogram + self.plot_widget = pg.PlotWidget() + self.plot_item = self.plot_widget.plotItem + + # Customize the axes + self.plot_item.getAxis("bottom").setTicks([]) + self.plot_item.getAxis("left").setTicks([]) + + # Set ViewBox limits to restrict panning and zooming + self.plot_item.getViewBox().setMouseEnabled(x=False, y=False) + self.plot_item.getViewBox().setLimits( + xMin=self.data_min, + xMax=self.data_max, # Restrict horizontal panning + yMin=0, # Prevent panning below 0 (optional, for clarity) + ) + + # Create the histogram data + hist, edges = np.histogram(self.data, bins="rice") + bar_graph = pg.BarGraphItem( + x=edges[:-1], height=hist, width=np.diff(edges), brush="white" + ) + self.plot_item.addItem(bar_graph) + + # Add a movable region (range slider) + slider_range = self.data_max - self.data_min + if self.dtype == int: + slider_range = int(slider_range) + self.region = pg.LinearRegionItem( + values=[ + self.data_min + 0.25 * slider_range, + self.data_max - 0.25 * slider_range, + ] + ) + self.region.setZValue(10) # Ensure it is on top of the bars + self.region.sigRegionChangeFinished.connect(self.update_range_label) + + # Customize the region's appearance + self.fill_color.setAlpha(100) + self.region.setBrush(pg.mkBrush(self.fill_color)) + self.fill_color.setAlpha(120) + self.region.setHoverBrush(pg.mkBrush(self.fill_color)) + self.plot_item.addItem(self.region) + + # Add QSpinBoxes to view and adjust the region range + if self.dtype == int: + self.min_spinbox = QSpinBox() + self.max_spinbox = QSpinBox() + self.min_spinbox.setValue(self.data_min + int(0.25 * slider_range)) + self.max_spinbox.setValue(self.data_max - int(0.25 * slider_range)) + + else: + self.min_spinbox = QDoubleSpinBox() + self.max_spinbox = QDoubleSpinBox() + self.min_spinbox.setValue(self.data_min + 0.25 * slider_range) + self.max_spinbox.setValue(self.data_max - 0.25 * slider_range) + + self.min_spinbox.setMinimum(self.data_min) + self.min_spinbox.setMaximum(self.data_max) + self.max_spinbox.setMinimum(self.data_min) + self.max_spinbox.setMaximum(self.data_max) + + self.min_spinbox.valueChanged.connect(self.update_min_value) + self.max_spinbox.valueChanged.connect(self.update_max_value) + + min_spin_box_layout = QHBoxLayout() + min_spin_box_layout.addWidget(QLabel("Min value")) + min_spin_box_layout.addWidget(self.min_spinbox) + max_spin_box_layout = QHBoxLayout() + max_spin_box_layout.addWidget(QLabel("Max value")) + max_spin_box_layout.addWidget(self.max_spinbox) + spinbox_layout = QVBoxLayout() + spinbox_layout.addLayout(min_spin_box_layout) + spinbox_layout.addLayout(max_spin_box_layout) + + # combine layouts + layout = QVBoxLayout() + layout.addLayout(spinbox_layout) + layout.addWidget(self.plot_widget) + layout.setContentsMargins(1, 1, 1, 2) + layout.setSpacing(0) + self.setLayout(layout) + + def update_range_label(self): + """Update the displayed range based on the region slider.""" + region = self.region.getRegion() + + if self.dtype == int: + self.min_spinbox.setValue(int(region[0])) + self.max_spinbox.setValue(int(region[1])) + else: + self.min_spinbox.setValue(region[0]) + self.max_spinbox.setValue(region[1]) + + def update_min_value(self): + """Update the region by the value in the min_spinbox""" + + min_value = self.min_spinbox.value() + max_value = self.region.getRegion()[1] + if self.dtype == int: + min_value = int(min_value) + max_value = int(max_value) + + self.region.setRegion([min_value, max_value]) + self.update_rule.emit() + + def update_max_value(self): + """Update the region by the value in the max_spinbox""" + + min_value = self.region.getRegion()[0] + max_value = self.max_spinbox.value() + + if self.dtype == int: + min_value = int(min_value) + max_value = int(max_value) + + self.region.setRegion([min_value, max_value]) + self.update_rule.emit() + + def update_color(self, color: QColor): + """Update the color of the region""" + + self.fill_color = color + self.fill_color.setAlpha(100) + self.region.setBrush(pg.mkBrush(self.fill_color)) + self.fill_color.setAlpha(120) + self.region.setHoverBrush(pg.mkBrush(self.fill_color)) + + +class MultipleChoiceWidget(QWidget): + update_rule = Signal() + + def __init__(self, dataframe, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Get the unique values from the dataframe column + self.unique_values = dataframe.dropna().unique() + + # Create a scroll area for the checkboxes + scroll_area = QScrollArea() + scroll_area.setWidgetResizable(True) + scroll_widget = QWidget() + scroll_area.setWidget(scroll_widget) + scroll_layout = QVBoxLayout(scroll_widget) + scroll_layout.setContentsMargins(0, 0, 0, 0) + scroll_layout.setSpacing(0) + + # Create a checkbox for each unique value + self.checkboxes = [] + for value in self.unique_values: + checkbox = QCheckBox(str(value)) + checkbox.stateChanged.connect(self._update) + self.checkboxes.append(checkbox) + scroll_layout.addWidget(checkbox) + + # Add the scroll area to the layout + layout = QVBoxLayout() + layout.addWidget(scroll_area) + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(0) + self.setLayout(layout) + + def _update(self): + """Emit update signal""" + + self.update_rule.emit() + + def get_selected_choices(self): + """Return the list of selected choices.""" + + selected = [ + checkbox.text() for checkbox in self.checkboxes if checkbox.isChecked() + ] + return selected + class Rule(QWidget): """Widget for constructing a condition to filter by""" - def __init__(self, graph: nx.diGraph): + update = Signal() + + def __init__(self, df: pd.DataFrame, fill_color: tuple[int, int, int, int]): super().__init__() - self.graph = graph + self.data = df + self.fill_color = fill_color # assign node attributes as items to filter by - self.items = set() - for _, attrs in self.graph.nodes(data=True): - self.items.update(attrs.keys()) - if ( - "pos" in self.items - ): # not using position information right now, consider splitting in x, y, (z) coordinates for filtering? - self.items.remove("pos") - + self.items = self.data.columns self.item_dropdown = QComboBox() self.item_dropdown.addItems(self.items) - self.item_dropdown.currentIndexChanged.connect(self._set_sign_value_widget) - - # create a dropdown with signs for comparisons - self.signs = ["<", "<=", ">", ">=", "=", "\u2260"] - self.sign_dropdown = QComboBox() - self.sign_dropdown.addItems(self.signs) + self.item_dropdown.currentIndexChanged.connect(self._set_value_widget) # create a dropdown with different logical operators for combining multiple conditions self.logic = ["AND", "OR", "NOT", "XOR"] self.logic_dropdown = QComboBox() self.logic_dropdown.addItems(self.logic) + self.logic_dropdown.currentIndexChanged.connect(self._update) # Placeholder for the dynamic value widget self.value_widget = QWidget() self.value_layout = QVBoxLayout() - self.value_widget.setLayout(self.value_layout) + self.value_layout.addWidget(self.value_widget) # Create a delete button for removing the rule delete_icon = QColoredSVGIcon.from_resources("delete").colored("white") @@ -73,73 +255,46 @@ def __init__(self, graph: nx.diGraph): # Combine widgets and assign layout layout = QHBoxLayout() - layout.addWidget(self.delete) - layout.addWidget(self.item_dropdown) - layout.addWidget(self.sign_dropdown) - layout.addWidget(self.logic_dropdown) - layout.addWidget(self.value_widget) + menu_widget_layout = QVBoxLayout() + menu_widget_layout.addWidget(self.item_dropdown) + menu_widget_layout.addWidget(self.logic_dropdown) + menu_widget_layout.addWidget(self.delete) + layout.addLayout(menu_widget_layout) + layout.addLayout(self.value_layout) + layout.setContentsMargins(0, 1, 0, 0) + layout.setSpacing(1) + self.setLayout(layout) - # Initialize value widget - self._set_sign_value_widget() + def update_color(self, color: QColor): + """Update the color on the value_widget""" - def _set_sign_value_widget(self) -> None: - """Replaces self.value_widget with a new widget of the appropriate type: combobox for string types, spinboxes for numerical values, QLineEdit for values in tuples or lists. - Also assigns the correct signs to self.sign_dropdown, as >, >=, <=, and < cannot be used for string comparisons. - """ + self.fill_color = color + if isinstance(self.value_widget, HistogramRangeSliderWidget): + self.value_widget.update_color(color) + + def _update(self): + self.update.emit() + + def _set_value_widget(self): + """Replaces self.value_widget with a new widget of the appropriate type: multiplechoice widget for categorical values, and a histogram for numerical values.""" - # Remove the existing value widget from the layout self.layout().removeWidget(self.value_widget) self.value_widget.deleteLater() - # Get the selected item (attribute name) - selected_item = self.item_dropdown.currentText() - - # Determine the attribute type by checking the nodes - attr_type = None - for _, attrs in self.graph.nodes(data=True): - if selected_item in attrs: - attr_type = type(attrs[selected_item]) - break - - # Set up the value widget based on the attribute type - if attr_type == str: - # Use a dropdown for string attributes - unique_values = { - attrs[selected_item] - for _, attrs in self.graph.nodes(data=True) - if selected_item in attrs and isinstance(attrs[selected_item], str) - } - self.value_widget = QComboBox() - self.value_widget.addItems(unique_values) - - # update the signs, cannot use >, >=, <=, < for string comparison - self.sign_dropdown.clear() - signs = ["=", "\u2260"] - self.sign_dropdown.addItems(signs) - - elif attr_type in (int, float, np.float64): - # Use a spin box for numeric attributes (int or float) - self.value_widget = QSpinBox() if attr_type == int else QDoubleSpinBox() - self.value_widget.setMinimum(0) - self.value_widget.setMaximum(100000000) - - # all signs should be allowed - self.sign_dropdown.clear() - self.sign_dropdown.addItems(self.signs) - - elif attr_type in (list, tuple): - self.value_widget = QLineEdit("Type your value here") - # update the signs, cannot use >. >=, <=, < for string comparison - self.sign_dropdown.clear() - signs = ["=", "\u2260"] - self.sign_dropdown.addItems(signs) + df = self.data[self.item_dropdown.currentText()] + if df.dtype == int or df.dtype == float: + self.value_widget = HistogramRangeSliderWidget(df, self.fill_color) + self.value_layout = QVBoxLayout() + self.value_widget.setLayout(self.value_layout) else: - # Fallback if attribute type is not a string of number - self.value_widget = QLabel("No valid attribute type") + self.value_widget = MultipleChoiceWidget(df) + self.value_layout = QVBoxLayout() + self.value_widget.setLayout(self.value_layout) # Add the new value widget to the layout + self.value_widget.update_rule.connect(self._update) self.layout().addWidget(self.value_widget) @@ -167,29 +322,33 @@ def __init__(self, tracks_viewer, item: QListWidgetItem): # available colors self.color = QComboBox() self.color.addItems(["red", "green", "blue", "magenta", "yellow", "orange"]) + self.color.setFixedSize(100, 20) self.color.currentIndexChanged.connect(self._update_filter) + self.current_color = "red" # adding, removing, updating rules - add_button = QPushButton("Add rule") + add_button = QPushButton("+") + add_button.setFixedSize(20, 20) add_button.clicked.connect(self._add_rule) - update_button = QPushButton("Apply") - update_button.clicked.connect(self._update_filter) delete_icon = QColoredSVGIcon.from_resources("delete").colored("white") self.delete = QPushButton(icon=delete_icon) - self.delete.setFixedSize(40, 40) + self.delete.setFixedSize(20, 20) self.rule_list = QListWidget() - layout = QHBoxLayout() + layout = QVBoxLayout() layout.setSpacing(1) # combine settings widgets - settings_layout = QVBoxLayout() - settings_layout.addWidget(self.color) + settings_layout = QHBoxLayout() settings_layout.addWidget(add_button) - settings_layout.addWidget(update_button) + settings_layout.addWidget(self.color) settings_layout.addWidget(self.delete) + settings_layout.setAlignment(Qt.AlignLeft) + settings_layout.setSpacing(1) + settings_layout.setContentsMargins(1, 0, 1, 0) - layout.addWidget(self.rule_list) layout.addLayout(settings_layout) + layout.addWidget(self.rule_list) + self.setLayout(layout) # Set fixed size (keep this at the end of the init!) @@ -202,20 +361,33 @@ def sizeHint(self) -> None: return hint def _update_filter(self) -> None: + """Send a signal to apply the filter""" + if not self.item.isSelected(): self.item.setSelected(True) + + # update the colors if necessary + if self.color.currentText() != self.current_color: + for i in range(self.rule_list.count()): + item = self.rule_list.item(i) # Get the QListWidgetItem + self.rule_list.itemWidget(item).update_color( + QColor(self.color.currentText()) + ) # Get the Rule widget + self.current_color = self.color.currentText() + self.filter_updated.emit() def _add_rule(self) -> None: """Create a new custom group""" item = QListWidgetItem(self.rule_list) - group_row = Rule(self.tracks_viewer.tracks.graph) + group_row = Rule(self.tracks_viewer.track_df, QColor(self.color.currentText())) self.rule_list.setItemWidget(item, group_row) item.setSizeHint(group_row.minimumSizeHint()) item.setBackground(QColor("#262931")) self.rule_list.addItem(item) group_row.delete.clicked.connect(partial(self._remove_rule, item)) + group_row.update.connect(self._update_filter) def _remove_rule(self, item: QListWidgetItem) -> None: """Remove a rule from the list. You must pass the list item that @@ -227,6 +399,7 @@ def _remove_rule(self, item: QListWidgetItem) -> None: """ row = self.rule_list.indexFromItem(item).row() self.rule_list.takeItem(row) + self._update_filter() class FilterWidget(QGroupBox): @@ -316,19 +489,22 @@ def _filter_nodes(self) -> None: """Assign a filtered set of nodes to the tracks_viewer based on the criteria of the selected filter""" if self.selected_filter.rule_list.count() == 0: - result_set = set() + self.tracks_viewer.filtered_nodes = [] else: - OP_MAP = { - "<": operator.lt, - "<=": operator.le, - ">": operator.gt, - ">=": operator.ge, - "=": operator.eq, - "\u2260": operator.ne, - } - - result_set = set(self.tracks_viewer.tracks.graph.nodes) + mask = pd.Series(True, index=self.tracks_viewer.track_df.index) + + def apply_logic(existing_mask, new_condition, logic): + if logic == "AND": + return existing_mask & new_condition + elif logic == "OR": + return existing_mask | new_condition + elif logic == "NOT": + return existing_mask & ~new_condition + elif logic == "XOR": + return existing_mask ^ new_condition + else: + raise ValueError(f"Unknown logic operator: {logic}") for i in range(self.selected_filter.rule_list.count()): item = self.selected_filter.rule_list.item(i) # Get the QListWidgetItem @@ -336,69 +512,25 @@ def _filter_nodes(self) -> None: item ) # Get the Rule widget - # Extract information from each rule - item_value = rule.item_dropdown.currentText() - sign_value = rule.sign_dropdown.currentText() - logic_value = rule.logic_dropdown.currentText() + column_name = rule.item_dropdown.currentText() + logic = rule.logic_dropdown.currentText() + column_data = self.tracks_viewer.track_df[column_name] - if isinstance(rule.value_widget, QSpinBox | QDoubleSpinBox): - value = rule.value_widget.value() - elif isinstance(rule.value_widget, QLineEdit): - value = rule.value_widget.text() + if column_data.dtype == int or column_data.dtype == float: + # Filtering based on numerical values + min_val, max_val = rule.value_widget.region.getRegion() + new_condition = column_data.between(min_val, max_val) else: - value = rule.value_widget.currentText() - - # Define a set for nodes that satisfy this rule - current_set = set() - - # Get the correct operator function - compare_op = OP_MAP.get(sign_value) - - # Iterate over nodes and apply the rule condition - for node, attrs in self.tracks_viewer.tracks.graph.nodes(data=True): - node_attr_value = attrs.get(item_value) - - try: - if ( - type(node_attr_value) in (tuple, list) - or node_attr_value is None - ): # not all nodes may have a value for given attribute, e.g. not all nodes may belong to a group, and therefore do not have a 'group' attribute - if node_attr_value is None: - node_attr_value = [] - node_attr_value = [str(v) for v in node_attr_value] - if sign_value == "=": - condition = ( - value in node_attr_value - ) # we consider the condition to be true if the requested value is present in the list (even if other values are also present in the list). - else: - condition = value not in node_attr_value - elif isinstance(value, str): - condition = compare_op(str(node_attr_value), value) - else: - node_attr_value = float(node_attr_value) - value = float(value) - condition = compare_op(node_attr_value, value) - - # If the condition is true, add the node to the current set - if condition: - current_set.add(node) - except ( - ValueError, - TypeError, - ): # If there's a type mismatch or conversion issue, skip this node - continue - - # Apply logic to chain the result sets - if logic_value == "AND": - result_set &= current_set - elif logic_value == "OR": - result_set |= current_set - elif logic_value == "NOT": - result_set -= current_set - elif logic_value == "XOR": - result_set ^= current_set - - self.tracks_viewer.filtered_nodes = result_set + # Categorical filtering based on selected choices + selected_choices = rule.value_widget.get_selected_choices() + new_condition = column_data.isin(selected_choices) + + # Apply the condition to the mask with the specified logic + mask = apply_logic(mask, new_condition, logic) + + self.tracks_viewer.filtered_nodes = self.tracks_viewer.track_df[mask][ + "node_id" + ].tolist() self.apply_filter.emit() def _create_group(self) -> None: From 1370c636646bf25ed424606495800a8e513571c8 Mon Sep 17 00:00:00 2001 From: AnniekStok Date: Tue, 3 Dec 2024 17:27:28 +0100 Subject: [PATCH 12/25] only fire event when editing of the spinboxes is finished (not while typing) or when the user adjusted the region with the mouse --- .../data_views/views_coordinator/filter_widget.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/motile_plugin/data_views/views_coordinator/filter_widget.py b/src/motile_plugin/data_views/views_coordinator/filter_widget.py index 1e75d27..fd19d00 100644 --- a/src/motile_plugin/data_views/views_coordinator/filter_widget.py +++ b/src/motile_plugin/data_views/views_coordinator/filter_widget.py @@ -106,8 +106,8 @@ def __init__(self, df: pd.DataFrame, fill_color: QColor): self.max_spinbox.setMinimum(self.data_min) self.max_spinbox.setMaximum(self.data_max) - self.min_spinbox.valueChanged.connect(self.update_min_value) - self.max_spinbox.valueChanged.connect(self.update_max_value) + self.min_spinbox.editingFinished.connect(self.update_min_value) + self.max_spinbox.editingFinished.connect(self.update_max_value) min_spin_box_layout = QHBoxLayout() min_spin_box_layout.addWidget(QLabel("Min value")) @@ -138,6 +138,8 @@ def update_range_label(self): self.min_spinbox.setValue(region[0]) self.max_spinbox.setValue(region[1]) + self.update_rule.emit() + def update_min_value(self): """Update the region by the value in the min_spinbox""" From 75f3aeee6ff01f6eaa48db1602f873112228a675 Mon Sep 17 00:00:00 2001 From: AnniekStok Date: Mon, 9 Dec 2024 15:53:08 +0100 Subject: [PATCH 13/25] display nodes that are not selected as contours, and nodes that are part of a group or that are selected using a fill. Since the colormap is still based on trackIDs, this does not yet distinguish between individual nodes and tracks: all nodes of the same track are shown as filled if one of them is selected or part of the active group --- .../data_views/views/layers/track_labels.py | 138 ++++++++++++++++-- .../views_coordinator/tracks_viewer.py | 19 +-- 2 files changed, 130 insertions(+), 27 deletions(-) diff --git a/src/motile_plugin/data_views/views/layers/track_labels.py b/src/motile_plugin/data_views/views/layers/track_labels.py index e64ff37..800b029 100644 --- a/src/motile_plugin/data_views/views/layers/track_labels.py +++ b/src/motile_plugin/data_views/views/layers/track_labels.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import napari import numpy as np @@ -9,6 +9,57 @@ if TYPE_CHECKING: from motile_plugin.data_views.views_coordinator.tracks_viewer import TracksViewer +import functools + +from napari.layers.labels._labels_utils import ( + expand_slice, +) +from scipy import ndimage as ndi + + +def get_contours( + labels: np.ndarray, + thickness: int, + background_label: int, + group_labels: list[int] | None = None, +): + """Computes the contours of a 2D label image. + + Parameters + ---------- + labels : array of integers + An input labels image. + thickness : int + It controls the thickness of the inner boundaries. The outside thickness is always 1. + The final thickness of the contours will be `thickness + 1`. + background_label : int + That label is used to fill everything outside the boundaries. + + Returns + ------- + A new label image in which only the boundaries of the input image are kept. + """ + struct_elem = ndi.generate_binary_structure(labels.ndim, 1) + + thick_struct_elem = ndi.iterate_structure(struct_elem, thickness).astype(bool) + + dilated_labels = ndi.grey_dilation(labels, footprint=struct_elem) + eroded_labels = ndi.grey_erosion(labels, footprint=thick_struct_elem) + not_boundaries = dilated_labels == eroded_labels + + contours = labels.copy() + contours[not_boundaries] = background_label + + # instead of filling with background label, fill the group label with their normal color + if group_labels is not None and len(group_labels) > 0: + group_mask = functools.reduce( + np.logical_or, (labels == val for val in group_labels) + ) + combined_mask = not_boundaries & group_mask + contours = np.where(combined_mask, labels, contours) + + return contours + class TrackLabels(napari.layers.Labels): """Extended labels layer that holds the track information and emits @@ -52,6 +103,10 @@ def __init__( ) self.viewer = viewer + self.viewer.dims.events.ndisplay.connect( + lambda: self.update_label_colormap(visible=None) + ) + self.group_labels = None # Key bindings (should be specified both on the viewer (in tracks_viewer) # and on the layer to overwrite napari defaults) @@ -223,10 +278,13 @@ def _refresh(self): self.refresh() - def update_label_colormap(self, visible: list[int] | str) -> None: + def update_label_colormap(self, visible: list[int] | str | None = None) -> None: """Updates the opacity of the label colormap to highlight the selected label and optionally hide cells not belonging to the current lineage""" + if visible is None: + visible = self.group_labels if self.group_labels is not None else "all" + highlighted = [ self.tracks_viewer.tracks.get_track_id(node) for node in self.tracks_viewer.selected_nodes @@ -240,23 +298,31 @@ def update_label_colormap(self, visible: list[int] | str) -> None: ] # set the first track_id to be the selected label color # update the opacity of the cyclic label colormap values according to whether nodes are visible/invisible/highlighted + self.colormap.color_dict = { + key: np.array( + [*value[:-1], 0.6 if key is not None and key != 0 else value[-1]], + dtype=np.float32, + ) + for key, value in self.colormap.color_dict.items() + } + if visible == "all": - self.colormap.color_dict = { - key: np.array( - [*value[:-1], 0.6 if key is not None and key != 0 else value[-1]], - dtype=np.float32, - ) - for key, value in self.colormap.color_dict.items() - } + self.contour = 0 + self.group_labels = None else: - self.colormap.color_dict = { - key: np.array([*value[:-1], 0], dtype=np.float32) - for key, value in self.colormap.color_dict.items() - } - for label in visible: - # find the index in the cyclic label colormap - self.colormap.color_dict[label][-1] = 0.6 + if self.viewer.dims.ndisplay == 2: + self.contour = 1 + self.group_labels = visible # for now we can not distinguish between individual nodes and tracks, but if we switch to unique labels this will be possible. + + else: + self.colormap.color_dict = { + key: np.array([*value[:-1], 0], dtype=np.float32) + for key, value in self.colormap.color_dict.items() + } + for label in visible: + # find the index in the cyclic label colormap + self.colormap.color_dict[label][-1] = 0.6 for label in highlighted: self.colormap.color_dict[label][-1] = 1 # full opacity @@ -264,6 +330,7 @@ def update_label_colormap(self, visible: list[int] | str) -> None: self.colormap = DirectLabelColormap( color_dict=self.colormap.color_dict ) # create a new colormap from the updated colors (otherwise it does not refresh) + self.refresh() def new_colormap(self): """Extended version of existing function, to emit refresh signal to also update colors in other layers/widgets""" @@ -286,3 +353,42 @@ def _check_selected_label(self): self.events.selected_label.connect( self._check_selected_label ) # connect again + + def _calculate_contour( + self, labels: np.ndarray, data_slice: tuple[slice, ...] + ) -> Optional[np.ndarray]: + """Calculate the contour of a given label array within the specified data slice. + + Parameters + ---------- + labels : np.ndarray + The label array. + data_slice : Tuple[slice, ...] + The slice of the label array on which to calculate the contour. + + Returns + ------- + Optional[np.ndarray] + The calculated contour as a boolean mask array. + Returns None if the contour parameter is less than 1, + or if the label array has more than 2 dimensions. + """ + if self.contour < 1: + return None + if labels.ndim > 2: + return None + + expanded_slice = expand_slice(data_slice, labels.shape, 1) + sliced_labels = get_contours( + labels[expanded_slice], + self.contour, + self.colormap.background_value, + self.group_labels, + ) + + # Remove the latest one-pixel border from the result + delta_slice = tuple( + slice(s1.start - s2.start, s1.stop - s2.start) + for s1, s2 in zip(data_slice, expanded_slice, strict=False) + ) + return sliced_labels[delta_slice] diff --git a/src/motile_plugin/data_views/views_coordinator/tracks_viewer.py b/src/motile_plugin/data_views/views_coordinator/tracks_viewer.py index aea74c9..54a9b1d 100644 --- a/src/motile_plugin/data_views/views_coordinator/tracks_viewer.py +++ b/src/motile_plugin/data_views/views_coordinator/tracks_viewer.py @@ -211,17 +211,14 @@ def filter_visible_nodes(self) -> list[int]: elif self.mode == "group": self.group_visible = [] if self.collection_widget.selected_collection is not None: - for node_id in self.collection_widget.selected_collection.collection: - self.group_visible += extract_lineage_tree( - self.tracks.graph, node_id - ) - - return list( - { - self.tracks.graph.nodes[node][NodeAttr.TRACK_ID.value] - for node in self.group_visible - } - ) + return list( + { + self.tracks.graph.nodes[node][NodeAttr.TRACK_ID.value] + for node in self.collection_widget.selected_collection.collection + } + ) + else: + return [] else: return "all" From 18ecd56c6269c869dede4c6c89f1a69ab9bc7f14 Mon Sep 17 00:00:00 2001 From: AnniekStok Date: Tue, 10 Dec 2024 08:46:10 +0100 Subject: [PATCH 14/25] connect the collection widget group update event to calling the label colormap updating --- src/motile_plugin/data_views/views_coordinator/tracks_viewer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/motile_plugin/data_views/views_coordinator/tracks_viewer.py b/src/motile_plugin/data_views/views_coordinator/tracks_viewer.py index 54a9b1d..b1ff08b 100644 --- a/src/motile_plugin/data_views/views_coordinator/tracks_viewer.py +++ b/src/motile_plugin/data_views/views_coordinator/tracks_viewer.py @@ -73,6 +73,7 @@ def __init__( self.tracks_list.view_tracks.connect(self.update_tracks) self.collection_widget = CollectionWidget(self) + self.collection_widget.group_changed.connect(self.update_selection) self.filter_widget = FilterWidget(self) self.filter_widget.apply_filter.connect(self.apply_filter) From 0e90a81981ee538c341528e3b649f98b0bec2d92 Mon Sep 17 00:00:00 2001 From: AnniekStok Date: Tue, 10 Dec 2024 14:39:06 +0100 Subject: [PATCH 15/25] bug fix from merge; --- .../data_views/views/tree_view/tree_widget.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/src/motile_plugin/data_views/views/tree_view/tree_widget.py b/src/motile_plugin/data_views/views/tree_view/tree_widget.py index cf48ae0..6334d7d 100644 --- a/src/motile_plugin/data_views/views/tree_view/tree_widget.py +++ b/src/motile_plugin/data_views/views/tree_view/tree_widget.py @@ -853,16 +853,6 @@ def _update_track_data(self, reset_view: bool | None = None) -> None: df = self.group_df else: df = self.tracks_viewer.track_df - self.tree_widget.update( - self.lineage_df, - self.view_direction, - self.feature, - self.selected_nodes, - self.tracks_viewer.filtered_nodes, - self.tracks_viewer.filter_color, - reset_view=reset_view, - allow_flip=allow_flip, - ) if self.sync: self.sync_views(force_update=True) From c5b31878440fa26cab5038996d1bda992781a6a2 Mon Sep 17 00:00:00 2001 From: AnniekStok Date: Wed, 11 Dec 2024 10:16:26 +0100 Subject: [PATCH 16/25] make a separate label class that can show both contours and filled shapes --- .../data_views/views/layers/contour_labels.py | 123 ++++++++++++++++++ .../data_views/views/layers/track_labels.py | 95 +------------- 2 files changed, 127 insertions(+), 91 deletions(-) create mode 100644 src/motile_plugin/data_views/views/layers/contour_labels.py diff --git a/src/motile_plugin/data_views/views/layers/contour_labels.py b/src/motile_plugin/data_views/views/layers/contour_labels.py new file mode 100644 index 0000000..92a9924 --- /dev/null +++ b/src/motile_plugin/data_views/views/layers/contour_labels.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +import functools +from typing import Optional + +import napari +import numpy as np +from napari.layers.labels._labels_utils import ( + expand_slice, +) +from napari.utils import DirectLabelColormap +from scipy import ndimage as ndi + + +def get_contours( + labels: np.ndarray, + thickness: int, + background_label: int, + group_labels: list[int] | None = None, +): + """Computes the contours of a 2D label image. + + Parameters + ---------- + labels : array of integers + An input labels image. + thickness : int + It controls the thickness of the inner boundaries. The outside thickness is always 1. + The final thickness of the contours will be `thickness + 1`. + background_label : int + That label is used to fill everything outside the boundaries. + + Returns + ------- + A new label image in which only the boundaries of the input image are kept. + """ + struct_elem = ndi.generate_binary_structure(labels.ndim, 1) + + thick_struct_elem = ndi.iterate_structure(struct_elem, thickness).astype(bool) + + dilated_labels = ndi.grey_dilation(labels, footprint=struct_elem) + eroded_labels = ndi.grey_erosion(labels, footprint=thick_struct_elem) + not_boundaries = dilated_labels == eroded_labels + + contours = labels.copy() + contours[not_boundaries] = background_label + + # instead of filling with background label, fill the group label with their normal color + if group_labels is not None and len(group_labels) > 0: + group_mask = functools.reduce( + np.logical_or, (labels == val for val in group_labels) + ) + combined_mask = not_boundaries & group_mask + contours = np.where(combined_mask, labels, contours) + + return contours + + +class ContourLabels(napari.layers.Labels): + """Extended labels layer that allows to show contours and filled labels simultaneously""" + + @property + def _type_string(self) -> str: + return "labels" # to make sure that the layer is treated as labels layer for saving + + def __init__( + self, + viewer: napari.Viewer, + data: np.array, + name: str, + opacity: float, + scale: tuple, + colormap: DirectLabelColormap, + ): + super().__init__( + data=data, + name=name, + opacity=opacity, + scale=scale, + colormap=colormap, + ) + + self.viewer = viewer + self.group_labels = None + + def _calculate_contour( + self, labels: np.ndarray, data_slice: tuple[slice, ...] + ) -> Optional[np.ndarray]: + """Calculate the contour of a given label array within the specified data slice. + + Parameters + ---------- + labels : np.ndarray + The label array. + data_slice : Tuple[slice, ...] + The slice of the label array on which to calculate the contour. + + Returns + ------- + Optional[np.ndarray] + The calculated contour as a boolean mask array. + Returns None if the contour parameter is less than 1, + or if the label array has more than 2 dimensions. + """ + if self.contour < 1: + return None + if labels.ndim > 2: + return None + + expanded_slice = expand_slice(data_slice, labels.shape, 1) + sliced_labels = get_contours( + labels[expanded_slice], + self.contour, + self.colormap.background_value, + self.group_labels, + ) + + # Remove the latest one-pixel border from the result + delta_slice = tuple( + slice(s1.start - s2.start, s1.stop - s2.start) + for s1, s2 in zip(data_slice, expanded_slice, strict=False) + ) + return sliced_labels[delta_slice] diff --git a/src/motile_plugin/data_views/views/layers/track_labels.py b/src/motile_plugin/data_views/views/layers/track_labels.py index 800b029..ec0625c 100644 --- a/src/motile_plugin/data_views/views/layers/track_labels.py +++ b/src/motile_plugin/data_views/views/layers/track_labels.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING import napari import numpy as np @@ -9,59 +9,10 @@ if TYPE_CHECKING: from motile_plugin.data_views.views_coordinator.tracks_viewer import TracksViewer -import functools - -from napari.layers.labels._labels_utils import ( - expand_slice, -) -from scipy import ndimage as ndi - - -def get_contours( - labels: np.ndarray, - thickness: int, - background_label: int, - group_labels: list[int] | None = None, -): - """Computes the contours of a 2D label image. - - Parameters - ---------- - labels : array of integers - An input labels image. - thickness : int - It controls the thickness of the inner boundaries. The outside thickness is always 1. - The final thickness of the contours will be `thickness + 1`. - background_label : int - That label is used to fill everything outside the boundaries. - - Returns - ------- - A new label image in which only the boundaries of the input image are kept. - """ - struct_elem = ndi.generate_binary_structure(labels.ndim, 1) - - thick_struct_elem = ndi.iterate_structure(struct_elem, thickness).astype(bool) - - dilated_labels = ndi.grey_dilation(labels, footprint=struct_elem) - eroded_labels = ndi.grey_erosion(labels, footprint=thick_struct_elem) - not_boundaries = dilated_labels == eroded_labels - - contours = labels.copy() - contours[not_boundaries] = background_label - - # instead of filling with background label, fill the group label with their normal color - if group_labels is not None and len(group_labels) > 0: - group_mask = functools.reduce( - np.logical_or, (labels == val for val in group_labels) - ) - combined_mask = not_boundaries & group_mask - contours = np.where(combined_mask, labels, contours) - - return contours +from .contour_labels import ContourLabels -class TrackLabels(napari.layers.Labels): +class TrackLabels(ContourLabels): """Extended labels layer that holds the track information and emits and responds to dynamics visualization signals""" @@ -95,6 +46,7 @@ def __init__( ) super().__init__( + viewer=viewer, data=data, name=name, opacity=opacity, @@ -353,42 +305,3 @@ def _check_selected_label(self): self.events.selected_label.connect( self._check_selected_label ) # connect again - - def _calculate_contour( - self, labels: np.ndarray, data_slice: tuple[slice, ...] - ) -> Optional[np.ndarray]: - """Calculate the contour of a given label array within the specified data slice. - - Parameters - ---------- - labels : np.ndarray - The label array. - data_slice : Tuple[slice, ...] - The slice of the label array on which to calculate the contour. - - Returns - ------- - Optional[np.ndarray] - The calculated contour as a boolean mask array. - Returns None if the contour parameter is less than 1, - or if the label array has more than 2 dimensions. - """ - if self.contour < 1: - return None - if labels.ndim > 2: - return None - - expanded_slice = expand_slice(data_slice, labels.shape, 1) - sliced_labels = get_contours( - labels[expanded_slice], - self.contour, - self.colormap.background_value, - self.group_labels, - ) - - # Remove the latest one-pixel border from the result - delta_slice = tuple( - slice(s1.start - s2.start, s1.stop - s2.start) - for s1, s2 in zip(data_slice, expanded_slice, strict=False) - ) - return sliced_labels[delta_slice] From 2f183a05fc32ad7e93cda6952312f2d81f3378c2 Mon Sep 17 00:00:00 2001 From: AnniekStok Date: Wed, 11 Dec 2024 14:55:09 +0100 Subject: [PATCH 17/25] remove viewer from contour labels --- src/motile_plugin/data_views/views/layers/contour_labels.py | 2 -- src/motile_plugin/data_views/views/layers/track_labels.py | 1 - 2 files changed, 3 deletions(-) diff --git a/src/motile_plugin/data_views/views/layers/contour_labels.py b/src/motile_plugin/data_views/views/layers/contour_labels.py index 92a9924..8077379 100644 --- a/src/motile_plugin/data_views/views/layers/contour_labels.py +++ b/src/motile_plugin/data_views/views/layers/contour_labels.py @@ -65,7 +65,6 @@ def _type_string(self) -> str: def __init__( self, - viewer: napari.Viewer, data: np.array, name: str, opacity: float, @@ -80,7 +79,6 @@ def __init__( colormap=colormap, ) - self.viewer = viewer self.group_labels = None def _calculate_contour( diff --git a/src/motile_plugin/data_views/views/layers/track_labels.py b/src/motile_plugin/data_views/views/layers/track_labels.py index ec0625c..a2f2be7 100644 --- a/src/motile_plugin/data_views/views/layers/track_labels.py +++ b/src/motile_plugin/data_views/views/layers/track_labels.py @@ -46,7 +46,6 @@ def __init__( ) super().__init__( - viewer=viewer, data=data, name=name, opacity=opacity, From 1654155e31c7855af7250a3f5738165b7b413b65 Mon Sep 17 00:00:00 2001 From: AnniekStok Date: Wed, 11 Dec 2024 14:55:43 +0100 Subject: [PATCH 18/25] bug fix: show_all_group -> show_group_radio --- .../data_views/views/tree_view/tree_view_mode_widget.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/motile_plugin/data_views/views/tree_view/tree_view_mode_widget.py b/src/motile_plugin/data_views/views/tree_view/tree_view_mode_widget.py index 8777785..0c340e4 100644 --- a/src/motile_plugin/data_views/views/tree_view/tree_view_mode_widget.py +++ b/src/motile_plugin/data_views/views/tree_view/tree_view_mode_widget.py @@ -48,7 +48,7 @@ def _toggle_display_mode(self, event=None) -> None: if self.mode == "lineage": self._set_mode("group") - self.show_all_group.setChecked(True) + self.show_group_radio.setChecked(True) elif self.mode == "group": self._set_mode("all") self.show_all_radio.setChecked(True) From a3dbf50239f30ab0039b666b09712fc6829643ca Mon Sep 17 00:00:00 2001 From: AnniekStok Date: Wed, 11 Dec 2024 15:16:12 +0100 Subject: [PATCH 19/25] use a set for the collection instead of a separate class, which has become obsolete --- .../views_coordinator/collection_widget.py | 44 +++++++++++++------ .../views_coordinator/collections.py | 39 ---------------- 2 files changed, 30 insertions(+), 53 deletions(-) delete mode 100644 src/motile_plugin/data_views/views_coordinator/collections.py diff --git a/src/motile_plugin/data_views/views_coordinator/collection_widget.py b/src/motile_plugin/data_views/views_coordinator/collection_widget.py index 6f41194..f3e9e86 100644 --- a/src/motile_plugin/data_views/views_coordinator/collection_widget.py +++ b/src/motile_plugin/data_views/views_coordinator/collection_widget.py @@ -22,8 +22,6 @@ extract_lineage_tree, ) -from . import Collection - if TYPE_CHECKING: from motile_plugin.data_views.views_coordinator.tracks_viewer import TracksViewer @@ -35,7 +33,7 @@ def __init__(self, name: str): super().__init__() self.name = QLabel(name) self.name.setFixedHeight(20) - self.collection = Collection() + self.collection = set() delete_icon = QColoredSVGIcon.from_resources("delete").colored("white") self.node_count = QLabel(f"{len(self.collection)} nodes") self.delete = QPushButton(icon=delete_icon) @@ -186,7 +184,7 @@ def _select_nodes(self) -> None: if selected: self.selected_collection = self.collection_list.itemWidget(selected[0]) self.tracks_viewer.selected_nodes.add_list( - self.selected_collection.collection._list, append=False + list(self.selected_collection.collection), append=False ) def retrieve_existing_groups(self) -> None: @@ -209,8 +207,9 @@ def retrieve_existing_groups(self) -> None: # populate the lists based on the nodes that were assigned to the different groups for i in range(self.collection_list.count()): self.collection_list.setCurrentRow(i) - self.selected_collection.collection.add( - group_dict[self.selected_collection.name.text()] + self.selected_collection.collection = ( + self.selected_collection.collection + | set(group_dict[self.selected_collection.name.text()]) ) self.selected_collection.update_node_count() # enforce updating the node count for all elements @@ -238,7 +237,9 @@ def add_nodes(self, nodes: list[Any] | None = None) -> None: nodes = ( self.tracks_viewer.selected_nodes._list ) # take the nodes that are currently selected - self.selected_collection.collection.add(nodes) + self.selected_collection.collection = ( + self.selected_collection.collection | set(nodes) + ) for node_id in nodes: if "group" not in self.tracks_viewer.tracks.graph.nodes[node_id]: self.tracks_viewer.tracks.graph.nodes[node_id]["group"] = [] @@ -269,7 +270,9 @@ def _add_track(self) -> None: if data.get("track_id") == track_id } ) - self.selected_collection.collection.add(track) + self.selected_collection.collection = ( + self.selected_collection.collection | set(track) + ) for node_id in track: if "group" not in self.tracks_viewer.tracks.graph.nodes[node_id]: self.tracks_viewer.tracks.graph.nodes[node_id]["group"] = [] @@ -288,7 +291,9 @@ def _add_lineage(self) -> None: if self.selected_collection is not None: for node_id in self.tracks_viewer.selected_nodes: lineage = extract_lineage_tree(self.tracks_viewer.tracks.graph, node_id) - self.selected_collection.collection.add(lineage) + self.selected_collection.collection = ( + self.selected_collection.collection | set(lineage) + ) for node_id in lineage: if "group" not in self.tracks_viewer.tracks.graph.nodes[node_id]: self.tracks_viewer.tracks.graph.nodes[node_id]["group"] = [] @@ -306,9 +311,11 @@ def _remove_node(self) -> None: if self.selected_collection is not None: # remove from the collection - self.selected_collection.collection.remove( - self.tracks_viewer.selected_nodes - ) + self.selected_collection.collection = { + item + for item in self.selected_collection.collection + if item not in self.tracks_viewer.selected_nodes + } # remove from the node attribute for node_id in self.tracks_viewer.selected_nodes: @@ -340,7 +347,12 @@ def _remove_track(self) -> None: if data.get("track_id") == track_id } ) - self.selected_collection.collection.remove(track) + + self.selected_collection.collection = { + item + for item in self.selected_collection.collection + if item not in track + } for node_id in track: self.tracks_viewer.tracks.graph.nodes[node_id]["group"].remove( self.selected_collection.name.text() @@ -353,7 +365,11 @@ def _remove_lineage(self) -> None: if self.selected_collection is not None: for node_id in self.tracks_viewer.selected_nodes: lineage = extract_lineage_tree(self.tracks_viewer.tracks.graph, node_id) - self.selected_collection.collection.remove(lineage) + self.selected_collection.collection = { + item + for item in self.selected_collection.collection + if item not in lineage + } for node_id in lineage: self.tracks_viewer.tracks.graph.nodes[node_id]["group"].remove( self.selected_collection.name.text() diff --git a/src/motile_plugin/data_views/views_coordinator/collections.py b/src/motile_plugin/data_views/views_coordinator/collections.py deleted file mode 100644 index b5c46c0..0000000 --- a/src/motile_plugin/data_views/views_coordinator/collections.py +++ /dev/null @@ -1,39 +0,0 @@ -from __future__ import annotations - -from PyQt5.QtCore import QObject - - -class Collection(QObject): - """A collection of node ids belonging to a group""" - - def __init__(self): - super().__init__() - self._list = [] - - def add(self, items: list, append: bool | None = True): - """Add nodes from a list""" - - if append: - for item in items: - if item in self._list: - continue - else: - self._list.append(item) - - else: - self._list = items - - def remove(self, items: list): - """Remove nodes from a list""" - - self._list = [item for item in self._list if item not in items] - - def reset(self): - """Empty list""" - self._list = [] - - def __getitem__(self, index): - return self._list[index] - - def __len__(self): - return len(self._list) From 5218db5161c703c1b85790d0f5a446ede5453df0 Mon Sep 17 00:00:00 2001 From: AnniekStok Date: Wed, 11 Dec 2024 15:41:38 +0100 Subject: [PATCH 20/25] combine functions to add/remove nodes from a collection --- .../data_views/views_coordinator/__init__.py | 1 - .../views_coordinator/collection_widget.py | 146 ++++++------------ 2 files changed, 50 insertions(+), 97 deletions(-) diff --git a/src/motile_plugin/data_views/views_coordinator/__init__.py b/src/motile_plugin/data_views/views_coordinator/__init__.py index ff8b08b..e69de29 100644 --- a/src/motile_plugin/data_views/views_coordinator/__init__.py +++ b/src/motile_plugin/data_views/views_coordinator/__init__.py @@ -1 +0,0 @@ -from .collections import Collection # noqa diff --git a/src/motile_plugin/data_views/views_coordinator/collection_widget.py b/src/motile_plugin/data_views/views_coordinator/collection_widget.py index f3e9e86..4cb0978 100644 --- a/src/motile_plugin/data_views/views_coordinator/collection_widget.py +++ b/src/motile_plugin/data_views/views_coordinator/collection_widget.py @@ -91,7 +91,7 @@ def __init__(self, tracks_viewer: TracksViewer): add_layout = QHBoxLayout() self.add_nodes_btn = QPushButton("Add node(s)") - self.add_nodes_btn.clicked.connect(lambda: self.add_nodes(None)) + self.add_nodes_btn.clicked.connect(self._add_selection) self.add_track_btn = QPushButton("Add track(s)") self.add_track_btn.clicked.connect(self._add_track) self.add_lineage_btn = QPushButton("Add lineage(s)") @@ -102,7 +102,7 @@ def __init__(self, tracks_viewer: TracksViewer): remove_layout = QHBoxLayout() self.remove_node_btn = QPushButton("Remove node(s)") - self.remove_node_btn.clicked.connect(self._remove_node) + self.remove_node_btn.clicked.connect(self._remove_selection) self.remove_track_btn = QPushButton("Remove track(s)") self.remove_track_btn.clicked.connect(self._remove_track) self.remove_lineage_btn = QPushButton("Remove lineage(s)") @@ -233,10 +233,6 @@ def add_nodes(self, nodes: list[Any] | None = None) -> None: """ if self.selected_collection is not None: - if nodes is None: - nodes = ( - self.tracks_viewer.selected_nodes._list - ) # take the nodes that are currently selected self.selected_collection.collection = ( self.selected_collection.collection | set(nodes) ) @@ -253,72 +249,47 @@ def add_nodes(self, nodes: list[Any] | None = None) -> None: self.group_changed.emit() + def _add_selection(self) -> None: + """Add the currently selected node(s) to the collection""" + + self.add_nodes(self.tracks_viewer.selected_nodes._list) + def _add_track(self) -> None: - """Add tracks by track_ids to the selected collection and send update signal""" + """Add the tracks belonging to selected nodes to the selected collection""" - if self.selected_collection is not None: - for node_id in self.tracks_viewer.selected_nodes: - track_id = self.tracks_viewer.tracks._get_node_attr( - node_id, NodeAttr.TRACK_ID.value - ) - track = list( - { - node - for node, data in self.tracks_viewer.tracks.graph.nodes( - data=True - ) - if data.get("track_id") == track_id - } - ) - self.selected_collection.collection = ( - self.selected_collection.collection | set(track) - ) - for node_id in track: - if "group" not in self.tracks_viewer.tracks.graph.nodes[node_id]: - self.tracks_viewer.tracks.graph.nodes[node_id]["group"] = [] - if ( - self.selected_collection.name.text() - not in self.tracks_viewer.tracks.graph.nodes[node_id]["group"] - ): - self.tracks_viewer.tracks.graph.nodes[node_id]["group"].append( - self.selected_collection.name.text() - ) - self.group_changed.emit() + for node_id in self.tracks_viewer.selected_nodes: + track_id = self.tracks_viewer.tracks._get_node_attr( + node_id, NodeAttr.TRACK_ID.value + ) + track = list( + { + node + for node, data in self.tracks_viewer.tracks.graph.nodes(data=True) + if data.get("track_id") == track_id + } + ) + self.add_nodes(track) def _add_lineage(self) -> None: - """Add lineages to the selected collection and send update signal""" + """Add lineages to the selected collection""" - if self.selected_collection is not None: - for node_id in self.tracks_viewer.selected_nodes: - lineage = extract_lineage_tree(self.tracks_viewer.tracks.graph, node_id) - self.selected_collection.collection = ( - self.selected_collection.collection | set(lineage) - ) - for node_id in lineage: - if "group" not in self.tracks_viewer.tracks.graph.nodes[node_id]: - self.tracks_viewer.tracks.graph.nodes[node_id]["group"] = [] - if ( - self.selected_collection.name.text() - not in self.tracks_viewer.tracks.graph.nodes[node_id]["group"] - ): - self.tracks_viewer.tracks.graph.nodes[node_id]["group"].append( - self.selected_collection.name.text() - ) - self.group_changed.emit() + for node_id in self.tracks_viewer.selected_nodes: + lineage = extract_lineage_tree(self.tracks_viewer.tracks.graph, node_id) + self.add_nodes(lineage) - def _remove_node(self) -> None: - """Remove individual nodes from the selected collection and send update signal""" + def remove_nodes(self, nodes: list[Any]) -> None: + """Remove selected nodes from the selected collection""" if self.selected_collection is not None: # remove from the collection self.selected_collection.collection = { item for item in self.selected_collection.collection - if item not in self.tracks_viewer.selected_nodes + if item not in nodes } # remove from the node attribute - for node_id in self.tracks_viewer.selected_nodes: + for node_id in nodes: node_attrs = self.tracks_viewer.tracks.graph.nodes[node_id] group_list = node_attrs.get("group") if ( @@ -330,51 +301,34 @@ def _remove_node(self) -> None: self.group_changed.emit() + def _remove_selection(self) -> None: + """Remove individual nodes from the selected collection""" + + self.remove_nodes(self.tracks_viewer.selected_nodes) + def _remove_track(self) -> None: - """Remove tracks by track id from the selected collection and send update signal""" + """Remove tracks by track id from the selected collection""" - if self.selected_collection is not None: - for node_id in self.tracks_viewer.selected_nodes: - track_id = self.tracks_viewer.tracks._get_node_attr( - node_id, NodeAttr.TRACK_ID.value - ) - track = list( - { - node - for node, data in self.tracks_viewer.tracks.graph.nodes( - data=True - ) - if data.get("track_id") == track_id - } - ) - - self.selected_collection.collection = { - item - for item in self.selected_collection.collection - if item not in track + for node_id in self.tracks_viewer.selected_nodes: + track_id = self.tracks_viewer.tracks._get_node_attr( + node_id, NodeAttr.TRACK_ID.value + ) + track = list( + { + node + for node, data in self.tracks_viewer.tracks.graph.nodes(data=True) + if data.get("track_id") == track_id } - for node_id in track: - self.tracks_viewer.tracks.graph.nodes[node_id]["group"].remove( - self.selected_collection.name.text() - ) - self.group_changed.emit() + ) + + self.remove_nodes(track) def _remove_lineage(self) -> None: - """Remove lineages from the selected collection and send update signal""" + """Remove lineages from the selected collection""" - if self.selected_collection is not None: - for node_id in self.tracks_viewer.selected_nodes: - lineage = extract_lineage_tree(self.tracks_viewer.tracks.graph, node_id) - self.selected_collection.collection = { - item - for item in self.selected_collection.collection - if item not in lineage - } - for node_id in lineage: - self.tracks_viewer.tracks.graph.nodes[node_id]["group"].remove( - self.selected_collection.name.text() - ) - self.group_changed.emit() + for node_id in self.tracks_viewer.selected_nodes: + lineage = extract_lineage_tree(self.tracks_viewer.tracks.graph, node_id) + self.remove_nodes(lineage) def add_group(self, name: str | None = None, select: bool = True) -> None: """Create a new custom group From b7c347a555ff251698ac7a40a73dfeac13aae2c4 Mon Sep 17 00:00:00 2001 From: AnniekStok Date: Wed, 11 Dec 2024 16:00:00 +0100 Subject: [PATCH 21/25] get rid of for loops when setting tree view outlines --- .../data_views/views/tree_view/tree_widget.py | 52 +++++++++++-------- 1 file changed, 29 insertions(+), 23 deletions(-) diff --git a/src/motile_plugin/data_views/views/tree_view/tree_widget.py b/src/motile_plugin/data_views/views/tree_view/tree_widget.py index 6334d7d..859c8be 100644 --- a/src/motile_plugin/data_views/views/tree_view/tree_widget.py +++ b/src/motile_plugin/data_views/views/tree_view/tree_widget.py @@ -329,33 +329,39 @@ def set_selection( if len(filtered_nodes) > 0: color = [c * 255 for c in color] - for node_id in filtered_nodes: - node_df = self.track_df.loc[self.track_df["node_id"] == node_id] - if not node_df.empty: - index = self.node_ids.index(node_id) - outlines[index] = pg.mkPen(color=color, width=2) + node_indices = {node_id: idx for idx, node_id in enumerate(self.node_ids)} + valid_nodes = [ + node_id for node_id in filtered_nodes if node_id in node_indices + ] + indices_to_update = [node_indices[node_id] for node_id in valid_nodes] + + new_values = [pg.mkPen(color=color, width=2)] * len(indices_to_update) + outlines = np.array(outlines) # Convert to NumPy for slicing + outlines[indices_to_update] = new_values if len(selected_nodes) > 0: - x_values = [] - t_values = [] - for node_id in selected_nodes: - node_df = self.track_df.loc[self.track_df["node_id"] == node_id] - x_axis_value = None - if not node_df.empty: - x_axis_value = node_df[axis_label].values[0] - t = node_df["t"].values[0] - - x_values.append(x_axis_value) - t_values.append(t) - - # Update size and outline - index = self.node_ids.index(node_id) - size[index] += 5 - outlines[index] = pg.mkPen(color="c", width=2) + filtered_df = self.track_df[ + self.track_df["node_id"].isin(selected_nodes._list) + ] + x_values = filtered_df[axis_label].values + t_values = filtered_df["t"].values + node_indices = {node_id: idx for idx, node_id in enumerate(self.node_ids)} + valid_indices = np.array( + [ + node_indices[node_id] + for node_id in selected_nodes + if node_id in node_indices + ] + ) + size[valid_indices] += 5 + outlines[valid_indices] = pg.mkPen(color="c", width=2) # Center point if a single node is selected, center range if multiple nodes are selected - if len(selected_nodes) == 1 and x_axis_value is not None: - self._center_view(x_axis_value, t) + if len(selected_nodes) == 1 and not filtered_df.empty: + x_axis_value = filtered_df[axis_label].values[0] + t = filtered_df["t"].values[0] + if x_axis_value is not None: + self._center_view(x_axis_value, t) else: if ( len(x_values) > 0 From 2c584c460673c6ca6bc1f096ef416190ac6354fd Mon Sep 17 00:00:00 2001 From: AnniekStok Date: Thu, 12 Dec 2024 18:09:20 +0100 Subject: [PATCH 22/25] bugfix: remove _list from collection._list as it is a list already --- src/motile_plugin/data_views/views/tree_view/tree_widget.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/motile_plugin/data_views/views/tree_view/tree_widget.py b/src/motile_plugin/data_views/views/tree_view/tree_widget.py index 859c8be..40bfc4c 100644 --- a/src/motile_plugin/data_views/views/tree_view/tree_widget.py +++ b/src/motile_plugin/data_views/views/tree_view/tree_widget.py @@ -994,7 +994,7 @@ def _update_group_df(self) -> None: if node_id in self.tracks_viewer.track_df["node_id"].tolist(): visible += extract_lineage_tree(self.graph, node_id) else: - self.tracks_viewer.collection_widget.selected_collection.collection._list.remove( + self.tracks_viewer.collection_widget.selected_collection.collection.remove( node_id ) self.group_df = self.tracks_viewer.track_df[ @@ -1009,7 +1009,7 @@ def _update_group_df(self) -> None: self.group_df["color"] = self.group_df.apply( lambda row: [*row["color"][:3], 62.0] if row["node_id"] - not in self.tracks_viewer.collection_widget.selected_collection.collection._list + not in self.tracks_viewer.collection_widget.selected_collection.collection else row["color"], axis=1, ) From dcb2067924d9a89e91ca36861e7531491b620947 Mon Sep 17 00:00:00 2001 From: AnniekStok Date: Wed, 18 Dec 2024 10:07:37 +0100 Subject: [PATCH 23/25] ensure that deleted nodes are also deleted from the group in the collection widget --- .../views_coordinator/collection_widget.py | 82 ++++++++++++++----- .../views_coordinator/tracks_viewer.py | 1 + 2 files changed, 63 insertions(+), 20 deletions(-) diff --git a/src/motile_plugin/data_views/views_coordinator/collection_widget.py b/src/motile_plugin/data_views/views_coordinator/collection_widget.py index 4cb0978..f8d9700 100644 --- a/src/motile_plugin/data_views/views_coordinator/collection_widget.py +++ b/src/motile_plugin/data_views/views_coordinator/collection_widget.py @@ -177,6 +177,19 @@ def _update_buttons_and_node_count(self) -> None: else: self.new_group_button.setEnabled(False) + def _refresh(self) -> None: + """Removes nodes that are no longer existing from the collection""" + + selected = self.collection_list.selectedItems() + if selected: + self.selected_collection = self.collection_list.itemWidget(selected[0]) + nodes = self.selected_collection.collection + graph_nodes = set(self.tracks_viewer.tracks.graph.nodes) + self.selected_collection.collection = { + item for item in nodes if item in graph_nodes + } + self.selected_collection.update_node_count() + def _select_nodes(self) -> None: """Select all nodes in the collection""" @@ -257,24 +270,39 @@ def _add_selection(self) -> None: def _add_track(self) -> None: """Add the tracks belonging to selected nodes to the selected collection""" + track_ids = [] for node_id in self.tracks_viewer.selected_nodes: track_id = self.tracks_viewer.tracks._get_node_attr( node_id, NodeAttr.TRACK_ID.value ) - track = list( - { - node - for node, data in self.tracks_viewer.tracks.graph.nodes(data=True) - if data.get("track_id") == track_id - } - ) - self.add_nodes(track) + if track_id in track_ids: + continue # skip, since we already added all nodes with this track id + else: + track = list( + { + node + for node, data in self.tracks_viewer.tracks.graph.nodes( + data=True + ) + if data.get("track_id") == track_id + } + ) + self.add_nodes(track) + track_ids.append(track_id) def _add_lineage(self) -> None: """Add lineages to the selected collection""" + track_ids = [] for node_id in self.tracks_viewer.selected_nodes: - lineage = extract_lineage_tree(self.tracks_viewer.tracks.graph, node_id) + track_id = self.tracks_viewer.tracks._get_node_attr( + node_id, NodeAttr.TRACK_ID.value + ) + if track_id in track_ids: + continue # skip, since we already added all nodes with this track id + else: + lineage = extract_lineage_tree(self.tracks_viewer.tracks.graph, node_id) + track_ids.append(track_id) self.add_nodes(lineage) def remove_nodes(self, nodes: list[Any]) -> None: @@ -309,26 +337,40 @@ def _remove_selection(self) -> None: def _remove_track(self) -> None: """Remove tracks by track id from the selected collection""" + track_ids = [] for node_id in self.tracks_viewer.selected_nodes: track_id = self.tracks_viewer.tracks._get_node_attr( node_id, NodeAttr.TRACK_ID.value ) - track = list( - { - node - for node, data in self.tracks_viewer.tracks.graph.nodes(data=True) - if data.get("track_id") == track_id - } - ) - - self.remove_nodes(track) + if track_id in track_ids: + continue + else: + track = list( + { + node + for node, data in self.tracks_viewer.tracks.graph.nodes( + data=True + ) + if data.get("track_id") == track_id + } + ) + self.remove_nodes(track) + track_ids.append(track_id) def _remove_lineage(self) -> None: """Remove lineages from the selected collection""" + track_ids = [] for node_id in self.tracks_viewer.selected_nodes: - lineage = extract_lineage_tree(self.tracks_viewer.tracks.graph, node_id) - self.remove_nodes(lineage) + track_id = self.tracks_viewer.tracks._get_node_attr( + node_id, NodeAttr.TRACK_ID.value + ) + if track_id in track_ids: + continue + else: + lineage = extract_lineage_tree(self.tracks_viewer.tracks.graph, node_id) + self.remove_nodes(lineage) + track_ids.append(track_id) def add_group(self, name: str | None = None, select: bool = True) -> None: """Create a new custom group diff --git a/src/motile_plugin/data_views/views_coordinator/tracks_viewer.py b/src/motile_plugin/data_views/views_coordinator/tracks_viewer.py index b1ff08b..a9546eb 100644 --- a/src/motile_plugin/data_views/views_coordinator/tracks_viewer.py +++ b/src/motile_plugin/data_views/views_coordinator/tracks_viewer.py @@ -107,6 +107,7 @@ def _refresh(self, node: str | None = None, refresh_view: bool = False) -> None: self.selected_nodes._list = [] self.tracking_layers._refresh() + self.collection_widget._refresh() self.tracks_updated.emit(refresh_view) From 9ec0304b2a326e96284d184024dd8c48af73a2bd Mon Sep 17 00:00:00 2001 From: AnniekStok Date: Thu, 12 Dec 2024 18:08:18 +0100 Subject: [PATCH 24/25] add button to invert the selection --- .../views_coordinator/collection_widget.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/motile_plugin/data_views/views_coordinator/collection_widget.py b/src/motile_plugin/data_views/views_coordinator/collection_widget.py index f8d9700..926d0f2 100644 --- a/src/motile_plugin/data_views/views_coordinator/collection_widget.py +++ b/src/motile_plugin/data_views/views_coordinator/collection_widget.py @@ -80,10 +80,13 @@ def __init__(self, tracks_viewer: TracksViewer): selection_layout = QHBoxLayout() self.select_btn = QPushButton("Select nodes in group") self.select_btn.clicked.connect(self._select_nodes) + self.invert_btn = QPushButton("Invert selection") + self.invert_btn.clicked.connect(self._invert_selection) self.deselect_btn = QPushButton("Deselect") self.deselect_btn.clicked.connect(self.tracks_viewer.selected_nodes.reset) selection_layout.addWidget(self.select_btn) selection_layout.addWidget(self.deselect_btn) + selection_layout.addWidget(self.invert_btn) # edit layout edit_widget = QGroupBox("Edit") @@ -190,6 +193,16 @@ def _refresh(self) -> None: } self.selected_collection.update_node_count() + def _invert_selection(self) -> None: + """Invert the current selection""" + + nodes = [ + node + for node in self.tracks_viewer.tracks.graph.nodes + if node not in self.tracks_viewer.selected_nodes + ] + self.tracks_viewer.selected_nodes.add_list(nodes, append=False) + def _select_nodes(self) -> None: """Select all nodes in the collection""" From 47a3ab290168420ccd56d5fdd6072ccd2a9a5046 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Wed, 18 Dec 2024 11:06:12 -0500 Subject: [PATCH 25/25] Update imports in collection and filter widgets to motile_tracker --- .../data_views/views_coordinator/collection_widget.py | 9 +++++---- .../data_views/views_coordinator/filter_widget.py | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/motile_tracker/data_views/views_coordinator/collection_widget.py b/src/motile_tracker/data_views/views_coordinator/collection_widget.py index ad08fe9..dc2de46 100644 --- a/src/motile_tracker/data_views/views_coordinator/collection_widget.py +++ b/src/motile_tracker/data_views/views_coordinator/collection_widget.py @@ -3,9 +3,6 @@ from functools import partial from typing import TYPE_CHECKING, Any -from motile_plugin.data_views.views.tree_view.tree_widget_utils import ( - extract_lineage_tree, -) from motile_toolbox.candidate_graph.graph_attributes import NodeAttr from napari._qt.qt_resources import QColoredSVGIcon from qtpy.QtCore import Signal @@ -21,8 +18,12 @@ QWidget, ) +from motile_tracker.data_views.views.tree_view.tree_widget_utils import ( + extract_lineage_tree, +) + if TYPE_CHECKING: - from motile_plugin.data_views.views_coordinator.tracks_viewer import TracksViewer + from motile_tracker.data_views.views_coordinator.tracks_viewer import TracksViewer class CollectionButton(QWidget): diff --git a/src/motile_tracker/data_views/views_coordinator/filter_widget.py b/src/motile_tracker/data_views/views_coordinator/filter_widget.py index fd19d00..8875936 100644 --- a/src/motile_tracker/data_views/views_coordinator/filter_widget.py +++ b/src/motile_tracker/data_views/views_coordinator/filter_widget.py @@ -27,7 +27,7 @@ ) if TYPE_CHECKING: - from motile_plugin.data_views.views_coordinator.tracks_viewer import TracksViewer + from motile_tracker.data_views.views_coordinator.tracks_viewer import TracksViewer import pyqtgraph as pg from qtpy.QtCore import Qt