Skip to content

Commit

Permalink
Merge pull request #65 from PolarizedLightFieldMicroscopy/loss_fcns
Browse files Browse the repository at this point in the history
Loss function modularity beginning
  • Loading branch information
gschlafly authored Nov 27, 2023
2 parents 289e5bc + 557eec1 commit 1850e0a
Show file tree
Hide file tree
Showing 12 changed files with 444 additions and 38 deletions.
67 changes: 56 additions & 11 deletions VolumeRaytraceLFM/birefringence_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,13 +909,19 @@ def get_volume_reachable_region(self):
n_voxels_per_ml = self.optical_info['n_voxels_per_ml']
n_ml_half = floor(n_micro_lenses * n_voxels_per_ml / 2.0)
mask = torch.zeros(self.optical_info['volume_shape'])
mask[:,
self.vox_ctr_idx[1]-n_ml_half+1 : self.vox_ctr_idx[1]+n_ml_half,
self.vox_ctr_idx[2]-n_ml_half+1 : self.vox_ctr_idx[2]+n_ml_half] = 1.0
include_ray_angle_reach = True
if include_ray_angle_reach:
vox_span_half = int(self.voxel_span_per_ml + (n_micro_lenses * n_voxels_per_ml) / 2)
mask[:,
self.vox_ctr_idx[1]-vox_span_half+1 : self.vox_ctr_idx[1]+vox_span_half,
self.vox_ctr_idx[2]-vox_span_half+1 : self.vox_ctr_idx[2]+vox_span_half] = 1.0
else:
mask[:,
self.vox_ctr_idx[1]-n_ml_half+1 : self.vox_ctr_idx[1]+n_ml_half,
self.vox_ctr_idx[2]-n_ml_half+1 : self.vox_ctr_idx[2]+n_ml_half] = 1.0
# mask_volume = BirefringentVolume(backend=self.backend,
# optical_info=self.optical_info, Delta_n=0.01, optic_axis=[0.5,0.5,0])
# [r,a] = self.ray_trace_through_volume(mask_volume)

# # Check gradients to see what is affected
# L = r.mean() + a.mean()
# L.backward()
Expand All @@ -942,7 +948,7 @@ def precompute_MLA_volume_geometry(self):
# border_size_around_mla = np.ceil((volume_shape[1]-(n_micro_lenses*n_voxels_per_ml)) / 2)
min_needed_volume_size = int(self.voxel_span_per_ml + (n_micro_lenses*n_voxels_per_ml))
assert min_needed_volume_size <= volume_shape[1] and min_needed_volume_size <= volume_shape[2], "The volume in front of the microlenses" + \
f"({n_micro_lenses},{n_micro_lenses}) is to large for a volume_shape: {self.optical_info['volume_shape'][1:]}. " + \
f"({n_micro_lenses},{n_micro_lenses}) is too large for a volume_shape: {self.optical_info['volume_shape'][1:]}. " + \
f"Increase the volume_shape to at least [{min_needed_volume_size+1},{min_needed_volume_size+1}]"

odd_mla_shift = np.mod(n_micro_lenses,2)
Expand All @@ -956,12 +962,14 @@ def precompute_MLA_volume_geometry(self):
np.array([n_voxels_per_ml * ml_ii, n_voxels_per_ml*ml_jj])
+ np.array(self.vox_ctr_idx[1:]) - n_voxels_per_ml_half
)
self.vox_indices_ml_shifted_all += [[RayTraceLFM.ravel_index((vox[ix][0],
vox[ix][1]+current_offset[0],
vox[ix][2]+current_offset[1]),
self.optical_info['volume_shape']) for ix in range(len(vox))]
for vox in self.ray_vol_colli_indices
]
self.vox_indices_ml_shifted_all += [
[
RayTraceLFM.ravel_index(
(vox[ix][0], vox[ix][1]+current_offset[0], vox[ix][2]+current_offset[1]),
self.optical_info['volume_shape']) for ix in range(len(vox))
]
for vox in self.ray_vol_colli_indices
]
# Shift ray-pixel indices
if self.ray_valid_indices_all is None:
self.ray_valid_indices_all = self.ray_valid_indices.clone()
Expand Down Expand Up @@ -1408,6 +1416,43 @@ def voxRayJM(self, Delta_n, opticAxis, rayDir, ell, wavelength):
assert not torch.isnan(JM).any(), "A Jones matrix contains NaN values."
return JM

