Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Erase dtype and device #166

Merged
merged 10 commits into from
Feb 11, 2025
10 changes: 4 additions & 6 deletions docs/src/references/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,13 @@ changelog <https://keepachangelog.com/en/1.1.0/>`_ format. This project follows
Added
#####

* Enhanced ``device`` and ``dtype`` consistency checks throughout the library
* Better documentation for for ``cell``, ``charges`` and ``positions`` parameters
* Require consistent ``dtype`` between ``positions`` and ``neighbor_distances`` in
``Calculator`` classes and tuning functions.

Fixed
#####
Removed
#######

* Fix ``device`` and ``dtype`` not being used in the init of the ``P3MCalculator``
* Remove ``device`` and ``dtype`` from init of ``Calculator``, ``Potential`` and
``Tuning`` classes

`Version 0.2.0 <https://github.com/lab-cosmo/torch-pme/releases/tag/v0.2.0>`_ - 2025-01-23
------------------------------------------------------------------------------------------
Expand Down
9 changes: 4 additions & 5 deletions examples/01-charges-example.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@
cutoff=cutoff,
neighbor_indices=neighbor_indices,
neighbor_distances=neighbor_distances,
dtype=dtype,
)

# %%
Expand Down Expand Up @@ -103,9 +102,9 @@
# will be used to *compute* the potential energy of the system.

calculator = torchpme.PMECalculator(
torchpme.CoulombPotential(smearing=smearing, dtype=dtype), dtype=dtype, **pme_params
torchpme.CoulombPotential(smearing=smearing), **pme_params
)

calculator.to(dtype=dtype)
# %%
#
# Single Charge Channel
Expand Down Expand Up @@ -207,9 +206,9 @@
# creating a new calculator with the metatensor interface.

calculator_metatensor = torchpme.metatensor.PMECalculator(
torchpme.CoulombPotential(smearing=smearing, dtype=dtype), dtype=dtype, **pme_params
torchpme.CoulombPotential(smearing=smearing), **pme_params
)

calculator_metatensor.to(dtype=dtype)
# %%
#
# Computation with metatensor involves using Metatensor's :class:`System
Expand Down
4 changes: 1 addition & 3 deletions examples/02-neighbor-lists-usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@
cutoff=cutoff,
neighbor_indices=neighbor_indices,
neighbor_distances=neighbor_distances,
dtype=dtype,
)

# %%
Expand Down Expand Up @@ -195,8 +194,7 @@ def distances(
# compute the potential.

pme = torchpme.PMECalculator(
potential=torchpme.CoulombPotential(smearing=smearing, dtype=dtype),
dtype=dtype,
potential=torchpme.CoulombPotential(smearing=smearing),
**pme_params,
)
potential = pme(
Expand Down
2 changes: 1 addition & 1 deletion examples/07-lode-demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def __init__(self, potential: Potential, n_grid: int = 3):
)

# assumes a smooth exclusion region so sets the integration cutoff to half that
nodes, weights = get_full_grid(n_grid, potential.exclusion_radius.item() / 2)
nodes, weights = get_full_grid(n_grid, potential.exclusion_radius / 2)

# these are the "stencils" used to project the potential
# on an atom-centered basis. NB: weights might also be incorporated
Expand Down
14 changes: 8 additions & 6 deletions examples/08-combined-potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,12 @@
# evaluation, and so one has to set it also for the combined potential, even if it is
# not used explicitly in the evaluation of the combination.

pot_1 = InversePowerLawPotential(exponent=1, smearing=smearing, dtype=dtype)
pot_2 = InversePowerLawPotential(exponent=2, smearing=smearing, dtype=dtype)

potential = CombinedPotential(potentials=[pot_1, pot_2], smearing=smearing, dtype=dtype)
pot_1 = InversePowerLawPotential(exponent=1, smearing=smearing)
pot_2 = InversePowerLawPotential(exponent=2, smearing=smearing)
pot_1 = pot_1.to(dtype=dtype)
pot_2 = pot_2.to(dtype=dtype)
Comment on lines +72 to +73
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a comment why you need this here.

potential = CombinedPotential(potentials=[pot_1, pot_2], smearing=smearing)
potential = potential.to(dtype=dtype)

# Note also that :class:`CombinedPotential` can be used with any combination of
# potentials, as long they are all either direct or range separated. For instance, one
Expand Down Expand Up @@ -156,9 +158,9 @@
# much bigger system.

calculator = EwaldCalculator(
potential=potential, lr_wavelength=lr_wavelength, prefactor=eV_A, dtype=dtype
potential=potential, lr_wavelength=lr_wavelength, prefactor=eV_A
)

calculator.to(dtype=dtype)

# %%
#
Expand Down
21 changes: 4 additions & 17 deletions examples/10-tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,10 @@
pme_params = {"mesh_spacing": 1.0, "interpolation_nodes": 4}

pme = torchpme.PMECalculator(
potential=torchpme.CoulombPotential(smearing=smearing, device=device, dtype=dtype),
device=device,
dtype=dtype,
potential=torchpme.CoulombPotential(smearing=smearing),
**pme_params, # type: ignore[arg-type]
)

pme.to(device=device, dtype=dtype)
# %%
# Run the calculator
# ~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -170,8 +168,6 @@
neighbor_indices=neighbor_indices,
neighbor_distances=neighbor_distances,
run_backward=True,
device=device,
dtype=dtype,
)
estimated_timing = timings(pme)

