Skip to content

Commit

Permalink
Merge pull request #104 from PolarizedLightFieldMicroscopy/jones-matr…
Browse files Browse the repository at this point in the history
…ix-mult

Jones matrix multiplication loop
  • Loading branch information
gschlafly authored Jul 23, 2024
2 parents b1da992 + 5aa86c8 commit cbda367
Show file tree
Hide file tree
Showing 8 changed files with 207 additions and 187 deletions.
141 changes: 73 additions & 68 deletions src/VolumeRaytraceLFM/abstract_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ class BackEnds(Enum):


class OpticalElement(OpticBlock):
"""Abstract class defining a elements, with a back-end ans some optical information"""
"""Abstract class defining optical elements with a back-end and
some optical information."""

default_optical_info = {
# Volume information
Expand All @@ -86,89 +87,93 @@ class OpticalElement(OpticBlock):
}

def __init__(
self, backend: BackEnds = BackEnds.NUMPY, torch_args={}, optical_info={}
self, backend: BackEnds = BackEnds.NUMPY, torch_args=None, optical_info=None
):
# torch args could be {'optic_config' : None, 'members_to_learn' : []},

# Optical info is needed
assert (
len(optical_info) > 0
), f"Optical info (optical_info) dictionary needed: \
use OpticalElement.default_optical_info as reference \
{OpticalElement.default_optical_info}"
torch_args = torch_args or {}
optical_info = optical_info or {}
assert len(optical_info) > 0, (
f"Optical info (optical_info) dictionary needed so using the default: "
f"{OpticalElement.default_optical_info}"
)

# Compute voxel size
if optical_info["cube_voxels"] is False:
optical_info["voxel_size_um"] = [
optical_info["axial_voxel_size_um"],
] + 2 * [
optical_info["pixels_per_ml"]
* optical_info["camera_pix_pitch"]
/ optical_info["M_obj"]
/ optical_info["n_voxels_per_ml"]
]
else:
# Option to make voxel size uniform
optical_info["voxel_size_um"] = 3 * [
optical_info["pixels_per_ml"]
* optical_info["camera_pix_pitch"]
/ optical_info["M_obj"]
/ optical_info["n_voxels_per_ml"]
]
voxel_size = self._compute_voxel_size(optical_info)
optical_info["voxel_size_um"] = voxel_size

# Check if back-end is torch and overwrite self with an optic block, for Waveblocks
# compatibility.
if backend == BackEnds.PYTORCH:
# We need to make a copy if we don't want to modify the torch_args default argument,
# very weird.
new_torch_args = copy.deepcopy(torch_args)
# If no optic_config is provided, create one
if "optic_config" not in torch_args.keys() or (
"optic_config" not in torch_args.keys()
and not isinstance(torch_args["optic_config"], OpticConfig)
# We need to make a copy if we don't want to modify the
# torch_args default argument, very weird.
new_args = copy.deepcopy(torch_args)
if "optic_config" not in torch_args or not isinstance(
torch_args["optic_config"], OpticConfig
):
new_torch_args["optic_config"] = OpticConfig()
new_torch_args["optic_config"].volume_config.volume_shape = (
optical_info["volume_shape"]
)
new_torch_args["optic_config"].volume_config.voxel_size_um = (
optical_info["voxel_size_um"]
)
new_torch_args["optic_config"].mla_config.n_pixels_per_mla = (
optical_info["pixels_per_ml"]
)
new_torch_args["optic_config"].mla_config.n_micro_lenses = optical_info[
"n_micro_lenses"
]
new_torch_args["optic_config"].PSF_config.NA = optical_info["na_obj"]
new_torch_args["optic_config"].PSF_config.ni = optical_info["n_medium"]
new_torch_args["optic_config"].PSF_config.wvl = optical_info[
"wavelength"
]
try:
new_torch_args["optic_config"].pol_config.polarizer = optical_info[
"polarizer"
]
new_torch_args["optic_config"].pol_config.analyzer = optical_info[
"analyzer"
]
except:
print("Error: Polarizer and Analyzer not found in optical_info")
super(OpticalElement, self).__init__(
optic_config=new_torch_args["optic_config"],
members_to_learn=(
new_torch_args["members_to_learn"]
if "members_to_learn" in new_torch_args.keys()
else []
),
new_args["optic_config"] = self.create_optic_config(optical_info)
super().__init__(
optic_config=new_args["optic_config"],
members_to_learn=new_args.get("members_to_learn", []),
)
# Store variables

self.backend = backend
self.simul_type = SimulType.NOT_SPECIFIED
self.optical_info = optical_info

def _compute_voxel_size(self, optical_info):
if not optical_info["cube_voxels"]:
return [optical_info["axial_voxel_size_um"]] + 2 * [
optical_info["pixels_per_ml"]
* optical_info["camera_pix_pitch"]
/ optical_info["M_obj"]
/ optical_info["n_voxels_per_ml"]
]
return 3 * [
optical_info["pixels_per_ml"]
* optical_info["camera_pix_pitch"]
/ optical_info["M_obj"]
/ optical_info["n_voxels_per_ml"]
]

@staticmethod
def get_optical_info_template():
return copy.deepcopy(OpticalElement.default_optical_info)

def create_optic_config(self, optical_info):
"""Creates an OpticConfig instance and populates it with the provided optical information."""
optic_config = OpticConfig()

# Populate volume configuration
optic_config.volume_config.volume_shape = optical_info.get(
"volume_shape", [1, 1, 1]
)
optic_config.volume_config.voxel_size_um = optical_info.get(
"voxel_size_um", [1.0, 1.0, 1.0]
)

# Populate microlens array configuration
optic_config.mla_config.n_pixels_per_mla = optical_info.get("pixels_per_ml", 1)
optic_config.mla_config.n_micro_lenses = optical_info.get("n_micro_lenses", 1)

# Populate PSF configuration
optic_config.PSF_config.NA = optical_info.get("na_obj", 1.0)
optic_config.PSF_config.ni = optical_info.get("n_medium", 1.0)
optic_config.PSF_config.wvl = optical_info.get("wavelength", 0.550)

# Populate polarizer and analyzer if they exist
if "polarizer" in optical_info:
optic_config.pol_config.polarizer = optical_info["polarizer"]
if "analyzer" in optical_info:
optic_config.pol_config.analyzer = optical_info["analyzer"]

if DEBUG and "polarizer" not in optical_info and "analyzer" not in optical_info:
print(
"Warning: polarizer and analyzer not found in optical_info. "
+ "This could be problematic if simulating with intensity images."
)
return optic_config


###########################################################################################
class RayTraceLFM(OpticalElement):
Expand All @@ -195,7 +200,7 @@ def __init__(
},
):
# Initialize the OpticalElement class
super(RayTraceLFM, self).__init__(
super().__init__(
backend=backend, torch_args=torch_args, optical_info=optical_info
)

Expand Down
2 changes: 1 addition & 1 deletion src/VolumeRaytraceLFM/birefringence_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class BirefringentElement(OpticalElement):
def __init__(
self, backend: BackEnds = BackEnds.NUMPY, torch_args={}, optical_info=None
):
super(BirefringentElement, self).__init__(
super().__init__(
backend=backend, torch_args=torch_args, optical_info=optical_info
)
self.simul_type = SimulType.BIREFRINGENT
Loading

0 comments on commit cbda367

Please sign in to comment.