diff --git a/.gitignore b/.gitignore index 2d10d1c..4d3b05d 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,5 @@ playground/* data/* *.pkl src/bir_tomo.egg-info/* +assets/style.css +in_progress/* diff --git a/config/iter_config.json b/config/iter_config.json index 7b4181c..40e7da8 100644 --- a/config/iter_config.json +++ b/config/iter_config.json @@ -1,14 +1,17 @@ { - "n_epochs": 20, + "n_epochs": 100, "regularization_weight": 0.1, - "lr_birefringence": 1e-2, + "lr": 1e-3, + "lr_birefringence": 1e-3, "lr_optic_axis": 1e-1, "optimizer": "Nadam", "datafidelity": "euler", "regularization_fcns": [ ["birefringence active L2", 0], - ["birefringence active negative penalty", 0] + ["birefringence active negative penalty", 0], + ["birefringence mask", 1000] ], + "nerf_mode": false, "from_simulation": true, "mla_rays_at_once": true, "two_optic_axis_components": true, diff --git a/config/iter_config_sphere.json b/config/iter_config_sphere.json index b3b07bd..bb1369c 100644 --- a/config/iter_config_sphere.json +++ b/config/iter_config_sphere.json @@ -1,6 +1,7 @@ { "n_epochs": 200, "regularization_weight": 0.5, + "lr": 1e-3, "lr_birefringence": 1e-3, "lr_optic_axis": 1e-1, "bir_betas": [0.6, 0.9], @@ -9,7 +10,8 @@ "datafidelity": "euler", "regularization_fcns": [ ["birefringence active L2", 1000], - ["birefringence active negative penalty", 1000] + ["birefringence active negative penalty", 1000], + ["birefringence mask", 0] ], "from_simulation": true, "vox_indices_by_mla_idx_path": "", diff --git a/src/VolumeRaytraceLFM/abstract_classes.py b/src/VolumeRaytraceLFM/abstract_classes.py index febd3cd..aa1ef65 100644 --- a/src/VolumeRaytraceLFM/abstract_classes.py +++ b/src/VolumeRaytraceLFM/abstract_classes.py @@ -257,6 +257,30 @@ def safe_ravel_index(vox, microlens_offset, volume_shape): assert x >= 0 and y >= 0 and z >= 0, "Negative index detected" return RayTraceLFM.ravel_index((x, y, z), volume_shape) + @staticmethod + def unravel_index(idx, dims): + """Convert an array of 1D indices to 3D indices. + TODO: avoid idx being replaced by a zero tensor""" + if isinstance(idx, torch.Tensor): + c = torch.cumprod(torch.tensor([1] + dims[::-1], dtype=idx.dtype), dim=0)[ + :-1 + ].flip(0) + x = [] + for factor in c: + x.append(idx // factor) + idx %= factor + idx_3d = torch.stack(x, dim=-1) + else: + # Ensure idx is a numpy array + idx = np.asarray(idx) + c = np.cumprod([1] + dims[::-1])[:-1][::-1] + x = [] + for factor in c: + x.append(idx // factor) + idx %= factor + idx_3d = np.stack(x, axis=-1) + return idx_3d + @staticmethod def rotation_matrix(axis, angle): """Generates the rotation matrix that will rotate a 3D vector diff --git a/src/VolumeRaytraceLFM/birefringence_implementations.py b/src/VolumeRaytraceLFM/birefringence_implementations.py index 04dd965..5eeedb0 100644 --- a/src/VolumeRaytraceLFM/birefringence_implementations.py +++ b/src/VolumeRaytraceLFM/birefringence_implementations.py @@ -9,7 +9,25 @@ from collections import Counter from VolumeRaytraceLFM.abstract_classes import * from VolumeRaytraceLFM.birefringence_base import BirefringentElement +from VolumeRaytraceLFM.nerf import ( + ImplicitRepresentationMLP, + ImplicitRepresentationMLPSpherical, +) from VolumeRaytraceLFM.file_manager import VolumeFileManager +from VolumeRaytraceLFM.volumes.modification import ( + pad_to_region_shape, + crop_to_region_shape, +) +from VolumeRaytraceLFM.volumes.generation import ( + generate_single_voxel_volume, + generate_random_volume, + generate_planes_volume, + generate_ellipsoid_volume, +) +from VolumeRaytraceLFM.volumes.optic_axis import ( + spherical_to_unit_vector_torch, + unit_vector_to_spherical, +) from VolumeRaytraceLFM.jones.jones_calculus import ( JonesMatrixGenerators, JonesVectorGenerators, @@ -31,7 +49,6 @@ DEBUG = False - if DEBUG: from VolumeRaytraceLFM.utils.error_handling import check_for_inf_or_nan from utils import errors @@ -523,80 +540,6 @@ def get_vox_params(self, vox_idx): axis = self.optic_axis[:, vox_idx] return self.Delta_n[vox_idx], axis - @staticmethod - def crop_to_region_shape(delta_n, optic_axis, volume_shape, region_shape): - """ - Parameters: - delta_n (np.array): 3D array with dimension volume_shape - optic_axis (np.array): 4D array with dimension (3, *volume_shape) - volume_shape (np.array): dimensions of object volume - region_shape (np.array): dimensions of the region fitting the object, - values must be greater than volume_shape - Returns: - cropped_delta_n (np.array): 3D array with dimension region_shape - cropped_optic_axis (np.array): 4D array with dimension (3, *region_shape) - """ - assert ( - volume_shape >= region_shape - ).all(), "Error: volume_shape must be greater than region_shape" - crop_start = (volume_shape - region_shape) // 2 - crop_end = crop_start + region_shape - cropped_delta_n = delta_n[ - crop_start[0] : crop_end[0], - crop_start[1] : crop_end[1], - crop_start[2] : crop_end[2], - ] - cropped_optic_axis = optic_axis[ - :, - crop_start[0] : crop_end[0], - crop_start[1] : crop_end[1], - crop_start[2] : crop_end[2], - ] - return cropped_delta_n, cropped_optic_axis - - @staticmethod - def pad_to_region_shape(delta_n, optic_axis, volume_shape, region_shape): - """ - Parameters: - delta_n (np.array): 3D array with dimension volume_shape - optic_axis (np.array): 4D array with dimension (3, *volume_shape) - volume_shape (np.array): dimensions of object volume - region_shape (np.array): dimensions of the region fitting the object, - values must be less than volume_shape - Returns: - padded_delta_n (np.array): 3D array with dimension region_shape - padded_optic_axis (np.array): 4D array with dimension (3, *region_shape) - """ - assert ( - volume_shape <= region_shape - ).all(), "Error: volume_shape must be less than region_shape" - z_, y_, x_ = region_shape - z, y, x = volume_shape - z_pad = abs(z_ - z) - y_pad = abs(y_ - y) - x_pad = abs(x_ - x) - padded_delta_n = np.pad( - delta_n, - ( - (z_pad // 2, z_pad // 2 + z_pad % 2), - (y_pad // 2, y_pad // 2 + y_pad % 2), - (x_pad // 2, x_pad // 2 + x_pad % 2), - ), - mode="constant", - ).astype(np.float64) - padded_optic_axis = np.pad( - optic_axis, - ( - (0, 0), - (z_pad // 2, z_pad // 2 + z_pad % 2), - (y_pad // 2, y_pad // 2 + y_pad % 2), - (x_pad // 2, x_pad // 2 + x_pad % 2), - ), - mode="constant", - constant_values=np.sqrt(3), - ).astype(np.float64) - return padded_delta_n, padded_optic_axis - @staticmethod def init_from_file(h5_file_path, backend=BackEnds.NUMPY, optical_info=None): """Loads a birefringent volume from an h5 file and places it in the center of the volume. @@ -611,11 +554,11 @@ def init_from_file(h5_file_path, backend=BackEnds.NUMPY, optical_info=None): if (delta_n.shape == region_shape).all(): pass elif (delta_n.shape >= region_shape).all(): - delta_n, optic_axis = BirefringentVolume.crop_to_region_shape( + delta_n, optic_axis = crop_to_region_shape( delta_n, optic_axis, delta_n.shape, region_shape ) elif (delta_n.shape <= region_shape).all(): - delta_n, optic_axis = BirefringentVolume.pad_to_region_shape( + delta_n, optic_axis = pad_to_region_shape( delta_n, optic_axis, delta_n.shape, region_shape ) else: @@ -785,13 +728,9 @@ def _init_ellipsoid_or_shell(self, volume_shape, init_mode, init_args): self._apply_shell_modification() def _apply_shell_modification(self): - if self.backend == BackEnds.PYTORCH: - with torch.no_grad(): - self.get_delta_n()[ - : self.optical_info["volume_shape"][0] // 2 + 2, ... - ] = 0 - else: - self.get_delta_n()[: self.optical_info["volume_shape"][0] // 2 + 2, ...] = 0 + self.voxel_parameters[0, ...][ + : self.optical_info["volume_shape"][0] // 2 + 2, ... + ] = 0 def _set_volume_ref(self): volume_ref = BirefringentVolume( @@ -1148,7 +1087,6 @@ def create_dummy_volume( "init_args": sphere_args, }, ) - # elif 'my_volume:' # Feel free to add new volumes here else: raise NotImplementedError return volume @@ -1199,6 +1137,46 @@ def __init__( "Stacking": 0, } self.check_errors = False + self.use_nerf = False + self.inr_model = None + + def initialize_nerf_mode(self, use_nerf=True): + """Initialize the NeRF mode based on the user's preference. + Args: + use_nerf (bool): Flag to enable or disable NeRF mode. Default is True. + """ + self.use_nerf = use_nerf + if self.use_nerf: + self.inr_model = ImplicitRepresentationMLP(3, 4, [256, 128, 64]) + # self.inr_model = ImplicitRepresentationMLP(3, 4, [256, 256, 256, 256, 256]) + self.inr_model = ImplicitRepresentationMLPSpherical(3, 3, [256, 256, 256]) + self.inr_model = torch.nn.DataParallel(self.inr_model) + print("NeRF mode initialized.") + else: + self.inr_model = None + print("NeRF mode is disabled.") + + def save_nerf_model(self, filepath): + """Save the NeRF model to a file.""" + if self.use_nerf: + torch.save(self.inr_model.state_dict(), filepath) + print(f"Saved the NeRF model to {filepath}") + else: + print("NERF is not enabled, no model to save.") + + def load_nerf_model(self, filepath, eval_mode=False): + """Load the NeRF model from a file. + Args: + filepath (str): Path to the saved model file. + eval_mode (bool): Whether to set the model to evaluation mode. Default is False. + """ + if self.use_nerf: + self.inr_model.load_state_dict(torch.load(filepath)) + if eval_mode: + self.inr_model.eval() # Set the model to evaluation mode if needed + print(f"Loaded the NeRF model from {filepath}") + else: + print("NERF is not enabled, no model to load.") def __str__(self): info = [ @@ -1313,18 +1291,11 @@ def reset_timing_info(self): def to_device(self, device): """Move the BirefringentRaytraceLFM to a device""" - # self.ray_valid_indices = self.ray_valid_indices.to(device) - ## The following is needed for retrieving the voxel parameters - # self.volume.active_idx2spatial_idx_tensor.to(device) self.ray_valid_indices = self.ray_valid_indices.to(device) self.ray_direction_basis = self.ray_direction_basis.to(device) self.ray_vol_colli_lengths = self.ray_vol_colli_lengths.to(device) - err_msg = "Moving a BirefringentRaytraceLFM instance to a device has not been implemented yet." - raise_error = False - if raise_error: - raise NotImplementedError(err_msg) - else: - print("Note: ", err_msg) + if self.use_nerf: + self.inr_model = self.inr_model.to(device) def get_volume_reachable_region(self): """Returns a binary mask where the MLA's can reach into the volume""" @@ -1922,9 +1893,14 @@ def calc_cummulative_JM_of_ray_torch( try: start_time_gather_params = time.perf_counter() # Extract the birefringence and optic axis information from the volume - Delta_n, opticAxis = self.retrieve_properties_from_vox_idx( - volume_in, voxels_of_segs_tensor.long(), active_props_only=alt_props - ) + if self.use_nerf: + Delta_n, opticAxis = self.retrieve_properties_from_vox_idx_mlp( + volume_in, voxels_of_segs_tensor.long() + ) + else: + Delta_n, opticAxis = self.retrieve_properties_from_vox_idx( + volume_in, voxels_of_segs_tensor.long(), active_props_only=alt_props + ) end_time_gather_params = time.perf_counter() self.times["gather_params_for_voxRayJM"] += ( end_time_gather_params - start_time_gather_params @@ -2016,6 +1992,55 @@ def retrieve_properties_from_vox_idx( return Delta_n, opticAxis.permute(1, 0, 2) + def retrieve_properties_from_vox_idx_mlp(self, volume, vox): + """Retrieves the birefringence and optic axis from the volume + based on the provided voxel indices using an MLP. This function + is used to retrieve the properties of the voxels that each ray + segment interacts with. + + Args: + volume (BirefringentVolume): Birefringent volume object. + vox (torch.Tensor): Voxel indices in 1D. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Birefringence and optic axis. + """ + vol_shape = self.optical_info["volume_shape"] + filtered_vox = vox[self.mask[vox]] + vox_copy = filtered_vox.clone() + vox_3d = RayTraceLFM.unravel_index(vox_copy, vol_shape) + vox_3d_float = vox_3d.float().to(volume.Delta_n.device) + + # Normalize the input coordinates based on volume shape + vol_shape_tensor = torch.tensor( + vol_shape, dtype=vox_3d_float.dtype, device=vox_3d_float.device + ) + vox_3d_float = vox_3d_float / vol_shape_tensor + + # Pass the input through the MLP + properties_at_3d_position = self.inr_model(vox_3d_float) + + # Retrieve Delta_n and opticAxis from the MLP output + Delta_n_filtered = properties_at_3d_position[..., 0] + if properties_at_3d_position.shape[-1] == 3: + spherical_angles = properties_at_3d_position[..., 1:] + opticAxis_filtered = spherical_to_unit_vector_torch(spherical_angles) + else: + opticAxis_filtered = properties_at_3d_position[..., 1:] + + # Initialize with zeros and fill in with the filtered values + Delta_n = torch.zeros( + vox.shape, dtype=Delta_n_filtered.dtype, device=Delta_n_filtered.device + ) + opticAxis = torch.zeros( + (*vox.shape, 3), + dtype=opticAxis_filtered.dtype, + device=opticAxis_filtered.device, + ) + Delta_n[self.mask[vox]] = Delta_n_filtered + opticAxis[self.mask[vox], :] = opticAxis_filtered + return Delta_n, opticAxis.permute(0, 2, 1) + def _get_default_jones(self): """Returns the default Jones Matrix for a ray that does not interact with any voxels. This is the identity matrix. @@ -2646,6 +2671,8 @@ def vox_ray_matrix(self, ret, azim): jones = jones_matrix.calculate_jones_torch( ret, azim, nonzeros_only=self.only_nonzero_for_jones ) + # self.times["Diag-Offdiag"] = 0 + # self.times["Stacking"] = 0 if DEBUG: assert not torch.isnan( jones diff --git a/src/VolumeRaytraceLFM/jones/jones_matrix.py b/src/VolumeRaytraceLFM/jones/jones_matrix.py index fa30ecd..1edb9ff 100644 --- a/src/VolumeRaytraceLFM/jones/jones_matrix.py +++ b/src/VolumeRaytraceLFM/jones/jones_matrix.py @@ -32,6 +32,20 @@ def print_ret_azim_numpy(ret, azim): def vox_ray_ret_azim_torch(bir, optic_axis, rayDir, ell, wavelength): + """Calculate the retardance and azimuth angle for a given + birefringence, optic axis, and ray direction. + + Args: + bir (torch.Tensor): Birefringence values. Shape: [num_voxels, intersection_rows] + optic_axis (torch.Tensor): Optic axis vectors. Shape: [num_voxels, 3, intersection_rows] + rayDir (torch.Tensor): Ray direction vectors. Shape: [3, num_voxels, 3] + ell (torch.Tensor): Path lengths. Shape: [num_voxels, intersection_rows] + wavelength (float): Wavelength of light. + + Returns: + ret (torch.Tensor): Retardance values. Shape: [num_voxels, intersection_rows] + azim (torch.Tensor): Azimuth angles. Shape: [num_voxels, intersection_rows] + """ pi_tensor = torch.tensor(np.pi, device=bir.device, dtype=bir.dtype) # Dot product of optical axis and 3 ray-direction vectors OA_dot_rayDir = (rayDir.unsqueeze(2) @ optic_axis).squeeze(2) @@ -60,6 +74,7 @@ def normalized_projection_torch(optic_axis, rayDir): def calculate_vox_ray_ret_azim_torch( bir, optic_axis, rayDir, ell, wavelength, nonzeros_only=False ): + # TODO: update the nonzero_only version now that bir is not 1D if nonzeros_only: # Faster when the number of non-zero elements is large nonzero_indices = bir.nonzero() diff --git a/src/VolumeRaytraceLFM/metrics/metric.py b/src/VolumeRaytraceLFM/metrics/metric.py index 0e0b57d..d5c2c5d 100644 --- a/src/VolumeRaytraceLFM/metrics/metric.py +++ b/src/VolumeRaytraceLFM/metrics/metric.py @@ -16,6 +16,10 @@ neg_penalty_bir_active, pos_penalty_bir_active, pos_penalty_l2_bir_active, + masked_zero_loss, + l2_biref, + pos_penalty_biref, + pos_penalty_l2_biref, ) @@ -29,6 +33,10 @@ "birefringence active negative penalty": neg_penalty_bir_active, "birefringence active positive penalty": pos_penalty_bir_active, "birefringence active positive penalty L2": pos_penalty_l2_bir_active, + "birefringence mask": masked_zero_loss, + "biref L2": l2_biref, + "biref positive penalty": pos_penalty_biref, + "biref positive penalty L2": pos_penalty_l2_biref, } @@ -57,6 +65,7 @@ def __init__(self, params=None, json_file=None): self.optimizer = "Adam" self.datafidelity = "vector" self.regularization_fcns = [] + self.mask = None def set_retardance_target(self, target): self.target_retardance = target @@ -209,7 +218,10 @@ def compute_regularization_term(self, data): # Sum up the rest of the regularization terms if any for reg_fcn, weight in self.regularization_fcns[1:]: - term_value = weight * reg_fcn(data) * 1000 + if reg_fcn == masked_zero_loss: + term_value = weight * reg_fcn(data, self.mask) * 1000 + else: + term_value = weight * reg_fcn(data) * 1000 term_values.append(term_value) regularization_loss += term_value diff --git a/src/VolumeRaytraceLFM/metrics/regularization.py b/src/VolumeRaytraceLFM/metrics/regularization.py index 1739974..daa4a5f 100644 --- a/src/VolumeRaytraceLFM/metrics/regularization.py +++ b/src/VolumeRaytraceLFM/metrics/regularization.py @@ -53,6 +53,45 @@ def total_variation_bir_subset(volume: BirefringentVolume): return total_variation(birefringence) +def masked_zero_loss(volume: BirefringentVolume, mask: torch.Tensor): + """Compute the loss enforcing certain positions in the volume's prediction to + be zero based on a mask. + + Args: + - volume (BirefringentVolume): The volume object containing the predicted tensor. + - mask (torch.Tensor): A binary mask tensor where positions to be zeroed are marked with 1. + + Returns: + - torch.Tensor: The computed loss. + """ + # Ensure the mask is a binary tensor + mask = mask.float() + + # Invert the mask to get positions that should be zero + inverted_mask = 1 - mask + + # Apply the inverted mask to the volume's birefringence + masked_birefringence = volume.birefringence.flatten() * inverted_mask + + # Compute the loss as the mean squared error between the masked + # birefringence and a zero tensor + zero_tensor = torch.zeros_like(volume.birefringence.flatten()) + loss = F.mse_loss(masked_birefringence, zero_tensor) + return loss + + +def l2_biref(volume: BirefringentVolume): + return l2(volume.birefringence) + + +def pos_penalty_biref(volume: BirefringentVolume): + return positive_penalty(volume.birefringence) + + +def pos_penalty_l2_biref(volume: BirefringentVolume): + return positive_penalty_l2(volume.birefringence) + + class AnisotropyAnalysis: def __init__(self, volume: BirefringentVolume): self.volume = volume diff --git a/src/VolumeRaytraceLFM/my_siddon.py b/src/VolumeRaytraceLFM/my_siddon.py index 5814376..f3b450e 100644 --- a/src/VolumeRaytraceLFM/my_siddon.py +++ b/src/VolumeRaytraceLFM/my_siddon.py @@ -96,18 +96,18 @@ def siddon_midpoints(start, stop, a_list): def vox_indices(midpoints, vox_pitch): - """Identifies the voxels for which the midpoints belong by converting to + """Identifies the voxels for which the midpoints belong by converting to voxel units, then rounding down to get the voxel index used we are using to refer to the voxel""" dx, dy, dz = vox_pitch i_voxels = [] - for (x, y, z) in midpoints: + for x, y, z in midpoints: i_voxels.append((int(x / dx), int(y / dy), int(z / dz))) return i_voxels def siddon_lengths(start, stop, a_list): - """Finds length of intersections by multiplying difference in parametric + """Finds length of intersections by multiplying difference in parametric values by entire ray length""" entire_length = np.linalg.norm(stop - start) lengths = [] diff --git a/src/VolumeRaytraceLFM/nerf.py b/src/VolumeRaytraceLFM/nerf.py new file mode 100644 index 0000000..85ce9d2 --- /dev/null +++ b/src/VolumeRaytraceLFM/nerf.py @@ -0,0 +1,269 @@ +"""This script defines a PyTorch-based Implicit Neural Representation (INR) using a +Multi-Layer Perceptron (MLP) with custom sine activations and weight initialization. +The INR represents a continuous function mapping input coordinates to output properties. + +Classes: +- Sine: Custom sine activation function for use in the INR. +- ImplicitRepresentationMLP: MLP that acts as an implicit neural representation. + +Functions: +- sine_init: Custom weight initialization for sine activations. +- setup_optimizer_nerf: Sets up the optimizer for the neural network model. +- generate_voxel_grid: Generates a grid of voxel coordinates for a given volume shape. +- predict_voxel_properties: Predicts properties for each voxel in the grid using the given model. +- get_model_device: Returns the device of the parameters of the model. + +Example usage: +- Initialize the ImplicitRepresentationMLP with specified input and output dimensions. +- Generate voxel coordinates and predict properties using the model. +""" + +import torch +import torch.nn as nn +import math + + +class Sine(nn.Module): + def forward(self, x): + return torch.sin(x) + + +def sine_init(m): + with torch.no_grad(): + if isinstance(m, nn.Linear): + num_input = m.weight.size(-1) + std = math.sqrt(6 / num_input) + m.weight.uniform_(-std, std) + if m.bias is not None: + m.bias.uniform_(-std, std) + + +class ImplicitRepresentationMLP(nn.Module): + """Multi-Layer Perceptron (MLP) that acts as an + Implicit Neural Representation. + + Args: + input_dim (int): Dimensionality of the input. + output_dim (int): Dimensionality of the output. + hidden_layers (list): List of integers defining the number of + neurons in each hidden layer. + num_frequencies (int): Number of frequencies for positional encoding. + """ + + def __init__( + self, input_dim, output_dim, hidden_layers=[128, 64], num_frequencies=10 + ): + super(ImplicitRepresentationMLP, self).__init__() + self.num_frequencies = num_frequencies + self.input_dim = input_dim + self.output_dim = output_dim + + layers = [] + in_dim = input_dim * (2 * num_frequencies + 1) + for h in hidden_layers: + layers.append(nn.Linear(in_dim, h)) + layers.append(Sine()) # Using Sine activation for INR + in_dim = h + layers.append(nn.Linear(in_dim, output_dim)) + self.layers = nn.Sequential(*layers) + self._initialize_weights() + + def _initialize_weights(self): + """Initialize weights of the network using custom sine initialization.""" + for m in self.modules(): + if isinstance(m, nn.Linear): + sine_init(m) # Sine initialization for hidden layers + self._initialize_output_layer() + + def _initialize_output_layer(self): + """Initialize the weights of the output layer.""" + final_layer = self.layers[-1] + with torch.no_grad(): + final_layer.weight.uniform_(-0.01, 0.01) + final_layer.bias[0] = 0.05 # First output dimension fixed to 0.05 + final_layer.bias[1:].uniform_(-0.5, 0.5) # Initializing other biases + + def positional_encoding(self, x: torch.Tensor) -> torch.Tensor: + """Apply positional encoding to the input tensor. + Each element of x is multiplied by each frequency, effectively + encoding the input in a higher-dimensional space. + Args: + x (torch.Tensor): Input tensor of shape (N, input_dim). + Returns: + torch.Tensor: Encoded tensor of shape (N, input_dim * (2 * num_frequencies + 1)). + """ + frequencies = torch.linspace( + 0, self.num_frequencies - 1, self.num_frequencies, device=x.device + ) + frequencies = 2.0**frequencies + x_expanded = x.unsqueeze(-1) * frequencies.unsqueeze(0).unsqueeze(0) + x_sin = torch.sin(x_expanded) + x_cos = torch.cos(x_expanded) + x_encoded = torch.cat([x.unsqueeze(-1), x_sin, x_cos], dim=-1) + return x_encoded.view(x.size(0), -1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through the network. + Args: + x (torch.Tensor): Input tensor of shape (N, input_dim). + Returns: + torch.Tensor: Output tensor of shape (N, output_dim). + """ + x = self.positional_encoding(x) + x = self.layers(x) + # Scaling the outputs + x[:, 0] = torch.sigmoid(x[:, 0]) * 0.1 # First output dimension around 0.05 + x[:, 1] = torch.sigmoid(x[:, 1]) # Second output dimension between 0 and 1 + x[:, 2:4] = torch.tanh(x[:, 2:4]) # Last two dimensions between -1 and 1 + return x + + +class ImplicitRepresentationMLPSpherical(nn.Module): + """Multi-Layer Perceptron (MLP) that acts as an + Implicit Neural Representation. + + Args: + input_dim (int): Dimensionality of the input. + output_dim (int): Dimensionality of the output. + hidden_layers (list): List of integers defining the number of + neurons in each hidden layer. + num_frequencies (int): Number of frequencies for positional encoding. + """ + + def __init__( + self, input_dim, output_dim, hidden_layers=[128, 64], num_frequencies=10 + ): + super(ImplicitRepresentationMLPSpherical, self).__init__() + self.num_frequencies = num_frequencies + self.input_dim = input_dim + self.output_dim = output_dim + + layers = [] + in_dim = input_dim * (2 * num_frequencies + 1) + for h in hidden_layers: + layers.append(nn.Linear(in_dim, h)) + layers.append(Sine()) # Using Sine activation for INR + in_dim = h + layers.append(nn.Linear(in_dim, output_dim)) + self.layers = nn.Sequential(*layers) + self._initialize_weights() + + def _initialize_weights(self): + """Initialize weights of the network using custom sine initialization.""" + for m in self.modules(): + if isinstance(m, nn.Linear): + sine_init(m) # Sine initialization for hidden layers + self._initialize_output_layer() + + def _initialize_output_layer(self): + """Initialize the weights of the output layer.""" + final_layer = self.layers[-1] + with torch.no_grad(): + final_layer.weight.uniform_(-0.01, 0.01) + final_layer.bias[0] = 0.05 # First output dimension fixed to 0.05 + final_layer.bias[1:].uniform_(-0.5, 0.5) # Initializing other biases + + def positional_encoding(self, x: torch.Tensor) -> torch.Tensor: + """Apply positional encoding to the input tensor. + Each element of x is multiplied by each frequency, effectively + encoding the input in a higher-dimensional space. + Args: + x (torch.Tensor): Input tensor of shape (N, input_dim). + Returns: + torch.Tensor: Encoded tensor of shape (N, input_dim * (2 * num_frequencies + 1)). + """ + frequencies = 2.0**torch.arange(0, self.num_frequencies, device=x.device) + x_expanded = x.unsqueeze(-1) * frequencies.unsqueeze(0).unsqueeze(0) + x_sin = torch.sin(x_expanded) + x_cos = torch.cos(x_expanded) + x_encoded = torch.cat([x.unsqueeze(-1), x_sin, x_cos], dim=-1) + return x_encoded.view(x.size(0), -1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through the network. + Args: + x (torch.Tensor): Input tensor of shape (N, input_dim). + Returns: + torch.Tensor: Output tensor of shape (N, output_dim). + """ + x = self.positional_encoding(x) + x = self.layers(x) + # x[:, 1] = x[:, 1] % (2 * torch.pi) # Second output dimension between 0 and 2pi + # x[:, 2] = x[:, 2] % (torch.pi / 2) # Third output dimension between 0 and pi/2 + # x_new[:, 1] = torch.atan2(torch.sin(x[:, 1]), torch.cos(x[:, 1])) # Azimuthal angle (phi) between -pi and pi + # x_new[:, 2] = torch.acos(torch.clamp(x[:, 2], -1.0, 1.0)) # Polar angle (theta) between 0 and pi + + # Scaling and constraining the outputs + x_new = x.clone() # Clone the tensor to avoid in-place operations + x_new[:, 0] = torch.sigmoid(x[:, 0]) * 0.1 # Density output in [0, 0.1] + x_new[:, 1] = torch.remainder(x[:, 1], 2 * torch.pi) # Angle in [0, 2π] + x_new[:, 2] = torch.remainder(x[:, 2], torch.pi / 2) # Angle in [0, π/2] + x = x_new + return x + + +def setup_optimizer_nerf( + model: nn.Module, training_params: dict +) -> torch.optim.Optimizer: + """Set up the optimizer for the neural network model. + + Args: + model (nn.Module): The neural network model. + training_params (dict): Dictionary containing training parameters such as learning rate. + Returns: + torch.optim.Optimizer: Configured optimizer. + """ + inr_params = model.inr_model.parameters() + parameters = [ + { + "params": inr_params, + "lr": training_params.get("lr", 0.001), + } + ] + # optimizer_class = getattr(torch.optim, training_params.get("optimizer", "NAdam")) + # optimizer = optimizer_class(parameters) + optimizer = torch.optim.NAdam(parameters) # , lr=0.001) + return optimizer + + +def generate_voxel_grid(vol_shape: tuple) -> torch.Tensor: + """Generate a grid of voxel coordinates for a given volume shape. + Args: + vol_shape (tuple): Shape of the volume (D, H, W). + Returns: + torch.Tensor: Tensor of shape (D*H*W, 3) containing voxel + coordinates. + """ + x = torch.linspace(0, vol_shape[0] - 1, vol_shape[0]) + y = torch.linspace(0, vol_shape[1] - 1, vol_shape[1]) + z = torch.linspace(0, vol_shape[2] - 1, vol_shape[2]) + grid = torch.meshgrid(x, y, z, indexing="ij") + coords = torch.stack(grid, dim=-1).reshape(-1, 3) + return coords + + +def predict_voxel_properties(model: nn.Module, vol_shape: tuple, enable_grad=False): + """Predict properties for each voxel in the grid using the given model. + Args: + model (nn.Module): The neural network model. + vol_shape (tuple): Shape of the volume (D, H, W). + Returns: + torch.Tensor: Predicted properties reshaped to the + volume shape (D, H, W, C). + """ + device = get_model_device(model) + coords = generate_voxel_grid(vol_shape).float().to(device) + vol_shape_tensor = torch.tensor(vol_shape, dtype=coords.dtype, device=device) + coords_normalized = coords / vol_shape_tensor # Normalize coordinates if necessary + if enable_grad: + coords_normalized.requires_grad_(True) + output = model(coords_normalized) + else: + with torch.no_grad(): + output = model(coords_normalized) # .cpu() + return output.reshape(*vol_shape, -1) + + +def get_model_device(model: nn.Module): + """Returns the device of the parameters of the model.""" + return next(model.parameters()).device diff --git a/src/VolumeRaytraceLFM/reconstructions.py b/src/VolumeRaytraceLFM/reconstructions.py index d1cf918..d7126f6 100644 --- a/src/VolumeRaytraceLFM/reconstructions.py +++ b/src/VolumeRaytraceLFM/reconstructions.py @@ -49,8 +49,10 @@ from VolumeRaytraceLFM.volumes.optic_axis import ( fill_vector_based_on_nonaxial, stay_on_sphere, + spherical_to_unit_vector_torch, ) from VolumeRaytraceLFM.utils.mask_utils import filter_voxels_using_retardance +from VolumeRaytraceLFM.nerf import setup_optimizer_nerf, predict_voxel_properties DEBUG = False @@ -306,11 +308,12 @@ def __init__( self.rays = self.setup_raytracer( image=image_for_rays, filepath=saved_ray_path, device=device ) - + self.nerf_mode = self.iteration_params.get("nerf_mode", False) + self.initialize_nerf_mode(use_nerf=self.nerf_mode) self.from_simulation = self.iteration_params.get("from_simulation", False) self.apply_volume_mask = apply_volume_mask self.mask = torch.ones( - self.volume_initial_guess.Delta_n.shape[0], dtype=torch.bool + self.volume_initial_guess.Delta_n.shape[0], dtype=torch.bool, device=device ) # Volume that will be updated after each iteration @@ -421,7 +424,8 @@ def to_device(self, device): if self.volume_ground_truth is not None: self.volume_ground_truth = self.volume_ground_truth.to(device) self.rays.to_device(device) - self.mask.to(device) + self.mask = self.mask.to(device) + self.volume_pred = self.volume_pred.to(device) def save_parameters(self, output_dir, volume_type): """In progress. @@ -473,6 +477,9 @@ def setup_raytracer(self, image=None, filepath=None, device="cpu"): print(f"Raytracing time in seconds: {time.time() - start_time:.2f}") return rays + def initialize_nerf_mode(self, use_nerf=True): + self.rays.initialize_nerf_mode(use_nerf) + def mask_outside_rays(self): """Mask out volume that is outside FOV of the microscope. Original shapes of the volume are preserved.""" @@ -707,6 +714,7 @@ def _compute_loss(self, images_predicted: list): LossFcn.set_retardance_target(retardance_meas) LossFcn.set_orientation_target(azimuth_meas) LossFcn.set_intensity_list_target(intensity_imgs_meas) + LossFcn.mask = self.mask data_term = LossFcn.compute_datafidelity_term( LossFcn.datafidelity, images_predicted ) @@ -766,6 +774,17 @@ def one_iteration(self, optimizer, volume_estimation, scheduler=None): self.volume_pred.optic_axis[:, self.volume_pred.indices_active] = ( self.volume_pred.optic_axis_active ) + if self.nerf_mode: + # Update Delta_n before loss is computed so the the mask regularization is applied + vol_shape = self.optical_info["volume_shape"] + predicted_properties = predict_voxel_properties( + self.rays.inr_model, vol_shape, enable_grad=True + ) + Delta_n = predicted_properties[..., 0] + # # Gradients are lost when setting Delta_n as a torch nn parameter + # self.volume_pred.Delta_n = torch.nn.Parameter(Delta_n.flatten()) + self.volume_pred.birefringence = Delta_n + loss, data_term, regularization_term = self._compute_loss(img_list) if self.rays.verbose: tqdm.write(f"Computed the loss: {loss.item():.5}") @@ -792,7 +811,11 @@ def one_iteration(self, optimizer, volume_estimation, scheduler=None): optimizer.step() scheduler.step(loss) adj_lrs_dict = calculate_adjusted_lr(optimizer) - adjusted_lrs = [val.item() for val in adj_lrs_dict.values()] + if self.nerf_mode: + adjusted_lrs = [0] + else: + adjusted_lrs = [val.item() for val in adj_lrs_dict.values()] + if PRINT_GRADIENTS: print_moments(optimizer) @@ -910,7 +933,20 @@ def visualize_and_save(self, ep, fig, output_dir): # TODO: only update every 1 epoch if plotting is live if ep % 1 == 0: # plt.clf() - Delta_n = volume_estimation.get_delta_n().detach().unsqueeze(0) + if self.nerf_mode: + vol_shape = self.optical_info["volume_shape"] + predicted_properties = predict_voxel_properties( + self.rays.inr_model, vol_shape + ) + Delta_n = predicted_properties[..., 0] + volume_estimation.Delta_n = torch.nn.Parameter(Delta_n.flatten()) + # TODO: see if mask should be applied here + volume_estimation.Delta_n = torch.nn.Parameter( + volume_estimation.Delta_n * self.mask + ) + Delta_n = volume_estimation.get_delta_n().detach().unsqueeze(0) + else: + Delta_n = volume_estimation.get_delta_n().detach().unsqueeze(0) mip_image = convert_volume_to_2d_mip(Delta_n) mip_image_np = prepare_plot_mip(mip_image, plot=False) plot_iteration_update_gridspec( @@ -941,11 +977,23 @@ def visualize_and_save(self, ep, fig, output_dir): volume_estimation.optic_axis = torch.nn.Parameter( torch.zeros(3, vol_size_flat), requires_grad=False ).to(device) - if self.volume_pred.indices_active is not None: - with torch.no_grad(): - volume_estimation.optic_axis[ - :, volume_estimation.indices_active - ] = volume_estimation.optic_axis_active + if self.nerf_mode: + optic_axis_flat = predicted_properties.view( + -1, predicted_properties.shape[-1] + )[..., 1:] + if predicted_properties.shape[-1] == 3: + optic_axis_flat = spherical_to_unit_vector_torch(optic_axis_flat) + volume_estimation.optic_axis = torch.nn.Parameter( + optic_axis_flat.permute(1, 0) + ) + nerf_model_path = os.path.join(output_dir, f"nerf_model_{ep}.pth") + self.rays.save_nerf_model(nerf_model_path) + else: + if self.volume_pred.indices_active is not None: + with torch.no_grad(): + volume_estimation.optic_axis[ + :, volume_estimation.indices_active + ] = volume_estimation.optic_axis_active my_description = "Volume estimation after " + str(ep) + " iterations." volume_estimation.save_as_file( os.path.join(output_dir, f"volume_ep_{'{:04d}'.format(ep)}.h5"), @@ -1011,8 +1059,12 @@ def save_loss_lists_to_csv(self): self.loss_reg_term_list, self.adjusted_lrs_list, ) - for total, data_term, reg_term, (optax_lr, bir_lr) in zipped_lists: - writer.writerow([total, data_term, reg_term, optax_lr, bir_lr]) + if self.nerf_mode: + for total, data_term, reg_term, lr in zipped_lists: + writer.writerow([total, data_term, reg_term, lr]) + else: + for total, data_term, reg_term, (optax_lr, bir_lr) in zipped_lists: + writer.writerow([total, data_term, reg_term, optax_lr, bir_lr]) def _create_regularization_terms_csv(self): """Create a csv file to store the regularization terms.""" @@ -1113,11 +1165,14 @@ def reconstruct( self.volume_pred.optic_axis.requires_grad = False self.specify_variables_to_learn(param_list) - optimizer = self.optimizer_setup(self.volume_pred, self.iteration_params) - optax_betas = self.iteration_params.get("optax_betas", (0.9, 0.999)) - bir_betas = self.iteration_params.get("bir_betas", (0.9, 0.999)) - optimizer.param_groups[0]["betas"] = tuple(optax_betas) - optimizer.param_groups[1]["betas"] = tuple(bir_betas) + if self.nerf_mode: + optimizer = setup_optimizer_nerf(self.rays, self.iteration_params) + else: + optimizer = self.optimizer_setup(self.volume_pred, self.iteration_params) + optax_betas = self.iteration_params.get("optax_betas", (0.9, 0.999)) + bir_betas = self.iteration_params.get("bir_betas", (0.9, 0.999)) + optimizer.param_groups[0]["betas"] = tuple(optax_betas) + optimizer.param_groups[1]["betas"] = tuple(bir_betas) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode="min", @@ -1147,7 +1202,10 @@ def reconstruct( self.prepare_volume_for_recon(self.volume_pred) initial_lr_0 = optimizer.param_groups[0]["lr"] - initial_lr_1 = optimizer.param_groups[1]["lr"] + if self.nerf_mode: + initial_lr_1 = optimizer.param_groups[0]["lr"] + else: + initial_lr_1 = optimizer.param_groups[1]["lr"] # Parameters for learning rate warmup warmup_epochs = 10 @@ -1166,10 +1224,14 @@ def reconstruct( + (1 - warmup_start_proportion) * (ep / warmup_epochs) ) optimizer.param_groups[0]["lr"] = lr_0 - optimizer.param_groups[1]["lr"] = lr_1 + if not self.nerf_mode: + optimizer.param_groups[1]["lr"] = lr_1 else: current_lr_0 = scheduler.optimizer.param_groups[0]["lr"] - current_lr_1 = scheduler.optimizer.param_groups[1]["lr"] + if self.nerf_mode: + current_lr_1 = lr_0 + else: + current_lr_1 = scheduler.optimizer.param_groups[1]["lr"] if lr_0 != current_lr_0 or lr_1 != current_lr_1: print( f"Learning rates at iteration {ep - 1}: {lr_0:.2e}, {lr_1:.2e}" @@ -1229,3 +1291,7 @@ def reconstruct( print("Saved the final volume estimation to", vol_save_path) plt.savefig(os.path.join(self.recon_directory, "optim_final.pdf")) plt.close() + + if self.nerf_mode: + nerf_model_path = os.path.join(self.recon_directory, "nerf_model.pth") + self.rays.save_nerf_model(nerf_model_path) diff --git a/src/VolumeRaytraceLFM/volumes/generation.py b/src/VolumeRaytraceLFM/volumes/generation.py new file mode 100644 index 0000000..f24e7ea --- /dev/null +++ b/src/VolumeRaytraceLFM/volumes/generation.py @@ -0,0 +1,142 @@ +import numpy as np +from math import floor + + +def generate_single_voxel_volume( + volume_shape, delta_n=0.01, optic_axis=[1, 0, 0], offset=[0, 0, 0] +): + # Identity the center of the volume after the shifts + vox_idx = [ + volume_shape[0] // 2 + offset[0], + volume_shape[1] // 2 + offset[1], + volume_shape[2] // 2 + offset[2], + ] + vol = np.zeros((4, *volume_shape)) + vol[0, vox_idx[0], vox_idx[1], vox_idx[2]] = delta_n + vol[1:, vox_idx[0], vox_idx[1], vox_idx[2]] = np.array(optic_axis) + return vol + + +def generate_random_volume( + volume_shape, init_args={"Delta_n_range": [0, 1], "axes_range": [-1, 1]} +): + np.random.seed(42) + Delta_n = np.random.uniform( + init_args["Delta_n_range"][0], init_args["Delta_n_range"][1], volume_shape + ) + # Random axis + min_axis = init_args["axes_range"][0] + max_axis = init_args["axes_range"][1] + a_0 = np.random.uniform(min_axis, max_axis, volume_shape) + a_1 = np.random.uniform(min_axis, max_axis, volume_shape) + a_2 = np.random.uniform(min_axis, max_axis, volume_shape) + norm_A = np.sqrt(a_0**2 + a_1**2 + a_2**2) + return np.concatenate( + ( + np.expand_dims(Delta_n, axis=0), + np.expand_dims(a_0 / norm_A, axis=0), + np.expand_dims(a_1 / norm_A, axis=0), + np.expand_dims(a_2 / norm_A, axis=0), + ), + 0, + ) + + +def generate_planes_volume(volume_shape, n_planes=1, z_offset=0, delta_n=0.01): + vol = np.zeros((4, *volume_shape)) + z_size = volume_shape[0] + z_ranges = np.linspace(0, z_size - 1, n_planes * 2).astype(int) + + # Set random optic axis + optic_axis = np.random.uniform(-1, 1, [3, *volume_shape]) + norms = np.linalg.norm(optic_axis, axis=0) + vol[1:, ...] = optic_axis / norms + + if n_planes == 1: + # Birefringence + vol[0, z_size // 2 + z_offset, :, :] = delta_n + # Axis + vol[1, z_size // 2 + z_offset, :, :] = 1 + vol[2, z_size // 2 + z_offset, :, :] = 0 + vol[3, z_size // 2 + z_offset, :, :] = 0 + return vol + random_data = generate_random_volume([n_planes]) + for z_ix in range(0, n_planes): + vol[:, z_ranges[z_ix * 2] : z_ranges[z_ix * 2 + 1]] = ( + np.expand_dims(random_data[:, z_ix], [1, 2, 3]) + .repeat(1, 1) + .repeat(volume_shape[1], 2) + .repeat(volume_shape[2], 3) + ) + return vol + + +def generate_ellipsoid_volume( + volume_shape, center=[0.5, 0.5, 0.5], radius=[10, 10, 10], alpha=1, delta_n=0.01 +): + """Creates an ellipsoid with optical axis normal to the ellipsoid surface. + Args: + center [3]: [cz,cy,cx] from 0 to 1 where 0.5 is the center of the volume_shape. + radius [3]: in voxels, the radius in z,y,x for this ellipsoid. + alpha (float): Border thickness. + delta_n (float): Delta_n value of birefringence in the volume + Returns: + vol (np.array): 4D array where the first dimension represents the + birefringence and optic axis properties, and the last three + dims represents the 3D spatial locations. + """ + # Originally grabbed from https://math.stackexchange.com/questions/2931909/normal-of-a-point-on-the-surface-of-an-ellipsoid, + # then modified to do the subtraction of two ellipsoids instead. + vol = np.zeros((4, *volume_shape)) + kk, jj, ii = np.meshgrid( + np.arange(volume_shape[0]), + np.arange(volume_shape[1]), + np.arange(volume_shape[2]), + indexing="ij", + ) + # shift to center + kk = floor(center[0] * volume_shape[0]) - kk.astype(float) + jj = floor(center[1] * volume_shape[1]) - jj.astype(float) + ii = floor(center[2] * volume_shape[2]) - ii.astype(float) + + # DEBUG: checking the indices + # np.argwhere(ellipsoid_border == np.min(ellipsoid_border)) + # plt.imshow(ellipsoid_border_mask[int(volume_shape[0] / 2),:,:]) + ellipsoid_border = ( + (kk**2) / (radius[0] ** 2) + + (jj**2) / (radius[1] ** 2) + + (ii**2) / (radius[2] ** 2) + ) + hollow_inner = True + if hollow_inner: + ellipsoid_border_mask = np.abs(ellipsoid_border) <= 1 + # The inner radius could also be defined as a scaled version of the outer radius. + # inner_radius = [0.9 * r for r in radius] + inner_radius = [r - alpha for r in radius] + inner_ellipsoid_border = ( + (kk**2) / (inner_radius[0] ** 2) + + (jj**2) / (inner_radius[1] ** 2) + + (ii**2) / (inner_radius[2] ** 2) + ) + inner_mask = np.abs(inner_ellipsoid_border) <= 1 + else: + ellipsoid_border_mask = np.abs(ellipsoid_border - alpha) <= 1 + + vol[0, ...] = ellipsoid_border_mask.astype(float) + # Compute normals + kk_normal = 2 * kk / radius[0] + jj_normal = 2 * jj / radius[1] + ii_normal = 2 * ii / radius[2] + norm_factor = np.sqrt(kk_normal**2 + jj_normal**2 + ii_normal**2) + # Avoid division by zero + norm_factor[norm_factor == 0] = 1 + vol[1, ...] = (kk_normal / norm_factor) * vol[0, ...] + vol[2, ...] = (jj_normal / norm_factor) * vol[0, ...] + vol[3, ...] = (ii_normal / norm_factor) * vol[0, ...] + vol[0, ...] *= delta_n + # vol = vol.permute(0,2,1,3) + if hollow_inner: + # Hollowing out the ellipsoid + combined_mask = np.logical_and(ellipsoid_border_mask, ~inner_mask) + vol[0, ...] = vol[0, ...] * combined_mask.astype(float) + return vol diff --git a/src/VolumeRaytraceLFM/volumes/modification.py b/src/VolumeRaytraceLFM/volumes/modification.py new file mode 100644 index 0000000..3218d54 --- /dev/null +++ b/src/VolumeRaytraceLFM/volumes/modification.py @@ -0,0 +1,77 @@ +"""Functions for modifying the shape of a birefringent volume.""" + +import numpy as np + + +def pad_to_region_shape(delta_n, optic_axis, volume_shape, region_shape): + """ + Args: + delta_n (np.array): 3D array with dimension volume_shape + optic_axis (np.array): 4D array with dimension (3, *volume_shape) + volume_shape (np.array): dimensions of object volume + region_shape (np.array): dimensions of the region fitting the object, + values must be less than volume_shape + Returns: + padded_delta_n (np.array): 3D array with dimension region_shape + padded_optic_axis (np.array): 4D array with dimension (3, *region_shape) + """ + assert ( + volume_shape <= region_shape + ).all(), "Error: volume_shape must be less than region_shape" + z_, y_, x_ = region_shape + z, y, x = volume_shape + z_pad = abs(z_ - z) + y_pad = abs(y_ - y) + x_pad = abs(x_ - x) + padded_delta_n = np.pad( + delta_n, + ( + (z_pad // 2, z_pad // 2 + z_pad % 2), + (y_pad // 2, y_pad // 2 + y_pad % 2), + (x_pad // 2, x_pad // 2 + x_pad % 2), + ), + mode="constant", + ).astype(np.float64) + padded_optic_axis = np.pad( + optic_axis, + ( + (0, 0), + (z_pad // 2, z_pad // 2 + z_pad % 2), + (y_pad // 2, y_pad // 2 + y_pad % 2), + (x_pad // 2, x_pad // 2 + x_pad % 2), + ), + mode="constant", + constant_values=np.sqrt(3), + ).astype(np.float64) + return padded_delta_n, padded_optic_axis + + +def crop_to_region_shape(delta_n, optic_axis, volume_shape, region_shape): + """ + Parameters: + delta_n (np.array): 3D array with dimension volume_shape + optic_axis (np.array): 4D array with dimension (3, *volume_shape) + volume_shape (np.array): dimensions of object volume + region_shape (np.array): dimensions of the region fitting the object, + values must be greater than volume_shape + Returns: + cropped_delta_n (np.array): 3D array with dimension region_shape + cropped_optic_axis (np.array): 4D array with dimension (3, *region_shape) + """ + assert ( + volume_shape >= region_shape + ).all(), "Error: volume_shape must be greater than region_shape" + crop_start = (volume_shape - region_shape) // 2 + crop_end = crop_start + region_shape + cropped_delta_n = delta_n[ + crop_start[0] : crop_end[0], + crop_start[1] : crop_end[1], + crop_start[2] : crop_end[2], + ] + cropped_optic_axis = optic_axis[ + :, + crop_start[0] : crop_end[0], + crop_start[1] : crop_end[1], + crop_start[2] : crop_end[2], + ] + return cropped_delta_n, cropped_optic_axis diff --git a/src/VolumeRaytraceLFM/volumes/optic_axis.py b/src/VolumeRaytraceLFM/volumes/optic_axis.py index 240e790..c3cda50 100644 --- a/src/VolumeRaytraceLFM/volumes/optic_axis.py +++ b/src/VolumeRaytraceLFM/volumes/optic_axis.py @@ -1,4 +1,5 @@ import torch +import numpy as np def stay_on_sphere(optic_axis): @@ -27,3 +28,47 @@ def fill_vector_based_on_nonaxial(axis_full, axis_nonaxial): axis_full[0, :] = torch.sqrt(1 - square_sum) axis_full[0, torch.isnan(axis_full[0, :])] = 0 return axis_full + + +def spherical_to_unit_vector_np(theta, phi): + """Convert spherical angles to a unit vector. + Args: + theta (float): Azimuthal angle in radians (0 <= theta < 2*pi). + phi (float): Polar angle in radians (0 <= phi <= pi/2). + Returns: + np.ndarray: Unit vector [z, y, x] where z >= 0. + """ + x = np.sin(phi) * np.cos(theta) + y = np.sin(phi) * np.sin(theta) + z = np.cos(phi) + return np.array([z, y, x]) + + +def spherical_to_unit_vector_torch(theta_phi: torch.Tensor) -> torch.Tensor: + """Convert a batch of spherical angles to unit vectors. + Args: + theta_phi (torch.Tensor): Tensor of shape (N, 2) where each row contains + [theta, phi] angles in radians. + Returns: + torch.Tensor: Tensor of shape (N, 3) containing unit vectors [z, y, x] where z >= 0. + """ + theta = theta_phi[:, 0] + phi = theta_phi[:, 1] + x = torch.sin(phi) * torch.cos(theta) + y = torch.sin(phi) * torch.sin(theta) + z = torch.cos(phi) + return torch.stack([z, y, x], dim=-1) + + +def unit_vector_to_spherical(vector): + """Convert a unit vector to spherical angles. + Args: + vector (np.ndarray): Unit vector [z, y, x] where z >= 0. + Returns: + tuple: (theta, phi) where theta is the azimuthal angle in radians (0 <= theta < 2*pi) + and phi is the polar angle in radians (0 <= phi <= pi/2). + """ + z, y, x = vector + phi = np.arccos(z) + theta = np.arctan2(y, x) + return theta, phi diff --git a/src/VolumeRaytraceLFM/volumes/volume_args.py b/src/VolumeRaytraceLFM/volumes/volume_args.py index b7f1df4..27042c0 100644 --- a/src/VolumeRaytraceLFM/volumes/volume_args.py +++ b/src/VolumeRaytraceLFM/volumes/volume_args.py @@ -57,6 +57,11 @@ "init_args": {"Delta_n_range": [-0.01, 0], "axes_range": [-1, 1]}, } +random_neg_args_min05 = { + "init_mode": "random", + "init_args": {"Delta_n_range": [-0.05, 0], "axes_range": [-1, 1]}, +} + random_args1 = { "init_mode": "random", "init_args": {"Delta_n_range": [0, 0.01], "axes_range": [-1, 1]}, @@ -202,6 +207,16 @@ }, } +shell_args_ss3 = { + "init_mode": "shell", + "init_args": { + "radius": [15.5, 29.5, 15.5], + "center": [0.5, 0.5, 0.5], + "delta_n": -0.01, + "border_thickness": 2, + }, +} + shell_small_args = { "init_mode": "shell", "init_args": { @@ -233,7 +248,7 @@ } shell1_args = { - "init_mode": "ellipsoid", + "init_mode": "shell", "init_args": { "radius": [10.5, 15.5, 10.5], "center": [0.5, 0.5, 0.5], diff --git a/src/dataset_creation/LF_viewing.ipynb b/src/dataset_creation/LF_viewing.ipynb index f8b0fed..add9a8a 100644 --- a/src/dataset_creation/LF_viewing.ipynb +++ b/src/dataset_creation/LF_viewing.ipynb @@ -18,9 +18,10 @@ "metadata": {}, "outputs": [], "source": [ - "filename = 'sphere/0_sphere.tiff'\n", + "filename = \"sphere/0_sphere.tiff\"\n", "image = imread(filename)\n", "\n", + "\n", "def transform_into_perspective(img, n_lenses, n_pix):\n", " perspective_img = np.zeros((n_lenses * n_pix, n_lenses * n_pix))\n", " n_lenses = 33\n", @@ -31,11 +32,12 @@ " for j in range(n_pix):\n", " lfx = lx * n_pix + i\n", " lfy = ly * n_pix + j\n", - " psx = i * n_lenses + lx\n", - " psy = j * n_lenses + ly\n", + " psx = i * n_lenses + lx\n", + " psy = j * n_lenses + ly\n", " perspective_img[psx, psy] = img[lfx, lfy]\n", " return perspective_img\n", "\n", + "\n", "plt.imshow(image[0])\n", "\n", "psv_img = transform_into_perspective(image[0], 33, 17)\n", @@ -99,8 +101,8 @@ } ], "source": [ - "read_plot_vol_tiff('raw/objects/6_ell.tiff')\n", - "read_plot_img_tiff('raw/images/6_ell.tiff')" + "read_plot_vol_tiff(\"raw/objects/6_ell.tiff\")\n", + "read_plot_img_tiff(\"raw/images/6_ell.tiff\")" ] }, { @@ -130,8 +132,8 @@ } ], "source": [ - "read_plot_vol_tiff('small_sphere/objects/0_sphere.tiff', axial=4)\n", - "read_plot_img_tiff('small_sphere/images/0_sphere.tiff')" + "read_plot_vol_tiff(\"small_sphere/objects/0_sphere.tiff\", axial=4)\n", + "read_plot_img_tiff(\"small_sphere/images/0_sphere.tiff\")" ] }, { @@ -152,7 +154,7 @@ ], "source": [ "# width of camera\n", - "8*4" + "8 * 4" ] }, { @@ -177,7 +179,7 @@ "from PIL import Image\n", "from IPython.display import display\n", "\n", - "img1 = Image.open('small_sphere/objects/0_sphere.tiff', 'r')\n", + "img1 = Image.open(\"small_sphere/objects/0_sphere.tiff\", \"r\")\n", "display(img1)" ] }, @@ -189,7 +191,7 @@ "source": [ "import cv2\n", "\n", - "gray = cv2.imread('small_sphere/objects/0_sphere.tiff', cv2.IMREAD_UNCHANGED)\n", + "gray = cv2.imread(\"small_sphere/objects/0_sphere.tiff\", cv2.IMREAD_UNCHANGED)\n", "\n", "cv2.namedWindow(\"MyImage\", cv2.WINDOW_NORMAL)\n", "cv2.imshow(\"MyImage\", gray)\n", diff --git a/tests/test_birefringent_volume.py b/tests/test_birefringent_volume.py index 146ee22..3cb40af 100644 --- a/tests/test_birefringent_volume.py +++ b/tests/test_birefringent_volume.py @@ -9,6 +9,10 @@ from tests.fixtures_optical_info import optical_info_vol11 from VolumeRaytraceLFM.abstract_classes import BackEnds from VolumeRaytraceLFM.birefringence_implementations import BirefringentVolume +from VolumeRaytraceLFM.volumes.modification import ( + pad_to_region_shape, + crop_to_region_shape, +) @pytest.mark.parametrize("backend_fixture", ["numpy", "pytorch"], indirect=True) @@ -149,7 +153,7 @@ def test_crop_to_region_shape(backend_fixture): volume_shape = np.array([10, 10, 10]) region_shape = np.array([5, 5, 5]) - cropped_delta_n, cropped_optic_axis = BirefringentVolume.crop_to_region_shape( + cropped_delta_n, cropped_optic_axis = crop_to_region_shape( delta_n, optic_axis, volume_shape, region_shape ) @@ -165,7 +169,7 @@ def test_pad_to_region_shape(backend_fixture): volume_shape = np.array([5, 5, 5]) region_shape = np.array([10, 10, 10]) - padded_delta_n, padded_optic_axis = BirefringentVolume.pad_to_region_shape( + padded_delta_n, padded_optic_axis = pad_to_region_shape( delta_n, optic_axis, volume_shape, region_shape ) diff --git a/tests/test_optic_axis.py b/tests/test_optic_axis.py new file mode 100644 index 0000000..abd49fb --- /dev/null +++ b/tests/test_optic_axis.py @@ -0,0 +1,44 @@ +import numpy as np +import torch +from VolumeRaytraceLFM.volumes.optic_axis import spherical_to_unit_vector_np, unit_vector_to_spherical, spherical_to_unit_vector_torch + +def test_spherical_to_unit_vector_and_back(): + # Test angles + theta = np.pi / 4 + phi = np.pi / 6 + + # Convert to unit vector and back + unit_vector = spherical_to_unit_vector_np(theta, phi) + theta_back, phi_back = unit_vector_to_spherical(unit_vector) + + # Allow for some numerical tolerance + assert np.isclose(theta, theta_back, atol=1e-6) + assert np.isclose(phi, phi_back, atol=1e-6) + + +def test_unit_vector_to_spherical_and_back(): + # Test unit vector + vector = np.array([0.5, 0.5, np.sqrt(2)/2]) + + # Convert to spherical angles and back + theta, phi = unit_vector_to_spherical(vector) + vector_back = spherical_to_unit_vector_np(theta, phi) + + # Allow for some numerical tolerance + assert np.allclose(vector, vector_back, atol=1e-6) + + +def test_spherical_to_unit_vector(): + # Generate random angles + theta = np.random.uniform(0, 2 * np.pi, 100) + phi = np.random.uniform(0, np.pi / 2, 100) + + # Convert to unit vectors using numpy + unit_vectors_np = np.array([spherical_to_unit_vector_np(t, p) for t, p in zip(theta, phi)]) + + # Convert to unit vectors using torch + angles_torch = torch.tensor(np.stack([theta, phi], axis=-1), dtype=torch.float32) + unit_vectors_torch = spherical_to_unit_vector_torch(angles_torch).numpy() + + # Compare the results + assert np.allclose(unit_vectors_np, unit_vectors_torch, atol=1e-6)