Skip to content

Commit

Permalink
Merge pull request #93 from PolarizedLightFieldMicroscopy/guv-xylem
Browse files Browse the repository at this point in the history
MLA indexing and supersampling voxel size

Fixes #89
  • Loading branch information
gschlafly authored Mar 19, 2024
2 parents 615432d + 9fa8bd3 commit 3e2275d
Show file tree
Hide file tree
Showing 9 changed files with 309 additions and 39 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/pytest-action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ jobs:
matrix:
python-version: ["3.10", "3.11"]
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
Expand Down
4 changes: 2 additions & 2 deletions VolumeRaytraceLFM/abstract_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,13 @@ def __init__(self, backend : BackEnds = BackEnds.NUMPY, torch_args={},
optical_info['voxel_size_um'] = (
[optical_info['axial_voxel_size_um'],]
+ 2*[optical_info['pixels_per_ml'] * optical_info['camera_pix_pitch']
/ optical_info['M_obj']]
/ optical_info['M_obj'] / optical_info['n_voxels_per_ml']]
)
else:
# Option to make voxel size uniform
optical_info['voxel_size_um'] = (
3*[optical_info['pixels_per_ml'] * optical_info['camera_pix_pitch']
/ optical_info['M_obj']]
/ optical_info['M_obj'] / optical_info['n_voxels_per_ml']]
)
# Check if back-end is torch and overwrite self with an optic block, for Waveblocks
# compatibility.
Expand Down
8 changes: 4 additions & 4 deletions VolumeRaytraceLFM/birefringence_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,7 +1052,7 @@ def store_shifted_vox_indices(self):
current_offset = self._calculate_current_offset(
ml_ii, ml_jj, n_voxels_per_ml, n_micro_lenses
)
mla_index = (ml_ii_idx, ml_jj_idx)
mla_index = (ml_jj_idx, ml_ii_idx)
vox_list = self._gather_voxels_of_rays_pytorch(
current_offset, collision_indices
)
Expand Down Expand Up @@ -1131,12 +1131,12 @@ def ray_trace_through_volume(self, volume_in : BirefringentVolume = None,
# Generate (intensity or ret/azim) images for the current microlens,
# by passing an offset to this function
# depending on the microlens and the super resolution
current_mla_index = (ml_ii_idx, ml_jj_idx)
current_mla_index = (ml_jj_idx, ml_ii_idx)
start_time = time.time()
img_list = self.generate_images(volume_in, current_offset,
intensity, mla_index=current_mla_index)
execution_time = time.time() - start_time
mla_index = (ml_ii_idx, ml_jj_idx)
mla_index = (ml_jj_idx, ml_ii_idx)
if mla_index not in self.mla_execution_times:
self.mla_execution_times[mla_index] = 0
self.mla_execution_times[mla_index] += execution_time
Expand Down Expand Up @@ -1567,7 +1567,7 @@ def _count_vox_raytrace_occurrences(self, zero_retardance_voxels=False):
count = Counter()
for ml_ii_idx in range(n_micro_lenses):
for ml_jj_idx in range(n_micro_lenses):
mla_index = (ml_ii_idx, ml_jj_idx)
mla_index = (ml_jj_idx, ml_ii_idx)
vox_indices = self.vox_indices_by_mla_idx[mla_index]

if zero_retardance_voxels:
Expand Down
23 changes: 17 additions & 6 deletions VolumeRaytraceLFM/reconstructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,9 +465,9 @@ def _compute_loss(self, retardance_pred: torch.Tensor, azimuth_pred: torch.Tenso
if not torch.is_tensor(azimuth_meas):
azimuth_meas = torch.tensor(azimuth_meas)
# Vector difference GT
co_gt, ca_gt = retardance_meas * torch.cos(azimuth_meas), retardance_meas * torch.sin(azimuth_meas)
co_gt, ca_gt = retardance_meas * torch.cos(2*azimuth_meas), retardance_meas * torch.sin(2*azimuth_meas)
# Compute data term loss
co_pred, ca_pred = retardance_pred * torch.cos(azimuth_pred), retardance_pred * torch.sin(azimuth_pred)
co_pred, ca_pred = retardance_pred * torch.cos(2*azimuth_pred), retardance_pred * torch.sin(2*azimuth_pred)
data_term = ((co_gt - co_pred) ** 2 + (ca_gt - ca_pred) ** 2).mean()

# Compute regularization term
Expand All @@ -478,6 +478,16 @@ def _compute_loss(self, retardance_pred: torch.Tensor, azimuth_pred: torch.Tenso
(delta_n[:, :, 1:] - delta_n[:, :, :-1]).pow(2).sum()
)

# Try a scaled TV regularization
# avg_scale_1 = torch.abs((delta_n[1:, ...] + delta_n[:-1, ...])) / 2.0
# avg_scale_2 = torch.abs((delta_n[:, 1:, ...] + delta_n[:, :-1, ...])) / 2.0
# avg_scale_3 = torch.abs((delta_n[:, :, 1:] + delta_n[:, :, :-1])) / 2.0
# TV_reg_scaled = (
# ((delta_n[1:, ...] - delta_n[:-1, ...]) * avg_scale_1).pow(2).sum() +
# ((delta_n[:, 1:, ...] - delta_n[:, :-1, ...]) * avg_scale_2).pow(2).sum() +
# ((delta_n[:, :, 1:] - delta_n[:, :, :-1]) * avg_scale_3).pow(2).sum()
# )

cos_sim_loss = weighted_local_cosine_similarity_loss(
vol_pred.get_optic_axis(), vol_pred.get_delta_n()
)
Expand All @@ -487,6 +497,7 @@ def _compute_loss(self, retardance_pred: torch.Tensor, azimuth_pred: torch.Tenso
# regularization_term = TV_reg + 1000 * (volume_estimation.Delta_n ** 2).mean() + TV_reg_axis_x / 100000

TV_term = params['TV_weight'] * TV_reg
# TV_term = params['TV_scaled_weight'] * TV_reg_scaled
L1_norm_term = params['L1_norm_weight'] * (vol_pred.Delta_n ** 2).mean()
cos_sim_term = params['cos_sim_weight'] * cos_sim_loss
regularization_term = params['regularization_weight'] * (TV_term + L1_norm_term + cos_sim_term)
Expand Down Expand Up @@ -733,11 +744,11 @@ def reconstruct(self, output_dir=None, use_streamlit=False):
self.save_loss_lists_to_csv()
my_description = "Volume estimation after " + \
str(ep) + " iterations."
vol_save_path = os.path.join(output_dir, f"volume_ep_{'{:03d}'.format(ep)}.h5")
self.volume_pred.save_as_file(
os.path.join(
output_dir, f"volume_ep_{'{:03d}'.format(ep)}.h5"),
description=my_description
vol_save_path, description=my_description
)
print("Saved the final volume estimation to", vol_save_path)
# Final visualizations after training completes
plt.savefig(os.path.join(output_dir, "optim_final.pdf"))
plt.show()
plt.show()
7 changes: 5 additions & 2 deletions VolumeRaytraceLFM/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,14 @@ def save_as_tif(file_path, data, metadata):
return


def create_unique_directory(base_output_dir):
def create_unique_directory(base_output_dir, postfix=None):
# Get the current date and time
now = datetime.now()
# Format the date and time as a string in the desired format, e.g., 'YYYY-MM-DD_HH-MM-SS'
dir_name = now.strftime("%Y-%m-%d_%H-%M-%S")
if postfix is not None:
dir_name = now.strftime("%Y-%m-%d_%H-%M-%S") + '_' + postfix
else:
dir_name = now.strftime("%Y-%m-%d_%H-%M-%S")
unique_output_dir = os.path.join(base_output_dir, dir_name)
os.makedirs(unique_output_dir, exist_ok=True)
print(f"Created the unique output directory {unique_output_dir}")
Expand Down
10 changes: 10 additions & 0 deletions VolumeRaytraceLFM/volumes/volume_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,16 @@
}
}

sphere_args6_thick_ss3 = {
'init_mode': 'ellipsoid',
'init_args': {
'radius': [19.5, 19.5, 19.5],
'center': [0.5, 0.5, 0.5],
'delta_n': 0.01,
'border_thickness': 6
}
}

sphere_shifted = {
'init_mode': 'ellipsoid',
'init_args': {
Expand Down
2 changes: 1 addition & 1 deletion recon_sphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def recon_sphere6_thick():
optical_info=recon_optical_info,
volume_creation_args=volume_args.random_args
)
recon_directory = create_unique_directory("reconstructions")
recon_directory = create_unique_directory("reconstructions", postfix='sphere6_thick2')
if not simulate:
volume_GT = initial_volume
recon_config = ReconstructionConfig(recon_optical_info, ret_image_meas,
Expand Down
Loading

0 comments on commit 3e2275d

Please sign in to comment.