Skip to content

Commit

Permalink
add loader namespace
Browse files Browse the repository at this point in the history
  • Loading branch information
hanjinliu committed Dec 8, 2023
1 parent ca28d1e commit 8b8b7b5
Show file tree
Hide file tree
Showing 9 changed files with 90 additions and 50 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "cylindra_ext"
version = "1.0.0-alpha.6"
version = "1.0.0-alpha.7"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
Expand Down
4 changes: 3 additions & 1 deletion cylindra/project/_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,9 @@ def _normalize_id(self, out: pl.DataFrame, id: _IDTYPE) -> pl.DataFrame:
_map[i] = label
_appeared.add(label)
out = out.with_columns(
pl.col(Mole.image).map_dict(_map, return_dtype=pl.Categorical)
pl.col(Mole.image).replace(
_map, default=None, return_dtype=pl.Categorical
)
)
else:
raise ValueError(f"Invalid id type {id!r}.")
Expand Down
2 changes: 1 addition & 1 deletion cylindra/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from cylindra.const import PREVIEW_LAYER_NAME

if TYPE_CHECKING:
from magicgui.widgets._bases import CategoricalWidget
from magicgui.widgets.bases import CategoricalWidget


def _viewer_ancestor() -> "napari.Viewer | None":
Expand Down
90 changes: 64 additions & 26 deletions cylindra/widgets/_accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
TypeVar,
MutableSequence,
)
from typing_extensions import Self
import weakref
from fnmatch import fnmatch

Expand All @@ -19,36 +18,26 @@
from cylindra.types import MoleculesLayer

if TYPE_CHECKING:
from magicclass import MagicTemplate
from cylindra.widgets.main import CylindraMainWidget
from cylindra.widgets.batch import CylindraBatchWidget
from cylindra.widgets.batch._loaderlist import LoaderInfo
from acryo import Molecules
from typing_extensions import Self

_T = TypeVar("_T", bound="Accessor")
_V = TypeVar("_V")
_W = TypeVar("_W", bound="MagicTemplate")


class AccessorField(Generic[_T]):
def __init__(self, constructor: type[_T]):
self._instances = dict[int, _T]()
self._constructor = constructor
class Accessor(MutableSequence[_V], Generic[_V, _W]):
def __init__(self, widget: _W | None = None):
if widget is not None:
self._widget = weakref.ref(widget)
else:
self._widget = lambda: None
self._instances = dict[int, "Self"]()

def __get__(self, instance: Any, owner: type) -> _T:
if instance is None:
return self
_id = id(instance)
if _id not in self._instances:
self._instances[_id] = self._constructor(instance)
return self._instances[_id]


class Accessor(MutableSequence[_V]):
def __init__(self, widget: CylindraMainWidget):
self._widget = weakref.ref(widget)

@classmethod
def field(cls) -> AccessorField[Self]:
return AccessorField(cls)

def widget(self) -> CylindraMainWidget:
def widget(self) -> _W:
widget = self._widget()
if widget is None:
raise RuntimeError("Widget is already deleted.")
Expand All @@ -60,11 +49,22 @@ def viewer(self) -> napari.Viewer:
raise RuntimeError("Viewer not found.")
return viewer

def __get__(self, instance: Any, owner: type) -> Self:
if instance is None:
return self
_id = id(instance)
if _id not in self._instances:
self._instances[_id] = self.__class__(instance)
return self._instances[_id]

def __repr__(self) -> str:
return f"{type(self).__name__}({list(self)!r})"


_Condition = Callable[[MoleculesLayer], bool]


class MoleculesLayerAccessor(Accessor[MoleculesLayer]):
class MoleculesLayerAccessor(Accessor[MoleculesLayer, "CylindraMainWidget"]):
"""Accessor to the molecules layers of the viewer."""

def __getitem__(self, name: str) -> MoleculesLayer:
Expand Down Expand Up @@ -106,7 +106,7 @@ def count(self) -> int:
return len(self)

def _ipython_key_completions_(self) -> list[str]:
"""Just for autocompletion."""
"""Just for autocompletion.""" # BUG: not working
return self.names()