def vox_ray_ret_azim(self, Delta_n, opticAxis, rayDir, ell, wavelength):
'''Calculate the effective retardance and azimuth of a ray passing through a voxel'''
if self.backend == BackEnds.NUMPY:
# Azimuth is the angle of the slow axis of retardance.
azim = np.arctan2(np.dot(opticAxis, rayDir[1]), np.dot(opticAxis, rayDir[2]))
if Delta_n == 0:
azim = 0
elif Delta_n < 0:
azim = azim + np.pi / 2
# print(f"Azimuth angle of index ellipsoid is
# {np.around(np.rad2deg(azim), decimals=0)} degrees.")
normAxis = np.linalg.norm(opticAxis)
proj_along_ray = np.dot(opticAxis, rayDir[0])
# np.divide(my_arr, my_arr1, out=np.ones_like(my_arr, dtype=np.float32), where=my_arr1 != 0)
ret = abs(Delta_n) * (1 - np.dot(opticAxis, rayDir[0]) ** 2) * 2 * np.pi * ell / wavelength
else:
raise NotImplementedError("Not implemented for pytorch yet.")
return ret, azim

def vox_ray_matrix(self, ret, azim):
'''Calculate the Jones matrix associated with a particular ray and voxel combination'''
if self.backend == BackEnds.NUMPY:
JM = JonesMatrixGenerators.linear_retarder(ret, azim)
pass
else:
raise NotImplementedError("Not implemented for pytorch yet.")
offdiag = 1j * torch.sin(azim) * torch.sin(ret)
diag1 = torch.cos(ret) + 1j * torch.cos(azim) * torch.sin(ret)
diag2 = torch.conj(diag1)
# Construct Jones Matrix
JM = torch.zeros([Delta_n.shape[0], 2, 2], dtype=torch.complex64, device=Delta_n.device)
JM[:,0,0] = diag1
JM[:,0,1] = offdiag
JM[:,1,0] = offdiag
JM[:,1,1] = diag2
return JM

def clone(self):
# Code to create a copy of this instance
new_instance = BirefringentVolume(...)
Expand Down
15 changes: 15 additions & 0 deletions VolumeRaytraceLFM/metrics/data_fidelity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
'''Data-fidelity metrics'''
import torch
import torch.nn.functional as F

def von_mises_loss(angle_pred, angle_gt, kappa=1.0):
'''Von Mises loss function for orientation'''
diff = angle_pred - angle_gt
loss = 1 - torch.exp(kappa * torch.cos(diff))
return loss.mean()

def cosine_similarity_loss(vector_pred, vector_gt):
'''Cosine similarity loss function for orientation'''
cos_sim = F.cosine_similarity(vector_pred, vector_gt, dim=-1)
loss = 1 - cos_sim
return loss.mean()
75 changes: 75 additions & 0 deletions VolumeRaytraceLFM/metrics/metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from regularization import L1Regularization, L2Regularization

REGULARIZATION_FNS = {
'L1Regularization': L1Regularization,
'L2Regularization': L2Regularization,
# Add more functions here if needed
}

class PolarimetricLossFunction:
def __init__(self, json_file=None):
if json_file:
with open(json_file, 'r') as f:
params = json.load(f)
self.weight_retardance = params.get('weight_retardance', 1.0)
self.weight_orientation = params.get('weight_orientation', 1.0)
self.weight_datafidelity = params.get('weight_datafidelity', 1.0)
self.weight_regularization = params.get('weight_regularization', 0.1)
# Initialize any specific loss functions you might need
self.mse_loss = nn.MSELoss()
# Initialize regularization functions
self.regularization_fns = [(REGULARIZATION_FNS[fn_name], weight) for fn_name, weight in params.get('regularization_fns', [])]
else:
self.weight_retardance = 1.0
self.weight_orientation = 1.0
self.weight_datafidelity = 1.0
self.weight_regularization = 0.1
self.mse_loss = nn.MSELoss()
self.regularization_fns = []

