diff --git a/docs/src/references/changelog.rst b/docs/src/references/changelog.rst index ee17fdf4..4d312f14 100644 --- a/docs/src/references/changelog.rst +++ b/docs/src/references/changelog.rst @@ -27,15 +27,13 @@ changelog `_ format. This project follows Added ##### -* Enhanced ``device`` and ``dtype`` consistency checks throughout the library * Better documentation for for ``cell``, ``charges`` and ``positions`` parameters -* Require consistent ``dtype`` between ``positions`` and ``neighbor_distances`` in - ``Calculator`` classes and tuning functions. -Fixed -##### +Removed +####### -* Fix ``device`` and ``dtype`` not being used in the init of the ``P3MCalculator`` +* Remove ``device`` and ``dtype`` from init of ``Calculator``, ``Potential`` and + ``Tuning`` classes `Version 0.2.0 `_ - 2025-01-23 ------------------------------------------------------------------------------------------ diff --git a/examples/01-charges-example.py b/examples/01-charges-example.py index 6465a507..c4038bda 100644 --- a/examples/01-charges-example.py +++ b/examples/01-charges-example.py @@ -73,7 +73,6 @@ cutoff=cutoff, neighbor_indices=neighbor_indices, neighbor_distances=neighbor_distances, - dtype=dtype, ) # %% @@ -103,9 +102,9 @@ # will be used to *compute* the potential energy of the system. calculator = torchpme.PMECalculator( - torchpme.CoulombPotential(smearing=smearing, dtype=dtype), dtype=dtype, **pme_params + torchpme.CoulombPotential(smearing=smearing), **pme_params ) - +calculator.to(dtype=dtype) # %% # # Single Charge Channel @@ -207,9 +206,9 @@ # creating a new calculator with the metatensor interface. calculator_metatensor = torchpme.metatensor.PMECalculator( - torchpme.CoulombPotential(smearing=smearing, dtype=dtype), dtype=dtype, **pme_params + torchpme.CoulombPotential(smearing=smearing), **pme_params ) - +calculator_metatensor.to(dtype=dtype) # %% # # Computation with metatensor involves using Metatensor's :class:`System diff --git a/examples/02-neighbor-lists-usage.py b/examples/02-neighbor-lists-usage.py index e72689f0..bdda04eb 100644 --- a/examples/02-neighbor-lists-usage.py +++ b/examples/02-neighbor-lists-usage.py @@ -110,7 +110,6 @@ cutoff=cutoff, neighbor_indices=neighbor_indices, neighbor_distances=neighbor_distances, - dtype=dtype, ) # %% @@ -195,8 +194,7 @@ def distances( # compute the potential. pme = torchpme.PMECalculator( - potential=torchpme.CoulombPotential(smearing=smearing, dtype=dtype), - dtype=dtype, + potential=torchpme.CoulombPotential(smearing=smearing), **pme_params, ) potential = pme( diff --git a/examples/07-lode-demo.py b/examples/07-lode-demo.py index 59feb341..6588f10d 100644 --- a/examples/07-lode-demo.py +++ b/examples/07-lode-demo.py @@ -420,7 +420,7 @@ def __init__(self, potential: Potential, n_grid: int = 3): ) # assumes a smooth exclusion region so sets the integration cutoff to half that - nodes, weights = get_full_grid(n_grid, potential.exclusion_radius.item() / 2) + nodes, weights = get_full_grid(n_grid, potential.exclusion_radius / 2) # these are the "stencils" used to project the potential # on an atom-centered basis. NB: weights might also be incorporated diff --git a/examples/08-combined-potential.py b/examples/08-combined-potential.py index 8d0667bb..73e1c077 100644 --- a/examples/08-combined-potential.py +++ b/examples/08-combined-potential.py @@ -67,10 +67,12 @@ # evaluation, and so one has to set it also for the combined potential, even if it is # not used explicitly in the evaluation of the combination. -pot_1 = InversePowerLawPotential(exponent=1, smearing=smearing, dtype=dtype) -pot_2 = InversePowerLawPotential(exponent=2, smearing=smearing, dtype=dtype) - -potential = CombinedPotential(potentials=[pot_1, pot_2], smearing=smearing, dtype=dtype) +pot_1 = InversePowerLawPotential(exponent=1, smearing=smearing) +pot_2 = InversePowerLawPotential(exponent=2, smearing=smearing) +pot_1 = pot_1.to(dtype=dtype) +pot_2 = pot_2.to(dtype=dtype) +potential = CombinedPotential(potentials=[pot_1, pot_2], smearing=smearing) +potential = potential.to(dtype=dtype) # Note also that :class:`CombinedPotential` can be used with any combination of # potentials, as long they are all either direct or range separated. For instance, one @@ -156,9 +158,9 @@ # much bigger system. calculator = EwaldCalculator( - potential=potential, lr_wavelength=lr_wavelength, prefactor=eV_A, dtype=dtype + potential=potential, lr_wavelength=lr_wavelength, prefactor=eV_A ) - +calculator.to(dtype=dtype) # %% # diff --git a/examples/10-tuning.py b/examples/10-tuning.py index 183b0f45..0978cb98 100644 --- a/examples/10-tuning.py +++ b/examples/10-tuning.py @@ -120,12 +120,10 @@ pme_params = {"mesh_spacing": 1.0, "interpolation_nodes": 4} pme = torchpme.PMECalculator( - potential=torchpme.CoulombPotential(smearing=smearing, device=device, dtype=dtype), - device=device, - dtype=dtype, + potential=torchpme.CoulombPotential(smearing=smearing), **pme_params, # type: ignore[arg-type] ) - +pme.to(device=device, dtype=dtype) # %% # Run the calculator # ~~~~~~~~~~~~~~~~~~ @@ -170,8 +168,6 @@ neighbor_indices=neighbor_indices, neighbor_distances=neighbor_distances, run_backward=True, - device=device, - dtype=dtype, ) estimated_timing = timings(pme) @@ -220,14 +216,11 @@ def timed_madelung(cutoff, smearing, mesh_spacing, interpolation_nodes, device, ) pme = torchpme.PMECalculator( - potential=torchpme.CoulombPotential( - smearing=smearing, device=device, dtype=dtype - ), + potential=torchpme.CoulombPotential(smearing=smearing), mesh_spacing=mesh_spacing, interpolation_nodes=interpolation_nodes, - device=device, - dtype=dtype, ) + pme.to(device=device, dtype=dtype) potential = pme( charges=charges, cell=cell, @@ -247,8 +240,6 @@ def timed_madelung(cutoff, smearing, mesh_spacing, interpolation_nodes, device, run_backward=True, n_warmup=1, n_repeat=4, - device=device, - dtype=dtype, ) estimated_timing = timings(pme) return madelung, estimated_timing @@ -457,8 +448,6 @@ def timed_madelung(cutoff, smearing, mesh_spacing, interpolation_nodes, device, cutoff=5.0, neighbor_indices=neighbor_indices, neighbor_distances=neighbor_distances, - device=device, - dtype=dtype, ) print( @@ -492,8 +481,6 @@ def timed_madelung(cutoff, smearing, mesh_spacing, interpolation_nodes, device, cutoff=cutoff, neighbor_indices=filter_indices, neighbor_distances=filter_distances, - device=device, - dtype=dtype, ) timings_grid.append(timing) diff --git a/examples/basic-usage.py b/examples/basic-usage.py index 47bc1506..3be525b2 100644 --- a/examples/basic-usage.py +++ b/examples/basic-usage.py @@ -146,7 +146,8 @@ # contains all the necessary functions (such as those defining the short-range and # long-range splits) for this potential and makes them useable in the rest of the code. -potential = CoulombPotential(smearing=smearing, device=device, dtype=dtype) +potential = CoulombPotential(smearing=smearing) +potential.to(device=device, dtype=dtype) # %% # @@ -193,10 +194,8 @@ # Since our structure is relatively small, we use the :class:`EwaldCalculator`. # We start by the initialization of the class. -calculator = EwaldCalculator( - potential=potential, lr_wavelength=lr_wavelength, device=device, dtype=dtype -) - +calculator = EwaldCalculator(potential=potential, lr_wavelength=lr_wavelength) +calculator.to(device=device, dtype=dtype) # %% # # Compute Energy diff --git a/src/torchpme/_utils.py b/src/torchpme/_utils.py index b7b4063c..91568ecf 100644 --- a/src/torchpme/_utils.py +++ b/src/torchpme/_utils.py @@ -1,23 +1,8 @@ -from typing import Optional, Union +from typing import Union import torch -def _get_dtype(dtype: Optional[torch.dtype]) -> torch.dtype: - return torch.get_default_dtype() if dtype is None else dtype - - -def _get_device(device: Union[None, str, torch.device]) -> torch.device: - new_device = torch.get_default_device() if device is None else torch.device(device) - - # Add default index of 0 to a cuda device to avoid errors when comparing with - # devices from tensors - if new_device.type == "cuda" and new_device.index is None: - new_device = torch.device("cuda:0") - - return new_device - - def _validate_parameters( charges: torch.Tensor, cell: torch.Tensor, @@ -25,20 +10,9 @@ def _validate_parameters( neighbor_indices: torch.Tensor, neighbor_distances: torch.Tensor, smearing: Union[float, None], - dtype: torch.dtype, - device: torch.device, ) -> None: - if positions.dtype != dtype: - raise TypeError( - f"type of `positions` ({positions.dtype}) must be same as the class " - f"type ({dtype})" - ) - - if positions.device != device: - raise ValueError( - f"device of `positions` ({positions.device}) must be same as the class " - f"device ({device})" - ) + dtype = positions.dtype + device = positions.device # check shape, dtype and device of positions num_atoms = len(positions) @@ -55,14 +29,14 @@ def _validate_parameters( f"{list(cell.shape)}" ) - if cell.dtype != positions.dtype: + if cell.dtype != dtype: raise TypeError( - f"type of `cell` ({cell.dtype}) must be same as the class ({dtype})" + f"type of `cell` ({cell.dtype}) must be same as that of the `positions` class ({dtype})" ) if cell.device != device: raise ValueError( - f"device of `cell` ({cell.device}) must be same as the class ({device})" + f"device of `cell` ({cell.device}) must be same as that of the `positions` class ({device})" ) if smearing is not None and torch.equal( @@ -89,14 +63,14 @@ def _validate_parameters( f"{len(positions)} atoms" ) - if charges.dtype != positions.dtype: + if charges.dtype != dtype: raise TypeError( - f"type of `charges` ({charges.dtype}) must be same as the class ({dtype})" + f"type of `charges` ({charges.dtype}) must be same as that of the `positions` class ({dtype})" ) if charges.device != device: raise ValueError( - f"device of `charges` ({charges.device}) must be same as the class " + f"device of `charges` ({charges.device}) must be same as that of the `positions` class " f"({device})" ) @@ -111,7 +85,7 @@ def _validate_parameters( if neighbor_indices.device != device: raise ValueError( f"device of `neighbor_indices` ({neighbor_indices.device}) must be " - f"same as the class ({device})" + f"same as that of the `positions` class ({device})" ) if neighbor_distances.shape != neighbor_indices[:, 0].shape: @@ -124,11 +98,11 @@ def _validate_parameters( if neighbor_distances.device != device: raise ValueError( f"device of `neighbor_distances` ({neighbor_distances.device}) must be " - f"same as the class ({device})" + f"same as that of the `positions` class ({device})" ) - if neighbor_distances.dtype != positions.dtype: + if neighbor_distances.dtype != dtype: raise TypeError( f"type of `neighbor_distances` ({neighbor_distances.dtype}) must be same " - f"as the class ({dtype})" + f"as that of the `positions` class ({dtype})" ) diff --git a/src/torchpme/calculators/calculator.py b/src/torchpme/calculators/calculator.py index 3dfbeb38..0bdfc421 100644 --- a/src/torchpme/calculators/calculator.py +++ b/src/torchpme/calculators/calculator.py @@ -1,9 +1,7 @@ -from typing import Optional, Union - import torch from torch import profiler -from .._utils import _get_device, _get_dtype, _validate_parameters +from .._utils import _validate_parameters from ..potentials import Potential @@ -27,8 +25,6 @@ class Calculator(torch.nn.Module): will come from a full (True) or half (False, default) neighbor list. :param prefactor: electrostatics prefactor; see :ref:`prefactors` for details and common values. - :param dtype: type used for the internal buffers and parameters - :param device: device used for the internal buffers and parameters """ def __init__( @@ -36,8 +32,6 @@ def __init__( potential: Potential, full_neighbor_list: bool = False, prefactor: float = 1.0, - dtype: Optional[torch.dtype] = None, - device: Union[None, str, torch.device] = None, ): super().__init__() @@ -46,24 +40,8 @@ def __init__( f"Potential must be an instance of Potential, got {type(potential)}" ) - self.device = _get_device(device) - self.dtype = _get_dtype(dtype) - - if self.dtype != potential.dtype: - raise TypeError( - f"dtype of `potential` ({potential.dtype}) must be same as of " - f"`calculator` ({self.dtype})" - ) - - if self.device != potential.device: - raise ValueError( - f"device of `potential` ({potential.device}) must be same as of " - f"`calculator` ({self.device})" - ) - self.potential = potential self.full_neighbor_list = full_neighbor_list - self.prefactor = prefactor def _compute_rspace( @@ -164,8 +142,6 @@ def forward( neighbor_indices=neighbor_indices, neighbor_distances=neighbor_distances, smearing=self.potential.smearing, - dtype=self.dtype, - device=self.device, ) # Compute short-range (SR) part using a real space sum diff --git a/src/torchpme/calculators/ewald.py b/src/torchpme/calculators/ewald.py index 83e2dc85..4c5d1906 100644 --- a/src/torchpme/calculators/ewald.py +++ b/src/torchpme/calculators/ewald.py @@ -1,5 +1,3 @@ -from typing import Optional, Union - import torch from ..lib import generate_kvectors_for_ewald @@ -55,8 +53,6 @@ class EwaldCalculator(Calculator): :obj:`False`, a "half" neighbor list is expected. :param prefactor: electrostatics prefactor; see :ref:`prefactors` for details and common values. - :param dtype: type used for the internal buffers and parameters - :param device: device used for the internal buffers and parameters """ def __init__( @@ -65,15 +61,11 @@ def __init__( lr_wavelength: float, full_neighbor_list: bool = False, prefactor: float = 1.0, - dtype: Optional[torch.dtype] = None, - device: Union[None, str, torch.device] = None, ): super().__init__( potential=potential, full_neighbor_list=full_neighbor_list, prefactor=prefactor, - dtype=dtype, - device=device, ) if potential.smearing is None: raise ValueError( diff --git a/src/torchpme/calculators/p3m.py b/src/torchpme/calculators/p3m.py index f85533db..5597ff9e 100644 --- a/src/torchpme/calculators/p3m.py +++ b/src/torchpme/calculators/p3m.py @@ -1,5 +1,3 @@ -from typing import Optional, Union - import torch from ..lib.kspace_filter import P3MKSpaceFilter @@ -42,8 +40,6 @@ class P3MCalculator(PMECalculator): set to :py:obj:`False`, a "half" neighbor list is expected. :param prefactor: electrostatics prefactor; see :ref:`prefactors` for details and common values. - :param dtype: type used for the internal buffers and parameters - :param device: device used for the internal buffers and parameters For an **example** on the usage for any calculator refer to :ref:`userdoc-how-to`. """ @@ -55,8 +51,6 @@ def __init__( interpolation_nodes: int = 4, full_neighbor_list: bool = False, prefactor: float = 1.0, - dtype: Optional[torch.dtype] = None, - device: Union[None, str, torch.device] = None, ): self.mesh_spacing: float = mesh_spacing @@ -68,13 +62,23 @@ def __init__( potential=potential, full_neighbor_list=full_neighbor_list, prefactor=prefactor, - dtype=dtype, - device=device, ) + if potential.smearing is None: + raise ValueError( + "Must specify smearing to use a potential with P3MCalculator" + ) + + cell = torch.eye( + 3, + device=self.potential.smearing.device, + dtype=self.potential.smearing.dtype, + ) + ns_mesh = torch.ones(3, dtype=int, device=cell.device) + self.kspace_filter: P3MKSpaceFilter = P3MKSpaceFilter( - cell=torch.eye(3, dtype=self.dtype, device=self.device), - ns_mesh=torch.ones(3, dtype=int, device=self.device), + cell=cell, + ns_mesh=ns_mesh, interpolation_nodes=self.interpolation_nodes, kernel=self.potential, mode=0, # Green's function for point-charge potentials @@ -84,8 +88,8 @@ def __init__( ) self.mesh_interpolator: MeshInterpolator = MeshInterpolator( - cell=torch.eye(3, dtype=self.dtype, device=self.device), - ns_mesh=torch.ones(3, dtype=int, device=self.device), + cell=cell, + ns_mesh=ns_mesh, interpolation_nodes=self.interpolation_nodes, method="P3M", ) diff --git a/src/torchpme/calculators/pme.py b/src/torchpme/calculators/pme.py index dd389812..a79fb249 100644 --- a/src/torchpme/calculators/pme.py +++ b/src/torchpme/calculators/pme.py @@ -1,5 +1,3 @@ -from typing import Optional, Union - import torch from torch import profiler @@ -47,8 +45,6 @@ class PMECalculator(Calculator): set to :obj:`False`, a "half" neighbor list is expected. :param prefactor: electrostatics prefactor; see :ref:`prefactors` for details and common values. - :param dtype: type used for the internal buffers and parameters - :param device: device used for the internal buffers and parameters """ def __init__( @@ -58,27 +54,30 @@ def __init__( interpolation_nodes: int = 4, full_neighbor_list: bool = False, prefactor: float = 1.0, - dtype: Optional[torch.dtype] = None, - device: Union[None, str, torch.device] = None, ): super().__init__( potential=potential, full_neighbor_list=full_neighbor_list, prefactor=prefactor, - dtype=dtype, - device=device, ) if potential.smearing is None: raise ValueError( - "Must specify smearing to use a potential with EwaldCalculator" + "Must specify smearing to use a potential with PMECalculator" ) self.mesh_spacing: float = mesh_spacing + cell = torch.eye( + 3, + device=self.potential.smearing.device, + dtype=self.potential.smearing.dtype, + ) + ns_mesh = torch.ones(3, dtype=int, device=cell.device) + self.kspace_filter: KSpaceFilter = KSpaceFilter( - cell=torch.eye(3, dtype=self.dtype, device=self.device), - ns_mesh=torch.ones(3, dtype=int, device=self.device), + cell=cell, + ns_mesh=ns_mesh, kernel=self.potential, fft_norm="backward", ifft_norm="forward", @@ -87,8 +86,8 @@ def __init__( self.interpolation_nodes: int = interpolation_nodes self.mesh_interpolator: MeshInterpolator = MeshInterpolator( - cell=torch.eye(3, dtype=self.dtype, device=self.device), - ns_mesh=torch.ones(3, dtype=int, device=self.device), + cell=cell, + ns_mesh=ns_mesh, interpolation_nodes=self.interpolation_nodes, method="Lagrange", # convention for classic PME ) diff --git a/src/torchpme/lib/splines.py b/src/torchpme/lib/splines.py index 036ded7c..973c7576 100644 --- a/src/torchpme/lib/splines.py +++ b/src/torchpme/lib/splines.py @@ -127,8 +127,8 @@ def _solve_tridiagonal(a, b, c, d): """ n = len(d) # Create copies to avoid modifying the original arrays - c_prime = torch.zeros(n) - d_prime = torch.zeros(n) + c_prime = torch.zeros_like(d) + d_prime = torch.zeros_like(d) # Initial coefficients c_prime[0] = c[0] / b[0] @@ -141,7 +141,7 @@ def _solve_tridiagonal(a, b, c, d): d_prime[i] = (d[i] - a[i] * d_prime[i - 1]) / denom # Backward substitution - x = torch.zeros(n) + x = torch.zeros_like(d) x[-1] = d_prime[-1] for i in reversed(range(n - 1)): x[i] = d_prime[i] - c_prime[i] * x[i + 1] @@ -151,36 +151,28 @@ def _solve_tridiagonal(a, b, c, d): def compute_second_derivatives( x_points: torch.Tensor, y_points: torch.Tensor, - high_precision: Optional[bool] = True, ): """ Computes second derivatives given the grid points of a cubic spline. :param x_points: Abscissas of the splining points for the real-space function :param y_points: Ordinates of the splining points for the real-space function - :param high_accuracy: bool, perform calculation in double precision :return: The second derivatives for the spline points """ # Do the calculation in float64 if required x = x_points y = y_points - if high_precision: - x = x.to(dtype=torch.float64) - y = y.to(dtype=torch.float64) # Calculate intervals intervals = x[1:] - x[:-1] dy = (y[1:] - y[:-1]) / intervals - # Create zero boundary conditions (natural spline) - d2y = torch.zeros_like(x, dtype=torch.float64) - n = len(x) - a = torch.zeros(n) # Sub-diagonal (a[1..n-1]) - b = torch.zeros(n) # Main diagonal (b[0..n-1]) - c = torch.zeros(n) # Super-diagonal (c[0..n-2]) - d = torch.zeros(n) # Right-hand side (d[0..n-1]) + a = torch.zeros_like(x) # Sub-diagonal (a[1..n-1]) + b = torch.zeros_like(x) # Main diagonal (b[0..n-1]) + c = torch.zeros_like(x) # Super-diagonal (c[0..n-2]) + d = torch.zeros_like(x) # Right-hand side (d[0..n-1]) # Natural spline boundary conditions b[0] = 1 @@ -195,10 +187,9 @@ def compute_second_derivatives( c[i] = intervals[i] / 6 d[i] = dy[i] - dy[i - 1] - d2y = _solve_tridiagonal(a, b, c, d) + return _solve_tridiagonal(a, b, c, d) # Converts back to the original dtype - return d2y.to(dtype=x_points.dtype, device=x_points.device) def compute_spline_ft( @@ -206,7 +197,6 @@ def compute_spline_ft( x_points: torch.Tensor, y_points: torch.Tensor, d2y_points: torch.Tensor, - high_precision: Optional[bool] = True, ): r""" Computes the Fourier transform of a splined radial function. @@ -228,7 +218,6 @@ def compute_spline_ft( :param x_points: Abscissas of the splining points for the real-space function :param y_points: Ordinates of the splining points for the real-space function :param d2y_points: Second derivatives for the spline points - :param high_accuracy: bool, perform calculation in double precision :return: The radial Fourier transform :math:`\hat{f}(k)` computed at the ``k_points`` provided. @@ -244,8 +233,6 @@ def compute_spline_ft( # chooses precision for the FT evaluation dtype = x_points.dtype - if high_precision: - dtype = torch.float64 # broadcast to compute at once on all k values. # all these are terms that enter the analytical integral. diff --git a/src/torchpme/potentials/combined.py b/src/torchpme/potentials/combined.py index 212f4744..dc67cdc7 100644 --- a/src/torchpme/potentials/combined.py +++ b/src/torchpme/potentials/combined.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Optional import torch @@ -27,8 +27,6 @@ class CombinedPotential(Potential): :param exclusion_radius: A length scale that defines a *local environment* within which the potential should be smoothly zeroed out, as it will be described by a separate model. - :param dtype: type used for the internal buffers and parameters - :param device: device used for the internal buffers and parameters """ def __init__( @@ -38,14 +36,10 @@ def __init__( learnable_weights: Optional[bool] = True, smearing: Optional[float] = None, exclusion_radius: Optional[float] = None, - dtype: Optional[torch.dtype] = None, - device: Union[None, str, torch.device] = None, ): super().__init__( smearing=smearing, exclusion_radius=exclusion_radius, - dtype=dtype, - device=device, ) smearings = [pot.smearing for pot in potentials] @@ -73,9 +67,7 @@ def __init__( "The number of initial weights must match the number of potentials being combined" ) else: - initial_weights = torch.ones( - len(potentials), dtype=self.dtype, device=self.device - ) + initial_weights = torch.ones(len(potentials)) # for torchscript self.potentials = torch.nn.ModuleList(potentials) if learnable_weights: diff --git a/src/torchpme/potentials/coulomb.py b/src/torchpme/potentials/coulomb.py index 1e35897c..76188839 100644 --- a/src/torchpme/potentials/coulomb.py +++ b/src/torchpme/potentials/coulomb.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Optional import torch @@ -26,30 +26,14 @@ class CoulombPotential(Potential): :param exclusion_radius: A length scale that defines a *local environment* within which the potential should be smoothly zeroed out, as it will be described by a separate model. - :param dtype: type used for the internal buffers and parameters - :param device: device used for the internal buffers and parameters """ def __init__( self, smearing: Optional[float] = None, exclusion_radius: Optional[float] = None, - dtype: Optional[torch.dtype] = None, - device: Union[None, str, torch.device] = None, ): - super().__init__(smearing, exclusion_radius, dtype, device) - - # constants used in the forwward - self.register_buffer( - "_rsqrt2", - torch.rsqrt(torch.tensor(2.0, dtype=self.dtype, device=self.device)), - ) - self.register_buffer( - "_sqrt_2_on_pi", - torch.sqrt( - torch.tensor(2.0 / torch.pi, dtype=self.dtype, device=self.device) - ), - ) + super().__init__(smearing, exclusion_radius) def from_dist(self, dist: torch.Tensor) -> torch.Tensor: """ @@ -75,7 +59,7 @@ def lr_from_dist(self, dist: torch.Tensor) -> torch.Tensor: "Cannot compute long-range contribution without specifying `smearing`." ) - return torch.erf(dist * (self._rsqrt2 / self.smearing)) / dist + return torch.erf(dist / self.smearing / 2.0**0.5) / dist def lr_from_k_sq(self, k_sq: torch.Tensor) -> torch.Tensor: r""" @@ -105,7 +89,7 @@ def self_contribution(self) -> torch.Tensor: raise ValueError( "Cannot compute self contribution without specifying `smearing`." ) - return self._sqrt_2_on_pi / self.smearing + return (2 / torch.pi) ** 0.5 / self.smearing def background_correction(self) -> torch.Tensor: # "charge neutrality" correction for 1/r potential diff --git a/src/torchpme/potentials/inversepowerlaw.py b/src/torchpme/potentials/inversepowerlaw.py index 374ab56e..1e7f945f 100644 --- a/src/torchpme/potentials/inversepowerlaw.py +++ b/src/torchpme/potentials/inversepowerlaw.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Optional import torch from torch.special import gammainc @@ -31,8 +31,6 @@ class InversePowerLawPotential(Potential): :param: exclusion_radius: float or torch.Tensor containing the length scale corresponding to a local environment. See also :class:`Potential`. - :param dtype: type used for the internal buffers and parameters - :param device: device used for the internal buffers and parameters """ def __init__( @@ -40,16 +38,12 @@ def __init__( exponent: int, smearing: Optional[float] = None, exclusion_radius: Optional[float] = None, - dtype: Optional[torch.dtype] = None, - device: Union[None, str, torch.device] = None, ): - super().__init__(smearing, exclusion_radius, dtype, device) + super().__init__(smearing, exclusion_radius) # function call to check the validity of the exponent - gammaincc_over_powerlaw(exponent, torch.tensor(1.0, dtype=dtype, device=device)) - self.register_buffer( - "exponent", torch.tensor(exponent, dtype=self.dtype, device=self.device) - ) + gammaincc_over_powerlaw(exponent, torch.tensor(1.0)) + self.register_buffer("exponent", torch.tensor(exponent, dtype=torch.float64)) @torch.jit.export def from_dist(self, dist: torch.Tensor) -> torch.Tensor: @@ -84,12 +78,9 @@ def lr_from_dist(self, dist: torch.Tensor) -> torch.Tensor: "Cannot compute long-range contribution without specifying `smearing`." ) - exponent = self.exponent - smearing = self.smearing - - x = 0.5 * dist**2 / smearing**2 - peff = exponent / 2 - prefac = 1.0 / (2 * smearing**2) ** peff + x = 0.5 * dist**2 / self.smearing**2 + peff = self.exponent / 2 + prefac = 1.0 / (2 * self.smearing**2) ** peff return prefac * gammainc(peff, x) / x**peff @torch.jit.export @@ -105,12 +96,11 @@ def lr_from_k_sq(self, k_sq: torch.Tensor) -> torch.Tensor: "Cannot compute long-range kernel without specifying `smearing`." ) - exponent = self.exponent - smearing = self.smearing - - peff = (3 - exponent) / 2 - prefac = torch.pi**1.5 / gamma(exponent / 2) * (2 * smearing**2) ** peff - x = 0.5 * smearing**2 * k_sq + peff = (3 - self.exponent) / 2 + prefac = ( + torch.pi**1.5 / gamma(self.exponent / 2) * (2 * self.smearing**2) ** peff + ) + x = 0.5 * self.smearing**2 * k_sq # The k=0 term often needs to be set separately since for exponents p<=3 # dimension, there is a divergence to +infinity. Setting this value manually @@ -121,7 +111,7 @@ def lr_from_k_sq(self, k_sq: torch.Tensor) -> torch.Tensor: # for consistency reasons. masked = torch.where(x == 0, 1.0, x) # avoid NaNs in backwards, see Coulomb return torch.where( - k_sq == 0, 0.0, prefac * gammaincc_over_powerlaw(exponent, masked) + k_sq == 0, 0.0, prefac * gammaincc_over_powerlaw(self.exponent, masked) ) def self_contribution(self) -> torch.Tensor: @@ -142,7 +132,7 @@ def background_correction(self) -> torch.Tensor: "Cannot compute background correction without specifying `smearing`." ) if self.exponent >= 3: - return self.smearing * 0.0 + return torch.zero_like(self.smearing) prefac = torch.pi**1.5 * (2 * self.smearing**2) ** ((3 - self.exponent) / 2) prefac /= (3 - self.exponent) * gamma(self.exponent / 2) return prefac diff --git a/src/torchpme/potentials/potential.py b/src/torchpme/potentials/potential.py index af715505..538acf04 100644 --- a/src/torchpme/potentials/potential.py +++ b/src/torchpme/potentials/potential.py @@ -1,9 +1,7 @@ -from typing import Optional, Union +from typing import Optional import torch -from .._utils import _get_device, _get_dtype - class Potential(torch.nn.Module): r""" @@ -32,35 +30,23 @@ class Potential(torch.nn.Module): :param exclusion_radius: A length scale that defines a *local environment* within which the potential should be smoothly zeroed out, as it will be described by a separate model. - :param dtype: type used for the internal buffers and parameters - :param device: device used for the internal buffers and parameters """ def __init__( self, smearing: Optional[float] = None, exclusion_radius: Optional[float] = None, - dtype: Optional[torch.dtype] = None, - device: Union[None, str, torch.device] = None, ): super().__init__() - self.device = _get_device(device) - self.dtype = _get_dtype(dtype) - if smearing is not None: self.register_buffer( - "smearing", torch.tensor(smearing, device=self.device, dtype=self.dtype) + "smearing", torch.tensor(smearing, dtype=torch.float64) ) else: self.smearing = None - if exclusion_radius is not None: - self.register_buffer( - "exclusion_radius", - torch.tensor(exclusion_radius, device=self.device, dtype=self.dtype), - ) - else: - self.exclusion_radius = None + + self.exclusion_radius = exclusion_radius @torch.jit.export def f_cutoff(self, dist: torch.Tensor) -> torch.Tensor: diff --git a/src/torchpme/potentials/spline.py b/src/torchpme/potentials/spline.py index b58d31eb..2419f641 100644 --- a/src/torchpme/potentials/spline.py +++ b/src/torchpme/potentials/spline.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Optional import torch @@ -42,8 +42,6 @@ class SplinePotential(Potential): :param exclusion_radius: A length scale that defines a *local environment* within which the potential should be smoothly zeroed out, as it will be described by a separate model. - :param dtype: type used for the internal buffers and parameters - :param device: device used for the internal buffers and parameters """ def __init__( @@ -57,21 +55,17 @@ def __init__( yhat_at_zero: Optional[float] = None, smearing: Optional[float] = None, exclusion_radius: Optional[float] = None, - dtype: Optional[torch.dtype] = None, - device: Union[None, str, torch.device] = None, ): super().__init__( smearing=smearing, exclusion_radius=exclusion_radius, - dtype=dtype, - device=device, ) if len(y_grid) != len(r_grid): raise ValueError("Length of radial grid and value array mismatch.") - r_grid = r_grid.to(dtype=self.dtype, device=self.device) - y_grid = y_grid.to(dtype=self.dtype, device=self.device) + self.register_buffer("r_grid", r_grid) + self.register_buffer("y_grid", y_grid) if reciprocal: if torch.min(r_grid) <= 0.0: @@ -87,9 +81,9 @@ def __init__( if reciprocal: k_grid = torch.pi * 2 * torch.reciprocal(r_grid).flip(dims=[0]) else: - k_grid = r_grid.clone() - else: - k_grid = k_grid.to(dtype=self.dtype, device=self.device) + k_grid = r_grid.clone().detach() + + self.register_buffer("k_grid", k_grid) if yhat_grid is None: # computes automatically! @@ -99,8 +93,8 @@ def __init__( y_grid, compute_second_derivatives(r_grid, y_grid), ) - else: - yhat_grid = yhat_grid.to(dtype=self.dtype, device=self.device) + + self.register_buffer("yhat_grid", yhat_grid) # the function is defined for k**2, so we define the grid accordingly if reciprocal: @@ -112,14 +106,14 @@ def __init__( if y_at_zero is None: self._y_at_zero = self._spline( - torch.zeros(1, dtype=self.dtype, device=self.device) + torch.zeros(1, dtype=self.r_grid.dtype, device=self.r_grid.device) ) else: self._y_at_zero = y_at_zero if yhat_at_zero is None: self._yhat_at_zero = self._krn_spline( - torch.zeros(1, dtype=self.dtype, device=self.device) + torch.zeros(1, dtype=self.k_grid.dtype, device=self.k_grid.device) ) else: self._yhat_at_zero = yhat_at_zero diff --git a/src/torchpme/tuning/ewald.py b/src/torchpme/tuning/ewald.py index b5bb0ae5..b1116245 100644 --- a/src/torchpme/tuning/ewald.py +++ b/src/torchpme/tuning/ewald.py @@ -1,5 +1,5 @@ import math -from typing import Any, Optional, Union +from typing import Any from warnings import warn import torch @@ -19,8 +19,6 @@ def tune_ewald( ns_lo: int = 1, ns_hi: int = 14, accuracy: float = 1e-3, - dtype: Optional[torch.dtype] = None, - device: Union[None, str, torch.device] = None, ) -> tuple[float, dict[str, Any], float]: r""" Find the optimal parameters for :class:`torchpme.EwaldCalculator`. @@ -96,8 +94,6 @@ def tune_ewald( calculator=EwaldCalculator, error_bounds=EwaldErrorBounds(charges=charges, cell=cell, positions=positions), params=params, - dtype=dtype, - device=device, ) smearing = tuner.estimate_smearing(accuracy) errs, timings = tuner.tune(accuracy) diff --git a/src/torchpme/tuning/p3m.py b/src/torchpme/tuning/p3m.py index 6a64230a..7ad6e4db 100644 --- a/src/torchpme/tuning/p3m.py +++ b/src/torchpme/tuning/p3m.py @@ -1,6 +1,6 @@ import math from itertools import product -from typing import Any, Optional, Union +from typing import Any from warnings import warn import torch @@ -79,8 +79,6 @@ def tune_p3m( mesh_lo: int = 2, mesh_hi: int = 7, accuracy: float = 1e-3, - dtype: Optional[torch.dtype] = None, - device: Union[None, str, torch.device] = None, ) -> tuple[float, dict[str, Any], float]: r""" Find the optimal parameters for :class:`torchpme.calculators.pme.PMECalculator`. @@ -169,8 +167,6 @@ def tune_p3m( calculator=P3MCalculator, error_bounds=P3MErrorBounds(charges=charges, cell=cell, positions=positions), params=params, - dtype=dtype, - device=device, ) smearing = tuner.estimate_smearing(accuracy) errs, timings = tuner.tune(accuracy) diff --git a/src/torchpme/tuning/pme.py b/src/torchpme/tuning/pme.py index 540f3d40..68d80f0b 100644 --- a/src/torchpme/tuning/pme.py +++ b/src/torchpme/tuning/pme.py @@ -1,6 +1,6 @@ import math from itertools import product -from typing import Any, Optional, Union +from typing import Any from warnings import warn import torch @@ -22,8 +22,6 @@ def tune_pme( mesh_lo: int = 2, mesh_hi: int = 7, accuracy: float = 1e-3, - dtype: Optional[torch.dtype] = None, - device: Union[None, str, torch.device] = None, ) -> tuple[float, dict[str, Any], float]: r""" Find the optimal parameters for :class:`torchpme.PMECalculator`. @@ -112,8 +110,6 @@ def tune_pme( calculator=PMECalculator, error_bounds=PMEErrorBounds(charges=charges, cell=cell, positions=positions), params=params, - dtype=dtype, - device=device, ) smearing = tuner.estimate_smearing(accuracy) errs, timings = tuner.tune(accuracy) diff --git a/src/torchpme/tuning/tuner.py b/src/torchpme/tuning/tuner.py index 546b3995..3887c188 100644 --- a/src/torchpme/tuning/tuner.py +++ b/src/torchpme/tuning/tuner.py @@ -1,10 +1,10 @@ import math import time -from typing import Optional, Union +from typing import Optional import torch -from .._utils import _get_device, _get_dtype, _validate_parameters +from .._utils import _validate_parameters from ..calculators import Calculator from ..potentials import InversePowerLawPotential @@ -83,17 +83,12 @@ def __init__( cutoff: float, calculator: type[Calculator], exponent: int = 1, - dtype: Optional[torch.dtype] = None, - device: Union[None, str, torch.device] = None, ): if exponent != 1: raise NotImplementedError( f"Only exponent = 1 is supported but got {exponent}." ) - self.device = _get_device(device) - self.dtype = _get_dtype(dtype) - _validate_parameters( charges=charges, cell=cell, @@ -103,8 +98,6 @@ def __init__( [1.0], device=positions.device, dtype=positions.dtype ), smearing=1.0, # dummy value because; always have range-seperated potentials - dtype=self.dtype, - device=self.device, ) self.charges = charges self.cell = cell @@ -189,8 +182,6 @@ def __init__( neighbor_indices: torch.Tensor, neighbor_distances: torch.Tensor, exponent: int = 1, - dtype: Optional[torch.dtype] = None, - device: Union[None, str, torch.device] = None, ): super().__init__( charges=charges, @@ -199,8 +190,6 @@ def __init__( cutoff=cutoff, calculator=calculator, exponent=exponent, - dtype=dtype, - device=device, ) self.error_bounds = error_bounds self.params = params @@ -211,8 +200,6 @@ def __init__( neighbor_indices, neighbor_distances, True, - dtype=dtype, - device=device, ) def tune(self, accuracy: float = 1e-3) -> tuple[list[float], list[float]]: @@ -244,14 +231,10 @@ def _timing(self, smearing: float, k_space_params: dict): potential=InversePowerLawPotential( exponent=self.exponent, # but only exponent = 1 is supported smearing=smearing, - device=self.device, - dtype=self.dtype, ), - device=self.device, - dtype=self.dtype, **k_space_params, ) - + calculator.to(device=self.positions.device, dtype=self.positions.dtype) return self.time_func(calculator) @@ -289,14 +272,9 @@ def __init__( n_repeat: int = 4, n_warmup: int = 4, run_backward: Optional[bool] = True, - dtype: Optional[torch.dtype] = None, - device: Union[None, str, torch.device] = None, ): super().__init__() - self.device = _get_device(device) - self.dtype = _get_dtype(dtype) - _validate_parameters( charges=charges, cell=cell, @@ -304,8 +282,6 @@ def __init__( neighbor_indices=neighbor_indices, neighbor_distances=neighbor_distances, smearing=1.0, # dummy value because; always have range-seperated potentials - device=self.device, - dtype=self.dtype, ) self.charges = charges @@ -351,8 +327,6 @@ def forward(self, calculator: torch.nn.Module): if self.run_backward: value.backward(retain_graph=True) - if self.device is torch.device("cuda"): - torch.cuda.synchronize() execution_time += time.monotonic() return execution_time / self.n_repeat diff --git a/tests/calculators/test_calculator.py b/tests/calculators/test_calculator.py index dff72455..b5e8bfe8 100644 --- a/tests/calculators/test_calculator.py +++ b/tests/calculators/test_calculator.py @@ -43,35 +43,6 @@ def test_compute_output_shapes(): assert result.shape == charges.shape -def test_wrong_device_positions(): - calculator = CalculatorTest() - match = r"device of `positions` \(meta\) must be same as the class device \(cpu\)" - with pytest.raises(ValueError, match=match): - calculator.forward( - positions=POSITIONS_1.to(device="meta"), - charges=CHARGES_1, - cell=CELL_1, - neighbor_indices=NEIGHBOR_INDICES, - neighbor_distances=NEIGHBOR_DISTANCES, - ) - - -def test_wrong_dtype_positions(): - calculator = CalculatorTest() - match = ( - r"type of `positions` \(torch.float64\) must be same as the class type " - r"\(torch.float32\)" - ) - with pytest.raises(TypeError, match=match): - calculator.forward( - positions=POSITIONS_1.to(dtype=torch.float64), - charges=CHARGES_1, - cell=CELL_1, - neighbor_indices=NEIGHBOR_INDICES, - neighbor_distances=NEIGHBOR_DISTANCES, - ) - - # Tests for invalid shape, dtype and device of positions def test_invalid_shape_positions(): calculator = CalculatorTest() @@ -107,9 +78,7 @@ def test_invalid_shape_cell(): def test_invalid_dtype_cell(): calculator = CalculatorTest() - match = ( - r"type of `cell` \(torch.float64\) must be same as the class \(torch.float32\)" - ) + match = r"type of `cell` \(torch.float64\) must be same as that of the `positions` class \(torch.float32\)" with pytest.raises(TypeError, match=match): calculator.forward( positions=POSITIONS_1, @@ -122,7 +91,7 @@ def test_invalid_dtype_cell(): def test_invalid_device_cell(): calculator = CalculatorTest() - match = r"device of `cell` \(meta\) must be same as the class \(cpu\)" + match = r"device of `cell` \(meta\) must be same as that of the `positions` class \(cpu\)" with pytest.raises(ValueError, match=match): calculator.forward( positions=POSITIONS_1, @@ -188,7 +157,7 @@ def test_invalid_shape_charges(): def test_invalid_dtype_charges(): calculator = CalculatorTest() match = ( - r"type of `charges` \(torch.float64\) must be same as the class " + r"type of `charges` \(torch.float64\) must be same as that of the `positions` class " r"\(torch.float32\)" ) with pytest.raises(TypeError, match=match): @@ -203,7 +172,7 @@ def test_invalid_dtype_charges(): def test_invalid_device_charges(): calculator = CalculatorTest() - match = r"device of `charges` \(meta\) must be same as the class \(cpu\)" + match = r"device of `charges` \(meta\) must be same as that of the `positions` class \(cpu\)" with pytest.raises(ValueError, match=match): calculator.forward( positions=POSITIONS_1, @@ -248,7 +217,7 @@ def test_invalid_shape_neighbor_indices_neighbor_distances(): def test_invalid_device_neighbor_indices(): calculator = CalculatorTest() - match = r"device of `neighbor_indices` \(meta\) must be same as the class \(cpu\)" + match = r"device of `neighbor_indices` \(meta\) must be same as that of the `positions` class \(cpu\)" with pytest.raises(ValueError, match=match): calculator.forward( positions=POSITIONS_1, @@ -261,7 +230,7 @@ def test_invalid_device_neighbor_indices(): def test_invalid_device_neighbor_distances(): calculator = CalculatorTest() - match = r"device of `neighbor_distances` \(meta\) must be same as the class \(cpu\)" + match = r"device of `neighbor_distances` \(meta\) must be same as that of the `positions` class \(cpu\)" with pytest.raises(ValueError, match=match): calculator.forward( positions=POSITIONS_1, @@ -276,7 +245,7 @@ def test_invalid_dtype_neighbor_distances(): calculator = CalculatorTest() match = ( r"type of `neighbor_distances` \(torch.float64\) must be same " - r"as the class \(torch.float32\)" + r"as that of the `positions` class \(torch.float32\)" ) with pytest.raises(TypeError, match=match): calculator.forward( diff --git a/tests/calculators/test_values_direct.py b/tests/calculators/test_values_direct.py index a5ace407..867137c0 100644 --- a/tests/calculators/test_values_direct.py +++ b/tests/calculators/test_values_direct.py @@ -18,10 +18,10 @@ class CalculatorTest(Calculator): def __init__(self, **kwargs): super().__init__( potential=CoulombPotential( - smearing=None, exclusion_radius=None, dtype=DTYPE + smearing=None, + exclusion_radius=None, ), **kwargs, - dtype=DTYPE, ) diff --git a/tests/calculators/test_values_ewald.py b/tests/calculators/test_values_ewald.py index edcb886e..742ee2d9 100644 --- a/tests/calculators/test_values_ewald.py +++ b/tests/calculators/test_values_ewald.py @@ -101,9 +101,8 @@ def test_madelung(crystal_name, scaling_factor, calc_name): smearing = sr_cutoff / 5.0 lr_wavelength = 0.5 * smearing calc = EwaldCalculator( - InversePowerLawPotential(exponent=1, smearing=smearing, dtype=DTYPE), + InversePowerLawPotential(exponent=1, smearing=smearing), lr_wavelength=lr_wavelength, - dtype=DTYPE, ) rtol = 4e-6 elif calc_name == "pme": @@ -113,19 +112,16 @@ def test_madelung(crystal_name, scaling_factor, calc_name): InversePowerLawPotential( exponent=1, smearing=smearing, - dtype=DTYPE, ), mesh_spacing=smearing / 8, - dtype=DTYPE, ) rtol = 9e-4 elif calc_name == "p3m": sr_cutoff = 2 * scaling_factor smearing = sr_cutoff / 5.0 calc = P3MCalculator( - CoulombPotential(smearing=smearing, dtype=DTYPE), + CoulombPotential(smearing=smearing), mesh_spacing=smearing / 8, - dtype=DTYPE, ) rtol = 9e-4 @@ -133,7 +129,7 @@ def test_madelung(crystal_name, scaling_factor, calc_name): neighbor_indices, neighbor_distances = neighbor_list( positions=pos, periodic=True, box=cell, cutoff=sr_cutoff ) - + calc.to(DTYPE) # Compute potential and compare against target value using default hypers potentials = calc.forward( positions=pos, @@ -200,10 +196,10 @@ def test_wigner(crystal_name, scaling_factor): # Compute potential and compare against reference calc = EwaldCalculator( - InversePowerLawPotential(exponent=1, smearing=smeareff, dtype=DTYPE), + InversePowerLawPotential(exponent=1, smearing=smeareff), lr_wavelength=smeareff / 2, - dtype=DTYPE, ) + calc.to(DTYPE) potentials = calc.forward( positions=positions, charges=charges, @@ -251,28 +247,25 @@ def test_random_structure( if calc_name == "ewald": calc = EwaldCalculator( - CoulombPotential(smearing=smearing, dtype=DTYPE), + CoulombPotential(smearing=smearing), lr_wavelength=0.5 * smearing, full_neighbor_list=full_neighbor_list, prefactor=torchpme.prefactors.eV_A, - dtype=DTYPE, ) elif calc_name == "pme": calc = PMECalculator( - CoulombPotential(smearing=smearing, dtype=DTYPE), + CoulombPotential(smearing=smearing), mesh_spacing=smearing / 8.0, full_neighbor_list=full_neighbor_list, prefactor=torchpme.prefactors.eV_A, - dtype=DTYPE, ) elif calc_name == "p3m": calc = P3MCalculator( - CoulombPotential(smearing=smearing, dtype=DTYPE), + CoulombPotential(smearing=smearing), mesh_spacing=smearing / 8.0, full_neighbor_list=full_neighbor_list, prefactor=torchpme.prefactors.eV_A, - dtype=DTYPE, ) neighbor_indices, neighbor_shifts = neighbor_list( @@ -295,6 +288,7 @@ def test_random_structure( neighbor_shifts=neighbor_shifts, ) + calc.to(DTYPE) potentials = calc.forward( positions=positions, charges=charges, diff --git a/tests/calculators/test_workflow.py b/tests/calculators/test_workflow.py index 3b62ae3b..7d2e19ee 100644 --- a/tests/calculators/test_workflow.py +++ b/tests/calculators/test_workflow.py @@ -15,7 +15,6 @@ P3MCalculator, PMECalculator, ) -from torchpme._utils import _get_device, _get_dtype DEVICES = ["cpu", torch.device("cpu")] + torch.cuda.is_available() * ["cuda"] DTYPES = [torch.float32, torch.float64] @@ -32,35 +31,27 @@ ( Calculator, { - "potential": lambda dtype, device: CoulombPotential( - smearing=None, dtype=dtype, device=device - ), + "potential": CoulombPotential(smearing=None), }, ), ( EwaldCalculator, { - "potential": lambda dtype, device: CoulombPotential( - smearing=SMEARING, dtype=dtype, device=device - ), + "potential": CoulombPotential(smearing=SMEARING), "lr_wavelength": LR_WAVELENGTH, }, ), ( PMECalculator, { - "potential": lambda dtype, device: CoulombPotential( - smearing=SMEARING, dtype=dtype, device=device - ), + "potential": CoulombPotential(smearing=SMEARING), "mesh_spacing": MESH_SPACING, }, ), ( P3MCalculator, { - "potential": lambda dtype, device: CoulombPotential( - smearing=SMEARING, dtype=dtype, device=device - ), + "potential": CoulombPotential(smearing=SMEARING), "mesh_spacing": MESH_SPACING, }, ), @@ -69,9 +60,6 @@ class TestWorkflow: def cscl_system(self, device=None, dtype=None): """CsCl crystal. Same as in the madelung test""" - device = _get_device(device) - dtype = _get_dtype(dtype) - positions = torch.tensor( [[0, 0, 0], [0.5, 0.5, 0.5]], dtype=dtype, device=device ) @@ -83,45 +71,37 @@ def cscl_system(self, device=None, dtype=None): return charges, cell, positions, neighbor_indices, neighbor_distances def test_smearing_non_positive(self, CalculatorClass, params, device, dtype): - params = params.copy() - params["potential"] = params["potential"](dtype, device) match = r"`smearing` .* has to be positive" if type(CalculatorClass) in [EwaldCalculator, PMECalculator]: params["smearing"] = 0 with pytest.raises(ValueError, match=match): - CalculatorClass(**params, device=device, dtype=dtype) + CalculatorClass(**params) params["smearing"] = -0.1 with pytest.raises(ValueError, match=match): - CalculatorClass(**params, device=device, dtype=dtype) + CalculatorClass(**params) def test_interpolation_order_error(self, CalculatorClass, params, device, dtype): - params = params.copy() - params["potential"] = params["potential"](dtype, device) if type(CalculatorClass) in [PMECalculator]: match = "Only `interpolation_nodes` from 1 to 5" params["interpolation_nodes"] = 10 with pytest.raises(ValueError, match=match): - CalculatorClass(**params, device=device, dtype=dtype) + CalculatorClass(**params) def test_lr_wavelength_non_positive(self, CalculatorClass, params, device, dtype): - params = params.copy() - params["potential"] = params["potential"](dtype, device) match = r"`lr_wavelength` .* has to be positive" if type(CalculatorClass) in [EwaldCalculator]: params["lr_wavelength"] = 0 with pytest.raises(ValueError, match=match): - CalculatorClass(**params, device=device, dtype=dtype) + CalculatorClass(**params) params["lr_wavelength"] = -0.1 with pytest.raises(ValueError, match=match): - CalculatorClass(**params, device=device, dtype=dtype) + CalculatorClass(**params) def test_dtype_device(self, CalculatorClass, params, device, dtype): """Test that the output dtype and device are the same as the input.""" - params = params.copy() - params["potential"] = params["potential"](dtype, device) - - calculator = CalculatorClass(**params, device=device, dtype=dtype) - potential = calculator.forward(*self.cscl_system(device=device, dtype=dtype)) + calculator = CalculatorClass(**params) + calculator.to(device=device, dtype=dtype) + potential = calculator(*self.cscl_system(device=device, dtype=dtype)) assert potential.dtype == dtype @@ -137,26 +117,21 @@ def check_operation(self, calculator, device, dtype): def test_operation_as_python(self, CalculatorClass, params, device, dtype): """Run `check_operation` as a normal python script""" - params = params.copy() - params["potential"] = params["potential"](dtype, device) - - calculator = CalculatorClass(**params, device=device, dtype=dtype) + calculator = CalculatorClass(**params) + calculator.to(device=device, dtype=dtype) self.check_operation(calculator=calculator, device=device, dtype=dtype) def test_operation_as_torch_script(self, CalculatorClass, params, device, dtype): """Run `check_operation` as a compiled torch script module.""" - params = params.copy() - params["potential"] = params["potential"](dtype, device) - - calculator = CalculatorClass(**params, device=device, dtype=dtype) + calculator = CalculatorClass(**params) + calculator.to(device=device, dtype=dtype) scripted = torch.jit.script(calculator) self.check_operation(calculator=scripted, device=device, dtype=dtype) def test_save_load(self, CalculatorClass, params, device, dtype): - params = params.copy() - params["potential"] = params["potential"](dtype, device) - - calculator = CalculatorClass(**params, device=device, dtype=dtype) + """Test if the calculator can be saved and loaded.""" + calculator = CalculatorClass(**params) + calculator.to(device=device, dtype=dtype) scripted = torch.jit.script(calculator) with io.BytesIO() as buffer: torch.jit.save(scripted, buffer) @@ -165,15 +140,11 @@ def test_save_load(self, CalculatorClass, params, device, dtype): def test_prefactor(self, CalculatorClass, params, device, dtype): """Test if the prefactor is applied correctly.""" - params = params.copy() - params["potential"] = params["potential"](dtype, device) - prefactor = 2.0 - calculator1 = CalculatorClass(**params, device=device, dtype=dtype) - calculator2 = CalculatorClass( - **params, prefactor=prefactor, device=device, dtype=dtype - ) - + calculator1 = CalculatorClass(**params) + calculator2 = CalculatorClass(**params, prefactor=prefactor) + calculator1.to(device=device, dtype=dtype) + calculator2.to(device=device, dtype=dtype) potentials1 = calculator1.forward(*self.cscl_system(device=device, dtype=dtype)) potentials2 = calculator2.forward(*self.cscl_system(device=device, dtype=dtype)) @@ -181,10 +152,8 @@ def test_prefactor(self, CalculatorClass, params, device, dtype): def test_not_nan(self, CalculatorClass, params, device, dtype): """Make sure derivatives are not NaN.""" - params = params.copy() - params["potential"] = params["potential"](dtype, device) - - calculator = CalculatorClass(**params, device=device, dtype=dtype) + calculator = CalculatorClass(**params) + calculator.to(device=device, dtype=dtype) system = self.cscl_system(device=device, dtype=dtype) system[0].requires_grad = True system[1].requires_grad = True @@ -211,29 +180,6 @@ def test_not_nan(self, CalculatorClass, params, device, dtype): torch.autograd.grad(energy, system[2], retain_graph=True)[0] ).any() - def test_dtype_and_device_incompatability( - self, CalculatorClass, params, device, dtype - ): - """Test that the calculator raises an error if the dtype and device are incompatible with potential.""" - params = params.copy() - - other_dtype = torch.float32 if dtype == torch.float64 else torch.float64 - params["potential"] = params["potential"](dtype, device) - - match = ( - rf"dtype of `potential` \({params['potential'].dtype}\) must be same as " - rf"of `calculator` \({other_dtype}\)" - ) - with pytest.raises(TypeError, match=match): - CalculatorClass(**params, dtype=other_dtype, device=device) - - match = ( - rf"device of `potential` \({params['potential'].device}\) must be same as " - rf"of `calculator` \(meta\)" - ) - with pytest.raises(ValueError, match=match): - CalculatorClass(**params, dtype=dtype, device=torch.device("meta")) - def test_potential_and_calculator_incompatability( self, CalculatorClass, @@ -242,34 +188,17 @@ def test_potential_and_calculator_incompatability( dtype, ): """Test that the calculator raises an error if the potential and calculator are incompatible.""" - params = params.copy() - params["potential"] = params["potential"](dtype, device) - params["potential"] = torch.jit.script(params["potential"]) with pytest.raises( TypeError, match="Potential must be an instance of Potential, got.*" ): - CalculatorClass(**params, device=device, dtype=dtype) + CalculatorClass(**params) - def test_device_string_compatability(self, CalculatorClass, params, dtype, device): - """Test that the calculator works with device strings.""" - params = params.copy() - params["potential"] = params["potential"](dtype, "cpu") - calculator = CalculatorClass( - **params, - device=torch.device("cpu"), - dtype=dtype, - ) - - assert calculator.device == params["potential"].device - - def test_device_index_compatability(self, CalculatorClass, params, dtype, device): - """Test that the calculator works with no index on the device.""" - if torch.cuda.is_available(): - params = params.copy() - params["potential"] = params["potential"](dtype, "cuda") - calculator = CalculatorClass( - **params, device=torch.device("cuda:0"), dtype=dtype - ) - - assert calculator.device == params["potential"].device + def test_smearing_incompatability(self, CalculatorClass, params, device, dtype): + """Test that the calculator raises an error if the potential and calculator are incompatible.""" + if type(CalculatorClass) in [EwaldCalculator, PMECalculator, P3MCalculator]: + params["smearing"] = None + with pytest.raises( + TypeError, match="Must specify smearing to use a potential with .*" + ): + CalculatorClass(**params) diff --git a/tests/helpers.py b/tests/helpers.py index 4682b906..f553bff6 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -7,8 +7,6 @@ import torch from vesin import NeighborList -from torchpme._utils import _get_device, _get_dtype - SQRT3 = math.sqrt(3) DIR_PATH = Path(__file__).parent @@ -17,9 +15,6 @@ def define_crystal(crystal_name="CsCl", dtype=None, device=None): - device = _get_device(device) - dtype = _get_dtype(dtype) - # Define all relevant parameters (atom positions, charges, cell) of the reference # crystal structures for which the Madelung constants obtained from the Ewald sums # are compared with reference values. diff --git a/tests/lib/test_splines.py b/tests/lib/test_splines.py index e8576188..f5595e3e 100644 --- a/tests/lib/test_splines.py +++ b/tests/lib/test_splines.py @@ -59,8 +59,12 @@ def test_inverse_spline(function): @pytest.mark.parametrize("high_accuracy", [True, False]) def test_ft_accuracy(high_accuracy): - x_grid = torch.linspace(0, 20, 2000, dtype=torch.float32) - y_grid = torch.exp(-(x_grid**2) * 0.5) + if high_accuracy: + x_grid = torch.linspace(0, 20, 2000, dtype=torch.float64) + y_grid = torch.exp(-(x_grid**2) * 0.5) + else: + x_grid = torch.linspace(0, 20, 2000, dtype=torch.float32) + y_grid = torch.exp(-(x_grid**2) * 0.5) k_grid = torch.linspace(0, 20, 20, dtype=torch.float32) krn = compute_spline_ft( @@ -68,9 +72,9 @@ def test_ft_accuracy(high_accuracy): x_points=x_grid, y_points=y_grid, d2y_points=compute_second_derivatives( - x_points=x_grid, y_points=y_grid, high_precision=high_accuracy + x_points=x_grid, + y_points=y_grid, ), - high_precision=high_accuracy, ) krn_ref = torch.exp(-(k_grid**2) * 0.5) * (2 * torch.pi) ** (3 / 2) diff --git a/tests/metatensor/test_workflow_metatensor.py b/tests/metatensor/test_workflow_metatensor.py index 141977d2..c8f9681f 100644 --- a/tests/metatensor/test_workflow_metatensor.py +++ b/tests/metatensor/test_workflow_metatensor.py @@ -7,7 +7,6 @@ from packaging import version import torchpme -from torchpme._utils import _get_device, _get_dtype mts_torch = pytest.importorskip("metatensor.torch") mts_atomistic = pytest.importorskip("metatensor.torch.atomistic") @@ -27,35 +26,27 @@ ( torchpme.metatensor.Calculator, { - "potential": lambda dtype, device: torchpme.CoulombPotential( - smearing=None, dtype=dtype, device=device - ), + "potential": torchpme.CoulombPotential(smearing=None), }, ), ( torchpme.metatensor.EwaldCalculator, { - "potential": lambda dtype, device: torchpme.CoulombPotential( - smearing=SMEARING, dtype=dtype, device=device - ), + "potential": torchpme.CoulombPotential(smearing=SMEARING), "lr_wavelength": LR_WAVELENGTH, }, ), ( torchpme.metatensor.PMECalculator, { - "potential": lambda dtype, device: torchpme.CoulombPotential( - smearing=SMEARING, dtype=dtype, device=device - ), + "potential": torchpme.CoulombPotential(smearing=SMEARING), "mesh_spacing": MESH_SPACING, }, ), ( torchpme.metatensor.P3MCalculator, { - "potential": lambda dtype, device: torchpme.CoulombPotential( - smearing=SMEARING, dtype=dtype, device=device - ), + "potential": torchpme.CoulombPotential(smearing=SMEARING), "mesh_spacing": MESH_SPACING, }, ), @@ -63,9 +54,6 @@ ) class TestWorkflow: def system(self, device=None, dtype=None): - device = _get_device(device) - dtype = _get_dtype(dtype) - system = mts_atomistic.System( types=torch.tensor([1, 2, 2]), positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.2], [0.0, 0.0, 0.5]]), @@ -118,24 +106,21 @@ def check_operation(self, calculator, device, dtype): def test_operation_as_python(self, CalculatorClass, params, device, dtype): """Run `check_operation` as a normal python script""" - params = params.copy() - params["potential"] = params["potential"](dtype, device) - calculator = CalculatorClass(**params, device=device, dtype=dtype) + calculator = CalculatorClass(**params) + calculator.to(device=device, dtype=dtype) self.check_operation(calculator=calculator, device=device, dtype=dtype) def test_operation_as_torch_script(self, CalculatorClass, params, device, dtype): """Run `check_operation` as a compiled torch script module.""" - params = params.copy() - params["potential"] = params["potential"](dtype, device) - calculator = CalculatorClass(**params, device=device, dtype=dtype) + calculator = CalculatorClass(**params) + calculator.to(device=device, dtype=dtype) scripted = torch.jit.script(calculator) self.check_operation(calculator=scripted, device=device, dtype=dtype) def test_save_load(self, CalculatorClass, params, device, dtype): - params = params.copy() - params["potential"] = params["potential"](dtype, device) - - calculator = CalculatorClass(**params, device=device, dtype=dtype) + """Save and load a compiled torch script module.""" + calculator = CalculatorClass(**params) + calculator.to(device=device, dtype=dtype) scripted = torch.jit.script(calculator) with io.BytesIO() as buffer: torch.jit.save(scripted, buffer) diff --git a/tests/test_potentials.py b/tests/test_potentials.py index 08793749..54398aa5 100644 --- a/tests/test_potentials.py +++ b/tests/test_potentials.py @@ -67,8 +67,8 @@ def test_sr_lr_split(exponent, smearing): potential. """ # Compute diverse potentials for this inverse power law - ipl = InversePowerLawPotential(exponent=exponent, smearing=smearing, dtype=dtype) - + ipl = InversePowerLawPotential(exponent=exponent, smearing=smearing) + ipl.to(dtype=dtype) potential_from_dist = ipl.from_dist(dists) potential_sr_from_dist = ipl.sr_from_dist(dists) potential_lr_from_dist = ipl.lr_from_dist(dists) @@ -96,10 +96,9 @@ def test_exact_sr(exponent, smearing): """ # Compute SR part of Coulomb potential using the potentials class working for any # exponent - ipl = InversePowerLawPotential(exponent=exponent, smearing=smearing, dtype=dtype) - + ipl = InversePowerLawPotential(exponent=exponent, smearing=smearing) + ipl.to(dtype=dtype) potential_sr_from_dist = ipl.sr_from_dist(dists) - # Compute exact analytical expression obtained for relevant exponents potential_1 = erfc(dists / SQRT2 / smearing) / dists potential_2 = torch.exp(-0.5 * dists_sq / smearing**2) / dists_sq @@ -110,7 +109,6 @@ def test_exact_sr(exponent, smearing): elif exponent == 3: prefac = SQRT2 / torch.sqrt(PI) / smearing potential_exact = potential_1 / dists_sq + prefac * potential_2 - # Compare results. Large tolerance due to singular division rtol = 1e2 * machine_epsilon atol = 4e-15 @@ -130,7 +128,8 @@ def test_exact_lr(exponent, smearing): """ # Compute LR part of Coulomb potential using the potentials class working for any # exponent - ipl = InversePowerLawPotential(exponent=exponent, smearing=smearing, dtype=dtype) + ipl = InversePowerLawPotential(exponent=exponent, smearing=smearing) + ipl.to(dtype=dtype) potential_lr_from_dist = ipl.lr_from_dist(dists) @@ -164,7 +163,8 @@ def test_exact_fourier(exponent, smearing): """ # Compute LR part of Coulomb potential using the potentials class working for any # exponent - ipl = InversePowerLawPotential(exponent=exponent, smearing=smearing, dtype=dtype) + ipl = InversePowerLawPotential(exponent=exponent, smearing=smearing) + ipl.to(dtype=dtype) fourier_from_class = ipl.lr_from_k_sq(ks_sq) @@ -202,7 +202,8 @@ def test_lr_value_at_zero(exponent, smearing): """ # Get atomic density at tiny distance dist_small = torch.tensor(1e-8, dtype=dtype) - ipl = InversePowerLawPotential(exponent=exponent, smearing=smearing, dtype=dtype) + ipl = InversePowerLawPotential(exponent=exponent, smearing=smearing) + ipl.to(dtype=dtype) potential_close_to_zero = ipl.lr_from_dist(dist_small) @@ -279,7 +280,8 @@ class NoImplPotential(Potential): @pytest.mark.parametrize("exclusion_radius", [0.5, 1.0, 2.0]) def test_f_cutoff(exclusion_radius): - coul = CoulombPotential(exclusion_radius=exclusion_radius, dtype=dtype) + coul = CoulombPotential(exclusion_radius=exclusion_radius) + coul.to(dtype=dtype) dist = torch.tensor([0.3]) fcut = coul.f_cutoff(dist) @@ -294,8 +296,10 @@ def test_inverserp_coulomb(smearing): """ # Compute LR part of Coulomb potential using the potentials class working for any # exponent - ipl = InversePowerLawPotential(exponent=1, smearing=smearing, dtype=dtype) - coul = CoulombPotential(smearing=smearing, dtype=dtype) + ipl = InversePowerLawPotential(exponent=1, smearing=smearing) + ipl.to(dtype=dtype) + coul = CoulombPotential(smearing=smearing) + coul.to(dtype=dtype) ipl_from_dist = ipl.from_dist(dists) ipl_sr_from_dist = ipl.sr_from_dist(dists) @@ -371,16 +375,17 @@ def test_spline_potential_vs_coulomb(): # the approximation is not super-accurate coulomb = CoulombPotential(smearing=1.0) - x_grid = torch.logspace(-3.0, 3.0, 1000) + coulomb.to(dtype=dtype) + x_grid = torch.logspace(-3.0, 3.0, 1000, dtype=dtype) y_grid = coulomb.lr_from_dist(x_grid) spline = SplinePotential(r_grid=x_grid, y_grid=y_grid, reciprocal=True) - t_grid = torch.logspace(-torch.pi / 2, torch.pi / 2, 100) + t_grid = torch.logspace(-torch.pi / 2, torch.pi / 2, 100, dtype=dtype) z_coul = coulomb.lr_from_dist(t_grid) z_spline = spline.lr_from_dist(t_grid) assert_close(z_coul, z_spline, atol=5e-5, rtol=0) - k_grid2 = torch.logspace(-2, 1, 40) + k_grid2 = torch.logspace(-2, 1, 40, dtype=dtype) krn_coul = coulomb.kernel_from_k_sq(k_grid2) krn_spline = spline.kernel_from_k_sq(k_grid2) @@ -439,8 +444,8 @@ def forward(self, x: torch.Tensor): @pytest.mark.parametrize("smearing", smearinges) def test_combined_potential(smearing): - ipl_1 = InversePowerLawPotential(exponent=1, smearing=smearing, dtype=dtype) - ipl_2 = InversePowerLawPotential(exponent=2, smearing=smearing, dtype=dtype) + ipl_1 = InversePowerLawPotential(exponent=1, smearing=smearing) + ipl_2 = InversePowerLawPotential(exponent=2, smearing=smearing) ipl_1_from_dist = ipl_1.from_dist(dists) ipl_1_sr_from_dist = ipl_1.sr_from_dist(dists) @@ -461,7 +466,6 @@ def test_combined_potential(smearing): potentials=[ipl_1, ipl_2], initial_weights=weights, learnable_weights=False, - dtype=dtype, smearing=1.0, ) combined_from_dist = combined.from_dist(dists) @@ -516,53 +520,53 @@ def test_combined_potential(smearing): def test_combined_potentials_jit(smearing): # make a separate test as pytest.mark.parametrize does not work with # torch.jit.script for combined potentials - coulomb = CoulombPotential(smearing=smearing, dtype=dtype) + coulomb = CoulombPotential(smearing=smearing) + coulomb.to(dtype=dtype) x_grid = torch.logspace(-2, 2, 100, dtype=dtype) y_grid = coulomb.lr_from_dist(x_grid) # create a spline potential spline = SplinePotential( - r_grid=x_grid, y_grid=y_grid, reciprocal=True, dtype=dtype, smearing=1.0 + r_grid=x_grid, y_grid=y_grid, reciprocal=True, smearing=1.0 ) - - combo = CombinedPotential(potentials=[spline, coulomb], dtype=dtype, smearing=1.0) - mypme = PMECalculator(combo, mesh_spacing=1.0, dtype=dtype) + spline.to(dtype=dtype) + combo = CombinedPotential(potentials=[spline, coulomb], smearing=1.0) + combo.to(dtype=dtype) + mypme = PMECalculator(combo, mesh_spacing=1.0) _ = torch.jit.script(mypme) def test_combined_potential_incompatability(): - coulomb1 = CoulombPotential(smearing=1.0, dtype=dtype) - coulomb2 = CoulombPotential(dtype=dtype) + coulomb1 = CoulombPotential(smearing=1.0) + coulomb2 = CoulombPotential() with pytest.raises( ValueError, match="Cannot combine direct \\(`smearing=None`\\) and range-separated \\(`smearing=float`\\) potentials.", ): - _ = CombinedPotential(potentials=[coulomb1, coulomb2], dtype=dtype) + _ = CombinedPotential(potentials=[coulomb1, coulomb2]) with pytest.raises( ValueError, match="You should specify a `smearing` when combining range-separated \\(`smearing=float`\\) potentials.", ): - _ = CombinedPotential(potentials=[coulomb1, coulomb1], dtype=dtype) + _ = CombinedPotential(potentials=[coulomb1, coulomb1]) with pytest.raises( ValueError, match="Cannot specify `smearing` when combining direct \\(`smearing=None`\\) potentials.", ): - _ = CombinedPotential( - potentials=[coulomb2, coulomb2], smearing=1.0, dtype=dtype - ) + _ = CombinedPotential(potentials=[coulomb2, coulomb2], smearing=1.0) def test_combined_potential_learnable_weights(): weights = torch.randn(2, dtype=dtype) - coulomb1 = CoulombPotential(smearing=2.0, dtype=dtype) - coulomb2 = CoulombPotential(smearing=1.0, dtype=dtype) + coulomb1 = CoulombPotential(smearing=2.0) + coulomb2 = CoulombPotential(smearing=1.0) combined = CombinedPotential( potentials=[coulomb1, coulomb2], smearing=1.0, - dtype=dtype, initial_weights=weights.clone(), learnable_weights=True, ) + combined.to(dtype=dtype) assert combined.weights.requires_grad # make a small optimization step @@ -587,17 +591,16 @@ def test_potential_device_dtype(potential_class, device, dtype): exponent = 2 if potential_class is InversePowerLawPotential: - potential = potential_class( - exponent=exponent, smearing=smearing, dtype=dtype, device=device - ) + potential = potential_class(exponent=exponent, smearing=smearing) + potential.to(device=device, dtype=dtype) elif potential_class is SplinePotential: x_grid = torch.linspace(0, 20, 100, device=device, dtype=dtype) y_grid = torch.exp(-(x_grid**2) * 0.5) - potential = potential_class( - r_grid=x_grid, y_grid=y_grid, reciprocal=False, dtype=dtype, device=device - ) + potential = potential_class(r_grid=x_grid, y_grid=y_grid, reciprocal=False) + potential.to(device=device, dtype=dtype) else: - potential = potential_class(smearing=smearing, dtype=dtype, device=device) + potential = potential_class(smearing=smearing) + potential.to(device=device, dtype=dtype) dists = torch.linspace(0.1, 10.0, 100, device=device, dtype=dtype) potential_lr = potential.lr_from_dist(dists) @@ -616,13 +619,15 @@ def test_inverserp_vs_spline(exponent, smearing): ks_sq_grad1 = ks_sq.clone().requires_grad_(True) ks_sq_grad2 = ks_sq.clone().requires_grad_(True) # Create InversePowerLawPotential - ipl = InversePowerLawPotential(exponent=exponent, smearing=smearing, dtype=dtype) + ipl = InversePowerLawPotential(exponent=exponent, smearing=smearing) + ipl.to(dtype=dtype) ipl_fourier = ipl.lr_from_k_sq(ks_sq_grad1) # Create PotentialSpline r_grid = torch.logspace(-5, 2, 1000, dtype=dtype) y_grid = ipl.lr_from_dist(r_grid) - spline = SplinePotential(r_grid=r_grid, y_grid=y_grid, dtype=dtype) + spline = SplinePotential(r_grid=r_grid, y_grid=y_grid) + spline.to(dtype=dtype) spline_fourier = spline.lr_from_k_sq(ks_sq_grad2) # Test agreement between InversePowerLawPotential and SplinePotential diff --git a/tests/tuning/test_timer.py b/tests/tuning/test_timer.py index 4481d2fa..2100e877 100644 --- a/tests/tuning/test_timer.py +++ b/tests/tuning/test_timer.py @@ -41,7 +41,6 @@ def test_timer(): calculator = EwaldCalculator( potential=CoulombPotential(smearing=1.0), lr_wavelength=0.25, - dtype=DTYPE, ) timing_1 = TuningTimings( @@ -50,7 +49,6 @@ def test_timer(): positions=pos, neighbor_indices=neighbor_indices, neighbor_distances=neighbor_distances, - dtype=DTYPE, n_repeat=n_repeat_1, ) @@ -60,7 +58,6 @@ def test_timer(): positions=pos, neighbor_indices=neighbor_indices, neighbor_distances=neighbor_distances, - dtype=DTYPE, n_repeat=n_repeat_2, ) diff --git a/tests/tuning/test_tuning.py b/tests/tuning/test_tuning.py index cabea6d9..07cce237 100644 --- a/tests/tuning/test_tuning.py +++ b/tests/tuning/test_tuning.py @@ -10,7 +10,6 @@ P3MCalculator, PMECalculator, ) -from torchpme._utils import _get_device, _get_dtype from torchpme.tuning import tune_ewald, tune_p3m, tune_pme from torchpme.tuning.tuner import TunerBase @@ -23,9 +22,6 @@ def system(device=None, dtype=None): - device = _get_device(device) - dtype = _get_dtype(dtype) - charges = torch.ones((4, 1), dtype=dtype, device=device) cell = torch.eye(3, dtype=dtype, device=device) positions = 0.3 * torch.arange(12, dtype=dtype, device=device).reshape((4, 3)) @@ -50,8 +46,6 @@ def test_TunerBase_init(device, dtype): cutoff=DEFAULT_CUTOFF, calculator=1.0, exponent=1, - dtype=dtype, - device=device, ) @@ -89,19 +83,16 @@ def test_parameter_choose(device, dtype, calculator, tune, param_length, accurac neighbor_indices=neighbor_indices, neighbor_distances=neighbor_distances, accuracy=accuracy, - dtype=dtype, - device=device, ) assert len(params) == param_length # Compute potential and compare against target value using default hypers calc = calculator( - potential=(CoulombPotential(smearing=smearing, dtype=dtype, device=device)), - dtype=dtype, - device=device, + potential=(CoulombPotential(smearing=smearing)), **params, ) + calc.to(device=device, dtype=dtype) potentials = calc.forward( positions=pos, charges=charges, @@ -212,9 +203,7 @@ def test_invalid_cell(tune): @pytest.mark.parametrize("tune", [tune_ewald, tune_pme, tune_p3m]) def test_invalid_dtype_cell(tune): charges, _, positions = system() - match = ( - r"type of `cell` \(torch.float64\) must be same as the class \(torch.float32\)" - ) + match = r"type of `cell` \(torch.float64\) must be same as that of the `positions` class \(torch.float32\)" with pytest.raises(TypeError, match=match): tune( charges=charges,