Skip to content

Commit

Permalink
Helpstring update and tuning function return update
Browse files Browse the repository at this point in the history
  • Loading branch information
GardevoirX committed Jan 22, 2025
1 parent 3d00e44 commit 1cbffed
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 25 deletions.
7 changes: 6 additions & 1 deletion src/torchpme/tuning/ewald.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def tune_ewald(
:return: Tuple containing a float of the optimal smearing for the :class:
`CoulombPotential`, and a dictionary with the parameters for
:class:`EwaldCalculator`.
:class:`EwaldCalculator`, and the timing of this set of parameters.
Example
-------
Expand Down Expand Up @@ -109,6 +109,11 @@ def tune_ewald(
if any(err < accuracy for err in errs):
return smearing, params[timings.index(min(timings))]
# No parameter meets the requirement, return the one with the smallest error
warn(
f"No parameter meets the accuracy requirement.\n"
f"Returning the parameter with the smallest error, which is {min(errs)}.\n",
stacklevel=1,
)
return smearing, params[errs.index(min(errs))]


Expand Down
43 changes: 25 additions & 18 deletions src/torchpme/tuning/p3m.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math
from itertools import product
from typing import Optional
from warnings import warn

import torch

Expand Down Expand Up @@ -85,24 +86,24 @@ def tune_p3m(
r"""
Find the optimal parameters for :class:`torchpme.calculators.pme.PMECalculator`.
For the error formulas are given `here <https://doi.org/10.1063/1.477415>`_.
Note the difference notation between the parameters in the reference and ours:
For the error formulas are given `here <https://doi.org/10.1063/1.477415>`_. Note
the difference notation between the parameters in the reference and ours:
.. math::
\alpha = \left(\sqrt{2}\,\mathrm{smearing} \right)^{-1}
:param charges: torch.Tensor, atomic (pseudo-)charges
:param cell: torch.Tensor, periodic supercell for the system
:param positions: torch.Tensor, Cartesian coordinates of the particles within
the supercell.
:param positions: torch.Tensor, Cartesian coordinates of the particles within the
supercell.
:param neighbor_indices: torch.Tensor with the ``i,j`` indices of neighbors for
which the potential should be computed in real space.
:param neighbor_distances: torch.Tensor with the pair distances of the neighbors
for which the potential should be computed in real space.
:param cutoff: float, cutoff distance for the neighborlist
supported
:param exponent: :math:`p` in :math:`1/r^p` potentials, currently only :math:`p=1` is
:param neighbor_distances: torch.Tensor with the pair distances of the neighbors for
which the potential should be computed in real space.
:param cutoff: float, cutoff distance for the neighborlist supported
:param exponent: :math:`p` in :math:`1/r^p` potentials, currently only :math:`p=1`
is
:param nodes_lo: Minimum number of interpolation nodes
:param nodes_hi: Maximum number of interpolation nodes
:param mesh_lo: Minimum number of mesh points per axis
Expand All @@ -112,8 +113,8 @@ def tune_p3m(
:return: Tuple containing a float of the optimal smearing for the :py:class:
`CoulombPotential`, a dictionary with the parameters for
:py:class:`PMECalculator` and a float of the optimal cutoff value for the
neighborlist computation.
:py:class:`P3MCalculator` and a float of the optimal cutoff value for the
neighborlist computation, and the timing of this set of parameters.
Example
-------
Expand All @@ -134,7 +135,7 @@ def tune_p3m(
>>> neighbor_indices = torch.tensor(
... [[0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1]]
... )
>>> smearing, parameter = tune_p3m(
>>> smearing, parameter, timing = tune_p3m(
... charges,
... cell,
... positions,
Expand Down Expand Up @@ -174,13 +175,19 @@ def tune_p3m(
smearing = tuner.estimate_smearing(accuracy)
errs, timings = tuner.tune(accuracy)

# There are multiple errors below the accuracy, return the one with the shortest
# calculation time. The timing of those parameters leading to an higher error
# than the accuracy are set to infinity
if any(err < accuracy for err in errs):
# There are multiple errors below the accuracy, return the one with the shortest
# calculation time. The timing of those parameters leading to an higher error
# than the accuracy are set to infinity
return smearing, params[timings.index(min(timings))]
# No parameter meets the requirement, return the one with the smallest error
return smearing, params[errs.index(min(errs))]
return smearing, params[timings.index(min(timings))], min(timings)
# No parameter meets the requirement, return the one with the smallest error, and
# throw a warning
warn(
f"No parameter meets the accuracy requirement.\n"
f"Returning the parameter with the smallest error, which is {min(errs)}.\n",
stacklevel=1,
)
return smearing, params[errs.index(min(errs))], timings[errs.index(min(errs))]


class P3MErrorBounds(TuningErrorBounds):
Expand Down
20 changes: 14 additions & 6 deletions src/torchpme/tuning/pme.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math
from itertools import product
from typing import Optional
from warnings import warn

import torch

Expand Down Expand Up @@ -55,7 +56,8 @@ def tune_pme(
:return: Tuple containing a float of the optimal smearing for the :class:
`CoulombPotential`, a dictionary with the parameters for :class:`PMECalculator`
and a float of the optimal cutoff value for the neighborlist computation.
and a float of the optimal cutoff value for the neighborlist computation, and
the timing of this set of parameters.
Example
-------
Expand All @@ -76,7 +78,7 @@ def tune_pme(
>>> neighbor_indices = torch.tensor(
... [[0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1]]
... )
>>> smearing, parameter = tune_pme(
>>> smearing, parameter, timing = tune_pme(
... charges,
... cell,
... positions,
Expand Down Expand Up @@ -116,12 +118,18 @@ def tune_pme(
smearing = tuner.estimate_smearing(accuracy)
errs, timings = tuner.tune(accuracy)

# There are multiple errors below the accuracy, return the one with the shortest
# calculation time. The timing of those parameters leading to an higher error
# than the accuracy are set to infinity
if any(err < accuracy for err in errs):
# There are multiple errors below the accuracy, return the one with the shortest
# calculation time. The timing of those parameters leading to an higher error
# than the accuracy are set to infinity
return smearing, params[timings.index(min(timings))], min(timings)
# No parameter meets the requirement, return the one with the smallest error
# No parameter meets the requirement, return the one with the smallest error, and
# throw a warning
warn(
f"No parameter meets the accuracy requirement.\n"
f"Returning the parameter with the smallest error, which is {min(errs)}.\n",
stacklevel=1,
)
return smearing, params[errs.index(min(errs))], timings[errs.index(min(errs))]


Expand Down

0 comments on commit 1cbffed

Please sign in to comment.