def set_retardance_target(self, target):
self.target_retardance = target

def set_orientation_target(self, target):
self.target_orientation = target

def compute_retardance_loss(self, prediction):
# Add logic to transform data and compute retardance loss
pass

def compute_orientation_loss(self, prediction):
# Add logic to transform data and compute orientation loss
pass

def transform_input_data(self, data):
# Transform the input data into a vector form
pass

def compute_datafidelity_term(self, pred_retardance, pred_orientation):
'''Incorporates the retardance and orientation losses'''
retardance_loss = self.compute_retardance_loss(pred_retardance)
orientation_loss = self.compute_orientation_loss(pred_orientation)
data_loss = (self.weight_retardance * retardance_loss +
self.weight_regularization * orientation_loss)
return data_loss

def compute_regularization_term(self, data):
'''Compute regularization term'''
regularization_loss = torch.tensor(0.)
for reg_fn, weight in self.regularization_fns:
regularization_loss += weight * reg_fn(data)
return regularization_loss

def compute_total_loss(self, pred_retardance, pred_orientation, data):
# Compute individual losses
datafidelity_loss = self.compute_datafidelity_term(pred_retardance, pred_orientation)
regularization_loss = self.compute_regularization_term(data)

# Compute total loss with weighted sum
total_loss = (self.weight_datafidelity * datafidelity_loss +
self.weight_regularization * regularization_loss)
return total_loss
28 changes: 28 additions & 0 deletions VolumeRaytraceLFM/metrics/regularization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
'''Regularization metrics for a birefringent volume'''
from VolumeRaytraceLFM.birefringence_implementations import BirefringentVolume
from regularization_fundamentals import *

def l2_bir(volume: BirefringentVolume):
birefringence = volume.get_delta_n()
return l2(birefringence)

def total_variation_bir(volume: BirefringentVolume):
birefringence = volume.get_delta_n()
return total_variation_3d_volumetric(birefringence)

class AnisotropyAnalysis:
def __init__(self, volume: BirefringentVolume):
self.volume = volume
self.birefringence = volume.get_delta_n()
self.optic_axis = volume.get_optic_axis()

def l2_regularization(self):
return l2(self.birefringence)

def total_variation_regularization(self):
return total_variation_3d_volumetric(self.birefringence)

def process_optic_axis(self):
# Example method to process optic axis data
# Implement the specific logic you need for optic axis data
pass
34 changes: 34 additions & 0 deletions VolumeRaytraceLFM/metrics/regularization_fundamentals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
'''Regularization functions that can use used in the optimization process.'''
import torch

def l1(data, weight=1.0):
return weight * torch.abs(data).mean()

def l2(data, weight=1.0):
return weight * torch.pow(data, 2).mean()

def linfinity(data, weight=1.0):
return weight * torch.max(torch.abs(data))

def elastic_net(data, weight1=1.0, weight2=1.0):
l1_term = torch.abs(data).sum()
l2_term = torch.pow(data, 2).sum()
return weight1 * l1_term + weight2 * l2_term

def total_variation_3d_volumetric(data, weight=1.0):
"""
Computes the Total Variation regularization for a 4D tensor representing volumetric data.
Args:
data (torch.Tensor): The input 3D tensor with shape [depth, height, width].
weight (float): Weighting factor for the regularization term.
Returns:
torch.Tensor: The computed Total Variation regularization term.
"""
# Calculate the differences between adjacent elements along each spatial dimension
diff_depth = torch.pow(data[1:, :, :] - data[:-1, :, :], 2).sum()
diff_height = torch.pow(data[:, 1:, :] - data[:, :-1, :], 2).sum()
diff_width = torch.pow(data[:, :, 1:] - data[:, :, :-1], 2).sum()

# Sum up the differences and apply the weight
tv_reg = weight * (diff_depth + diff_height + diff_width)
return tv_reg
Loading

0 comments on commit 1850e0a

Please sign in to comment.