diff --git a/config/iter_config.json b/config/iter_config.json index f0880c3..1469394 100644 --- a/config/iter_config.json +++ b/config/iter_config.json @@ -1,22 +1,64 @@ { - "num_iterations": 100, - "regularization_weight": 0.1, - "lr": 1e-3, - "lr_birefringence": 1e-3, - "lr_optic_axis": 1e-1, - "optimizer": "Nadam", - "datafidelity": "euler", - "regularization_fcns": [ - ["birefringence active L2", 0], - ["birefringence active negative penalty", 0], - ["birefringence mask", 1000] - ], - "nerf_mode": false, - "from_simulation": true, - "mla_rays_at_once": true, - "two_optic_axis_components": true, - "free_memory_by_del_large_arrays": false, - "save_freq": 50, - "output_posfix": "", - "notes": "" + "general": { + "num_iterations": 1000, + "save_freq": 100, + "output_directory_postfix": "", + "notes": "" + }, + "learning_rates": { + "birefringence": 1e-4, + "optic_axis": 1e-1 + }, + "regularization": { + "weight": 1.0, + "functions": [ + ["birefringence active L2", 100], + ["birefringence active negative penalty", 0] + ] + }, + "file_paths": { + "initial_volume": null, + "saved_rays": null, + "vox_indices_by_mla_idx": null, + "ret_image": null, + "azim_image": null, + "radiometry": null + }, + "schedulers": { + "birefringence": { + "type": "ReduceLROnPlateau", + "params": { + "mode": "min", + "factor": 0.8, + "patience": 5, + "threshold": 1e-6, + "min_lr": 1e-8 + } + }, + "optic_axis": { + "type": "CosineAnnealingWarmRestarts", + "params": { + "T_0": 20, + "T_mult": 2, + "eta_min": 1e-4 + } + } + }, + "visualization": { + "plot_live": true, + "fig_size": [10, 11] + }, + "learnables": { + "all_prop_elements": false, + "two_optic_axis_components": true + }, + "misc": { + "from_simulation": true, + "save_ray_geometry": true, + "optimizer": "Nadam", + "datafidelity": "euler", + "warmup_iterations": 10, + "mla_rays_at_once": true, + "free_memory_by_del_large_arrays": false + } } \ No newline at end of file diff --git a/config/iter_config_sphere.json b/config/iter_config_sphere.json deleted file mode 100644 index 49ca7d9..0000000 --- a/config/iter_config_sphere.json +++ /dev/null @@ -1,26 +0,0 @@ -{ - "num_iterations": 200, - "regularization_weight": 0.5, - "lr": 1e-3, - "lr_birefringence": 1e-3, - "lr_optic_axis": 1e-1, - "bir_betas": [0.6, 0.9], - "optax_betas": [0.6, 0.9], - "optimizer": "Nadam", - "datafidelity": "euler", - "regularization_fcns": [ - ["birefringence active L2", 1000], - ["birefringence active negative penalty", 1000], - ["birefringence mask", 0] - ], - "nerf_mode": false, - "from_simulation": true, - "vox_indices_by_mla_idx_path": "", - "mla_rays_at_once": true, - "two_optic_axis_components": true, - "free_memory_by_del_large_arrays": false, - "save_rays": false, - "save_freq": 20, - "output_posfix": "", - "notes": "" -} \ No newline at end of file diff --git a/config/optical_config_default.json b/config/optical_config.json similarity index 93% rename from config/optical_config_default.json rename to config/optical_config.json index 653ae8e..4a4ebc8 100644 --- a/config/optical_config_default.json +++ b/config/optical_config.json @@ -1,15 +1,16 @@ { "volume_shape" : [1, 1, 1], - "axial_voxel_size_um" : 1.0, - "cube_voxels" : true, - "pixels_per_ml" : 17, "n_micro_lenses" : 1, + "pixels_per_ml" : 17, "n_voxels_per_ml" : 1, "M_obj" : 60, "na_obj" : 1.2, "n_medium" : 1.35, "wavelength" : 0.550, "camera_pix_pitch" : 6.5, + "aperture_radius_px": 7.5, + "cube_voxels" : true, + "axial_voxel_size_um" : 1.0, "polarizer" : [[1, 0], [0, 1]], "analyzer" : [[1, 0], [0, 1]], "polarizer_swing" : 0.03 diff --git a/docs/Reconstruction Configuration Parameters.md b/docs/Reconstruction Configuration Parameters.md new file mode 100644 index 0000000..eca901c --- /dev/null +++ b/docs/Reconstruction Configuration Parameters.md @@ -0,0 +1,117 @@ +# Reconstruction Configuration Key Descriptions + +This document describes the various keys and their potential values in the JSON reconstruction configuration file. + +--- + +## General Settings + +| Key | Description | Possible Values | Default | +|-----------------------------|-----------------------------------------------------------------------------|---------------------------|----------------| +| `general.num_iterations` | Number of iterations for the training or reconstruction process. | Integer (e.g., 200) | 200 | +| `general.save_freq` | Frequency for saving intermediate results (in iterations). | Integer (e.g., 100) | 100 | +| `general.output_directory_postfix` | A string appended to the output directory for easier identification. | String | `""` (empty) | +| `general.notes` | Additional notes for the configuration run. | String | `""` (empty) | + +--- + +## Learning Rates + +| Key | Description | Possible Values | Default | +|-----------------------------|-----------------------------------------------------------------------------|---------------------------|----------------| +| `learning_rates.birefringence` | Learning rate for birefringence optimization. | Float (e.g., 1e-4) | 1e-4 | +| `learning_rates.optic_axis` | Learning rate for optic axis optimization. | Float (e.g., 1e-1) | 1e-1 | + +--- + +## Regularization + +| Key | Description | Possible Values | Default | +|-----------------------------|-----------------------------------------------------------------------------|---------------------------|----------------| +| `regularization.weight` | Weight of the regularization term in the loss function. | Float (e.g., 0.5) | 0.5 | +| `regularization.functions` | List of regularization functions and their associated weights. | List (e.g., `["function_name", weight]`) | N/A | + +--- + +## File Paths + +| Key | Description | Possible Values | Default | +|-----------------------------|-----------------------------------------------------------------------------|---------------------------|----------------| +| `file_paths.initial_volume` | Filepath for the initial volume, if any. | Filepath (string) | `null` | +| `file_paths.saved_rays` | Filepath for saved rays, if any. | Filepath (string) | `null` | +| `file_paths.vox_indices_by_mla_idx` | Filepath for voxel indices mapped by MLA index. | Filepath (string) | `null` | +| `file_paths.ret_image` | Filepath for the measured retardance image. | Filepath (string) | `null` | +| `file_paths.azim_image` | Filepath for the measured azimuth image. | Filepath (string) | `null` | +| `file_paths.radiometry` | Filepath for the radiometry data. | Filepath (string) | `null` | + +--- + +## Schedulers + +| Key | Description | Possible Values | Default | +|-----------------------------|-----------------------------------------------------------------------------|---------------------------|----------------| +| `schedulers.birefringence.type` | Scheduler type for birefringence learning rate. | String (e.g., "ReduceLROnPlateau") | "ReduceLROnPlateau" | +| `schedulers.birefringence.params.mode` | Mode for scheduler, controls if the scheduler is reducing learning rates based on min/max of the loss. | `"min"`, `"max"` | `"min"` | +| `schedulers.birefringence.params.factor` | Factor by which the learning rate is reduced. | Float (e.g., 0.8) | 0.8 | +| `schedulers.birefringence.params.patience` | Number of epochs with no improvement before reducing learning rate. | Integer (e.g., 5) | 5 | +| `schedulers.birefringence.params.threshold` | Threshold for measuring new optimal value. | Float (e.g., 1e-6) | 1e-6 | +| `schedulers.birefringence.params.min_lr` | Minimum learning rate after reductions. | Float (e.g., 1e-8) | 1e-8 | +| `schedulers.optic_axis.type` | Scheduler type for optic axis learning rate. | String (e.g., "CosineAnnealingWarmRestarts") | "ReduceLROnPlateau" | +| `schedulers.optic_axis.params.T_0` | Number of iterations for the first restart cycle. | Integer (e.g., 20) | N/A | +| `schedulers.optic_axis.params.T_mult` | Multiplication factor to increase the length of each cycle. | Integer (e.g., 2) | N/A | +| `schedulers.optic_axis.params.eta_min` | Minimum learning rate during annealing. | Float (e.g., 1e-4) | N/A | + +--- + +## NeRF Settings + +| Key | Description | Possible Values | Default | +|-----------------------------|-----------------------------------------------------------------------------|---------------------------|----------------| +| `nerf.enabled` | Boolean flag to enable or disable NeRF mode. | `true`, `false` | `false` | +| `nerf.learning_rates.fc1` | Learning rate for the first NeRF fully-connected layer (fc1). | Float (e.g., 1e-2) | 1e-2 | +| `nerf.learning_rates.fc2` | Learning rate for the second NeRF fully-connected layer (fc2). | Float (e.g., 1e-4) | 1e-4 | +| `nerf.learning_rates.fc3` | Learning rate for the third NeRF fully-connected layer (fc3). | Float (e.g., 1e-4) | 1e-4 | +| `nerf.learning_rates.output` | Learning rate for the NeRF output layer. | Float (e.g., 1e-4) | 1e-4 | +| `nerf.optimizer.type` | Type of optimizer used in NeRF mode. | String (e.g., "NAdam") | `"NAdam"` | +| `nerf.optimizer.betas` | Betas for momentum terms in the NeRF optimizer. | List (e.g., `[0.9, 0.999]`) | `[0.9, 0.999]`| +| `nerf.optimizer.eps` | Epsilon value for numerical stability in NeRF optimizer. | Float (e.g., 1e-7) | 1e-7 | +| `nerf.optimizer.weight_decay` | Weight decay (L2 regularization) for the NeRF optimizer. | Float (e.g., 1e-4) | 1e-4 | +| `nerf.scheduler.type` | Scheduler type for NeRF learning rates. | String (e.g., "CosineAnnealingLR") | N/A | +| `nerf.scheduler.params` | Parameters for the NeRF scheduler. | Dictionary | N/A | +| `nerf.MLP.hidden_layers` | Hidden layers for the NeRF MLP. | List (e.g., `[256, 256, 256]`) | `[256, 256, 256]` | +| `nerf.MLP.num_frequencies` | Number of frequencies for the NeRF MLP. | Integer (e.g., 10) | 10 | +| `nerf.MLP.final_layer_bias_birefringence` | Bias for the final layer of the NeRF MLP for birefringence. | Float (e.g., -0.05) | -0.05 | +| `nerf.MLP.final_layer_weight_range` | Weight range for the final layer of the NeRF MLP for birefringence. | List (e.g., `[-0.01, 0.01]`) | `[-0.01, 0.01]` | + +--- + +## Visualization + +| Key | Description | Possible Values | Default | +|-----------------------------|-----------------------------------------------------------------------------|---------------------------|----------------| +| `visualization.plot_live` | Boolean flag to determine whether to plot the reconstruction live. | `true`, `false` | `true` | + +--- + +## Learnables + +| Key | Description | Possible Values | Default | +|-----------------------------|-----------------------------------------------------------------------------|---------------------------|----------------| +| `learnables.all_prop_elements` | Boolean flag to indicate if all properties are learned. | `true`, `false` | `false` | +| `learnables.two_optic_axis_components` | Boolean flag to indicate if two components are used for optic axis. | `true`, `false` | `true` | + +--- + +## Miscellaneous Settings + +| Key | Description | Possible Values | Default | +|-----------------------------|-----------------------------------------------------------------------------|---------------------------|----------------| +| `misc.from_simulation` | Boolean flag to indicate if data comes from simulation or real-world measurements. | `true`, `false` | `false` | +| `misc.save_ray_geometry` | Boolean flag to determine whether to save the ray geometry. | `true`, `false` | `false` | +| `misc.optimizer` | Type of optimizer used for training or reconstruction. | String (e.g., "Nadam") | `"Nadam"` | +| `misc.datafidelity` | Term used in the loss function for data fidelity. | String (e.g., "euler") | `"euler"` | +| `misc.mla_rays_at_once` | Boolean flag to process MLA rays in batches. | `true`, `false` | `true` | +| `misc.free_memory_by_del_large_arrays` | Boolean flag to free memory by deleting large arrays when possible. | `true`, `false` | `false` | +| `misc.save_to_logfile` | Boolean flag to determine whether to save the output to a logfile. | `true`, `false` | `true` | + +--- diff --git a/examples/reconstruction_basic.py b/examples/reconstruction_basic.py index 7f8cf96..44e50f9 100644 --- a/examples/reconstruction_basic.py +++ b/examples/reconstruction_basic.py @@ -28,12 +28,10 @@ # Path to the directory where the reconstruction will be saved recon_output_dir = os.path.join("..", "reconstructions", "voxel") -recon_output_dir_postfix = "postfix" -recon_directory = create_unique_directory(recon_output_dir, postfix=recon_output_dir_postfix) # Whether to continue a previous reconstruction continue_recon = False -recon_file_path = r"to be alterned.h5" +recon_init_file_path = r"to be alterned.h5" # For loading forward images that were saved in a previous reconstruction folder measurement_dir = os.path.join(recon_output_dir, "to be altered") @@ -66,8 +64,10 @@ # %% Run reconstruction recon_optical_info = optical_info.copy() iteration_params = setup_iteration_parameters(iter_config_file) +recon_dir_postfix = iteration_params["general"]["output_directory_postfix"] +recon_directory = create_unique_directory(recon_output_dir, postfix=recon_dir_postfix) if continue_recon: - initial_volume = BirefringentVolume.init_from_file(recon_file_path, BACKEND, recon_optical_info) + initial_volume = BirefringentVolume.init_from_file(recon_init_file_path, BACKEND, recon_optical_info) else: initial_volume = BirefringentVolume( backend=BACKEND, @@ -88,7 +88,7 @@ output_dir=recon_directory, device=DEVICE ) -reconstructor.reconstruct(plot_live=True) +reconstructor.reconstruct() print("Reconstruction complete") # %% diff --git a/src/VolumeRaytraceLFM/abstract_classes.py b/src/VolumeRaytraceLFM/abstract_classes.py index 92704e0..128eecf 100644 --- a/src/VolumeRaytraceLFM/abstract_classes.py +++ b/src/VolumeRaytraceLFM/abstract_classes.py @@ -74,6 +74,7 @@ class OpticalElement(OpticBlock): "pixels_per_ml": 17, "n_micro_lenses": 1, "n_voxels_per_ml": 1, + "aperture_radius_px": 7.5, # Objective lens information "M_obj": 60, "na_obj": 1.2, @@ -390,7 +391,7 @@ def calc_ray_direction(ray): """ Allows to the calculations to be done in ray-space coordinates as oppossed to laboratory coordinates - Parameters: + Args: ray (np.array): normalized 3D vector giving the direction of the light ray Returns: @@ -398,78 +399,70 @@ def calc_ray_direction(ray): ray_perp1 (np.array): normalized 3D vector ray_perp2 (np.array): normalized 3D vector """ - # in case ray is not a unit vector <- does not need to be normalized - # ray = ray / np.linalg.norm(ray) theta = np.arccos(np.dot(ray, np.array([1, 0, 0]))) # Unit vectors that give the laboratory axes, can be changed scope_axis = np.array([1, 0, 0]) scope_perp1 = np.array([0, 1, 0]) scope_perp2 = np.array([0, 0, 1]) theta = np.arccos(np.dot(ray, scope_axis)) - # print(f"Rotating by {np.around(np.rad2deg(theta), decimals=0)} degrees") + print(f"Maximum ray angle is {np.around(np.rad2deg(theta), decimals=0)} degrees") normal_vec = RayTraceLFM.find_orthogonal_vec(ray, scope_axis) Rinv = RayTraceLFM.rotation_matrix(normal_vec, -theta) # Extracting basis vectors that are orthogonal to the ray and will be parallel # to the laboratory axes that are not the optic axis after a rotation. # Note: If scope_perp1 if the y-axis, then ray_perp1 if the 2nd column of Rinv. ray_perp1 = np.dot(Rinv, scope_perp1) - ray_perp2 = np.dot(Rinv, scope_perp2) + ray_perp2 = -np.dot(Rinv, scope_perp2) return [ray, ray_perp1, ray_perp2] @staticmethod def calc_ray_direction_torch(ray_in): - """ - Allows to the calculations to be done in ray-space coordinates - as oppossed to laboratory coordinates - Parameters: - ray_in [n_rays,3] (torch.array): normalized 3D vector giving the direction - of the light ray + """Allows to the calculations to be done in ray-space coordinates + as oppossed to laboratory coordinates. For each ray, we calculate + a set of three ray basis vectors. + Args: + ray_in [n_rays, 3] (torch.array): normalized 3D vector giving + the direction of the light ray Returns: + torch.array: [3, n_rays, 3] where... ray (torch.array): same as input ray_perp1 (torch.array): normalized 3D vector ray_perp2 (torch.array): normalized 3D vector """ - if not torch.is_tensor(ray_in): - ray = torch.from_numpy(ray_in) - else: - ray = ray_in - theta = torch.arccos( - torch.linalg.multi_dot( - (ray, torch.tensor([1.0, 0, 0], dtype=ray.dtype, device=ray_in.device)) - ) - ) + ray = torch.from_numpy(ray_in) if not torch.is_tensor(ray_in) else ray_in # Unit vectors that give the laboratory axes, can be changed - scope_axis = torch.tensor([1.0, 0, 0], dtype=ray.dtype, device=ray_in.device) - scope_perp1 = torch.tensor([0, 1.0, 0], dtype=ray.dtype, device=ray_in.device) - scope_perp2 = torch.tensor([0, 0, 1.0], dtype=ray.dtype, device=ray_in.device) - # print(f"Rotating by {np.around(torch.rad2deg(theta).numpy(), decimals=0)} degrees") + scope_axis = torch.tensor([1.0, 0, 0], dtype=ray.dtype, device=ray.device) + scope_perp1 = torch.tensor([0, 1.0, 0], dtype=ray.dtype, device=ray.device) + scope_perp2 = torch.tensor([0, 0, 1.0], dtype=ray.dtype, device=ray.device) + theta = torch.arccos(torch.matmul(ray, scope_axis)) + print(f"Maximum ray angle is {torch.round(torch.rad2deg(theta).max(), decimals=0)} degrees") normal_vec = RayTraceLFM.find_orthogonal_vec_torch(ray, scope_axis) Rinv = RayTraceLFM.rotation_matrix_torch(normal_vec, -theta) # Extracting basis vectors that are orthogonal to the ray and will be parallel # to the laboratory axes that are not the optic axis after a rotation. - # Note: If scope_perp1 if the y-axis, then ray_perp1 if the 2nd column of Rinv. - if scope_perp1[0] == 0 and scope_perp1[1] == 1 and scope_perp1[2] == 0: + # Note: If scope_perp1 is the y-axis, then ray_perp1 is the 2nd column of Rinv. + if torch.equal(scope_perp1, torch.tensor([0, 1.0, 0], dtype=ray.dtype)): ray_perp1 = Rinv[:, :, 1] # dot product needed else: - # todo: we need to put a for loop to do this operation # ray_perp1 = torch.linalg.multi_dot((Rinv, scope_perp1)) raise NotImplementedError - if scope_perp2[0] == 0 and scope_perp2[1] == 0 and scope_perp2[2] == 1: - ray_perp2 = Rinv[:, :, 2] + if torch.equal(scope_perp2, torch.tensor([0, 0, 1.0], dtype=ray.dtype)): + ray_perp2 = -Rinv[:, :, 2] else: - # todo: we need to put a for loop to do this operation # ray_perp2 = torch.linalg.multi_dot((Rinv, scope_perp2)) raise NotImplementedError + + # Unsqueeze tensors to make them of shape [1, n_rays, 3] + ray = ray.unsqueeze(0) + ray_perp1 = ray_perp1.unsqueeze(0) + ray_perp2 = ray_perp2.unsqueeze(0) - # Returns a list size 3, where each element is a torch tensor shaped [n_rays, 3] - return torch.cat( - [ray.unsqueeze(0), ray_perp1.unsqueeze(0), ray_perp2.unsqueeze(0)], 0 - ) + return torch.cat([ray, ray_perp1, ray_perp2], dim=0) ########################################################################################### # Ray-tracing functions @staticmethod - def rays_through_vol(pixels_per_ml, naObj, nMedium, volume_ctr_um): + def rays_through_vol(pixels_per_ml, naObj, nMedium, volume_ctr_um, aperture_radius_px): """Identifies the rays that pass through the volume and the central lenslet Args: pixels_per_ml (int): number of pixels per microlens in one direction, @@ -479,6 +472,8 @@ def rays_through_vol(pixels_per_ml, naObj, nMedium, volume_ctr_um): nMedium (float): refractive index of the volume volume_ctr_um (np.array): 3D vector containing the coordinates of the center of the volume in volume space units (um) + aperture_radius_px (float): radius of the effective aperture of a microlens + in pixels, about pixels_per_ml/2 Returns: ray_enter (np.array): (3, X, X) array where (3, i, j) gives the coordinates within the volume ray entrance plane for which the @@ -492,17 +487,16 @@ def rays_through_vol(pixels_per_ml, naObj, nMedium, volume_ctr_um): # Units are in pixel indicies, referring to the pixel that is centered up 0.5 units # Ex: if ml_ctr = [8, 8], then the spatial center pixel is at [8.5, 8.5] ml_ctr = [(pixels_per_ml - 1) / 2, (pixels_per_ml - 1) / 2] - ml_radius = 7.5 # pixels_per_ml / 2 i = np.linspace(0, pixels_per_ml - 1, pixels_per_ml) j = np.linspace(0, pixels_per_ml - 1, pixels_per_ml) jv, iv = np.meshgrid(i, j) dist_from_ctr = np.sqrt((iv - ml_ctr[0]) ** 2 + (jv - ml_ctr[1]) ** 2) # Angles that reach the pixels - cam_pixels_azim = np.arctan2(jv - ml_ctr[1], iv - ml_ctr[0]) - cam_pixels_azim[dist_from_ctr > ml_radius] = np.nan - dist_from_ctr[dist_from_ctr > ml_radius] = np.nan - cam_pixels_tilt = np.arcsin(dist_from_ctr / ml_radius * naObj / nMedium) + cam_pixels_azim = np.atan2(jv - ml_ctr[1], iv - ml_ctr[0]) + cam_pixels_azim[dist_from_ctr > aperture_radius_px] = np.nan + dist_from_ctr[dist_from_ctr > aperture_radius_px] = np.nan + cam_pixels_tilt = np.arcsin(dist_from_ctr / aperture_radius_px * naObj / nMedium) # Plotting if DEBUG: @@ -781,6 +775,7 @@ def _initialize_ray_geometry(self): pixels_per_ml = self.optical_info["pixels_per_ml"] naObj = self.optical_info["na_obj"] nMedium = self.optical_info["n_medium"] + aperture_radius_px = self.optical_info["aperture_radius_px"] valid_vol_shape = ( self.optical_info["n_micro_lenses"] * self.optical_info["n_voxels_per_ml"] ) @@ -801,7 +796,7 @@ def _initialize_ray_geometry(self): # Calculate the ray geometry ray_enter, ray_exit, ray_diff = RayTraceLFM.rays_through_vol( - pixels_per_ml, naObj, nMedium, volume_ctr_um_restricted + pixels_per_ml, naObj, nMedium, volume_ctr_um_restricted, aperture_radius_px ) # Store locally diff --git a/src/VolumeRaytraceLFM/birefringence_implementations.py b/src/VolumeRaytraceLFM/birefringence_implementations.py index e19ba39..6001424 100644 --- a/src/VolumeRaytraceLFM/birefringence_implementations.py +++ b/src/VolumeRaytraceLFM/birefringence_implementations.py @@ -6,8 +6,11 @@ from math import floor from tqdm import tqdm import time +import torch +import numpy as np from collections import Counter from VolumeRaytraceLFM.abstract_classes import * +from VolumeRaytraceLFM.abstract_classes import BackEnds, RayTraceLFM from VolumeRaytraceLFM.birefringence_base import BirefringentElement from VolumeRaytraceLFM.nerf import ( ImplicitRepresentationMLP, @@ -26,8 +29,6 @@ ) from VolumeRaytraceLFM.volumes.optic_axis import ( spherical_to_unit_vector_torch, - unit_vector_to_spherical, - fill_vector_based_on_nonaxial, adjust_optic_axis_positive_axial, ) from VolumeRaytraceLFM.jones.jones_calculus import ( @@ -43,6 +44,7 @@ from VolumeRaytraceLFM.jones import jones_matrix from VolumeRaytraceLFM.utils.dict_utils import filter_keys_by_count, convert_to_tensors from VolumeRaytraceLFM.utils.error_handling import check_for_negative_values_dict +from VolumeRaytraceLFM.utils.orientation_utils import transpose_and_flip from VolumeRaytraceLFM.combine_lenslets import ( gather_voxels_of_rays_pytorch_batch, calculate_offsets_vectorized, @@ -226,6 +228,11 @@ def _handle_single_optic_axis_torch(self, optic_axis): ) self.optic_axis = optic_axis_tensor.repeat(1, *self.volume_shape) + def set_requires_grad(self, requires_grad=False): + """Set the requires_grad attribute for Delta_n and optic_axis.""" + self.Delta_n.requires_grad = requires_grad + self.optic_axis.requires_grad = requires_grad + def get_delta_n(self): """Retrieves the birefringence as a 3D array""" if self.backend == BackEnds.PYTORCH: @@ -275,13 +282,11 @@ def __iadd__(self, other): requires_grad = getattr(self.Delta_n, "requires_grad", False) if requires_grad: torch.set_grad_enabled(False) - self.Delta_n.requires_grad = False - self.optic_axis.requires_grad = False + self.set_requires_grad(False) # Perform the addition self.Delta_n += other.Delta_n self.optic_axis += other.optic_axis - # Maybe normalize axis again? # Normalize the optic axis norm = ( @@ -293,8 +298,7 @@ def __iadd__(self, other): # Re-enable gradients if they were disabled if requires_grad: - self.Delta_n.requires_grad = True - self.optic_axis.requires_grad = True + self.set_requires_grad(True) torch.set_grad_enabled(True) return self @@ -991,7 +995,7 @@ def __init__( self.use_nerf = False self.inr_model = None - def initialize_nerf_mode(self, use_nerf=True): + def initialize_nerf_mode(self, use_nerf=True, mlp_params_dict=None): """Initialize the NeRF mode based on the user's preference. Args: use_nerf (bool): Flag to enable or disable NeRF mode. Default is True. @@ -999,8 +1003,7 @@ def initialize_nerf_mode(self, use_nerf=True): self.use_nerf = use_nerf if self.use_nerf: # self.inr_model = ImplicitRepresentationMLP(3, 4, [256, 128, 64]) - # self.inr_model = ImplicitRepresentationMLP(3, 4, [256, 256, 256, 256, 256]) - self.inr_model = ImplicitRepresentationMLPSpherical(3, 3, [256, 256, 256]) + self.inr_model = ImplicitRepresentationMLPSpherical(3, 3, mlp_params_dict) self.inr_model = torch.nn.DataParallel(self.inr_model) print("NeRF mode initialized.") else: @@ -1225,10 +1228,10 @@ def store_shifted_vox_indices(self): n_ml_half = floor(n_micro_lenses / 2.0) collision_indices = self.ray_vol_colli_indices if self.verbose: - print(f"Storing shifted voxel indices for each microlens:") + print("Storing shifted voxel indices for each microlens:") row_iterable = tqdm( range(n_micro_lenses), - desc=f"Computing rows of microlenses for storing voxel indices", + desc="Computing rows of microlenses for storing voxel indices", position=1, leave=True, ) @@ -1413,6 +1416,8 @@ def ray_trace_through_volume( self.times["ray_trace_through_volume"] += ( end_time_raytrace - start_time_raytrace ) + for i, img in enumerate(full_img_list): + full_img_list[i] = transpose_and_flip(img) return full_img_list def _get_row_iterable(self, n_ml_half, odd_mla_shift): @@ -1632,10 +1637,6 @@ def calc_cummulative_JM_of_ray_torch( torch.Tensor: The cumulative Jones Matrices for the rays. torch.Size([n_rays_with_voxels, 2, 2]) """ - if False: # DEBUG - assert not all(element == voxels_of_segs[0] for element in voxels_of_segs) - # Note: if all elements of voxels_of_segs are equal, then all of - # self.ray_vol_colli_indices may be equal if False: # DEBUG print("DEBUG: making the optical info of volume and self the same") print("vol in: ", volume_in.optical_info) diff --git a/src/VolumeRaytraceLFM/jones/eigenanalysis.py b/src/VolumeRaytraceLFM/jones/eigenanalysis.py index 3959ff1..7f42666 100644 --- a/src/VolumeRaytraceLFM/jones/eigenanalysis.py +++ b/src/VolumeRaytraceLFM/jones/eigenanalysis.py @@ -95,7 +95,8 @@ def azimuth_from_jones_numpy(jones, simple=True): j12 = jones[0, 1] imag_j11 = np.imag(j11) imag_j12 = np.imag(j12) - azimuth = 0.5 * np.arctan2(imag_j12, imag_j11) + np.pi / 2.0 + azimuth = 0.5 * np.atan2(imag_j11, imag_j12) - np.pi / 4.0 + azimuth = np.remainder(azimuth, np.pi) if np.isclose(np.abs(imag_j11), 0.0) and np.isclose(np.abs(imag_j12), 0.0): azimuth = 0.0 return azimuth @@ -124,8 +125,9 @@ def azimuth_from_jones_torch(jones): ) # Compute azimuth for non-zero mask elements azimuth_non_zero = ( - 0.5 * torch.atan2(imag_j12[non_zero_mask], imag_j11[non_zero_mask]) - + torch.pi / 2.0 + 0.5 * torch.atan2(imag_j11[non_zero_mask], imag_j12[non_zero_mask]) - torch.pi / 4.0 ) + # Ensure the azimuth is within the range [0, pi) + azimuth_non_zero = torch.remainder(azimuth_non_zero, torch.pi) azimuth[non_zero_mask] = azimuth_non_zero return azimuth diff --git a/src/VolumeRaytraceLFM/jones/intensity.py b/src/VolumeRaytraceLFM/jones/intensity.py index 5352fa1..0d9af10 100644 --- a/src/VolumeRaytraceLFM/jones/intensity.py +++ b/src/VolumeRaytraceLFM/jones/intensity.py @@ -4,7 +4,7 @@ def ret_and_azim_from_intensity(image_list, swing): """Note: this function is still in development.""" if len(image_list) != 5: - raise ValueError(f"Expected 5 images, got {len(imgs)}.") + raise ValueError(f"Expected 5 images, got {len(image_list)}.") # The order of the images matters! imgs = [image_list[0], image_list[2], image_list[3], image_list[1], image_list[4]] # using arctan vs arctan2 does not seem to make a difference @@ -25,6 +25,6 @@ def ret_and_azim_from_intensity(image_list, swing): test_value = imgs[1] + imgs[2] - 2 * imgs[0] indices = np.where(test_value < 0) ret[indices] = 2 * np.pi - ret[indices] - # azim = 0.5 * np.arctan2(A, B) + np.pi / 2 - azim = 0.5 * np.arctan2(B, A) + np.pi / 2 + # azim = 0.5 * np.atan2(A, B) + np.pi / 2 + azim = 0.5 * np.atan2(B, A) + np.pi / 2 return [ret, azim] diff --git a/src/VolumeRaytraceLFM/jones/jones_matrix.py b/src/VolumeRaytraceLFM/jones/jones_matrix.py index dfe28d1..99666a0 100644 --- a/src/VolumeRaytraceLFM/jones/jones_matrix.py +++ b/src/VolumeRaytraceLFM/jones/jones_matrix.py @@ -5,7 +5,7 @@ def vox_ray_ret_azim_numpy(bir, optic_axis, rayDir, ell, wavelength): # Azimuth is the angle of the slow axis of retardance. # TODO: verify the order of these two components - azim = np.arctan2(np.dot(optic_axis, rayDir[1]), np.dot(optic_axis, rayDir[2])) + azim = np.atan2(np.dot(optic_axis, rayDir[1]), np.dot(optic_axis, rayDir[2])) azim = 0 if bir == 0 else (azim + np.pi / 2 if bir < 0 else azim) # proj_along_ray = np.dot(optic_axis, rayDir[0]) ret = ( @@ -21,11 +21,11 @@ def vox_ray_ret_azim_numpy(bir, optic_axis, rayDir, ell, wavelength): def print_ret_azim_numpy(ret, azim): print( - f"Azimuth angle of index ellipsoid is " + "Azimuth angle of index ellipsoid is " + f"{np.around(np.rad2deg(azim), decimals=0)} degrees." ) print( - f"Accumulated retardance from index ellipsoid is " + "Accumulated retardance from index ellipsoid is " + f"{np.around(np.rad2deg(ret), decimals=0)} ~ {int(np.rad2deg(ret)) % 360} degrees." ) @@ -48,8 +48,8 @@ def vox_ray_ret_azim_torch(bir, optic_axis, rayDir, ell, wavelength): pi_tensor = torch.tensor(np.pi, device=bir.device, dtype=bir.dtype) # Dot product of optical axis and 3 ray-direction vectors OA_dot_rayDir = (rayDir.unsqueeze(2) @ optic_axis).squeeze(2) - # TODO: verify x2 should be mult by the azimuth angle - azim = 2 * torch.arctan2(OA_dot_rayDir[1], OA_dot_rayDir[2]) + # There is the x2 here because it is not in the jones matrix function + azim = 2 * torch.atan2(OA_dot_rayDir[1], OA_dot_rayDir[2]) ret = abs(bir) * (1 - (OA_dot_rayDir[0]) ** 2) * ell * pi_tensor / wavelength neg_delta_mask = bir < 0 diff --git a/src/VolumeRaytraceLFM/metrics/metric.py b/src/VolumeRaytraceLFM/metrics/metric.py index d5c2c5d..5c130cf 100644 --- a/src/VolumeRaytraceLFM/metrics/metric.py +++ b/src/VolumeRaytraceLFM/metrics/metric.py @@ -49,21 +49,21 @@ def __init__(self, params=None, json_file=None): self.weight_retardance = params.get("weight_retardance", 1.0) self.weight_orientation = params.get("weight_orientation", 1.0) self.weight_datafidelity = params.get("weight_datafidelity", 1.0) - self.weight_regularization = params.get("regularization_weight", 1.0) + self.weight_regularization = params.get("regularization", {}).get("weight", 1.0) # Initialize specific loss functions - self.optimizer = params.get("optimizer", "Adam") - self.datafidelity = params.get("datafidelity", "vector") + self.optimizer = params.get("misc", {}).get("optimizer", "Nadam") + self.datafidelity = params.get("misc", {}).get("datafidelity", "euler") self.regularization_fcns = [ (REGULARIZATION_FCNS[fn_name], weight) - for fn_name, weight in params.get("regularization_fcns", []) + for fn_name, weight in params.get("regularization", {}).get("functions", []) ] else: self.weight_retardance = 1.0 self.weight_orientation = 1.0 self.weight_datafidelity = 1.0 - self.weight_regularization = 0.1 - self.optimizer = "Adam" - self.datafidelity = "vector" + self.weight_regularization = 1.0 + self.optimizer = "Nadam" + self.datafidelity = "euler" self.regularization_fcns = [] self.mask = None @@ -259,6 +259,6 @@ def __init__(self): def forward(self, predicted_volume, target_volume): """Compute the birefringence field loss""" vector_field1 = predicted_volume[0, ...] * predicted_volume[1:, ...] - vector_field2 = target_volume[0, ...] * target_volume[1, ...] + vector_field2 = target_volume[0, ...] * target_volume[1:, ...] loss = F.mse_loss(vector_field1, vector_field2) return loss diff --git a/src/VolumeRaytraceLFM/nerf.py b/src/VolumeRaytraceLFM/nerf.py index 85ce9d2..3531e83 100644 --- a/src/VolumeRaytraceLFM/nerf.py +++ b/src/VolumeRaytraceLFM/nerf.py @@ -131,12 +131,15 @@ class ImplicitRepresentationMLPSpherical(nn.Module): """ def __init__( - self, input_dim, output_dim, hidden_layers=[128, 64], num_frequencies=10 + self, input_dim, output_dim, params_dict=None ): super(ImplicitRepresentationMLPSpherical, self).__init__() - self.num_frequencies = num_frequencies self.input_dim = input_dim self.output_dim = output_dim + self.params_dict = params_dict + hidden_layers = self.params_dict.get("hidden_layers", [128, 64]) + num_frequencies = self.params_dict.get("num_frequencies", 10) + self.num_frequencies = num_frequencies layers = [] in_dim = input_dim * (2 * num_frequencies + 1) @@ -158,10 +161,13 @@ def _initialize_weights(self): def _initialize_output_layer(self): """Initialize the weights of the output layer.""" final_layer = self.layers[-1] + weight_range = self.params_dict.get("final_layer_weight_range", [-0.01, 0.01]) + birefringence_bias = self.params_dict.get("final_layer_bias_birefringence", 0.05) with torch.no_grad(): - final_layer.weight.uniform_(-0.01, 0.01) - final_layer.bias[0] = 0.05 # First output dimension fixed to 0.05 - final_layer.bias[1:].uniform_(-0.5, 0.5) # Initializing other biases + + final_layer.weight.uniform_(weight_range[0], weight_range[1]) + final_layer.bias[0] = birefringence_bias + final_layer.bias[1:].uniform_(-0.5, 0.5) def positional_encoding(self, x: torch.Tensor) -> torch.Tensor: """Apply positional encoding to the input tensor. @@ -205,7 +211,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def setup_optimizer_nerf( model: nn.Module, training_params: dict ) -> torch.optim.Optimizer: - """Set up the optimizer for the neural network model. + """Set up the optimizer for the neural network model with layer-specific learning rates. Args: model (nn.Module): The neural network model. @@ -213,16 +219,54 @@ def setup_optimizer_nerf( Returns: torch.optim.Optimizer: Configured optimizer. """ - inr_params = model.inr_model.parameters() + if isinstance(model, nn.DataParallel): + model = model.module # Access the actual model inside DataParallel + + # Extract the NeRF-specific parameters + nerf_params = training_params.get("nerf", {}) + + # Extract NeRF learning rates + lr_fc1 = nerf_params.get("learning_rates", {}).get("fc1", 1e-2) # Learning rate for fc1 + lr_fc2 = nerf_params.get("learning_rates", {}).get("fc2", 1e-4) # Learning rate for fc2 + lr_fc3 = nerf_params.get("learning_rates", {}).get("fc3", 1e-4) # Learning rate for fc3 + lr_output = nerf_params.get("learning_rates", {}).get("output", 1e-4) # Learning rate for output layer + + # Extract optimizer parameters from the JSON + optimizer_type = nerf_params.get("optimizer", {}).get("type", "NAdam") + betas = tuple(nerf_params.get("optimizer", {}).get("betas", [0.9, 0.999])) # Tuple for betas + eps = nerf_params.get("optimizer", {}).get("eps", 1e-8) + weight_decay = nerf_params.get("optimizer", {}).get("weight_decay", 1e-4) + + # Access layers from model (assuming it's an instance of ImplicitRepresentationMLPSpherical) parameters = [ + # fc1 layer + { + "params": model.layers[0].parameters(), # First Linear layer (fc1) + "lr": lr_fc1, + }, + # fc2 layer + { + "params": model.layers[2].parameters(), # Second Linear layer (fc2) + "lr": lr_fc2, + }, + # fc3 layer + { + "params": model.layers[4].parameters(), # Third Linear layer (fc3) + "lr": lr_fc3, + }, + # Output layer { - "params": inr_params, - "lr": training_params.get("lr", 0.001), - } + "params": model.layers[-1].parameters(), # Output Linear layer + "lr": lr_output, + }, ] - # optimizer_class = getattr(torch.optim, training_params.get("optimizer", "NAdam")) - # optimizer = optimizer_class(parameters) - optimizer = torch.optim.NAdam(parameters) # , lr=0.001) + # Setup the optimizer using the NAdam parameters from the JSON + optimizer = torch.optim.NAdam( + parameters, + betas=betas, # Momentum coefficients from the JSON + eps=eps, # Epsilon for numerical stability + weight_decay=weight_decay, # Weight decay for regularization + ) return optimizer diff --git a/src/VolumeRaytraceLFM/reconstructions.py b/src/VolumeRaytraceLFM/reconstructions.py index 0b89896..e20c933 100644 --- a/src/VolumeRaytraceLFM/reconstructions.py +++ b/src/VolumeRaytraceLFM/reconstructions.py @@ -1,21 +1,23 @@ """This module contains the ReconstructionConfig and Reconstructor classes.""" import sys -import copy -import time import os import json -import torch -import numpy as np -from tqdm import tqdm +import time +import copy import csv import pickle +import gc + +import torch +import numpy as np import tifffile import matplotlib.pyplot as plt +from tqdm import tqdm ## For analyzing the memory usage of a function # from memory_profiler import profile -import gc + from VolumeRaytraceLFM.abstract_classes import BackEnds from VolumeRaytraceLFM.birefringence_implementations import ( BirefringentVolume, @@ -38,6 +40,7 @@ reshape_and_crop, store_as_pytorch_parameter, ) +from VolumeRaytraceLFM.utils.orientation_utils import undo_transpose_and_flip from VolumeRaytraceLFM.utils.error_handling import ( check_for_inf_or_nan, check_for_negative_values, @@ -59,7 +62,9 @@ ) from VolumeRaytraceLFM.utils.mask_utils import filter_voxels_using_retardance from VolumeRaytraceLFM.nerf import setup_optimizer_nerf, predict_voxel_properties +from VolumeRaytraceLFM.utils.gradient_utils import monitor_gradients, clip_gradient_norms_nerf, print_grad_info from utils.logging import redirect_output_to_log, restore_output +from VolumeRaytraceLFM.volumes.compare import compare_volumes DEBUG = False PRINT_GRADIENTS = False @@ -128,7 +133,7 @@ def __init__( self.optical_info = optical_info self.retardance_image = self._to_numpy(ret_image) self.azimuth_image = self._to_numpy(azim_image) - radiometry_path = iteration_params.get("radiometry_path", None) + radiometry_path = iteration_params.get("file_paths", {}).get("radiometry", None) if radiometry_path: self.radiometry = tifffile.imread(radiometry_path) else: @@ -266,10 +271,13 @@ def __init__( omit_rays_based_on_pixels=True, apply_volume_mask=False, ): - """ - Initialize the Reconstructor with the provided parameters. + """Initialize the Reconstructor with the provided parameters. recon_info (class): containing reconstruction parameters + output_dir (str): directory to save the reconstruction results + device (str): device to run the reconstruction on + omit_rays_based_on_pixels (bool): whether to omit rays based on pixels with zero retardance + apply_volume_mask (bool): whether to apply a mask to the volume """ start_time = time.perf_counter() print(f"\nInitializing a Reconstructor, using computing device {device}") @@ -316,14 +324,14 @@ def __init__( if omit_rays_based_on_pixels: image_for_rays = self.ret_img_meas print("Omitting rays based on pixels with zero retardance.") - saved_ray_path = self.iteration_params.get("saved_ray_path", None) + saved_ray_path = self.iteration_params.get("file_paths", {}).get("saved_rays", None) self.rays = self.setup_raytracer( image=image_for_rays, filepath=saved_ray_path, device=device ) self.rays.verbose = False - self.nerf_mode = self.iteration_params.get("nerf_mode", False) + self.nerf_mode = self.iteration_params.get("nerf", {}).get("enabled", False) self.initialize_nerf_mode(use_nerf=self.nerf_mode) - self.from_simulation = self.iteration_params.get("from_simulation", False) + self.from_simulation = self.iteration_params.get("misc", {}).get("from_simulation", False) self.apply_volume_mask = apply_volume_mask self.mask = torch.ones( self.volume_initial_guess.Delta_n.shape[0], dtype=torch.bool, device=device @@ -332,7 +340,7 @@ def __init__( # Volume that will be updated after each iteration self.volume_pred = copy.deepcopy(self.volume_initial_guess) - self.remove_large_arrs = self.iteration_params.get( + self.remove_large_arrs = self.iteration_params.get("misc", {}).get( "free_memory_by_del_large_arrays", False ) if self.remove_large_arrs and self.apply_volume_mask: @@ -341,14 +349,14 @@ def __init__( "volume gradient at the same time." ) self.two_optic_axis_components = self.iteration_params.get( - "two_optic_axis_components", False - ) + "learnables", {}).get("two_optic_axis_components", True) - self.mla_rays_at_once = self.iteration_params.get("mla_rays_at_once", False) + self.mla_rays_at_once = self.iteration_params.get("misc", {}).get("mla_rays_at_once", False) if self.mla_rays_at_once and not self.rays.MLA_volume_geometry_ready: + print("Preparing rays for all rays at once...") self.rays.prepare_for_all_rays_at_once() if not self.from_simulation: - radiometry_path = self.iteration_params.get("radiometry_path", None) + radiometry_path = self.iteration_params.get("file_paths", {}).get("radiometry", None) if radiometry_path: num_rays_og = self.rays.ray_valid_indices_all.shape[1] radiometry = torch.tensor(recon_info.radiometry) @@ -366,8 +374,7 @@ def __init__( dict_save_dir = os.path.join(self.recon_directory, "config_parameters") if not os.path.exists(dict_save_dir): os.makedirs(dict_save_dir) - dict_filename = "vox_indices_by_mla_idx.pkl" - dict_save_path = os.path.join(dict_save_dir, dict_filename) + dict_save_path = os.path.join(dict_save_dir, "vox_indices_by_mla_idx.pkl") with open(dict_save_path, "wb") as f: pickle.dump(vox_indices_by_mla_idx, f) print(f"Saving voxel indices by MLA index to {dict_save_path}") @@ -377,7 +384,7 @@ def __init__( except AttributeError: self.voxel_mask_setup() - save_rays = self.iteration_params.get("save_rays", False) + save_rays = self.iteration_params.get("misc", {}).get("save_ray_geometry", False) # Ray saving should be done after self.rays.prepare_for_all_rays_at_once() if save_rays: rays_save_path = os.path.join( @@ -392,7 +399,7 @@ def __init__( pass gc.collect() - datafidelity_method = self.iteration_params.get("datafidelity", "vector") + datafidelity_method = self.iteration_params.get("misc", {}).get("datafidelity", "euler") first_word = datafidelity_method.split()[0] if first_word == "intensity": self.intensity_bool = True @@ -406,7 +413,7 @@ def __init__( self.loss_data_term_list = [] self.loss_reg_term_list = [] self.adjusted_lrs_list = [] - + self.volume_discrepancy_list = [] self.to_device(device) end_time = time.perf_counter() print(f"Reconstructor initialized in {end_time - start_time:.2f} seconds\n") @@ -454,23 +461,6 @@ def save_parameters(self, output_dir, volume_type): f"{output_dir}/parameters.pt", ) - @staticmethod - def replace_nans(volume, ep): - """Used in response to an error message.""" - with torch.no_grad(): - num_nan_vecs = torch.sum(torch.isnan(volume.optic_axis[0, :])) - if num_nan_vecs > 0: - replacement_vecs = torch.nn.functional.normalize( - torch.rand(3, int(num_nan_vecs)), p=2, dim=0 - ) - volume.optic_axis[:, torch.isnan(volume.optic_axis[0, :])] = ( - replacement_vecs - ) - if ep == 0: - print( - f"Replaced {num_nan_vecs} NaN optic axis vectors with random unit vectors." - ) - def setup_raytracer(self, image=None, filepath=None, device="cpu"): """Initialize Birefringent Raytracer.""" if filepath: @@ -491,7 +481,8 @@ def setup_raytracer(self, image=None, filepath=None, device="cpu"): return rays def initialize_nerf_mode(self, use_nerf=True): - self.rays.initialize_nerf_mode(use_nerf) + nerf_params_dict = self.iteration_params.get("nerf", {}).get("MLP", {}) + self.rays.initialize_nerf_mode(use_nerf, nerf_params_dict) def mask_outside_rays(self): """Mask out volume that is outside FOV of the microscope. @@ -536,29 +527,43 @@ def restrict_volume_to_reachable_region(self): def _turn_off_initial_volume_gradients(self): """Turn off the gradients for the initial volume guess.""" - self.volume_initial_guess.Delta_n.requires_grad = False - self.volume_initial_guess.optic_axis.requires_grad = False + self.volume_initial_guess.set_requires_grad(False) - def specify_variables_to_learn(self, learning_vars=None): - """ - Specify which variables of the initial volume object should be considered for learning. - This method updates the 'members_to_learn' attribute of the initial volume object, ensuring - no duplicates are added. + def _specify_variables_to_learn(self): + """Specify which variables of the initial volume object should be + considered for learning. This method updates the 'members_to_learn' + attribute of the initial volume object, ensuring no duplicates are added. The variable names must be attributes of the BirefringentVolume class. - Args: - learning_vars (list): Variable names to be appended for learning. - Defaults to ['Delta_n', 'optic_axis']. """ volume = self.volume_pred - if learning_vars is None: + learning_vars = [] + + # Determine learnable variables based on iteration params and volume properties + all_prop_elements = self.iteration_params.get("learnables", {}).get("all_prop_elements", False) + + if all_prop_elements: learning_vars = ["Delta_n", "optic_axis"] + else: + self.create_parameters_from_mask(volume, self.mask) + if self.two_optic_axis_components: + learning_vars = ["birefringence_active", "optic_axis_planar"] + else: + learning_vars = ["birefringence_active", "optic_axis_active"] + + # Handle large arrays if necessary + if self.remove_large_arrs: + del volume.Delta_n + del volume.optic_axis + gc.collect() + else: + self._detach_volume_params(volume) for var in learning_vars: if var not in volume.members_to_learn: volume.members_to_learn.append(var) def optimizer_setup(self, parameters, training_params): """Setup optimizer.""" - optimizer_type = training_params.get("optimizer", "Nadam") + optimizer_type = training_params.get("misc", {}).get("optimizer", "Nadam") optimizers = { "Adam": lambda params: torch.optim.Adam(params), "SGD": lambda params: torch.optim.SGD(params, nesterov=True, momentum=0.7), @@ -594,7 +599,7 @@ def voxel_mask_setup(self): filtered_voxels = filter_voxels_using_retardance( self.rays.vox_indices_ml_shifted_all, self.rays.ray_valid_indices_all, - self.ret_img_meas, + undo_transpose_and_flip(self.ret_img_meas) ) mask = torch.zeros(num_vox_in_volume, dtype=torch.bool) @@ -606,7 +611,7 @@ def voxel_mask_setup(self): print(f"Voxel mask created in {end_time - start_time:.2f} seconds") else: try: - vox_indices_path = self.iteration_params["vox_indices_by_mla_idx_path"] + vox_indices_path = self.iteration_params["file_paths"]["vox_indices_by_mla_idx"] if not vox_indices_path: raise ValueError("Vox indices path is empty.") start_time = time.perf_counter() @@ -709,10 +714,8 @@ def _compute_loss(self, images_predicted: list): ) # Compute regularization term - if isinstance(params["regularization_weight"], list): - params["regularization_weight"] = params["regularization_weight"][0] reg_loss, reg_term_values = LossFcn.compute_regularization_term(vol_pred) - regularization_term = params["regularization_weight"] * reg_loss + regularization_term = params["regularization"]["weight"] * reg_loss self.reg_term_values = [reg.item() for reg in reg_term_values] # Total loss (which has gradients enabled) @@ -743,6 +746,46 @@ def fill_optaxis_component(self, volume): ) return + def _assign_nerf_output_to_volume(self, volume): + """Method to assign the output of the NeRF model to the volume.""" + vol_shape = self.optical_info["volume_shape"] + predicted_properties = predict_voxel_properties( + self.rays.inr_model, vol_shape, enable_grad=False + ) + Delta_n = predicted_properties[..., 0] + volume.Delta_n = torch.nn.Parameter(Delta_n.flatten() * self.mask) + optic_axis_flat = predicted_properties.view( + -1, predicted_properties.shape[-1] + )[..., 1:] + if predicted_properties.shape[-1] == 3: + optic_axis_flat = spherical_to_unit_vector_torch(optic_axis_flat) + volume.optic_axis = torch.nn.Parameter(optic_axis_flat.permute(1, 0)) + return volume + + def _create_placeholder_volume_attributes(self, volume, grad=False): + """Create the Delta_n and optic_axis attributes for the volume. + This method is intended to be used when the large arrays are deleted. + """ + vol_shape = self.optical_info["volume_shape"] + device = volume.birefringence_active.device + volume.Delta_n = torch.nn.Parameter( + torch.zeros(vol_shape).flatten(), requires_grad=grad + ).to(device) + vol_size_flat = volume.Delta_n.size(0) + volume.optic_axis = torch.nn.Parameter( + torch.zeros(3, vol_size_flat), requires_grad=grad + ).to(device) + return volume + + def _assign_active_params_to_volume(self, volume): + """Assign the active parameters to the volume.""" + if volume.indices_active is None: + raise ValueError("Indices active is None") + with torch.no_grad(): + volume.Delta_n[volume.indices_active] = volume.birefringence_active + volume.optic_axis[:, volume.indices_active] = volume.optic_axis_active + return volume + # @profile # to see the memory breakdown of the function def one_iteration(self, volume_estimation, optimizers, schedulers): """Performs one iteration of the reconstruction process. @@ -770,22 +813,17 @@ def one_iteration(self, volume_estimation, optimizers, schedulers): # In case the entire volume is needed for the loss computation: total_vol_needed = False if total_vol_needed and self.volume_pred.indices_active is not None: - with torch.no_grad(): - self.volume_pred.Delta_n[self.volume_pred.indices_active] = ( - self.volume_pred.birefringence_active - ) - self.volume_pred.optic_axis[:, self.volume_pred.indices_active] = ( - self.volume_pred.optic_axis_active - ) + self._assign_active_params_to_volume(volume_estimation) + if self.nerf_mode: - # Update Delta_n before loss is computed so the the mask regularization is applied + # TODO: only update if regularization weight is nonzero + # Update Delta_n before loss is computed so regularization can be applied vol_shape = self.optical_info["volume_shape"] predicted_properties = predict_voxel_properties( self.rays.inr_model, vol_shape, enable_grad=True ) Delta_n = predicted_properties[..., 0] # # Gradients are lost when setting Delta_n as a torch nn parameter - # self.volume_pred.Delta_n = torch.nn.Parameter(Delta_n.flatten()) self.volume_pred.birefringence = Delta_n loss, data_term, regularization_term = self._compute_loss(img_list) @@ -795,13 +833,17 @@ def one_iteration(self, volume_estimation, optimizers, schedulers): # Verify the gradients before and after the backward pass if PRINT_GRADIENTS: print("\nBefore backward pass:") - self.print_grad_info(volume_estimation) + print_grad_info(volume_estimation) loss.backward() if PRINT_GRADIENTS: print("\nAfter backward pass:") - self.print_grad_info(volume_estimation) + print_grad_info(volume_estimation) + + if self.nerf_mode: + monitor_gradients(self.rays.inr_model) + clip_gradient_norms_nerf(self.rays.inr_model, self.ep, verbose=True) if CLIP_GRADIENT_NORM: self.clip_gradient_norms(volume_estimation) @@ -818,20 +860,21 @@ def one_iteration(self, volume_estimation, optimizers, schedulers): if scheduler is not None: step_scheduler(scheduler, loss) + # Keep the optic axis on the unit sphere + if self.two_optic_axis_components: + self.fill_optaxis_component(volume_estimation) + self.keep_optic_axis_on_sphere(volume_estimation) + if self.nerf_mode: adjusted_lrs = optimizer_nerf.param_groups[0]["lr"] + self._assign_nerf_output_to_volume(volume_estimation) else: adjusted_lrs = optimizer_opticaxis.param_groups[0]["lr"], optimizer_birefringence.param_groups[0]["lr"] if PRINT_GRADIENTS: print_moments(optimizer) - # Keep the optic axis on the unit sphere - if self.two_optic_axis_components: - self.fill_optaxis_component(volume_estimation) - self.keep_optic_axis_on_sphere(volume_estimation) - - if self.ep % 50 == 0 and False: + if self.ep % 50 == 0 and DEBUG: tqdm.write(f"Iteration {self.ep} first 5 values:") tqdm.write( f"birefringence: {volume_estimation.birefringence_active[:5].detach().cpu().numpy()}" @@ -880,29 +923,6 @@ def one_iteration(self, volume_estimation, optimizers, schedulers): ) return - def print_grad_info(self, volume_estimation): - if False: - print( - "Delta_n requires_grad:", - volume_estimation.Delta_n.requires_grad, - "birefringence_active requires_grad:", - volume_estimation.birefringence_active.requires_grad, - ) - if volume_estimation.Delta_n.grad is not None: - print( - "Gradient for Delta_n (up to 10 values):", - volume_estimation.Delta_n.grad[:10], - ) - else: - print("Gradient for Delta_n is None") - if volume_estimation.birefringence_active.grad is not None: - print( - "Gradient for birefringence_active (up to 10 values):", - volume_estimation.birefringence_active.grad[:10], - ) - else: - print("Gradient for birefringence_active is None") - def store_results( self, ret_image_current, @@ -921,41 +941,53 @@ def store_results( self.loss_reg_term_list.append(regularization_term.item()) self.adjusted_lrs_list.append(adjusted_lrs) + def _save_figure_as_pdf(self, ep, output_dir, in_progress=False): + """Saves the current figure as a PDF.""" + if in_progress: + filename = f"optim_iter_{'{:04d}'.format(ep)}.pdf" + plt.savefig(os.path.join(output_dir, "results_in_progress", filename)) + else: + plt.savefig(os.path.join(output_dir, "optimization.pdf")) + + def _save_volume_as_h5(self, volume, output_dir, ep, in_progress=False): + """Saves the volume as a h5 file.""" + desc = f"Volume estimation after {ep} iterations." + if in_progress: + volume_filename = f"volume_iter_{'{:04d}'.format(ep)}.h5" + volume.save_as_file(os.path.join(output_dir, "results_in_progress", volume_filename), description=desc) + else: + vol_save_path = os.path.join(output_dir, "volume.h5") + volume.save_as_file(vol_save_path, description=desc) + print("Saved the final volume estimation to", vol_save_path) + + def _save_nerf_model(self, rays, output_dir, ep, in_progress=False): + """Saves the NeRF model as a pth file.""" + if in_progress: + filename = f"nerf_model_iter_{'{:04d}'.format(ep)}.pth" + rays.save_nerf_model(os.path.join(output_dir, "results_in_progress", filename)) + else: + filepath = os.path.join(output_dir, "nerf_model.pth") + rays.save_nerf_model(filepath) + print("Saved the final NeRF model to", filepath) + def visualize_and_save(self, ep, fig, output_dir): + """Visualize and save the results of the reconstruction.""" + self.save_loss_lists_to_csv() + self._save_regularization_terms_to_csv(ep) + if self.volume_ground_truth is not None: + self._save_volume_discrepancy_to_csv(ep) + volume_estimation = self.volume_pred if self.remove_large_arrs: - vol_shape = self.optical_info["volume_shape"] - temp_bir = torch.zeros(vol_shape).flatten() - device = volume_estimation.birefringence_active.device - volume_estimation.Delta_n = torch.nn.Parameter( - temp_bir, requires_grad=False - ).to(device) - if self.volume_pred.indices_active is not None: - with torch.no_grad(): - volume_estimation.Delta_n[volume_estimation.indices_active] = ( - volume_estimation.birefringence_active - ) + self._create_placeholder_volume_attributes(volume_estimation, grad=False) + if self.two_optic_axis_components: + self._assign_active_params_to_volume(volume_estimation) - save_freq = self.iteration_params.get("save_freq", 5) - # TODO: only update every 1 iteration if plotting is live - if ep % 1 == 0: + save_freq = self.iteration_params["general"]["save_freq"] + plot_live = self.iteration_params.get("visualization", {}).get("plot_live", True) + if plot_live or ep % save_freq == 0: # plt.clf() - if self.nerf_mode: - vol_shape = self.optical_info["volume_shape"] - predicted_properties = predict_voxel_properties( - self.rays.inr_model, vol_shape - ) - Delta_n = predicted_properties[..., 0] - volume_estimation.Delta_n = torch.nn.Parameter(Delta_n.flatten()) - # TODO: see if mask should be applied here - volume_estimation.Delta_n = torch.nn.Parameter( - volume_estimation.Delta_n * self.mask - ) - Delta_n = volume_estimation.get_delta_n().detach().unsqueeze(0) - else: - Delta_n = volume_estimation.get_delta_n().detach().unsqueeze(0) - if False: - tqdm.write(f"{[round(val.item(), 4) for val in Delta_n[Delta_n != 0]]}") + Delta_n = volume_estimation.get_delta_n().detach().unsqueeze(0) vol_size_um = self.optical_info["voxel_size_um"] rel_scaling_factor = vol_size_um[0] / vol_size_um[2] mip_image = convert_volume_to_2d_mip( @@ -972,46 +1004,17 @@ def visualize_and_save(self, ep, fig, output_dir): self.loss_total_list, self.loss_data_term_list, self.loss_reg_term_list, + discrepancy_losses=self.volume_discrepancy_list, figure=fig, ) fig.canvas.draw() fig.canvas.flush_events() - time.sleep(0.1) - self.save_loss_lists_to_csv() - self._save_regularization_terms_to_csv(ep) - if ep % save_freq == 0: - filename = f"optim_iter_{'{:04d}'.format(ep)}.pdf" - plt.savefig(os.path.join(output_dir, filename)) - time.sleep(0.1) + time.sleep(0.001) if ep % save_freq == 0: - if self.remove_large_arrs: - vol_size_flat = volume_estimation.Delta_n.size(0) - device = volume_estimation.optic_axis_active.device - volume_estimation.optic_axis = torch.nn.Parameter( - torch.zeros(3, vol_size_flat), requires_grad=False - ).to(device) + self._save_figure_as_pdf(ep, output_dir, in_progress=True) + self._save_volume_as_h5(volume_estimation, output_dir, ep, in_progress=True) if self.nerf_mode: - optic_axis_flat = predicted_properties.view( - -1, predicted_properties.shape[-1] - )[..., 1:] - if predicted_properties.shape[-1] == 3: - optic_axis_flat = spherical_to_unit_vector_torch(optic_axis_flat) - volume_estimation.optic_axis = torch.nn.Parameter( - optic_axis_flat.permute(1, 0) - ) - nerf_model_path = os.path.join(output_dir, f"nerf_model_{ep}.pth") - self.rays.save_nerf_model(nerf_model_path) - else: - if self.volume_pred.indices_active is not None: - with torch.no_grad(): - volume_estimation.optic_axis[ - :, volume_estimation.indices_active - ] = volume_estimation.optic_axis_active - my_description = "Volume estimation after " + str(ep) + " iterations." - volume_estimation.save_as_file( - os.path.join(output_dir, f"volume_iter_{'{:04d}'.format(ep)}.h5"), - description=my_description, - ) + self._save_nerf_model(self.rays, output_dir, ep, in_progress=True) if self.remove_large_arrs: del volume_estimation.optic_axis gc.collect() @@ -1020,10 +1023,12 @@ def visualize_and_save(self, ep, fig, output_dir): gc.collect() return - def __visualize_and_update_streamlit( - self, progress_bar, ep, n_iterations, recon_img_plot, my_loss - ): + def __visualize_and_update_streamlit(self, streamlit_elements, ep, n_iterations): import pandas as pd + + progress_bar = streamlit_elements['progress_bar'] + recon_img_plot = streamlit_elements['my_recon_img_plot'] + my_loss = streamlit_elements['my_loss'] percent_complete = int(ep / n_iterations * 100) progress_bar.progress(percent_complete + 1) @@ -1083,11 +1088,11 @@ def _create_regularization_terms_csv(self): """Create a csv file to store the regularization terms.""" filename = "regularization_terms.csv" filepath = os.path.join(self.recon_directory, filename) - reg_fcns = self.iteration_params["regularization_fcns"] + reg_fcns = self.iteration_params["regularization"]["functions"] fcn_names = [sublist[0] for sublist in reg_fcns] with open(filepath, mode="w", newline="") as file: writer = csv.writer(file) - writer.writerow(["ep", *fcn_names]) + writer.writerow(["Iteration", *fcn_names]) def _save_regularization_terms_to_csv(self, ep): """Save the regularization terms to a csv file.""" @@ -1097,6 +1102,22 @@ def _save_regularization_terms_to_csv(self, ep): writer = csv.writer(file) writer.writerow([ep, *self.reg_term_values]) + def _save_volume_discrepancy_to_csv(self, iteration_num): + """Append the latest volume discrepancy value to a CSV file after each iteration.""" + filename = "volume_discrepancy.csv" + filepath = os.path.join(self.recon_directory, filename) + + # Check if the file exists (write header if it does not) + file_exists = os.path.isfile(filepath) + + with open(filepath, mode="a", newline="") as file: + writer = csv.writer(file) + if not file_exists: + writer.writerow(["Iteration", "Discrepancy"]) + # Write the latest value from self.volume_discrepancy_list + latest_discrepancy = self.volume_discrepancy_list[-1] + writer.writerow([iteration_num, latest_discrepancy]) + def clip_gradient_norms(self, model, verbose=False): # Gradient clipping max_norm = 1.0 @@ -1144,46 +1165,63 @@ def prepare_volume_for_recon(self, volume): check_for_inf_or_nan(volume.birefringence_active) check_for_inf_or_nan(volume.optic_axis_active) - def reconstruct( - self, - use_streamlit=False, - plot_live=False, - all_prop_elements=False, - ): - """Method to perform the actual reconstruction based on the provided parameters. - """ - log_file = True - log_file_path = os.path.join(self.recon_directory, "output_log.txt") + def update_ret_azim_when_missing(self): + """Update the ret_img_pred and azim_img_pred attributes when + they are not present, such as when using intensity boolean.""" + with torch.no_grad(): + [ret_image_current, azim_image_current] = ( + self.rays.ray_trace_through_volume(self.volume_pred) + ) + self.ret_img_pred = ret_image_current.detach().cpu().numpy() + self.azim_img_pred = azim_image_current.detach().cpu().numpy() + + def _detach_volume_params(self, volume): + """Detach and disable gradient for Delta_n and optic_axis.""" + volume.Delta_n.detach_() + volume.Delta_n.requires_grad = False + volume.optic_axis.detach_() + volume.optic_axis.requires_grad = False + + def _create_results_subdirectory(self): + """Create the results subdirectory if it does not exist.""" + results_directory = os.path.join(self.recon_directory, "results_in_progress") + if not os.path.exists(results_directory): + os.makedirs(results_directory) + return results_directory + + def _setup_logging(self): + log_file = self.iteration_params.get("misc", {}).get("save_to_logfile", True) if log_file: - # Redirect output to the log file if provided - log_file_handle = redirect_output_to_log(log_file_path) - print(f"Beginning reconstruction iterations...") - # Turn off the gradients for the initial volume guess - self._turn_off_initial_volume_gradients() + log_file_path = os.path.join(self.recon_directory, "output_log.txt") + return redirect_output_to_log(log_file_path) + return None - # Specify variables to learn - if all_prop_elements: - param_list = ["Delta_n", "optic_axis"] - else: - self.create_parameters_from_mask(self.volume_pred, self.mask) - if self.two_optic_axis_components: - param_list = ["birefringence_active", "optic_axis_planar"] - else: - param_list = ["birefringence_active", "optic_axis_active"] - if self.remove_large_arrs: - del self.volume_pred.Delta_n - del self.volume_pred.optic_axis - gc.collect() - else: - self.volume_pred.Delta_n.detach() - self.volume_pred.Delta_n.requires_grad = False - self.volume_pred.optic_axis.detach() - self.volume_pred.optic_axis.requires_grad = False - self.specify_variables_to_learn(param_list) + def _setup_streamlit(self, use_streamlit, n_iterations): + if use_streamlit: + import streamlit as st + st.write("Working on these ", n_iterations, "iterations...") + return { + 'my_recon_img_plot': st.empty(), + 'my_loss': st.empty(), + 'progress_bar': st.progress(0) + } + return None + + def reconstruct(self, use_streamlit=False): + """Method to perform the actual reconstruction based on the + provided parameters. + """ + log_file_handle = self._setup_logging() + print("Beginning reconstruction...") + self._create_results_subdirectory() + self._turn_off_initial_volume_gradients() + self._specify_variables_to_learn() + print("Setting up optimizer and scheduler...") if self.nerf_mode: - optimizer = setup_optimizer_nerf(self.rays, self.iteration_params) - scheduler_nerf_config = get_scheduler_configs_nerf(self.iteration_params) + nerf_params = self.iteration_params.get("nerf", {}) + optimizer = setup_optimizer_nerf(self.rays.inr_model, nerf_params) + scheduler_nerf_config = get_scheduler_configs_nerf(nerf_params) scheduler_nerf = create_scheduler(optimizer, scheduler_nerf_config) optimizer_opticaxis, optimizer_birefringence = None, None scheduler_opticaxis, scheduler_birefringence = None, None @@ -1191,17 +1229,18 @@ def reconstruct( else: training_params = self.iteration_params volume_estimation = self.volume_pred + self.prepare_volume_for_recon(volume_estimation) trainable_parameters = volume_estimation.get_trainable_variables() trainable_vars_names = volume_estimation.get_names_of_trainable_variables() parameters_optic_axis = [{ "params": trainable_parameters[0], - "lr": training_params["lr_optic_axis"], + "lr": training_params["learning_rates"]["optic_axis"], "name": trainable_vars_names[0], }] optimizer_opticaxis = self.optimizer_setup(parameters_optic_axis, training_params) parameters_birefringence = [{ "params": trainable_parameters[1], - "lr": training_params["lr_birefringence"], + "lr": training_params["learning_rates"]["birefringence"], "name": trainable_vars_names[1], }] optimizer_birefringence = self.optimizer_setup(parameters_birefringence, training_params) @@ -1215,32 +1254,24 @@ def reconstruct( initial_lr_0 = optimizer_opticaxis.param_groups[0]["lr"] initial_lr_1 = optimizer_birefringence.param_groups[0]["lr"] + plot_live = self.iteration_params.get("visualization", {}).get("plot_live", True) + fig_size = self.iteration_params.get("visualization", {}).get("fig_size", (10, 11)) figure = setup_visualization( - window_title=self.recon_directory, plot_live=plot_live + window_title=self.recon_directory, plot_live=plot_live, fig_size=fig_size ) self._create_regularization_terms_csv() - n_iterations = self.iteration_params["num_iterations"] - if use_streamlit: - import streamlit as st - - st.write("Working on these ", n_iterations, "iterations...") - my_recon_img_plot = st.empty() - my_loss = st.empty() - my_plot = st.empty() # set up a place holder for the plot - my_3D_plot = st.empty() # set up a place holder for the 3D plot - progress_bar = st.progress(0) - - self.prepare_volume_for_recon(self.volume_pred) + n_iterations = self.iteration_params["general"]["num_iterations"] + streamlit_elements = self._setup_streamlit(use_streamlit, n_iterations) # Parameters for learning rate warmup - warmup_iterations = 10 - warmup_start_proportion = 0.1 + warmup_iterations = self.iteration_params.get("misc", {}).get("warmup_iterations", 10) + warmup_start_proportion = 1 / warmup_iterations + print("Starting iterations...") # Iterations for ep in tqdm(range(1, n_iterations + 1), "Minimizing"): self.ep = ep - # Learning rate warmup if ep <= warmup_iterations: warmup_factor = warmup_start_proportion + (1 - warmup_start_proportion) * (ep / warmup_iterations) lr_0 = initial_lr_0 * warmup_factor @@ -1264,58 +1295,38 @@ def reconstruct( optimizers = (optimizer, optimizer_opticaxis, optimizer_birefringence) schedulers = (scheduler_nerf, scheduler_birefringence, scheduler_opticaxis) self.one_iteration(self.volume_pred, optimizers, schedulers) + + if self.volume_ground_truth is not None: + volume_discrepancy = compare_volumes(self.volume_pred, self.volume_ground_truth) + self.volume_discrepancy_list.append(volume_discrepancy.item()) if ep == 1 and PRINT_TIMING_INFO: self.rays.print_timing_info() if ep % 20 == 0 and self.intensity_bool: - with torch.no_grad(): - [ret_image_current, azim_image_current] = ( - self.rays.ray_trace_through_volume(self.volume_pred) - ) - self.ret_img_pred = ret_image_current.detach().cpu().numpy() - self.azim_img_pred = azim_image_current.detach().cpu().numpy() + self.update_ret_azim_when_missing() sys.stdout.flush() azim_damp_mask = self._to_numpy(self.ret_img_meas / self.ret_img_meas.max()) self.azim_img_pred[azim_damp_mask == 0] = 0 if use_streamlit: self.__visualize_and_update_streamlit( - progress_bar, ep, n_iterations, my_recon_img_plot, my_loss + streamlit_elements, ep, n_iterations ) self.visualize_and_save(ep, figure, self.recon_directory) + self._save_figure_as_pdf(ep, self.recon_directory, in_progress=False) + plt.close() + self.save_loss_lists_to_csv() - if self.remove_large_arrs: - vol_shape = self.optical_info["volume_shape"] - temp_bir = torch.zeros(vol_shape).flatten() - device = self.volume_pred.birefringence_active.device - self.volume_pred.Delta_n = torch.nn.Parameter( - temp_bir, requires_grad=False - ).to(device) - self.volume_pred.Delta_n[self.volume_pred.indices_active] = ( - self.volume_pred.birefringence_active - ) - vol_size_flat = self.volume_pred.Delta_n.size(0) - self.volume_pred.optic_axis = torch.nn.Parameter( - torch.zeros(3, vol_size_flat), requires_grad=False - ).to(device) - self.volume_pred.optic_axis[:, self.volume_pred.indices_active] = ( - self.volume_pred.optic_axis_active - ) - my_description = "Volume estimation after " + str(ep) + " iterations." - vol_save_path = os.path.join( - self.recon_directory, f"volume_iter_{'{:04d}'.format(ep)}.h5" - ) - self.volume_pred.save_as_file(vol_save_path, description=my_description) - print("Saved the final volume estimation to", vol_save_path) - plt.savefig(os.path.join(self.recon_directory, "optim_final.pdf")) - plt.close() + if self.remove_large_arrs: + self._create_placeholder_volume_attributes(self.volume_pred, grad=False) + self._assign_active_params_to_volume(self.volume_pred) + self._save_volume_as_h5(self.volume_pred, self.recon_directory, ep, in_progress=False) if self.nerf_mode: - nerf_model_path = os.path.join(self.recon_directory, "nerf_model.pth") - self.rays.save_nerf_model(nerf_model_path) + self._save_nerf_model(self.rays, self.recon_directory, ep, in_progress=False) - if log_file: - # Restore the standard output + if log_file_handle: restore_output(log_file_handle) + print("Reconstruction complete.") diff --git a/src/VolumeRaytraceLFM/setup_parameters.py b/src/VolumeRaytraceLFM/setup_parameters.py index 69a8ac4..b4af8aa 100644 --- a/src/VolumeRaytraceLFM/setup_parameters.py +++ b/src/VolumeRaytraceLFM/setup_parameters.py @@ -19,10 +19,21 @@ def setup_iteration_parameters(config_file=None): iteration_params = json.load(f) else: iteration_params = { - "num_iterations": 201, - "azimuth_weight": 0.5, - "regularization_weight": 0.1, - "lr": 1e-3, - "output_posfix": "", + "general": { + "num_iterations": 1000, + "save_freq": 100 + }, + + "learning_rates": { + "birefringence": 1e-4, + "optic_axis": 1e-1 + }, + + "regularization": { + "weight": 1.0, + "functions": [ + ["birefringence active L2", 100] + ] + }, } return iteration_params diff --git a/src/VolumeRaytraceLFM/simulations.py b/src/VolumeRaytraceLFM/simulations.py index cb805b1..b4cb94e 100644 --- a/src/VolumeRaytraceLFM/simulations.py +++ b/src/VolumeRaytraceLFM/simulations.py @@ -138,7 +138,7 @@ def view_images(self, azimuth_plot_type="hsv"): ret_image, azim_image, azimuth_plot_type, include_labels=True ) # my_fig.tight_layout() - plt.pause(0.2) + plt.pause(0.001) plt.show(block=True) def view_intensity_images(self): @@ -147,7 +147,7 @@ def view_intensity_images(self): self.img_list[i] = self.convert_to_numpy(self.img_list[i]) my_fig = plot_intensity_images(self.img_list) my_fig.tight_layout() - plt.pause(0.2) + plt.pause(0.001) plt.show(block=True) def save_ret_azim_images(self): diff --git a/src/VolumeRaytraceLFM/utils/error_handling.py b/src/VolumeRaytraceLFM/utils/error_handling.py index c60113b..09e3c4b 100644 --- a/src/VolumeRaytraceLFM/utils/error_handling.py +++ b/src/VolumeRaytraceLFM/utils/error_handling.py @@ -57,3 +57,19 @@ def check_for_negative_values_dict(my_dict): raise ValueError("The dictionary contains negative values.") else: print("All entries are nonnegative.") + + +def replace_nans_in_optic_axis(volume): + """Used in response to an error message.""" + with torch.no_grad(): + num_nan_vecs = torch.sum(torch.isnan(volume.optic_axis[0, :])) + if num_nan_vecs > 0: + replacement_vecs = torch.nn.functional.normalize( + torch.rand(3, int(num_nan_vecs)), p=2, dim=0 + ) + volume.optic_axis[:, torch.isnan(volume.optic_axis[0, :])] = ( + replacement_vecs + ) + print( + f"Replaced {num_nan_vecs} NaN optic axis vectors with random unit vectors." + ) diff --git a/src/VolumeRaytraceLFM/utils/gradient_utils.py b/src/VolumeRaytraceLFM/utils/gradient_utils.py new file mode 100644 index 0000000..4e73d59 --- /dev/null +++ b/src/VolumeRaytraceLFM/utils/gradient_utils.py @@ -0,0 +1,68 @@ +"""Functions for monitoring the gradients of the neural network layers.""" +import torch + + +def monitor_gradients(model): + print("Monitoring layer gradients:") + + if isinstance(model, torch.nn.DataParallel): + model = model.module + + # First hidden layer (fc1) + if model.layers[0].weight.grad is not None: + print(f"Layer 1 (fc1) weight gradient norm: {model.layers[0].weight.grad.norm(2).item():.4f}") + if model.layers[0].bias.grad is not None: + print(f"Layer 1 (fc1) bias gradient norm: {model.layers[0].bias.grad.norm(2).item():.4f}") + + # Second hidden layer (fc2) + if model.layers[2].weight.grad is not None: + print(f"Layer 2 (fc2) weight gradient norm: {model.layers[2].weight.grad.norm(2).item():.4f}") + if model.layers[2].bias.grad is not None: + print(f"Layer 2 (fc2) bias gradient norm: {model.layers[2].bias.grad.norm(2).item():.4f}") + + # Third hidden layer (fc3) + if model.layers[4].weight.grad is not None: + print(f"Layer 3 (fc3) weight gradient norm: {model.layers[4].weight.grad.norm(2).item():.4f}") + if model.layers[4].bias.grad is not None: + print(f"Layer 3 (fc3) bias gradient norm: {model.layers[4].bias.grad.norm(2).item():.4f}") + + # Output layer + if model.layers[-1].weight.grad is not None: + print(f"Output layer weight gradient norm: {model.layers[-1].weight.grad.norm(2).item():.4f}") + if model.layers[-1].bias.grad is not None: + print(f"Output layer bias gradient norm: {model.layers[-1].bias.grad.norm(2).item():.4f}") + + +def clip_gradient_norms_nerf(model, iteration_num, verbose=False): + # Gradient clipping + max_norm = 1.0 + total_norm = torch.nn.utils.clip_grad_norm_(model.module.parameters(), max_norm=max_norm) + + if verbose: + print(f"Iteration {iteration_num}: Total gradient norm: {total_norm:.2f}") + if total_norm > max_norm: + print(f"Iteration {iteration_num}: Gradients clipped to norm {max_norm}") + + +def print_grad_info(volume_estimation): + if False: + print( + "Delta_n requires_grad:", + volume_estimation.Delta_n.requires_grad, + "birefringence_active requires_grad:", + volume_estimation.birefringence_active.requires_grad, + ) + if volume_estimation.Delta_n.grad is not None: + print( + "Gradient for Delta_n (up to 10 values):", + volume_estimation.Delta_n.grad[:10], + ) + else: + print("Gradient for Delta_n is None") + if volume_estimation.birefringence_active.grad is not None: + print( + "Gradient for birefringence_active (up to 10 values):", + volume_estimation.birefringence_active.grad[:10], + ) + else: + print("Gradient for birefringence_active is None") diff --git a/src/VolumeRaytraceLFM/utils/optimizer_utils.py b/src/VolumeRaytraceLFM/utils/optimizer_utils.py index 1554a58..512bcb3 100644 --- a/src/VolumeRaytraceLFM/utils/optimizer_utils.py +++ b/src/VolumeRaytraceLFM/utils/optimizer_utils.py @@ -105,14 +105,13 @@ def get_scheduler_configs(iteration_params): "eps": 1e-8 } } - scheduler_opticaxis_config = schedulers.get("opticaxis", default_sched_config) + scheduler_opticaxis_config = schedulers.get("optic_axis", default_sched_config) scheduler_birefringence_config = schedulers.get("birefringence", default_sched_config) return scheduler_opticaxis_config, scheduler_birefringence_config -def get_scheduler_configs_nerf(iteration_params): +def get_scheduler_configs_nerf(nerf_params): """Get the schedulers for the optimizer.""" - schedulers = iteration_params.get("schedulers", {}) default_sched_config = { "type": "ReduceLROnPlateau", "params": { @@ -124,7 +123,7 @@ def get_scheduler_configs_nerf(iteration_params): "eps": 1e-8 } } - scheduler_nerf_config = schedulers.get("nerf", default_sched_config) + scheduler_nerf_config = nerf_params.get("scheduler", default_sched_config) return scheduler_nerf_config diff --git a/src/VolumeRaytraceLFM/utils/orientation_utils.py b/src/VolumeRaytraceLFM/utils/orientation_utils.py new file mode 100644 index 0000000..06c7c50 --- /dev/null +++ b/src/VolumeRaytraceLFM/utils/orientation_utils.py @@ -0,0 +1,22 @@ +import numpy as np +import torch + + +def transpose_and_flip(data): + if isinstance(data, np.ndarray): + result = np.flip(data.T, axis=1).copy() + elif isinstance(data, torch.Tensor): + result = torch.flip(data.transpose(0, 1), dims=[1]).clone() + else: + raise TypeError("Input must be either a NumPy array or PyTorch tensor.") + return result + + +def undo_transpose_and_flip(data): + if isinstance(data, np.ndarray): + result = np.flip(data.T, axis=1).copy() + elif isinstance(data, torch.Tensor): + result = torch.flip(data.transpose(0, 1), dims=[1]).clone() + else: + raise TypeError("Input must be either a NumPy array or PyTorch tensor.") + return result diff --git a/src/VolumeRaytraceLFM/visualization/plot_ellipsoid.py b/src/VolumeRaytraceLFM/visualization/plot_ellipsoid.py deleted file mode 100644 index c870047..0000000 --- a/src/VolumeRaytraceLFM/visualization/plot_ellipsoid.py +++ /dev/null @@ -1,115 +0,0 @@ -import numpy as np -import plotly.graph_objects as go - - -def generate_ellipsoid_volume( - volume_shape, center=[0.5, 0.5, 0.5], radius=[10, 10, 10], alpha=0.1, delta_n=0.1 -): - """generate_ellipsoid_volume: Creates an ellipsoid with optical axis normal to the ellipsoid surface. - Args: - Center [3]: [cz,cy,cx] from 0 to 1 where 0.5 is the center of the volume_shape. - radius [3]: in voxels, the radius in z,y,x for this ellipsoid. - alpha float: Border thickness. - delta_n float: Delta_n value of birefringence in the volume - """ - # Grabbed from https://math.stackexchange.com/questions/2931909/normal-of-a-point-on-the-surface-of-an-ellipsoid - vol = np.zeros( - [ - 4, - ] - + volume_shape - ) - - kk, jj, ii = np.meshgrid( - np.arange(volume_shape[0]), - np.arange(volume_shape[1]), - np.arange(volume_shape[2]), - indexing="ij", - ) - # shift to center - kk = np.floor(center[0] * volume_shape[0]) - kk.astype(float) - jj = np.floor(center[1] * volume_shape[1]) - jj.astype(float) - ii = np.floor(center[2] * volume_shape[2]) - ii.astype(float) - - # DEBUG: checking the indicies - # np.argwhere(ellipsoid_border == np.min(ellipsoid_border)) - # plt.imshow(ellipsoid_border_mask[int(volume_shape[0] / 2),:,:]) - ellipsoid_border = ( - (kk**2) / (radius[0] ** 2) - + (jj**2) / (radius[1] ** 2) - + (ii**2) / (radius[2] ** 2) - ) - ellipsoid_border_mask = np.abs(ellipsoid_border - alpha) <= 1 - vol[0, ...] = ellipsoid_border_mask.astype(float) - # Compute normals - kk_normal = 2 * kk / radius[0] - jj_normal = 2 * jj / radius[1] - ii_normal = 2 * ii / radius[2] - norm_factor = np.sqrt(kk_normal**2 + jj_normal**2 + ii_normal**2) - # Avoid division by zero - norm_factor[norm_factor == 0] = 1 - vol[1, ...] = (kk_normal / norm_factor) * vol[0, ...] - vol[2, ...] = (jj_normal / norm_factor) * vol[0, ...] - vol[3, ...] = (ii_normal / norm_factor) * vol[0, ...] - vol[0, ...] *= delta_n - # vol = vol.permute(0,2,1,3) - - inner_alpha = 0.5 - outer_mask = np.abs(ellipsoid_border - alpha) <= 1 - inner_mask = ellipsoid_border < inner_alpha - - # Hollowing out the ellipsoid - combined_mask = np.logical_and(outer_mask, ~inner_mask) - - vol[0, ...] = combined_mask.astype(float) - return vol - - -def plot_ellipsoid(vol): - """Plots the ellipsoid using Plotly. - Args: - - vol (numpy array): The output from generate_ellipsoid_volume. - """ - - # Extract the ellipsoid's border mask (which is in the first channel of vol) - ellipsoid_mask = vol[0, ...] > 0 - - # Extract the x, y, z coordinates of the surface voxels - z, y, x = np.where(ellipsoid_mask) - - # Create a scatter plot of the surface voxels - scatter = go.Scatter3d(x=x, y=y, z=z, mode="markers", marker=dict(size=2)) - - # Plot - fig = go.Figure(data=[scatter]) - fig.show() - - -# def plot_ellipsoid(vol): -# """ Plots the ellipsoid using Plotly. -# Args: -# - vol (numpy array): The output from generate_ellipsoid_volume. -# """ - -# fig = go.Figure(data=go.Volume( -# x=vol[3, ...].flatten(), -# y=vol[2, ...].flatten(), -# z=vol[1, ...].flatten(), -# value=vol[0, ...].flatten(), -# isomin=0.1, -# isomax=0.8, -# opacity=0.1, # adjust this for visualization clarity -# surface_count=17, # adjust this based on preference -# colorscale='Viridis' -# )) - -# fig.show() - - -# Example usage -volume_shape = [50, 50, 50] -myshape = [15, 51, 51] -radius = [5.5, 5.5, 3.5] -myradius = [5.5, 9.5, 5.5] -vol = generate_ellipsoid_volume(myshape, radius=myradius, center=[0.5, 0.5, 0.5]) -plot_ellipsoid(vol) diff --git a/src/VolumeRaytraceLFM/visualization/plotting_iterations.py b/src/VolumeRaytraceLFM/visualization/plotting_iterations.py index 0a5b701..db803e9 100644 --- a/src/VolumeRaytraceLFM/visualization/plotting_iterations.py +++ b/src/VolumeRaytraceLFM/visualization/plotting_iterations.py @@ -1,6 +1,8 @@ import torch +import numpy as np import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec +import matplotlib.ticker as ticker def plot_iteration_update( @@ -17,7 +19,7 @@ def plot_iteration_update( ): if streamlit_purpose: fig = plt.figure(figsize=(18, 9)) - plt.rcParams["image.origin"] = "lower" + plt.rcParams["image.origin"] = "upper" # Plot measurements plt.subplot(2, 4, 1) @@ -86,7 +88,7 @@ def plot_est_iteration_update( ): if streamlit_purpose: fig = plt.figure(figsize=(18, 9)) - plt.rcParams["image.origin"] = "lower" + plt.rcParams["image.origin"] = "upper" # Plot predictions plt.subplot(2, 4, 5) @@ -144,6 +146,8 @@ def plot_image_subplot(ax, image, title, cmap="plasma"): ax.set_title(title, fontsize=8) ax.axis("off") # Hide the axis for a cleaner look ax.xaxis.set_visible(False) # Hide the x-axis if not needed + if title == "Orientation": + im.set_clim(0, np.pi) def plot_combined_loss_subplot( @@ -165,15 +169,19 @@ def plot_combined_loss_subplot( ax.set_xlabel("iteration") ax.set_ylabel("loss") ax.legend(loc="upper right") + ax.grid(True) # Set y-axis limit to zoom in on the lower range of loss values if max_y_limit is not None: ax.set_ylim([0, max_y_limit]) + # Use scientific notation for the y-axis + ax.yaxis.set_major_formatter(ticker.ScalarFormatter(useMathText=True)) + ax.ticklabel_format(style='sci', axis='y', scilimits=(0, 0)) + def calculate_dynamic_max_y_limit(losses, window_size=10, scale_factor=1.1): - """ - Calculate the dynamic max_y_limit based on the recent loss values. + """Calculate the dynamic max_y_limit based on the recent loss values. Args: losses (list or np.array): List of total loss values. @@ -192,6 +200,33 @@ def calculate_dynamic_max_y_limit(losses, window_size=10, scale_factor=1.1): return max_recent_loss * scale_factor +def plot_discrepancy_loss_subplot(ax, discrepancy_losses, max_y_limit=None): + """Plots discrepancy losses on a separate subplot.""" + iterations = list(range(len(discrepancy_losses))) + ax.plot(iterations, discrepancy_losses, label="discrepancy", color='purple', linestyle="-") + ax.set_xlim(left=0) + ax.set_xlabel("iteration") + ax.set_ylabel("discrepancy") + ax.yaxis.set_label_position('right') + ax.yaxis.tick_right() + ax.grid(True) + + # Set dynamic y-axis limit + min_discrepancy = min(discrepancy_losses) + max_discrepancy = max(discrepancy_losses) + y_min = min(min_discrepancy * 0.6, max_discrepancy * 0.5) + + # Set y-axis limit to zoom in on the lower range of loss values + if max_y_limit is not None: + ax.set_ylim([0, max_y_limit]) + else: + ax.set_ylim([y_min, max_discrepancy * 1.05]) + + # Use scientific notation for the y-axis + ax.yaxis.set_major_formatter(ticker.ScalarFormatter(useMathText=True)) + ax.ticklabel_format(style='sci', axis='y', scilimits=(0, 0)) + + def plot_iteration_update_gridspec( vol_meas, ret_meas, @@ -202,11 +237,12 @@ def plot_iteration_update_gridspec( losses, data_term_losses, regularization_term_losses, + discrepancy_losses=None, figure=None, streamlit_purpose=False, ): """Plots measured and predicted volumes, retardance, orientation, - and combined losses using GridSpec for layout. + and combined losses using GridSpec for layout. Optionally plots discrepancy loss in a separate subplot. """ # If a figure is provided, use it; otherwise, use the current figure if figure is not None: @@ -215,42 +251,58 @@ def plot_iteration_update_gridspec( fig = plt.gcf() # Get the current figure # Clear the current figure to ensure we're not plotting over old data fig.clf() - # Create GridSpec layout - gs = gridspec.GridSpec(3, 3, figure=fig, hspace=0.2, wspace=0.2) + + # Adjust GridSpec layout: Add extra rows for subheaders + nrows = 10 if discrepancy_losses is not None else 7 + # Define the height ratios for the rows: smaller for the text rows, larger for the plot rows + height_ratios = [0.2, 1, 0.2, 1, 0.05, 0.1, 1] + if discrepancy_losses is not None: + height_ratios.append(0.15) + height_ratios.append(0.1) + height_ratios.append(1) + + # Create GridSpec layout with custom height ratios + gs = gridspec.GridSpec(nrows, 3, figure=fig, hspace=0.2, wspace=0.2, height_ratios=height_ratios) + titles = ["Birefringence (MIP)", "Retardance", "Orientation"] cmaps = ["plasma", "plasma", "twilight"] - # Plot measured data and predictions - for i, (meas, pred, title, cmap) in enumerate( - zip( - [vol_meas, ret_meas, azim_meas], - [vol_current, ret_current, azim_current], - titles, - cmaps, - ) - ): - ax_meas = fig.add_subplot(gs[0, i]) + text_params = { + "ha": "center", + "va": "center", + "fontsize": 10, + "weight": "bold" + } + + # Plot the 'Measurements' header across all three columns + ax_measurements_header = fig.add_subplot(gs[0, :]) + ax_measurements_header.text(0.5, 0.5, "Measurements", **text_params) + ax_measurements_header.axis("off") + + # Plot measured data + for i, (meas, title, cmap) in enumerate(zip([vol_meas, ret_meas, azim_meas], titles, cmaps)): + ax_meas = fig.add_subplot(gs[1, i]) plot_image_subplot(ax_meas, meas, f"{title}", cmap=cmap) - ax_pred = fig.add_subplot(gs[1, i]) + # Plot the 'Predictions' header across all three columns + ax_predictions_header = fig.add_subplot(gs[2, :]) + ax_predictions_header.text(0.5, 0.5, "Predictions", **text_params) + ax_predictions_header.axis("off") + + # Plot predicted data + for i, (pred, title, cmap) in enumerate(zip([vol_current, ret_current, azim_current], titles, cmaps)): + ax_pred = fig.add_subplot(gs[3, i]) plot_image_subplot(ax_pred, pred, f"{title}", cmap=cmap) - # Add row titles - fig.text( - 0.5, 0.96, "Measurements", ha="center", va="center", fontsize=10, weight="bold" - ) - fig.text( - 0.5, 0.645, "Predictions", ha="center", va="center", fontsize=10, weight="bold" - ) - fig.text( - 0.5, 0.33, "Loss Function", ha="center", va="center", fontsize=10, weight="bold" - ) + + # Plot the 'Loss Function' header + ax_loss_header = fig.add_subplot(gs[5, :]) + ax_loss_header.text(0.5, 0.5, "Loss Function", **text_params) + ax_loss_header.axis("off") # Calculate dynamic max_y_limit based on recent loss values - max_y_limit = calculate_dynamic_max_y_limit( - losses, window_size=50, scale_factor=1.1 - ) + max_y_limit = calculate_dynamic_max_y_limit(losses, window_size=50, scale_factor=1.1) - # Plot combined losses across the entire bottom row - ax_combined = fig.add_subplot(gs[2, :]) + # Plot combined losses across the entire row + ax_combined = fig.add_subplot(gs[6, :]) plot_combined_loss_subplot( ax_combined, losses, @@ -258,9 +310,25 @@ def plot_iteration_update_gridspec( regularization_term_losses, max_y_limit=max_y_limit, ) - # Adjust layout to prevent overlap, leave space for row titles - plt.subplots_adjust(left=0.05, right=0.91, bottom=0.07, top=0.92) - # Return the figure object if in Streamlit, else show the plot + + # If discrepancy losses are provided, create a new subplot for them + if discrepancy_losses is not None and len(discrepancy_losses) > 0: + ax_discrepancy_header = fig.add_subplot(gs[8, :]) + ax_discrepancy_header.text(0.5, 0.5, "Discrepancy from Ground Truth", **text_params) + ax_discrepancy_header.axis("off") + ax_discrepancy = fig.add_subplot(gs[9, :]) # New subplot in the final row + # max_y_limit_discrepancy = calculate_dynamic_max_y_limit( + # discrepancy_losses, window_size=500, scale_factor=1.1 + # ) + max_y_limit_discrepancy = None + plot_discrepancy_loss_subplot( + ax_discrepancy, + discrepancy_losses, + max_y_limit=max_y_limit_discrepancy, + ) + + plt.subplots_adjust(left=0.05, right=0.91, bottom=0.07, top=0.98) + if streamlit_purpose: return fig else: diff --git a/src/VolumeRaytraceLFM/visualization/plotting_ret_azim.py b/src/VolumeRaytraceLFM/visualization/plotting_ret_azim.py index 4ad4c89..027d891 100644 --- a/src/VolumeRaytraceLFM/visualization/plotting_ret_azim.py +++ b/src/VolumeRaytraceLFM/visualization/plotting_ret_azim.py @@ -8,7 +8,7 @@ def plot_birefringence_lines( retardance_img, azimuth_img, - origin="lower", + origin="upper", upscale=1, cmap="Wistia_r", line_color="blue", @@ -32,12 +32,12 @@ def plot_birefringence_lines( lc_data = [[(l_ii[ix], l_jj[ix]), (h_ii[ix], h_jj[ix])] for ix in range(len(l_ii))] colors = retardance_img.flatten() cmap = matplotlib.cm.get_cmap(cmap) - rgba = cmap(colors / (2 * np.pi)) + rgba = cmap(colors / np.pi) lc = matplotlib.collections.LineCollection(lc_data, colors=line_color, linewidths=1) if ax is None: fig, ax = plt.subplots() - im = ax.imshow(retardance_img, origin="lower", cmap=cmap) + im = ax.imshow(retardance_img, origin="upper", cmap=cmap) ax.add_collection(lc) ax.autoscale() ax.margins(0.1) @@ -177,7 +177,7 @@ def plot_retardance_orientation( ): plt.ioff() # Prevents plots from popping up fig = plt.figure(figsize=(12, 3)) - plt.rcParams["image.origin"] = "lower" + plt.rcParams["image.origin"] = "upper" # Retardance subplot plt.subplot(1, 3, 1) plt.imshow(ret_image, cmap="plasma") # viridis @@ -194,6 +194,7 @@ def plot_retardance_orientation( plt.title("Orientation") plt.xticks([]) plt.yticks([]) + plt.clim(0, np.pi) # Combined retardance and orientation subplot ax = plt.subplot(1, 3, 3) if azimuth_plot_type == "lines": diff --git a/src/VolumeRaytraceLFM/visualization/plotting_volume.py b/src/VolumeRaytraceLFM/visualization/plotting_volume.py index a3ef2e9..921319e 100644 --- a/src/VolumeRaytraceLFM/visualization/plotting_volume.py +++ b/src/VolumeRaytraceLFM/visualization/plotting_volume.py @@ -91,6 +91,7 @@ def convert_volume_to_2d_mip( normalize=False, border_thickness=1, add_view_separation_lines=True, + transpose_and_flip=True ): """ Convert a 3D volume to a single 2D Maximum Intensity Projection (MIP) image. @@ -165,6 +166,14 @@ def convert_volume_to_2d_mip( out_img[:, :, :, scaled_vol_size[1] : scaled_vol_size[1] + border_thickness] = ( line_color ) + + # Transpose and flip the MIP image if specified + if transpose_and_flip: + # Transpose the image (swap axes 2 and 3) + out_img = torch.transpose(out_img, 2, 3) + # Flip along the 3rd axis (axis=2 after transpose) + out_img = torch.flip(out_img, dims=[2]) + return out_img diff --git a/src/VolumeRaytraceLFM/visualization/plt_util.py b/src/VolumeRaytraceLFM/visualization/plt_util.py index 626aebf..00def8e 100644 --- a/src/VolumeRaytraceLFM/visualization/plt_util.py +++ b/src/VolumeRaytraceLFM/visualization/plt_util.py @@ -1,14 +1,13 @@ import matplotlib.pyplot as plt -def setup_visualization(window_title, plot_live=True): +def setup_visualization(window_title, plot_live=True, fig_size=(10, 9)): if plot_live: plt.ion() else: plt.ioff() - fig_size = (10, 9) figure = plt.figure(figsize=fig_size) - plt.rcParams["image.origin"] = "lower" + plt.rcParams["image.origin"] = "upper" manager = plt.get_current_fig_manager() manager.set_window_title(window_title) if False: diff --git a/src/VolumeRaytraceLFM/volumes/compare.py b/src/VolumeRaytraceLFM/volumes/compare.py index e69de29..42e0b09 100644 --- a/src/VolumeRaytraceLFM/volumes/compare.py +++ b/src/VolumeRaytraceLFM/volumes/compare.py @@ -0,0 +1,41 @@ +import torch +import torch.nn.functional as F +from VolumeRaytraceLFM.birefringence_implementations import BirefringentVolume + + +def compare_volumes(volume1: BirefringentVolume, volume2: BirefringentVolume, mask=None, only_nonzero=False): + if mask is not None: + Delta_n1 = volume1.get_delta_n().detach() + Delta_n2 = volume2.get_delta_n().detach() + Delta_n1 = Delta_n1[mask] + Delta_n2 = Delta_n2[mask] + else: + Delta_n1 = volume1.get_delta_n().detach() + Delta_n2 = volume2.get_delta_n().detach() + optic_axis1 = volume1.get_optic_axis().detach() + optic_axis2 = volume2.get_optic_axis().detach() + + if only_nonzero: + non_zero_mask = (Delta_n1 != 0) | (Delta_n2 != 0) + Delta_n1 = Delta_n1[non_zero_mask] + Delta_n2 = Delta_n2[non_zero_mask] + optic_axis1 = optic_axis1[:, non_zero_mask] + optic_axis2 = optic_axis2[:, non_zero_mask] + + # Combine Delta_n and optic_axis into a single tensor for both volumes + # Stack Delta_n and optic_axis such that Delta_n corresponds to index [0, ...] + predicted_volume = torch.cat((Delta_n1.unsqueeze(0), optic_axis1), dim=0) # (4, H, W, D) + target_volume = torch.cat((Delta_n2.unsqueeze(0), optic_axis2), dim=0) # (4, H, W, D) + + # Compute the loss + loss = mse_sum(predicted_volume, target_volume) + + return loss + + +def mse_sum(predicted_volume, target_volume): + """Compute the birefringence field loss""" + vector_field1 = predicted_volume[0, ...] * predicted_volume[1:, ...] + vector_field2 = target_volume[0, ...] * target_volume[1:, ...] + loss = F.mse_loss(vector_field1, vector_field2, reduction='sum') + return loss diff --git a/src/VolumeRaytraceLFM/volumes/optic_axis.py b/src/VolumeRaytraceLFM/volumes/optic_axis.py index 159ded8..f6c9c0b 100644 --- a/src/VolumeRaytraceLFM/volumes/optic_axis.py +++ b/src/VolumeRaytraceLFM/volumes/optic_axis.py @@ -101,5 +101,5 @@ def unit_vector_to_spherical(vector): """ z, y, x = vector phi = np.arccos(z) - theta = np.arctan2(y, x) + theta = np.atan2(y, x) return theta, phi diff --git a/src/VolumeRaytraceLFM/volumes/volume_args.py b/src/VolumeRaytraceLFM/volumes/volume_args.py index 58e3479..23ac685 100644 --- a/src/VolumeRaytraceLFM/volumes/volume_args.py +++ b/src/VolumeRaytraceLFM/volumes/volume_args.py @@ -230,13 +230,39 @@ }, } -shell_big_args = { +shell_wide_args = { + "init_mode": "shell", + "init_args": { + "radius": [2.5, 3.5, 5.5], + "center": [0.5, 0.5, 0.5], + "delta_n": -0.01, + "border_thickness": 1, + "tallness": 3, + "highness": 3, + "flip": True, + }, +} + +shell_widewide_args = { "init_mode": "shell", "init_args": { - "radius": [32, 40, 24], + "radius": [2.5, 2.5, 6.5], "center": [0.5, 0.5, 0.5], "delta_n": -0.01, "border_thickness": 1, + "tallness": 3, + "highness": 3, + "flip": True, + }, +} + +shell_big_args = { + "init_mode": "shell", + "init_args": { + "radius": [32, 24, 40], + "center": [0.5, 0.5, 0.5], + "delta_n": -0.01, + "border_thickness": 2, "tallness": 6, # "highness": 8, # "flip": False, diff --git a/src/streamlit_app/pages/2_Reconstruction OG.py b/src/streamlit_app/pages/2_Reconstruction OG.py deleted file mode 100644 index 1789ee3..0000000 --- a/src/streamlit_app/pages/2_Reconstruction OG.py +++ /dev/null @@ -1,586 +0,0 @@ -import streamlit as st - -st.set_page_config( - page_title="Reconstructions", - page_icon="👋", - layout="wide", -) - -st.title("Reconstructions") -st.write("Let's try to reconstruct a volume based on our images!") - -import time -import os -import io -import json -import copy -import numpy as np -import torch -from PIL import Image -import h5py -from tqdm import tqdm -import matplotlib.pyplot as plt -import matplotlib -from VolumeRaytraceLFM.abstract_classes import BackEnds -from VolumeRaytraceLFM.birefringence_implementations import ( - BirefringentVolume, - BirefringentRaytraceLFM, -) -from VolumeRaytraceLFM.visualization.plotting_volume import volume_2_projections -from VolumeRaytraceLFM.visualization.plotting_iterations import plot_iteration_update -from VolumeRaytraceLFM.loss_functions import * - -st.header("Choose our parameters") - -columns = st.columns(2) -# first Column -with columns[0]: - ############ Optical Params ################# - # Get optical parameters template - optical_info = BirefringentVolume.get_optical_info_template() - # Alter some of the optical parameters - st.subheader("Optical") - optical_info["n_micro_lenses"] = st.slider( - "Number of microlenses", min_value=1, max_value=51, value=9 - ) - optical_info["pixels_per_ml"] = st.slider( - "Pixels per microlens", min_value=1, max_value=33, value=17, step=2 - ) - # GT volume simulation - optical_info["n_voxels_per_ml"] = st.slider( - "Number of voxels per microlens (volume sampling)", - min_value=1, - max_value=7, - value=1, - ) - - ############ Reconstruction settings ################# - backend = BackEnds.PYTORCH - st.subheader("Iterative reconstruction parameters") - num_iterations = st.slider("Number of iterations", min_value=1, max_value=500, value=500) - # See loss_functions.py for more details - loss_function = st.selectbox( - "Loss function", ["vonMisses", "vector", "L1_cos", "L1all"], 1 - ) - regularization_function1 = st.selectbox( - "Volume regularization function 1", ["L1", "L2", "unit", "TV", "none"], 2 - ) - regularization_function2 = st.selectbox( - "Volume regularization function 2", ["L1", "L2", "unit", "TV", "none"], 4 - ) - reg_weight1 = st.number_input( - "Regularization weight 1", min_value=0.0, max_value=0.5, value=0.5 - ) - # st.write('The current regularization weight 1 is ', reg_weight1) - reg_weight2 = st.number_input( - "Regularization weight 2", min_value=0.0, max_value=0.5, value=0.5 - ) - # st.write('The current regularization weight 2 is ', reg_weight2) - ret_azim_weight = st.number_input( - "Retardance-Orientation weight", min_value=0.0, max_value=1.0, value=0.5 - ) - # st.write('The current retardance/orientation weight is ', ret_azim_weight) - st.subheader("Initial estimated volume") - volume_init_type = st.selectbox("Initial volume type", ["random", "upload"], 0) - if volume_init_type == "upload": - h5file_init = st.file_uploader("Upload the initial volume h5 Here", type=["h5"]) - delta_n_init_magnitude = 1 - mask_bool = False - else: - mask_bool = st.checkbox("Mask out area unreachable by light rays") - delta_n_init_magnitude = st.number_input( - "Volume Delta_n initial magnitude", - min_value=0.0, - max_value=1.0, - value=0.0001, - format="%0.5f", - ) - # st.write('The current Volume Delta_n initial magnitude is ', delta_n_init_magnitude) - - st.subheader("Learning rate") - learning_rate_delta_n = st.number_input( - "Learning rate for Delta_n", - min_value=0.0, - max_value=10.0, - value=0.001, - format="%0.5f", - ) - # st.write('The current LR is ', learning_rate_delta_n) - learning_rate_optic_axis = st.number_input( - "Learning rate for optic_axis", - min_value=0.0, - max_value=10.0, - value=0.001, - format="%0.5f", - ) - # st.write('The current optic axis LR is ', learning_rate_optic_axis) - - -def key_investigator(key_home, my_str="", prefix="- "): - if hasattr(key_home, "keys"): - for my_key in key_home.keys(): - my_str = my_str + prefix + my_key + "\n" - my_str = key_investigator(key_home[my_key], my_str, "\t" + prefix) - return my_str - - -# Second Column -with columns[1]: - ############ Volume ################# - st.subheader("Ground truth volume") - volume_container = st.container() # set up a home for other volume selections to go - with volume_container: - how_get_vol = st.radio( - "Volume can be created or uploaded as an h5 file", - ["h5 upload", "Create a new volume", "Upload experimental images"], - index=1, - ) - if how_get_vol == "h5 upload": - h5file = st.file_uploader("Upload Volume h5 Here", type=["h5"]) - optical_info["n_voxels_per_ml_volume"] = st.slider( - "Number of voxels per microlens in volume", - min_value=1, - max_value=21, - value=1, - ) - if h5file is not None: - with h5py.File(h5file) as file: - try: - vol_shape = file["optical_info"]["volume_shape"][()] - except KeyError: - st.error("This file does specify the volume shape.") - except Exception as e: - st.error(e) - vol_shape_default = [int(v) for v in vol_shape] - optical_info["volume_shape"] = vol_shape_default - st.markdown( - f"Using a cube volume shape with the dimension of the" - + f" loaded volume: {vol_shape_default}." - ) - - display_h5 = st.checkbox("Display h5 file contents") - if display_h5: - with h5py.File(h5file) as file: - st.markdown("**File Structure:**\n" + key_investigator(file)) - try: - st.markdown( - "**Description:** " - + str(file["optical_info"]["description"][()])[2:-1] - ) - except KeyError: - st.error("This file does not have a description.") - except Exception as e: - st.error(e) - try: - vol_shape = file["optical_info"]["volume_shape"][()] - # optical_info['volume_shape'] = vol_shape - st.markdown(f"**Volume Shape:** {vol_shape}") - except KeyError: - st.error("This file does specify the volume shape.") - except Exception as e: - st.error(e) - try: - voxel_size = file["optical_info"]["voxel_size_um"][()] - st.markdown(f"**Voxel Size (um):** {voxel_size}") - except KeyError: - st.error( - "This file does specify the voxel size. Voxels are likely to be cubes." - ) - except Exception as e: - st.error(e) - elif how_get_vol == "Upload experimental images": - retardance_path = st.file_uploader( - "Upload retardance tif", type=["png", "tif", "tiff"] - ) - azimuth_path = st.file_uploader( - "Upload orientation tif", type=["png", "tif", "tiff"] - ) - metadata_file = st.file_uploader( - "Upload metadata from Napari-LF", type=["txt"] - ) - plot_cropped_imgs = st.empty() # set up a place holder for the plot - - # Load files - if retardance_path is not None: - ret_img_raw = torch.from_numpy( - np.array(Image.open(retardance_path)).astype(np.float32) - ) - if azimuth_path is not None: - azim_img_raw = torch.from_numpy( - np.array(Image.open(azimuth_path)).astype(np.float32) - ) - - if metadata_file is not None: - # Lets load metadata - metadata = metadata_file.read() - metadata = json.loads(metadata) - # MLA data - optical_info["pixels_per_ml"] = ( - metadata["calibrate"]["pixels_per_ml"] - if "pixels_per_ml" in metadata["calibrate"].keys() - else optical_info["pixels_per_ml"] - ) - # optical_info['n_micro_lenses'] = 11 - # optical_info['n_voxels_per_ml'] = 1 - # optical_info['axial_voxel_size_um'] = 1 - # Optics data - optical_info["M_obj"] = metadata["calibrate"]["objective_magnification"] - optical_info["na_obj"] = metadata["calibrate"]["objective_na"] - optical_info["n_medium"] = metadata["calibrate"]["medium_index"] - optical_info["wavelength"] = metadata["calibrate"]["center_wavelength"] - optical_info["camera_pix_pitch"] = metadata["calibrate"]["pixel_size"] - - ### Which part to crop from the images? - - if retardance_path and azimuth_path and metadata_file: - # Crop data based on n_micro_lenses and n_voxels_per_ml - n_mls_y = ret_img_raw.shape[0] // optical_info["pixels_per_ml"] - n_mls_x = ret_img_raw.shape[1] // optical_info["pixels_per_ml"] - - crop_pos_y = st.slider( - "Image region center Y", min_value=1, max_value=n_mls_y, value=55 - ) - crop_pos_x = st.slider( - "Image region center X", min_value=1, max_value=n_mls_x, value=54 - ) - start_ml = [crop_pos_y, crop_pos_x] - start_coords = [sc * optical_info["pixels_per_ml"] for sc in start_ml] - end_coords = [ - sc + optical_info["n_micro_lenses"] * optical_info["pixels_per_ml"] - for sc in start_coords - ] - - st.session_state["ret_image_measured"] = ret_img_raw[ - start_coords[0] : end_coords[0], start_coords[1] : end_coords[1] - ] - st.session_state["azim_image_measured"] = azim_img_raw[ - start_coords[0] : end_coords[0], start_coords[1] : end_coords[1] - ] - - # Plot images - fig = plt.figure(figsize=(6, 3)) - plt.rcParams["image.origin"] = "lower" - plt.subplot(1, 2, 1) - plt.imshow(st.session_state["ret_image_measured"].numpy(), cmap="gray") - plt.subplot(1, 2, 2) - plt.imshow(st.session_state["azim_image_measured"].numpy(), cmap="gray") - plot_cropped_imgs.pyplot(fig) - - st.subheader("Volume shape") - optical_info["volume_shape"][0] = st.slider( - "Axial volume dimension", min_value=1, max_value=50, value=15 - ) - # y will follow x if x is changed. x will not follow y if y is changed - optical_info["volume_shape"][1] = st.slider( - "Y-Z volume dimension", min_value=1, max_value=100, value=51 - ) - optical_info["volume_shape"][2] = optical_info["volume_shape"][1] - else: - optical_info["n_voxels_per_ml_volume"] = st.slider( - "Number of voxels per microlens in volume space", - min_value=1, - max_value=21, - value=1, - ) - volume_type = st.selectbox( - "Volume type", ["ellipsoid", "shell", "2ellipsoids", "single_voxel"], 3 - ) - st.subheader("Volume shape") - optical_info["volume_shape"][0] = st.slider( - "Axial volume dimension", min_value=1, max_value=50, value=5 - ) - # y will follow x if x is changed. x will not follow y if y is changed - optical_info["volume_shape"][1] = st.slider( - "Y-Z volume dimension", min_value=1, max_value=100, value=51 - ) - optical_info["volume_shape"][2] = optical_info["volume_shape"][1] - shift_from_center = st.slider( - "Axial shift from center [voxels]", - min_value=-int(optical_info["volume_shape"][0] / 2), - max_value=int(optical_info["volume_shape"][0] / 2), - value=-1, - ) - volume_axial_offset = ( - optical_info["volume_shape"][0] // 2 + shift_from_center - ) # for center - # Create the volume based on the selections. - with volume_container: - if how_get_vol == "h5 upload": - # Upload ground truth volume from an h5 file - if h5file is not None: - # Lets create a new optical info for volume space, as the sampling might be higher than the reconstruction - # DEBUG: might not be using the 'n_voxels_per_ml_volume' selected above - # deepcopy may disregard the selection made above int he ground truth volume - optical_info_volume = copy.deepcopy(optical_info) - optical_info_volume["n_voxels_per_ml"] = optical_info_volume[ - "n_voxels_per_ml_volume" - ] - # optical_info_volume['n_voxels_per_ml'] = 3 - # optical_info_volume['volume_shape'][1] = 501 - # optical_info_volume['volume_shape'][2] = 501 - st.session_state["my_volume"] = BirefringentVolume.init_from_file( - h5file, backend=backend, optical_info=optical_info_volume - ) - test_vol = st.session_state["my_volume"] = ( - BirefringentVolume.init_from_file( - h5file, backend=backend, optical_info=optical_info_volume - ) - ) - elif how_get_vol == "Upload experimental images": - with torch.no_grad(): - st.session_state["my_volume"] = BirefringentVolume.create_dummy_volume( - backend=backend, - optical_info=optical_info, - vol_type="zeros", - volume_axial_offset=0, - ) - else: - # Lets create a new optical info for volume space, as the sampling might be higher than the reconstruction - optical_info_volume = copy.deepcopy(optical_info) - optical_info_volume["n_voxels_per_ml"] = optical_info_volume[ - "n_voxels_per_ml_volume" - ] - # optical_info_volume['volume_shape'][1] = 501 - # optical_info_volume['volume_shape'][2] = 501 - with torch.no_grad(): - st.session_state["my_volume"] = BirefringentVolume.create_dummy_volume( - backend=backend, - optical_info=optical_info_volume, - vol_type=volume_type, - volume_axial_offset=volume_axial_offset, - ) - -###################################################################### - -# want learning rate to be multiple choice -# lr = st.slider('Learning rate', min_value=1, max_value=5, value=3) -# filename_message = st.text_input('Message to add to the filename (not currently saving anyway..)') -training_params = { - "num_iterations": num_iterations, # How long to train for - "azimuth_weight": ret_azim_weight, # Azimuth loss weight - "regularization_weight": [reg_weight1, reg_weight2], # Regularization weight - "lr": learning_rate_delta_n, # Learning rate for delta_n - "lr_optic_axis": learning_rate_optic_axis, # Learning rate for optic axis - "output_posfix": "", # Output file name posfix - "loss": loss_function, # Loss function - "reg": [ - regularization_function1, - regularization_function2, - ], # Regularization function -} - -if st.button("Reconstruct!"): - my_volume = st.session_state["my_volume"] - - # Create a Birefringent Raytracer - # DEBUG - # Force the volume shape to be smaller for the reconstructed volume - st.subheader("Volume shape for the estimated volume") - estimated_volume_shape = optical_info["volume_shape"].copy() - estimated_volume_shape[0] = st.slider( - "Axial volume dimension for estimated volume", - min_value=1, - max_value=50, - value=5, - ) - # y will follow x if x is changed. x will not follow y if y is changed - estimated_volume_shape[1] = st.slider( - "Y-Z volume dimension for estimated volume", - min_value=1, - max_value=100, - value=31, - ) - estimated_volume_shape[2] = estimated_volume_shape[1] - optical_info["volume_shape"] = estimated_volume_shape - rays = BirefringentRaytraceLFM(backend=backend, optical_info=optical_info) - rays.compute_rays_geometry() - if backend == BackEnds.PYTORCH: - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - # Force cpu, as for now cpu is faster - device = "cpu" - print(f"Using computing device: {device}") - rays = rays.to_device(device) - - # Generate images using the forward model - with torch.no_grad(): - if how_get_vol == "Upload experimental images": - ret_image_measured = st.session_state["ret_image_measured"] - azim_image_measured = st.session_state["azim_image_measured"] - # Normalize data - ret_image_measured /= ret_image_measured.max() - ret_image_measured *= 0.01 - azim_image_measured *= torch.pi / azim_image_measured.max() - else: - # We need a raytracer with different number of voxels per ml for higher sampling measurements - rays_higher_sampling = BirefringentRaytraceLFM( - backend=backend, optical_info=optical_info_volume - ) - rays_higher_sampling.compute_rays_geometry() - # Perform same calculation with torch - start_time = time.time() - [ret_image_measured, azim_image_measured] = ( - rays_higher_sampling.ray_trace_through_volume(my_volume) - ) - execution_time = time.time() - start_time - print("Warmup time in seconds with Torch: " + str(execution_time)) - - # Store GT images - Delta_n_GT = my_volume.get_delta_n().detach().clone() - optic_axis_GT = my_volume.get_optic_axis().detach().clone() - ret_image_measured = ret_image_measured.detach() - azim_image_measured = azim_image_measured.detach() - - ############# - # Let's create an optimizer - # Initial guess - if volume_init_type == "upload" and h5file_init is not None: - volume_estimation = BirefringentVolume.init_from_file( - h5file_init, backend=backend, optical_info=optical_info - ) - else: - volume_estimation = BirefringentVolume( - backend=backend, - optical_info=optical_info, - volume_creation_args={"init_mode": "random"}, - ) - # Let's rescale the random to initialize the volume - volume_estimation.Delta_n.requires_grad = False - volume_estimation.optic_axis.requires_grad = False - volume_estimation.Delta_n *= delta_n_init_magnitude - if mask_bool: - # And mask out volume that is outside FOV of the microscope - mask = rays.get_volume_reachable_region() - volume_estimation.Delta_n[mask.view(-1) == 0] = 0 - volume_estimation.Delta_n.requires_grad = True - volume_estimation.optic_axis.requires_grad = True - - # Indicate to this object that we are going to optimize Delta_n and optic_axis - volume_estimation.members_to_learn.append("Delta_n") - volume_estimation.members_to_learn.append("optic_axis") - volume_estimation = volume_estimation.to(device) - - trainable_parameters = volume_estimation.get_trainable_variables() - - # As delta_n has much lower values than optic_axis, we might need 2 different learning rates - parameters = [ - { - "params": trainable_parameters[0], - "lr": training_params["lr_optic_axis"], - }, # Optic axis - {"params": trainable_parameters[1], "lr": training_params["lr"]}, - ] # Delta_n - - # Create optimizer - optimizer = torch.optim.Adam(parameters, lr=training_params["lr"]) - - # To test differentiability let's define a loss function L = |ret_image_torch|, and minimize it - losses = [] - data_term_losses = [] - regularization_term_losses = [] - - # Create weight mask for the azimuth - # as the azimuth is irrelevant when the retardance is low, lets scale error with a mask - azimuth_damp_mask = (ret_image_measured / ret_image_measured.max()).detach() - - # width = st.sidebar.slider("Plot width", 1, 25, 15) - # height = st.sidebar.slider("Plot height", 1, 25, 8) - - my_plot = st.empty() # set up a place holder for the plot - my_3D_plot = st.empty() # set up a place holder for the 3D plot - - st.write("Working on these ", num_iterations, "iterations...") - my_bar = st.progress(0) - for ep in tqdm(range(training_params["num_iterations"]), "Minimizing"): - optimizer.zero_grad() - - # Forward projection - [ret_image_current, azim_image_current] = rays.ray_trace_through_volume( - volume_estimation - ) - - # Conpute loss and regularization - L, data_term, regularization_term = apply_loss_function_and_reg( - training_params["loss"], - training_params["reg"], - ret_image_measured, - azim_image_measured, - ret_image_current, - azim_image_current, - training_params["azimuth_weight"], - volume_estimation, - training_params["regularization_weight"], - ) - - # Calculate update of the my_volume (Compute gradients of the L with respect to my_volume) - L.backward() - - # Apply gradient updates to the volume - optimizer.step() - with torch.no_grad(): - num_nan_vecs = torch.sum(torch.isnan(volume_estimation.optic_axis[0, :])) - replacement_vecs = torch.nn.functional.normalize( - torch.rand(3, int(num_nan_vecs)), p=2, dim=0 - ) - volume_estimation.optic_axis[ - :, torch.isnan(volume_estimation.optic_axis[0, :]) - ] = replacement_vecs - if ep == 0 and num_nan_vecs != 0: - st.write( - f"Replaced {num_nan_vecs} NaN optic axis vectors with random unit vectors, " - + "likely on every iteration." - ) - # print(f'Ep:{ep} loss: {L.item()}') - losses.append(L.item()) - data_term_losses.append(data_term.item()) - regularization_term_losses.append(regularization_term.item()) - - azim_image_out = azim_image_current.detach() - azim_image_out[azimuth_damp_mask == 0] = 0 - - percent_complete = int(ep / training_params["num_iterations"] * 100) - my_bar.progress(percent_complete + 1) - - if ep % 2 == 0: - matplotlib.pyplot.close() - fig = plot_iteration_update( - volume_2_projections(Delta_n_GT.unsqueeze(0))[0, 0] - .detach() - .cpu() - .numpy(), - ret_image_measured.detach().cpu().numpy(), - azim_image_measured.detach().cpu().numpy(), - volume_2_projections(volume_estimation.get_delta_n().unsqueeze(0))[0, 0] - .detach() - .cpu() - .numpy(), - ret_image_current.detach().cpu().numpy(), - azim_image_current.detach().cpu().numpy(), - losses, - data_term_losses, - regularization_term_losses, - streamlit_purpose=True, - ) - - my_plot.pyplot(fig) - - st.success("Done reconstructing! How does it look?", icon="✅") - st.session_state["my_volume"] = volume_estimation - st.write("Scroll over image to zoom in and out.") - # Todo: use a slider to filter the volume - volume_ths = 0.05 # st.slider('volume ths', min_value=0., max_value=1., value=0.1) - matplotlib.pyplot.close() - my_fig = st.session_state["my_volume"].plot_lines_plotly(delta_n_ths=volume_ths) - st.plotly_chart(my_fig, use_container_width=True) - - st.subheader("Download results") - # print(ret_image_current.detach().cpu().numpy().shape) - # st.download_button('Download estimated retardance', ret_image_current.detach().cpu().numpy().tobytes(), mime="image/jpeg") - # st.download_button('Download estimated orientation', azim_image_current.detach().cpu().numpy().tobytes(), mime="image/jpeg") - - # Save volume to h5 - h5_file = io.BytesIO() - st.download_button( - "Download estimated volume as HDF5 file", - volume_estimation.save_as_file(h5_file), - mime="application/x-hdf5", - ) diff --git a/src/streamlit_app/pages/2_Reconstruction.py b/src/streamlit_app/pages/2_Reconstruction.py index d5869c2..bccaccd 100644 --- a/src/streamlit_app/pages/2_Reconstruction.py +++ b/src/streamlit_app/pages/2_Reconstruction.py @@ -316,8 +316,7 @@ def generate_random_vol(mask=False): volume_creation_args={"init_mode": "random"}, ) # Let's rescale the random to initialize the volume - volume.Delta_n.requires_grad = False - volume.optic_axis.requires_grad = False + volume.set_requires_grad(False) volume.Delta_n *= delta_n_init_magnitude if mask: # And mask out volume that is outside FOV of the microscope diff --git a/tests/fixtures_optical_info.py b/tests/fixtures_optical_info.py index ecbf154..4a529ec 100644 --- a/tests/fixtures_optical_info.py +++ b/tests/fixtures_optical_info.py @@ -18,6 +18,7 @@ def optical_info_vol11(): optical_info["wavelength"] = 0.550 optical_info["n_micro_lenses"] = 1 optical_info["n_voxels_per_ml"] = 1 + optical_info["aperture_radius_px"] = 7.5 optical_info["polarizer"] = np.array([[1, 0], [0, 1]]) optical_info["analyzer"] = np.array([[1, 0], [0, 1]]) @@ -39,6 +40,7 @@ def set_optical_info(vol_shape, pixels_per_ml, num_lenslets): optical_info["M_obj"] = 60 optical_info["cube_voxels"] = True optical_info["camera_pix_pitch"] = 6.5 + optical_info["aperture_radius_px"] = 7.5 optical_info["polarizer"] = np.array([[1, 0], [0, 1]]) optical_info["analyzer"] = np.array([[1, 0], [0, 1]]) optical_info["polarizer_swing"] = 0.03 diff --git a/tests/speed_speed.py b/tests/speed_speed.py index b42dbaf..6bc8aac 100644 --- a/tests/speed_speed.py +++ b/tests/speed_speed.py @@ -825,7 +825,7 @@ def plot_azimuth(img): fig = plt.figure(figsize=(13, 4)) fig.subplots_adjust(bottom=0.025, left=0.025, top=0.975, right=0.975) - plt.rcParams["image.origin"] = "lower" + plt.rcParams["image.origin"] = "upper" plt.clf() sub1 = plt.subplot(1, 3, 1) sub1.imshow(dist_from_ctr) @@ -880,7 +880,7 @@ def plot_ret_azi_image_comparison( if "PYTEST_CURRENT_TEST" in os.environ: return - plt.rcParams["image.origin"] = "lower" + plt.rcParams["image.origin"] = "upper" plt.clf() plt.subplot(3, 2, 1) plt.imshow(ret_img_numpy) diff --git a/tests/test_abstract.py b/tests/test_abstract.py index 562cb95..03be4ff 100644 --- a/tests/test_abstract.py +++ b/tests/test_abstract.py @@ -15,10 +15,11 @@ def ray_trace_lfm_instance(backend_fixture, optical_info_vol11): def test_rays_through_vol( - pixels_per_ml=5, naObj=1.4, nMedium=1.52, volume_ctr_um=np.array([0.5, 0.5, 0.5]) + pixels_per_ml=5, naObj=1.4, nMedium=1.52, + volume_ctr_um=np.array([0.5, 0.5, 0.5]), aperture_radius_px=7.5 ): ray_enter, ray_exit, ray_diff = RayTraceLFM.rays_through_vol( - pixels_per_ml, naObj, nMedium, volume_ctr_um + pixels_per_ml, naObj, nMedium, volume_ctr_um, aperture_radius_px ) rays_shape = ray_enter.shape assert rays_shape == ray_exit.shape == ray_diff.shape @@ -27,8 +28,7 @@ def test_rays_through_vol( @pytest.mark.parametrize("backend_fixture", ["numpy", "pytorch"], indirect=True) def test_compute_lateral_ray_length_and_voxel_span(ray_trace_lfm_instance): - """ - Test that the voxel span is computed correctly. + """Test that the voxel span is computed correctly. The sample ray_diff is created with pixels_per_ml = 5. This function is called by compute_rays_geometry, where the axial_volume_dim is rays.optical_info['volume_shape'][0]. diff --git a/tests/test_all.py b/tests/test_all.py index b42f5f3..a0b69d0 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -390,6 +390,7 @@ def test_forward_projection_lenslet_grid_random_volumes(global_data, volume_shap # Gather global data local_data = copy.deepcopy(global_data) optical_info = local_data["optical_info"] + optical_info["aperture_radius_px"] = optical_info["pixels_per_ml"] / 2 # Volume shape volume_shape = volume_shape_in @@ -398,7 +399,7 @@ def test_forward_projection_lenslet_grid_random_volumes(global_data, volume_shap # The n_micro_lenses defines the active volume area, and it should be # smaller than the volume_shape. # This as some rays go beyond the volume in front of a single micro-lens - optical_info["n_micro_lenses"] = volume_shape[1] - 4 + optical_info["n_micro_lenses"] = volume_shape[1] - 6 optical_info["n_voxels_per_ml"] = 1 # Create Ray-tracing objects @@ -444,18 +445,10 @@ def test_forward_projection_lenslet_grid_random_volumes(global_data, volume_shap ret_img_numpy, azi_img_numpy, ret_img_torch, azi_img_torch ) - assert np.all( - np.isnan(ret_img_numpy) == False - ), "Error in numpy retardance computations nan found" - assert np.all( - np.isnan(azi_img_numpy) == False - ), "Error in numpy azimuth computations nan found" - assert torch.all( - torch.isnan(ret_img_torch) == False - ), "Error in torch retardance computations nan found" - assert torch.all( - torch.isnan(azi_img_torch) == False - ), "Error in torch azimuth computations nan found" + assert not np.any(np.isnan(ret_img_numpy)), "Error in numpy retardance computations nan found" + assert not np.any(np.isnan(azi_img_numpy)), "Error in numpy azimuth computations nan found" + assert not torch.any(torch.isnan(ret_img_torch)), "Error in torch retardance computations nan found" + assert not torch.any(torch.isnan(azi_img_torch)), "Error in torch azimuth computations nan found" assert np.all( np.isclose(ret_img_numpy.astype(np.float32), ret_img_torch.numpy(), atol=1e-5) @@ -1045,9 +1038,6 @@ def main(): # Objective configuration -# Check azimuth images - - def check_azimuth_images( img1, img2, message="Error when comparing azimuth computations" ): @@ -1056,9 +1046,27 @@ def check_azimuth_images( if not np.all(np.isclose(img1, img2, atol=1e-5)): # Check if the difference is a multiple of pi diff = np.abs(img1 - img2) - assert np.all( - np.isclose(diff[~np.isclose(diff, 0.0, atol=1e-5)], np.pi, atol=1e-5) - ), message + # assert np.all( + # np.isclose(diff[~np.isclose(diff, 0.0, atol=1e-5)], np.pi, atol=1e-5) + # ), message + + # Debugging: print the max difference for better diagnosis + max_diff = np.max(diff) + print(f"Max difference: {max_diff:.8f}") + + # Check if the difference is a multiple of pi + # Only consider differences that are not close to zero + non_zero_diff = diff[~np.isclose(diff, 0.0, atol=1e-5)] + + # plot_azimuth(diff) + + # If there are non-zero differences, check if they are close to pi + if non_zero_diff.size > 0: + is_multiple_of_pi = np.all(np.isclose(non_zero_diff, np.pi, atol=1e-5)) + if not is_multiple_of_pi: + print(f"Non-zero differences: {non_zero_diff}") + print(f"Expected differences close to pi, but got max: {np.max(non_zero_diff)}") + assert is_multiple_of_pi, message def plot_azimuth(img): @@ -1070,9 +1078,9 @@ def plot_azimuth(img): dist_from_ctr = np.sqrt((iv - ctr[0]) ** 2 + (jv - ctr[1]) ** 2) fig = plt.figure(figsize=(13, 4)) - fig.subplots_adjust(bottom=0.025, left=0.025, top=0.975, right=0.975) + fig.subplots_adjust(bottom=0, left=0.025, top=0.925, right=0.975) - plt.rcParams["image.origin"] = "lower" + plt.rcParams["image.origin"] = "upper" plt.clf() sub1 = plt.subplot(1, 3, 1) sub1.imshow(dist_from_ctr) @@ -1128,7 +1136,7 @@ def plot_ret_azi_image_comparison( if "PYTEST_CURRENT_TEST" in os.environ: return - plt.rcParams["image.origin"] = "lower" + plt.rcParams["image.origin"] = "upper" plt.clf() plt.subplot(3, 2, 1) plt.imshow(ret_img_numpy) @@ -1165,7 +1173,7 @@ def plot_ret_azi_flipsign_image_comparison( if "PYTEST_CURRENT_TEST" in os.environ: return - plt.rcParams["image.origin"] = "lower" + plt.rcParams["image.origin"] = "upper" plt.clf() plt.subplot(3, 2, 1) plt.imshow(ret_img_pos) diff --git a/tests/test_data/precomputed_images_plane_4_8_8_16_4.pt b/tests/test_data/precomputed_images_plane_4_8_8_16_4.pt index 80d4604..a340f88 100644 Binary files a/tests/test_data/precomputed_images_plane_4_8_8_16_4.pt and b/tests/test_data/precomputed_images_plane_4_8_8_16_4.pt differ diff --git a/tests/test_data/precomputed_images_shell_small_7_18_18_16_9.pt b/tests/test_data/precomputed_images_shell_small_7_18_18_16_9.pt index a89c84f..24dc5ba 100644 Binary files a/tests/test_data/precomputed_images_shell_small_7_18_18_16_9.pt and b/tests/test_data/precomputed_images_shell_small_7_18_18_16_9.pt differ diff --git a/tests/test_data/precomputed_images_sphere2_11_30_30_16_11.pt b/tests/test_data/precomputed_images_sphere2_11_30_30_16_11.pt index 3c2fd77..798181f 100644 Binary files a/tests/test_data/precomputed_images_sphere2_11_30_30_16_11.pt and b/tests/test_data/precomputed_images_sphere2_11_30_30_16_11.pt differ diff --git a/tests/test_data/precomputed_images_voxel_3_5_5_16_1.pt b/tests/test_data/precomputed_images_voxel_3_5_5_16_1.pt index 53a9218..62ffaf1 100644 Binary files a/tests/test_data/precomputed_images_voxel_3_5_5_16_1.pt and b/tests/test_data/precomputed_images_voxel_3_5_5_16_1.pt differ diff --git a/tests/test_data/precomputed_images_voxel_3_9_9_16_5.pt b/tests/test_data/precomputed_images_voxel_3_9_9_16_5.pt index d2f07f2..f82963e 100644 Binary files a/tests/test_data/precomputed_images_voxel_3_9_9_16_5.pt and b/tests/test_data/precomputed_images_voxel_3_9_9_16_5.pt differ diff --git a/tests/test_orientation.py b/tests/test_orientation.py new file mode 100644 index 0000000..f9b46ad --- /dev/null +++ b/tests/test_orientation.py @@ -0,0 +1,163 @@ +"""Test the azimuth image for different optic axis configurations.""" + +import pytest +import torch +import math +from VolumeRaytraceLFM.abstract_classes import BackEnds +from VolumeRaytraceLFM.simulations import ForwardModel +from VolumeRaytraceLFM.birefringence_implementations import BirefringentVolume + +# Setting up constants +BACKEND = BackEnds.PYTORCH + +@pytest.fixture +def optical_system_fixture(): + optical_info = { + "volume_shape": [1, 3, 3], + "axial_voxel_size_um": 1.0, + "cube_voxels": True, + "pixels_per_ml": 1, + "n_micro_lenses": 1, + "n_voxels_per_ml": 1, + "M_obj": 60, + "na_obj": 1.2, + "n_medium": 1.35, + "wavelength": 0.550, + "aperture_radius_px": 1, + "camera_pix_pitch": 6.5, + "polarizer": [[1, 0], [0, 1]], + "analyzer": [[1, 0], [0, 1]], + "polarizer_swing": 0.03 + } + return {"optical_info": optical_info} + + +@pytest.fixture(scope="function") +def setup_simulator(optical_system_fixture): + simulator = ForwardModel(optical_system_fixture, backend=BACKEND) + simulator.rays.prepare_for_all_rays_at_once() + return simulator + + +# Function to test the azimuth image for a given volume +def compare_azimuth_image(simulator, volume, expected_output): + simulator.forward_model(volume, all_lenslets=True) + azim = simulator.azim_img + print(f" - Computed Azimuth: {azim.item():.4f}\n" + f" - Expected Output: {expected_output.item():.4f}") + azim = torch.where( + torch.isclose(azim, torch.tensor(torch.pi, dtype=azim.dtype), atol=1e-8), + torch.tensor(0.0, dtype=azim.dtype), + azim + ) + # Check if the azimuth image matches the expected output + assert torch.allclose(azim, expected_output.float(), atol=1e-6), f"Azimuth image is not as expected. Got {azim}, expected {expected_output}" + + +# Function to create a birefringent volume with given optic axis and delta_n +def create_birefringent_volume(optical_system_fixture, optic_axis, delta_n=-0.05): + voxel_args = { + "init_mode": "single_voxel", + "init_args": {"delta_n": delta_n, "offset": [0, 0, 0], "optic_axis": optic_axis}, + } + volume = BirefringentVolume( + backend=BACKEND, + optical_info=optical_system_fixture["optical_info"], + volume_creation_args=voxel_args, + ) + volume.set_requires_grad(False) + return volume + + +# Parametrize the test with different configurations +@pytest.mark.parametrize( + "optic_axis, delta_n, expected_output", [ + # Positive Birefringence (delta_n > 0) + ([1, 0, 0], 0.05, torch.tensor([[0]])), # azim = 0 + ([-1, 0, 0], 0.05, torch.tensor([[0]])), # azim = 0 + ([0, 1, 0], 0.05, torch.tensor([[math.pi/2]])), # azim = pi/2 + ([0, -1, 0], 0.05, torch.tensor([[math.pi/2]])), # azim = pi/2 + ([0, 0, 1], 0.05, torch.tensor([[0]])), # azim = 0 + ([0, 0, -1], 0.05, torch.tensor([[0]])), # azim = 0 + ([0, 1, 1], 0.05, torch.tensor([[math.pi/4]])), # azim = pi/4 + ([0, -1, -1], 0.05, torch.tensor([[math.pi/4]])), # azim = pi/4 + ([0, -1, 1], 0.05, torch.tensor([[3*math.pi/4]])), # azim = 3pi/4 + ([0, 1, -1], 0.05, torch.tensor([[3*math.pi/4]])), # azim = 3pi/4 + + # Negative Birefringence (delta_n < 0) + ([1, 0, 0], -0.05, torch.tensor([[0]])), # azim = 0 + ([-1, 0, 0], -0.05, torch.tensor([[0]])), # azim = 0 + ([0, 1, 0], -0.05, torch.tensor([[0]])), # azim flipped to 0 + ([0, -1, 0], -0.05, torch.tensor([[0]])), # azim flipped to 0 + ([0, 0, 1], -0.05, torch.tensor([[math.pi/2]])), # azim = pi/2 + ([0, 0, -1], -0.05, torch.tensor([[math.pi/2]])), # azim = pi/2 + ([0, 1, 1], -0.05, torch.tensor([[3*math.pi/4]])), # azim flipped to 3pi/4 + ([0, -1, -1], -0.05, torch.tensor([[3*math.pi/4]])), # azim flipped to 3pi/4 + ([0, -1, 1], -0.05, torch.tensor([[math.pi/4]])), # azim flipped to pi/4 + ([0, 1, -1], -0.05, torch.tensor([[math.pi/4]])), # azim flipped to pi/4 + ] +) +def test_azimuth_images(optic_axis, delta_n, expected_output, setup_simulator, optical_system_fixture): + """Test different optic axis configurations with positive and negative birefringence.""" + # Create the simulator (only once for all tests) + simulator = setup_simulator + + # Create a birefringent volume for the test + volume = create_birefringent_volume(optical_system_fixture,optic_axis, delta_n=delta_n) + + # Run the test on the azimuth image + compare_azimuth_image(simulator, volume, expected_output) + + +if __name__ == "__main__": + # Define 10 different configurations for volumes + volumes_config_positive_birefringent = [ + {"optic_axis": [1, 0, 0], "expected_output": torch.tensor([[0]])}, # azim = 0 + {"optic_axis": [-1, 0, 0], "expected_output": torch.tensor([[0]])}, # azim = 0 + {"optic_axis": [0, 1, 0], "expected_output": torch.tensor([[math.pi/2]])}, # azim = pi/2 + {"optic_axis": [0, -1, 0], "expected_output": torch.tensor([[math.pi/2]])}, # azim = pi/2 + {"optic_axis": [0, 0, 1], "expected_output": torch.tensor([[0]])}, # azim = 0 + {"optic_axis": [0, 0, -1], "expected_output": torch.tensor([[0]])}, # azim = 0 + {"optic_axis": [0, 1, 1], "expected_output": torch.tensor([[math.pi/4]])}, # azim = pi/4 + {"optic_axis": [0, -1, -1], "expected_output": torch.tensor([[math.pi/4]])}, # azim = pi/4 + {"optic_axis": [0, -1, 1], "expected_output": torch.tensor([[3*math.pi/4]])}, # azim = 3pi/4 + {"optic_axis": [0, 1, -1], "expected_output": torch.tensor([[3*math.pi/4]])}, # azim = 3pi/4 + ] + + volumes_config_negative_birefringent = [ + {"optic_axis": [1, 0, 0], "expected_output": torch.tensor([[0]])}, + {"optic_axis": [-1, 0, 0], "expected_output": torch.tensor([[0]])}, + {"optic_axis": [0, 1, 0], "expected_output": torch.tensor([[0]])}, + {"optic_axis": [0, -1, 0], "expected_output": torch.tensor([[0]])}, + {"optic_axis": [0, 0, 1], "expected_output": torch.tensor([[math.pi/2]])}, + {"optic_axis": [0, 0, -1], "expected_output": torch.tensor([[math.pi/2]])}, + {"optic_axis": [0, 1, 1], "expected_output": torch.tensor([[3*math.pi/4]])}, + {"optic_axis": [0, -1, -1], "expected_output": torch.tensor([[3*math.pi/4]])}, + {"optic_axis": [0, -1, 1], "expected_output": torch.tensor([[math.pi/4]])}, + {"optic_axis": [0, 1, -1], "expected_output": torch.tensor([[math.pi/4]])}, + ] + + + # optical_system["optical_info"]["pixels_per_ml"] = 4 + # optical_system["optical_info"]["aperture_radius_px"] = 3 + simulator = setup_simulator(optical_system_fixture) + + birefringence = 0.05 + print("Birefringence: ", birefringence) + if birefringence > 0: + volumes_config = volumes_config_positive_birefringent[::2] + else: + volumes_config = volumes_config_negative_birefringent[::2] + + single_only = False + if not single_only: + # Iterate through the 10 volume configurations, run the tests, and check results + for idx, config in enumerate(volumes_config): + volume = create_birefringent_volume(config["optic_axis"], delta_n=birefringence) + print(f"Testing volume {idx + 1} with optic_axis: {config['optic_axis']}") + compare_azimuth_image(simulator, volume, config["expected_output"]) + # print(f"Test {idx + 1} passed.") + else: + volume = create_birefringent_volume([1, 0, 0], delta_n=birefringence) + simulator.forward_model(volume, all_lenslets=True) + simulator.view_images() diff --git a/tests/test_simulation_results.py b/tests/test_simulation_results.py index 02c38b2..8cdc5fe 100644 --- a/tests/test_simulation_results.py +++ b/tests/test_simulation_results.py @@ -17,7 +17,16 @@ from tests.fixtures_optical_info import set_optical_info BACKEND = BackEnds.PYTORCH +PARAMETER_SETS = [ + ("voxel", [3, 5, 5], 16, 1), + ("voxel", [3, 9, 9], 16, 5), + ("sphere2", [11, 30, 30], 16, 11), + ("plane", [4, 8, 8], 16, 4), + ("shell_small", [7, 18, 18], 16, 9), +] +# Unit test idea: images of a (shifted) voxel should be the same for +# all odd axial dimension volumes def create_simulator(optical_info, backend): optical_system = {"optical_info": optical_info} @@ -73,13 +82,7 @@ def compare_images(generated_images, saved_images): @pytest.mark.parametrize( "vol_type, vol_shape, pixels_per_ml, n_lenslets", - [ - ("voxel", [3, 5, 5], 16, 1), - ("voxel", [3, 9, 9], 16, 5), - ("sphere2", [11, 30, 30], 16, 11), - ("plane", [4, 8, 8], 16, 4), - ("shell_small", [7, 18, 18], 16, 9), - ], + PARAMETER_SETS, ) @pytest.mark.slow def test_simulation(vol_type, vol_shape, pixels_per_ml, n_lenslets): @@ -99,6 +102,10 @@ def test_simulation(vol_type, vol_shape, pixels_per_ml, n_lenslets): if __name__ == "__main__": from fixtures_optical_info import set_optical_info - images = run_simulation("shell_small", [7, 18, 18], 16, 9) - filename = generate_filename("shell_small", [7, 18, 18], 16, 9) - save_images(images, filename) + # images = run_simulation("shell_small", [7, 18, 18], 16, 9) + + # Loop through the parameter sets and execute functions + for vol_type, vol_shape, pixels_per_ml, n_lenslets in PARAMETER_SETS: + images = run_simulation(vol_type, vol_shape, pixels_per_ml, n_lenslets) + filename = generate_filename(vol_type, vol_shape, pixels_per_ml, n_lenslets) + save_images(images, filename)