def __iter__(self) -> Iterator[MoleculesLayer]:
Expand Down Expand Up @@ -253,3 +253,41 @@ def _get_monomer_layer(viewer: napari.Viewer, name: str) -> MoleculesLayer:
if not isinstance(layer, MoleculesLayer):
raise TypeError(f"Layer {name} is not a MoleculesLayer.")
return layer


class BatchLoaderAccessor(Accessor["LoaderInfo", "CylindraBatchWidget"]):
def __getitem__(self, name: str) -> LoaderInfo:
return self.widget()._loaders[name]

def __setitem__(self, name: str, layer: LoaderInfo) -> None:
if name in self:
raise ValueError(f"Layer {name} already exists.")
return self.append(layer)

def __delitem__(self, name: str) -> None:
for i, info in enumerate(self.widget()._loaders):
if info.name == name:
self.widget()._loaders.pop(i)
return
raise KeyError(f"Loader {name} not found.")

def insert(self, index: int, info: LoaderInfo) -> None:
return self.widget()._loaders.insert(index, info)

def __iter__(self) -> Iterator[LoaderInfo]:
return iter(self.widget()._loaders)

def __len__(self) -> int:
return len(self.widget()._loaders)

def names(self) -> list[str]:
"""All molecules layer names."""
return list(layer.name for layer in self)

def count(self) -> int:
"""Number of molecules layers."""
return len(self)

def _ipython_key_completions_(self) -> list[str]:
"""Just for autocompletion."""
return self.names()
2 changes: 2 additions & 0 deletions cylindra/widgets/batch/_loaderlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def find(self, name: str, default: int | None = None) -> int:
return -1

def insert(self, index: int, value: LoaderInfo) -> None:
if not isinstance(value, LoaderInfo):
raise TypeError(f"Expected LoaderInfo, got {type(value)}")
name = value.name
if self.find(name, default=-1) >= 0:
if re.match(r".+-\d+$", name):
Expand Down
2 changes: 2 additions & 0 deletions cylindra/widgets/batch/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)
from cylindra.core import ACTIVE_WIDGETS
from cylindra.widget_utils import POLARS_NAMESPACE, capitalize
from cylindra.widgets._accessors import BatchLoaderAccessor
from cylindra.project import CylindraProject, CylindraBatchProject
from cylindra._config import get_config
from .sta import BatchSubtomogramAveraging
Expand Down Expand Up @@ -54,6 +55,7 @@
class CylindraBatchWidget(MagicTemplate):
constructor = field(ProjectSequenceEdit)
sta = field(BatchSubtomogramAveraging)
loaders = BatchLoaderAccessor()

def __init__(self):
self._loaders = LoaderList()
Expand Down
33 changes: 14 additions & 19 deletions cylindra/widgets/batch/sta.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def _get_loader_names(self, _=None) -> list[str]:
parent = self._get_parent()
except Exception:
return []
return [info.name for info in parent._loaders]
return parent.loaders.names()

