diff --git a/docs/extensions/versions_list.py b/docs/extensions/versions_list.py
index 11f5432b..df03352e 100644
--- a/docs/extensions/versions_list.py
+++ b/docs/extensions/versions_list.py
@@ -43,7 +43,6 @@ def run(self):
:margin: 0 0 0 0\n"""
for group_i, (version_short, group) in enumerate(grouped_versions.items()):
-
if group_i < 3:
generated_content += f"""
.. grid-item::
diff --git a/docs/src/references/changelog.rst b/docs/src/references/changelog.rst
index d5b7960b..d1e86621 100644
--- a/docs/src/references/changelog.rst
+++ b/docs/src/references/changelog.rst
@@ -24,9 +24,16 @@ changelog `_ format. This project follows
`Unreleased `_
-------------------------------------------------------
+Added
+#####
+
+* Added ``dtype`` and ``device`` for ``Calculator`` classses
+
Fixed
#####
+* Ensured consistency of ``dtype`` and ``device`` in the ``Potential`` and
+ ``Calculator`` classses
* Fixed consistency of ``dtype`` and ``device`` in the ``SplinePotential`` class
* Fix inconsistent ``cutoff`` in neighbor list example
* All calculators now check if the cell is zero if the potential is range-separated
diff --git a/src/torchpme/calculators/calculator.py b/src/torchpme/calculators/calculator.py
index 9a35802a..627283e5 100644
--- a/src/torchpme/calculators/calculator.py
+++ b/src/torchpme/calculators/calculator.py
@@ -26,6 +26,8 @@ class Calculator(torch.nn.Module):
will come from a full (True) or half (False, default) neighbor list.
:param prefactor: electrostatics prefactor; see :ref:`prefactors` for details and
common values.
+ :param dtype: type used for the internal buffers and parameters
+ :param device: device used for the internal buffers and parameters
"""
def __init__(
@@ -33,14 +35,28 @@ def __init__(
potential: Potential,
full_neighbor_list: bool = False,
prefactor: float = 1.0,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
):
super().__init__()
- # TorchScript requires to initialize all attributes in __init__
- self._device = torch.device("cpu")
- self._dtype = torch.float32
+ assert isinstance(potential, Potential), (
+ f"Potential must be an instance of Potential, got {type(potential)}"
+ )
+
+ self.device = "cpu" if device is None else device
+ self.dtype = torch.get_default_dtype() if dtype is None else dtype
self.potential = potential
+ assert self.dtype == self.potential.dtype, (
+ f"Potential and Calculator must have the same dtype, got {self.dtype} and "
+ f"{self.potential.dtype}"
+ )
+ assert self.device == self.potential.device, (
+ f"Potential and Calculator must have the same device, got {self.device} and "
+ f"{self.potential.device}"
+ )
+
self.full_neighbor_list = full_neighbor_list
self.prefactor = prefactor
diff --git a/src/torchpme/calculators/ewald.py b/src/torchpme/calculators/ewald.py
index e8bffb5c..d009c4d1 100644
--- a/src/torchpme/calculators/ewald.py
+++ b/src/torchpme/calculators/ewald.py
@@ -1,3 +1,5 @@
+from typing import Optional
+
import torch
from ..lib import generate_kvectors_for_ewald
@@ -53,6 +55,8 @@ class EwaldCalculator(Calculator):
:obj:`False`, a "half" neighbor list is expected.
:param prefactor: electrostatics prefactor; see :ref:`prefactors` for details and
common values.
+ :param dtype: type used for the internal buffers and parameters
+ :param device: device used for the internal buffers and parameters
"""
def __init__(
@@ -61,11 +65,15 @@ def __init__(
lr_wavelength: float,
full_neighbor_list: bool = False,
prefactor: float = 1.0,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
):
super().__init__(
potential=potential,
full_neighbor_list=full_neighbor_list,
prefactor=prefactor,
+ dtype=dtype,
+ device=device,
)
if potential.smearing is None:
raise ValueError(
diff --git a/src/torchpme/calculators/p3m.py b/src/torchpme/calculators/p3m.py
index 58826b08..76a63874 100644
--- a/src/torchpme/calculators/p3m.py
+++ b/src/torchpme/calculators/p3m.py
@@ -1,3 +1,5 @@
+from typing import Optional
+
import torch
from ..lib.kspace_filter import P3MKSpaceFilter
@@ -40,6 +42,8 @@ class P3MCalculator(PMECalculator):
set to :py:obj:`False`, a "half" neighbor list is expected.
:param prefactor: electrostatics prefactor; see :ref:`prefactors` for details and
common values.
+ :param dtype: type used for the internal buffers and parameters
+ :param device: device used for the internal buffers and parameters
For an **example** on the usage for any calculator refer to :ref:`userdoc-how-to`.
"""
@@ -51,6 +55,8 @@ def __init__(
interpolation_nodes: int = 4,
full_neighbor_list: bool = False,
prefactor: float = 1.0,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
):
self.mesh_spacing: float = mesh_spacing
@@ -62,6 +68,8 @@ def __init__(
potential=potential,
full_neighbor_list=full_neighbor_list,
prefactor=prefactor,
+ dtype=dtype,
+ device=device,
)
self.kspace_filter: P3MKSpaceFilter = P3MKSpaceFilter(
diff --git a/src/torchpme/calculators/pme.py b/src/torchpme/calculators/pme.py
index 93c207a8..0d6742cd 100644
--- a/src/torchpme/calculators/pme.py
+++ b/src/torchpme/calculators/pme.py
@@ -1,3 +1,5 @@
+from typing import Optional
+
import torch
from torch import profiler
@@ -45,6 +47,8 @@ class PMECalculator(Calculator):
set to :obj:`False`, a "half" neighbor list is expected.
:param prefactor: electrostatics prefactor; see :ref:`prefactors` for details and
common values.
+ :param dtype: type used for the internal buffers and parameters
+ :param device: device used for the internal buffers and parameters
"""
def __init__(
@@ -54,11 +58,15 @@ def __init__(
interpolation_nodes: int = 4,
full_neighbor_list: bool = False,
prefactor: float = 1.0,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
):
super().__init__(
potential=potential,
full_neighbor_list=full_neighbor_list,
prefactor=prefactor,
+ dtype=dtype,
+ device=device,
)
if potential.smearing is None:
@@ -69,8 +77,8 @@ def __init__(
self.mesh_spacing: float = mesh_spacing
self.kspace_filter: KSpaceFilter = KSpaceFilter(
- cell=torch.eye(3),
- ns_mesh=torch.ones(3, dtype=int),
+ cell=torch.eye(3, dtype=self.dtype, device=self.device),
+ ns_mesh=torch.ones(3, dtype=int, device=self.device),
kernel=self.potential,
fft_norm="backward",
ifft_norm="forward",
@@ -79,8 +87,8 @@ def __init__(
self.interpolation_nodes: int = interpolation_nodes
self.mesh_interpolator: MeshInterpolator = MeshInterpolator(
- cell=torch.eye(3),
- ns_mesh=torch.ones(3, dtype=int),
+ cell=torch.eye(3, dtype=self.dtype, device=self.device),
+ ns_mesh=torch.ones(3, dtype=int, device=self.device),
interpolation_nodes=self.interpolation_nodes,
method="Lagrange", # convention for classic PME
)
diff --git a/src/torchpme/potentials/combined.py b/src/torchpme/potentials/combined.py
index 2c4e2612..d76a20c0 100644
--- a/src/torchpme/potentials/combined.py
+++ b/src/torchpme/potentials/combined.py
@@ -47,10 +47,7 @@ def __init__(
dtype=dtype,
device=device,
)
- if dtype is None:
- dtype = torch.get_default_dtype()
- if device is None:
- device = torch.device("cpu")
+
smearings = [pot.smearing for pot in potentials]
if not all(smearings) and any(smearings):
raise ValueError(
@@ -76,7 +73,9 @@ def __init__(
"The number of initial weights must match the number of potentials being combined"
)
else:
- initial_weights = torch.ones(len(potentials), dtype=dtype, device=device)
+ initial_weights = torch.ones(
+ len(potentials), dtype=self.dtype, device=self.device
+ )
# for torchscript
self.potentials = torch.nn.ModuleList(potentials)
if learnable_weights:
diff --git a/src/torchpme/potentials/coulomb.py b/src/torchpme/potentials/coulomb.py
index f121e38e..4cde5611 100644
--- a/src/torchpme/potentials/coulomb.py
+++ b/src/torchpme/potentials/coulomb.py
@@ -38,19 +38,17 @@ def __init__(
device: Optional[torch.device] = None,
):
super().__init__(smearing, exclusion_radius, dtype, device)
- if dtype is None:
- dtype = torch.get_default_dtype()
- if device is None:
- device = torch.device("cpu")
# constants used in the forwward
self.register_buffer(
"_rsqrt2",
- torch.rsqrt(torch.tensor(2.0, dtype=dtype, device=device)),
+ torch.rsqrt(torch.tensor(2.0, dtype=self.dtype, device=self.device)),
)
self.register_buffer(
"_sqrt_2_on_pi",
- torch.sqrt(torch.tensor(2.0 / torch.pi, dtype=dtype, device=device)),
+ torch.sqrt(
+ torch.tensor(2.0 / torch.pi, dtype=self.dtype, device=self.device)
+ ),
)
def from_dist(self, dist: torch.Tensor) -> torch.Tensor:
diff --git a/src/torchpme/potentials/inversepowerlaw.py b/src/torchpme/potentials/inversepowerlaw.py
index bd44236e..8b2449ca 100644
--- a/src/torchpme/potentials/inversepowerlaw.py
+++ b/src/torchpme/potentials/inversepowerlaw.py
@@ -53,15 +53,11 @@ def __init__(
device: Optional[torch.device] = None,
):
super().__init__(smearing, exclusion_radius, dtype, device)
- if dtype is None:
- dtype = torch.get_default_dtype()
- if device is None:
- device = torch.device("cpu")
if exponent <= 0 or exponent > 3:
raise ValueError(f"`exponent` p={exponent} has to satisfy 0 < p <= 3")
self.register_buffer(
- "exponent", torch.tensor(exponent, dtype=dtype, device=device)
+ "exponent", torch.tensor(exponent, dtype=self.dtype, device=self.device)
)
@torch.jit.export
diff --git a/src/torchpme/potentials/potential.py b/src/torchpme/potentials/potential.py
index fa587896..b5e896ec 100644
--- a/src/torchpme/potentials/potential.py
+++ b/src/torchpme/potentials/potential.py
@@ -42,20 +42,18 @@ def __init__(
device: Optional[torch.device] = None,
):
super().__init__()
- if dtype is None:
- dtype = torch.get_default_dtype()
- if device is None:
- device = torch.device("cpu")
+ self.dtype = torch.get_default_dtype() if dtype is None else dtype
+ self.device = "cpu" if device is None else device
if smearing is not None:
self.register_buffer(
- "smearing", torch.tensor(smearing, device=device, dtype=dtype)
+ "smearing", torch.tensor(smearing, device=self.device, dtype=self.dtype)
)
else:
self.smearing = None
if exclusion_radius is not None:
self.register_buffer(
"exclusion_radius",
- torch.tensor(exclusion_radius, device=device, dtype=dtype),
+ torch.tensor(exclusion_radius, device=self.device, dtype=self.dtype),
)
else:
self.exclusion_radius = None
diff --git a/src/torchpme/potentials/spline.py b/src/torchpme/potentials/spline.py
index a89120f4..f7a3ccd2 100644
--- a/src/torchpme/potentials/spline.py
+++ b/src/torchpme/potentials/spline.py
@@ -66,16 +66,12 @@ def __init__(
dtype=dtype,
device=device,
)
- if dtype is None:
- dtype = torch.get_default_dtype()
- if device is None:
- device = torch.device("cpu")
if len(y_grid) != len(r_grid):
raise ValueError("Length of radial grid and value array mismatch.")
- r_grid = r_grid.to(dtype=dtype, device=device)
- y_grid = y_grid.to(dtype=dtype, device=device)
+ r_grid = r_grid.to(dtype=self.dtype, device=self.device)
+ y_grid = y_grid.to(dtype=self.dtype, device=self.device)
if reciprocal:
if torch.min(r_grid) <= 0.0:
@@ -93,7 +89,7 @@ def __init__(
else:
k_grid = r_grid.clone()
else:
- k_grid = k_grid.to(dtype=dtype, device=device)
+ k_grid = k_grid.to(dtype=self.dtype, device=self.device)
if yhat_grid is None:
# computes automatically!
@@ -104,7 +100,7 @@ def __init__(
compute_second_derivatives(r_grid, y_grid),
)
else:
- yhat_grid = yhat_grid.to(dtype=dtype, device=device)
+ yhat_grid = yhat_grid.to(dtype=self.dtype, device=self.device)
# the function is defined for k**2, so we define the grid accordingly
if reciprocal:
@@ -115,13 +111,15 @@ def __init__(
self._krn_spline = CubicSpline(k_grid**2, yhat_grid)
if y_at_zero is None:
- self._y_at_zero = self._spline(torch.zeros(1, dtype=dtype, device=device))
+ self._y_at_zero = self._spline(
+ torch.zeros(1, dtype=self.dtype, device=self.device)
+ )
else:
self._y_at_zero = y_at_zero
if yhat_at_zero is None:
self._yhat_at_zero = self._krn_spline(
- torch.zeros(1, dtype=dtype, device=device)
+ torch.zeros(1, dtype=self.dtype, device=self.device)
)
else:
self._yhat_at_zero = yhat_at_zero
diff --git a/tests/calculators/test_workflow.py b/tests/calculators/test_workflow.py
index 8d2fcee2..7858b395 100644
--- a/tests/calculators/test_workflow.py
+++ b/tests/calculators/test_workflow.py
@@ -17,9 +17,7 @@
PMECalculator,
)
-AVAILABLE_DEVICES = [torch.device("cpu")] + torch.cuda.is_available() * [
- torch.device("cuda")
-]
+AVAILABLE_DEVICES = ["cpu"] + torch.cuda.is_available() * ["cuda"]
MADELUNG_CSCL = torch.tensor(2 * 1.7626 / math.sqrt(3))
CHARGES_CSCL = torch.tensor([1.0, -1.0])
SMEARING = 0.1
@@ -27,6 +25,7 @@
MESH_SPACING = SMEARING / 4
+@pytest.mark.parametrize("device", AVAILABLE_DEVICES)
@pytest.mark.parametrize(
("CalculatorClass", "params"),
[
@@ -79,49 +78,47 @@ def cscl_system(self, device=None):
neighbor_distances.to(device=device),
)
- def test_smearing_non_positive(self, CalculatorClass, params):
+ def test_smearing_non_positive(self, CalculatorClass, params, device):
params = params.copy()
match = r"`smearing` .* has to be positive"
if type(CalculatorClass) in [EwaldCalculator, PMECalculator]:
params["smearing"] = 0
with pytest.raises(ValueError, match=match):
- CalculatorClass(**params)
+ CalculatorClass(**params, device=device)
params["smearing"] = -0.1
with pytest.raises(ValueError, match=match):
- CalculatorClass(**params)
+ CalculatorClass(**params, device=device)
- def test_interpolation_order_error(self, CalculatorClass, params):
+ def test_interpolation_order_error(self, CalculatorClass, params, device):
params = params.copy()
if type(CalculatorClass) in [PMECalculator]:
match = "Only `interpolation_nodes` from 1 to 5"
params["interpolation_nodes"] = 10
with pytest.raises(ValueError, match=match):
- CalculatorClass(**params)
+ CalculatorClass(**params, device=device)
- def test_lr_wavelength_non_positive(self, CalculatorClass, params):
+ def test_lr_wavelength_non_positive(self, CalculatorClass, params, device):
params = params.copy()
match = r"`lr_wavelength` .* has to be positive"
if type(CalculatorClass) in [EwaldCalculator]:
params["lr_wavelength"] = 0
with pytest.raises(ValueError, match=match):
- CalculatorClass(**params)
+ CalculatorClass(**params, device=device)
params["lr_wavelength"] = -0.1
with pytest.raises(ValueError, match=match):
- CalculatorClass(**params)
+ CalculatorClass(**params, device=device)
- def test_dtype_device(self, CalculatorClass, params):
+ def test_dtype_device(self, CalculatorClass, params, device):
"""Test that the output dtype and device are the same as the input."""
- device = "cpu"
dtype = torch.float64
-
+ params = params.copy()
positions = torch.tensor([[0.0, 0.0, 0.0]], dtype=dtype, device=device)
charges = torch.ones((1, 2), dtype=dtype, device=device)
cell = torch.eye(3, dtype=dtype, device=device)
neighbor_indices = torch.tensor([[0, 0]], device=device)
neighbor_distances = torch.tensor([0.1], device=device)
-
- calculator = CalculatorClass(**params)
-
+ params["potential"].device = device
+ calculator = CalculatorClass(**params, device=device)
potential = calculator.forward(
charges=charges,
cell=cell,
@@ -138,41 +135,48 @@ def check_operation(self, calculator, device):
descriptor = calculator.forward(*self.cscl_system(device))
assert type(descriptor) is torch.Tensor
- @pytest.mark.parametrize("device", AVAILABLE_DEVICES)
def test_operation_as_python(self, CalculatorClass, params, device):
"""Run `check_operation` as a normal python script"""
- calculator = CalculatorClass(**params)
+ params = params.copy()
+ params["potential"].device = device
+ calculator = CalculatorClass(**params, device=device)
self.check_operation(calculator=calculator, device=device)
- @pytest.mark.parametrize("device", AVAILABLE_DEVICES)
def test_operation_as_torch_script(self, CalculatorClass, params, device):
"""Run `check_operation` as a compiled torch script module."""
- calculator = CalculatorClass(**params)
+ params = params.copy()
+ params["potential"].device = device
+ calculator = CalculatorClass(**params, device=device)
scripted = torch.jit.script(calculator)
self.check_operation(calculator=scripted, device=device)
- def test_save_load(self, CalculatorClass, params):
- calculator = CalculatorClass(**params)
+ def test_save_load(self, CalculatorClass, params, device):
+ params = params.copy()
+ params["potential"].device = device
+ calculator = CalculatorClass(**params, device=device)
scripted = torch.jit.script(calculator)
with io.BytesIO() as buffer:
torch.jit.save(scripted, buffer)
buffer.seek(0)
torch.jit.load(buffer)
- def test_prefactor(self, CalculatorClass, params):
+ def test_prefactor(self, CalculatorClass, params, device):
"""Test if the prefactor is applied correctly."""
+ params = params.copy()
+ params["potential"].device = device
prefactor = 2.0
- calculator1 = CalculatorClass(**params)
- calculator2 = CalculatorClass(**params, prefactor=prefactor)
+ calculator1 = CalculatorClass(**params, device=device)
+ calculator2 = CalculatorClass(**params, prefactor=prefactor, device=device)
potentials1 = calculator1.forward(*self.cscl_system())
potentials2 = calculator2.forward(*self.cscl_system())
assert torch.allclose(potentials1 * prefactor, potentials2)
- def test_not_nan(self, CalculatorClass, params):
+ def test_not_nan(self, CalculatorClass, params, device):
"""Make sure derivatives are not NaN."""
- device = "cpu"
+ params = params.copy()
+ params["potential"].device = device
- calculator = CalculatorClass(**params)
+ calculator = CalculatorClass(**params, device=device)
system = self.cscl_system(device)
system[0].requires_grad = True
system[1].requires_grad = True
@@ -198,3 +202,27 @@ def test_not_nan(self, CalculatorClass, params):
assert not torch.isnan(
torch.autograd.grad(energy, system[2], retain_graph=True)[0]
).any()
+
+ def test_dtype_and_device_incompatability(self, CalculatorClass, params, device):
+ """Test that the calculator raises an error if the dtype and device are incompatible."""
+ params = params.copy()
+ params["potential"].device = device
+ params["potential"].dtype = torch.float64
+ with pytest.raises(AssertionError, match=".*dtype.*"):
+ CalculatorClass(**params, dtype=torch.float32, device=device)
+ with pytest.raises(AssertionError, match=".*device.*"):
+ CalculatorClass(
+ **params, dtype=params["potential"].dtype, device=torch.device("meta")
+ )
+
+ def test_potential_and_calculator_incompatability(
+ self, CalculatorClass, params, device
+ ):
+ """Test that the calculator raises an error if the potential and calculator are incompatible."""
+ params = params.copy()
+ params["potential"].device = device
+ params["potential"] = torch.jit.script(params["potential"])
+ with pytest.raises(
+ AssertionError, match="Potential must be an instance of Potential, got.*"
+ ):
+ CalculatorClass(**params)
diff --git a/tests/test_potentials.py b/tests/test_potentials.py
index 87d590bd..87670b3b 100644
--- a/tests/test_potentials.py
+++ b/tests/test_potentials.py
@@ -526,8 +526,7 @@ def test_combined_potentials_jit(smearing):
)
combo = CombinedPotential(potentials=[spline, coulomb], dtype=dtype, smearing=1.0)
- jitcombo = torch.jit.script(combo)
- mypme = PMECalculator(jitcombo, mesh_spacing=1.0)
+ mypme = PMECalculator(combo, mesh_spacing=1.0, dtype=dtype)
_ = torch.jit.script(mypme)