Skip to content

Commit

Permalink
refactoring of loss metrics, support for monitoring multiple metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
pbenner committed Feb 11, 2025
1 parent b489776 commit fb49e22
Show file tree
Hide file tree
Showing 8 changed files with 500 additions and 403 deletions.
10 changes: 8 additions & 2 deletions equitrain/argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,9 +266,15 @@ def get_args_parser(script_type: str) -> argparse.ArgumentParser:
)
parser.add_argument(
'--loss-type',
help='Type of loss function [l1 (default), smooth-l1, l2, huber]',
help='Type of loss function [mae, smooth-l1, mse, huber (default)]',
type=str,
default='l1',
default='huber',
)
parser.add_argument(
'--loss-monitor',
help='Comma separated list of loss types to monitor [default: mae,mse]',
type=str,
default='mae,mse',
)
parser.add_argument(
'--smooth-l1-beta',
Expand Down
231 changes: 18 additions & 213 deletions equitrain/loss.py
Original file line number Diff line number Diff line change
@@ -1,129 +1,5 @@
import torch

from equitrain.data.scatter import scatter_mean


class GenericError(torch.nn.Module):
def __init__(
self,
loss_type: str = None,
smooth_l1_beta: float = None,
huber_delta: float = None,
**args,
):
super().__init__()

loss_type = loss_type.lower()

if type is None or loss_type == 'l1':
self.error_fn = lambda x, y: torch.nn.functional.l1_loss(
x, y, reduction='none'
)

elif loss_type == 'smooth-l1':
self.error_fn = lambda x, y: torch.nn.functional.smooth_l1_loss(
x, y, beta=smooth_l1_beta, reduction='none'
)

elif loss_type == 'l2' or loss_type == 'mse':
self.error_fn = lambda x, y: torch.nn.functional.mse_loss(
x, y, reduction='none'
)

elif loss_type == 'huber':
self.error_fn = lambda x, y: torch.nn.functional.huber_loss(
x, y, delta=huber_delta, reduction='none'
)

else:
raise ValueError(f'Invalid loss type: {loss_type}')

def forward(self, input, target):
return self.error_fn(input, target)


class L1LossEnergy(torch.nn.Module):
def __init__(self, **args):
super().__init__()

self.error_fn = GenericError(**args)

def forward(self, input, target, weights):
error = self.error_fn(input, target)
error *= weights

loss = error.mean()
error = error.detach()

return loss, error


class L1LossForces(torch.nn.Module):
def __init__(self, **args):
super().__init__()

self.error_fn = GenericError(**args)

def forward(self, input, target, batch):
error = self.error_fn(input, target)

loss = error.mean()

error = error.detach()
error = error.mean(dim=1)
error = scatter_mean(error, batch)

return loss, error


class L1LossStress(torch.nn.Module):
def __init__(self, **args):
super().__init__()

self.error_fn = GenericError(**args)

def forward(self, input, target):
error = self.error_fn(input, target)

loss = error.mean()
error = error.detach()
error = error.mean(dim=(1, 2))

return loss, error


class ForceAngleLoss(torch.nn.Module):
def __init__(self, angle_weight=1.0, epsilon=1e-8):
super().__init__()

self.angle_weight = angle_weight
self.epsilon = epsilon

def forward(self, input, target, weights=None):
# Compute lengths of force vectors
n1 = torch.norm(target, dim=1)
n2 = torch.norm(input, dim=1)
# Compute angle between force vectors
angle = self.compute_angle(target, input, n1=n1, n2=n2)

# Loss is the sum of normalized length mismath and angle discrepancy
return torch.mean(
*(
torch.abs(n1 - n2) / (0.5 * n1 + 0.5 * n2 + self.epsilon)
+ self.angle_weight * angle
)
)

def compute_angle(self, s, t, n1=None, n2=None):
if n1 is None:
n1 = torch.norm(s, dim=1)
if n2 is None:
n2 = torch.norm(t, dim=1)
# Compute dot product between force vectors
dp = torch.einsum('ij,ij->i', s, t)
# Compute angle, use tanh for numerical stability
return torch.arccos(dp / (n1 * n2 + self.epsilon))


class LossComponent:
def __init__(self, value: torch.Tensor = None, n: torch.Tensor = None, device=None):
Expand Down Expand Up @@ -157,7 +33,6 @@ def gather_for_metrics(self, accelerator):

values = accelerator.gather_for_metrics(self.value.detach())
ns = accelerator.gather_for_metrics(self.n.detach())
skip = (ns == 0.0).any().item()

if len(values.shape) == 0:
# Single processing context
Expand All @@ -170,7 +45,7 @@ def gather_for_metrics(self, accelerator):
for i in range(len(values)):
r += LossComponent(value=values[i], n=ns[i])

return r, skip
return r


class Loss(dict):
Expand Down Expand Up @@ -199,102 +74,32 @@ def detach(self):

def gather_for_metrics(self, accelerator):
result = Loss(device=accelerator.device)
skip = {}

for key, component in self.items():
result[key], skip[key] = component.gather_for_metrics(accelerator)

return result, skip


class GenericLossFn(torch.nn.Module):
def __init__(
self,
energy_weight: float = 1.0,
forces_weight: float = 1.0,
stress_weight: float = 0.0,
# As opposed to forces, energy is predicted per material. By normalizing
# the energy by the number of atoms, forces and energy become comparable
loss_energy_per_atom: bool = True,
**args,
):
super().__init__()

self.loss_energy = L1LossEnergy(**args)
self.loss_forces = L1LossForces(**args)
self.loss_stress = L1LossStress(**args)

self.energy_weight = energy_weight
self.forces_weight = forces_weight
self.stress_weight = stress_weight

self.loss_energy_per_atom = loss_energy_per_atom

def compute_weighted(self, energy_value, forces_value, stress_value):
result = 0.0
# handle initial values correctly when weights are zero, i.e. 0.0*Inf -> NaN
if energy_value is not None and (
not torch.isinf(energy_value).any() or self.energy_weight > 0.0
):
result += self.energy_weight * energy_value
if forces_value is not None and (
not torch.isinf(forces_value).any() or self.forces_weight > 0.0
):
result += self.forces_weight * forces_value
if stress_value is not None and (
not torch.isinf(stress_value).any() or self.stress_weight > 0.0
):
result += self.stress_weight * stress_value
result[key] = component.gather_for_metrics(accelerator)

return result

def forward(self, y_pred, y_true):
loss = Loss(device=y_true.batch.device)

energy_weights = None

if self.loss_energy_per_atom:
num_atoms = y_true.ptr[1:] - y_true.ptr[:-1]
energy_weights = 1.0 / num_atoms

e_true = y_true.y
f_true = y_true['force']
s_true = y_true['stress']
class LossCollection(dict):
def __init__(self, loss_types, device=None):
self.main = Loss(device=device)
for loss_type in loss_types:
self[loss_type] = Loss(device=device)

e_pred = y_pred['energy']
f_pred = y_pred['forces']
s_pred = y_pred['stress']
def __iadd__(self, loss_collection: 'LossCollection'):
self.main += loss_collection.main
for loss_type, loss in loss_collection.items():
self[loss_type] += loss

loss_e = None
loss_f = None
loss_s = None

error_e = None
error_f = None
error_s = None

# Evaluate every loss component
if self.energy_weight > 0.0:
loss_e, error_e = self.loss_energy(e_pred, e_true, energy_weights)
if self.forces_weight > 0.0:
loss_f, error_f = self.loss_forces(f_pred, f_true, y_true.batch)
if self.stress_weight > 0.0:
loss_s, error_s = self.loss_stress(s_pred, s_true)
return self

# Move results to loss object
loss['total'].value += self.compute_weighted(loss_e, loss_f, loss_s)
loss['total'].n += y_true.batch.max() + 1
def gather_for_metrics(self, accelerator):
result = LossCollection(list(self.keys()), device=accelerator.device)

if self.energy_weight > 0.0:
loss['energy'].value = loss_e
loss['energy'].n += e_true.numel()
if self.forces_weight > 0.0:
loss['forces'].value += loss_f
loss['forces'].n += f_true.numel()
if self.stress_weight > 0.0:
loss['stress'].value += loss_s
loss['stress'].n += s_true.numel()
result.main = self.main.gather_for_metrics(accelerator)

error = self.compute_weighted(error_e, error_f, error_s)
for loss_type, loss in self.items():
result[loss_type] = loss.gather_for_metrics(accelerator)

return loss, error
return result
Loading

0 comments on commit fb49e22

Please sign in to comment.