# Menus
BatchSubtomogramAnalysis = field(
Expand All @@ -113,8 +113,7 @@ def _get_current_loader_name(self, _=None) -> str:
@do_not_record
def show_loader_info(self):
"""Show information about this loader"""
loaderlist = self._get_parent()._loaders
info = loaderlist[self.loader_name]
info = self._get_parent().loaders[self.loader_name]
loader = info.loader
img_info = "\n" + "\n".join(
f"{img_id}: {img_path}" for img_id, img_path in info.image_paths.items()
Expand All @@ -135,8 +134,7 @@ def remove_loader(
self, loader_name: Annotated[str, {"bind": _get_current_loader_name}]
):
"""Remove this loader"""
loaderlist = self._get_parent()._loaders
del loaderlist[loader_name]
del self._get_parent().loaders[loader_name]

params = field(StaParameters)

Expand All @@ -161,14 +159,14 @@ def split_loader(
delete_old: bool = False,
):
parent = self._get_parent()
batch_info = parent._loaders[loader_name]
loaders = parent._loaders
batch_info = loaders[loader_name]
batch_loader = batch_info.loader
n_unique = batch_loader.molecules.features[by].n_unique()
if n_unique > 48:
raise ValueError(
f"Too many groups ({n_unique}). Did you choose a float column?"
)
loaders = parent._loaders
for _key, loader in batch_loader.groupby(by):
existing_id = set(loader.features[Mole.image])
image_paths = {
Expand Down Expand Up @@ -224,8 +222,7 @@ def filter_loader(
@nogui
def get_loader(self, name: str) -> BatchLoader:
"""Return the acryo.BatchLoader object with the given name"""
info = self._get_parent()._loaders[name]
return info.loader
return self._get_parent().loaders[name].loader

@set_design(text="Average all molecules", location=BatchSubtomogramAnalysis)
@dask_thread_worker.with_progress(desc="Averaging all molecules in projects")
Expand All @@ -237,8 +234,7 @@ def average_all(
bin_size: _BINSIZE = 1,
):
t0 = timer()
loaderlist = self._get_parent()._loaders
loader = loaderlist[loader_name].loader
loader = self._get_parent().loaders[loader_name].loader
shape = self._get_shape_in_px(size, loader)
img = ip.asarray(
loader.replace(output_shape=shape, order=interpolation)
Expand Down Expand Up @@ -274,8 +270,7 @@ def average_groups(
{interpolation}{bin_size}
"""
t0 = timer()
loaderlist = self._get_parent()._loaders
loader = loaderlist[loader_name].loader
loader = self._get_parent().loaders[loader_name].loader
shape = self._get_shape_in_px(size, loader)
img = ip.asarray(
loader.replace(output_shape=shape, order=interpolation)
Expand Down Expand Up @@ -372,8 +367,9 @@ def calculate_fsc(
at frequency 0.01, 0.03, 0.05, ..., 0.45.
"""
t0 = timer()
loader = self._get_parent()._loaders[loader_name].loader
loader = loader.replace(order=interpolation)
loader = (
self._get_parent().loaders[loader_name].loader.replace(order=interpolation)
)

template, mask = loader.normalize_input(
template=self.params._norm_template_param(template_path, allow_none=True),
Expand Down Expand Up @@ -440,8 +436,7 @@ def classify_pca(
from cylindra.widgets.subwidgets import PcaViewer

t0 = timer()
loaderlist = self._get_parent()._loaders
loader = loaderlist[loader_name].loader
loader = self._get_parent().loaders[loader_name].loader
shape = self._get_shape_in_px(size, loader)

_, mask = loader.normalize_input(
Expand Down Expand Up @@ -512,7 +507,7 @@ def show_template_original(self):
@do_not_record
def show_mask(self):
"""Load and show mask image in the scale of the tomogram."""
loader = self._get_parent()._loaders[self.loader_name].loader
loader = self._get_parent().loaders[self.loader_name].loader
_, mask = loader.normalize_input(
self.params._norm_template_param(allow_none=True), self.params._get_mask()
)
Expand All @@ -535,7 +530,7 @@ def _get_shape_in_px(
return (roundint(default / loader.scale),) * 3

def _get_template_image(self) -> ip.ImgArray:
scale = self._get_parent()._loaders[self.loader_name].loader.scale
scale = self._get_parent().loaders[self.loader_name].loader.scale

template = self.params._norm_template_param(
self.params._get_template_input(allow_multiple=True),
Expand Down
2 changes: 1 addition & 1 deletion cylindra/widgets/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ class CylindraMainWidget(MagicTemplate):
# Widget for subtomogram analysis
sta = field(SubtomogramAveraging, name="_Subtomogram averaging")

mole_layers = MoleculesLayerAccessor.field()
mole_layers = MoleculesLayerAccessor()

@property
def batch(self) -> "CylindraBatchWidget":
Expand Down
3 changes: 2 additions & 1 deletion cylindra/widgets/sta.py
Original file line number Diff line number Diff line change
Expand Up @@ -2033,7 +2033,8 @@ def _define_correlation_function(
temp = ip.asarray(temp, axes="zyx")

def _fn(img: "NDArray[np.float32]") -> float:
return func(ip.asarray(img * mask, axes="zyx"), temp * mask)
corr = float(func(ip.asarray(img * mask, axes="zyx"), temp * mask))
return corr

return _fn

Expand Down

0 comments on commit 8b8b7b5

Please sign in to comment.