-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #65 from PolarizedLightFieldMicroscopy/loss_fcns
Loss function modularity beginning
- Loading branch information
Showing
12 changed files
with
444 additions
and
38 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
'''Data-fidelity metrics''' | ||
import torch | ||
import torch.nn.functional as F | ||
|
||
def von_mises_loss(angle_pred, angle_gt, kappa=1.0): | ||
'''Von Mises loss function for orientation''' | ||
diff = angle_pred - angle_gt | ||
loss = 1 - torch.exp(kappa * torch.cos(diff)) | ||
return loss.mean() | ||
|
||
def cosine_similarity_loss(vector_pred, vector_gt): | ||
'''Cosine similarity loss function for orientation''' | ||
cos_sim = F.cosine_similarity(vector_pred, vector_gt, dim=-1) | ||
loss = 1 - cos_sim | ||
return loss.mean() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import json | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from regularization import L1Regularization, L2Regularization | ||
|
||
REGULARIZATION_FNS = { | ||
'L1Regularization': L1Regularization, | ||
'L2Regularization': L2Regularization, | ||
# Add more functions here if needed | ||
} | ||
|
||
class PolarimetricLossFunction: | ||
def __init__(self, json_file=None): | ||
if json_file: | ||
with open(json_file, 'r') as f: | ||
params = json.load(f) | ||
self.weight_retardance = params.get('weight_retardance', 1.0) | ||
self.weight_orientation = params.get('weight_orientation', 1.0) | ||
self.weight_datafidelity = params.get('weight_datafidelity', 1.0) | ||
self.weight_regularization = params.get('weight_regularization', 0.1) | ||
# Initialize any specific loss functions you might need | ||
self.mse_loss = nn.MSELoss() | ||
# Initialize regularization functions | ||
self.regularization_fns = [(REGULARIZATION_FNS[fn_name], weight) for fn_name, weight in params.get('regularization_fns', [])] | ||
else: | ||
self.weight_retardance = 1.0 | ||
self.weight_orientation = 1.0 | ||
self.weight_datafidelity = 1.0 | ||
self.weight_regularization = 0.1 | ||
self.mse_loss = nn.MSELoss() | ||
self.regularization_fns = [] | ||
|
||
def set_retardance_target(self, target): | ||
self.target_retardance = target | ||
|
||
def set_orientation_target(self, target): | ||
self.target_orientation = target | ||
|
||
def compute_retardance_loss(self, prediction): | ||
# Add logic to transform data and compute retardance loss | ||
pass | ||
|
||
def compute_orientation_loss(self, prediction): | ||
# Add logic to transform data and compute orientation loss | ||
pass | ||
|
||
def transform_input_data(self, data): | ||
# Transform the input data into a vector form | ||
pass | ||
|
||
def compute_datafidelity_term(self, pred_retardance, pred_orientation): | ||
'''Incorporates the retardance and orientation losses''' | ||
retardance_loss = self.compute_retardance_loss(pred_retardance) | ||
orientation_loss = self.compute_orientation_loss(pred_orientation) | ||
data_loss = (self.weight_retardance * retardance_loss + | ||
self.weight_regularization * orientation_loss) | ||
return data_loss | ||
|
||
def compute_regularization_term(self, data): | ||
'''Compute regularization term''' | ||
regularization_loss = torch.tensor(0.) | ||
for reg_fn, weight in self.regularization_fns: | ||
regularization_loss += weight * reg_fn(data) | ||
return regularization_loss | ||
|
||
def compute_total_loss(self, pred_retardance, pred_orientation, data): | ||
# Compute individual losses | ||
datafidelity_loss = self.compute_datafidelity_term(pred_retardance, pred_orientation) | ||
regularization_loss = self.compute_regularization_term(data) | ||
|
||
# Compute total loss with weighted sum | ||
total_loss = (self.weight_datafidelity * datafidelity_loss + | ||
self.weight_regularization * regularization_loss) | ||
return total_loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
'''Regularization metrics for a birefringent volume''' | ||
from VolumeRaytraceLFM.birefringence_implementations import BirefringentVolume | ||
from regularization_fundamentals import * | ||
|
||
def l2_bir(volume: BirefringentVolume): | ||
birefringence = volume.get_delta_n() | ||
return l2(birefringence) | ||
|
||
def total_variation_bir(volume: BirefringentVolume): | ||
birefringence = volume.get_delta_n() | ||
return total_variation_3d_volumetric(birefringence) | ||
|
||
class AnisotropyAnalysis: | ||
def __init__(self, volume: BirefringentVolume): | ||
self.volume = volume | ||
self.birefringence = volume.get_delta_n() | ||
self.optic_axis = volume.get_optic_axis() | ||
|
||
def l2_regularization(self): | ||
return l2(self.birefringence) | ||
|
||
def total_variation_regularization(self): | ||
return total_variation_3d_volumetric(self.birefringence) | ||
|
||
def process_optic_axis(self): | ||
# Example method to process optic axis data | ||
# Implement the specific logic you need for optic axis data | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
'''Regularization functions that can use used in the optimization process.''' | ||
import torch | ||
|
||
def l1(data, weight=1.0): | ||
return weight * torch.abs(data).mean() | ||
|
||
def l2(data, weight=1.0): | ||
return weight * torch.pow(data, 2).mean() | ||
|
||
def linfinity(data, weight=1.0): | ||
return weight * torch.max(torch.abs(data)) | ||
|
||
def elastic_net(data, weight1=1.0, weight2=1.0): | ||
l1_term = torch.abs(data).sum() | ||
l2_term = torch.pow(data, 2).sum() | ||
return weight1 * l1_term + weight2 * l2_term | ||
|
||
def total_variation_3d_volumetric(data, weight=1.0): | ||
""" | ||
Computes the Total Variation regularization for a 4D tensor representing volumetric data. | ||
Args: | ||
data (torch.Tensor): The input 3D tensor with shape [depth, height, width]. | ||
weight (float): Weighting factor for the regularization term. | ||
Returns: | ||
torch.Tensor: The computed Total Variation regularization term. | ||
""" | ||
# Calculate the differences between adjacent elements along each spatial dimension | ||
diff_depth = torch.pow(data[1:, :, :] - data[:-1, :, :], 2).sum() | ||
diff_height = torch.pow(data[:, 1:, :] - data[:, :-1, :], 2).sum() | ||
diff_width = torch.pow(data[:, :, 1:] - data[:, :, :-1], 2).sum() | ||
|
||
# Sum up the differences and apply the weight | ||
tv_reg = weight * (diff_depth + diff_height + diff_width) | ||
return tv_reg |
Oops, something went wrong.