Skip to content

Commit

Permalink
Merge pull request #108 from PolarizedLightFieldMicroscopy/nerf
Browse files Browse the repository at this point in the history
Nerf
  • Loading branch information
gschlafly authored Sep 21, 2024
2 parents 29e7bf4 + 6a61413 commit faa1609
Show file tree
Hide file tree
Showing 18 changed files with 926 additions and 138 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@ playground/*
data/*
*.pkl
src/bir_tomo.egg-info/*
assets/style.css
in_progress/*
9 changes: 6 additions & 3 deletions config/iter_config.json
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
{
"n_epochs": 20,
"n_epochs": 100,
"regularization_weight": 0.1,
"lr_birefringence": 1e-2,
"lr": 1e-3,
"lr_birefringence": 1e-3,
"lr_optic_axis": 1e-1,
"optimizer": "Nadam",
"datafidelity": "euler",
"regularization_fcns": [
["birefringence active L2", 0],
["birefringence active negative penalty", 0]
["birefringence active negative penalty", 0],
["birefringence mask", 1000]
],
"nerf_mode": false,
"from_simulation": true,
"mla_rays_at_once": true,
"two_optic_axis_components": true,
Expand Down
4 changes: 3 additions & 1 deletion config/iter_config_sphere.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{
"n_epochs": 200,
"regularization_weight": 0.5,
"lr": 1e-3,
"lr_birefringence": 1e-3,
"lr_optic_axis": 1e-1,
"bir_betas": [0.6, 0.9],
Expand All @@ -9,7 +10,8 @@
"datafidelity": "euler",
"regularization_fcns": [
["birefringence active L2", 1000],
["birefringence active negative penalty", 1000]
["birefringence active negative penalty", 1000],
["birefringence mask", 0]
],
"from_simulation": true,
"vox_indices_by_mla_idx_path": "",
Expand Down
24 changes: 24 additions & 0 deletions src/VolumeRaytraceLFM/abstract_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,30 @@ def safe_ravel_index(vox, microlens_offset, volume_shape):
assert x >= 0 and y >= 0 and z >= 0, "Negative index detected"
return RayTraceLFM.ravel_index((x, y, z), volume_shape)

@staticmethod
def unravel_index(idx, dims):
"""Convert an array of 1D indices to 3D indices.
TODO: avoid idx being replaced by a zero tensor"""
if isinstance(idx, torch.Tensor):
c = torch.cumprod(torch.tensor([1] + dims[::-1], dtype=idx.dtype), dim=0)[
:-1
].flip(0)
x = []
for factor in c:
x.append(idx // factor)
idx %= factor
idx_3d = torch.stack(x, dim=-1)
else:
# Ensure idx is a numpy array
idx = np.asarray(idx)
c = np.cumprod([1] + dims[::-1])[:-1][::-1]
x = []
for factor in c:
x.append(idx // factor)
idx %= factor
idx_3d = np.stack(x, axis=-1)
return idx_3d

@staticmethod
def rotation_matrix(axis, angle):
"""Generates the rotation matrix that will rotate a 3D vector
Expand Down
221 changes: 124 additions & 97 deletions src/VolumeRaytraceLFM/birefringence_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,25 @@
from collections import Counter
from VolumeRaytraceLFM.abstract_classes import *
from VolumeRaytraceLFM.birefringence_base import BirefringentElement
from VolumeRaytraceLFM.nerf import (
ImplicitRepresentationMLP,
ImplicitRepresentationMLPSpherical,
)
from VolumeRaytraceLFM.file_manager import VolumeFileManager
from VolumeRaytraceLFM.volumes.modification import (
pad_to_region_shape,
crop_to_region_shape,
)
from VolumeRaytraceLFM.volumes.generation import (
generate_single_voxel_volume,
generate_random_volume,
generate_planes_volume,
generate_ellipsoid_volume,
)
from VolumeRaytraceLFM.volumes.optic_axis import (
spherical_to_unit_vector_torch,
unit_vector_to_spherical,
)
from VolumeRaytraceLFM.jones.jones_calculus import (
JonesMatrixGenerators,
JonesVectorGenerators,
Expand All @@ -31,7 +49,6 @@


DEBUG = False

if DEBUG:
from VolumeRaytraceLFM.utils.error_handling import check_for_inf_or_nan
from utils import errors
Expand Down Expand Up @@ -523,80 +540,6 @@ def get_vox_params(self, vox_idx):
axis = self.optic_axis[:, vox_idx]
return self.Delta_n[vox_idx], axis

@staticmethod
def crop_to_region_shape(delta_n, optic_axis, volume_shape, region_shape):
"""
Parameters:
delta_n (np.array): 3D array with dimension volume_shape
optic_axis (np.array): 4D array with dimension (3, *volume_shape)
volume_shape (np.array): dimensions of object volume
region_shape (np.array): dimensions of the region fitting the object,
values must be greater than volume_shape
Returns:
cropped_delta_n (np.array): 3D array with dimension region_shape
cropped_optic_axis (np.array): 4D array with dimension (3, *region_shape)
"""
assert (
volume_shape >= region_shape
).all(), "Error: volume_shape must be greater than region_shape"
crop_start = (volume_shape - region_shape) // 2
crop_end = crop_start + region_shape
cropped_delta_n = delta_n[
crop_start[0] : crop_end[0],
crop_start[1] : crop_end[1],
crop_start[2] : crop_end[2],
]
cropped_optic_axis = optic_axis[
:,
crop_start[0] : crop_end[0],
crop_start[1] : crop_end[1],
crop_start[2] : crop_end[2],
]
return cropped_delta_n, cropped_optic_axis

@staticmethod
def pad_to_region_shape(delta_n, optic_axis, volume_shape, region_shape):
"""
Parameters:
delta_n (np.array): 3D array with dimension volume_shape
optic_axis (np.array): 4D array with dimension (3, *volume_shape)
volume_shape (np.array): dimensions of object volume
region_shape (np.array): dimensions of the region fitting the object,
values must be less than volume_shape
Returns:
padded_delta_n (np.array): 3D array with dimension region_shape
padded_optic_axis (np.array): 4D array with dimension (3, *region_shape)
"""
assert (
volume_shape <= region_shape
).all(), "Error: volume_shape must be less than region_shape"
z_, y_, x_ = region_shape
z, y, x = volume_shape
z_pad = abs(z_ - z)
y_pad = abs(y_ - y)
x_pad = abs(x_ - x)
padded_delta_n = np.pad(
delta_n,
(
(z_pad // 2, z_pad // 2 + z_pad % 2),
(y_pad // 2, y_pad // 2 + y_pad % 2),
(x_pad // 2, x_pad // 2 + x_pad % 2),
),
mode="constant",
).astype(np.float64)
padded_optic_axis = np.pad(
optic_axis,
(
(0, 0),
(z_pad // 2, z_pad // 2 + z_pad % 2),
(y_pad // 2, y_pad // 2 + y_pad % 2),
(x_pad // 2, x_pad // 2 + x_pad % 2),
),
mode="constant",
constant_values=np.sqrt(3),
).astype(np.float64)
return padded_delta_n, padded_optic_axis

@staticmethod
def init_from_file(h5_file_path, backend=BackEnds.NUMPY, optical_info=None):
"""Loads a birefringent volume from an h5 file and places it in the center of the volume.
Expand All @@ -611,11 +554,11 @@ def init_from_file(h5_file_path, backend=BackEnds.NUMPY, optical_info=None):
if (delta_n.shape == region_shape).all():
pass
elif (delta_n.shape >= region_shape).all():
delta_n, optic_axis = BirefringentVolume.crop_to_region_shape(
delta_n, optic_axis = crop_to_region_shape(
delta_n, optic_axis, delta_n.shape, region_shape
)
elif (delta_n.shape <= region_shape).all():
delta_n, optic_axis = BirefringentVolume.pad_to_region_shape(
delta_n, optic_axis = pad_to_region_shape(
delta_n, optic_axis, delta_n.shape, region_shape
)
else:
Expand Down Expand Up @@ -785,13 +728,9 @@ def _init_ellipsoid_or_shell(self, volume_shape, init_mode, init_args):
self._apply_shell_modification()

def _apply_shell_modification(self):
if self.backend == BackEnds.PYTORCH:
with torch.no_grad():
self.get_delta_n()[
: self.optical_info["volume_shape"][0] // 2 + 2, ...
] = 0
else:
self.get_delta_n()[: self.optical_info["volume_shape"][0] // 2 + 2, ...] = 0
self.voxel_parameters[0, ...][
: self.optical_info["volume_shape"][0] // 2 + 2, ...
] = 0

def _set_volume_ref(self):
volume_ref = BirefringentVolume(
Expand Down Expand Up @@ -1148,7 +1087,6 @@ def create_dummy_volume(
"init_args": sphere_args,
},
)
# elif 'my_volume:' # Feel free to add new volumes here
else:
raise NotImplementedError
return volume
Expand Down Expand Up @@ -1199,6 +1137,46 @@ def __init__(
"Stacking": 0,
}
self.check_errors = False
self.use_nerf = False
self.inr_model = None

def initialize_nerf_mode(self, use_nerf=True):
"""Initialize the NeRF mode based on the user's preference.
Args:
use_nerf (bool): Flag to enable or disable NeRF mode. Default is True.
"""
self.use_nerf = use_nerf
if self.use_nerf:
self.inr_model = ImplicitRepresentationMLP(3, 4, [256, 128, 64])
# self.inr_model = ImplicitRepresentationMLP(3, 4, [256, 256, 256, 256, 256])
self.inr_model = ImplicitRepresentationMLPSpherical(3, 3, [256, 256, 256])
self.inr_model = torch.nn.DataParallel(self.inr_model)
print("NeRF mode initialized.")
else:
self.inr_model = None
print("NeRF mode is disabled.")

def save_nerf_model(self, filepath):
"""Save the NeRF model to a file."""
if self.use_nerf:
torch.save(self.inr_model.state_dict(), filepath)
print(f"Saved the NeRF model to {filepath}")
else:
print("NERF is not enabled, no model to save.")

def load_nerf_model(self, filepath, eval_mode=False):
"""Load the NeRF model from a file.
Args:
filepath (str): Path to the saved model file.
eval_mode (bool): Whether to set the model to evaluation mode. Default is False.
"""
if self.use_nerf:
self.inr_model.load_state_dict(torch.load(filepath))
if eval_mode:
self.inr_model.eval() # Set the model to evaluation mode if needed
print(f"Loaded the NeRF model from {filepath}")
else:
print("NERF is not enabled, no model to load.")

def __str__(self):
info = [
Expand Down Expand Up @@ -1313,18 +1291,11 @@ def reset_timing_info(self):

def to_device(self, device):
"""Move the BirefringentRaytraceLFM to a device"""
# self.ray_valid_indices = self.ray_valid_indices.to(device)
## The following is needed for retrieving the voxel parameters
# self.volume.active_idx2spatial_idx_tensor.to(device)
self.ray_valid_indices = self.ray_valid_indices.to(device)
self.ray_direction_basis = self.ray_direction_basis.to(device)
self.ray_vol_colli_lengths = self.ray_vol_colli_lengths.to(device)
err_msg = "Moving a BirefringentRaytraceLFM instance to a device has not been implemented yet."
raise_error = False
if raise_error:
raise NotImplementedError(err_msg)
else:
print("Note: ", err_msg)
if self.use_nerf:
self.inr_model = self.inr_model.to(device)

def get_volume_reachable_region(self):
"""Returns a binary mask where the MLA's can reach into the volume"""
Expand Down Expand Up @@ -1922,9 +1893,14 @@ def calc_cummulative_JM_of_ray_torch(
try:
start_time_gather_params = time.perf_counter()
# Extract the birefringence and optic axis information from the volume
Delta_n, opticAxis = self.retrieve_properties_from_vox_idx(
volume_in, voxels_of_segs_tensor.long(), active_props_only=alt_props
)
if self.use_nerf:
Delta_n, opticAxis = self.retrieve_properties_from_vox_idx_mlp(
volume_in, voxels_of_segs_tensor.long()
)
else:
Delta_n, opticAxis = self.retrieve_properties_from_vox_idx(
volume_in, voxels_of_segs_tensor.long(), active_props_only=alt_props
)
end_time_gather_params = time.perf_counter()
self.times["gather_params_for_voxRayJM"] += (
end_time_gather_params - start_time_gather_params
Expand Down Expand Up @@ -2016,6 +1992,55 @@ def retrieve_properties_from_vox_idx(

return Delta_n, opticAxis.permute(1, 0, 2)

def retrieve_properties_from_vox_idx_mlp(self, volume, vox):
"""Retrieves the birefringence and optic axis from the volume
based on the provided voxel indices using an MLP. This function
is used to retrieve the properties of the voxels that each ray
segment interacts with.
Args:
volume (BirefringentVolume): Birefringent volume object.
vox (torch.Tensor): Voxel indices in 1D.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Birefringence and optic axis.
"""
vol_shape = self.optical_info["volume_shape"]
filtered_vox = vox[self.mask[vox]]
vox_copy = filtered_vox.clone()
vox_3d = RayTraceLFM.unravel_index(vox_copy, vol_shape)
vox_3d_float = vox_3d.float().to(volume.Delta_n.device)

# Normalize the input coordinates based on volume shape
vol_shape_tensor = torch.tensor(
vol_shape, dtype=vox_3d_float.dtype, device=vox_3d_float.device
)
vox_3d_float = vox_3d_float / vol_shape_tensor

# Pass the input through the MLP
properties_at_3d_position = self.inr_model(vox_3d_float)

# Retrieve Delta_n and opticAxis from the MLP output
Delta_n_filtered = properties_at_3d_position[..., 0]
if properties_at_3d_position.shape[-1] == 3:
spherical_angles = properties_at_3d_position[..., 1:]
opticAxis_filtered = spherical_to_unit_vector_torch(spherical_angles)
else:
opticAxis_filtered = properties_at_3d_position[..., 1:]

# Initialize with zeros and fill in with the filtered values
Delta_n = torch.zeros(
vox.shape, dtype=Delta_n_filtered.dtype, device=Delta_n_filtered.device
)
opticAxis = torch.zeros(
(*vox.shape, 3),
dtype=opticAxis_filtered.dtype,
device=opticAxis_filtered.device,
)
Delta_n[self.mask[vox]] = Delta_n_filtered
opticAxis[self.mask[vox], :] = opticAxis_filtered
return Delta_n, opticAxis.permute(0, 2, 1)

def _get_default_jones(self):
"""Returns the default Jones Matrix for a ray that does not
interact with any voxels. This is the identity matrix.
Expand Down Expand Up @@ -2646,6 +2671,8 @@ def vox_ray_matrix(self, ret, azim):
jones = jones_matrix.calculate_jones_torch(
ret, azim, nonzeros_only=self.only_nonzero_for_jones
)
# self.times["Diag-Offdiag"] = 0
# self.times["Stacking"] = 0
if DEBUG:
assert not torch.isnan(
jones
Expand Down
Loading

0 comments on commit faa1609

Please sign in to comment.