From d8e675ef7f01c0e953cc024e631016f4b13bad02 Mon Sep 17 00:00:00 2001 From: Egor Rumiantsev <48020029+E-Rum@users.noreply.github.com> Date: Thu, 16 Jan 2025 16:05:03 +0100 Subject: [PATCH] Add Explicit `dtype` and `device` Support for Calculators and Ensure Compatibility with Potentials (#143) * Refactor parameter handling in calculators and potentials for improved dtype and device management * Updated docstrings and changelog, added an assertion to check for an instance of the potential, and resolved the TorchScript Potential/Calculator incompatibility. * Update changelog and add test for potential and calculator compatibility --- docs/extensions/versions_list.py | 1 - docs/src/references/changelog.rst | 7 ++ src/torchpme/calculators/calculator.py | 22 +++++- src/torchpme/calculators/ewald.py | 8 ++ src/torchpme/calculators/p3m.py | 8 ++ src/torchpme/calculators/pme.py | 16 +++- src/torchpme/potentials/combined.py | 9 +-- src/torchpme/potentials/coulomb.py | 10 +-- src/torchpme/potentials/inversepowerlaw.py | 6 +- src/torchpme/potentials/potential.py | 10 +-- src/torchpme/potentials/spline.py | 18 ++--- tests/calculators/test_workflow.py | 86 ++++++++++++++-------- tests/test_potentials.py | 3 +- 13 files changed, 133 insertions(+), 71 deletions(-) diff --git a/docs/extensions/versions_list.py b/docs/extensions/versions_list.py index 11f5432b..df03352e 100644 --- a/docs/extensions/versions_list.py +++ b/docs/extensions/versions_list.py @@ -43,7 +43,6 @@ def run(self): :margin: 0 0 0 0\n""" for group_i, (version_short, group) in enumerate(grouped_versions.items()): - if group_i < 3: generated_content += f""" .. grid-item:: diff --git a/docs/src/references/changelog.rst b/docs/src/references/changelog.rst index d5b7960b..d1e86621 100644 --- a/docs/src/references/changelog.rst +++ b/docs/src/references/changelog.rst @@ -24,9 +24,16 @@ changelog `_ format. This project follows `Unreleased `_ ------------------------------------------------------- +Added +##### + +* Added ``dtype`` and ``device`` for ``Calculator`` classses + Fixed ##### +* Ensured consistency of ``dtype`` and ``device`` in the ``Potential`` and + ``Calculator`` classses * Fixed consistency of ``dtype`` and ``device`` in the ``SplinePotential`` class * Fix inconsistent ``cutoff`` in neighbor list example * All calculators now check if the cell is zero if the potential is range-separated diff --git a/src/torchpme/calculators/calculator.py b/src/torchpme/calculators/calculator.py index 9a35802a..627283e5 100644 --- a/src/torchpme/calculators/calculator.py +++ b/src/torchpme/calculators/calculator.py @@ -26,6 +26,8 @@ class Calculator(torch.nn.Module): will come from a full (True) or half (False, default) neighbor list. :param prefactor: electrostatics prefactor; see :ref:`prefactors` for details and common values. + :param dtype: type used for the internal buffers and parameters + :param device: device used for the internal buffers and parameters """ def __init__( @@ -33,14 +35,28 @@ def __init__( potential: Potential, full_neighbor_list: bool = False, prefactor: float = 1.0, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, ): super().__init__() - # TorchScript requires to initialize all attributes in __init__ - self._device = torch.device("cpu") - self._dtype = torch.float32 + assert isinstance(potential, Potential), ( + f"Potential must be an instance of Potential, got {type(potential)}" + ) + + self.device = "cpu" if device is None else device + self.dtype = torch.get_default_dtype() if dtype is None else dtype self.potential = potential + assert self.dtype == self.potential.dtype, ( + f"Potential and Calculator must have the same dtype, got {self.dtype} and " + f"{self.potential.dtype}" + ) + assert self.device == self.potential.device, ( + f"Potential and Calculator must have the same device, got {self.device} and " + f"{self.potential.device}" + ) + self.full_neighbor_list = full_neighbor_list self.prefactor = prefactor diff --git a/src/torchpme/calculators/ewald.py b/src/torchpme/calculators/ewald.py index e8bffb5c..d009c4d1 100644 --- a/src/torchpme/calculators/ewald.py +++ b/src/torchpme/calculators/ewald.py @@ -1,3 +1,5 @@ +from typing import Optional + import torch from ..lib import generate_kvectors_for_ewald @@ -53,6 +55,8 @@ class EwaldCalculator(Calculator): :obj:`False`, a "half" neighbor list is expected. :param prefactor: electrostatics prefactor; see :ref:`prefactors` for details and common values. + :param dtype: type used for the internal buffers and parameters + :param device: device used for the internal buffers and parameters """ def __init__( @@ -61,11 +65,15 @@ def __init__( lr_wavelength: float, full_neighbor_list: bool = False, prefactor: float = 1.0, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, ): super().__init__( potential=potential, full_neighbor_list=full_neighbor_list, prefactor=prefactor, + dtype=dtype, + device=device, ) if potential.smearing is None: raise ValueError( diff --git a/src/torchpme/calculators/p3m.py b/src/torchpme/calculators/p3m.py index 58826b08..76a63874 100644 --- a/src/torchpme/calculators/p3m.py +++ b/src/torchpme/calculators/p3m.py @@ -1,3 +1,5 @@ +from typing import Optional + import torch from ..lib.kspace_filter import P3MKSpaceFilter @@ -40,6 +42,8 @@ class P3MCalculator(PMECalculator): set to :py:obj:`False`, a "half" neighbor list is expected. :param prefactor: electrostatics prefactor; see :ref:`prefactors` for details and common values. + :param dtype: type used for the internal buffers and parameters + :param device: device used for the internal buffers and parameters For an **example** on the usage for any calculator refer to :ref:`userdoc-how-to`. """ @@ -51,6 +55,8 @@ def __init__( interpolation_nodes: int = 4, full_neighbor_list: bool = False, prefactor: float = 1.0, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, ): self.mesh_spacing: float = mesh_spacing @@ -62,6 +68,8 @@ def __init__( potential=potential, full_neighbor_list=full_neighbor_list, prefactor=prefactor, + dtype=dtype, + device=device, ) self.kspace_filter: P3MKSpaceFilter = P3MKSpaceFilter( diff --git a/src/torchpme/calculators/pme.py b/src/torchpme/calculators/pme.py index 93c207a8..0d6742cd 100644 --- a/src/torchpme/calculators/pme.py +++ b/src/torchpme/calculators/pme.py @@ -1,3 +1,5 @@ +from typing import Optional + import torch from torch import profiler @@ -45,6 +47,8 @@ class PMECalculator(Calculator): set to :obj:`False`, a "half" neighbor list is expected. :param prefactor: electrostatics prefactor; see :ref:`prefactors` for details and common values. + :param dtype: type used for the internal buffers and parameters + :param device: device used for the internal buffers and parameters """ def __init__( @@ -54,11 +58,15 @@ def __init__( interpolation_nodes: int = 4, full_neighbor_list: bool = False, prefactor: float = 1.0, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, ): super().__init__( potential=potential, full_neighbor_list=full_neighbor_list, prefactor=prefactor, + dtype=dtype, + device=device, ) if potential.smearing is None: @@ -69,8 +77,8 @@ def __init__( self.mesh_spacing: float = mesh_spacing self.kspace_filter: KSpaceFilter = KSpaceFilter( - cell=torch.eye(3), - ns_mesh=torch.ones(3, dtype=int), + cell=torch.eye(3, dtype=self.dtype, device=self.device), + ns_mesh=torch.ones(3, dtype=int, device=self.device), kernel=self.potential, fft_norm="backward", ifft_norm="forward", @@ -79,8 +87,8 @@ def __init__( self.interpolation_nodes: int = interpolation_nodes self.mesh_interpolator: MeshInterpolator = MeshInterpolator( - cell=torch.eye(3), - ns_mesh=torch.ones(3, dtype=int), + cell=torch.eye(3, dtype=self.dtype, device=self.device), + ns_mesh=torch.ones(3, dtype=int, device=self.device), interpolation_nodes=self.interpolation_nodes, method="Lagrange", # convention for classic PME ) diff --git a/src/torchpme/potentials/combined.py b/src/torchpme/potentials/combined.py index 2c4e2612..d76a20c0 100644 --- a/src/torchpme/potentials/combined.py +++ b/src/torchpme/potentials/combined.py @@ -47,10 +47,7 @@ def __init__( dtype=dtype, device=device, ) - if dtype is None: - dtype = torch.get_default_dtype() - if device is None: - device = torch.device("cpu") + smearings = [pot.smearing for pot in potentials] if not all(smearings) and any(smearings): raise ValueError( @@ -76,7 +73,9 @@ def __init__( "The number of initial weights must match the number of potentials being combined" ) else: - initial_weights = torch.ones(len(potentials), dtype=dtype, device=device) + initial_weights = torch.ones( + len(potentials), dtype=self.dtype, device=self.device + ) # for torchscript self.potentials = torch.nn.ModuleList(potentials) if learnable_weights: diff --git a/src/torchpme/potentials/coulomb.py b/src/torchpme/potentials/coulomb.py index f121e38e..4cde5611 100644 --- a/src/torchpme/potentials/coulomb.py +++ b/src/torchpme/potentials/coulomb.py @@ -38,19 +38,17 @@ def __init__( device: Optional[torch.device] = None, ): super().__init__(smearing, exclusion_radius, dtype, device) - if dtype is None: - dtype = torch.get_default_dtype() - if device is None: - device = torch.device("cpu") # constants used in the forwward self.register_buffer( "_rsqrt2", - torch.rsqrt(torch.tensor(2.0, dtype=dtype, device=device)), + torch.rsqrt(torch.tensor(2.0, dtype=self.dtype, device=self.device)), ) self.register_buffer( "_sqrt_2_on_pi", - torch.sqrt(torch.tensor(2.0 / torch.pi, dtype=dtype, device=device)), + torch.sqrt( + torch.tensor(2.0 / torch.pi, dtype=self.dtype, device=self.device) + ), ) def from_dist(self, dist: torch.Tensor) -> torch.Tensor: diff --git a/src/torchpme/potentials/inversepowerlaw.py b/src/torchpme/potentials/inversepowerlaw.py index bd44236e..8b2449ca 100644 --- a/src/torchpme/potentials/inversepowerlaw.py +++ b/src/torchpme/potentials/inversepowerlaw.py @@ -53,15 +53,11 @@ def __init__( device: Optional[torch.device] = None, ): super().__init__(smearing, exclusion_radius, dtype, device) - if dtype is None: - dtype = torch.get_default_dtype() - if device is None: - device = torch.device("cpu") if exponent <= 0 or exponent > 3: raise ValueError(f"`exponent` p={exponent} has to satisfy 0 < p <= 3") self.register_buffer( - "exponent", torch.tensor(exponent, dtype=dtype, device=device) + "exponent", torch.tensor(exponent, dtype=self.dtype, device=self.device) ) @torch.jit.export diff --git a/src/torchpme/potentials/potential.py b/src/torchpme/potentials/potential.py index fa587896..b5e896ec 100644 --- a/src/torchpme/potentials/potential.py +++ b/src/torchpme/potentials/potential.py @@ -42,20 +42,18 @@ def __init__( device: Optional[torch.device] = None, ): super().__init__() - if dtype is None: - dtype = torch.get_default_dtype() - if device is None: - device = torch.device("cpu") + self.dtype = torch.get_default_dtype() if dtype is None else dtype + self.device = "cpu" if device is None else device if smearing is not None: self.register_buffer( - "smearing", torch.tensor(smearing, device=device, dtype=dtype) + "smearing", torch.tensor(smearing, device=self.device, dtype=self.dtype) ) else: self.smearing = None if exclusion_radius is not None: self.register_buffer( "exclusion_radius", - torch.tensor(exclusion_radius, device=device, dtype=dtype), + torch.tensor(exclusion_radius, device=self.device, dtype=self.dtype), ) else: self.exclusion_radius = None diff --git a/src/torchpme/potentials/spline.py b/src/torchpme/potentials/spline.py index a89120f4..f7a3ccd2 100644 --- a/src/torchpme/potentials/spline.py +++ b/src/torchpme/potentials/spline.py @@ -66,16 +66,12 @@ def __init__( dtype=dtype, device=device, ) - if dtype is None: - dtype = torch.get_default_dtype() - if device is None: - device = torch.device("cpu") if len(y_grid) != len(r_grid): raise ValueError("Length of radial grid and value array mismatch.") - r_grid = r_grid.to(dtype=dtype, device=device) - y_grid = y_grid.to(dtype=dtype, device=device) + r_grid = r_grid.to(dtype=self.dtype, device=self.device) + y_grid = y_grid.to(dtype=self.dtype, device=self.device) if reciprocal: if torch.min(r_grid) <= 0.0: @@ -93,7 +89,7 @@ def __init__( else: k_grid = r_grid.clone() else: - k_grid = k_grid.to(dtype=dtype, device=device) + k_grid = k_grid.to(dtype=self.dtype, device=self.device) if yhat_grid is None: # computes automatically! @@ -104,7 +100,7 @@ def __init__( compute_second_derivatives(r_grid, y_grid), ) else: - yhat_grid = yhat_grid.to(dtype=dtype, device=device) + yhat_grid = yhat_grid.to(dtype=self.dtype, device=self.device) # the function is defined for k**2, so we define the grid accordingly if reciprocal: @@ -115,13 +111,15 @@ def __init__( self._krn_spline = CubicSpline(k_grid**2, yhat_grid) if y_at_zero is None: - self._y_at_zero = self._spline(torch.zeros(1, dtype=dtype, device=device)) + self._y_at_zero = self._spline( + torch.zeros(1, dtype=self.dtype, device=self.device) + ) else: self._y_at_zero = y_at_zero if yhat_at_zero is None: self._yhat_at_zero = self._krn_spline( - torch.zeros(1, dtype=dtype, device=device) + torch.zeros(1, dtype=self.dtype, device=self.device) ) else: self._yhat_at_zero = yhat_at_zero diff --git a/tests/calculators/test_workflow.py b/tests/calculators/test_workflow.py index 8d2fcee2..7858b395 100644 --- a/tests/calculators/test_workflow.py +++ b/tests/calculators/test_workflow.py @@ -17,9 +17,7 @@ PMECalculator, ) -AVAILABLE_DEVICES = [torch.device("cpu")] + torch.cuda.is_available() * [ - torch.device("cuda") -] +AVAILABLE_DEVICES = ["cpu"] + torch.cuda.is_available() * ["cuda"] MADELUNG_CSCL = torch.tensor(2 * 1.7626 / math.sqrt(3)) CHARGES_CSCL = torch.tensor([1.0, -1.0]) SMEARING = 0.1 @@ -27,6 +25,7 @@ MESH_SPACING = SMEARING / 4 +@pytest.mark.parametrize("device", AVAILABLE_DEVICES) @pytest.mark.parametrize( ("CalculatorClass", "params"), [ @@ -79,49 +78,47 @@ def cscl_system(self, device=None): neighbor_distances.to(device=device), ) - def test_smearing_non_positive(self, CalculatorClass, params): + def test_smearing_non_positive(self, CalculatorClass, params, device): params = params.copy() match = r"`smearing` .* has to be positive" if type(CalculatorClass) in [EwaldCalculator, PMECalculator]: params["smearing"] = 0 with pytest.raises(ValueError, match=match): - CalculatorClass(**params) + CalculatorClass(**params, device=device) params["smearing"] = -0.1 with pytest.raises(ValueError, match=match): - CalculatorClass(**params) + CalculatorClass(**params, device=device) - def test_interpolation_order_error(self, CalculatorClass, params): + def test_interpolation_order_error(self, CalculatorClass, params, device): params = params.copy() if type(CalculatorClass) in [PMECalculator]: match = "Only `interpolation_nodes` from 1 to 5" params["interpolation_nodes"] = 10 with pytest.raises(ValueError, match=match): - CalculatorClass(**params) + CalculatorClass(**params, device=device) - def test_lr_wavelength_non_positive(self, CalculatorClass, params): + def test_lr_wavelength_non_positive(self, CalculatorClass, params, device): params = params.copy() match = r"`lr_wavelength` .* has to be positive" if type(CalculatorClass) in [EwaldCalculator]: params["lr_wavelength"] = 0 with pytest.raises(ValueError, match=match): - CalculatorClass(**params) + CalculatorClass(**params, device=device) params["lr_wavelength"] = -0.1 with pytest.raises(ValueError, match=match): - CalculatorClass(**params) + CalculatorClass(**params, device=device) - def test_dtype_device(self, CalculatorClass, params): + def test_dtype_device(self, CalculatorClass, params, device): """Test that the output dtype and device are the same as the input.""" - device = "cpu" dtype = torch.float64 - + params = params.copy() positions = torch.tensor([[0.0, 0.0, 0.0]], dtype=dtype, device=device) charges = torch.ones((1, 2), dtype=dtype, device=device) cell = torch.eye(3, dtype=dtype, device=device) neighbor_indices = torch.tensor([[0, 0]], device=device) neighbor_distances = torch.tensor([0.1], device=device) - - calculator = CalculatorClass(**params) - + params["potential"].device = device + calculator = CalculatorClass(**params, device=device) potential = calculator.forward( charges=charges, cell=cell, @@ -138,41 +135,48 @@ def check_operation(self, calculator, device): descriptor = calculator.forward(*self.cscl_system(device)) assert type(descriptor) is torch.Tensor - @pytest.mark.parametrize("device", AVAILABLE_DEVICES) def test_operation_as_python(self, CalculatorClass, params, device): """Run `check_operation` as a normal python script""" - calculator = CalculatorClass(**params) + params = params.copy() + params["potential"].device = device + calculator = CalculatorClass(**params, device=device) self.check_operation(calculator=calculator, device=device) - @pytest.mark.parametrize("device", AVAILABLE_DEVICES) def test_operation_as_torch_script(self, CalculatorClass, params, device): """Run `check_operation` as a compiled torch script module.""" - calculator = CalculatorClass(**params) + params = params.copy() + params["potential"].device = device + calculator = CalculatorClass(**params, device=device) scripted = torch.jit.script(calculator) self.check_operation(calculator=scripted, device=device) - def test_save_load(self, CalculatorClass, params): - calculator = CalculatorClass(**params) + def test_save_load(self, CalculatorClass, params, device): + params = params.copy() + params["potential"].device = device + calculator = CalculatorClass(**params, device=device) scripted = torch.jit.script(calculator) with io.BytesIO() as buffer: torch.jit.save(scripted, buffer) buffer.seek(0) torch.jit.load(buffer) - def test_prefactor(self, CalculatorClass, params): + def test_prefactor(self, CalculatorClass, params, device): """Test if the prefactor is applied correctly.""" + params = params.copy() + params["potential"].device = device prefactor = 2.0 - calculator1 = CalculatorClass(**params) - calculator2 = CalculatorClass(**params, prefactor=prefactor) + calculator1 = CalculatorClass(**params, device=device) + calculator2 = CalculatorClass(**params, prefactor=prefactor, device=device) potentials1 = calculator1.forward(*self.cscl_system()) potentials2 = calculator2.forward(*self.cscl_system()) assert torch.allclose(potentials1 * prefactor, potentials2) - def test_not_nan(self, CalculatorClass, params): + def test_not_nan(self, CalculatorClass, params, device): """Make sure derivatives are not NaN.""" - device = "cpu" + params = params.copy() + params["potential"].device = device - calculator = CalculatorClass(**params) + calculator = CalculatorClass(**params, device=device) system = self.cscl_system(device) system[0].requires_grad = True system[1].requires_grad = True @@ -198,3 +202,27 @@ def test_not_nan(self, CalculatorClass, params): assert not torch.isnan( torch.autograd.grad(energy, system[2], retain_graph=True)[0] ).any() + + def test_dtype_and_device_incompatability(self, CalculatorClass, params, device): + """Test that the calculator raises an error if the dtype and device are incompatible.""" + params = params.copy() + params["potential"].device = device + params["potential"].dtype = torch.float64 + with pytest.raises(AssertionError, match=".*dtype.*"): + CalculatorClass(**params, dtype=torch.float32, device=device) + with pytest.raises(AssertionError, match=".*device.*"): + CalculatorClass( + **params, dtype=params["potential"].dtype, device=torch.device("meta") + ) + + def test_potential_and_calculator_incompatability( + self, CalculatorClass, params, device + ): + """Test that the calculator raises an error if the potential and calculator are incompatible.""" + params = params.copy() + params["potential"].device = device + params["potential"] = torch.jit.script(params["potential"]) + with pytest.raises( + AssertionError, match="Potential must be an instance of Potential, got.*" + ): + CalculatorClass(**params) diff --git a/tests/test_potentials.py b/tests/test_potentials.py index 87d590bd..87670b3b 100644 --- a/tests/test_potentials.py +++ b/tests/test_potentials.py @@ -526,8 +526,7 @@ def test_combined_potentials_jit(smearing): ) combo = CombinedPotential(potentials=[spline, coulomb], dtype=dtype, smearing=1.0) - jitcombo = torch.jit.script(combo) - mypme = PMECalculator(jitcombo, mesh_spacing=1.0) + mypme = PMECalculator(combo, mesh_spacing=1.0, dtype=dtype) _ = torch.jit.script(mypme)