Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Segmentation to mesh conversion #29

Merged
merged 14 commits into from
Mar 30, 2024
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ install_requires =
numpy
pooch
qtpy
pyacvd
pyvista
rich
scikit-image
starfile
Expand Down
79 changes: 71 additions & 8 deletions src/surforama/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
import trimesh
from magicgui import magicgui
from napari.layers import Image, Surface
from napari.layers import Image, Labels, Surface
from qtpy.QtCore import Qt
from qtpy.QtWidgets import (
QGroupBox,
Expand All @@ -17,6 +17,7 @@
QVBoxLayout,
QWidget,
)
from scipy.ndimage import map_coordinates

from surforama.constants import (
NAPARI_NORMAL_0,
Expand All @@ -27,7 +28,7 @@
NAPARI_UP_2,
ROTATION,
)
from surforama.io import read_obj_file
from surforama.io import convert_mask_to_mesh, read_obj_file
from surforama.io.star import oriented_points_to_star_file
from surforama.utils.geometry import rotate_around_vector
from surforama.utils.napari import (
Expand Down Expand Up @@ -100,6 +101,8 @@ def __init__(
)
self.point_writer_widget.setVisible(False)

self.mesh_generator_widget = QtMeshGenerator(viewer, parent=self)

# make the layout
self.setLayout(QVBoxLayout())
self.layout().addWidget(self._layer_selection_widget.native)
Expand All @@ -110,6 +113,7 @@ def __init__(
self.layout().addWidget(self.picking_widget)
self.layout().addWidget(self.point_writer_widget)
self.layout().addStretch()
self.layout().addWidget(self.mesh_generator_widget)

# set the layers
self._set_layers(surface_layer=surface_layer, image_layer=volume_layer)
Expand Down Expand Up @@ -144,7 +148,7 @@ def _set_layers(
self.surface_layer.refresh()

self.normals = self.mesh.vertex_normals
self.volume = image_layer.data
self.volume = image_layer.data.astype(np.float32)

# make the widgets visible
self.slider.setVisible(True)
Expand All @@ -171,11 +175,9 @@ def _get_valid_image_layers(self, combo_box) -> List[Image]:
]

def get_point_colors(self, points):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also added this linear interpolation here. The direct sampling on coordinates gave some artifacts, particularly on the Chlamy thylakoid demo example.

point_indices = points.astype(int)

point_values = self.volume[
point_indices[:, 0], point_indices[:, 1], point_indices[:, 2]
]
point_values = map_coordinates(
self.volume, points.T, order=1, mode="nearest"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is the speed when doing linear interpolation? Does it still feel responsive? Do you think this should be an option? In any case, we can make that option later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the demo example on my local computer, it still ran very smooth with the demo examples (higher order interpolations were very slow though). But yes, I guess for larger meshes, this can become problematic.
Would be good to have that as an option.

)

normalized_values = (point_values - point_values.min()) / (
point_values.max() - point_values.min() + np.finfo(float).eps
Expand Down Expand Up @@ -483,6 +485,67 @@ def _write_star_file(self, output_path: Path):
)


class QtMeshGenerator(QGroupBox):
def __init__(
self, viewer: napari.Viewer, parent: Optional[QWidget] = None
):
super().__init__("Generate Mesh from Labels", parent=parent)
self.viewer = viewer

# make the labels layer selection widget
self.labels_layer_selection_widget = magicgui(
self._generate_mesh_from_labels,
labels_layer={"choices": self._get_valid_labels_layers},
barycentric_area={
"widget_type": "Slider",
"min": 0.1,
"max": 10.0,
"value": 1.0,
"step": 0.1,
},
smoothing={
"widget_type": "Slider",
"min": 0,
"max": 1000,
"value": 1000,
},
call_button="Generate Mesh",
)

# make the layout
self.setLayout(QVBoxLayout())
self.layout().addWidget(self.labels_layer_selection_widget.native)

# Add callback to update choices when layers change
self.viewer.layers.events.inserted.connect(self._on_layer_update)
self.viewer.layers.events.removed.connect(self._on_layer_update)

def _on_layer_update(self, event=None):
"""Refresh the layer choices when layers are added or removed."""
self.labels_layer_selection_widget.reset_choices()

def _get_valid_labels_layers(self, combo_box) -> List[Labels]:
return [
layer
for layer in self.viewer.layers
if isinstance(layer, napari.layers.Labels)
]

def _generate_mesh_from_labels(
self,
labels_layer: Labels,
smoothing: int = 10,
barycentric_area: float = 1.0,
):
# Assuming create_mesh_from_mask exists and generates vertices, faces, and values
vertices, faces, values = convert_mask_to_mesh(
labels_layer.data,
smoothing=smoothing,
barycentric_area=barycentric_area,
)
self.viewer.add_surface((vertices, faces, values))


if __name__ == "__main__":

obj_path = "../../examples/tomo_17_M10_grow1_1_mesh_data.obj"
Expand Down
4 changes: 2 additions & 2 deletions src/surforama/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from surforama.io.mesh import read_obj_file
from surforama.io.mesh import convert_mask_to_mesh, read_obj_file

__all__ = ("read_obj_file",)
__all__ = ("read_obj_file", "convert_mask_to_mesh")
55 changes: 55 additions & 0 deletions src/surforama/io/mesh.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import numpy as np
import pyacvd
import pyvista as pv
import trimesh
from skimage.measure import marching_cubes


def read_obj_file(file_path):
Expand All @@ -23,3 +26,55 @@ def read_obj_file(file_path):
values = np.ones((len(verts),))

return verts, faces, values


def convert_mask_to_mesh(
mask: np.ndarray,
barycentric_area: float = 1.0,
smoothing: int = 10,
):
"""
Convert a binary mask to a mesh.

Parameters
----------
mask : np.ndarray
A binary mask.
barycentric_area : float, optional
The target barycentric area of each vertex in the mesh,
by default 1.0
smoothing : int, optional
Number of iterations for Laplacian mesh smoothing,
by default 10
"""
verts, faces, _, _ = marching_cubes(
volume=mask,
level=0.5,
step_size=1,
method="lewiner",
)

# Prepend 3 for pyvista format
faces = np.concatenate(
(np.ones((faces.shape[0], 1), dtype=int) * 3, faces), axis=1
)

# Create a mesh
surf = pv.PolyData(verts, faces)
surf = surf.smooth(n_iter=smoothing)

# remesh to desired point size
cluster_points = int(surf.area / barycentric_area)
clus = pyacvd.Clustering(surf)
clus.subdivide(3)
clus.cluster(cluster_points)
remesh = clus.create_mesh()

verts = remesh.points
faces = remesh.faces.reshape(-1, 4)[:, 1:]
values = np.ones((len(verts),))

# switch face order to have outward normals
faces = faces[:, [0, 2, 1]]

return verts, faces, values
Loading