Skip to content

Commit

Permalink
Refactor parameter validation in _validate_parameters and remove unus…
Browse files Browse the repository at this point in the history
…ed dtype and device parameters from Calculator
  • Loading branch information
E-Rum committed Feb 4, 2025
1 parent a680e9b commit c9111df
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 98 deletions.
36 changes: 13 additions & 23 deletions src/torchpme/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,10 @@ def _validate_parameters(
neighbor_indices: torch.Tensor,
neighbor_distances: torch.Tensor,
smearing: Union[float, None],
dtype: torch.dtype,
device: torch.device,
) -> None:
if positions.dtype != dtype:
raise TypeError(
f"type of `positions` ({positions.dtype}) must be same as the class "
f"type ({dtype})"
)

if positions.device != device:
raise ValueError(
f"device of `positions` ({positions.device}) must be same as the class "
f"device ({device})"
)

dtype = positions.dtype
device = positions.device

# check shape, dtype and device of positions
num_atoms = len(positions)
Expand All @@ -40,14 +30,14 @@ def _validate_parameters(
f"{list(cell.shape)}"
)

if cell.dtype != positions.dtype:
if cell.dtype != dtype:
raise TypeError(
f"type of `cell` ({cell.dtype}) must be same as the class ({dtype})"
f"type of `cell` ({cell.dtype}) must be same as that of the `positions` class ({dtype})"
)

if cell.device != device:
raise ValueError(
f"device of `cell` ({cell.device}) must be same as the class ({device})"
f"device of `cell` ({cell.device}) must be same as that of the `positions` class ({device})"
)