Expand Down Expand Up @@ -220,14 +216,11 @@ def timed_madelung(cutoff, smearing, mesh_spacing, interpolation_nodes, device,
)

pme = torchpme.PMECalculator(
potential=torchpme.CoulombPotential(
smearing=smearing, device=device, dtype=dtype
),
potential=torchpme.CoulombPotential(smearing=smearing),
mesh_spacing=mesh_spacing,
interpolation_nodes=interpolation_nodes,
device=device,
dtype=dtype,
)
pme.to(device=device, dtype=dtype)
potential = pme(
charges=charges,
cell=cell,
Expand All @@ -247,8 +240,6 @@ def timed_madelung(cutoff, smearing, mesh_spacing, interpolation_nodes, device,
run_backward=True,
n_warmup=1,
n_repeat=4,
device=device,
dtype=dtype,
)
estimated_timing = timings(pme)
return madelung, estimated_timing
Expand Down Expand Up @@ -457,8 +448,6 @@ def timed_madelung(cutoff, smearing, mesh_spacing, interpolation_nodes, device,
cutoff=5.0,
neighbor_indices=neighbor_indices,
neighbor_distances=neighbor_distances,
device=device,
dtype=dtype,
)

print(
Expand Down Expand Up @@ -492,8 +481,6 @@ def timed_madelung(cutoff, smearing, mesh_spacing, interpolation_nodes, device,
cutoff=cutoff,
neighbor_indices=filter_indices,
neighbor_distances=filter_distances,
device=device,
dtype=dtype,
)
timings_grid.append(timing)

Expand Down
9 changes: 4 additions & 5 deletions examples/basic-usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@
# contains all the necessary functions (such as those defining the short-range and
# long-range splits) for this potential and makes them useable in the rest of the code.

potential = CoulombPotential(smearing=smearing, device=device, dtype=dtype)
potential = CoulombPotential(smearing=smearing)
potential.to(device=device, dtype=dtype)

# %%
#
Expand Down Expand Up @@ -193,10 +194,8 @@
# Since our structure is relatively small, we use the :class:`EwaldCalculator`.
# We start by the initialization of the class.

calculator = EwaldCalculator(
potential=potential, lr_wavelength=lr_wavelength, device=device, dtype=dtype
)

calculator = EwaldCalculator(potential=potential, lr_wavelength=lr_wavelength)
calculator.to(device=device, dtype=dtype)
# %%
#
# Compute Energy
Expand Down
52 changes: 13 additions & 39 deletions src/torchpme/_utils.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,18 @@
from typing import Optional, Union
from typing import Union

import torch


def _get_dtype(dtype: Optional[torch.dtype]) -> torch.dtype:
return torch.get_default_dtype() if dtype is None else dtype


def _get_device(device: Union[None, str, torch.device]) -> torch.device:
new_device = torch.get_default_device() if device is None else torch.device(device)

# Add default index of 0 to a cuda device to avoid errors when comparing with
# devices from tensors
if new_device.type == "cuda" and new_device.index is None:
new_device = torch.device("cuda:0")

return new_device


def _validate_parameters(
charges: torch.Tensor,
cell: torch.Tensor,
positions: torch.Tensor,
neighbor_indices: torch.Tensor,
neighbor_distances: torch.Tensor,
smearing: Union[float, None],
dtype: torch.dtype,
device: torch.device,
) -> None:
if positions.dtype != dtype:
raise TypeError(
f"type of `positions` ({positions.dtype}) must be same as the class "
f"type ({dtype})"
)

if positions.device != device:
raise ValueError(
f"device of `positions` ({positions.device}) must be same as the class "
f"device ({device})"
)
dtype = positions.dtype
device = positions.device

# check shape, dtype and device of positions
num_atoms = len(positions)
Expand All @@ -55,14 +29,14 @@ def _validate_parameters(
f"{list(cell.shape)}"
)

if cell.dtype != positions.dtype:
if cell.dtype != dtype:
raise TypeError(
f"type of `cell` ({cell.dtype}) must be same as the class ({dtype})"
f"type of `cell` ({cell.dtype}) must be same as that of the `positions` class ({dtype})"
)

if cell.device != device:
raise ValueError(
f"device of `cell` ({cell.device}) must be same as the class ({device})"
f"device of `cell` ({cell.device}) must be same as that of the `positions` class ({device})"
)

if smearing is not None and torch.equal(
Expand All @@ -89,14 +63,14 @@ def _validate_parameters(
f"{len(positions)} atoms"
)

if charges.dtype != positions.dtype:
if charges.dtype != dtype:
raise TypeError(
f"type of `charges` ({charges.dtype}) must be same as the class ({dtype})"
f"type of `charges` ({charges.dtype}) must be same as that of the `positions` class ({dtype})"
)

if charges.device != device:
raise ValueError(
f"device of `charges` ({charges.device}) must be same as the class "
f"device of `charges` ({charges.device}) must be same as that of the `positions` class "
f"({device})"
)

Expand All @@ -111,7 +85,7 @@ def _validate_parameters(
if neighbor_indices.device != device:
raise ValueError(
f"device of `neighbor_indices` ({neighbor_indices.device}) must be "
f"same as the class ({device})"
f"same as that of the `positions` class ({device})"
)

if neighbor_distances.shape != neighbor_indices[:, 0].shape:
Expand All @@ -124,11 +98,11 @@ def _validate_parameters(
if neighbor_distances.device != device:
raise ValueError(
f"device of `neighbor_distances` ({neighbor_distances.device}) must be "
f"same as the class ({device})"
f"same as that of the `positions` class ({device})"
)

if neighbor_distances.dtype != positions.dtype:
if neighbor_distances.dtype != dtype:
raise TypeError(
f"type of `neighbor_distances` ({neighbor_distances.dtype}) must be same "
f"as the class ({dtype})"
f"as that of the `positions` class ({dtype})"
)
26 changes: 1 addition & 25 deletions src/torchpme/calculators/calculator.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from typing import Optional, Union

import torch
from torch import profiler

from .._utils import _get_device, _get_dtype, _validate_parameters
from .._utils import _validate_parameters
from ..potentials import Potential


Expand All @@ -27,17 +25,13 @@ class Calculator(torch.nn.Module):
will come from a full (True) or half (False, default) neighbor list.
:param prefactor: electrostatics prefactor; see :ref:`prefactors` for details and
common values.
:param dtype: type used for the internal buffers and parameters
:param device: device used for the internal buffers and parameters
"""

def __init__(
self,
potential: Potential,
full_neighbor_list: bool = False,
prefactor: float = 1.0,
dtype: Optional[torch.dtype] = None,
device: Union[None, str, torch.device] = None,
):
super().__init__()

Expand All @@ -46,24 +40,8 @@ def __init__(
f"Potential must be an instance of Potential, got {type(potential)}"
)

self.device = _get_device(device)
self.dtype = _get_dtype(dtype)

if self.dtype != potential.dtype:
raise TypeError(
f"dtype of `potential` ({potential.dtype}) must be same as of "
f"`calculator` ({self.dtype})"
)

if self.device != potential.device:
raise ValueError(
f"device of `potential` ({potential.device}) must be same as of "
f"`calculator` ({self.device})"
)

self.potential = potential
self.full_neighbor_list = full_neighbor_list

self.prefactor = prefactor

def _compute_rspace(
Expand Down Expand Up @@ -164,8 +142,6 @@ def forward(
neighbor_indices=neighbor_indices,
neighbor_distances=neighbor_distances,
smearing=self.potential.smearing,
dtype=self.dtype,
device=self.device,
)

# Compute short-range (SR) part using a real space sum
Expand Down
Loading