Skip to content

Commit

Permalink
Add Explicit dtype and device Support for Calculators and Ensure …
Browse files Browse the repository at this point in the history
…Compatibility with Potentials (#143)

* Refactor parameter handling in calculators and potentials for improved dtype and device management
* Updated docstrings and changelog, added an assertion to check for an instance of the potential, and resolved the TorchScript Potential/Calculator incompatibility.
* Update changelog and add test for potential and calculator compatibility
  • Loading branch information
E-Rum authored and GardevoirX committed Jan 16, 2025
1 parent 048e8be commit d8e675e
Show file tree
Hide file tree
Showing 13 changed files with 133 additions and 71 deletions.
1 change: 0 additions & 1 deletion docs/extensions/versions_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def run(self):
:margin: 0 0 0 0\n"""

for group_i, (version_short, group) in enumerate(grouped_versions.items()):

if group_i < 3:
generated_content += f"""
.. grid-item::
Expand Down
7 changes: 7 additions & 0 deletions docs/src/references/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,16 @@ changelog <https://keepachangelog.com/en/1.1.0/>`_ format. This project follows
`Unreleased <https://github.com/lab-cosmo/torch-pme/>`_
-------------------------------------------------------

Added
#####

* Added ``dtype`` and ``device`` for ``Calculator`` classses

Fixed
#####

* Ensured consistency of ``dtype`` and ``device`` in the ``Potential`` and
``Calculator`` classses
* Fixed consistency of ``dtype`` and ``device`` in the ``SplinePotential`` class
* Fix inconsistent ``cutoff`` in neighbor list example
* All calculators now check if the cell is zero if the potential is range-separated
Expand Down
22 changes: 19 additions & 3 deletions src/torchpme/calculators/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,37 @@ class Calculator(torch.nn.Module):
will come from a full (True) or half (False, default) neighbor list.
:param prefactor: electrostatics prefactor; see :ref:`prefactors` for details and
common values.
:param dtype: type used for the internal buffers and parameters
:param device: device used for the internal buffers and parameters
"""

def __init__(
self,
potential: Potential,
full_neighbor_list: bool = False,
prefactor: float = 1.0,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
super().__init__()
# TorchScript requires to initialize all attributes in __init__
self._device = torch.device("cpu")
self._dtype = torch.float32

assert isinstance(potential, Potential), (
f"Potential must be an instance of Potential, got {type(potential)}"
)

self.device = "cpu" if device is None else device
self.dtype = torch.get_default_dtype() if dtype is None else dtype
self.potential = potential

assert self.dtype == self.potential.dtype, (
f"Potential and Calculator must have the same dtype, got {self.dtype} and "
f"{self.potential.dtype}"
)
assert self.device == self.potential.device, (
f"Potential and Calculator must have the same device, got {self.device} and "
f"{self.potential.device}"
)

self.full_neighbor_list = full_neighbor_list

self.prefactor = prefactor
Expand Down
8 changes: 8 additions & 0 deletions src/torchpme/calculators/ewald.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

import torch

from ..lib import generate_kvectors_for_ewald
Expand Down Expand Up @@ -53,6 +55,8 @@ class EwaldCalculator(Calculator):
:obj:`False`, a "half" neighbor list is expected.
:param prefactor: electrostatics prefactor; see :ref:`prefactors` for details and
common values.
:param dtype: type used for the internal buffers and parameters
:param device: device used for the internal buffers and parameters
"""

def __init__(
Expand All @@ -61,11 +65,15 @@ def __init__(
lr_wavelength: float,
full_neighbor_list: bool = False,
prefactor: float = 1.0,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
super().__init__(
potential=potential,
full_neighbor_list=full_neighbor_list,
prefactor=prefactor,
dtype=dtype,
device=device,
)
if potential.smearing is None:
raise ValueError(
Expand Down
8 changes: 8 additions & 0 deletions src/torchpme/calculators/p3m.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

import torch

from ..lib.kspace_filter import P3MKSpaceFilter
Expand Down Expand Up @@ -40,6 +42,8 @@ class P3MCalculator(PMECalculator):
set to :py:obj:`False`, a "half" neighbor list is expected.
:param prefactor: electrostatics prefactor; see :ref:`prefactors` for details and
common values.
:param dtype: type used for the internal buffers and parameters
:param device: device used for the internal buffers and parameters
For an **example** on the usage for any calculator refer to :ref:`userdoc-how-to`.
"""
Expand All @@ -51,6 +55,8 @@ def __init__(
interpolation_nodes: int = 4,
full_neighbor_list: bool = False,
prefactor: float = 1.0,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
self.mesh_spacing: float = mesh_spacing

Expand All @@ -62,6 +68,8 @@ def __init__(
potential=potential,
full_neighbor_list=full_neighbor_list,
prefactor=prefactor,
dtype=dtype,
device=device,
)

self.kspace_filter: P3MKSpaceFilter = P3MKSpaceFilter(
Expand Down
16 changes: 12 additions & 4 deletions src/torchpme/calculators/pme.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

import torch
from torch import profiler

Expand Down Expand Up @@ -45,6 +47,8 @@ class PMECalculator(Calculator):
set to :obj:`False`, a "half" neighbor list is expected.
:param prefactor: electrostatics prefactor; see :ref:`prefactors` for details and
common values.
:param dtype: type used for the internal buffers and parameters
:param device: device used for the internal buffers and parameters
"""

def __init__(
Expand All @@ -54,11 +58,15 @@ def __init__(
interpolation_nodes: int = 4,
full_neighbor_list: bool = False,
prefactor: float = 1.0,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
super().__init__(
potential=potential,
full_neighbor_list=full_neighbor_list,
prefactor=prefactor,
dtype=dtype,
device=device,
)

if potential.smearing is None:
Expand All @@ -69,8 +77,8 @@ def __init__(
self.mesh_spacing: float = mesh_spacing

self.kspace_filter: KSpaceFilter = KSpaceFilter(
cell=torch.eye(3),
ns_mesh=torch.ones(3, dtype=int),
cell=torch.eye(3, dtype=self.dtype, device=self.device),
ns_mesh=torch.ones(3, dtype=int, device=self.device),
kernel=self.potential,
fft_norm="backward",
ifft_norm="forward",
Expand All @@ -79,8 +87,8 @@ def __init__(
self.interpolation_nodes: int = interpolation_nodes

self.mesh_interpolator: MeshInterpolator = MeshInterpolator(
cell=torch.eye(3),
ns_mesh=torch.ones(3, dtype=int),
cell=torch.eye(3, dtype=self.dtype, device=self.device),
ns_mesh=torch.ones(3, dtype=int, device=self.device),
interpolation_nodes=self.interpolation_nodes,
method="Lagrange", # convention for classic PME
)
Expand Down
9 changes: 4 additions & 5 deletions src/torchpme/potentials/combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,7 @@ def __init__(
dtype=dtype,
device=device,
)
if dtype is None:
dtype = torch.get_default_dtype()
if device is None:
device = torch.device("cpu")

smearings = [pot.smearing for pot in potentials]
if not all(smearings) and any(smearings):
raise ValueError(
Expand All @@ -76,7 +73,9 @@ def __init__(
"The number of initial weights must match the number of potentials being combined"
)
else:
initial_weights = torch.ones(len(potentials), dtype=dtype, device=device)
initial_weights = torch.ones(
len(potentials), dtype=self.dtype, device=self.device
)
# for torchscript
self.potentials = torch.nn.ModuleList(potentials)
if learnable_weights:
Expand Down
10 changes: 4 additions & 6 deletions src/torchpme/potentials/coulomb.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,17 @@ def __init__(
device: Optional[torch.device] = None,
):
super().__init__(smearing, exclusion_radius, dtype, device)
if dtype is None:
dtype = torch.get_default_dtype()
if device is None:
device = torch.device("cpu")

# constants used in the forwward
self.register_buffer(
"_rsqrt2",
torch.rsqrt(torch.tensor(2.0, dtype=dtype, device=device)),
torch.rsqrt(torch.tensor(2.0, dtype=self.dtype, device=self.device)),
)
self.register_buffer(
"_sqrt_2_on_pi",
torch.sqrt(torch.tensor(2.0 / torch.pi, dtype=dtype, device=device)),
torch.sqrt(
torch.tensor(2.0 / torch.pi, dtype=self.dtype, device=self.device)
),
)

def from_dist(self, dist: torch.Tensor) -> torch.Tensor:
Expand Down
6 changes: 1 addition & 5 deletions src/torchpme/potentials/inversepowerlaw.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,11 @@ def __init__(
device: Optional[torch.device] = None,
):
super().__init__(smearing, exclusion_radius, dtype, device)
if dtype is None:
dtype = torch.get_default_dtype()
if device is None:
device = torch.device("cpu")

if exponent <= 0 or exponent > 3:
raise ValueError(f"`exponent` p={exponent} has to satisfy 0 < p <= 3")
self.register_buffer(
"exponent", torch.tensor(exponent, dtype=dtype, device=device)
"exponent", torch.tensor(exponent, dtype=self.dtype, device=self.device)
)

@torch.jit.export
Expand Down
10 changes: 4 additions & 6 deletions src/torchpme/potentials/potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,20 +42,18 @@ def __init__(
device: Optional[torch.device] = None,
):
super().__init__()
if dtype is None:
dtype = torch.get_default_dtype()
if device is None:
device = torch.device("cpu")
self.dtype = torch.get_default_dtype() if dtype is None else dtype
self.device = "cpu" if device is None else device
if smearing is not None:
self.register_buffer(
"smearing", torch.tensor(smearing, device=device, dtype=dtype)
"smearing", torch.tensor(smearing, device=self.device, dtype=self.dtype)
)
else:
self.smearing = None
if exclusion_radius is not None:
self.register_buffer(
"exclusion_radius",
torch.tensor(exclusion_radius, device=device, dtype=dtype),
torch.tensor(exclusion_radius, device=self.device, dtype=self.dtype),
)
else:
self.exclusion_radius = None
Expand Down
18 changes: 8 additions & 10 deletions src/torchpme/potentials/spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,12 @@ def __init__(
dtype=dtype,
device=device,
)
if dtype is None:
dtype = torch.get_default_dtype()
if device is None:
device = torch.device("cpu")

if len(y_grid) != len(r_grid):
raise ValueError("Length of radial grid and value array mismatch.")

r_grid = r_grid.to(dtype=dtype, device=device)
y_grid = y_grid.to(dtype=dtype, device=device)
r_grid = r_grid.to(dtype=self.dtype, device=self.device)
y_grid = y_grid.to(dtype=self.dtype, device=self.device)

if reciprocal:
if torch.min(r_grid) <= 0.0:
Expand All @@ -93,7 +89,7 @@ def __init__(
else:
k_grid = r_grid.clone()
else:
k_grid = k_grid.to(dtype=dtype, device=device)
k_grid = k_grid.to(dtype=self.dtype, device=self.device)

if yhat_grid is None:
# computes automatically!
Expand All @@ -104,7 +100,7 @@ def __init__(
compute_second_derivatives(r_grid, y_grid),
)
else:
yhat_grid = yhat_grid.to(dtype=dtype, device=device)
yhat_grid = yhat_grid.to(dtype=self.dtype, device=self.device)

# the function is defined for k**2, so we define the grid accordingly
if reciprocal:
Expand All @@ -115,13 +111,15 @@ def __init__(
self._krn_spline = CubicSpline(k_grid**2, yhat_grid)

if y_at_zero is None:
self._y_at_zero = self._spline(torch.zeros(1, dtype=dtype, device=device))
self._y_at_zero = self._spline(
torch.zeros(1, dtype=self.dtype, device=self.device)
)
else:
self._y_at_zero = y_at_zero

if yhat_at_zero is None:
self._yhat_at_zero = self._krn_spline(
torch.zeros(1, dtype=dtype, device=device)
torch.zeros(1, dtype=self.dtype, device=self.device)
)
else:
self._yhat_at_zero = yhat_at_zero
Expand Down
Loading

0 comments on commit d8e675e

Please sign in to comment.