if smearing is not None and torch.equal(
Expand All @@ -74,14 +64,14 @@ def _validate_parameters(
f"{len(positions)} atoms"
)

if charges.dtype != positions.dtype:
if charges.dtype != dtype:
raise TypeError(
f"type of `charges` ({charges.dtype}) must be same as the class ({dtype})"
f"type of `charges` ({charges.dtype}) must be same as that of the `positions` class ({dtype})"
)

if charges.device != device:
raise ValueError(
f"device of `charges` ({charges.device}) must be same as the class "
f"device of `charges` ({charges.device}) must be same as that of the `positions` class "
f"({device})"
)

Expand All @@ -96,7 +86,7 @@ def _validate_parameters(
if neighbor_indices.device != device:
raise ValueError(
f"device of `neighbor_indices` ({neighbor_indices.device}) must be "
f"same as the class ({device})"
f"same as that of the `positions` class ({device})"
)

if neighbor_distances.shape != neighbor_indices[:, 0].shape:
Expand All @@ -109,11 +99,11 @@ def _validate_parameters(
if neighbor_distances.device != device:
raise ValueError(
f"device of `neighbor_distances` ({neighbor_distances.device}) must be "
f"same as the class ({device})"
f"same as that of the `positions` class ({device})"
)

if neighbor_distances.dtype != positions.dtype:
if neighbor_distances.dtype != dtype:
raise TypeError(
f"type of `neighbor_distances` ({neighbor_distances.dtype}) must be same "
f"as the class ({dtype})"
f"as that of the `positions` class ({dtype})"
)
22 changes: 1 addition & 21 deletions src/torchpme/calculators/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from torch import profiler

from .._utils import _get_device, _get_dtype, _validate_parameters
from .._utils import _validate_parameters
from ..potentials import Potential


Expand Down Expand Up @@ -36,8 +36,6 @@ def __init__(
potential: Potential,
full_neighbor_list: bool = False,
prefactor: float = 1.0,
dtype: Optional[torch.dtype] = None,
device: Union[None, str, torch.device] = None,
):
super().__init__()

Expand All @@ -46,24 +44,8 @@ def __init__(
f"Potential must be an instance of Potential, got {type(potential)}"
)

self.device = _get_device(device)
self.dtype = _get_dtype(dtype)

if self.dtype != potential.dtype:
raise TypeError(
f"dtype of `potential` ({potential.dtype}) must be same as of "
f"`calculator` ({self.dtype})"
)

if self.device != potential.device:
raise ValueError(
f"device of `potential` ({potential.device}) must be same as of "
f"`calculator` ({self.device})"
)

self.potential = potential
self.full_neighbor_list = full_neighbor_list

self.prefactor = prefactor

def _compute_rspace(
Expand Down Expand Up @@ -164,8 +146,6 @@ def forward(
neighbor_indices=neighbor_indices,
neighbor_distances=neighbor_distances,
smearing=self.potential.smearing,
dtype=self.dtype,
device=self.device,
)

# Compute short-range (SR) part using a real space sum
Expand Down
18 changes: 9 additions & 9 deletions src/torchpme/lib/splines.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ def _solve_tridiagonal(a, b, c, d):
"""
n = len(d)
# Create copies to avoid modifying the original arrays
c_prime = torch.zeros(n)
d_prime = torch.zeros(n)
c_prime = torch.zeros_like(d)
d_prime = torch.zeros_like(d)

# Initial coefficients
c_prime[0] = c[0] / b[0]
Expand All @@ -141,7 +141,7 @@ def _solve_tridiagonal(a, b, c, d):
d_prime[i] = (d[i] - a[i] * d_prime[i - 1]) / denom

# Backward substitution
x = torch.zeros(n)
x = torch.zeros_like(d)
x[-1] = d_prime[-1]
for i in reversed(range(n - 1)):
x[i] = d_prime[i] - c_prime[i] * x[i + 1]
Expand Down Expand Up @@ -174,13 +174,13 @@ def compute_second_derivatives(
dy = (y[1:] - y[:-1]) / intervals

# Create zero boundary conditions (natural spline)
d2y = torch.zeros_like(x, dtype=torch.float64)
d2y = torch.zeros_like(x)

n = len(x)
a = torch.zeros(n) # Sub-diagonal (a[1..n-1])
b = torch.zeros(n) # Main diagonal (b[0..n-1])
c = torch.zeros(n) # Super-diagonal (c[0..n-2])
d = torch.zeros(n) # Right-hand side (d[0..n-1])
a = torch.zeros_like(x) # Sub-diagonal (a[1..n-1])
b = torch.zeros_like(x) # Main diagonal (b[0..n-1])
c = torch.zeros_like(x) # Super-diagonal (c[0..n-2])
d = torch.zeros_like(x) # Right-hand side (d[0..n-1])

# Natural spline boundary conditions
b[0] = 1
Expand All @@ -198,7 +198,7 @@ def compute_second_derivatives(
d2y = _solve_tridiagonal(a, b, c, d)

# Converts back to the original dtype
return d2y.to(dtype=x_points.dtype, device=x_points.device)
return d2y


def compute_spline_ft(
Expand Down
16 changes: 2 additions & 14 deletions src/torchpme/potentials/coulomb.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,6 @@ def __init__(
):
super().__init__(smearing, exclusion_radius)

# constants used in the forwward
self.register_buffer(
"_rsqrt2",
torch.rsqrt(torch.tensor(2.0)),
)
self.register_buffer(
"_sqrt_2_on_pi",
torch.sqrt(
torch.tensor(2.0 / torch.pi)
),
)

def from_dist(self, dist: torch.Tensor) -> torch.Tensor:
"""
Full :math:`1/r` potential as a function of :math:`r`.
Expand All @@ -73,7 +61,7 @@ def lr_from_dist(self, dist: torch.Tensor) -> torch.Tensor:
"Cannot compute long-range contribution without specifying `smearing`."
)

return torch.erf(dist * (self._rsqrt2 / self.smearing)) / dist
return torch.erf(dist / self.smearing / 2.0 ** 0.5) / dist

def lr_from_k_sq(self, k_sq: torch.Tensor) -> torch.Tensor:
r"""
Expand Down Expand Up @@ -103,7 +91,7 @@ def self_contribution(self) -> torch.Tensor:
raise ValueError(
"Cannot compute self contribution without specifying `smearing`."
)
return self._sqrt_2_on_pi / self.smearing
return (2 / torch.pi) ** 0.5 / self.smearing

def background_correction(self) -> torch.Tensor:
# "charge neutrality" correction for 1/r potential
Expand Down
22 changes: 8 additions & 14 deletions src/torchpme/potentials/inversepowerlaw.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,9 @@ def lr_from_dist(self, dist: torch.Tensor) -> torch.Tensor:
"Cannot compute long-range contribution without specifying `smearing`."
)

exponent = self.exponent
smearing = self.smearing

x = 0.5 * dist**2 / smearing**2
peff = exponent / 2
prefac = 1.0 / (2 * smearing**2) ** peff
x = 0.5 * dist**2 / self.smearing**2
peff = self.exponent / 2
prefac = 1.0 / (2 * self.smearing**2) ** peff
return prefac * gammainc(peff, x) / x**peff

@torch.jit.export
Expand All @@ -101,12 +98,9 @@ def lr_from_k_sq(self, k_sq: torch.Tensor) -> torch.Tensor:
"Cannot compute long-range kernel without specifying `smearing`."
)

exponent = self.exponent
smearing = self.smearing

peff = (3 - exponent) / 2
prefac = torch.pi**1.5 / gamma(exponent / 2) * (2 * smearing**2) ** peff
x = 0.5 * smearing**2 * k_sq
peff = (3 - self.exponent) / 2
prefac = torch.pi**1.5 / gamma(self.exponent / 2) * (2 * self.smearing**2) ** peff
x = 0.5 * self.smearing**2 * k_sq

# The k=0 term often needs to be set separately since for exponents p<=3
# dimension, there is a divergence to +infinity. Setting this value manually
Expand All @@ -117,7 +111,7 @@ def lr_from_k_sq(self, k_sq: torch.Tensor) -> torch.Tensor:
# for consistency reasons.
masked = torch.where(x == 0, 1.0, x) # avoid NaNs in backwards, see Coulomb
return torch.where(
k_sq == 0, 0.0, prefac * gammaincc_over_powerlaw(exponent, masked)
k_sq == 0, 0.0, prefac * gammaincc_over_powerlaw(self.exponent, masked)
)

def self_contribution(self) -> torch.Tensor:
Expand All @@ -138,7 +132,7 @@ def background_correction(self) -> torch.Tensor:
"Cannot compute background correction without specifying `smearing`."
)
if self.exponent >= 3:
return self.smearing * 0.0
return torch.zero_like(self.smearing)
prefac = torch.pi**1.5 * (2 * self.smearing**2) ** ((3 - self.exponent) / 2)
prefac /= (3 - self.exponent) * gamma(self.exponent / 2)
return prefac
Expand Down
15 changes: 2 additions & 13 deletions src/torchpme/potentials/potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,8 @@ def __init__(
):
super().__init__()

if smearing is not None:
self.register_buffer(
"smearing", torch.tensor(smearing)
)
else:
self.smearing = None
if exclusion_radius is not None:
self.register_buffer(
"exclusion_radius",
torch.tensor(exclusion_radius),
)
else:
self.exclusion_radius = None
self.smearing = smearing
self.exclusion_radius = exclusion_radius

@torch.jit.export
def f_cutoff(self, dist: torch.Tensor) -> torch.Tensor:
Expand Down
8 changes: 4 additions & 4 deletions src/torchpme/potentials/spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ def __init__(
k_grid = torch.pi * 2 * torch.reciprocal(r_grid).flip(dims=[0])
else:
k_grid = r_grid.clone().detach()
else:
self.register_buffer("k_grid", k_grid)

self.register_buffer("k_grid", k_grid)

if yhat_grid is None:
# computes automatically!
Expand All @@ -95,8 +95,8 @@ def __init__(
y_grid,
compute_second_derivatives(r_grid, y_grid),
)
else:
self.register_buffer("yhat_grid", yhat_grid)

self.register_buffer("yhat_grid", yhat_grid)

# the function is defined for k**2, so we define the grid accordingly
if reciprocal:
Expand Down

0 comments on commit c9111df

Please sign in to comment.