From 577a73a4c317dfc7297adb35d93b44457036e276 Mon Sep 17 00:00:00 2001 From: Geneva Schlafly Date: Wed, 6 Dec 2023 14:29:34 -0600 Subject: [PATCH 1/6] Establish reconstruction method of a voxel --- .gitignore | 3 +- config_settings/optical_config_voxel.json | 16 ++++++++++ run_recon.py | 39 +++++++++++++++++++++-- 3 files changed, 54 insertions(+), 4 deletions(-) create mode 100644 config_settings/optical_config_voxel.json diff --git a/.gitignore b/.gitignore index 7e99e36..94e8a67 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ -*.pyc \ No newline at end of file +*.pyc +reconstructions/* \ No newline at end of file diff --git a/config_settings/optical_config_voxel.json b/config_settings/optical_config_voxel.json new file mode 100644 index 0000000..a01c953 --- /dev/null +++ b/config_settings/optical_config_voxel.json @@ -0,0 +1,16 @@ +{ + "volume_shape" : [3, 7, 7], + "axial_voxel_size_um" : 1.0, + "cube_voxels" : true, + "pixels_per_ml" : 17, + "n_micro_lenses" : 1, + "n_voxels_per_ml" : 1, + "M_obj" : 60, + "na_obj" : 1.2, + "n_medium" : 1.35, + "wavelength" : 0.550, + "camera_pix_pitch" : 6.5, + "polarizer" : [[1, 0], [0, 1]], + "analyzer" : [[1, 0], [0, 1]], + "polarizer_swing" : 0.03 +} \ No newline at end of file diff --git a/run_recon.py b/run_recon.py index 345f47f..1802049 100644 --- a/run_recon.py +++ b/run_recon.py @@ -55,7 +55,41 @@ def recon_gpu(): reconstructor.to_device(DEVICE) # Move the reconstructor to the GPU reconstructor.reconstruct(output_dir=recon_directory) - visualize_volume(reconstructor.volume_pred, reconstructor.optical_info) + visualize_volume(reconstructor.volume_pred, reconstructor.optical_info) + +def recon(): + optical_info = setup_optical_parameters("config_settings\optical_config_voxel.json") + optical_system = {'optical_info': optical_info} + # Initialize the forward model. Raytracing is performed as part of the initialization. + simulator = ForwardModel(optical_system, backend=BACKEND) + # Volume creation + volume_GT = BirefringentVolume( + backend=BACKEND, + optical_info=optical_info, + volume_creation_args=volume_args.voxel_args + ) + visualize_volume(volume_GT, optical_info) + + simulator.forward_model(volume_GT) + # simulator.view_images() + ret_image_meas = simulator.ret_img + azim_image_meas = simulator.azim_img + + recon_optical_info = optical_info.copy() + iteration_params = setup_iteration_parameters("config_settings\iter_config.json") + initial_volume = BirefringentVolume( + backend=BackEnds.PYTORCH, + optical_info=recon_optical_info, + volume_creation_args = volume_args.random_args + ) + recon_directory = create_unique_directory("reconstructions") + recon_config = ReconstructionConfig(recon_optical_info, ret_image_meas, azim_image_meas, + initial_volume, iteration_params, gt_vol=volume_GT) + recon_config.save(recon_directory) + # recon_config_recreated = ReconstructionConfig.load(recon_directory) + reconstructor = Reconstructor(recon_config) + reconstructor.reconstruct(output_dir=recon_directory) + visualize_volume(reconstructor.volume_pred, reconstructor.optical_info) def main(): optical_info = setup_optical_parameters("config_settings\optical_config_largemla.json") @@ -91,5 +125,4 @@ def main(): visualize_volume(reconstructor.volume_pred, reconstructor.optical_info) if __name__ == '__main__': - main() - # recon_gpu() + recon() From b15939451b7907cf3c0280127122beef6e9bd6bc Mon Sep 17 00:00:00 2001 From: Geneva Schlafly Date: Wed, 6 Dec 2023 19:33:03 -0600 Subject: [PATCH 2/6] Create function to modify the volume before recon Added simulated images to test the reconstruction process. Currently, the new function are not in use within the Reconstructor class. This can be changed by replacing 'False' with 'True'. All tests pass. --- VolumeRaytraceLFM/optic_config.py | 23 +++++++++-- VolumeRaytraceLFM/reconstructions.py | 54 ++++++++++++++++++++++++- forward_images/azim_voxel_1mla.npy | Bin 0 -> 1284 bytes forward_images/azim_voxel_pos_1mla.npy | Bin 0 -> 1284 bytes forward_images/ret_voxel_1mla.npy | Bin 0 -> 1284 bytes forward_images/ret_voxel_pos_1mla.npy | Bin 0 -> 1284 bytes run_recon.py | 43 +++++++++++++------- 7 files changed, 99 insertions(+), 21 deletions(-) create mode 100644 forward_images/azim_voxel_1mla.npy create mode 100644 forward_images/azim_voxel_pos_1mla.npy create mode 100644 forward_images/ret_voxel_1mla.npy create mode 100644 forward_images/ret_voxel_pos_1mla.npy diff --git a/VolumeRaytraceLFM/optic_config.py b/VolumeRaytraceLFM/optic_config.py index a31b637..a37fea8 100644 --- a/VolumeRaytraceLFM/optic_config.py +++ b/VolumeRaytraceLFM/optic_config.py @@ -6,26 +6,41 @@ except: pass -class OpticBlock(nn.Module): # pure virtual class +class OpticBlock(nn.Module): """Base class containing all the basic functionality of an optic block""" def __init__( self, optic_config=None, members_to_learn=None, - ): # Contains a list of members which should be optimized (In case none are provided members are created without gradients) + ): + """ + Initialize the OpticBlock. + Args: + optic_config (optional): Configuration for the optic block. Defaults to None. + members_to_learn (optional): List of members to be optimized. Defaults to None. + """ super(OpticBlock, self).__init__() self.optic_config = optic_config self.members_to_learn = [] if members_to_learn is None else members_to_learn self.device_dummy = nn.Parameter(torch.tensor([1.0])) - def get_trainable_variables(self): + """ + Get the trainable variables of the optic block. + Returns: + list: List of trainable variables. + """ trainable_vars = [] for name, param in self.named_parameters(): if name in self.members_to_learn: trainable_vars.append(param) return list(trainable_vars) - + def get_device(self): + """ + Get the device of the optic block. + Returns: + torch.device: The device of the optic block. + """ return self.device_dummy.device diff --git a/VolumeRaytraceLFM/reconstructions.py b/VolumeRaytraceLFM/reconstructions.py index 39cfe2b..53d57f8 100644 --- a/VolumeRaytraceLFM/reconstructions.py +++ b/VolumeRaytraceLFM/reconstructions.py @@ -244,6 +244,7 @@ def specify_variables_to_learn(self, learning_vars=None): Specify which variables of the initial volume object should be considered for learning. This method updates the 'members_to_learn' attribute of the initial volume object, ensuring no duplicates are added. + The variable names must be attributes of the BirefringentVolume class. Args: learning_vars (list): Variable names to be appended for learning. Defaults to ['Delta_n', 'optic_axis']. @@ -331,11 +332,22 @@ def _compute_loss(self, retardance_pred: torch.Tensor, azimuth_pred: torch.Tenso def one_iteration(self, optimizer, volume_estimation): optimizer.zero_grad() + # The in-place operation may cause problems with the gradient tracking + if False: + Delta_n_combined = torch.cat([volume_estimation.Delta_n_first_half, volume_estimation.Delta_n_second_half], dim=0) + volume_estimation.Delta_n = torch.nn.Parameter(Delta_n_combined) # Apply forward model [ret_image_current, azim_image_current] = self.rays.ray_trace_through_volume(volume_estimation) loss, data_term, regularization_term = self._compute_loss(ret_image_current, azim_image_current) loss.backward() + + # One method would be to set the gradients of the second half to zero + if False: + half_length = volume_estimation.Delta_n.size(0) // 2 + volume_estimation.Delta_n.grad[half_length:] = 0 + + # Note: This is where volume_estimation.Delta_n.grad becomes non-zero optimizer.step() self.ret_img_pred = ret_image_current.detach().cpu().numpy() @@ -348,6 +360,9 @@ def one_iteration(self, optimizer, volume_estimation): def visualize_and_save(self, ep, fig, output_dir): volume_estimation = self.volume_pred + # Delta_n_combined = torch.cat([volume_estimation.Delta_n_first_half, volume_estimation.Delta_n_second_half], dim=0) + # Delta_n_combined.retain_grad() + # volume_estimation.Delta_n = torch.nn.Parameter(Delta_n_combined) if ep % 1 == 0: # plt.clf() mip_image = convert_volume_to_2d_mip(volume_estimation.get_delta_n().detach().unsqueeze(0)) @@ -374,16 +389,51 @@ def visualize_and_save(self, ep, fig, output_dir): volume_estimation.save_as_file(os.path.join(output_dir, f"volume_ep_{'{:02d}'.format(ep)}.h5")) return + def modify_volume(self): + """ + Method to modify the initial volume guess. + """ + volume = self.volume_pred + Delta_n = volume.Delta_n + length = Delta_n.size(0) + half_length = length // 2 + + # Split Delta_n into two parts + # volume.Delta_n_first_half = Delta_n[:half_length].clone().detach().requires_grad_(True) + # volume.Delta_n_second_half = Delta_n[half_length:].clone().detach().requires_grad_(False) + volume.Delta_n_first_half = torch.nn.Parameter(Delta_n[:half_length].clone()) + # volume.Delta_n_second_half = Delta_n[half_length:].clone().detach() # This remains as a tensor attribute + volume.Delta_n_second_half = torch.nn.Parameter(Delta_n[half_length:].clone()) + + # Below is false + # Now, Delta_n_first_half has requires_grad=True and Delta_n_second_half has requires_grad=False + + # # During the forward pass, combine them + Delta_n_combined = torch.cat([volume.Delta_n_first_half, volume.Delta_n_second_half], dim=0) + # volume.Delta_n = torch.nn.Parameter(Delta_n_combined) + + # Update Delta_n of BirefringentVolume directly + with torch.no_grad(): # Temporarily disable gradient tracking + volume.Delta_n[:] = Delta_n_combined # Update the value in-place + return + def reconstruct(self, output_dir=None): """ Method to perform the actual reconstruction based on the provided parameters. """ if output_dir is None: output_dir = create_unique_directory("reconstructions") - # self.restrict_volume_to_reachable_region() - self.specify_variables_to_learn() + # Turn off the gradients for the initial volume guess self._turn_off_initial_volume_gradients() + + # Adjust the estimated volume variable + # self.restrict_volume_to_reachable_region() + param_list = ['Delta_n_first_half'] + self.specify_variables_to_learn() + if False: + self.modify_volume() + optimizer = self.optimizer_setup(self.volume_pred, self.iteration_params) figure = setup_visualization() # Iterations diff --git a/forward_images/azim_voxel_1mla.npy b/forward_images/azim_voxel_1mla.npy new file mode 100644 index 0000000000000000000000000000000000000000..cb0b630b8391b7a9c8fe60237f722da33d658dd4 GIT binary patch literal 1284 zcmcJOJ!lhQ7{{X`LTET$I<3YDN?PGii35~ z#U)M^)WNBO+KS0g)I~bgO#>2IDGpV{O}qH~4z5(Ziio~&T=FLW-}8UIZXZ9{kvqLE zb|E%u6}+PBTl-C`y(ewirqwg(5BvFn?m@rcaepY^SM=n*cs@Vm$@kr7&A>tXRLbPX>mU} zUDq4c?^qvZ8}(wnSx;QtrSCm$(a-za14eB75oc!Zh4ZB=uI4JXdTV4OIWLoOJwIh= zIF9gVEnwWOT?#y0t?mr;b=Om8%cH+=I@9ZQa_c6Ym`(&7aw4Cg9+C$+5Wm&! zSA~UZs$u%(PmH~jOX|DxRb^j&Q?T%aAL;^wI-t)>Y_xr6%BotQP~$@nRya4>YRcdD zUNv{ls}sE+10F0mFwmn#S9tPJm1if_!KKH_UYuHC)Qs0^t%n}Gli$~R!XCjpAly5` z_ash5;w2=mco1zi@PjY%&56$@@r&1gAmzM*m=;TYZSVAbm$)k~zS`Bbf^r%n4qY8#v4n zo>uF_&*0T1NyO!fo@dqX~WA*U!^)ZH=J S(-GNM80;^59QD6{b@v1J20Eeu literal 0 HcmV?d00001 diff --git a/forward_images/azim_voxel_pos_1mla.npy b/forward_images/azim_voxel_pos_1mla.npy new file mode 100644 index 0000000000000000000000000000000000000000..40c658c665ec454640b54bbd1d5e3f713203c322 GIT binary patch literal 1284 zcmcJOPe>GD7{;e4f`x7BA|14bil(I@vKF=t`QGmkftF!lWOT4?c9tOjF)Lw-B6KK% zx^%HFUhL4xpwh@WM0E*I(anWeOr(Q>bgM3XzYkXWV;V&R!|u%P^LwB7eV^~n$lb)M&_54iTikyz?b@FyP z5IDK8+j=QIeqb4Gmp)hz>8Z5UFH{@!#-3fja1Mtb)lx@FMv~o@Lr?U>Jm`ZSs6SoZ zp~tT5(f#4R9~dWwZmW@GTqSQ_C|GBVC;TuM7|a1Zwb710IiQaZWc5LvTjmV3ZdK3g zGiuw;7WK4p)$(A$fq@<^I`2YWKkIJN?y%AuzREJgah&Wbw=?TD>)+i*XD@hg;Zuit z)TMsK_`nao==Y)VxyAU!H~wo3`a(yaXz3e0bAXXHbAiL0U@pF literal 0 HcmV?d00001 diff --git a/forward_images/ret_voxel_1mla.npy b/forward_images/ret_voxel_1mla.npy new file mode 100644 index 0000000000000000000000000000000000000000..960bb53d013b7eed9bbe2962440089de8fa8ec8e GIT binary patch literal 1284 zcmcJOPe@cz6vkhx%0-Clk_RrH!3YgTM$FukP%Q$7f?DWC2OSZLCTGGL2cw0nmX2*y zh-%X^sAUTmqArZe3ThwGqD71TOru!ePv^p8`X^FvF=x&>-~H}6-#NFEIp1@pcULeL zl%t_saj*~_3!~KV;V2PC!}-FMLiSRBzA%&{ekyyhm{YtslD(Xh-=6GBgvqXhiEtv^ z;nx)W$8YZbsJ*BS+Vq8fdnt>*-RqV!H*IYoZ{L>A+0^O@TfAYL7Qd$WlJbAzTDx`I zrazYK+Qdm~ecER49yQzJ>rI;$pE!ANzvAJ3@*C;1Ha^;FpPp3Ro4)66c4Wc5lC8=o z{#|)+;2U|Ub4>Su%e3t__i)}FO;()im~=<5@`;y~zbf3i@NteA=~Y#~)XV+Hg}+k{ z+};m!>6Xq?dXMZ~`NY2}e?U0!obx=Z(jEQYtN+)koBKI;dg^oP!LwrOkSsoN@^6K! z2+zDc2lGRJ^rL@M{rhzPiq3<9hkEFXJUH;2^BnNS>!5%7@O5;n4t(lBXXXW-`aK7I z)#`eDsq1ksrU&Qo(F48E4SZY=xX_uoleI*6EQF6nS2H8dnd>Ru8I5)BCh3r$;*5Y<)$O~qg`Dhm>W#VBDk z7z6|H5d*uh7z_xl!T)#ZC6|kjMxx8p_w)a}&--3B)Zg7RP@+w1X(Jj-MUqCdZUjdf z48LxS#FJCW@WfC&8I9rZ43DQ`%AXnyPsWs86KM47fyP?Dp3(pK^JxG0TiZ^^buJ<+ zV?!2V7yI0CXi0LzamgL`N%pE;juu3g6=HuV|A~tK)GNQbDJxI2GLz|$s&k**?^VeD zyvM?7W#cDK-dBC-Ph2y5WiH{B=d(BScJR_%A3d}X8$WUK;Op)KPu?Lf>lh7{n3y4+Z2D^r=A=To4JUSM?Q6SkMDSNsrOG+ zJ?P{CUwAuta6j*AKKAF Date: Wed, 6 Dec 2023 21:07:31 -0600 Subject: [PATCH 3/6] Optimize birefringence subsets The forward model can now be based on a new Delta_n variable. This variable can be a torch concatenation of two torch arrays, one with gradients and one not. A similar approach should be done for optic_axis. To use, change OPTIMIZING_MODE to True at the top of birefringence_implementations.py. All tests pass. --- .../birefringence_implementations.py | 6 +- VolumeRaytraceLFM/reconstructions.py | 76 +++++++++++++------ config_settings/iter_config.json | 2 +- run_recon.py | 4 +- 4 files changed, 59 insertions(+), 29 deletions(-) diff --git a/VolumeRaytraceLFM/birefringence_implementations.py b/VolumeRaytraceLFM/birefringence_implementations.py index a14f440..d3a9dd2 100644 --- a/VolumeRaytraceLFM/birefringence_implementations.py +++ b/VolumeRaytraceLFM/birefringence_implementations.py @@ -6,6 +6,7 @@ from tifffile import imsave NORM_PROJ = False # normalize the projection of the ray onto the optic axis +OPTIMIZING_MODE = False # use the birefringence stored in Delta_n_combined class BirefringentElement(OpticalElement): ''' Birefringent element, such as voxel, raytracer, etc, extending optical element, @@ -1178,7 +1179,10 @@ def calc_cummulative_JM_of_ray_torch(self, volume_in : BirefringentVolume, # Extract the information from the volume # Birefringence try: - Delta_n = volume_in.Delta_n[vox] + if OPTIMIZING_MODE: + Delta_n = volume_in.Delta_n_combined[vox] + else: + Delta_n = volume_in.Delta_n[vox] # And axis opticAxis = volume_in.optic_axis[:,vox].permute(1,0) # Grab the subset of precomputed ray directions that have voxels in this step diff --git a/VolumeRaytraceLFM/reconstructions.py b/VolumeRaytraceLFM/reconstructions.py index 53d57f8..13ad52f 100644 --- a/VolumeRaytraceLFM/reconstructions.py +++ b/VolumeRaytraceLFM/reconstructions.py @@ -25,6 +25,10 @@ reshape_and_crop, store_as_pytorch_parameter ) + +COMBINING_DELTA_N = True +DEBUG = False + class ReconstructionConfig: def __init__(self, optical_info, ret_image, azim_image, initial_vol, iteration_params, loss_fcn=None, gt_vol=None): """ @@ -331,16 +335,30 @@ def _compute_loss(self, retardance_pred: torch.Tensor, azimuth_pred: torch.Tenso def one_iteration(self, optimizer, volume_estimation): optimizer.zero_grad() - - # The in-place operation may cause problems with the gradient tracking - if False: - Delta_n_combined = torch.cat([volume_estimation.Delta_n_first_half, volume_estimation.Delta_n_second_half], dim=0) - volume_estimation.Delta_n = torch.nn.Parameter(Delta_n_combined) + + if COMBINING_DELTA_N: + Delta_n_combined = torch.cat([volume_estimation.Delta_n_first_part, volume_estimation.Delta_n_second_part], dim=0) + # Attempt to update Delta_n of BirefringentVolume directly + # The in-place operation causes problems with the gradient tracking + # with torch.no_grad(): # Temporarily disable gradient tracking + # volume_estimation.Delta_n[:] = Delta_n_combined # Update the value in-place + volume_estimation.Delta_n_combined = Delta_n_combined # Apply forward model [ret_image_current, azim_image_current] = self.rays.ray_trace_through_volume(volume_estimation) loss, data_term, regularization_term = self._compute_loss(ret_image_current, azim_image_current) + # Verify the gradients before and after the backward pass + if DEBUG: + print("\nBefore backward pass:") + print("requires_grad:", volume_estimation.Delta_n_first_part.requires_grad) + print("Gradient for Delta_n_first_part:", volume_estimation.Delta_n_first_part.grad) + print("Gradient for Delta_n_second_part:", volume_estimation.Delta_n_second_part.grad) loss.backward() + if DEBUG: + print("\nAfter backward pass:") + print("requires_grad:", volume_estimation.Delta_n_first_part.requires_grad) + print("Gradient for Delta_n_first_part:", volume_estimation.Delta_n_first_part.grad) + print("Gradient for Delta_n_second_part:", volume_estimation.Delta_n_second_part.grad) # One method would be to set the gradients of the second half to zero if False: @@ -365,7 +383,11 @@ def visualize_and_save(self, ep, fig, output_dir): # volume_estimation.Delta_n = torch.nn.Parameter(Delta_n_combined) if ep % 1 == 0: # plt.clf() - mip_image = convert_volume_to_2d_mip(volume_estimation.get_delta_n().detach().unsqueeze(0)) + if COMBINING_DELTA_N: + Delta_n = volume_estimation.Delta_n_combined.view(self.optical_info['volume_shape']).detach().unsqueeze(0) + else: + Delta_n = volume_estimation.get_delta_n().detach().unsqueeze(0) + mip_image = convert_volume_to_2d_mip(Delta_n) mip_image_np = prepare_plot_mip(mip_image, plot=False) plot_iteration_update_gridspec( self.birefringence_mip_sim, @@ -399,22 +421,24 @@ def modify_volume(self): half_length = length // 2 # Split Delta_n into two parts - # volume.Delta_n_first_half = Delta_n[:half_length].clone().detach().requires_grad_(True) - # volume.Delta_n_second_half = Delta_n[half_length:].clone().detach().requires_grad_(False) - volume.Delta_n_first_half = torch.nn.Parameter(Delta_n[:half_length].clone()) - # volume.Delta_n_second_half = Delta_n[half_length:].clone().detach() # This remains as a tensor attribute - volume.Delta_n_second_half = torch.nn.Parameter(Delta_n[half_length:].clone()) - - # Below is false - # Now, Delta_n_first_half has requires_grad=True and Delta_n_second_half has requires_grad=False - - # # During the forward pass, combine them - Delta_n_combined = torch.cat([volume.Delta_n_first_half, volume.Delta_n_second_half], dim=0) - # volume.Delta_n = torch.nn.Parameter(Delta_n_combined) - - # Update Delta_n of BirefringentVolume directly - with torch.no_grad(): # Temporarily disable gradient tracking - volume.Delta_n[:] = Delta_n_combined # Update the value in-place + # volume.Delta_n_first_half = torch.nn.Parameter(Delta_n[:half_length].clone()) + # volume.Delta_n_second_half = torch.nn.Parameter(Delta_n[half_length:].clone(), requires_grad=False) + + Delta_n_reshaped = Delta_n.clone().view(3, 7, 7) + + # Extract the middle row of each plane + # The middle row index in each 7x7 plane is 3 + Delta_n_first_part = Delta_n_reshaped[:, 3, :] # Shape: (3, 7) + volume.Delta_n_first_part = torch.nn.Parameter(Delta_n_first_part.flatten()) + + # Concatenate slices before and after the middle row for each plane + Delta_n_second_part = torch.cat([Delta_n_reshaped[:, :3, :], # Rows before the middle + Delta_n_reshaped[:, 4:, :]], # Rows after the middle + dim=1) # Concatenate along the row dimension + volume.Delta_n_second_part = torch.nn.Parameter(Delta_n_second_part.flatten(), requires_grad=False) + + # Unsure the affect of turning off the gradients for Delta_n + Delta_n.requires_grad = False return def reconstruct(self, output_dir=None): @@ -429,10 +453,12 @@ def reconstruct(self, output_dir=None): # Adjust the estimated volume variable # self.restrict_volume_to_reachable_region() - param_list = ['Delta_n_first_half'] - self.specify_variables_to_learn() - if False: + if COMBINING_DELTA_N: self.modify_volume() + param_list = ['Delta_n_first_part', 'optic_axis'] # 'Delta_n_second_part' + else: + param_list = ['Delta_n', 'optic_axis'] + self.specify_variables_to_learn(param_list) optimizer = self.optimizer_setup(self.volume_pred, self.iteration_params) figure = setup_visualization() diff --git a/config_settings/iter_config.json b/config_settings/iter_config.json index e4ff734..0bc6b11 100644 --- a/config_settings/iter_config.json +++ b/config_settings/iter_config.json @@ -2,7 +2,7 @@ "n_epochs": 31, "azimuth_weight": 0.5, "regularization_weight": 0.1, - "lr": 1e-3, + "lr": 1e-2, "output_posfix": "", "loss_function": "" } \ No newline at end of file diff --git a/run_recon.py b/run_recon.py index 1799f6e..6a45a8e 100644 --- a/run_recon.py +++ b/run_recon.py @@ -1,7 +1,6 @@ import os import numpy as np import torch -import time from VolumeRaytraceLFM.abstract_classes import BackEnds from VolumeRaytraceLFM.simulations import ForwardModel from VolumeRaytraceLFM.birefringence_implementations import BirefringentVolume @@ -73,12 +72,13 @@ def recon(): optical_info=optical_info, volume_creation_args=volume_args.voxel_args ) - # visualize_volume(volume_GT, optical_info) + visualize_volume(volume_GT, optical_info) simulator.forward_model(volume_GT) # simulator.view_images() ret_image_meas = simulator.ret_img azim_image_meas = simulator.azim_img + # Save the images as numpy arrays # ret_numpy = ret_image_meas.detach().numpy() # np.save('forward_images/ret_voxel_pos_1mla.npy', ret_numpy) # azim_numpy = azim_image_meas.detach().numpy() From b08fb8e2355e52ee0be4f9c1ad6315385f2690d0 Mon Sep 17 00:00:00 2001 From: Geneva Schlafly Date: Thu, 7 Dec 2023 17:11:48 -0600 Subject: [PATCH 4/6] Document raytracing through volume functions I explained the steps involved in some processes using a pytorch backend. I also modularized the method ray_trace_through_volume() for clarity purposes. All tests pass. --- .../birefringence_implementations.py | 311 ++++++++++++------ VolumeRaytraceLFM/reconstructions.py | 4 +- 2 files changed, 206 insertions(+), 109 deletions(-) diff --git a/VolumeRaytraceLFM/birefringence_implementations.py b/VolumeRaytraceLFM/birefringence_implementations.py index d3a9dd2..8e2bf54 100644 --- a/VolumeRaytraceLFM/birefringence_implementations.py +++ b/VolumeRaytraceLFM/birefringence_implementations.py @@ -896,7 +896,7 @@ def __init__( backend=backend, torch_args=torch_args, optical_info=optical_info ) - # Ray-voxel colisions for different micro-lenses, + # Ray-voxel collisions for different microlenses, # this dictionary gets filled in: calc_cummulative_JM_of_ray_torch self.vox_indices_ml_shifted = {} self.vox_indices_ml_shifted_all = [] @@ -932,7 +932,7 @@ def get_volume_reachable_region(self): return mask.detach() def precompute_MLA_volume_geometry(self): - """ Expand the ray-voxel interactions from a single micro-lens to an nxn MLA""" + """ Expand the ray-voxel interactions from a single microlens to an nxn MLA""" if self.MLA_volume_geometry_ready: return # volume_shape defines the size of the workspace @@ -945,7 +945,7 @@ def precompute_MLA_volume_geometry(self): n_voxels_per_ml_half = floor(self.optical_info['n_voxels_per_ml'] * n_micro_lenses / 2.0) # Check if the volume_size can fit these micro_lenses. - # # considering that some rays go beyond the volume in front of the micro-lens + # # considering that some rays go beyond the volume in front of the microlens # border_size_around_mla = np.ceil((volume_shape[1]-(n_micro_lenses*n_voxels_per_ml)) / 2) min_needed_volume_size = int(self.voxel_span_per_ml + (n_micro_lenses*n_voxels_per_ml)) assert min_needed_volume_size <= volume_shape[1] and min_needed_volume_size <= volume_shape[2], "The volume in front of the microlenses" + \ @@ -953,12 +953,12 @@ def precompute_MLA_volume_geometry(self): f"Increase the volume_shape to at least [{min_needed_volume_size+1},{min_needed_volume_size+1}]" odd_mla_shift = np.mod(n_micro_lenses,2) - # Iterate micro-lenses in y direction + # Iterate microlenses in y direction for iix,ml_ii in tqdm(enumerate(range(-n_ml_half, n_ml_half+odd_mla_shift)), - f'Computing rows of micro-lens ret+azim {self.backend}'): - # Iterate micro-lenses in x direction + f'Computing rows of microlens ret+azim {self.backend}'): + # Iterate microlenses in x direction for jjx,ml_jj in enumerate(range(-n_ml_half, n_ml_half+odd_mla_shift)): - # Compute offset to top corner of the volume in front of the micro-lens (ii,jj) + # Compute offset to top corner of the volume in front of the microlens (ii,jj) current_offset = ( np.array([n_voxels_per_ml * ml_ii, n_voxels_per_ml*ml_jj]) + np.array(self.vox_ctr_idx[1:]) - n_voxels_per_ml_half @@ -980,7 +980,7 @@ def precompute_MLA_volume_geometry(self): self.ray_valid_indices + torch.tensor([jjx * n_pixels_per_ml, iix * n_pixels_per_ml]).unsqueeze(1)), 1) - # Replicate ray info for all the micro-lenses + # Replicate ray info for all the microlenses self.ray_vol_colli_lengths = nn.Parameter(self.ray_vol_colli_lengths.repeat(n_micro_lenses ** 2, 1)) self.ray_direction_basis = nn.Parameter(self.ray_direction_basis.repeat(1, n_micro_lenses ** 2, 1)) @@ -990,10 +990,18 @@ def precompute_MLA_volume_geometry(self): def ray_trace_through_volume(self, volume_in : BirefringentVolume = None, all_rays_at_once=True, intensity=False): """ This function forward projects a whole volume, by iterating through - the volume in front of each micro-lens in the system. We compute an offset + the volume in front of each microlens in the system. We compute an offset (current_offset) that shifts the volume indices reached by each ray. - Then we accumulate the images generated by each micro-lens, + Then we accumulate the images generated by each microlens, and concatenate in a final image. + + Args: + volume_in (BirefringentVolume): The volume to be processed. + all_rays_at_once (bool): Flag to indicate whether all rays should be processed at once. + intensity (bool): Flag to indicate whether to generate intensity images. + + Returns: + list[ImageType]: A list of images resulting from the ray tracing process. """ # volume_shape defines the size of the workspace @@ -1003,53 +1011,104 @@ def ray_trace_through_volume(self, volume_in : BirefringentVolume = None, n_voxels_per_ml = self.optical_info['n_voxels_per_ml'] n_ml_half = floor(n_micro_lenses / 2.0) - n_voxels_per_ml_half = floor(self.optical_info['n_voxels_per_ml'] * n_micro_lenses / 2.0) - - # Check if the volume_size can fit these micro_lenses. - # # considering that some rays go beyond the volume in front of the micro-lens + # Check if the volume_size can fit these microlenses. + # # considering that some rays go beyond the volume in front of the microlenses # border_size_around_mla = np.ceil((volume_shape[1]-(n_micro_lenses*n_voxels_per_ml)) / 2) - min_needed_volume_size = int(self.voxel_span_per_ml + (n_micro_lenses*n_voxels_per_ml)) - assert min_needed_volume_size <= volume_shape[1] and min_needed_volume_size <= volume_shape[2], "The volume in front of the microlenses" + \ + min_required_volume_size = self._calculate_min_volume_size(n_micro_lenses, n_voxels_per_ml) + self._validate_volume_size(min_required_volume_size, volume_shape) + # The following assert statement is redundant, but it is kept for clarity + assert min_required_volume_size <= volume_shape[1] and min_required_volume_size <= volume_shape[2], "The volume in front of the microlenses" + \ f"({n_micro_lenses},{n_micro_lenses}) is too large for a volume_shape: {self.optical_info['volume_shape'][1:]}. " + \ - f"Increase the volume_shape to at least [{min_needed_volume_size+1},{min_needed_volume_size+1}]" + f"Increase the volume_shape to at least [{min_required_volume_size+1},{min_required_volume_size+1}]" # Traverse volume for every ray, and generate intensity images or retardance and azimuth images + + # Initialize a list to store the final concatenated images full_img_list = [None] * 5 - odd_mla_shift = np.mod(n_micro_lenses,2) - # Iterate micro-lenses in y direction - for ml_ii in tqdm(range(-n_ml_half, n_ml_half+odd_mla_shift), + + # Calculate shift for odd number of microlenses + odd_mla_shift = np.mod(n_micro_lenses, 2) + + # Iterate over each row of microlenses (y direction) + for ml_ii in tqdm(range(-n_ml_half, n_ml_half + odd_mla_shift), f'Computing rows of microlenses {self.backend}'): + + # Initialize a list for storing concatenated images of the current row full_img_row_list = [None] * 5 - # Iterate micro-lenses in x direction + + # Iterate over each column of microlenses in teh current row (x direction) for ml_jj in range(-n_ml_half, n_ml_half+odd_mla_shift): - # Compute offset to top corner of the volume in front of the micro-lens (ii,jj) - current_offset = np.array([n_voxels_per_ml * ml_ii, n_voxels_per_ml * ml_jj]) + np.array(self.vox_ctr_idx[1:]) - n_voxels_per_ml_half - # Compute images for current microlens, by passing an offset to this function depending on the micro lens and the super resolution - if intensity: - img_list = self.intensity_images(volume_in, micro_lens_offset=current_offset) - else: - img_list = self.ret_and_azim_images(volume_in, micro_lens_offset=current_offset) - # If this is the first image, create + + # Calculate the offset to the top corner of the volume in front of + # the current microlens (ml_ii, ml_jj) + current_offset = self._calculate_current_offset(ml_ii, ml_jj, n_voxels_per_ml, n_micro_lenses) + + # Generate (intensity or ret/azim) images for the current microlens, by passing an offset to this function + # depending on the microlens and the super resolution + img_list = self._generate_images(volume_in, current_offset, intensity) + + # Concatenate the generated images with the images of the current row if full_img_row_list[0] is None: full_img_row_list = img_list - else: # Concatenate to existing image otherwise - if self.backend == BackEnds.NUMPY: - full_img_row_list = [np.concatenate((img0, img1), 0) - for img0, img1 in zip(full_img_row_list, img_list)] - elif self.backend == BackEnds.PYTORCH: - full_img_row_list = [torch.concatenate((img0, img1), 0) - for img0, img1 in zip(full_img_row_list, img_list)] + else: + full_img_row_list = self._concatenate_images(full_img_row_list, img_list, axis=0) + + # Concatenate the row images with the full image list if full_img_list[0] is None: full_img_list = full_img_row_list else: - if self.backend == BackEnds.NUMPY: - full_img_list = [np.concatenate((img0, img1), 1) - for img0, img1 in zip(full_img_list, full_img_row_list)] - elif self.backend == BackEnds.PYTORCH: - full_img_list = [torch.concatenate((img0, img1), 1) - for img0, img1 in zip(full_img_list, full_img_row_list)] + full_img_list = self._concatenate_images(full_img_list, full_img_row_list, axis=1) + return full_img_list + def _calculate_min_volume_size(self, num_microlenses, num_voxels_per_ml): + return int(self.voxel_span_per_ml + (num_microlenses * num_voxels_per_ml)) + + def _validate_volume_size(self, min_required_volume_size, volume_shape): + if min_required_volume_size > volume_shape[1] or min_required_volume_size > volume_shape[2]: + raise ValueError(f"The required volume size ({min_required_volume_size}) exceeds the provided volume shape {volume_shape[1:]}.") + + def _calculate_current_offset(self, row_index, col_index, num_voxels_per_ml, num_microlenses): + """Maps the position of a microlens in its array to the corresponding position + in the volumetric data, identified by its row and column indices. This function + calculates the offset to the top corner of the volume in front of the current microlens. + + Args: + row_index (int): The row index of the current microlens in the microlens array. + col_index (int): The column index of the current microlens in the microlens array. + num_voxels_per_ml (int): The number of voxels per microlens, indicating the + size of the voxel area each microlens covers. + num_microlenses (int): The total number of microlenses in one dimension of the microlens array. + + Returns: + np.array: An array representing the calculated offset in the volumetric data for the current microlens. + """ + # Scale row and column indices to voxel space. This is important when using supersampling. + scaled_indices = np.array([num_voxels_per_ml * row_index, num_voxels_per_ml * col_index]) + + # Add central indices of the volume. This shifts the focus to the relevant part of the volume + # based on the predefined central indices (vox_ctr_idx). + central_offset = np.array(self.vox_ctr_idx[1:]) + + # Compute the midpoint of the total voxel space covered by the microlenses. This value is subtracted + # to center the offset around the middle of the microlens array + half_voxel_span = floor(num_voxels_per_ml * num_microlenses / 2.0) + + # Calculate and return the final offset for the current microlens + return scaled_indices + central_offset - half_voxel_span + + def _generate_images(self, volume, offset, intensity): + if intensity: + return self.intensity_images(volume, microlens_offset=offset) + else: + return self.ret_and_azim_images(volume, microlens_offset=offset) + + def _concatenate_images(self, img_list1, img_list2, axis): + if self.backend == BackEnds.NUMPY: + return [np.concatenate((img1, img2), axis) for img1, img2 in zip(img_list1, img_list2)] + elif self.backend == BackEnds.PYTORCH: + return [torch.concatenate((img1, img2), axis) for img1, img2 in zip(img_list1, img_list2)] + def retardance(self, JM): '''Phase delay introduced between the fast and slow axis in a Jones Matrix''' if self.backend == BackEnds.NUMPY: @@ -1095,14 +1154,14 @@ def azimuth(self, JM): # azimuth[pi_index] = 0 return azimuth - def calc_cummulative_JM_of_ray(self, volume_in : BirefringentVolume, micro_lens_offset=[0,0]): + def calc_cummulative_JM_of_ray(self, volume_in : BirefringentVolume, microlens_offset=[0,0]): if self.backend==BackEnds.NUMPY: - return self.calc_cummulative_JM_of_ray_numpy(volume_in, micro_lens_offset) + return self.calc_cummulative_JM_of_ray_numpy(volume_in, microlens_offset) elif self.backend==BackEnds.PYTORCH: - return self.calc_cummulative_JM_of_ray_torch(volume_in, micro_lens_offset) + return self.calc_cummulative_JM_of_ray_torch(volume_in, microlens_offset) def calc_cummulative_JM_of_ray_numpy(self, i, j, - volume_in : BirefringentVolume, micro_lens_offset=[0,0]): + volume_in : BirefringentVolume, microlens_offset=[0,0]): '''For the (i,j) pixel behind a single microlens''' # Fetch precomputed Siddon parameters voxels_of_segs, ell_in_voxels = self.ray_vol_colli_indices, self.ray_vol_colli_lengths @@ -1116,12 +1175,12 @@ def calc_cummulative_JM_of_ray_numpy(self, i, j, ell = ell_in_voxels[n_ray][m] vox = voxels_of_segs[n_ray][m] Delta_n = volume_in.Delta_n[vox[0], - vox[1]+micro_lens_offset[0], - vox[2]+micro_lens_offset[1]] + vox[1]+microlens_offset[0], + vox[2]+microlens_offset[1]] opticAxis = volume_in.optic_axis[:, vox[0], - vox[1]+micro_lens_offset[0], - vox[2]+micro_lens_offset[1]] + vox[1]+microlens_offset[0], + vox[2]+microlens_offset[1]] JM = self.voxRayJM(Delta_n, opticAxis, rayDir, ell, self.optical_info['wavelength']) JM_list.append(JM) except: @@ -1130,28 +1189,43 @@ def calc_cummulative_JM_of_ray_numpy(self, i, j, return material_JM def calc_cummulative_JM_of_ray_torch(self, volume_in : BirefringentVolume, - micro_lens_offset=[0,0], all_rays_at_once=False): - '''This function computes the Jones Matrices of all rays defined in this object. - It uses pytorch's batch dimension to store each ray, and process them in parallel''' - # Fetch the voxels traversed per ray and the lengths - # that each ray travels through every voxel + microlens_offset=[0,0], all_rays_at_once=False): + """ + Computes the cumulative Jones Matrices (JM) for all rays defined in a BirefringentVolume + object using PyTorch. This function can process rays either all at once or individually + based on the `all_rays_at_once` flag. It uses pytorch's batch dimension to store each ray, + and process them in parallel. + + Args: + volume_in (BirefringentVolume): The volume through which rays are passing. + microlens_offset (list, optional): Offset [x, y] for the microlens. Defaults to [0, 0]. + all_rays_at_once (bool, optional): If True, processes all rays simultaneously. Defaults to False. + + Returns: + torch.Tensor: The cumulative Jones Matrices for the rays. + torch.Size([n_rays_with_voxels, 2, 2]) + """ + # Fetch the lengths that each ray travels through every voxel ell_in_voxels = self.ray_vol_colli_lengths + + # Determine voxel indices based on the processing mode. The voxel indices correspond + # to the voxels that each ray segment traverses. if all_rays_at_once: voxels_of_segs = self.vox_indices_ml_shifted_all else: - # Compute the 1D index of each micro-lens. - # compute once and store for later. - # accessing 1D arrays increases training speed by 25% - key = str(micro_lens_offset) + # Compute the 1D index for each microlens and store for later use + # Accessing 1D arrays increases training speed by 25% + key = str(microlens_offset) if key not in self.vox_indices_ml_shifted: self.vox_indices_ml_shifted[key] = [ [RayTraceLFM.ravel_index((vox[ix][0], - vox[ix][1]+micro_lens_offset[0], - vox[ix][2]+micro_lens_offset[1]), + vox[ix][1] + microlens_offset[0], + vox[ix][2] + microlens_offset[1]), self.optical_info['volume_shape']) for ix in range(len(vox))] for vox in self.ray_vol_colli_indices ] voxels_of_segs = self.vox_indices_ml_shifted[key] + # DEBUG # print("DEBUG: making the optical info of volume and self the same") # print("vol in: ", volume_in.optical_info) @@ -1162,54 +1236,62 @@ def calc_cummulative_JM_of_ray_torch(self, volume_in : BirefringentVolume, # errors.compare_dicts(self.optical_info, volume_in.optical_info) # except ValueError as e: # print('Optical info between ray-tracer and volume mismatch. ' + \ - # 'This might cause issues on the border micro-lenses.') + # 'This might cause issues on the border microlenses.') + + # Initialize material Jones Matrix + # Note: This could allow the try statement to be removed, but it is kept for clarity + material_JM = None + + # Process interactions of all rays with each voxel # Iterate the interactions of all rays with the m-th voxel - # Some rays interact with less voxels, so we mask the rays valid - # for this step with rays_with_voxels + # Some rays interact with less voxels, so we mask the rays valid with rays_with_voxels for m in range(self.ray_vol_colli_lengths.shape[1]): - # Check which rays still have voxels to traverse - rays_with_voxels = [len(vx)>m for vx in voxels_of_segs] - # How many rays at this step + # Determine which rays have remaining voxels to traverse + rays_with_voxels = [len(vx) > m for vx in voxels_of_segs] # n_rays_with_voxels = sum(rays_with_voxels) - # The lengths these rays traveled through the current voxels - ell = ell_in_voxels[rays_with_voxels,m] - # The voxel coordinates each ray collides with - vox = [vx[m] for ix,vx in enumerate(voxels_of_segs) if rays_with_voxels[ix]] + # print(f"The number of rays with voxels to transverse at this step is {n_rays_with_voxels}") + + # Get the lengths rays traveled through the m-th voxel + ell = ell_in_voxels[rays_with_voxels, m] + + # Get the voxel coordinates each ray interacts with + vox = [vx[m] for ix, vx in enumerate(voxels_of_segs) if rays_with_voxels[ix]] - # Extract the information from the volume - # Birefringence try: + # Extract the birefringence and optic axis information from the volume if OPTIMIZING_MODE: Delta_n = volume_in.Delta_n_combined[vox] else: Delta_n = volume_in.Delta_n[vox] - # And axis opticAxis = volume_in.optic_axis[:,vox].permute(1,0) - # Grab the subset of precomputed ray directions that have voxels in this step - filtered_rayDir = self.ray_direction_basis[:,rays_with_voxels,:] + + # Subset of precomputed ray directions that interact with voxels in this step + filtered_ray_directions = self.ray_direction_basis[:, rays_with_voxels, :] # Compute the interaction from the rays with their corresponding voxels - JM = self.voxRayJM( Delta_n = Delta_n, - opticAxis = opticAxis, - rayDir = filtered_rayDir, - ell = ell, - wavelength=self.optical_info['wavelength']) - if m==0: + JM = self.voxRayJM(Delta_n=Delta_n, opticAxis=opticAxis, + rayDir=filtered_ray_directions, ell=ell, + wavelength=self.optical_info['wavelength']) + + # Combine the current Jones Matrix with the cumulative one + if m == 0: material_JM = JM else: material_JM[rays_with_voxels,...] = material_JM[rays_with_voxels,...] @ JM + except: raise Exception("Error accessing the volume, try increasing the volume size in Y-Z") + return material_JM - def ret_and_azim_images(self, volume_in : BirefringentVolume, micro_lens_offset=[0,0]): + def ret_and_azim_images(self, volume_in : BirefringentVolume, microlens_offset=[0,0]): '''Calculate retardance and azimuth values for a ray with a Jones Matrix''' if self.backend==BackEnds.NUMPY: - return self.ret_and_azim_images_numpy(volume_in, micro_lens_offset) + return self.ret_and_azim_images_numpy(volume_in, microlens_offset) elif self.backend==BackEnds.PYTORCH: - return self.ret_and_azim_images_torch(volume_in, micro_lens_offset) + return self.ret_and_azim_images_torch(volume_in, microlens_offset) - def ret_and_azim_images_numpy(self, volume_in : BirefringentVolume, micro_lens_offset=[0,0]): + def ret_and_azim_images_numpy(self, volume_in : BirefringentVolume, microlens_offset=[0,0]): '''Calculate retardance and azimuth values for a ray with a Jones Matrix''' pixels_per_ml = self.optical_info['pixels_per_ml'] ret_image = np.zeros((pixels_per_ml, pixels_per_ml)) @@ -1220,7 +1302,7 @@ def ret_and_azim_images_numpy(self, volume_in : BirefringentVolume, micro_lens_o ret_image[i, j] = 0 azim_image[i, j] = 0 else: - effective_JM = self.calc_cummulative_JM_of_ray_numpy(i, j, volume_in, micro_lens_offset) + effective_JM = self.calc_cummulative_JM_of_ray_numpy(i, j, volume_in, microlens_offset) ret_image[i, j] = self.retardance(effective_JM) if np.isclose(ret_image[i, j], 0.0): azim_image[i, j] = 0 @@ -1259,22 +1341,35 @@ def ret_and_azim_images_mla_torch(self, volume_in : BirefringentVolume): # values = azimuth, size=(pixels_per_ml, pixels_per_ml)).to_dense() return [ret_image, azim_image] - def ret_and_azim_images_torch(self, volume_in : BirefringentVolume, micro_lens_offset=[0,0]): - '''This function computes the retardance and azimuth images - of the precomputed rays going through a volume''' - # Include offset to move to the center of the volume, - # as the ray collisions are computed only for a single micro-lens + def ret_and_azim_images_torch(self, volume_in : BirefringentVolume, microlens_offset=[0,0]): + """ + Computes the retardance and azimuth images for a given volume and microlens offset using PyTorch. - # Fetch needed variables + This function calculates the retardance and azimuth values for the (precomputed) rays + passing through a specific region of the volume, as determined by the microlens offset. + It generates two images: one for retardance and one for azimuth, for a single microlens. + This offset is included to move the center of the volume, as the ray collisions are + computed only for a single microlens. + + Args: + volume_in (BirefringentVolume): The volume through which rays are passing. + microlens_offset (list): The offset [x, y] to the center of the volume for the specific microlens. + + Returns: + list: A list containing two PyTorch tensors, one for the retardance image and one for the azimuth image. + """ + + # Fetch the number of pixels per microlens array from the optic configuration pixels_per_ml = self.optic_config.mla_config.n_pixels_per_mla - # Calculate Jones Matrices for all rays - effective_JM = self.calc_cummulative_JM_of_ray(volume_in, micro_lens_offset) - # Calculate retardance and azimuth + # Calculate Jones Matrices for all rays given the volume and microlens offset + effective_JM = self.calc_cummulative_JM_of_ray(volume_in, microlens_offset) + + # Calculate retardance and azimuth from the effective Jones Matrices retardance = self.retardance(effective_JM) azimuth = self.azimuth(effective_JM) - # Create output images + # Initialize output images for retardance and azimuth on the appropriate device ret_image = torch.zeros((pixels_per_ml, pixels_per_ml), dtype=torch.float32, requires_grad=True, device=self.get_device()) azim_image = torch.zeros((pixels_per_ml, pixels_per_ml), dtype=torch.float32, @@ -1282,23 +1377,25 @@ def ret_and_azim_images_torch(self, volume_in : BirefringentVolume, micro_lens_o ret_image.requires_grad = False azim_image.requires_grad = False - # Fill the values in the images - ret_image[self.ray_valid_indices[0,:],self.ray_valid_indices[1,:]] = retardance - azim_image[self.ray_valid_indices[0,:],self.ray_valid_indices[1,:]] = azimuth - # Alternative version + # Fill the calculated values into the images at the valid ray indices + ret_image[self.ray_valid_indices[0,:], self.ray_valid_indices[1,:]] = retardance + azim_image[self.ray_valid_indices[0,:], self.ray_valid_indices[1,:]] = azimuth + + # Alternative implementation using sparse tensors (commented out) # ret_image = torch.sparse_coo_tensor(indices = self.ray_valid_indices, # values = retardance, size=(pixels_per_ml, pixels_per_ml)).to_dense() # azim_image = torch.sparse_coo_tensor(indices = self.ray_valid_indices, # values = azimuth, size=(pixels_per_ml, pixels_per_ml)).to_dense() + return [ret_image, azim_image] - def intensity_images(self, volume_in : BirefringentVolume, micro_lens_offset=[0,0]): + def intensity_images(self, volume_in : BirefringentVolume, microlens_offset=[0,0]): '''Calculate intensity images using Jones Calculus. The polarizer and analyzer are applied to the cummulated Jones matrices.''' analyzer = self.optical_info['analyzer'] swing = self.optical_info['polarizer_swing'] pixels_per_ml = self.optical_info['pixels_per_ml'] - lenslet_JM = self.calc_cummulative_JM_lenslet(volume_in, micro_lens_offset) + lenslet_JM = self.calc_cummulative_JM_lenslet(volume_in, microlens_offset) intensity_image_list = [np.zeros((pixels_per_ml, pixels_per_ml))] * 5 # if not self.MLA_volume_geometry_ready: @@ -1323,23 +1420,23 @@ def intensity_images(self, volume_in : BirefringentVolume, micro_lens_offset=[0, return intensity_image_list def calc_cummulative_JM_lenslet(self, volume_in : BirefringentVolume, - micro_lens_offset=[0,0]): + microlens_offset=[0,0]): '''Calculate the Jones matrix associated with each pixel behind a lenslet.''' pixels_per_ml = self.optical_info['pixels_per_ml'] lenslet = np.zeros((pixels_per_ml, pixels_per_ml, 2, 2), dtype=np.complex128) - is_nan = np.isnan if self.backend == BackEnds.PYTORCH: lenslet = torch.from_numpy(lenslet).to(volume_in.Delta_n.device) is_nan = torch.isnan - lenslet = self.calc_cummulative_JM_of_ray_torch(volume_in, micro_lens_offset) + lenslet = self.calc_cummulative_JM_of_ray_torch(volume_in, microlens_offset) else: + is_nan = np.isnan for i in range(pixels_per_ml): for j in range(pixels_per_ml): if not is_nan(self.ray_entry[0, i, j]): # Due to the optics, no light reaches the pixel # TODO: verify that the Jones matrix should be zeros instead of identity lenslet[i, j, :, :] = self.calc_cummulative_JM_of_ray_numpy( - i, j, volume_in, micro_lens_offset + i, j, volume_in, microlens_offset ) return lenslet diff --git a/VolumeRaytraceLFM/reconstructions.py b/VolumeRaytraceLFM/reconstructions.py index 13ad52f..4e21ca1 100644 --- a/VolumeRaytraceLFM/reconstructions.py +++ b/VolumeRaytraceLFM/reconstructions.py @@ -26,7 +26,7 @@ store_as_pytorch_parameter ) -COMBINING_DELTA_N = True +COMBINING_DELTA_N = False DEBUG = False class ReconstructionConfig: @@ -342,7 +342,7 @@ def one_iteration(self, optimizer, volume_estimation): # The in-place operation causes problems with the gradient tracking # with torch.no_grad(): # Temporarily disable gradient tracking # volume_estimation.Delta_n[:] = Delta_n_combined # Update the value in-place - volume_estimation.Delta_n_combined = Delta_n_combined + volume_estimation.Delta_n_combined = torch.nn.Parameter(Delta_n_combined) # Apply forward model [ret_image_current, azim_image_current] = self.rays.ray_trace_through_volume(volume_estimation) loss, data_term, regularization_term = self._compute_loss(ret_image_current, azim_image_current) From 2896e443b1fde81547ce8dd985b710b37dd013f8 Mon Sep 17 00:00:00 2001 From: Geneva Schlafly Date: Fri, 8 Dec 2023 13:24:51 -0600 Subject: [PATCH 5/6] Refactor birefringence_implementations.py Moved the Jones Calculus implementation to their own script. To avoid circular dependencies, I moved BirefringentElement to its own script too. All tests pass. --- VolumeRaytraceLFM/birefringence_base.py | 14 ++ .../birefringence_implementations.py | 204 +----------------- VolumeRaytraceLFM/jones_calculus.py | 195 +++++++++++++++++ forward_intensity.py | 4 +- main_forward_projection.py | 4 +- pages/1_Forward_Projection.py | 3 +- tests/test_all.py | 2 + tests/test_jones.py | 5 +- 8 files changed, 223 insertions(+), 208 deletions(-) create mode 100644 VolumeRaytraceLFM/birefringence_base.py create mode 100644 VolumeRaytraceLFM/jones_calculus.py diff --git a/VolumeRaytraceLFM/birefringence_base.py b/VolumeRaytraceLFM/birefringence_base.py new file mode 100644 index 0000000..b4df6dd --- /dev/null +++ b/VolumeRaytraceLFM/birefringence_base.py @@ -0,0 +1,14 @@ +from VolumeRaytraceLFM.abstract_classes import ( + BackEnds, OpticalElement, SimulType + ) + +class BirefringentElement(OpticalElement): + ''' Birefringent element, such as voxel, raytracer, etc, + extending optical element, so it has a back-end and optical information''' + def __init__(self, backend : BackEnds = BackEnds.NUMPY, torch_args={}, + optical_info=None): + super(BirefringentElement, self).__init__(backend=backend, + torch_args=torch_args, + optical_info=optical_info + ) + self.simul_type = SimulType.BIREFRINGENT diff --git a/VolumeRaytraceLFM/birefringence_implementations.py b/VolumeRaytraceLFM/birefringence_implementations.py index 8e2bf54..ef93f90 100644 --- a/VolumeRaytraceLFM/birefringence_implementations.py +++ b/VolumeRaytraceLFM/birefringence_implementations.py @@ -2,23 +2,14 @@ from tqdm import tqdm import h5py from VolumeRaytraceLFM.abstract_classes import * +from VolumeRaytraceLFM.birefringence_base import BirefringentElement +from VolumeRaytraceLFM.jones_calculus import JonesMatrixGenerators, JonesVectorGenerators from utils import errors from tifffile import imsave NORM_PROJ = False # normalize the projection of the ray onto the optic axis OPTIMIZING_MODE = False # use the birefringence stored in Delta_n_combined -class BirefringentElement(OpticalElement): - ''' Birefringent element, such as voxel, raytracer, etc, extending optical element, - so it has a back-end and optical information''' - def __init__(self, backend : BackEnds = BackEnds.NUMPY, torch_args={}, - optical_info=None): - super(BirefringentElement, self).__init__(backend=backend, - torch_args=torch_args, - optical_info=optical_info - ) - self.simul_type = SimulType.BIREFRINGENT - ########################################################################################### # Implementations of OpticalElement # TODO: rename to BirefringentVolume inherits @@ -1598,193 +1589,4 @@ def apply_polarizers(self, material_JM): @staticmethod def ret_and_azim_from_intensity(image_list): - raise NotImplementedError - -########################################################################################### - # Constructors for different types of elements - # This methods are constructors only, - # they don't support torch optimization of internal variables. - -class JonesMatrixGenerators(BirefringentElement): - '''2x2 Jones matrices representing various of polariztion elements''' - - def __init__(self, backend : BackEnds = BackEnds.NUMPY): - super(BirefringentElement, self).__init__(backend=backend, torch_args={}, optical_info={}) - - @staticmethod - def rotator(angle, backend=BackEnds.NUMPY): - '''2D rotation matrix - Args: - angle: angle to rotate by counterclockwise [radians] - Return: Jones matrix''' - if backend == BackEnds.NUMPY: - s = np.sin(angle) - c = np.cos(angle) - R = np.array([[c, -s], [s, c]]) - elif backend == BackEnds.PYTORCH: - s = torch.sin(angle) - c = torch.cos(angle) - R = torch.tensor([[c, -s], [s, c]]) - return R - - @staticmethod - def linear_retarder(ret, azim, backend=BackEnds.NUMPY): - '''Linear retarder - Args: - ret (float): retardance [radians] - azim (float): azimuth angle of fast axis [radians] - Return: Jones matrix - ''' - retarder_azim0 = JonesMatrixGenerators.linear_retarder_azim0(ret, backend=backend) - R = JonesMatrixGenerators.rotator(azim, backend=backend) - Rinv = JonesMatrixGenerators.rotator(-azim, backend=backend) - return R @ retarder_azim0 @ Rinv - - @staticmethod - def linear_retarder_azim0(ret, backend=BackEnds.NUMPY): - '''todo''' - if backend == BackEnds.NUMPY: - return np.array([[np.exp(1j * ret / 2), 0], [0, np.exp(-1j * ret / 2)]]) - else: - return torch.cat( - (torch.cat((torch.exp(1j * ret / 2).unsqueeze(1), torch.zeros(len(ret),1)),1).unsqueeze(2), - torch.cat((torch.zeros(len(ret),1), torch.exp(-1j * ret / 2).unsqueeze(1)),1).unsqueeze(2)), - 2 - ) - - @staticmethod - def linear_retarter_azim90(ret, backend=BackEnds.NUMPY): - '''Linear retarder, convention not establisted yet''' - # TODO: using same convention as linear_retarder_azim0 - if backend == BackEnds.NUMPY: - return np.array([[np.exp(1j * ret / 2), 0], [0, np.exp(-1j * ret / 2)]]) - else: - return torch.tensor([torch.exp(1j * ret / 2), 0], [0, torch.exp(-1j * ret / 2)]) - - @staticmethod - def quarter_waveplate(azim): - '''Quarter Waveplate - Linear retarder with lambda/4 or equiv pi/2 radians - Commonly used to convert linear polarized light to circularly polarized light''' - ret = np.pi / 2 - return JonesMatrixGenerators.linear_retarder(ret, azim) - - @staticmethod - def half_waveplate(azim): - '''Half Waveplate - Linear retarder with lambda/2 or equiv pi radians - Commonly used to rotate the plane of linear polarization''' - # Faster method - s = np.sin(2 * azim) - c = np.cos(2 * azim) - # # Alternative method - # ret = np.pi - # JM = self.LR(ret, azim) - return np.array([[c, s], [s, -c]]) - - @staticmethod - def linear_polarizer(theta): - '''Linear Polarizer - Args: - theta: angle that light can pass through - Returns: Jones matrix - ''' - c = np.cos(theta) - s = np.sin(theta) - J00 = c ** 2 - J11 = s ** 2 - J01 = s * c - J10 = J01 - return np.array([[J00, J01], [J10, J11]]) - - @staticmethod - def right_circular_polarizer(): - '''Right Circular Polarizer''' - return 1 / 2 * np.array([[1, -1j], [1j, 1]]) - - @staticmethod - def left_circular_polarizer(): - '''Left Circular Polarizer''' - return 1 / 2 * np.array([[1, 1j], [-1j, 1]]) - @staticmethod - def right_circular_retarder(ret): - '''Right Circular Retarder''' - return JonesMatrixGenerators.rotator(-ret / 2) - @staticmethod - def left_circular_retarder(ret): - '''Left Circular Retarder''' - return JonesMatrixGenerators.rotator(ret / 2) - - @staticmethod - def polscope_analyzer(): - '''Acts as a circular polarizer - Inhomogeneous elements because eigenvectors are linear (-45 deg) and - (right) circular polarization states - Source: 2010 Polarized Light pg. 224''' - return 1 / (2 * np.sqrt(2)) * np.array([[1 + 1j, 1 - 1j], [1 + 1j, 1 - 1j]]) - - @staticmethod - def universal_compensator(retA, retB): - '''Universal Polarizer - Used as the polarizer for the LC-PolScope''' - LP = JonesMatrixGenerators.linear_polarizer(0) - LCA = JonesMatrixGenerators.linear_retarder(retA, -np.pi / 4) - LCB = JonesMatrixGenerators.linear_retarder_azim0(retB) - return LCB @ LCA @ LP - - @staticmethod - def universal_compensator_modes(setting=0, swing=0): - '''Settings for the LC-PolScope polarizer - Parameters: - setting (int): LC-PolScope setting number between 0 and 4 - swing (float): proportion of wavelength, for ex 0.03 - Returns: - Jones matrix''' - swing_rad = swing * 2 * np.pi - if setting == 0: - retA = np.pi / 2 - retB = np.pi - elif setting == 1: - retA = np.pi / 2 + swing_rad - retB = np.pi - elif setting == 2: - retA = np.pi / 2 - retB = np.pi + swing_rad - elif setting == 3: - retA = np.pi / 2 - retB = np.pi - swing_rad - elif setting == 4: - retA = np.pi / 2 - swing_rad - retB = np.pi - return JonesMatrixGenerators.universal_compensator(retA, retB) - - -class JonesVectorGenerators(BirefringentElement): - '''2x1 Jones vectors representing various states of polarized light''' - def __init__(self, backend : BackEnds = BackEnds.NUMPY): - super(BirefringentElement, self).__init__(backend=backend, torch_args={}, optical_info={}) - - @staticmethod - def right_circular(): - '''Right circularly polarized light''' - return np.array([1, -1j]) / np.sqrt(2) - - @staticmethod - def left_circular(): - '''Left circularly polarized light''' - return np.array([1, 1j]) / np.sqrt(2) - - @staticmethod - def linear(angle): - '''Linearlly polarized light at an angle in radians''' - return JonesMatrixGenerators.rotator(angle) @ np.array([1, 0]) - - @staticmethod - def horizonal(): - '''Horizontally polarized light''' - return np.array([1, 0]) - - @staticmethod - def vertical(): - '''Vertically polarized light''' - return np.array([0, 1]) + raise NotImplementedError("Not implemented yet.") diff --git a/VolumeRaytraceLFM/jones_calculus.py b/VolumeRaytraceLFM/jones_calculus.py new file mode 100644 index 0000000..404787f --- /dev/null +++ b/VolumeRaytraceLFM/jones_calculus.py @@ -0,0 +1,195 @@ +'''Jones Calculus Matrices and Vector Generators + +Constructors for different types of elements. +These methods are constructors only. They don't support torch +optimization of internal variables. +''' +import numpy as np +import torch +from VolumeRaytraceLFM.abstract_classes import BackEnds +from VolumeRaytraceLFM.birefringence_base import BirefringentElement + + +class JonesMatrixGenerators(BirefringentElement): + '''2x2 Jones matrices representing various of polariztion elements''' + + def __init__(self, backend : BackEnds = BackEnds.NUMPY): + super(BirefringentElement, self).__init__(backend=backend, torch_args={}, optical_info={}) + + @staticmethod + def rotator(angle, backend=BackEnds.NUMPY): + '''2D rotation matrix + Args: + angle: angle to rotate by counterclockwise [radians] + Return: Jones matrix''' + if backend == BackEnds.NUMPY: + s = np.sin(angle) + c = np.cos(angle) + R = np.array([[c, -s], [s, c]]) + elif backend == BackEnds.PYTORCH: + s = torch.sin(angle) + c = torch.cos(angle) + R = torch.tensor([[c, -s], [s, c]]) + return R + + @staticmethod + def linear_retarder(ret, azim, backend=BackEnds.NUMPY): + '''Linear retarder + Args: + ret (float): retardance [radians] + azim (float): azimuth angle of fast axis [radians] + Return: Jones matrix + ''' + retarder_azim0 = JonesMatrixGenerators.linear_retarder_azim0(ret, backend=backend) + R = JonesMatrixGenerators.rotator(azim, backend=backend) + Rinv = JonesMatrixGenerators.rotator(-azim, backend=backend) + return R @ retarder_azim0 @ Rinv + + @staticmethod + def linear_retarder_azim0(ret, backend=BackEnds.NUMPY): + '''todo''' + if backend == BackEnds.NUMPY: + return np.array([[np.exp(1j * ret / 2), 0], [0, np.exp(-1j * ret / 2)]]) + else: + return torch.cat( + (torch.cat((torch.exp(1j * ret / 2).unsqueeze(1), torch.zeros(len(ret),1)),1).unsqueeze(2), + torch.cat((torch.zeros(len(ret),1), torch.exp(-1j * ret / 2).unsqueeze(1)),1).unsqueeze(2)), + 2 + ) + + @staticmethod + def linear_retarter_azim90(ret, backend=BackEnds.NUMPY): + '''Linear retarder, convention not establisted yet''' + # TODO: using same convention as linear_retarder_azim0 + if backend == BackEnds.NUMPY: + return np.array([[np.exp(1j * ret / 2), 0], [0, np.exp(-1j * ret / 2)]]) + else: + return torch.tensor([torch.exp(1j * ret / 2), 0], [0, torch.exp(-1j * ret / 2)]) + + @staticmethod + def quarter_waveplate(azim): + '''Quarter Waveplate + Linear retarder with lambda/4 or equiv pi/2 radians + Commonly used to convert linear polarized light to circularly polarized light''' + ret = np.pi / 2 + return JonesMatrixGenerators.linear_retarder(ret, azim) + + @staticmethod + def half_waveplate(azim): + '''Half Waveplate + Linear retarder with lambda/2 or equiv pi radians + Commonly used to rotate the plane of linear polarization''' + # Faster method + s = np.sin(2 * azim) + c = np.cos(2 * azim) + # # Alternative method + # ret = np.pi + # JM = self.LR(ret, azim) + return np.array([[c, s], [s, -c]]) + + @staticmethod + def linear_polarizer(theta): + '''Linear Polarizer + Args: + theta: angle that light can pass through + Returns: Jones matrix + ''' + c = np.cos(theta) + s = np.sin(theta) + J00 = c ** 2 + J11 = s ** 2 + J01 = s * c + J10 = J01 + return np.array([[J00, J01], [J10, J11]]) + + @staticmethod + def right_circular_polarizer(): + '''Right Circular Polarizer''' + return 1 / 2 * np.array([[1, -1j], [1j, 1]]) + + @staticmethod + def left_circular_polarizer(): + '''Left Circular Polarizer''' + return 1 / 2 * np.array([[1, 1j], [-1j, 1]]) + @staticmethod + def right_circular_retarder(ret): + '''Right Circular Retarder''' + return JonesMatrixGenerators.rotator(-ret / 2) + @staticmethod + def left_circular_retarder(ret): + '''Left Circular Retarder''' + return JonesMatrixGenerators.rotator(ret / 2) + + @staticmethod + def polscope_analyzer(): + '''Acts as a circular polarizer + Inhomogeneous elements because eigenvectors are linear (-45 deg) and + (right) circular polarization states + Source: 2010 Polarized Light pg. 224''' + return 1 / (2 * np.sqrt(2)) * np.array([[1 + 1j, 1 - 1j], [1 + 1j, 1 - 1j]]) + + @staticmethod + def universal_compensator(retA, retB): + '''Universal Polarizer + Used as the polarizer for the LC-PolScope''' + LP = JonesMatrixGenerators.linear_polarizer(0) + LCA = JonesMatrixGenerators.linear_retarder(retA, -np.pi / 4) + LCB = JonesMatrixGenerators.linear_retarder_azim0(retB) + return LCB @ LCA @ LP + + @staticmethod + def universal_compensator_modes(setting=0, swing=0): + '''Settings for the LC-PolScope polarizer + Parameters: + setting (int): LC-PolScope setting number between 0 and 4 + swing (float): proportion of wavelength, for ex 0.03 + Returns: + Jones matrix''' + swing_rad = swing * 2 * np.pi + if setting == 0: + retA = np.pi / 2 + retB = np.pi + elif setting == 1: + retA = np.pi / 2 + swing_rad + retB = np.pi + elif setting == 2: + retA = np.pi / 2 + retB = np.pi + swing_rad + elif setting == 3: + retA = np.pi / 2 + retB = np.pi - swing_rad + elif setting == 4: + retA = np.pi / 2 - swing_rad + retB = np.pi + return JonesMatrixGenerators.universal_compensator(retA, retB) + + +class JonesVectorGenerators(BirefringentElement): + '''2x1 Jones vectors representing various states of polarized light''' + def __init__(self, backend : BackEnds = BackEnds.NUMPY): + super(BirefringentElement, self).__init__(backend=backend, torch_args={}, optical_info={}) + + @staticmethod + def right_circular(): + '''Right circularly polarized light''' + return np.array([1, -1j]) / np.sqrt(2) + + @staticmethod + def left_circular(): + '''Left circularly polarized light''' + return np.array([1, 1j]) / np.sqrt(2) + + @staticmethod + def linear(angle): + '''Linearlly polarized light at an angle in radians''' + return JonesMatrixGenerators.rotator(angle) @ np.array([1, 0]) + + @staticmethod + def horizonal(): + '''Horizontally polarized light''' + return np.array([1, 0]) + + @staticmethod + def vertical(): + '''Vertically polarized light''' + return np.array([0, 1]) diff --git a/forward_intensity.py b/forward_intensity.py index 212d9eb..72c2c99 100644 --- a/forward_intensity.py +++ b/forward_intensity.py @@ -7,12 +7,12 @@ """ import time # to measure ray tracing time import matplotlib.pyplot as plt +from VolumeRaytraceLFM.jones_calculus import JonesMatrixGenerators from VolumeRaytraceLFM.visualization.plotting_intensity import plot_intensity_images from VolumeRaytraceLFM.abstract_classes import BackEnds from VolumeRaytraceLFM.birefringence_implementations import ( BirefringentVolume, - BirefringentRaytraceLFM, - JonesMatrixGenerators + BirefringentRaytraceLFM ) # Select backend method diff --git a/main_forward_projection.py b/main_forward_projection.py index c3a03a1..7c154d8 100644 --- a/main_forward_projection.py +++ b/main_forward_projection.py @@ -8,12 +8,12 @@ """ import time # to measure ray tracing time import matplotlib.pyplot as plt +from VolumeRaytraceLFM.jones_calculus import JonesMatrixGenerators from VolumeRaytraceLFM.visualization.plotting_ret_azim import plot_retardance_orientation from VolumeRaytraceLFM.abstract_classes import BackEnds from VolumeRaytraceLFM.birefringence_implementations import ( BirefringentVolume, - BirefringentRaytraceLFM, - JonesMatrixGenerators + BirefringentRaytraceLFM ) # Select backend method diff --git a/pages/1_Forward_Projection.py b/pages/1_Forward_Projection.py index bcad577..e36caf5 100644 --- a/pages/1_Forward_Projection.py +++ b/pages/1_Forward_Projection.py @@ -5,12 +5,13 @@ import h5py # for reading h5 volume files import streamlit as st import matplotlib.pyplot as plt +from VolumeRaytraceLFM.jones_calculus import JonesMatrixGenerators from VolumeRaytraceLFM.utils.file_utils import save_as_tif from VolumeRaytraceLFM.visualization.plotting_ret_azim import plot_retardance_orientation from VolumeRaytraceLFM.visualization.plotting_intensity import plot_intensity_images from VolumeRaytraceLFM.abstract_classes import BackEnds from VolumeRaytraceLFM.birefringence_implementations import ( - BirefringentVolume, BirefringentRaytraceLFM, JonesMatrixGenerators + BirefringentVolume, BirefringentRaytraceLFM ) try: import torch diff --git a/tests/test_all.py b/tests/test_all.py index 1529e76..5cae52e 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -4,6 +4,8 @@ import copy import os +from VolumeRaytraceLFM.jones_calculus import JonesMatrixGenerators + @pytest.fixture(scope = 'module') def global_data(): '''Create global optic_setting and optical_info containing all the optics and volume information diff --git a/tests/test_jones.py b/tests/test_jones.py index 64a15f2..d60156c 100644 --- a/tests/test_jones.py +++ b/tests/test_jones.py @@ -1,7 +1,8 @@ '''Test that Jones matrix conventions are consistent.''' import numpy as np -from VolumeRaytraceLFM.birefringence_implementations import ( - JonesMatrixGenerators, JonesVectorGenerators +from VolumeRaytraceLFM.jones_calculus import ( + JonesMatrixGenerators, + JonesVectorGenerators ) def test_polarizer_generators(): From 403cc747e1374eb198d53b589fb03633dd35b6a1 Mon Sep 17 00:00:00 2001 From: Geneva Schlafly Date: Fri, 8 Dec 2023 17:18:30 -0600 Subject: [PATCH 6/6] Refactor file management class methods I created the class VolumeFileManager to include the functionality of creating and saving birefringent volumes. Tests were created for the new class methods. --- .../birefringence_implementations.py | 285 ++++++++---------- VolumeRaytraceLFM/file_manager.py | 135 +++++++++ tests/test_all.py | 4 +- tests/test_birefringent_volume.py | 1 + tests/test_file_manager.py | 131 ++++++++ 5 files changed, 398 insertions(+), 158 deletions(-) create mode 100644 VolumeRaytraceLFM/file_manager.py create mode 100644 tests/test_file_manager.py diff --git a/VolumeRaytraceLFM/birefringence_implementations.py b/VolumeRaytraceLFM/birefringence_implementations.py index ef93f90..433492c 100644 --- a/VolumeRaytraceLFM/birefringence_implementations.py +++ b/VolumeRaytraceLFM/birefringence_implementations.py @@ -1,18 +1,15 @@ from math import floor from tqdm import tqdm -import h5py from VolumeRaytraceLFM.abstract_classes import * from VolumeRaytraceLFM.birefringence_base import BirefringentElement +from VolumeRaytraceLFM.file_manager import VolumeFileManager from VolumeRaytraceLFM.jones_calculus import JonesMatrixGenerators, JonesVectorGenerators from utils import errors -from tifffile import imsave NORM_PROJ = False # normalize the projection of the ray onto the optic axis OPTIMIZING_MODE = False # use the birefringence stored in Delta_n_combined ########################################################################################### -# Implementations of OpticalElement -# TODO: rename to BirefringentVolume inherits class BirefringentVolume(BirefringentElement): '''This class stores a 3D array of voxels with birefringence properties, either with a numpy or pytorch back-end.''' @@ -53,77 +50,96 @@ def __init__(self, backend=BackEnds.NUMPY, torch_args={}, #{'optic_config' : Non torch_args=torch_args, optical_info=optical_info ) + self._initialize_volume_attributes(optical_info, Delta_n, optic_axis) + # Check if a volume creation was requested + if volume_creation_args is not None: + self.init_volume(volume_creation_args['init_mode'], volume_creation_args.get('init_args', {})) + + def _initialize_volume_attributes(self, optical_info, Delta_n, optic_axis): + self.volume_shape = optical_info['volume_shape'] if self.backend == BackEnds.NUMPY: - self.volume_shape = self.optical_info['volume_shape'] - # In the case when an optic axis per voxel of a 3D volume is provided - # e.g. [3,nz,ny,nx] - if isinstance(optic_axis, np.ndarray) and len(optic_axis.shape) == 4: - self.volume_shape = optic_axis.shape[1:] - # flatten all the voxels in order to normalize them - optic_axis = optic_axis.reshape( - 3, - self.volume_shape[0] * self.volume_shape[1] * self.volume_shape[2] - ).astype(np.float64) - for n_voxel in range(len(optic_axis[0,...])): - oa_norm = np.linalg.norm(optic_axis[:,n_voxel]) - if oa_norm > 0: - optic_axis[:,n_voxel] /= oa_norm - # Set 4D shape again - self.optic_axis = optic_axis.reshape(3, *self.volume_shape) - - self.Delta_n = Delta_n - assert len(self.Delta_n.shape) == 3, \ - '3D Delta_n expected, as the optic_axis was provided as a 3D array' - # Single optic axis, we replicate it for all the voxels - elif isinstance(optic_axis, list) or isinstance(optic_axis, np.ndarray): - # Same optic axis for all voxels - optic_axis = np.array(optic_axis) - norm = np.linalg.norm(optic_axis) - if norm != 0: - optic_axis /= norm - self.optic_axis = np.expand_dims(optic_axis,[1,2,3]).repeat(self.volume_shape[0],1).repeat(self.volume_shape[1],2).repeat(self.volume_shape[2],3) - - # Create Delta_n 3D volume - self.Delta_n = Delta_n * np.ones(self.volume_shape) - - self.Delta_n[np.isnan(self.Delta_n)] = 0 - self.optic_axis[np.isnan(self.optic_axis)] = 0 + self._initialize_numpy_backend(Delta_n, optic_axis) elif self.backend == BackEnds.PYTORCH: - # Update volume shape from optic config - self.volume_shape = self.optical_info['volume_shape'] - # Normalization of optical axis, depending on input - if not isinstance(optic_axis, list) and optic_axis.ndim==4: - if isinstance(optic_axis, np.ndarray): - optic_axis = torch.from_numpy(optic_axis).type(torch.get_default_dtype()) - norm_A = (optic_axis[0,...]**2+optic_axis[1,...]**2+optic_axis[2,...]**2).sqrt() - self.optic_axis = optic_axis / norm_A.repeat(3,1,1,1) - assert len(Delta_n.shape) == 3, \ - '3D Delta_n expected, as the optic_axis was provided as a 3D torch tensor' - self.Delta_n = Delta_n - if not torch.is_tensor(Delta_n): + self._initialize_pytorch_backend(Delta_n, optic_axis) + else: + raise ValueError(f"Unsupported backend type: {self.backend}") + + def _initialize_numpy_backend(self, Delta_n, optic_axis): + # In the case when an optic axis per voxel of a 3D volume is provided, e.g. [3,nz,ny,nx] + if isinstance(optic_axis, np.ndarray) and len(optic_axis.shape) == 4: + self._handle_3d_optic_axis_numpy(optic_axis) + self.Delta_n = Delta_n + assert len(self.Delta_n.shape) == 3, '3D Delta_n expected, as the optic_axis was provided as a 3D array' + # Single optic axis, replicate for all voxels + elif isinstance(optic_axis, list) or isinstance(optic_axis, np.ndarray): + self._handle_single_optic_axis_numpy(optic_axis) + # Create Delta_n 3D volume + self.Delta_n = Delta_n * np.ones(self.volume_shape) + + self.Delta_n[np.isnan(self.Delta_n)] = 0 + self.optic_axis[np.isnan(self.optic_axis)] = 0 + + def _initialize_pytorch_backend(self, Delta_n, optic_axis): + # Normalization of optical axis, depending on input + if not isinstance(optic_axis, list) and optic_axis.ndim == 4: + self._handle_3d_optic_axis_torch(optic_axis) + assert len(Delta_n.shape) == 3, \ + '3D Delta_n expected, as the optic_axis was provided as a 3D torch tensor' + self.Delta_n = Delta_n + if not torch.is_tensor(Delta_n): self.Delta_n = torch.from_numpy(Delta_n).type(torch.get_default_dtype()) - else: - # Same optic axis for all voxels - optic_axis = np.array(optic_axis).astype(np.float32) - norm = np.linalg.norm(optic_axis) - if norm != 0: - optic_axis /= norm - self.optic_axis = torch.from_numpy(optic_axis).unsqueeze(1).unsqueeze(1).unsqueeze(1) \ - .repeat(1, self.volume_shape[0], self.volume_shape[1], self.volume_shape[2]) - self.Delta_n = Delta_n * torch.ones(self.volume_shape) - # Check for not a number, for when the voxel optic_axis is all zeros - self.Delta_n[torch.isnan(self.Delta_n)] = 0 - self.optic_axis[torch.isnan(self.optic_axis)] = 0 - # Store the data as pytorch parameters - self.optic_axis = nn.Parameter(self.optic_axis.reshape(3,-1)).type(torch.get_default_dtype()) - self.Delta_n = nn.Parameter(self.Delta_n.flatten()).type(torch.get_default_dtype()) - # Check if a volume creation was requested - if volume_creation_args is not None: - self.init_volume( - volume_creation_args['init_mode'], - volume_creation_args['init_args'] if 'init_args' in volume_creation_args.keys() else {} - ) + else: + self._handle_single_optic_axis_torch(optic_axis) + self.Delta_n = Delta_n * torch.ones(self.volume_shape) + + # Check for not a number, for when the voxel optic_axis is all zeros + self.Delta_n[torch.isnan(self.Delta_n)] = 0 + self.optic_axis[torch.isnan(self.optic_axis)] = 0 + # Store the data as pytorch parameters + self.optic_axis = nn.Parameter(self.optic_axis.reshape(3, -1)).type(torch.get_default_dtype()) + self.Delta_n = nn.Parameter(self.Delta_n.flatten()).type(torch.get_default_dtype()) + + def _handle_3d_optic_axis_numpy(self, optic_axis): + """Normalize and reshape a 3D optic axis array for Numpy backend.""" + self.volume_shape = optic_axis.shape[1:] + # Flatten all the voxels in order to normalize them + optic_axis = optic_axis.reshape( + 3, + self.volume_shape[0] * self.volume_shape[1] * self.volume_shape[2] + ).astype(np.float64) + for n_voxel in range(len(optic_axis[0,...])): + oa_norm = np.linalg.norm(optic_axis[:,n_voxel]) + if oa_norm > 0: + optic_axis[:,n_voxel] /= oa_norm + # Set 4D shape again + self.optic_axis = optic_axis.reshape(3, *self.volume_shape) + + def _handle_single_optic_axis_numpy(self, optic_axis): + """Set a single optic axis for all voxels for Numpy backend.""" + optic_axis = np.array(optic_axis) + oa_norm = np.linalg.norm(optic_axis) + if oa_norm != 0: + optic_axis /= oa_norm + self.optic_axis = np.expand_dims(optic_axis,[1,2,3]).repeat(self.volume_shape[0],1).repeat(self.volume_shape[1],2).repeat(self.volume_shape[2],3) + # self.optic_axis = np.expand_dims(optic_axis, axis=(1, 2, 3)) + # self.optic_axis = np.repeat(self.optic_axis, self.volume_shape, axis=(1, 2, 3)) + + def _handle_3d_optic_axis_torch(self, optic_axis): + """Normalize and reshape a 3D optic axis array for PyTorch backend.""" + if isinstance(optic_axis, np.ndarray): + optic_axis = torch.from_numpy(optic_axis).type(torch.get_default_dtype()) + oa_norm = torch.sqrt(torch.sum(optic_axis**2, dim=0)) + self.optic_axis = optic_axis / oa_norm.repeat(3, 1, 1, 1) + + def _handle_single_optic_axis_torch(self, optic_axis): + """Set a single optic axis for all voxels for PyTorch backend.""" + optic_axis = np.array(optic_axis).astype(np.float32) + oa_norm = np.linalg.norm(optic_axis) + if oa_norm != 0: + optic_axis /= oa_norm + optic_axis_tensor = torch.from_numpy(optic_axis).unsqueeze(1).unsqueeze(1).unsqueeze(1) + self.optic_axis = optic_axis_tensor.repeat(1, *self.volume_shape) def get_delta_n(self): '''Retrieves the birefringence as a 3D array''' @@ -368,56 +384,6 @@ def get_vox_params(self, vox_idx): axis = self.optic_axis[:, vox_idx] return self.Delta_n[vox_idx], axis -########### Generate different birefringent volumes - def save_as_file(self, h5_file_path, description="Temporary description", optical_all=False): - '''Store this volume into an h5 file''' - print(f'Saving volume to h5 file: {h5_file_path}') - # Create file - with h5py.File(h5_file_path, "w") as f: - # Save optical_info - oc_grp = f.create_group('optical_info') - try: - oc_grp.create_dataset('description', - [1], - data=description - ) - vol_shape = self.optical_info['volume_shape'] - voxel_size_um = self.optical_info['voxel_size_um'] - except: - pass - - if not optical_all: - try: - oc_grp.create_dataset('volume_shape', - np.array(vol_shape).shape if isinstance(vol_shape, list) else [1], - data=vol_shape - ) - oc_grp.create_dataset('voxel_size_um', - np.array(voxel_size_um).shape if isinstance(voxel_size_um, list) else [1], - data=voxel_size_um - ) - except: - pass - else: - for k,v in self.optical_info.items(): - try: - oc_grp.create_dataset(k, np.array(v).shape if isinstance(v,list) else [1], data=v) - # print(f'Added optical_info/{k} to {h5_file_path}') - except: - pass - # Save data (birefringence and optic_axis) - delta_n = self.get_delta_n() - optic_axis = self.get_optic_axis() - - if self.backend == BackEnds.PYTORCH: - delta_n = delta_n.detach().cpu().numpy() - optic_axis = optic_axis.detach().cpu().numpy() - - data_grp = f.create_group('data') - data_grp.create_dataset("delta_n", delta_n.shape, data=delta_n.astype(np.float32)) - data_grp.create_dataset("optic_axis", optic_axis.shape, data=optic_axis.astype(np.float32)) - return h5_file_path - @staticmethod def crop_to_region_shape(delta_n, optic_axis, volume_shape, region_shape): ''' @@ -475,16 +441,10 @@ def init_from_file(h5_file_path, backend=BackEnds.NUMPY, optical_info=None): It requires to have: optical_info/volume_shape [3]: shape of the volume in voxels [nz,ny,nx] data/delta_n [nz,ny,nx]: Birefringence volumetric information. - data/optic_axis [3,nz,ny,nx]: Optical axis per voxel.''' - - # Load volume - volume_file = h5py.File(h5_file_path, "r") - - # Fetch birefringence - delta_n = np.array(volume_file['data/delta_n']) - # Fetch optic_axis - optic_axis = np.array(volume_file['data/optic_axis']) - # TODO: adjust for when optical_info is None + data/optic_axis [3,nz,ny,nx]: Optical axis per voxel. + ''' + file_manager = VolumeFileManager() + delta_n, optic_axis = file_manager.extract_data_from_h5(h5_file_path) region_shape = np.array(optical_info['volume_shape']) if (delta_n.shape == region_shape).all(): pass @@ -496,9 +456,8 @@ def init_from_file(h5_file_path, backend=BackEnds.NUMPY, optical_info=None): err = (f"BirefringentVolume has dimensions ({delta_n.shape}) that are not all greater " + f"than or less than the volume region dimensions ({region_shape}) set for the microscope") raise ValueError(err) - # Create volume - volume_out = BirefringentVolume(backend=backend, optical_info=optical_info, Delta_n=delta_n, optic_axis=optic_axis) - return volume_out + volume = BirefringentVolume(backend=backend, optical_info=optical_info, Delta_n=delta_n, optic_axis=optic_axis) + return volume @staticmethod def load_from_file(h5_file_path, backend_type='numpy'): @@ -512,15 +471,9 @@ def load_from_file(h5_file_path, backend_type='numpy'): backend = BackEnds.NUMPY else: raise ValueError(f"Backend type {backend_type} is not an option.") - # Load volume - volume_file = h5py.File(h5_file_path, "r") - # Fetch birefringence - delta_n = np.array(volume_file['data/delta_n']) - # Fetch optic_axis - optic_axis = np.array(volume_file['data/optic_axis']) - # Fetch optical info - volume_shape = np.array(volume_file['optical_info/volume_shape']) - voxel_size_um = np.array(volume_file['optical_info/voxel_size_um']) + + file_manager = VolumeFileManager() + delta_n, optic_axis, volume_shape, voxel_size_um = file_manager.extract_all_data_from_h5(h5_file_path) cube_voxels = True # Create optical info dictionary # TODO: add the remaining variables, notably the voxel size and the cube voxels boolean @@ -532,6 +485,40 @@ def load_from_file(h5_file_path, backend_type='numpy'): volume_out = BirefringentVolume(backend=backend, optical_info=optical_info, Delta_n=delta_n, optic_axis=optic_axis) return volume_out + def save_as_file(self, h5_file_path, description="Temporary description", optical_all=False): + '''Store this volume into an h5 file''' + print(f'Saving volume to h5 file: {h5_file_path}') + + delta_n, optic_axis = self._get_data_as_numpy_arrays() + file_manager = VolumeFileManager() + file_manager.save_as_h5(h5_file_path, delta_n, optic_axis, self.optical_info, description, optical_all) + + def _get_data_as_numpy_arrays(self): + '''Converts delta_n and optic_axis based on backend''' + delta_n = self.get_delta_n() + optic_axis = self.get_optic_axis() + + if self.backend == BackEnds.PYTORCH: + delta_n = delta_n.detach().cpu().numpy() + optic_axis = optic_axis.detach().cpu().numpy() + + return delta_n, optic_axis + + def save_as_tiff(self, filename): + '''Store this volume into a tiff file''' + delta_n, optic_axis = self._get_data_as_numpy_arrays() + file_manager = VolumeFileManager() + file_manager.save_as_channel_stack_tiff(filename, delta_n, optic_axis) + + def _get_backend_str(self): + if self.backend == BackEnds.PYTORCH: + return 'pytorch' + elif self.backend == BackEnds.NUMPY: + return 'numpy' + else: + raise ValueError(f"Backend type {self.backend} is not supported.") + +########### Generate different birefringent volumes ############ def init_volume(self, init_mode='zeros', init_args={}): ''' This function creates predefined volumes and shapes, such as planes, ellipsoids, random, etc TODO: use init_args for random and planes''' @@ -861,18 +848,6 @@ def create_dummy_volume(backend=BackEnds.NUMPY, optical_info=None, vol_type="she raise NotImplementedError return volume - def save_as_tiff(self, filename): - '''Store this volume into a tiff file''' - print(f'Saving volume to file: {filename}') - delta_n = self.get_delta_n() - optic_axis = self.get_optic_axis() - combined_data = np.stack([delta_n, optic_axis[0], optic_axis[1], optic_axis[2]], axis=0, dtype=np.float32) - if self.backend == BackEnds.PYTORCH: - delta_n = delta_n.detach().cpu().numpy() - optic_axis = optic_axis.detach().cpu().numpy() - imsave(filename, combined_data) - return filename - ############ Implementations class BirefringentRaytraceLFM(RayTraceLFM, BirefringentElement): diff --git a/VolumeRaytraceLFM/file_manager.py b/VolumeRaytraceLFM/file_manager.py new file mode 100644 index 0000000..2f296ae --- /dev/null +++ b/VolumeRaytraceLFM/file_manager.py @@ -0,0 +1,135 @@ +import numpy as np +import h5py +import tifffile + +class VolumeFileManager: + def __init__(self): + """Initializes the VolumeFileManager class.""" + pass + + def extract_data_from_h5(self, file_path): + """ + Extracts birefringence (delta_n) and optic axis data from an H5 file. + + Args: + - file_path (str): Path to the H5 file from which data is to be extracted. + + Returns: + - tuple: A tuple containing numpy arrays for delta_n and optic_axis. + """ + volume_file = h5py.File(file_path, "r") + delta_n = np.array(volume_file['data/delta_n']) + optic_axis = np.array(volume_file['data/optic_axis']) + + return delta_n, optic_axis + + def extract_all_data_from_h5(self, file_path): + """ + Extracts birefringence (delta_n), optic axis data, and optical information from an H5 file. + + Args: + - file_path (str): Path to the H5 file from which data is to be extracted. + + Returns: + - tuple: A tuple containing numpy arrays for + delta_n, optic_axis, volume_shape, and voxel_size_um. + """ + volume_file = h5py.File(file_path, "r") + + # Fetch birefringence and optic axis + delta_n = np.array(volume_file['data/delta_n']) + optic_axis = np.array(volume_file['data/optic_axis']) + + # Fetch optical info + volume_shape = np.array(volume_file['optical_info/volume_shape']) + voxel_size_um = np.array(volume_file['optical_info/voxel_size_um']) + + return delta_n, optic_axis, volume_shape, voxel_size_um + + def save_as_channel_stack_tiff(self, filename, delta_n, optic_axis): + """ + Saves the provided volume data as a multi-channel TIFF file. + + Args: + - filename (str): The file path where the TIFF file will be saved. + - delta_n (np.ndarray): Numpy array containing the birefringence information of the volume. + - optic_axis (np.ndarray): Numpy array containing the optic axis data of the volume. + + The method combines delta_n and optic_axis data into a single multi-channel array + and saves it as a TIFF file. Exceptions related to file operations are caught and logged. + """ + try: + print(f'Saving volume to file: {filename}') + combined_data = np.stack([delta_n, optic_axis[0], optic_axis[1], optic_axis[2]], axis=0) + tifffile.imwrite(filename, combined_data) + print('Volume saved successfully.') + except Exception as e: + print(f"Error saving file: {e}") + + def save_as_h5(self, h5_file_path, delta_n, optic_axis, optical_info, description, optical_all): + """ + Saves the volume data, including birefringence information (delta_n) and optic axis data, + along with optical metadata into an H5 file. + + The method creates an H5 file at the specified path and writes the provided data + to this file, organizing the data into appropriate groups and datasets within the file. + + Args: + - h5_file_path (str): The file path where the H5 file will be saved. + - delta_n (np.ndarray): Numpy array containing the birefringence information of the volume. + - optic_axis (np.ndarray): Numpy array containing the optic axis data of the volume. + - optical_info (dict): Dictionary containing optical metadata about the volume. This may + include properties like volume shape, voxel size, etc. + - description (str): A brief description or note to be included in the optical information + of the H5 file. Useful for providing context or additional details about the data. + - optical_all (bool): A flag indicating whether to save all optical metadata present in + `optical_info` to the H5 file. If False, only specific predefined metadata (like volume + shape and voxel size) will be saved. + + Returns: + None. The result of this method is the creation of an H5 file with the specified data. + """ + with h5py.File(h5_file_path, "w") as f: + self._save_optical_info(f, optical_info, description, optical_all) + self._save_data(f, delta_n, optic_axis) + + def _save_optical_info(self, file_handle, optical_info, description, optical_all): + """ + Private method to save optical information to an H5 file. + + Args: + - file_handle (File): An open H5 file handle. + - optical_info (dict): Dictionary containing optical metadata. + - description (str): Description to be included in the H5 file. + - optical_all (bool): Flag indicating whether to save all optical metadata. + + This method creates a group for optical information and adds datasets to it. + """ + optics_grp = file_handle.create_group('optical_info') + optics_grp.create_dataset('description', data=np.string_(description)) + # optics_grp.create_dataset('description', data=description) + if not optical_all: + vol_shape = optical_info.get('volume_shape', None) + voxel_size_um = optical_info.get('voxel_size_um', None) + if vol_shape is not None: + optics_grp.create_dataset('volume_shape', data=np.array(vol_shape)) + if voxel_size_um is not None: + optics_grp.create_dataset('voxel_size_um', data=np.array(voxel_size_um)) + else: + for k, v in optical_info.items(): + optics_grp.create_dataset(k, data=np.array(v)) + + def _save_data(self, file_handle, delta_n, optic_axis): + """ + Private method to save delta_n and optic_axis data to an H5 file. + + Args: + - file_handle (File): An open H5 file handle. + - delta_n (np.ndarray): Numpy array of delta_n data. + - optic_axis (np.ndarray): Numpy array of optic_axis data. + + This method creates a group for volume data and adds datasets for delta_n and optic_axis. + """ + data_grp = file_handle.create_group('data') + data_grp.create_dataset("delta_n", delta_n.shape, data=delta_n) + data_grp.create_dataset("optic_axis", optic_axis.shape, data=optic_axis) diff --git a/tests/test_all.py b/tests/test_all.py index 5cae52e..feb36e3 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -1,9 +1,7 @@ import pytest -from VolumeRaytraceLFM.birefringence_implementations import * import matplotlib.pyplot as plt import copy -import os - +from VolumeRaytraceLFM.birefringence_implementations import * from VolumeRaytraceLFM.jones_calculus import JonesMatrixGenerators @pytest.fixture(scope = 'module') diff --git a/tests/test_birefringent_volume.py b/tests/test_birefringent_volume.py index ed46bc6..0d1a694 100644 --- a/tests/test_birefringent_volume.py +++ b/tests/test_birefringent_volume.py @@ -2,6 +2,7 @@ import numpy as np import torch import pytest +import h5py from plotly.graph_objs import Figure from VolumeRaytraceLFM.birefringence_implementations import * diff --git a/tests/test_file_manager.py b/tests/test_file_manager.py new file mode 100644 index 0000000..42413f6 --- /dev/null +++ b/tests/test_file_manager.py @@ -0,0 +1,131 @@ +'''Tests for VolumeFileManager class''' +import numpy as np +import h5py +import os +from unittest.mock import Mock +from VolumeRaytraceLFM.file_manager import VolumeFileManager + +def mock_h5_file(return_value): + def mock(*args, **kwargs): + class MockH5File: + def __getitem__(self, item): + return return_value[item] + + def __enter__(self, *args, **kwargs): + return self + + def __exit__(self, *args, **kwargs): + pass + + return MockH5File() + + return mock + +def test_extract_data_from_h5(monkeypatch): + test_file_path = 'test_file.h5' + expected_delta_n = np.array([1, 2, 3]) + expected_optic_axis = np.array([4, 5, 6]) + + # Mocking the h5py.File call + monkeypatch.setattr(h5py, 'File', mock_h5_file({ + 'data/delta_n': expected_delta_n, + 'data/optic_axis': expected_optic_axis + })) + + vfm = VolumeFileManager() + delta_n, optic_axis = vfm.extract_data_from_h5(test_file_path) + + assert np.array_equal(delta_n, expected_delta_n) + assert np.array_equal(optic_axis, expected_optic_axis) + +def test_extract_all_data_from_h5(monkeypatch): + """Verify that the extract_all_data_from_h5 method correctly + extracts data and optical information from an H5 file""" + test_file_path = 'test_file.h5' + expected_delta_n = np.array([1, 2, 3]) + expected_optic_axis = np.array([4, 5, 6]) + expected_volume_shape = np.array([7, 8, 9]) + expected_voxel_size_um = np.array([10, 11, 12]) + + # Mocking the h5py.File call + monkeypatch.setattr(h5py, 'File', mock_h5_file({ + 'data/delta_n': expected_delta_n, + 'data/optic_axis': expected_optic_axis, + 'optical_info/volume_shape': expected_volume_shape, + 'optical_info/voxel_size_um': expected_voxel_size_um + })) + + vfm = VolumeFileManager() + delta_n, optic_axis, volume_shape, voxel_size_um = vfm.extract_all_data_from_h5(test_file_path) + + assert np.array_equal(delta_n, expected_delta_n) + assert np.array_equal(optic_axis, expected_optic_axis) + assert np.array_equal(volume_shape, expected_volume_shape) + assert np.array_equal(voxel_size_um, expected_voxel_size_um) + +def test_save_as_channel_stack_tiff(monkeypatch): + filename = 'test.tiff' + shape = (3, 1, 5, 5) + delta_n = np.random.random(shape[1:]) + optic_axis = np.random.random(shape) + norms = np.linalg.norm(optic_axis, axis=0) + optic_axis /= norms + + # Create a mock for the imwrite function + mock_imwrite = Mock() + # Use monkeypatch to replace imwrite with the mock + monkeypatch.setattr('tifffile.imwrite', mock_imwrite) + + # Create an instance of VolumeFileManager and call the method + vfm = VolumeFileManager() + vfm.save_as_channel_stack_tiff(filename, delta_n, optic_axis) + + assert mock_imwrite.called, "imwrite was not called" + + # Extract the actual arguments with which imwrite was called + actual_args, _ = mock_imwrite.call_args + actual_filename, actual_data = actual_args + + assert actual_filename == filename, "Filename does not match" + + # Check if the data matches within a tolerance + expected_data = np.stack([delta_n, optic_axis[0], optic_axis[1], optic_axis[2]], axis=0) + assert np.allclose(actual_data, expected_data), "Data does not match within tolerance" + +def test_save_as_h5(): + # Mock data for testing + mock_h5_file_path = "test_saving_h5_file.h5" + mock_delta_n = np.array([0.1]) + mock_optic_axis = np.array([0, 1, 0]) + mock_optical_info = {"volume_shape": [1, 1, 1], "voxel_size_um": 1.0} + mock_description = "Test data for volume file manager" + mock_optical_all = True + + manager = VolumeFileManager() + manager.save_as_h5( + mock_h5_file_path, + mock_delta_n, + mock_optic_axis, + mock_optical_info, + mock_description, + mock_optical_all, + ) + assert os.path.exists(mock_h5_file_path) + with h5py.File(mock_h5_file_path, "r") as f: + assert "data" in f and "optical_info" in f + assert "delta_n" in f["data"] and "optic_axis" in f["data"] + assert "description" in f["optical_info"] + + # Verify the contents of the datasets + assert np.array_equal(f["data"]["delta_n"][:], mock_delta_n) + assert np.array_equal(f["data"]["optic_axis"][:], mock_optic_axis) + description_bytes = f["optical_info"]["description"][()] + description_string = description_bytes.decode('utf-8') + assert description_string == mock_description + + # Verify the presence of additional optical metadata if `optical_all` is True + if mock_optical_all: + for key, value in mock_optical_info.items(): + assert key in f["optical_info"] + assert np.array_equal(f["optical_info"][key][()], value) + os.remove(mock_h5_file_path)