diff --git a/docs/src/references/changelog.rst b/docs/src/references/changelog.rst
index ee17fdf4..4d312f14 100644
--- a/docs/src/references/changelog.rst
+++ b/docs/src/references/changelog.rst
@@ -27,15 +27,13 @@ changelog `_ format. This project follows
Added
#####
-* Enhanced ``device`` and ``dtype`` consistency checks throughout the library
* Better documentation for for ``cell``, ``charges`` and ``positions`` parameters
-* Require consistent ``dtype`` between ``positions`` and ``neighbor_distances`` in
- ``Calculator`` classes and tuning functions.
-Fixed
-#####
+Removed
+#######
-* Fix ``device`` and ``dtype`` not being used in the init of the ``P3MCalculator``
+* Remove ``device`` and ``dtype`` from init of ``Calculator``, ``Potential`` and
+ ``Tuning`` classes
`Version 0.2.0 `_ - 2025-01-23
------------------------------------------------------------------------------------------
diff --git a/examples/01-charges-example.py b/examples/01-charges-example.py
index 6465a507..c4038bda 100644
--- a/examples/01-charges-example.py
+++ b/examples/01-charges-example.py
@@ -73,7 +73,6 @@
cutoff=cutoff,
neighbor_indices=neighbor_indices,
neighbor_distances=neighbor_distances,
- dtype=dtype,
)
# %%
@@ -103,9 +102,9 @@
# will be used to *compute* the potential energy of the system.
calculator = torchpme.PMECalculator(
- torchpme.CoulombPotential(smearing=smearing, dtype=dtype), dtype=dtype, **pme_params
+ torchpme.CoulombPotential(smearing=smearing), **pme_params
)
-
+calculator.to(dtype=dtype)
# %%
#
# Single Charge Channel
@@ -207,9 +206,9 @@
# creating a new calculator with the metatensor interface.
calculator_metatensor = torchpme.metatensor.PMECalculator(
- torchpme.CoulombPotential(smearing=smearing, dtype=dtype), dtype=dtype, **pme_params
+ torchpme.CoulombPotential(smearing=smearing), **pme_params
)
-
+calculator_metatensor.to(dtype=dtype)
# %%
#
# Computation with metatensor involves using Metatensor's :class:`System
diff --git a/examples/02-neighbor-lists-usage.py b/examples/02-neighbor-lists-usage.py
index e72689f0..bdda04eb 100644
--- a/examples/02-neighbor-lists-usage.py
+++ b/examples/02-neighbor-lists-usage.py
@@ -110,7 +110,6 @@
cutoff=cutoff,
neighbor_indices=neighbor_indices,
neighbor_distances=neighbor_distances,
- dtype=dtype,
)
# %%
@@ -195,8 +194,7 @@ def distances(
# compute the potential.
pme = torchpme.PMECalculator(
- potential=torchpme.CoulombPotential(smearing=smearing, dtype=dtype),
- dtype=dtype,
+ potential=torchpme.CoulombPotential(smearing=smearing),
**pme_params,
)
potential = pme(
diff --git a/examples/07-lode-demo.py b/examples/07-lode-demo.py
index 59feb341..6588f10d 100644
--- a/examples/07-lode-demo.py
+++ b/examples/07-lode-demo.py
@@ -420,7 +420,7 @@ def __init__(self, potential: Potential, n_grid: int = 3):
)
# assumes a smooth exclusion region so sets the integration cutoff to half that
- nodes, weights = get_full_grid(n_grid, potential.exclusion_radius.item() / 2)
+ nodes, weights = get_full_grid(n_grid, potential.exclusion_radius / 2)
# these are the "stencils" used to project the potential
# on an atom-centered basis. NB: weights might also be incorporated
diff --git a/examples/08-combined-potential.py b/examples/08-combined-potential.py
index 8d0667bb..73e1c077 100644
--- a/examples/08-combined-potential.py
+++ b/examples/08-combined-potential.py
@@ -67,10 +67,12 @@
# evaluation, and so one has to set it also for the combined potential, even if it is
# not used explicitly in the evaluation of the combination.
-pot_1 = InversePowerLawPotential(exponent=1, smearing=smearing, dtype=dtype)
-pot_2 = InversePowerLawPotential(exponent=2, smearing=smearing, dtype=dtype)
-
-potential = CombinedPotential(potentials=[pot_1, pot_2], smearing=smearing, dtype=dtype)
+pot_1 = InversePowerLawPotential(exponent=1, smearing=smearing)
+pot_2 = InversePowerLawPotential(exponent=2, smearing=smearing)
+pot_1 = pot_1.to(dtype=dtype)
+pot_2 = pot_2.to(dtype=dtype)
+potential = CombinedPotential(potentials=[pot_1, pot_2], smearing=smearing)
+potential = potential.to(dtype=dtype)
# Note also that :class:`CombinedPotential` can be used with any combination of
# potentials, as long they are all either direct or range separated. For instance, one
@@ -156,9 +158,9 @@
# much bigger system.
calculator = EwaldCalculator(
- potential=potential, lr_wavelength=lr_wavelength, prefactor=eV_A, dtype=dtype
+ potential=potential, lr_wavelength=lr_wavelength, prefactor=eV_A
)
-
+calculator.to(dtype=dtype)
# %%
#
diff --git a/examples/10-tuning.py b/examples/10-tuning.py
index 183b0f45..0978cb98 100644
--- a/examples/10-tuning.py
+++ b/examples/10-tuning.py
@@ -120,12 +120,10 @@
pme_params = {"mesh_spacing": 1.0, "interpolation_nodes": 4}
pme = torchpme.PMECalculator(
- potential=torchpme.CoulombPotential(smearing=smearing, device=device, dtype=dtype),
- device=device,
- dtype=dtype,
+ potential=torchpme.CoulombPotential(smearing=smearing),
**pme_params, # type: ignore[arg-type]
)
-
+pme.to(device=device, dtype=dtype)
# %%
# Run the calculator
# ~~~~~~~~~~~~~~~~~~
@@ -170,8 +168,6 @@
neighbor_indices=neighbor_indices,
neighbor_distances=neighbor_distances,
run_backward=True,
- device=device,
- dtype=dtype,
)
estimated_timing = timings(pme)
@@ -220,14 +216,11 @@ def timed_madelung(cutoff, smearing, mesh_spacing, interpolation_nodes, device,
)
pme = torchpme.PMECalculator(
- potential=torchpme.CoulombPotential(
- smearing=smearing, device=device, dtype=dtype
- ),
+ potential=torchpme.CoulombPotential(smearing=smearing),
mesh_spacing=mesh_spacing,
interpolation_nodes=interpolation_nodes,
- device=device,
- dtype=dtype,
)
+ pme.to(device=device, dtype=dtype)
potential = pme(
charges=charges,
cell=cell,
@@ -247,8 +240,6 @@ def timed_madelung(cutoff, smearing, mesh_spacing, interpolation_nodes, device,
run_backward=True,
n_warmup=1,
n_repeat=4,
- device=device,
- dtype=dtype,
)
estimated_timing = timings(pme)
return madelung, estimated_timing
@@ -457,8 +448,6 @@ def timed_madelung(cutoff, smearing, mesh_spacing, interpolation_nodes, device,
cutoff=5.0,
neighbor_indices=neighbor_indices,
neighbor_distances=neighbor_distances,
- device=device,
- dtype=dtype,
)
print(
@@ -492,8 +481,6 @@ def timed_madelung(cutoff, smearing, mesh_spacing, interpolation_nodes, device,
cutoff=cutoff,
neighbor_indices=filter_indices,
neighbor_distances=filter_distances,
- device=device,
- dtype=dtype,
)
timings_grid.append(timing)
diff --git a/examples/basic-usage.py b/examples/basic-usage.py
index 47bc1506..3be525b2 100644
--- a/examples/basic-usage.py
+++ b/examples/basic-usage.py
@@ -146,7 +146,8 @@
# contains all the necessary functions (such as those defining the short-range and
# long-range splits) for this potential and makes them useable in the rest of the code.
-potential = CoulombPotential(smearing=smearing, device=device, dtype=dtype)
+potential = CoulombPotential(smearing=smearing)
+potential.to(device=device, dtype=dtype)
# %%
#
@@ -193,10 +194,8 @@
# Since our structure is relatively small, we use the :class:`EwaldCalculator`.
# We start by the initialization of the class.
-calculator = EwaldCalculator(
- potential=potential, lr_wavelength=lr_wavelength, device=device, dtype=dtype
-)
-
+calculator = EwaldCalculator(potential=potential, lr_wavelength=lr_wavelength)
+calculator.to(device=device, dtype=dtype)
# %%
#
# Compute Energy
diff --git a/src/torchpme/_utils.py b/src/torchpme/_utils.py
index b7b4063c..91568ecf 100644
--- a/src/torchpme/_utils.py
+++ b/src/torchpme/_utils.py
@@ -1,23 +1,8 @@
-from typing import Optional, Union
+from typing import Union
import torch
-def _get_dtype(dtype: Optional[torch.dtype]) -> torch.dtype:
- return torch.get_default_dtype() if dtype is None else dtype
-
-
-def _get_device(device: Union[None, str, torch.device]) -> torch.device:
- new_device = torch.get_default_device() if device is None else torch.device(device)
-
- # Add default index of 0 to a cuda device to avoid errors when comparing with
- # devices from tensors
- if new_device.type == "cuda" and new_device.index is None:
- new_device = torch.device("cuda:0")
-
- return new_device
-
-
def _validate_parameters(
charges: torch.Tensor,
cell: torch.Tensor,
@@ -25,20 +10,9 @@ def _validate_parameters(
neighbor_indices: torch.Tensor,
neighbor_distances: torch.Tensor,
smearing: Union[float, None],
- dtype: torch.dtype,
- device: torch.device,
) -> None:
- if positions.dtype != dtype:
- raise TypeError(
- f"type of `positions` ({positions.dtype}) must be same as the class "
- f"type ({dtype})"
- )
-
- if positions.device != device:
- raise ValueError(
- f"device of `positions` ({positions.device}) must be same as the class "
- f"device ({device})"
- )
+ dtype = positions.dtype
+ device = positions.device
# check shape, dtype and device of positions
num_atoms = len(positions)
@@ -55,14 +29,14 @@ def _validate_parameters(
f"{list(cell.shape)}"
)
- if cell.dtype != positions.dtype:
+ if cell.dtype != dtype:
raise TypeError(
- f"type of `cell` ({cell.dtype}) must be same as the class ({dtype})"
+ f"type of `cell` ({cell.dtype}) must be same as that of the `positions` class ({dtype})"
)
if cell.device != device:
raise ValueError(
- f"device of `cell` ({cell.device}) must be same as the class ({device})"
+ f"device of `cell` ({cell.device}) must be same as that of the `positions` class ({device})"
)
if smearing is not None and torch.equal(
@@ -89,14 +63,14 @@ def _validate_parameters(
f"{len(positions)} atoms"
)
- if charges.dtype != positions.dtype:
+ if charges.dtype != dtype:
raise TypeError(
- f"type of `charges` ({charges.dtype}) must be same as the class ({dtype})"
+ f"type of `charges` ({charges.dtype}) must be same as that of the `positions` class ({dtype})"
)
if charges.device != device:
raise ValueError(
- f"device of `charges` ({charges.device}) must be same as the class "
+ f"device of `charges` ({charges.device}) must be same as that of the `positions` class "
f"({device})"
)
@@ -111,7 +85,7 @@ def _validate_parameters(
if neighbor_indices.device != device:
raise ValueError(
f"device of `neighbor_indices` ({neighbor_indices.device}) must be "
- f"same as the class ({device})"
+ f"same as that of the `positions` class ({device})"
)
if neighbor_distances.shape != neighbor_indices[:, 0].shape:
@@ -124,11 +98,11 @@ def _validate_parameters(
if neighbor_distances.device != device:
raise ValueError(
f"device of `neighbor_distances` ({neighbor_distances.device}) must be "
- f"same as the class ({device})"
+ f"same as that of the `positions` class ({device})"
)
- if neighbor_distances.dtype != positions.dtype:
+ if neighbor_distances.dtype != dtype:
raise TypeError(
f"type of `neighbor_distances` ({neighbor_distances.dtype}) must be same "
- f"as the class ({dtype})"
+ f"as that of the `positions` class ({dtype})"
)
diff --git a/src/torchpme/calculators/calculator.py b/src/torchpme/calculators/calculator.py
index 3dfbeb38..0bdfc421 100644
--- a/src/torchpme/calculators/calculator.py
+++ b/src/torchpme/calculators/calculator.py
@@ -1,9 +1,7 @@
-from typing import Optional, Union
-
import torch
from torch import profiler
-from .._utils import _get_device, _get_dtype, _validate_parameters
+from .._utils import _validate_parameters
from ..potentials import Potential
@@ -27,8 +25,6 @@ class Calculator(torch.nn.Module):
will come from a full (True) or half (False, default) neighbor list.
:param prefactor: electrostatics prefactor; see :ref:`prefactors` for details and
common values.
- :param dtype: type used for the internal buffers and parameters
- :param device: device used for the internal buffers and parameters
"""
def __init__(
@@ -36,8 +32,6 @@ def __init__(
potential: Potential,
full_neighbor_list: bool = False,
prefactor: float = 1.0,
- dtype: Optional[torch.dtype] = None,
- device: Union[None, str, torch.device] = None,
):
super().__init__()
@@ -46,24 +40,8 @@ def __init__(
f"Potential must be an instance of Potential, got {type(potential)}"
)
- self.device = _get_device(device)
- self.dtype = _get_dtype(dtype)
-
- if self.dtype != potential.dtype:
- raise TypeError(
- f"dtype of `potential` ({potential.dtype}) must be same as of "
- f"`calculator` ({self.dtype})"
- )
-
- if self.device != potential.device:
- raise ValueError(
- f"device of `potential` ({potential.device}) must be same as of "
- f"`calculator` ({self.device})"
- )
-
self.potential = potential
self.full_neighbor_list = full_neighbor_list
-
self.prefactor = prefactor
def _compute_rspace(
@@ -164,8 +142,6 @@ def forward(
neighbor_indices=neighbor_indices,
neighbor_distances=neighbor_distances,
smearing=self.potential.smearing,
- dtype=self.dtype,
- device=self.device,
)
# Compute short-range (SR) part using a real space sum
diff --git a/src/torchpme/calculators/ewald.py b/src/torchpme/calculators/ewald.py
index 83e2dc85..4c5d1906 100644
--- a/src/torchpme/calculators/ewald.py
+++ b/src/torchpme/calculators/ewald.py
@@ -1,5 +1,3 @@
-from typing import Optional, Union
-
import torch
from ..lib import generate_kvectors_for_ewald
@@ -55,8 +53,6 @@ class EwaldCalculator(Calculator):
:obj:`False`, a "half" neighbor list is expected.
:param prefactor: electrostatics prefactor; see :ref:`prefactors` for details and
common values.
- :param dtype: type used for the internal buffers and parameters
- :param device: device used for the internal buffers and parameters
"""
def __init__(
@@ -65,15 +61,11 @@ def __init__(
lr_wavelength: float,
full_neighbor_list: bool = False,
prefactor: float = 1.0,
- dtype: Optional[torch.dtype] = None,
- device: Union[None, str, torch.device] = None,
):
super().__init__(
potential=potential,
full_neighbor_list=full_neighbor_list,
prefactor=prefactor,
- dtype=dtype,
- device=device,
)
if potential.smearing is None:
raise ValueError(
diff --git a/src/torchpme/calculators/p3m.py b/src/torchpme/calculators/p3m.py
index f85533db..5597ff9e 100644
--- a/src/torchpme/calculators/p3m.py
+++ b/src/torchpme/calculators/p3m.py
@@ -1,5 +1,3 @@
-from typing import Optional, Union
-
import torch
from ..lib.kspace_filter import P3MKSpaceFilter
@@ -42,8 +40,6 @@ class P3MCalculator(PMECalculator):
set to :py:obj:`False`, a "half" neighbor list is expected.
:param prefactor: electrostatics prefactor; see :ref:`prefactors` for details and
common values.
- :param dtype: type used for the internal buffers and parameters
- :param device: device used for the internal buffers and parameters
For an **example** on the usage for any calculator refer to :ref:`userdoc-how-to`.
"""
@@ -55,8 +51,6 @@ def __init__(
interpolation_nodes: int = 4,
full_neighbor_list: bool = False,
prefactor: float = 1.0,
- dtype: Optional[torch.dtype] = None,
- device: Union[None, str, torch.device] = None,
):
self.mesh_spacing: float = mesh_spacing
@@ -68,13 +62,23 @@ def __init__(
potential=potential,
full_neighbor_list=full_neighbor_list,
prefactor=prefactor,
- dtype=dtype,
- device=device,
)
+ if potential.smearing is None:
+ raise ValueError(
+ "Must specify smearing to use a potential with P3MCalculator"
+ )
+
+ cell = torch.eye(
+ 3,
+ device=self.potential.smearing.device,
+ dtype=self.potential.smearing.dtype,
+ )
+ ns_mesh = torch.ones(3, dtype=int, device=cell.device)
+
self.kspace_filter: P3MKSpaceFilter = P3MKSpaceFilter(
- cell=torch.eye(3, dtype=self.dtype, device=self.device),
- ns_mesh=torch.ones(3, dtype=int, device=self.device),
+ cell=cell,
+ ns_mesh=ns_mesh,
interpolation_nodes=self.interpolation_nodes,
kernel=self.potential,
mode=0, # Green's function for point-charge potentials
@@ -84,8 +88,8 @@ def __init__(
)
self.mesh_interpolator: MeshInterpolator = MeshInterpolator(
- cell=torch.eye(3, dtype=self.dtype, device=self.device),
- ns_mesh=torch.ones(3, dtype=int, device=self.device),
+ cell=cell,
+ ns_mesh=ns_mesh,
interpolation_nodes=self.interpolation_nodes,
method="P3M",
)
diff --git a/src/torchpme/calculators/pme.py b/src/torchpme/calculators/pme.py
index dd389812..a79fb249 100644
--- a/src/torchpme/calculators/pme.py
+++ b/src/torchpme/calculators/pme.py
@@ -1,5 +1,3 @@
-from typing import Optional, Union
-
import torch
from torch import profiler
@@ -47,8 +45,6 @@ class PMECalculator(Calculator):
set to :obj:`False`, a "half" neighbor list is expected.
:param prefactor: electrostatics prefactor; see :ref:`prefactors` for details and
common values.
- :param dtype: type used for the internal buffers and parameters
- :param device: device used for the internal buffers and parameters
"""
def __init__(
@@ -58,27 +54,30 @@ def __init__(
interpolation_nodes: int = 4,
full_neighbor_list: bool = False,
prefactor: float = 1.0,
- dtype: Optional[torch.dtype] = None,
- device: Union[None, str, torch.device] = None,
):
super().__init__(
potential=potential,
full_neighbor_list=full_neighbor_list,
prefactor=prefactor,
- dtype=dtype,
- device=device,
)
if potential.smearing is None:
raise ValueError(
- "Must specify smearing to use a potential with EwaldCalculator"
+ "Must specify smearing to use a potential with PMECalculator"
)
self.mesh_spacing: float = mesh_spacing
+ cell = torch.eye(
+ 3,
+ device=self.potential.smearing.device,
+ dtype=self.potential.smearing.dtype,
+ )
+ ns_mesh = torch.ones(3, dtype=int, device=cell.device)
+
self.kspace_filter: KSpaceFilter = KSpaceFilter(
- cell=torch.eye(3, dtype=self.dtype, device=self.device),
- ns_mesh=torch.ones(3, dtype=int, device=self.device),
+ cell=cell,
+ ns_mesh=ns_mesh,
kernel=self.potential,
fft_norm="backward",
ifft_norm="forward",
@@ -87,8 +86,8 @@ def __init__(
self.interpolation_nodes: int = interpolation_nodes
self.mesh_interpolator: MeshInterpolator = MeshInterpolator(
- cell=torch.eye(3, dtype=self.dtype, device=self.device),
- ns_mesh=torch.ones(3, dtype=int, device=self.device),
+ cell=cell,
+ ns_mesh=ns_mesh,
interpolation_nodes=self.interpolation_nodes,
method="Lagrange", # convention for classic PME
)
diff --git a/src/torchpme/lib/splines.py b/src/torchpme/lib/splines.py
index 036ded7c..973c7576 100644
--- a/src/torchpme/lib/splines.py
+++ b/src/torchpme/lib/splines.py
@@ -127,8 +127,8 @@ def _solve_tridiagonal(a, b, c, d):
"""
n = len(d)
# Create copies to avoid modifying the original arrays
- c_prime = torch.zeros(n)
- d_prime = torch.zeros(n)
+ c_prime = torch.zeros_like(d)
+ d_prime = torch.zeros_like(d)
# Initial coefficients
c_prime[0] = c[0] / b[0]
@@ -141,7 +141,7 @@ def _solve_tridiagonal(a, b, c, d):
d_prime[i] = (d[i] - a[i] * d_prime[i - 1]) / denom
# Backward substitution
- x = torch.zeros(n)
+ x = torch.zeros_like(d)
x[-1] = d_prime[-1]
for i in reversed(range(n - 1)):
x[i] = d_prime[i] - c_prime[i] * x[i + 1]
@@ -151,36 +151,28 @@ def _solve_tridiagonal(a, b, c, d):
def compute_second_derivatives(
x_points: torch.Tensor,
y_points: torch.Tensor,
- high_precision: Optional[bool] = True,
):
"""
Computes second derivatives given the grid points of a cubic spline.
:param x_points: Abscissas of the splining points for the real-space function
:param y_points: Ordinates of the splining points for the real-space function
- :param high_accuracy: bool, perform calculation in double precision
:return: The second derivatives for the spline points
"""
# Do the calculation in float64 if required
x = x_points
y = y_points
- if high_precision:
- x = x.to(dtype=torch.float64)
- y = y.to(dtype=torch.float64)
# Calculate intervals
intervals = x[1:] - x[:-1]
dy = (y[1:] - y[:-1]) / intervals
- # Create zero boundary conditions (natural spline)
- d2y = torch.zeros_like(x, dtype=torch.float64)
-
n = len(x)
- a = torch.zeros(n) # Sub-diagonal (a[1..n-1])
- b = torch.zeros(n) # Main diagonal (b[0..n-1])
- c = torch.zeros(n) # Super-diagonal (c[0..n-2])
- d = torch.zeros(n) # Right-hand side (d[0..n-1])
+ a = torch.zeros_like(x) # Sub-diagonal (a[1..n-1])
+ b = torch.zeros_like(x) # Main diagonal (b[0..n-1])
+ c = torch.zeros_like(x) # Super-diagonal (c[0..n-2])
+ d = torch.zeros_like(x) # Right-hand side (d[0..n-1])
# Natural spline boundary conditions
b[0] = 1
@@ -195,10 +187,9 @@ def compute_second_derivatives(
c[i] = intervals[i] / 6
d[i] = dy[i] - dy[i - 1]
- d2y = _solve_tridiagonal(a, b, c, d)
+ return _solve_tridiagonal(a, b, c, d)
# Converts back to the original dtype
- return d2y.to(dtype=x_points.dtype, device=x_points.device)
def compute_spline_ft(
@@ -206,7 +197,6 @@ def compute_spline_ft(
x_points: torch.Tensor,
y_points: torch.Tensor,
d2y_points: torch.Tensor,
- high_precision: Optional[bool] = True,
):
r"""
Computes the Fourier transform of a splined radial function.
@@ -228,7 +218,6 @@ def compute_spline_ft(
:param x_points: Abscissas of the splining points for the real-space function
:param y_points: Ordinates of the splining points for the real-space function
:param d2y_points: Second derivatives for the spline points
- :param high_accuracy: bool, perform calculation in double precision
:return: The radial Fourier transform :math:`\hat{f}(k)` computed
at the ``k_points`` provided.
@@ -244,8 +233,6 @@ def compute_spline_ft(
# chooses precision for the FT evaluation
dtype = x_points.dtype
- if high_precision:
- dtype = torch.float64
# broadcast to compute at once on all k values.
# all these are terms that enter the analytical integral.
diff --git a/src/torchpme/potentials/combined.py b/src/torchpme/potentials/combined.py
index 212f4744..dc67cdc7 100644
--- a/src/torchpme/potentials/combined.py
+++ b/src/torchpme/potentials/combined.py
@@ -1,4 +1,4 @@
-from typing import Optional, Union
+from typing import Optional
import torch
@@ -27,8 +27,6 @@ class CombinedPotential(Potential):
:param exclusion_radius: A length scale that defines a *local environment* within
which the potential should be smoothly zeroed out, as it will be described by a
separate model.
- :param dtype: type used for the internal buffers and parameters
- :param device: device used for the internal buffers and parameters
"""
def __init__(
@@ -38,14 +36,10 @@ def __init__(
learnable_weights: Optional[bool] = True,
smearing: Optional[float] = None,
exclusion_radius: Optional[float] = None,
- dtype: Optional[torch.dtype] = None,
- device: Union[None, str, torch.device] = None,
):
super().__init__(
smearing=smearing,
exclusion_radius=exclusion_radius,
- dtype=dtype,
- device=device,
)
smearings = [pot.smearing for pot in potentials]
@@ -73,9 +67,7 @@ def __init__(
"The number of initial weights must match the number of potentials being combined"
)
else:
- initial_weights = torch.ones(
- len(potentials), dtype=self.dtype, device=self.device
- )
+ initial_weights = torch.ones(len(potentials))
# for torchscript
self.potentials = torch.nn.ModuleList(potentials)
if learnable_weights:
diff --git a/src/torchpme/potentials/coulomb.py b/src/torchpme/potentials/coulomb.py
index 1e35897c..76188839 100644
--- a/src/torchpme/potentials/coulomb.py
+++ b/src/torchpme/potentials/coulomb.py
@@ -1,4 +1,4 @@
-from typing import Optional, Union
+from typing import Optional
import torch
@@ -26,30 +26,14 @@ class CoulombPotential(Potential):
:param exclusion_radius: A length scale that defines a *local environment* within
which the potential should be smoothly zeroed out, as it will be described by a
separate model.
- :param dtype: type used for the internal buffers and parameters
- :param device: device used for the internal buffers and parameters
"""
def __init__(
self,
smearing: Optional[float] = None,
exclusion_radius: Optional[float] = None,
- dtype: Optional[torch.dtype] = None,
- device: Union[None, str, torch.device] = None,
):
- super().__init__(smearing, exclusion_radius, dtype, device)
-
- # constants used in the forwward
- self.register_buffer(
- "_rsqrt2",
- torch.rsqrt(torch.tensor(2.0, dtype=self.dtype, device=self.device)),
- )
- self.register_buffer(
- "_sqrt_2_on_pi",
- torch.sqrt(
- torch.tensor(2.0 / torch.pi, dtype=self.dtype, device=self.device)
- ),
- )
+ super().__init__(smearing, exclusion_radius)
def from_dist(self, dist: torch.Tensor) -> torch.Tensor:
"""
@@ -75,7 +59,7 @@ def lr_from_dist(self, dist: torch.Tensor) -> torch.Tensor:
"Cannot compute long-range contribution without specifying `smearing`."
)
- return torch.erf(dist * (self._rsqrt2 / self.smearing)) / dist
+ return torch.erf(dist / self.smearing / 2.0**0.5) / dist
def lr_from_k_sq(self, k_sq: torch.Tensor) -> torch.Tensor:
r"""
@@ -105,7 +89,7 @@ def self_contribution(self) -> torch.Tensor:
raise ValueError(
"Cannot compute self contribution without specifying `smearing`."
)
- return self._sqrt_2_on_pi / self.smearing
+ return (2 / torch.pi) ** 0.5 / self.smearing
def background_correction(self) -> torch.Tensor:
# "charge neutrality" correction for 1/r potential
diff --git a/src/torchpme/potentials/inversepowerlaw.py b/src/torchpme/potentials/inversepowerlaw.py
index 374ab56e..1e7f945f 100644
--- a/src/torchpme/potentials/inversepowerlaw.py
+++ b/src/torchpme/potentials/inversepowerlaw.py
@@ -1,4 +1,4 @@
-from typing import Optional, Union
+from typing import Optional
import torch
from torch.special import gammainc
@@ -31,8 +31,6 @@ class InversePowerLawPotential(Potential):
:param: exclusion_radius: float or torch.Tensor containing the length scale
corresponding to a local environment. See also
:class:`Potential`.
- :param dtype: type used for the internal buffers and parameters
- :param device: device used for the internal buffers and parameters
"""
def __init__(
@@ -40,16 +38,12 @@ def __init__(
exponent: int,
smearing: Optional[float] = None,
exclusion_radius: Optional[float] = None,
- dtype: Optional[torch.dtype] = None,
- device: Union[None, str, torch.device] = None,
):
- super().__init__(smearing, exclusion_radius, dtype, device)
+ super().__init__(smearing, exclusion_radius)
# function call to check the validity of the exponent
- gammaincc_over_powerlaw(exponent, torch.tensor(1.0, dtype=dtype, device=device))
- self.register_buffer(
- "exponent", torch.tensor(exponent, dtype=self.dtype, device=self.device)
- )
+ gammaincc_over_powerlaw(exponent, torch.tensor(1.0))
+ self.register_buffer("exponent", torch.tensor(exponent, dtype=torch.float64))
@torch.jit.export
def from_dist(self, dist: torch.Tensor) -> torch.Tensor:
@@ -84,12 +78,9 @@ def lr_from_dist(self, dist: torch.Tensor) -> torch.Tensor:
"Cannot compute long-range contribution without specifying `smearing`."
)
- exponent = self.exponent
- smearing = self.smearing
-
- x = 0.5 * dist**2 / smearing**2
- peff = exponent / 2
- prefac = 1.0 / (2 * smearing**2) ** peff
+ x = 0.5 * dist**2 / self.smearing**2
+ peff = self.exponent / 2
+ prefac = 1.0 / (2 * self.smearing**2) ** peff
return prefac * gammainc(peff, x) / x**peff
@torch.jit.export
@@ -105,12 +96,11 @@ def lr_from_k_sq(self, k_sq: torch.Tensor) -> torch.Tensor:
"Cannot compute long-range kernel without specifying `smearing`."
)
- exponent = self.exponent
- smearing = self.smearing
-
- peff = (3 - exponent) / 2
- prefac = torch.pi**1.5 / gamma(exponent / 2) * (2 * smearing**2) ** peff
- x = 0.5 * smearing**2 * k_sq
+ peff = (3 - self.exponent) / 2
+ prefac = (
+ torch.pi**1.5 / gamma(self.exponent / 2) * (2 * self.smearing**2) ** peff
+ )
+ x = 0.5 * self.smearing**2 * k_sq
# The k=0 term often needs to be set separately since for exponents p<=3
# dimension, there is a divergence to +infinity. Setting this value manually
@@ -121,7 +111,7 @@ def lr_from_k_sq(self, k_sq: torch.Tensor) -> torch.Tensor:
# for consistency reasons.
masked = torch.where(x == 0, 1.0, x) # avoid NaNs in backwards, see Coulomb
return torch.where(
- k_sq == 0, 0.0, prefac * gammaincc_over_powerlaw(exponent, masked)
+ k_sq == 0, 0.0, prefac * gammaincc_over_powerlaw(self.exponent, masked)
)
def self_contribution(self) -> torch.Tensor:
@@ -142,7 +132,7 @@ def background_correction(self) -> torch.Tensor:
"Cannot compute background correction without specifying `smearing`."
)
if self.exponent >= 3:
- return self.smearing * 0.0
+ return torch.zero_like(self.smearing)
prefac = torch.pi**1.5 * (2 * self.smearing**2) ** ((3 - self.exponent) / 2)
prefac /= (3 - self.exponent) * gamma(self.exponent / 2)
return prefac
diff --git a/src/torchpme/potentials/potential.py b/src/torchpme/potentials/potential.py
index af715505..538acf04 100644
--- a/src/torchpme/potentials/potential.py
+++ b/src/torchpme/potentials/potential.py
@@ -1,9 +1,7 @@
-from typing import Optional, Union
+from typing import Optional
import torch
-from .._utils import _get_device, _get_dtype
-
class Potential(torch.nn.Module):
r"""
@@ -32,35 +30,23 @@ class Potential(torch.nn.Module):
:param exclusion_radius: A length scale that defines a *local environment* within
which the potential should be smoothly zeroed out, as it will be described by a
separate model.
- :param dtype: type used for the internal buffers and parameters
- :param device: device used for the internal buffers and parameters
"""
def __init__(
self,
smearing: Optional[float] = None,
exclusion_radius: Optional[float] = None,
- dtype: Optional[torch.dtype] = None,
- device: Union[None, str, torch.device] = None,
):
super().__init__()
- self.device = _get_device(device)
- self.dtype = _get_dtype(dtype)
-
if smearing is not None:
self.register_buffer(
- "smearing", torch.tensor(smearing, device=self.device, dtype=self.dtype)
+ "smearing", torch.tensor(smearing, dtype=torch.float64)
)
else:
self.smearing = None
- if exclusion_radius is not None:
- self.register_buffer(
- "exclusion_radius",
- torch.tensor(exclusion_radius, device=self.device, dtype=self.dtype),
- )
- else:
- self.exclusion_radius = None
+
+ self.exclusion_radius = exclusion_radius
@torch.jit.export
def f_cutoff(self, dist: torch.Tensor) -> torch.Tensor:
diff --git a/src/torchpme/potentials/spline.py b/src/torchpme/potentials/spline.py
index b58d31eb..2419f641 100644
--- a/src/torchpme/potentials/spline.py
+++ b/src/torchpme/potentials/spline.py
@@ -1,4 +1,4 @@
-from typing import Optional, Union
+from typing import Optional
import torch
@@ -42,8 +42,6 @@ class SplinePotential(Potential):
:param exclusion_radius: A length scale that defines a *local environment* within
which the potential should be smoothly zeroed out, as it will be described by a
separate model.
- :param dtype: type used for the internal buffers and parameters
- :param device: device used for the internal buffers and parameters
"""
def __init__(
@@ -57,21 +55,17 @@ def __init__(
yhat_at_zero: Optional[float] = None,
smearing: Optional[float] = None,
exclusion_radius: Optional[float] = None,
- dtype: Optional[torch.dtype] = None,
- device: Union[None, str, torch.device] = None,
):
super().__init__(
smearing=smearing,
exclusion_radius=exclusion_radius,
- dtype=dtype,
- device=device,
)
if len(y_grid) != len(r_grid):
raise ValueError("Length of radial grid and value array mismatch.")
- r_grid = r_grid.to(dtype=self.dtype, device=self.device)
- y_grid = y_grid.to(dtype=self.dtype, device=self.device)
+ self.register_buffer("r_grid", r_grid)
+ self.register_buffer("y_grid", y_grid)
if reciprocal:
if torch.min(r_grid) <= 0.0:
@@ -87,9 +81,9 @@ def __init__(
if reciprocal:
k_grid = torch.pi * 2 * torch.reciprocal(r_grid).flip(dims=[0])
else:
- k_grid = r_grid.clone()
- else:
- k_grid = k_grid.to(dtype=self.dtype, device=self.device)
+ k_grid = r_grid.clone().detach()
+
+ self.register_buffer("k_grid", k_grid)
if yhat_grid is None:
# computes automatically!
@@ -99,8 +93,8 @@ def __init__(
y_grid,
compute_second_derivatives(r_grid, y_grid),
)
- else:
- yhat_grid = yhat_grid.to(dtype=self.dtype, device=self.device)
+
+ self.register_buffer("yhat_grid", yhat_grid)
# the function is defined for k**2, so we define the grid accordingly
if reciprocal:
@@ -112,14 +106,14 @@ def __init__(
if y_at_zero is None:
self._y_at_zero = self._spline(
- torch.zeros(1, dtype=self.dtype, device=self.device)
+ torch.zeros(1, dtype=self.r_grid.dtype, device=self.r_grid.device)
)
else:
self._y_at_zero = y_at_zero
if yhat_at_zero is None:
self._yhat_at_zero = self._krn_spline(
- torch.zeros(1, dtype=self.dtype, device=self.device)
+ torch.zeros(1, dtype=self.k_grid.dtype, device=self.k_grid.device)
)
else:
self._yhat_at_zero = yhat_at_zero
diff --git a/src/torchpme/tuning/ewald.py b/src/torchpme/tuning/ewald.py
index b5bb0ae5..b1116245 100644
--- a/src/torchpme/tuning/ewald.py
+++ b/src/torchpme/tuning/ewald.py
@@ -1,5 +1,5 @@
import math
-from typing import Any, Optional, Union
+from typing import Any
from warnings import warn
import torch
@@ -19,8 +19,6 @@ def tune_ewald(
ns_lo: int = 1,
ns_hi: int = 14,
accuracy: float = 1e-3,
- dtype: Optional[torch.dtype] = None,
- device: Union[None, str, torch.device] = None,
) -> tuple[float, dict[str, Any], float]:
r"""
Find the optimal parameters for :class:`torchpme.EwaldCalculator`.
@@ -96,8 +94,6 @@ def tune_ewald(
calculator=EwaldCalculator,
error_bounds=EwaldErrorBounds(charges=charges, cell=cell, positions=positions),
params=params,
- dtype=dtype,
- device=device,
)
smearing = tuner.estimate_smearing(accuracy)
errs, timings = tuner.tune(accuracy)
diff --git a/src/torchpme/tuning/p3m.py b/src/torchpme/tuning/p3m.py
index 6a64230a..7ad6e4db 100644
--- a/src/torchpme/tuning/p3m.py
+++ b/src/torchpme/tuning/p3m.py
@@ -1,6 +1,6 @@
import math
from itertools import product
-from typing import Any, Optional, Union
+from typing import Any
from warnings import warn
import torch
@@ -79,8 +79,6 @@ def tune_p3m(
mesh_lo: int = 2,
mesh_hi: int = 7,
accuracy: float = 1e-3,
- dtype: Optional[torch.dtype] = None,
- device: Union[None, str, torch.device] = None,
) -> tuple[float, dict[str, Any], float]:
r"""
Find the optimal parameters for :class:`torchpme.calculators.pme.PMECalculator`.
@@ -169,8 +167,6 @@ def tune_p3m(
calculator=P3MCalculator,
error_bounds=P3MErrorBounds(charges=charges, cell=cell, positions=positions),
params=params,
- dtype=dtype,
- device=device,
)
smearing = tuner.estimate_smearing(accuracy)
errs, timings = tuner.tune(accuracy)
diff --git a/src/torchpme/tuning/pme.py b/src/torchpme/tuning/pme.py
index 540f3d40..68d80f0b 100644
--- a/src/torchpme/tuning/pme.py
+++ b/src/torchpme/tuning/pme.py
@@ -1,6 +1,6 @@
import math
from itertools import product
-from typing import Any, Optional, Union
+from typing import Any
from warnings import warn
import torch
@@ -22,8 +22,6 @@ def tune_pme(
mesh_lo: int = 2,
mesh_hi: int = 7,
accuracy: float = 1e-3,
- dtype: Optional[torch.dtype] = None,
- device: Union[None, str, torch.device] = None,
) -> tuple[float, dict[str, Any], float]:
r"""
Find the optimal parameters for :class:`torchpme.PMECalculator`.
@@ -112,8 +110,6 @@ def tune_pme(
calculator=PMECalculator,
error_bounds=PMEErrorBounds(charges=charges, cell=cell, positions=positions),
params=params,
- dtype=dtype,
- device=device,
)
smearing = tuner.estimate_smearing(accuracy)
errs, timings = tuner.tune(accuracy)
diff --git a/src/torchpme/tuning/tuner.py b/src/torchpme/tuning/tuner.py
index 546b3995..3887c188 100644
--- a/src/torchpme/tuning/tuner.py
+++ b/src/torchpme/tuning/tuner.py
@@ -1,10 +1,10 @@
import math
import time
-from typing import Optional, Union
+from typing import Optional
import torch
-from .._utils import _get_device, _get_dtype, _validate_parameters
+from .._utils import _validate_parameters
from ..calculators import Calculator
from ..potentials import InversePowerLawPotential
@@ -83,17 +83,12 @@ def __init__(
cutoff: float,
calculator: type[Calculator],
exponent: int = 1,
- dtype: Optional[torch.dtype] = None,
- device: Union[None, str, torch.device] = None,
):
if exponent != 1:
raise NotImplementedError(
f"Only exponent = 1 is supported but got {exponent}."
)
- self.device = _get_device(device)
- self.dtype = _get_dtype(dtype)
-
_validate_parameters(
charges=charges,
cell=cell,
@@ -103,8 +98,6 @@ def __init__(
[1.0], device=positions.device, dtype=positions.dtype
),
smearing=1.0, # dummy value because; always have range-seperated potentials
- dtype=self.dtype,
- device=self.device,
)
self.charges = charges
self.cell = cell
@@ -189,8 +182,6 @@ def __init__(
neighbor_indices: torch.Tensor,
neighbor_distances: torch.Tensor,
exponent: int = 1,
- dtype: Optional[torch.dtype] = None,
- device: Union[None, str, torch.device] = None,
):
super().__init__(
charges=charges,
@@ -199,8 +190,6 @@ def __init__(
cutoff=cutoff,
calculator=calculator,
exponent=exponent,
- dtype=dtype,
- device=device,
)
self.error_bounds = error_bounds
self.params = params
@@ -211,8 +200,6 @@ def __init__(
neighbor_indices,
neighbor_distances,
True,
- dtype=dtype,
- device=device,
)
def tune(self, accuracy: float = 1e-3) -> tuple[list[float], list[float]]:
@@ -244,14 +231,10 @@ def _timing(self, smearing: float, k_space_params: dict):
potential=InversePowerLawPotential(
exponent=self.exponent, # but only exponent = 1 is supported
smearing=smearing,
- device=self.device,
- dtype=self.dtype,
),
- device=self.device,
- dtype=self.dtype,
**k_space_params,
)
-
+ calculator.to(device=self.positions.device, dtype=self.positions.dtype)
return self.time_func(calculator)
@@ -289,14 +272,9 @@ def __init__(
n_repeat: int = 4,
n_warmup: int = 4,
run_backward: Optional[bool] = True,
- dtype: Optional[torch.dtype] = None,
- device: Union[None, str, torch.device] = None,
):
super().__init__()
- self.device = _get_device(device)
- self.dtype = _get_dtype(dtype)
-
_validate_parameters(
charges=charges,
cell=cell,
@@ -304,8 +282,6 @@ def __init__(
neighbor_indices=neighbor_indices,
neighbor_distances=neighbor_distances,
smearing=1.0, # dummy value because; always have range-seperated potentials
- device=self.device,
- dtype=self.dtype,
)
self.charges = charges
@@ -351,8 +327,6 @@ def forward(self, calculator: torch.nn.Module):
if self.run_backward:
value.backward(retain_graph=True)
- if self.device is torch.device("cuda"):
- torch.cuda.synchronize()
execution_time += time.monotonic()
return execution_time / self.n_repeat
diff --git a/tests/calculators/test_calculator.py b/tests/calculators/test_calculator.py
index dff72455..b5e8bfe8 100644
--- a/tests/calculators/test_calculator.py
+++ b/tests/calculators/test_calculator.py
@@ -43,35 +43,6 @@ def test_compute_output_shapes():
assert result.shape == charges.shape
-def test_wrong_device_positions():
- calculator = CalculatorTest()
- match = r"device of `positions` \(meta\) must be same as the class device \(cpu\)"
- with pytest.raises(ValueError, match=match):
- calculator.forward(
- positions=POSITIONS_1.to(device="meta"),
- charges=CHARGES_1,
- cell=CELL_1,
- neighbor_indices=NEIGHBOR_INDICES,
- neighbor_distances=NEIGHBOR_DISTANCES,
- )
-
-
-def test_wrong_dtype_positions():
- calculator = CalculatorTest()
- match = (
- r"type of `positions` \(torch.float64\) must be same as the class type "
- r"\(torch.float32\)"
- )
- with pytest.raises(TypeError, match=match):
- calculator.forward(
- positions=POSITIONS_1.to(dtype=torch.float64),
- charges=CHARGES_1,
- cell=CELL_1,
- neighbor_indices=NEIGHBOR_INDICES,
- neighbor_distances=NEIGHBOR_DISTANCES,
- )
-
-
# Tests for invalid shape, dtype and device of positions
def test_invalid_shape_positions():
calculator = CalculatorTest()
@@ -107,9 +78,7 @@ def test_invalid_shape_cell():
def test_invalid_dtype_cell():
calculator = CalculatorTest()
- match = (
- r"type of `cell` \(torch.float64\) must be same as the class \(torch.float32\)"
- )
+ match = r"type of `cell` \(torch.float64\) must be same as that of the `positions` class \(torch.float32\)"
with pytest.raises(TypeError, match=match):
calculator.forward(
positions=POSITIONS_1,
@@ -122,7 +91,7 @@ def test_invalid_dtype_cell():
def test_invalid_device_cell():
calculator = CalculatorTest()
- match = r"device of `cell` \(meta\) must be same as the class \(cpu\)"
+ match = r"device of `cell` \(meta\) must be same as that of the `positions` class \(cpu\)"
with pytest.raises(ValueError, match=match):
calculator.forward(
positions=POSITIONS_1,
@@ -188,7 +157,7 @@ def test_invalid_shape_charges():
def test_invalid_dtype_charges():
calculator = CalculatorTest()
match = (
- r"type of `charges` \(torch.float64\) must be same as the class "
+ r"type of `charges` \(torch.float64\) must be same as that of the `positions` class "
r"\(torch.float32\)"
)
with pytest.raises(TypeError, match=match):
@@ -203,7 +172,7 @@ def test_invalid_dtype_charges():
def test_invalid_device_charges():
calculator = CalculatorTest()
- match = r"device of `charges` \(meta\) must be same as the class \(cpu\)"
+ match = r"device of `charges` \(meta\) must be same as that of the `positions` class \(cpu\)"
with pytest.raises(ValueError, match=match):
calculator.forward(
positions=POSITIONS_1,
@@ -248,7 +217,7 @@ def test_invalid_shape_neighbor_indices_neighbor_distances():
def test_invalid_device_neighbor_indices():
calculator = CalculatorTest()
- match = r"device of `neighbor_indices` \(meta\) must be same as the class \(cpu\)"
+ match = r"device of `neighbor_indices` \(meta\) must be same as that of the `positions` class \(cpu\)"
with pytest.raises(ValueError, match=match):
calculator.forward(
positions=POSITIONS_1,
@@ -261,7 +230,7 @@ def test_invalid_device_neighbor_indices():
def test_invalid_device_neighbor_distances():
calculator = CalculatorTest()
- match = r"device of `neighbor_distances` \(meta\) must be same as the class \(cpu\)"
+ match = r"device of `neighbor_distances` \(meta\) must be same as that of the `positions` class \(cpu\)"
with pytest.raises(ValueError, match=match):
calculator.forward(
positions=POSITIONS_1,
@@ -276,7 +245,7 @@ def test_invalid_dtype_neighbor_distances():
calculator = CalculatorTest()
match = (
r"type of `neighbor_distances` \(torch.float64\) must be same "
- r"as the class \(torch.float32\)"
+ r"as that of the `positions` class \(torch.float32\)"
)
with pytest.raises(TypeError, match=match):
calculator.forward(
diff --git a/tests/calculators/test_values_direct.py b/tests/calculators/test_values_direct.py
index a5ace407..867137c0 100644
--- a/tests/calculators/test_values_direct.py
+++ b/tests/calculators/test_values_direct.py
@@ -18,10 +18,10 @@ class CalculatorTest(Calculator):
def __init__(self, **kwargs):
super().__init__(
potential=CoulombPotential(
- smearing=None, exclusion_radius=None, dtype=DTYPE
+ smearing=None,
+ exclusion_radius=None,
),
**kwargs,
- dtype=DTYPE,
)
diff --git a/tests/calculators/test_values_ewald.py b/tests/calculators/test_values_ewald.py
index edcb886e..742ee2d9 100644
--- a/tests/calculators/test_values_ewald.py
+++ b/tests/calculators/test_values_ewald.py
@@ -101,9 +101,8 @@ def test_madelung(crystal_name, scaling_factor, calc_name):
smearing = sr_cutoff / 5.0
lr_wavelength = 0.5 * smearing
calc = EwaldCalculator(
- InversePowerLawPotential(exponent=1, smearing=smearing, dtype=DTYPE),
+ InversePowerLawPotential(exponent=1, smearing=smearing),
lr_wavelength=lr_wavelength,
- dtype=DTYPE,
)
rtol = 4e-6
elif calc_name == "pme":
@@ -113,19 +112,16 @@ def test_madelung(crystal_name, scaling_factor, calc_name):
InversePowerLawPotential(
exponent=1,
smearing=smearing,
- dtype=DTYPE,
),
mesh_spacing=smearing / 8,
- dtype=DTYPE,
)
rtol = 9e-4
elif calc_name == "p3m":
sr_cutoff = 2 * scaling_factor
smearing = sr_cutoff / 5.0
calc = P3MCalculator(
- CoulombPotential(smearing=smearing, dtype=DTYPE),
+ CoulombPotential(smearing=smearing),
mesh_spacing=smearing / 8,
- dtype=DTYPE,
)
rtol = 9e-4
@@ -133,7 +129,7 @@ def test_madelung(crystal_name, scaling_factor, calc_name):
neighbor_indices, neighbor_distances = neighbor_list(
positions=pos, periodic=True, box=cell, cutoff=sr_cutoff
)
-
+ calc.to(DTYPE)
# Compute potential and compare against target value using default hypers
potentials = calc.forward(
positions=pos,
@@ -200,10 +196,10 @@ def test_wigner(crystal_name, scaling_factor):
# Compute potential and compare against reference
calc = EwaldCalculator(
- InversePowerLawPotential(exponent=1, smearing=smeareff, dtype=DTYPE),
+ InversePowerLawPotential(exponent=1, smearing=smeareff),
lr_wavelength=smeareff / 2,
- dtype=DTYPE,
)
+ calc.to(DTYPE)
potentials = calc.forward(
positions=positions,
charges=charges,
@@ -251,28 +247,25 @@ def test_random_structure(
if calc_name == "ewald":
calc = EwaldCalculator(
- CoulombPotential(smearing=smearing, dtype=DTYPE),
+ CoulombPotential(smearing=smearing),
lr_wavelength=0.5 * smearing,
full_neighbor_list=full_neighbor_list,
prefactor=torchpme.prefactors.eV_A,
- dtype=DTYPE,
)
elif calc_name == "pme":
calc = PMECalculator(
- CoulombPotential(smearing=smearing, dtype=DTYPE),
+ CoulombPotential(smearing=smearing),
mesh_spacing=smearing / 8.0,
full_neighbor_list=full_neighbor_list,
prefactor=torchpme.prefactors.eV_A,
- dtype=DTYPE,
)
elif calc_name == "p3m":
calc = P3MCalculator(
- CoulombPotential(smearing=smearing, dtype=DTYPE),
+ CoulombPotential(smearing=smearing),
mesh_spacing=smearing / 8.0,
full_neighbor_list=full_neighbor_list,
prefactor=torchpme.prefactors.eV_A,
- dtype=DTYPE,
)
neighbor_indices, neighbor_shifts = neighbor_list(
@@ -295,6 +288,7 @@ def test_random_structure(
neighbor_shifts=neighbor_shifts,
)
+ calc.to(DTYPE)
potentials = calc.forward(
positions=positions,
charges=charges,
diff --git a/tests/calculators/test_workflow.py b/tests/calculators/test_workflow.py
index 3b62ae3b..7d2e19ee 100644
--- a/tests/calculators/test_workflow.py
+++ b/tests/calculators/test_workflow.py
@@ -15,7 +15,6 @@
P3MCalculator,
PMECalculator,
)
-from torchpme._utils import _get_device, _get_dtype
DEVICES = ["cpu", torch.device("cpu")] + torch.cuda.is_available() * ["cuda"]
DTYPES = [torch.float32, torch.float64]
@@ -32,35 +31,27 @@
(
Calculator,
{
- "potential": lambda dtype, device: CoulombPotential(
- smearing=None, dtype=dtype, device=device
- ),
+ "potential": CoulombPotential(smearing=None),
},
),
(
EwaldCalculator,
{
- "potential": lambda dtype, device: CoulombPotential(
- smearing=SMEARING, dtype=dtype, device=device
- ),
+ "potential": CoulombPotential(smearing=SMEARING),
"lr_wavelength": LR_WAVELENGTH,
},
),
(
PMECalculator,
{
- "potential": lambda dtype, device: CoulombPotential(
- smearing=SMEARING, dtype=dtype, device=device
- ),
+ "potential": CoulombPotential(smearing=SMEARING),
"mesh_spacing": MESH_SPACING,
},
),
(
P3MCalculator,
{
- "potential": lambda dtype, device: CoulombPotential(
- smearing=SMEARING, dtype=dtype, device=device
- ),
+ "potential": CoulombPotential(smearing=SMEARING),
"mesh_spacing": MESH_SPACING,
},
),
@@ -69,9 +60,6 @@
class TestWorkflow:
def cscl_system(self, device=None, dtype=None):
"""CsCl crystal. Same as in the madelung test"""
- device = _get_device(device)
- dtype = _get_dtype(dtype)
-
positions = torch.tensor(
[[0, 0, 0], [0.5, 0.5, 0.5]], dtype=dtype, device=device
)
@@ -83,45 +71,37 @@ def cscl_system(self, device=None, dtype=None):
return charges, cell, positions, neighbor_indices, neighbor_distances
def test_smearing_non_positive(self, CalculatorClass, params, device, dtype):
- params = params.copy()
- params["potential"] = params["potential"](dtype, device)
match = r"`smearing` .* has to be positive"
if type(CalculatorClass) in [EwaldCalculator, PMECalculator]:
params["smearing"] = 0
with pytest.raises(ValueError, match=match):
- CalculatorClass(**params, device=device, dtype=dtype)
+ CalculatorClass(**params)
params["smearing"] = -0.1
with pytest.raises(ValueError, match=match):
- CalculatorClass(**params, device=device, dtype=dtype)
+ CalculatorClass(**params)
def test_interpolation_order_error(self, CalculatorClass, params, device, dtype):
- params = params.copy()
- params["potential"] = params["potential"](dtype, device)
if type(CalculatorClass) in [PMECalculator]:
match = "Only `interpolation_nodes` from 1 to 5"
params["interpolation_nodes"] = 10
with pytest.raises(ValueError, match=match):
- CalculatorClass(**params, device=device, dtype=dtype)
+ CalculatorClass(**params)
def test_lr_wavelength_non_positive(self, CalculatorClass, params, device, dtype):
- params = params.copy()
- params["potential"] = params["potential"](dtype, device)
match = r"`lr_wavelength` .* has to be positive"
if type(CalculatorClass) in [EwaldCalculator]:
params["lr_wavelength"] = 0
with pytest.raises(ValueError, match=match):
- CalculatorClass(**params, device=device, dtype=dtype)
+ CalculatorClass(**params)
params["lr_wavelength"] = -0.1
with pytest.raises(ValueError, match=match):
- CalculatorClass(**params, device=device, dtype=dtype)
+ CalculatorClass(**params)
def test_dtype_device(self, CalculatorClass, params, device, dtype):
"""Test that the output dtype and device are the same as the input."""
- params = params.copy()
- params["potential"] = params["potential"](dtype, device)
-
- calculator = CalculatorClass(**params, device=device, dtype=dtype)
- potential = calculator.forward(*self.cscl_system(device=device, dtype=dtype))
+ calculator = CalculatorClass(**params)
+ calculator.to(device=device, dtype=dtype)
+ potential = calculator(*self.cscl_system(device=device, dtype=dtype))
assert potential.dtype == dtype
@@ -137,26 +117,21 @@ def check_operation(self, calculator, device, dtype):
def test_operation_as_python(self, CalculatorClass, params, device, dtype):
"""Run `check_operation` as a normal python script"""
- params = params.copy()
- params["potential"] = params["potential"](dtype, device)
-
- calculator = CalculatorClass(**params, device=device, dtype=dtype)
+ calculator = CalculatorClass(**params)
+ calculator.to(device=device, dtype=dtype)
self.check_operation(calculator=calculator, device=device, dtype=dtype)
def test_operation_as_torch_script(self, CalculatorClass, params, device, dtype):
"""Run `check_operation` as a compiled torch script module."""
- params = params.copy()
- params["potential"] = params["potential"](dtype, device)
-
- calculator = CalculatorClass(**params, device=device, dtype=dtype)
+ calculator = CalculatorClass(**params)
+ calculator.to(device=device, dtype=dtype)
scripted = torch.jit.script(calculator)
self.check_operation(calculator=scripted, device=device, dtype=dtype)
def test_save_load(self, CalculatorClass, params, device, dtype):
- params = params.copy()
- params["potential"] = params["potential"](dtype, device)
-
- calculator = CalculatorClass(**params, device=device, dtype=dtype)
+ """Test if the calculator can be saved and loaded."""
+ calculator = CalculatorClass(**params)
+ calculator.to(device=device, dtype=dtype)
scripted = torch.jit.script(calculator)
with io.BytesIO() as buffer:
torch.jit.save(scripted, buffer)
@@ -165,15 +140,11 @@ def test_save_load(self, CalculatorClass, params, device, dtype):
def test_prefactor(self, CalculatorClass, params, device, dtype):
"""Test if the prefactor is applied correctly."""
- params = params.copy()
- params["potential"] = params["potential"](dtype, device)
-
prefactor = 2.0
- calculator1 = CalculatorClass(**params, device=device, dtype=dtype)
- calculator2 = CalculatorClass(
- **params, prefactor=prefactor, device=device, dtype=dtype
- )
-
+ calculator1 = CalculatorClass(**params)
+ calculator2 = CalculatorClass(**params, prefactor=prefactor)
+ calculator1.to(device=device, dtype=dtype)
+ calculator2.to(device=device, dtype=dtype)
potentials1 = calculator1.forward(*self.cscl_system(device=device, dtype=dtype))
potentials2 = calculator2.forward(*self.cscl_system(device=device, dtype=dtype))
@@ -181,10 +152,8 @@ def test_prefactor(self, CalculatorClass, params, device, dtype):
def test_not_nan(self, CalculatorClass, params, device, dtype):
"""Make sure derivatives are not NaN."""
- params = params.copy()
- params["potential"] = params["potential"](dtype, device)
-
- calculator = CalculatorClass(**params, device=device, dtype=dtype)
+ calculator = CalculatorClass(**params)
+ calculator.to(device=device, dtype=dtype)
system = self.cscl_system(device=device, dtype=dtype)
system[0].requires_grad = True
system[1].requires_grad = True
@@ -211,29 +180,6 @@ def test_not_nan(self, CalculatorClass, params, device, dtype):
torch.autograd.grad(energy, system[2], retain_graph=True)[0]
).any()
- def test_dtype_and_device_incompatability(
- self, CalculatorClass, params, device, dtype
- ):
- """Test that the calculator raises an error if the dtype and device are incompatible with potential."""
- params = params.copy()
-
- other_dtype = torch.float32 if dtype == torch.float64 else torch.float64
- params["potential"] = params["potential"](dtype, device)
-
- match = (
- rf"dtype of `potential` \({params['potential'].dtype}\) must be same as "
- rf"of `calculator` \({other_dtype}\)"
- )
- with pytest.raises(TypeError, match=match):
- CalculatorClass(**params, dtype=other_dtype, device=device)
-
- match = (
- rf"device of `potential` \({params['potential'].device}\) must be same as "
- rf"of `calculator` \(meta\)"
- )
- with pytest.raises(ValueError, match=match):
- CalculatorClass(**params, dtype=dtype, device=torch.device("meta"))
-
def test_potential_and_calculator_incompatability(
self,
CalculatorClass,
@@ -242,34 +188,17 @@ def test_potential_and_calculator_incompatability(
dtype,
):
"""Test that the calculator raises an error if the potential and calculator are incompatible."""
- params = params.copy()
- params["potential"] = params["potential"](dtype, device)
-
params["potential"] = torch.jit.script(params["potential"])
with pytest.raises(
TypeError, match="Potential must be an instance of Potential, got.*"
):
- CalculatorClass(**params, device=device, dtype=dtype)
+ CalculatorClass(**params)
- def test_device_string_compatability(self, CalculatorClass, params, dtype, device):
- """Test that the calculator works with device strings."""
- params = params.copy()
- params["potential"] = params["potential"](dtype, "cpu")
- calculator = CalculatorClass(
- **params,
- device=torch.device("cpu"),
- dtype=dtype,
- )
-
- assert calculator.device == params["potential"].device
-
- def test_device_index_compatability(self, CalculatorClass, params, dtype, device):
- """Test that the calculator works with no index on the device."""
- if torch.cuda.is_available():
- params = params.copy()
- params["potential"] = params["potential"](dtype, "cuda")
- calculator = CalculatorClass(
- **params, device=torch.device("cuda:0"), dtype=dtype
- )
-
- assert calculator.device == params["potential"].device
+ def test_smearing_incompatability(self, CalculatorClass, params, device, dtype):
+ """Test that the calculator raises an error if the potential and calculator are incompatible."""
+ if type(CalculatorClass) in [EwaldCalculator, PMECalculator, P3MCalculator]:
+ params["smearing"] = None
+ with pytest.raises(
+ TypeError, match="Must specify smearing to use a potential with .*"
+ ):
+ CalculatorClass(**params)
diff --git a/tests/helpers.py b/tests/helpers.py
index 4682b906..f553bff6 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -7,8 +7,6 @@
import torch
from vesin import NeighborList
-from torchpme._utils import _get_device, _get_dtype
-
SQRT3 = math.sqrt(3)
DIR_PATH = Path(__file__).parent
@@ -17,9 +15,6 @@
def define_crystal(crystal_name="CsCl", dtype=None, device=None):
- device = _get_device(device)
- dtype = _get_dtype(dtype)
-
# Define all relevant parameters (atom positions, charges, cell) of the reference
# crystal structures for which the Madelung constants obtained from the Ewald sums
# are compared with reference values.
diff --git a/tests/lib/test_splines.py b/tests/lib/test_splines.py
index e8576188..f5595e3e 100644
--- a/tests/lib/test_splines.py
+++ b/tests/lib/test_splines.py
@@ -59,8 +59,12 @@ def test_inverse_spline(function):
@pytest.mark.parametrize("high_accuracy", [True, False])
def test_ft_accuracy(high_accuracy):
- x_grid = torch.linspace(0, 20, 2000, dtype=torch.float32)
- y_grid = torch.exp(-(x_grid**2) * 0.5)
+ if high_accuracy:
+ x_grid = torch.linspace(0, 20, 2000, dtype=torch.float64)
+ y_grid = torch.exp(-(x_grid**2) * 0.5)
+ else:
+ x_grid = torch.linspace(0, 20, 2000, dtype=torch.float32)
+ y_grid = torch.exp(-(x_grid**2) * 0.5)
k_grid = torch.linspace(0, 20, 20, dtype=torch.float32)
krn = compute_spline_ft(
@@ -68,9 +72,9 @@ def test_ft_accuracy(high_accuracy):
x_points=x_grid,
y_points=y_grid,
d2y_points=compute_second_derivatives(
- x_points=x_grid, y_points=y_grid, high_precision=high_accuracy
+ x_points=x_grid,
+ y_points=y_grid,
),
- high_precision=high_accuracy,
)
krn_ref = torch.exp(-(k_grid**2) * 0.5) * (2 * torch.pi) ** (3 / 2)
diff --git a/tests/metatensor/test_workflow_metatensor.py b/tests/metatensor/test_workflow_metatensor.py
index 141977d2..c8f9681f 100644
--- a/tests/metatensor/test_workflow_metatensor.py
+++ b/tests/metatensor/test_workflow_metatensor.py
@@ -7,7 +7,6 @@
from packaging import version
import torchpme
-from torchpme._utils import _get_device, _get_dtype
mts_torch = pytest.importorskip("metatensor.torch")
mts_atomistic = pytest.importorskip("metatensor.torch.atomistic")
@@ -27,35 +26,27 @@
(
torchpme.metatensor.Calculator,
{
- "potential": lambda dtype, device: torchpme.CoulombPotential(
- smearing=None, dtype=dtype, device=device
- ),
+ "potential": torchpme.CoulombPotential(smearing=None),
},
),
(
torchpme.metatensor.EwaldCalculator,
{
- "potential": lambda dtype, device: torchpme.CoulombPotential(
- smearing=SMEARING, dtype=dtype, device=device
- ),
+ "potential": torchpme.CoulombPotential(smearing=SMEARING),
"lr_wavelength": LR_WAVELENGTH,
},
),
(
torchpme.metatensor.PMECalculator,
{
- "potential": lambda dtype, device: torchpme.CoulombPotential(
- smearing=SMEARING, dtype=dtype, device=device
- ),
+ "potential": torchpme.CoulombPotential(smearing=SMEARING),
"mesh_spacing": MESH_SPACING,
},
),
(
torchpme.metatensor.P3MCalculator,
{
- "potential": lambda dtype, device: torchpme.CoulombPotential(
- smearing=SMEARING, dtype=dtype, device=device
- ),
+ "potential": torchpme.CoulombPotential(smearing=SMEARING),
"mesh_spacing": MESH_SPACING,
},
),
@@ -63,9 +54,6 @@
)
class TestWorkflow:
def system(self, device=None, dtype=None):
- device = _get_device(device)
- dtype = _get_dtype(dtype)
-
system = mts_atomistic.System(
types=torch.tensor([1, 2, 2]),
positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.2], [0.0, 0.0, 0.5]]),
@@ -118,24 +106,21 @@ def check_operation(self, calculator, device, dtype):
def test_operation_as_python(self, CalculatorClass, params, device, dtype):
"""Run `check_operation` as a normal python script"""
- params = params.copy()
- params["potential"] = params["potential"](dtype, device)
- calculator = CalculatorClass(**params, device=device, dtype=dtype)
+ calculator = CalculatorClass(**params)
+ calculator.to(device=device, dtype=dtype)
self.check_operation(calculator=calculator, device=device, dtype=dtype)
def test_operation_as_torch_script(self, CalculatorClass, params, device, dtype):
"""Run `check_operation` as a compiled torch script module."""
- params = params.copy()
- params["potential"] = params["potential"](dtype, device)
- calculator = CalculatorClass(**params, device=device, dtype=dtype)
+ calculator = CalculatorClass(**params)
+ calculator.to(device=device, dtype=dtype)
scripted = torch.jit.script(calculator)
self.check_operation(calculator=scripted, device=device, dtype=dtype)
def test_save_load(self, CalculatorClass, params, device, dtype):
- params = params.copy()
- params["potential"] = params["potential"](dtype, device)
-
- calculator = CalculatorClass(**params, device=device, dtype=dtype)
+ """Save and load a compiled torch script module."""
+ calculator = CalculatorClass(**params)
+ calculator.to(device=device, dtype=dtype)
scripted = torch.jit.script(calculator)
with io.BytesIO() as buffer:
torch.jit.save(scripted, buffer)
diff --git a/tests/test_potentials.py b/tests/test_potentials.py
index 08793749..54398aa5 100644
--- a/tests/test_potentials.py
+++ b/tests/test_potentials.py
@@ -67,8 +67,8 @@ def test_sr_lr_split(exponent, smearing):
potential.
"""
# Compute diverse potentials for this inverse power law
- ipl = InversePowerLawPotential(exponent=exponent, smearing=smearing, dtype=dtype)
-
+ ipl = InversePowerLawPotential(exponent=exponent, smearing=smearing)
+ ipl.to(dtype=dtype)
potential_from_dist = ipl.from_dist(dists)
potential_sr_from_dist = ipl.sr_from_dist(dists)
potential_lr_from_dist = ipl.lr_from_dist(dists)
@@ -96,10 +96,9 @@ def test_exact_sr(exponent, smearing):
"""
# Compute SR part of Coulomb potential using the potentials class working for any
# exponent
- ipl = InversePowerLawPotential(exponent=exponent, smearing=smearing, dtype=dtype)
-
+ ipl = InversePowerLawPotential(exponent=exponent, smearing=smearing)
+ ipl.to(dtype=dtype)
potential_sr_from_dist = ipl.sr_from_dist(dists)
-
# Compute exact analytical expression obtained for relevant exponents
potential_1 = erfc(dists / SQRT2 / smearing) / dists
potential_2 = torch.exp(-0.5 * dists_sq / smearing**2) / dists_sq
@@ -110,7 +109,6 @@ def test_exact_sr(exponent, smearing):
elif exponent == 3:
prefac = SQRT2 / torch.sqrt(PI) / smearing
potential_exact = potential_1 / dists_sq + prefac * potential_2
-
# Compare results. Large tolerance due to singular division
rtol = 1e2 * machine_epsilon
atol = 4e-15
@@ -130,7 +128,8 @@ def test_exact_lr(exponent, smearing):
"""
# Compute LR part of Coulomb potential using the potentials class working for any
# exponent
- ipl = InversePowerLawPotential(exponent=exponent, smearing=smearing, dtype=dtype)
+ ipl = InversePowerLawPotential(exponent=exponent, smearing=smearing)
+ ipl.to(dtype=dtype)
potential_lr_from_dist = ipl.lr_from_dist(dists)
@@ -164,7 +163,8 @@ def test_exact_fourier(exponent, smearing):
"""
# Compute LR part of Coulomb potential using the potentials class working for any
# exponent
- ipl = InversePowerLawPotential(exponent=exponent, smearing=smearing, dtype=dtype)
+ ipl = InversePowerLawPotential(exponent=exponent, smearing=smearing)
+ ipl.to(dtype=dtype)
fourier_from_class = ipl.lr_from_k_sq(ks_sq)
@@ -202,7 +202,8 @@ def test_lr_value_at_zero(exponent, smearing):
"""
# Get atomic density at tiny distance
dist_small = torch.tensor(1e-8, dtype=dtype)
- ipl = InversePowerLawPotential(exponent=exponent, smearing=smearing, dtype=dtype)
+ ipl = InversePowerLawPotential(exponent=exponent, smearing=smearing)
+ ipl.to(dtype=dtype)
potential_close_to_zero = ipl.lr_from_dist(dist_small)
@@ -279,7 +280,8 @@ class NoImplPotential(Potential):
@pytest.mark.parametrize("exclusion_radius", [0.5, 1.0, 2.0])
def test_f_cutoff(exclusion_radius):
- coul = CoulombPotential(exclusion_radius=exclusion_radius, dtype=dtype)
+ coul = CoulombPotential(exclusion_radius=exclusion_radius)
+ coul.to(dtype=dtype)
dist = torch.tensor([0.3])
fcut = coul.f_cutoff(dist)
@@ -294,8 +296,10 @@ def test_inverserp_coulomb(smearing):
"""
# Compute LR part of Coulomb potential using the potentials class working for any
# exponent
- ipl = InversePowerLawPotential(exponent=1, smearing=smearing, dtype=dtype)
- coul = CoulombPotential(smearing=smearing, dtype=dtype)
+ ipl = InversePowerLawPotential(exponent=1, smearing=smearing)
+ ipl.to(dtype=dtype)
+ coul = CoulombPotential(smearing=smearing)
+ coul.to(dtype=dtype)
ipl_from_dist = ipl.from_dist(dists)
ipl_sr_from_dist = ipl.sr_from_dist(dists)
@@ -371,16 +375,17 @@ def test_spline_potential_vs_coulomb():
# the approximation is not super-accurate
coulomb = CoulombPotential(smearing=1.0)
- x_grid = torch.logspace(-3.0, 3.0, 1000)
+ coulomb.to(dtype=dtype)
+ x_grid = torch.logspace(-3.0, 3.0, 1000, dtype=dtype)
y_grid = coulomb.lr_from_dist(x_grid)
spline = SplinePotential(r_grid=x_grid, y_grid=y_grid, reciprocal=True)
- t_grid = torch.logspace(-torch.pi / 2, torch.pi / 2, 100)
+ t_grid = torch.logspace(-torch.pi / 2, torch.pi / 2, 100, dtype=dtype)
z_coul = coulomb.lr_from_dist(t_grid)
z_spline = spline.lr_from_dist(t_grid)
assert_close(z_coul, z_spline, atol=5e-5, rtol=0)
- k_grid2 = torch.logspace(-2, 1, 40)
+ k_grid2 = torch.logspace(-2, 1, 40, dtype=dtype)
krn_coul = coulomb.kernel_from_k_sq(k_grid2)
krn_spline = spline.kernel_from_k_sq(k_grid2)
@@ -439,8 +444,8 @@ def forward(self, x: torch.Tensor):
@pytest.mark.parametrize("smearing", smearinges)
def test_combined_potential(smearing):
- ipl_1 = InversePowerLawPotential(exponent=1, smearing=smearing, dtype=dtype)
- ipl_2 = InversePowerLawPotential(exponent=2, smearing=smearing, dtype=dtype)
+ ipl_1 = InversePowerLawPotential(exponent=1, smearing=smearing)
+ ipl_2 = InversePowerLawPotential(exponent=2, smearing=smearing)
ipl_1_from_dist = ipl_1.from_dist(dists)
ipl_1_sr_from_dist = ipl_1.sr_from_dist(dists)
@@ -461,7 +466,6 @@ def test_combined_potential(smearing):
potentials=[ipl_1, ipl_2],
initial_weights=weights,
learnable_weights=False,
- dtype=dtype,
smearing=1.0,
)
combined_from_dist = combined.from_dist(dists)
@@ -516,53 +520,53 @@ def test_combined_potential(smearing):
def test_combined_potentials_jit(smearing):
# make a separate test as pytest.mark.parametrize does not work with
# torch.jit.script for combined potentials
- coulomb = CoulombPotential(smearing=smearing, dtype=dtype)
+ coulomb = CoulombPotential(smearing=smearing)
+ coulomb.to(dtype=dtype)
x_grid = torch.logspace(-2, 2, 100, dtype=dtype)
y_grid = coulomb.lr_from_dist(x_grid)
# create a spline potential
spline = SplinePotential(
- r_grid=x_grid, y_grid=y_grid, reciprocal=True, dtype=dtype, smearing=1.0
+ r_grid=x_grid, y_grid=y_grid, reciprocal=True, smearing=1.0
)
-
- combo = CombinedPotential(potentials=[spline, coulomb], dtype=dtype, smearing=1.0)
- mypme = PMECalculator(combo, mesh_spacing=1.0, dtype=dtype)
+ spline.to(dtype=dtype)
+ combo = CombinedPotential(potentials=[spline, coulomb], smearing=1.0)
+ combo.to(dtype=dtype)
+ mypme = PMECalculator(combo, mesh_spacing=1.0)
_ = torch.jit.script(mypme)
def test_combined_potential_incompatability():
- coulomb1 = CoulombPotential(smearing=1.0, dtype=dtype)
- coulomb2 = CoulombPotential(dtype=dtype)
+ coulomb1 = CoulombPotential(smearing=1.0)
+ coulomb2 = CoulombPotential()
with pytest.raises(
ValueError,
match="Cannot combine direct \\(`smearing=None`\\) and range-separated \\(`smearing=float`\\) potentials.",
):
- _ = CombinedPotential(potentials=[coulomb1, coulomb2], dtype=dtype)
+ _ = CombinedPotential(potentials=[coulomb1, coulomb2])
with pytest.raises(
ValueError,
match="You should specify a `smearing` when combining range-separated \\(`smearing=float`\\) potentials.",
):
- _ = CombinedPotential(potentials=[coulomb1, coulomb1], dtype=dtype)
+ _ = CombinedPotential(potentials=[coulomb1, coulomb1])
with pytest.raises(
ValueError,
match="Cannot specify `smearing` when combining direct \\(`smearing=None`\\) potentials.",
):
- _ = CombinedPotential(
- potentials=[coulomb2, coulomb2], smearing=1.0, dtype=dtype
- )
+ _ = CombinedPotential(potentials=[coulomb2, coulomb2], smearing=1.0)
def test_combined_potential_learnable_weights():
weights = torch.randn(2, dtype=dtype)
- coulomb1 = CoulombPotential(smearing=2.0, dtype=dtype)
- coulomb2 = CoulombPotential(smearing=1.0, dtype=dtype)
+ coulomb1 = CoulombPotential(smearing=2.0)
+ coulomb2 = CoulombPotential(smearing=1.0)
combined = CombinedPotential(
potentials=[coulomb1, coulomb2],
smearing=1.0,
- dtype=dtype,
initial_weights=weights.clone(),
learnable_weights=True,
)
+ combined.to(dtype=dtype)
assert combined.weights.requires_grad
# make a small optimization step
@@ -587,17 +591,16 @@ def test_potential_device_dtype(potential_class, device, dtype):
exponent = 2
if potential_class is InversePowerLawPotential:
- potential = potential_class(
- exponent=exponent, smearing=smearing, dtype=dtype, device=device
- )
+ potential = potential_class(exponent=exponent, smearing=smearing)
+ potential.to(device=device, dtype=dtype)
elif potential_class is SplinePotential:
x_grid = torch.linspace(0, 20, 100, device=device, dtype=dtype)
y_grid = torch.exp(-(x_grid**2) * 0.5)
- potential = potential_class(
- r_grid=x_grid, y_grid=y_grid, reciprocal=False, dtype=dtype, device=device
- )
+ potential = potential_class(r_grid=x_grid, y_grid=y_grid, reciprocal=False)
+ potential.to(device=device, dtype=dtype)
else:
- potential = potential_class(smearing=smearing, dtype=dtype, device=device)
+ potential = potential_class(smearing=smearing)
+ potential.to(device=device, dtype=dtype)
dists = torch.linspace(0.1, 10.0, 100, device=device, dtype=dtype)
potential_lr = potential.lr_from_dist(dists)
@@ -616,13 +619,15 @@ def test_inverserp_vs_spline(exponent, smearing):
ks_sq_grad1 = ks_sq.clone().requires_grad_(True)
ks_sq_grad2 = ks_sq.clone().requires_grad_(True)
# Create InversePowerLawPotential
- ipl = InversePowerLawPotential(exponent=exponent, smearing=smearing, dtype=dtype)
+ ipl = InversePowerLawPotential(exponent=exponent, smearing=smearing)
+ ipl.to(dtype=dtype)
ipl_fourier = ipl.lr_from_k_sq(ks_sq_grad1)
# Create PotentialSpline
r_grid = torch.logspace(-5, 2, 1000, dtype=dtype)
y_grid = ipl.lr_from_dist(r_grid)
- spline = SplinePotential(r_grid=r_grid, y_grid=y_grid, dtype=dtype)
+ spline = SplinePotential(r_grid=r_grid, y_grid=y_grid)
+ spline.to(dtype=dtype)
spline_fourier = spline.lr_from_k_sq(ks_sq_grad2)
# Test agreement between InversePowerLawPotential and SplinePotential
diff --git a/tests/tuning/test_timer.py b/tests/tuning/test_timer.py
index 4481d2fa..2100e877 100644
--- a/tests/tuning/test_timer.py
+++ b/tests/tuning/test_timer.py
@@ -41,7 +41,6 @@ def test_timer():
calculator = EwaldCalculator(
potential=CoulombPotential(smearing=1.0),
lr_wavelength=0.25,
- dtype=DTYPE,
)
timing_1 = TuningTimings(
@@ -50,7 +49,6 @@ def test_timer():
positions=pos,
neighbor_indices=neighbor_indices,
neighbor_distances=neighbor_distances,
- dtype=DTYPE,
n_repeat=n_repeat_1,
)
@@ -60,7 +58,6 @@ def test_timer():
positions=pos,
neighbor_indices=neighbor_indices,
neighbor_distances=neighbor_distances,
- dtype=DTYPE,
n_repeat=n_repeat_2,
)
diff --git a/tests/tuning/test_tuning.py b/tests/tuning/test_tuning.py
index cabea6d9..07cce237 100644
--- a/tests/tuning/test_tuning.py
+++ b/tests/tuning/test_tuning.py
@@ -10,7 +10,6 @@
P3MCalculator,
PMECalculator,
)
-from torchpme._utils import _get_device, _get_dtype
from torchpme.tuning import tune_ewald, tune_p3m, tune_pme
from torchpme.tuning.tuner import TunerBase
@@ -23,9 +22,6 @@
def system(device=None, dtype=None):
- device = _get_device(device)
- dtype = _get_dtype(dtype)
-
charges = torch.ones((4, 1), dtype=dtype, device=device)
cell = torch.eye(3, dtype=dtype, device=device)
positions = 0.3 * torch.arange(12, dtype=dtype, device=device).reshape((4, 3))
@@ -50,8 +46,6 @@ def test_TunerBase_init(device, dtype):
cutoff=DEFAULT_CUTOFF,
calculator=1.0,
exponent=1,
- dtype=dtype,
- device=device,
)
@@ -89,19 +83,16 @@ def test_parameter_choose(device, dtype, calculator, tune, param_length, accurac
neighbor_indices=neighbor_indices,
neighbor_distances=neighbor_distances,
accuracy=accuracy,
- dtype=dtype,
- device=device,
)
assert len(params) == param_length
# Compute potential and compare against target value using default hypers
calc = calculator(
- potential=(CoulombPotential(smearing=smearing, dtype=dtype, device=device)),
- dtype=dtype,
- device=device,
+ potential=(CoulombPotential(smearing=smearing)),
**params,
)
+ calc.to(device=device, dtype=dtype)
potentials = calc.forward(
positions=pos,
charges=charges,
@@ -212,9 +203,7 @@ def test_invalid_cell(tune):
@pytest.mark.parametrize("tune", [tune_ewald, tune_pme, tune_p3m])
def test_invalid_dtype_cell(tune):
charges, _, positions = system()
- match = (
- r"type of `cell` \(torch.float64\) must be same as the class \(torch.float32\)"
- )
+ match = r"type of `cell` \(torch.float64\) must be same as that of the `positions` class \(torch.float32\)"
with pytest.raises(TypeError, match=match):
tune(
charges=charges,