diff --git a/VolumeRaytraceLFM/birefringence_implementations.py b/VolumeRaytraceLFM/birefringence_implementations.py index 15d2675..a14f440 100644 --- a/VolumeRaytraceLFM/birefringence_implementations.py +++ b/VolumeRaytraceLFM/birefringence_implementations.py @@ -909,13 +909,19 @@ def get_volume_reachable_region(self): n_voxels_per_ml = self.optical_info['n_voxels_per_ml'] n_ml_half = floor(n_micro_lenses * n_voxels_per_ml / 2.0) mask = torch.zeros(self.optical_info['volume_shape']) - mask[:, - self.vox_ctr_idx[1]-n_ml_half+1 : self.vox_ctr_idx[1]+n_ml_half, - self.vox_ctr_idx[2]-n_ml_half+1 : self.vox_ctr_idx[2]+n_ml_half] = 1.0 + include_ray_angle_reach = True + if include_ray_angle_reach: + vox_span_half = int(self.voxel_span_per_ml + (n_micro_lenses * n_voxels_per_ml) / 2) + mask[:, + self.vox_ctr_idx[1]-vox_span_half+1 : self.vox_ctr_idx[1]+vox_span_half, + self.vox_ctr_idx[2]-vox_span_half+1 : self.vox_ctr_idx[2]+vox_span_half] = 1.0 + else: + mask[:, + self.vox_ctr_idx[1]-n_ml_half+1 : self.vox_ctr_idx[1]+n_ml_half, + self.vox_ctr_idx[2]-n_ml_half+1 : self.vox_ctr_idx[2]+n_ml_half] = 1.0 # mask_volume = BirefringentVolume(backend=self.backend, # optical_info=self.optical_info, Delta_n=0.01, optic_axis=[0.5,0.5,0]) # [r,a] = self.ray_trace_through_volume(mask_volume) - # # Check gradients to see what is affected # L = r.mean() + a.mean() # L.backward() @@ -942,7 +948,7 @@ def precompute_MLA_volume_geometry(self): # border_size_around_mla = np.ceil((volume_shape[1]-(n_micro_lenses*n_voxels_per_ml)) / 2) min_needed_volume_size = int(self.voxel_span_per_ml + (n_micro_lenses*n_voxels_per_ml)) assert min_needed_volume_size <= volume_shape[1] and min_needed_volume_size <= volume_shape[2], "The volume in front of the microlenses" + \ - f"({n_micro_lenses},{n_micro_lenses}) is to large for a volume_shape: {self.optical_info['volume_shape'][1:]}. " + \ + f"({n_micro_lenses},{n_micro_lenses}) is too large for a volume_shape: {self.optical_info['volume_shape'][1:]}. " + \ f"Increase the volume_shape to at least [{min_needed_volume_size+1},{min_needed_volume_size+1}]" odd_mla_shift = np.mod(n_micro_lenses,2) @@ -956,12 +962,14 @@ def precompute_MLA_volume_geometry(self): np.array([n_voxels_per_ml * ml_ii, n_voxels_per_ml*ml_jj]) + np.array(self.vox_ctr_idx[1:]) - n_voxels_per_ml_half ) - self.vox_indices_ml_shifted_all += [[RayTraceLFM.ravel_index((vox[ix][0], - vox[ix][1]+current_offset[0], - vox[ix][2]+current_offset[1]), - self.optical_info['volume_shape']) for ix in range(len(vox))] - for vox in self.ray_vol_colli_indices - ] + self.vox_indices_ml_shifted_all += [ + [ + RayTraceLFM.ravel_index( + (vox[ix][0], vox[ix][1]+current_offset[0], vox[ix][2]+current_offset[1]), + self.optical_info['volume_shape']) for ix in range(len(vox)) + ] + for vox in self.ray_vol_colli_indices + ] # Shift ray-pixel indices if self.ray_valid_indices_all is None: self.ray_valid_indices_all = self.ray_valid_indices.clone() @@ -1408,6 +1416,43 @@ def voxRayJM(self, Delta_n, opticAxis, rayDir, ell, wavelength): assert not torch.isnan(JM).any(), "A Jones matrix contains NaN values." return JM + def vox_ray_ret_azim(self, Delta_n, opticAxis, rayDir, ell, wavelength): + '''Calculate the effective retardance and azimuth of a ray passing through a voxel''' + if self.backend == BackEnds.NUMPY: + # Azimuth is the angle of the slow axis of retardance. + azim = np.arctan2(np.dot(opticAxis, rayDir[1]), np.dot(opticAxis, rayDir[2])) + if Delta_n == 0: + azim = 0 + elif Delta_n < 0: + azim = azim + np.pi / 2 + # print(f"Azimuth angle of index ellipsoid is + # {np.around(np.rad2deg(azim), decimals=0)} degrees.") + normAxis = np.linalg.norm(opticAxis) + proj_along_ray = np.dot(opticAxis, rayDir[0]) + # np.divide(my_arr, my_arr1, out=np.ones_like(my_arr, dtype=np.float32), where=my_arr1 != 0) + ret = abs(Delta_n) * (1 - np.dot(opticAxis, rayDir[0]) ** 2) * 2 * np.pi * ell / wavelength + else: + raise NotImplementedError("Not implemented for pytorch yet.") + return ret, azim + + def vox_ray_matrix(self, ret, azim): + '''Calculate the Jones matrix associated with a particular ray and voxel combination''' + if self.backend == BackEnds.NUMPY: + JM = JonesMatrixGenerators.linear_retarder(ret, azim) + pass + else: + raise NotImplementedError("Not implemented for pytorch yet.") + offdiag = 1j * torch.sin(azim) * torch.sin(ret) + diag1 = torch.cos(ret) + 1j * torch.cos(azim) * torch.sin(ret) + diag2 = torch.conj(diag1) + # Construct Jones Matrix + JM = torch.zeros([Delta_n.shape[0], 2, 2], dtype=torch.complex64, device=Delta_n.device) + JM[:,0,0] = diag1 + JM[:,0,1] = offdiag + JM[:,1,0] = offdiag + JM[:,1,1] = diag2 + return JM + def clone(self): # Code to create a copy of this instance new_instance = BirefringentVolume(...) diff --git a/VolumeRaytraceLFM/metrics/data_fidelity.py b/VolumeRaytraceLFM/metrics/data_fidelity.py new file mode 100644 index 0000000..0650a3d --- /dev/null +++ b/VolumeRaytraceLFM/metrics/data_fidelity.py @@ -0,0 +1,15 @@ +'''Data-fidelity metrics''' +import torch +import torch.nn.functional as F + +def von_mises_loss(angle_pred, angle_gt, kappa=1.0): + '''Von Mises loss function for orientation''' + diff = angle_pred - angle_gt + loss = 1 - torch.exp(kappa * torch.cos(diff)) + return loss.mean() + +def cosine_similarity_loss(vector_pred, vector_gt): + '''Cosine similarity loss function for orientation''' + cos_sim = F.cosine_similarity(vector_pred, vector_gt, dim=-1) + loss = 1 - cos_sim + return loss.mean() diff --git a/VolumeRaytraceLFM/metrics/metric.py b/VolumeRaytraceLFM/metrics/metric.py new file mode 100644 index 0000000..ef55ca7 --- /dev/null +++ b/VolumeRaytraceLFM/metrics/metric.py @@ -0,0 +1,75 @@ +import json +import torch +import torch.nn as nn +import torch.nn.functional as F +from regularization import L1Regularization, L2Regularization + +REGULARIZATION_FNS = { + 'L1Regularization': L1Regularization, + 'L2Regularization': L2Regularization, + # Add more functions here if needed +} + +class PolarimetricLossFunction: + def __init__(self, json_file=None): + if json_file: + with open(json_file, 'r') as f: + params = json.load(f) + self.weight_retardance = params.get('weight_retardance', 1.0) + self.weight_orientation = params.get('weight_orientation', 1.0) + self.weight_datafidelity = params.get('weight_datafidelity', 1.0) + self.weight_regularization = params.get('weight_regularization', 0.1) + # Initialize any specific loss functions you might need + self.mse_loss = nn.MSELoss() + # Initialize regularization functions + self.regularization_fns = [(REGULARIZATION_FNS[fn_name], weight) for fn_name, weight in params.get('regularization_fns', [])] + else: + self.weight_retardance = 1.0 + self.weight_orientation = 1.0 + self.weight_datafidelity = 1.0 + self.weight_regularization = 0.1 + self.mse_loss = nn.MSELoss() + self.regularization_fns = [] + + def set_retardance_target(self, target): + self.target_retardance = target + + def set_orientation_target(self, target): + self.target_orientation = target + + def compute_retardance_loss(self, prediction): + # Add logic to transform data and compute retardance loss + pass + + def compute_orientation_loss(self, prediction): + # Add logic to transform data and compute orientation loss + pass + + def transform_input_data(self, data): + # Transform the input data into a vector form + pass + + def compute_datafidelity_term(self, pred_retardance, pred_orientation): + '''Incorporates the retardance and orientation losses''' + retardance_loss = self.compute_retardance_loss(pred_retardance) + orientation_loss = self.compute_orientation_loss(pred_orientation) + data_loss = (self.weight_retardance * retardance_loss + + self.weight_regularization * orientation_loss) + return data_loss + + def compute_regularization_term(self, data): + '''Compute regularization term''' + regularization_loss = torch.tensor(0.) + for reg_fn, weight in self.regularization_fns: + regularization_loss += weight * reg_fn(data) + return regularization_loss + + def compute_total_loss(self, pred_retardance, pred_orientation, data): + # Compute individual losses + datafidelity_loss = self.compute_datafidelity_term(pred_retardance, pred_orientation) + regularization_loss = self.compute_regularization_term(data) + + # Compute total loss with weighted sum + total_loss = (self.weight_datafidelity * datafidelity_loss + + self.weight_regularization * regularization_loss) + return total_loss diff --git a/VolumeRaytraceLFM/metrics/regularization.py b/VolumeRaytraceLFM/metrics/regularization.py new file mode 100644 index 0000000..b75e4e5 --- /dev/null +++ b/VolumeRaytraceLFM/metrics/regularization.py @@ -0,0 +1,28 @@ +'''Regularization metrics for a birefringent volume''' +from VolumeRaytraceLFM.birefringence_implementations import BirefringentVolume +from regularization_fundamentals import * + +def l2_bir(volume: BirefringentVolume): + birefringence = volume.get_delta_n() + return l2(birefringence) + +def total_variation_bir(volume: BirefringentVolume): + birefringence = volume.get_delta_n() + return total_variation_3d_volumetric(birefringence) + +class AnisotropyAnalysis: + def __init__(self, volume: BirefringentVolume): + self.volume = volume + self.birefringence = volume.get_delta_n() + self.optic_axis = volume.get_optic_axis() + + def l2_regularization(self): + return l2(self.birefringence) + + def total_variation_regularization(self): + return total_variation_3d_volumetric(self.birefringence) + + def process_optic_axis(self): + # Example method to process optic axis data + # Implement the specific logic you need for optic axis data + pass diff --git a/VolumeRaytraceLFM/metrics/regularization_fundamentals.py b/VolumeRaytraceLFM/metrics/regularization_fundamentals.py new file mode 100644 index 0000000..4410e54 --- /dev/null +++ b/VolumeRaytraceLFM/metrics/regularization_fundamentals.py @@ -0,0 +1,34 @@ +'''Regularization functions that can use used in the optimization process.''' +import torch + +def l1(data, weight=1.0): + return weight * torch.abs(data).mean() + +def l2(data, weight=1.0): + return weight * torch.pow(data, 2).mean() + +def linfinity(data, weight=1.0): + return weight * torch.max(torch.abs(data)) + +def elastic_net(data, weight1=1.0, weight2=1.0): + l1_term = torch.abs(data).sum() + l2_term = torch.pow(data, 2).sum() + return weight1 * l1_term + weight2 * l2_term + +def total_variation_3d_volumetric(data, weight=1.0): + """ + Computes the Total Variation regularization for a 4D tensor representing volumetric data. + Args: + data (torch.Tensor): The input 3D tensor with shape [depth, height, width]. + weight (float): Weighting factor for the regularization term. + Returns: + torch.Tensor: The computed Total Variation regularization term. + """ + # Calculate the differences between adjacent elements along each spatial dimension + diff_depth = torch.pow(data[1:, :, :] - data[:-1, :, :], 2).sum() + diff_height = torch.pow(data[:, 1:, :] - data[:, :-1, :], 2).sum() + diff_width = torch.pow(data[:, :, 1:] - data[:, :, :-1], 2).sum() + + # Sum up the differences and apply the weight + tv_reg = weight * (diff_depth + diff_height + diff_width) + return tv_reg diff --git a/VolumeRaytraceLFM/reconstructions.py b/VolumeRaytraceLFM/reconstructions.py index 1c77aea..39cfe2b 100644 --- a/VolumeRaytraceLFM/reconstructions.py +++ b/VolumeRaytraceLFM/reconstructions.py @@ -20,7 +20,11 @@ from VolumeRaytraceLFM.visualization.plt_util import setup_visualization from VolumeRaytraceLFM.visualization.plotting_iterations import plot_iteration_update_gridspec from VolumeRaytraceLFM.utils.file_utils import create_unique_directory - +from VolumeRaytraceLFM.utils.dimensions_utils import ( + get_region_of_ones_shape, + reshape_and_crop, + store_as_pytorch_parameter + ) class ReconstructionConfig: def __init__(self, optical_info, ret_image, azim_image, initial_vol, iteration_params, loss_fcn=None, gt_vol=None): """ @@ -34,8 +38,6 @@ def __init__(self, optical_info, ret_image, azim_image, initial_vol, iteration_p assert isinstance(optical_info, dict), "Expected optical_info to be a dictionary" assert isinstance(ret_image, (torch.Tensor, np.ndarray)), "Expected ret_image to be a PyTorch Tensor or a numpy array" assert isinstance(azim_image, (torch.Tensor, np.ndarray)), "Expected azim_image to be a PyTorch Tensor or a numpy array" - # assert isinstance(ret_image, torch.Tensor), "Expected ret_image to be a PyTorch Tensor" - # assert isinstance(azim_image, torch.Tensor), "Expected azim_image to be a PyTorch Tensor" assert isinstance(initial_vol, BirefringentVolume), "Expected initial_volume to be of type BirefringentVolume" assert isinstance(iteration_params, dict), "Expected iteration_params to be a dictionary" if loss_fcn: @@ -46,8 +48,6 @@ def __init__(self, optical_info, ret_image, azim_image, initial_vol, iteration_p self.optical_info = optical_info self.retardance_image = self._to_numpy(ret_image) self.azimuth_image = self._to_numpy(azim_image) - # self.retardance_image = ret_image.detach() - # self.azimuth_image = azim_image.detach() self.initial_volume = initial_vol self.interation_parameters = iteration_params self.loss_function = loss_fcn @@ -70,6 +70,7 @@ def save(self, parent_directory): # Save the retardance and azimuth images np.save(os.path.join(directory, 'ret_image.npy'), self.retardance_image) np.save(os.path.join(directory, 'azim_image.npy'), self.azimuth_image) + plt.ioff() my_fig = plot_retardance_orientation(self.retardance_image, self.azimuth_image, 'hsv', include_labels=True) my_fig.savefig(directory + '/ret_azim.png', bbox_inches='tight', dpi=300) plt.close(my_fig) @@ -113,13 +114,14 @@ def load(cls, parent_directory): class Reconstructor: + backend = BackEnds.PYTORCH + def __init__(self, recon_info: ReconstructionConfig, device='cpu'): """ Initialize the Reconstructor with the provided parameters. - iteration_params (class): containing reconstruction parameters + recon_info (class): containing reconstruction parameters """ - self.backend = BackEnds.PYTORCH self.optical_info = recon_info.optical_info self.ret_img_meas = recon_info.retardance_image self.azim_img_meas = recon_info.azimuth_image @@ -170,7 +172,7 @@ def to_device(self, device): def setup_raytracer(self, device='cpu'): """Initialize Birefringent Raytracer.""" print(f'For raytracing, using computing device {device}') - rays = BirefringentRaytraceLFM(backend=self.backend, optical_info=self.optical_info) + rays = BirefringentRaytraceLFM(backend=Reconstructor.backend, optical_info=self.optical_info) rays.to(device) # Move the rays to the specified device start_time = time.time() rays.compute_rays_geometry() @@ -198,20 +200,51 @@ def setup_initial_volume(self): return initial_volume def mask_outside_rays(self): - """Mask out volume that is outside FOV of the microscope""" - self.volume_pred.Delta_n.requires_grad = False - self.volume_pred.optic_axis.requires_grad = False + """ + Mask out volume that is outside FOV of the microscope. + Original shapes of the volume are preserved. + """ mask = self.rays.get_volume_reachable_region() - self.volume_pred.Delta_n[mask.view(-1)==0] = 0 - self.volume_pred.Delta_n.requires_grad = True - self.volume_pred.optic_axis.requires_grad = True + with torch.no_grad(): + self.volume_pred.Delta_n[mask.view(-1)==0] = 0 + # Masking the optic axis caused NaNs in the Jones Matrix. So, we don't mask it. + # self.volume_pred.optic_axis[:, mask.view(-1)==0] = 0 + + def crop_pred_volume_to_reachable_region(self): + """Crop the predicted volume to the region that is reachable by the microscope. + Note: This method modifies the volume_pred attribute. The voxel indices of the predetermined ray tracing are no longer valid. + """ + mask = self.rays.get_volume_reachable_region() + region_shape = get_region_of_ones_shape(mask).tolist() + original_shape = self.optical_info["volume_shape"] + self.optical_info["volume_shape"] = region_shape + self.volume_pred.optical_info["volume_shape"] = region_shape + birefringence = self.volume_pred.Delta_n + optic_axis = self.volume_pred.optic_axis + with torch.no_grad(): + cropped_birefringence = reshape_and_crop(birefringence, original_shape, region_shape) + self.volume_pred.Delta_n = store_as_pytorch_parameter(cropped_birefringence, 'scalar') + cropped_optic_axis = reshape_and_crop(optic_axis, [3, *original_shape], region_shape) + self.volume_pred.optic_axis = store_as_pytorch_parameter(cropped_optic_axis, 'vector') + + def restrict_volume_to_reachable_region(self): + """Restrict the volume to the region that is reachable by the microscope. + This includes cropping the volume are creating a new ray geometry + """ + self.crop_pred_volume_to_reachable_region() + self.rays = self.setup_raytracer() + + def _turn_off_initial_volume_gradients(self): + """Turn off the gradients for the initial volume guess.""" + self.volume_initial_guess.Delta_n.requires_grad = False + self.volume_initial_guess.optic_axis.requires_grad = False def specify_variables_to_learn(self, learning_vars=None): """ Specify which variables of the initial volume object should be considered for learning. This method updates the 'members_to_learn' attribute of the initial volume object, ensuring no duplicates are added. - Parameters: + Args: learning_vars (list): Variable names to be appended for learning. Defaults to ['Delta_n', 'optic_axis']. """ @@ -252,18 +285,55 @@ def compute_losses(self, ret_image_measured, azim_image_measured, ret_image_curr (axis_x[:, :, 1:] - axis_x[:, :, :-1]).pow(2).sum() ) # regularization_term = TV_reg + 1000 * (volume_estimation.Delta_n ** 2).mean() + TV_reg_axis_x / 100000 - regularization_term = training_params['regularization_weight'] * (TV_reg + 1000 * (volume_estimation.Delta_n ** 2).mean()) + regularization_term = training_params['regularization_weight'] * (0.5 * TV_reg + 1000 * (volume_estimation.Delta_n ** 2).mean()) # Total loss loss = data_term + regularization_term return loss, data_term, regularization_term + def _compute_loss(self, retardance_pred: torch.Tensor, azimuth_pred: torch.Tensor): + """ + Compute the loss for the current iteration after the forward model is applied. + + Note: If ep is a class attibrute, then the loss function can depend on the current epoch. + """ + vol_pred = self.volume_pred + params = self.iteration_params + retardance_meas = self.ret_img_meas + azimuth_meas = self.azim_img_meas + + loss_fcn_name = params.get('loss_fcn', 'L1_cos') + if not torch.is_tensor(retardance_meas): + retardance_meas = torch.tensor(retardance_meas) + if not torch.is_tensor(azimuth_meas): + azimuth_meas = torch.tensor(azimuth_meas) + # Vector difference GT + co_gt, ca_gt = retardance_meas * torch.cos(azimuth_meas), retardance_meas * torch.sin(azimuth_meas) + # Compute data term loss + co_pred, ca_pred = retardance_pred * torch.cos(azimuth_pred), retardance_pred * torch.sin(azimuth_pred) + data_term = ((co_gt - co_pred) ** 2 + (ca_gt - ca_pred) ** 2).mean() + + # Compute regularization term + delta_n = vol_pred.get_delta_n() + TV_reg = ( + (delta_n[1:, ...] - delta_n[:-1, ...]).pow(2).sum() + + (delta_n[:, 1:, ...] - delta_n[:, :-1, ...]).pow(2).sum() + + (delta_n[:, :, 1:] - delta_n[:, :, :-1]).pow(2).sum() + ) + # regularization_term = TV_reg + 1000 * (volume_estimation.Delta_n ** 2).mean() + TV_reg_axis_x / 100000 + regularization_term = params['regularization_weight'] * (0.5 * TV_reg + 1000 * (vol_pred.Delta_n ** 2).mean()) + + # Total loss + loss = data_term + regularization_term + + return loss, data_term, regularization_term + def one_iteration(self, optimizer, volume_estimation): optimizer.zero_grad() # Apply forward model [ret_image_current, azim_image_current] = self.rays.ray_trace_through_volume(volume_estimation) - loss, data_term, regularization_term = self.compute_losses(self.ret_img_meas, self.azim_img_meas, ret_image_current, azim_image_current, volume_estimation, self.iteration_params) + loss, data_term, regularization_term = self._compute_loss(ret_image_current, azim_image_current) loss.backward() optimizer.step() @@ -310,11 +380,10 @@ def reconstruct(self, output_dir=None): """ if output_dir is None: output_dir = create_unique_directory("reconstructions") - self.mask_outside_rays() + # self.restrict_volume_to_reachable_region() self.specify_variables_to_learn() # Turn off the gradients for the initial volume guess - self.volume_initial_guess.Delta_n.requires_grad = False - self.volume_initial_guess.optic_axis.requires_grad = False + self._turn_off_initial_volume_gradients() optimizer = self.optimizer_setup(self.volume_pred, self.iteration_params) figure = setup_visualization() # Iterations diff --git a/VolumeRaytraceLFM/utils/dimensions_utils.py b/VolumeRaytraceLFM/utils/dimensions_utils.py new file mode 100644 index 0000000..36a13ba --- /dev/null +++ b/VolumeRaytraceLFM/utils/dimensions_utils.py @@ -0,0 +1,85 @@ +import torch + +def get_region_of_ones_shape(mask): + """ + Computes the shape of the smallest bounding box that contains all the ones in the input mask. + Args: + mask (torch.Tensor): binary mask tensor. + Returns: + shape (torch.Tensor): shape of the smallest bounding box that contains all the ones in the input mask. + """ + indices = torch.nonzero(mask) + if indices.numel() == 0: + raise ValueError("Mask contains no ones.") + min_indices = indices.min(dim=0)[0] + max_indices = indices.max(dim=0)[0] + shape = max_indices - min_indices + 1 + return shape + +def crop_3d_tensor(tensor, new_shape): + """ + Crops a 3D tensor to a specified new shape, keeping the central part of the original tensor. + """ + D, H, W = tensor.shape + new_D, new_H, new_W = new_shape + start_D = (D - new_D) // 2 + end_D = start_D + new_D + start_H = (H - new_H) // 2 + end_H = start_H + new_H + start_W = (W - new_W) // 2 + end_W = start_W + new_W + return tensor[start_D:end_D, start_H:end_H, start_W:end_W] + +def reshape_crop_and_flatten_parameter(flattened_param, original_shape, new_shape): + # Reshape the flattened parameter + reshaped_param = flattened_param.view(original_shape) + + # Crop the tensor + *_, D, H, W = original_shape + new_D, new_H, new_W = new_shape + start_D = (D - new_D) // 2 + end_D = start_D + new_D + start_H = (H - new_H) // 2 + end_H = start_H + new_H + start_W = (W - new_W) // 2 + end_W = start_W + new_W + cropped_tensor = reshaped_param[..., start_D:end_D, start_H:end_H, start_W:end_W] + + # Flatten and convert back to a Parameter + cropped_flattened_parameter = torch.nn.Parameter(cropped_tensor.flatten()) + return cropped_flattened_parameter + +def reshape_and_crop(flattened_param, original_shape, new_shape): + """ + Reshapes a flattened tensor to its original shape and crops it to a new shape. + Args: + flattened_param (torch.Tensor): Flattened tensor to be reshaped and cropped. + original_shape (list): Original shape of the tensor before flattening. + new_shape (list): Desired shape of the cropped tensor. + Returns: + torch.Tensor: Cropped tensor with the desired shape. + """ + # Reshape the flattened parameter + reshaped_param = flattened_param.view(original_shape) + # Crop the tensor + *_, D, H, W = original_shape + new_D, new_H, new_W = new_shape + start_D = (D - new_D) // 2 + end_D = start_D + new_D + start_H = (H - new_H) // 2 + end_H = start_H + new_H + start_W = (W - new_W) // 2 + end_W = start_W + new_W + cropped_tensor = reshaped_param[..., start_D:end_D, start_H:end_H, start_W:end_W] + return cropped_tensor + +def store_as_pytorch_parameter(tensor, var_type: str): + ''' + Converts a tensor to a PyTorch parameter and flattens appropriately. + Note: possibly .type(torch.get_default_dtype()) is needed. + ''' + if var_type == 'scalar': + parameter = torch.nn.Parameter(tensor.flatten()) + elif var_type == 'vector': + parameter = torch.nn.Parameter(tensor.reshape(3, -1)) + return parameter diff --git a/VolumeRaytraceLFM/visualization/plotting_ret_azim.py b/VolumeRaytraceLFM/visualization/plotting_ret_azim.py index 547c4a4..1ccc656 100644 --- a/VolumeRaytraceLFM/visualization/plotting_ret_azim.py +++ b/VolumeRaytraceLFM/visualization/plotting_ret_azim.py @@ -115,9 +115,12 @@ def plot_hue_map(retardance_img, azimuth_img, ax=None): cb2.set_ticks([0, 0.5, 1]) # Assuming the data for azimuth ranges from 0 to 2π, we normalize this range to 0-1 for the colorbar. cb2.set_ticklabels(['0', round(retardance_img.max()/2, 1), round(retardance_img.max(), 1)]) axins2.set_title("Value", fontsize=8) - plt.show() + display_plot = False + if display_plot: + plt.show() def plot_retardance_orientation(ret_image, azim_image, azimuth_plot_type='hsv', include_labels=False): + plt.ioff() # Prevents plots from popping up fig = plt.figure(figsize=(12,2.5)) plt.rcParams['image.origin'] = 'lower' # Retardance subplot diff --git a/config_settings/iter_config.json b/config_settings/iter_config.json index 3257c69..e4ff734 100644 --- a/config_settings/iter_config.json +++ b/config_settings/iter_config.json @@ -3,5 +3,6 @@ "azimuth_weight": 0.5, "regularization_weight": 0.1, "lr": 1e-3, - "output_posfix": "" + "output_posfix": "", + "loss_function": "" } \ No newline at end of file diff --git a/config_settings/optical_config2.json b/config_settings/optical_config2.json index 5ca26db..68acfcc 100644 --- a/config_settings/optical_config2.json +++ b/config_settings/optical_config2.json @@ -1,5 +1,5 @@ { - "volume_shape" : [9, 41, 41], + "volume_shape" : [9, 35, 35], "axial_voxel_size_um" : 1.0, "cube_voxels" : true, "pixels_per_ml" : 17, diff --git a/run_recon.py b/run_recon.py index e3c429f..345f47f 100644 --- a/run_recon.py +++ b/run_recon.py @@ -58,7 +58,7 @@ def recon_gpu(): visualize_volume(reconstructor.volume_pred, reconstructor.optical_info) def main(): - optical_info = setup_optical_parameters("config_settings\optical_config3.json") + optical_info = setup_optical_parameters("config_settings\optical_config_largemla.json") optical_system = {'optical_info': optical_info} # Initialize the forward model. Raytracing is performed as part of the initialization. simulator = ForwardModel(optical_system, backend=BACKEND) @@ -66,15 +66,15 @@ def main(): volume_GT = BirefringentVolume( backend=BACKEND, optical_info=optical_info, - volume_creation_args=volume_args.ellipsoid_args2 + volume_creation_args=volume_args.sphere_args5 #ellipsoid_args2 #voxel_args ) - visualize_volume(volume_GT, optical_info) + # visualize_volume(volume_GT, optical_info) simulator.forward_model(volume_GT) # simulator.view_images() ret_image_meas = simulator.ret_img azim_image_meas = simulator.azim_img - recon_optical_info = optical_info + recon_optical_info = optical_info.copy() iteration_params = setup_iteration_parameters("config_settings\iter_config.json") initial_volume = BirefringentVolume( backend=BackEnds.PYTORCH, diff --git a/tests/test_dimensions_utils.py b/tests/test_dimensions_utils.py new file mode 100644 index 0000000..8b8d1a9 --- /dev/null +++ b/tests/test_dimensions_utils.py @@ -0,0 +1,51 @@ +import torch +import pytest +from VolumeRaytraceLFM.utils.dimensions_utils import ( + get_region_of_ones_shape, + crop_3d_tensor, + reshape_crop_and_flatten_parameter, + reshape_and_crop, + store_as_pytorch_parameter +) + +def test_get_region_of_ones_shape(): + # Test with a simple case + mask = torch.tensor([[0, 1], [1, 0]]) + expected_shape = torch.tensor([2, 2]) + assert torch.all(get_region_of_ones_shape(mask) == expected_shape) + + # Test with no ones in the mask + mask = torch.zeros((2, 2)) + with pytest.raises(ValueError): + get_region_of_ones_shape(mask) + +def test_crop_3d_tensor(): + tensor = torch.randn(4, 4, 4) + new_shape = (2, 2, 2) + cropped_tensor = crop_3d_tensor(tensor, new_shape) + assert cropped_tensor.shape == torch.Size(new_shape) + +def test_reshape_crop_and_flatten_parameter(): + flattened_param = torch.randn(4*4*4) + original_shape = (4, 4, 4) + new_shape = (2, 2, 2) + parameter = reshape_crop_and_flatten_parameter(flattened_param, original_shape, new_shape) + assert parameter.shape == torch.Size([8]) + +def test_reshape_and_crop(): + flattened_param = torch.randn(4*4*4) + original_shape = (4, 4, 4) + new_shape = (2, 2, 2) + tensor = reshape_and_crop(flattened_param, original_shape, new_shape) + assert tensor.shape == torch.Size(new_shape) + +def test_store_as_pytorch_parameter(): + tensor = torch.tensor([1.0, 2.0, 3.0]) + scalar_param = store_as_pytorch_parameter(tensor, 'scalar') + vector_param = store_as_pytorch_parameter(tensor, 'vector') + + assert isinstance(scalar_param, torch.nn.Parameter) + assert scalar_param.shape == torch.Size([3]) + + assert isinstance(vector_param, torch.nn.Parameter) + assert vector_param.shape == torch.Size([3, 1])