Skip to content

Commit

Permalink
Merge pull request #138 from yfukai/frames_argment_in_transform
Browse files Browse the repository at this point in the history
Frames argument in transform
  • Loading branch information
yfukai authored Dec 12, 2023
2 parents fb5cdc0 + 3a6bb2c commit 32d6762
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 6 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ However, thanks to [cloudhan/jax-windows-builder](https://github.com/cloudhan/ja

```bash
pip install "jax[cpu]==0.4.11" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
pip install ml-dtypes==0.2.0
pip install basicpy
```

Expand Down
3 changes: 2 additions & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@ def tests(session: Session) -> None:
"""Run the test suite."""
if platform.system() == "Windows":
session.install(
"jax[cpu]===0.4.11",
"jax[cpu]==0.4.11",
"-f",
"https://whls.blob.core.windows.net/unstable/index.html",
"--use-deprecated",
"legacy-resolver",
)
session.install("ml-dtypes==0.2.0")
session.install(".")
session.install(
"dask",
Expand Down
17 changes: 12 additions & 5 deletions src/basicpy/basicpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from enum import Enum
from multiprocessing import cpu_count
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import jax.numpy as jnp

Expand Down Expand Up @@ -388,7 +388,7 @@ def fit(
if self.fitting_mode == FittingMode.approximate:
init_mu = self.mu_coef / spectral_norm
else:
init_mu = self.mu_coef / spectral_norm / np.product(Im2.shape)
init_mu = self.mu_coef / spectral_norm / np.prod(Im2.shape)
fit_params = self.dict()
fit_params.update(
dict(
Expand Down Expand Up @@ -537,7 +537,10 @@ def fit(
)

def transform(
self, images: np.ndarray, timelapse: Union[bool, TimelapseTransformMode] = False
self,
images: np.ndarray,
timelapse: Union[bool, TimelapseTransformMode] = False,
frames: Optional[Sequence[Union[int, np.int_]]] = None,
) -> Union[Tuple[np.ndarray, np.ndarray], np.ndarray]:
"""Apply profile to images.
Expand All @@ -548,6 +551,7 @@ def transform(
the object fluorescence. Also accepts "multiplicative"
(the same as `True`) or "additive" (residual is the object
fluorescence).
frames: Frames to use for transformation. Defaults to None (all frames).
Returns:
corrected images
Expand Down Expand Up @@ -588,8 +592,11 @@ def transform(
if timelapse:
if timelapse is True:
timelapse = TimelapseTransformMode.multiplicative

baseline_inds = tuple([slice(None)] + ([np.newaxis] * (im_float.ndim - 1)))
if frames is None:
_frames = slice(None)
else:
_frames = np.array(frames)
baseline_inds = tuple([_frames] + ([np.newaxis] * (im_float.ndim - 1)))
if timelapse == TimelapseTransformMode.multiplicative:
output = (im_float - self.darkfield[np.newaxis]) / self.flatfield[
np.newaxis
Expand Down

0 comments on commit 32d6762

Please sign in to comment.