Skip to content

Commit

Permalink
Merge pull request #68 from PolarizedLightFieldMicroscopy/zero-pixels
Browse files Browse the repository at this point in the history
Refactor birefringence implementations module
  • Loading branch information
gschlafly authored Dec 8, 2023
2 parents 67c54ca + 403cc74 commit 3ccf271
Show file tree
Hide file tree
Showing 21 changed files with 996 additions and 486 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
*.pyc
*.pyc
reconstructions/*
14 changes: 14 additions & 0 deletions VolumeRaytraceLFM/birefringence_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from VolumeRaytraceLFM.abstract_classes import (
BackEnds, OpticalElement, SimulType
)

class BirefringentElement(OpticalElement):
''' Birefringent element, such as voxel, raytracer, etc,
extending optical element, so it has a back-end and optical information'''
def __init__(self, backend : BackEnds = BackEnds.NUMPY, torch_args={},
optical_info=None):
super(BirefringentElement, self).__init__(backend=backend,
torch_args=torch_args,
optical_info=optical_info
)
self.simul_type = SimulType.BIREFRINGENT
806 changes: 342 additions & 464 deletions VolumeRaytraceLFM/birefringence_implementations.py

Large diffs are not rendered by default.

135 changes: 135 additions & 0 deletions VolumeRaytraceLFM/file_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import numpy as np
import h5py
import tifffile

class VolumeFileManager:
def __init__(self):
"""Initializes the VolumeFileManager class."""
pass

def extract_data_from_h5(self, file_path):
"""
Extracts birefringence (delta_n) and optic axis data from an H5 file.
Args:
- file_path (str): Path to the H5 file from which data is to be extracted.
Returns:
- tuple: A tuple containing numpy arrays for delta_n and optic_axis.
"""
volume_file = h5py.File(file_path, "r")
delta_n = np.array(volume_file['data/delta_n'])
optic_axis = np.array(volume_file['data/optic_axis'])

return delta_n, optic_axis

def extract_all_data_from_h5(self, file_path):
"""
Extracts birefringence (delta_n), optic axis data, and optical information from an H5 file.
Args:
- file_path (str): Path to the H5 file from which data is to be extracted.
Returns:
- tuple: A tuple containing numpy arrays for
delta_n, optic_axis, volume_shape, and voxel_size_um.
"""
volume_file = h5py.File(file_path, "r")

# Fetch birefringence and optic axis
delta_n = np.array(volume_file['data/delta_n'])
optic_axis = np.array(volume_file['data/optic_axis'])

# Fetch optical info
volume_shape = np.array(volume_file['optical_info/volume_shape'])
voxel_size_um = np.array(volume_file['optical_info/voxel_size_um'])

return delta_n, optic_axis, volume_shape, voxel_size_um

def save_as_channel_stack_tiff(self, filename, delta_n, optic_axis):
"""
Saves the provided volume data as a multi-channel TIFF file.
Args:
- filename (str): The file path where the TIFF file will be saved.
- delta_n (np.ndarray): Numpy array containing the birefringence information of the volume.
- optic_axis (np.ndarray): Numpy array containing the optic axis data of the volume.
The method combines delta_n and optic_axis data into a single multi-channel array
and saves it as a TIFF file. Exceptions related to file operations are caught and logged.
"""
try:
print(f'Saving volume to file: {filename}')
combined_data = np.stack([delta_n, optic_axis[0], optic_axis[1], optic_axis[2]], axis=0)
tifffile.imwrite(filename, combined_data)
print('Volume saved successfully.')
except Exception as e:
print(f"Error saving file: {e}")

def save_as_h5(self, h5_file_path, delta_n, optic_axis, optical_info, description, optical_all):
"""
Saves the volume data, including birefringence information (delta_n) and optic axis data,
along with optical metadata into an H5 file.
The method creates an H5 file at the specified path and writes the provided data
to this file, organizing the data into appropriate groups and datasets within the file.
Args:
- h5_file_path (str): The file path where the H5 file will be saved.
- delta_n (np.ndarray): Numpy array containing the birefringence information of the volume.
- optic_axis (np.ndarray): Numpy array containing the optic axis data of the volume.
- optical_info (dict): Dictionary containing optical metadata about the volume. This may
include properties like volume shape, voxel size, etc.
- description (str): A brief description or note to be included in the optical information
of the H5 file. Useful for providing context or additional details about the data.
- optical_all (bool): A flag indicating whether to save all optical metadata present in
`optical_info` to the H5 file. If False, only specific predefined metadata (like volume
shape and voxel size) will be saved.
Returns:
None. The result of this method is the creation of an H5 file with the specified data.
"""
with h5py.File(h5_file_path, "w") as f:
self._save_optical_info(f, optical_info, description, optical_all)
self._save_data(f, delta_n, optic_axis)

def _save_optical_info(self, file_handle, optical_info, description, optical_all):
"""
Private method to save optical information to an H5 file.
Args:
- file_handle (File): An open H5 file handle.
- optical_info (dict): Dictionary containing optical metadata.
- description (str): Description to be included in the H5 file.
- optical_all (bool): Flag indicating whether to save all optical metadata.
This method creates a group for optical information and adds datasets to it.
"""
optics_grp = file_handle.create_group('optical_info')
optics_grp.create_dataset('description', data=np.string_(description))
# optics_grp.create_dataset('description', data=description)
if not optical_all:
vol_shape = optical_info.get('volume_shape', None)
voxel_size_um = optical_info.get('voxel_size_um', None)
if vol_shape is not None:
optics_grp.create_dataset('volume_shape', data=np.array(vol_shape))
if voxel_size_um is not None:
optics_grp.create_dataset('voxel_size_um', data=np.array(voxel_size_um))
else:
for k, v in optical_info.items():
optics_grp.create_dataset(k, data=np.array(v))

def _save_data(self, file_handle, delta_n, optic_axis):
"""
Private method to save delta_n and optic_axis data to an H5 file.
Args:
- file_handle (File): An open H5 file handle.
- delta_n (np.ndarray): Numpy array of delta_n data.
- optic_axis (np.ndarray): Numpy array of optic_axis data.
This method creates a group for volume data and adds datasets for delta_n and optic_axis.
"""
data_grp = file_handle.create_group('data')
data_grp.create_dataset("delta_n", delta_n.shape, data=delta_n)
data_grp.create_dataset("optic_axis", optic_axis.shape, data=optic_axis)
195 changes: 195 additions & 0 deletions VolumeRaytraceLFM/jones_calculus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
'''Jones Calculus Matrices and Vector Generators
Constructors for different types of elements.
These methods are constructors only. They don't support torch
optimization of internal variables.
'''
import numpy as np
import torch
from VolumeRaytraceLFM.abstract_classes import BackEnds
from VolumeRaytraceLFM.birefringence_base import BirefringentElement


class JonesMatrixGenerators(BirefringentElement):
'''2x2 Jones matrices representing various of polariztion elements'''

def __init__(self, backend : BackEnds = BackEnds.NUMPY):
super(BirefringentElement, self).__init__(backend=backend, torch_args={}, optical_info={})

@staticmethod
def rotator(angle, backend=BackEnds.NUMPY):
'''2D rotation matrix
Args:
angle: angle to rotate by counterclockwise [radians]
Return: Jones matrix'''
if backend == BackEnds.NUMPY:
s = np.sin(angle)
c = np.cos(angle)
R = np.array([[c, -s], [s, c]])
elif backend == BackEnds.PYTORCH:
s = torch.sin(angle)
c = torch.cos(angle)
R = torch.tensor([[c, -s], [s, c]])
return R

@staticmethod
def linear_retarder(ret, azim, backend=BackEnds.NUMPY):
'''Linear retarder
Args:
ret (float): retardance [radians]
azim (float): azimuth angle of fast axis [radians]
Return: Jones matrix
'''
retarder_azim0 = JonesMatrixGenerators.linear_retarder_azim0(ret, backend=backend)
R = JonesMatrixGenerators.rotator(azim, backend=backend)
Rinv = JonesMatrixGenerators.rotator(-azim, backend=backend)
return R @ retarder_azim0 @ Rinv

@staticmethod
def linear_retarder_azim0(ret, backend=BackEnds.NUMPY):
'''todo'''
if backend == BackEnds.NUMPY:
return np.array([[np.exp(1j * ret / 2), 0], [0, np.exp(-1j * ret / 2)]])
else:
return torch.cat(
(torch.cat((torch.exp(1j * ret / 2).unsqueeze(1), torch.zeros(len(ret),1)),1).unsqueeze(2),
torch.cat((torch.zeros(len(ret),1), torch.exp(-1j * ret / 2).unsqueeze(1)),1).unsqueeze(2)),
2
)

@staticmethod
def linear_retarter_azim90(ret, backend=BackEnds.NUMPY):
'''Linear retarder, convention not establisted yet'''
# TODO: using same convention as linear_retarder_azim0
if backend == BackEnds.NUMPY:
return np.array([[np.exp(1j * ret / 2), 0], [0, np.exp(-1j * ret / 2)]])
else:
return torch.tensor([torch.exp(1j * ret / 2), 0], [0, torch.exp(-1j * ret / 2)])

@staticmethod
def quarter_waveplate(azim):
'''Quarter Waveplate
Linear retarder with lambda/4 or equiv pi/2 radians
Commonly used to convert linear polarized light to circularly polarized light'''
ret = np.pi / 2
return JonesMatrixGenerators.linear_retarder(ret, azim)

@staticmethod
def half_waveplate(azim):
'''Half Waveplate
Linear retarder with lambda/2 or equiv pi radians
Commonly used to rotate the plane of linear polarization'''
# Faster method
s = np.sin(2 * azim)
c = np.cos(2 * azim)
# # Alternative method
# ret = np.pi
# JM = self.LR(ret, azim)
return np.array([[c, s], [s, -c]])

@staticmethod
def linear_polarizer(theta):
'''Linear Polarizer
Args:
theta: angle that light can pass through
Returns: Jones matrix
'''
c = np.cos(theta)
s = np.sin(theta)
J00 = c ** 2
J11 = s ** 2
J01 = s * c
J10 = J01
return np.array([[J00, J01], [J10, J11]])

@staticmethod
def right_circular_polarizer():
'''Right Circular Polarizer'''
return 1 / 2 * np.array([[1, -1j], [1j, 1]])

@staticmethod
def left_circular_polarizer():
'''Left Circular Polarizer'''
return 1 / 2 * np.array([[1, 1j], [-1j, 1]])
@staticmethod
def right_circular_retarder(ret):
'''Right Circular Retarder'''
return JonesMatrixGenerators.rotator(-ret / 2)
@staticmethod
def left_circular_retarder(ret):
'''Left Circular Retarder'''
return JonesMatrixGenerators.rotator(ret / 2)

@staticmethod
def polscope_analyzer():
'''Acts as a circular polarizer
Inhomogeneous elements because eigenvectors are linear (-45 deg) and
(right) circular polarization states
Source: 2010 Polarized Light pg. 224'''
return 1 / (2 * np.sqrt(2)) * np.array([[1 + 1j, 1 - 1j], [1 + 1j, 1 - 1j]])

@staticmethod
def universal_compensator(retA, retB):
'''Universal Polarizer
Used as the polarizer for the LC-PolScope'''
LP = JonesMatrixGenerators.linear_polarizer(0)
LCA = JonesMatrixGenerators.linear_retarder(retA, -np.pi / 4)
LCB = JonesMatrixGenerators.linear_retarder_azim0(retB)
return LCB @ LCA @ LP

@staticmethod
def universal_compensator_modes(setting=0, swing=0):
'''Settings for the LC-PolScope polarizer
Parameters:
setting (int): LC-PolScope setting number between 0 and 4
swing (float): proportion of wavelength, for ex 0.03
Returns:
Jones matrix'''
swing_rad = swing * 2 * np.pi
if setting == 0:
retA = np.pi / 2
retB = np.pi
elif setting == 1:
retA = np.pi / 2 + swing_rad
retB = np.pi
elif setting == 2:
retA = np.pi / 2
retB = np.pi + swing_rad
elif setting == 3:
retA = np.pi / 2
retB = np.pi - swing_rad
elif setting == 4:
retA = np.pi / 2 - swing_rad
retB = np.pi
return JonesMatrixGenerators.universal_compensator(retA, retB)


class JonesVectorGenerators(BirefringentElement):
'''2x1 Jones vectors representing various states of polarized light'''
def __init__(self, backend : BackEnds = BackEnds.NUMPY):
super(BirefringentElement, self).__init__(backend=backend, torch_args={}, optical_info={})

@staticmethod
def right_circular():
'''Right circularly polarized light'''
return np.array([1, -1j]) / np.sqrt(2)

@staticmethod
def left_circular():
'''Left circularly polarized light'''
return np.array([1, 1j]) / np.sqrt(2)

@staticmethod
def linear(angle):
'''Linearlly polarized light at an angle in radians'''
return JonesMatrixGenerators.rotator(angle) @ np.array([1, 0])

@staticmethod
def horizonal():
'''Horizontally polarized light'''
return np.array([1, 0])

@staticmethod
def vertical():
'''Vertically polarized light'''
return np.array([0, 1])
23 changes: 19 additions & 4 deletions VolumeRaytraceLFM/optic_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,41 @@
except:
pass

class OpticBlock(nn.Module): # pure virtual class
class OpticBlock(nn.Module):
"""Base class containing all the basic functionality of an optic block"""

def __init__(
self, optic_config=None, members_to_learn=None,
): # Contains a list of members which should be optimized (In case none are provided members are created without gradients)
):
"""
Initialize the OpticBlock.
Args:
optic_config (optional): Configuration for the optic block. Defaults to None.
members_to_learn (optional): List of members to be optimized. Defaults to None.
"""
super(OpticBlock, self).__init__()
self.optic_config = optic_config
self.members_to_learn = [] if members_to_learn is None else members_to_learn
self.device_dummy = nn.Parameter(torch.tensor([1.0]))


def get_trainable_variables(self):
"""
Get the trainable variables of the optic block.
Returns:
list: List of trainable variables.
"""
trainable_vars = []
for name, param in self.named_parameters():
if name in self.members_to_learn:
trainable_vars.append(param)
return list(trainable_vars)

def get_device(self):
"""
Get the device of the optic block.
Returns:
torch.device: The device of the optic block.
"""
return self.device_dummy.device


Expand Down
Loading

0 comments on commit 3ccf271

Please sign in to comment.