Skip to content

Commit

Permalink
Remove unused dtype and device parameters from Calculator and potenti…
Browse files Browse the repository at this point in the history
…al classes
  • Loading branch information
E-Rum committed Feb 4, 2025
1 parent c9111df commit 0cd2e80
Show file tree
Hide file tree
Showing 7 changed files with 7 additions and 30 deletions.
2 changes: 0 additions & 2 deletions src/torchpme/calculators/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ class Calculator(torch.nn.Module):
will come from a full (True) or half (False, default) neighbor list.
:param prefactor: electrostatics prefactor; see :ref:`prefactors` for details and
common values.
:param dtype: type used for the internal buffers and parameters
:param device: device used for the internal buffers and parameters
"""

def __init__(
Expand Down
6 changes: 0 additions & 6 deletions src/torchpme/calculators/ewald.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@ class EwaldCalculator(Calculator):
:obj:`False`, a "half" neighbor list is expected.
:param prefactor: electrostatics prefactor; see :ref:`prefactors` for details and
common values.
:param dtype: type used for the internal buffers and parameters
:param device: device used for the internal buffers and parameters
"""

def __init__(
Expand All @@ -65,15 +63,11 @@ def __init__(
lr_wavelength: float,
full_neighbor_list: bool = False,
prefactor: float = 1.0,
dtype: Optional[torch.dtype] = None,
device: Union[None, str, torch.device] = None,
):
super().__init__(
potential=potential,
full_neighbor_list=full_neighbor_list,
prefactor=prefactor,
dtype=dtype,
device=device,
)
if potential.smearing is None:
raise ValueError(
Expand Down
6 changes: 0 additions & 6 deletions src/torchpme/calculators/p3m.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ class P3MCalculator(PMECalculator):
set to :py:obj:`False`, a "half" neighbor list is expected.
:param prefactor: electrostatics prefactor; see :ref:`prefactors` for details and
common values.
:param dtype: type used for the internal buffers and parameters
:param device: device used for the internal buffers and parameters
For an **example** on the usage for any calculator refer to :ref:`userdoc-how-to`.
"""
Expand All @@ -55,8 +53,6 @@ def __init__(
interpolation_nodes: int = 4,
full_neighbor_list: bool = False,
prefactor: float = 1.0,
dtype: Optional[torch.dtype] = None,
device: Union[None, str, torch.device] = None,
):
self.mesh_spacing: float = mesh_spacing

Expand All @@ -68,8 +64,6 @@ def __init__(
potential=potential,
full_neighbor_list=full_neighbor_list,
prefactor=prefactor,
dtype=dtype,
device=device,
)

self.kspace_filter: P3MKSpaceFilter = P3MKSpaceFilter(
Expand Down
17 changes: 7 additions & 10 deletions src/torchpme/calculators/pme.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ class PMECalculator(Calculator):
set to :obj:`False`, a "half" neighbor list is expected.
:param prefactor: electrostatics prefactor; see :ref:`prefactors` for details and
common values.
:param dtype: type used for the internal buffers and parameters
:param device: device used for the internal buffers and parameters
"""

def __init__(
Expand All @@ -58,15 +56,11 @@ def __init__(
interpolation_nodes: int = 4,
full_neighbor_list: bool = False,
prefactor: float = 1.0,
dtype: Optional[torch.dtype] = None,
device: Union[None, str, torch.device] = None,
):
super().__init__(
potential=potential,
full_neighbor_list=full_neighbor_list,
prefactor=prefactor,
dtype=dtype,
device=device,
)

if potential.smearing is None:
Expand All @@ -76,9 +70,12 @@ def __init__(

self.mesh_spacing: float = mesh_spacing

self.register_buffer("cell", torch.eye(3))
ns_mesh = torch.ones(3, dtype=int, device=self.cell.device)

self.kspace_filter: KSpaceFilter = KSpaceFilter(
cell=torch.eye(3, dtype=self.dtype, device=self.device),
ns_mesh=torch.ones(3, dtype=int, device=self.device),
cell=self.cell,
ns_mesh=ns_mesh,
kernel=self.potential,
fft_norm="backward",
ifft_norm="forward",
Expand All @@ -87,8 +84,8 @@ def __init__(
self.interpolation_nodes: int = interpolation_nodes

self.mesh_interpolator: MeshInterpolator = MeshInterpolator(
cell=torch.eye(3, dtype=self.dtype, device=self.device),
ns_mesh=torch.ones(3, dtype=int, device=self.device),
cell=self.cell,
ns_mesh=ns_mesh,
interpolation_nodes=self.interpolation_nodes,
method="Lagrange", # convention for classic PME
)
Expand Down
2 changes: 0 additions & 2 deletions src/torchpme/potentials/combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ class CombinedPotential(Potential):
:param exclusion_radius: A length scale that defines a *local environment* within
which the potential should be smoothly zeroed out, as it will be described by a
separate model.
:param dtype: type used for the internal buffers and parameters
:param device: device used for the internal buffers and parameters
"""

def __init__(
Expand Down
2 changes: 0 additions & 2 deletions src/torchpme/potentials/coulomb.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ class CoulombPotential(Potential):
:param exclusion_radius: A length scale that defines a *local environment* within
which the potential should be smoothly zeroed out, as it will be described by a
separate model.
:param dtype: type used for the internal buffers and parameters
:param device: device used for the internal buffers and parameters
"""

def __init__(
Expand Down
2 changes: 0 additions & 2 deletions src/torchpme/potentials/spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ class SplinePotential(Potential):
:param exclusion_radius: A length scale that defines a *local environment* within
which the potential should be smoothly zeroed out, as it will be described by a
separate model.
:param dtype: type used for the internal buffers and parameters
:param device: device used for the internal buffers and parameters
"""

def __init__(
Expand Down

0 comments on commit 0cd2e80

Please sign in to comment.