diff --git a/docs/src/references/changelog.rst b/docs/src/references/changelog.rst index 3e405a23..45b37452 100644 --- a/docs/src/references/changelog.rst +++ b/docs/src/references/changelog.rst @@ -29,12 +29,16 @@ Added * Added a PyTorch implementation of the exponential integral function * Added ``dtype`` and ``device`` for ``Calculator`` classses +* Added an example on the tuning scheme and usage, and how to optimize the ``cutoff`` Changed ####### * Removed ``utils`` module. ``utils.tuning`` and ``utils.prefactor`` are now in the root of the package. ``utils.splines`` is now in the ``lib`` module. +* The tuning now uses a grid-search based scheme, instead of a gradient based scheme. +* The tuning functions no longer takes the ``cutoff`` parameter, and thus does not + support a built-in NL calculation. Fixed ##### diff --git a/docs/src/references/index.rst b/docs/src/references/index.rst index c8f2112d..01470eac 100644 --- a/docs/src/references/index.rst +++ b/docs/src/references/index.rst @@ -22,7 +22,7 @@ refer to the :ref:`userdoc-how-to` section. potentials/index calculators/index - tuning + tuning/index prefactors metatensor lib/index diff --git a/docs/src/references/tuning.rst b/docs/src/references/tuning.rst deleted file mode 100644 index fdb44d5a..00000000 --- a/docs/src/references/tuning.rst +++ /dev/null @@ -1,22 +0,0 @@ -Tuning -###### - -The choice of parameters like the neighborlist ``cutoff``, the ``smearing`` or the -``lr_wavelength``/``mesh_spacing`` has a large influence one the accuracy of the -calculation. To help find the parameters that meet the accuracy requirements, this -module offers tuning methods for the calculators. - -The scheme behind all tuning functions is a gradient-based optimization, which tries to -find the minimal of the error estimation formula and stops after the error is smaller -than the given accuracy. Because these methods are gradient-based, be sure to pay -attention to the ``learning_rate`` and ``max_steps`` parameter. A good choice of these -two parameters can enhance the optimization speed and performance. - -.. autoclass:: torchpme.tuning.tune_ewald - :members: - -.. autoclass:: torchpme.tuning.tune_pme - :members: - -.. autoclass:: torchpme.tuning.tune_p3m - :members: diff --git a/docs/src/references/tuning/base_classes.rst b/docs/src/references/tuning/base_classes.rst new file mode 100644 index 00000000..0d5f202a --- /dev/null +++ b/docs/src/references/tuning/base_classes.rst @@ -0,0 +1,52 @@ +Base Classes +############ + +Current scheme behind all tuning functions is grid-searching based, focusing on the Fourier +space parameters like ``lr_wavelength``, ``mesh_spacing`` and ``interpolation_nodes``. +For real space parameter ``cutoff``, it is treated as a hyperparameter here, which +should be manually specified by the user. The parameter ``smearing`` is determined by +the real space error formula and is set to achieve a real space error of +``desired_accuracy / 4``. + +The Fourier space parameters are all discrete, so it's convenient to do the grid-search. +Default searching-ranges are provided for those parameters. For ``lr_wavelength``, the +values are chosen to be with a minimum of 1 and a maximum of 13 mesh points in each +spatial direction ``(x, y, z)``. For ``mesh_spacing``, the values are set to have +minimally 2 and maximally 7 mesh points in each spatial direction, for both the P3M and +PME method. The values of ``interpolation_nodes`` are the same as those supported in +:class:`torchpme.lib.MeshInterpolator`. + +In the grid-searching, all possible parameter combinations are evaluated. The error +associated with the parameter is estimated by the error formulas implemented in the +subclasses of :class:`torchpme.tuning.tuner.TuningErrorBounds`. Parameter with +the error within the desired accuracy are benchmarked for computational time by +:class:`torchpme.tuning.tuner.TuningTimings` The timing of the other parameters are +not tested and set to infinity. + +The return of these tuning functions contains the ``smearing`` and a dictionary, in +which there is parameter for the Fourier space. The parameter is that of the desired +accuracy and the shortest timing. The parameter of the smallest error will be returned +in the case that no parameter can fulfill the accuracy requirement. + + +.. autoclass:: torchpme.tuning.tuner.TunerBase + :members: + +.. autoclass:: torchpme.tuning.tuner.GridSearchTuner + :members: + +.. autoclass:: torchpme.tuning.tuner.TuningTimings + :members: + +.. autoclass:: torchpme.tuning.tuner.TuningErrorBounds + :members: + +Examples using Tuning Classes +----------------------------- + +.. minigallery:: + + torchpme.tuning.tuner.TunerBase + torchpme.tuning.tuner.GridSearchTuner + torchpme.tuning.tuner.TuningTimings + torchpme.tuning.tuner.TuningErrorBounds diff --git a/docs/src/references/tuning/index.rst b/docs/src/references/tuning/index.rst new file mode 100644 index 00000000..40f29302 --- /dev/null +++ b/docs/src/references/tuning/index.rst @@ -0,0 +1,24 @@ +Tuning +###### + +The choice of parameters like the neighborlist ``cutoff``, the ``smearing`` or the +``lr_wavelength``/``mesh_spacing`` has a large influence one the accuracy of the +calculation. To help find the parameters that meet the accuracy requirements, this +module offers tuning methods for the calculators. + +For usual tuning procedures we provide simple functions like +:func:`torchpme.tuning.tune_ewald` that returns for a given system the optimal +parameters for the Ewald summation. For more complex tuning procedures, we provide +classes like :class:`torchpme.tuning.ewald.EwaldErrorBounds` that can be used to +implement custom tuning procedures. + +.. important:: + + Current tuning methods are only implemented for the :class:`Coulomb potential + `. + +.. toctree:: + :maxdepth: 1 + :glob: + + ./* diff --git a/docs/src/references/tuning/tune_ewald.rst b/docs/src/references/tuning/tune_ewald.rst new file mode 100644 index 00000000..120a0d25 --- /dev/null +++ b/docs/src/references/tuning/tune_ewald.rst @@ -0,0 +1,24 @@ +Tune Ewald +########## + +The tuning is based on the following error formulas: + +.. math:: + \Delta F_\mathrm{real} + \approx \frac{Q^2}{\sqrt{N}} + \frac{2}{\sqrt{r_{\text{cutoff}} V}} + e^{-r_{\text{cutoff}}^2 / 2 \sigma^2} + +.. math:: + \Delta F_\mathrm{Fourier}^\mathrm{Ewald} + \approx \frac{Q^2}{\sqrt{N}} + \frac{\sqrt{2} / \sigma}{\pi\sqrt{2 V / h}} e^{-2\pi^2 \sigma^2 / h ^ 2} + +where :math:`N` is the number of charges, :math:`Q^2 = \sum_{i = 1}^N q_i^2`, is the sum of squared +charges, :math:`r_{\text{cutoff}}` is the short-range cutoff, :math:`V` is the volume of the +simulation box and :math:`h^2` is the long range wavelength. + +.. autofunction:: torchpme.tuning.tune_ewald + +.. autoclass:: torchpme.tuning.ewald.EwaldErrorBounds + :members: diff --git a/docs/src/references/tuning/tune_p3m.rst b/docs/src/references/tuning/tune_p3m.rst new file mode 100644 index 00000000..93e7c05b --- /dev/null +++ b/docs/src/references/tuning/tune_p3m.rst @@ -0,0 +1,27 @@ +Tune P3M +######### + +The tuning is based on the following error formulas: + +.. math:: + \Delta F_\mathrm{real} + \approx \frac{Q^2}{\sqrt{N}} + \frac{2}{\sqrt{r_{\text{cutoff}} V}} + e^{-r_{\text{cutoff}}^2 / 2 \sigma^2} + +.. math:: + \Delta F_\mathrm{Fourier}^\mathrm{P3M} + \approx \frac{Q^2}{L^2}(\frac{\sqrt{2}H}{\sigma})^p + \sqrt{\frac{\sqrt{2}L}{N\sigma} + \sqrt{2\pi}\sum_{m=0}^{p-1}a_m^{(p)}(\frac{\sqrt{2}H}{\sigma})^{2m}} + +where :math:`N` is the number of charges, :math:`Q^2 = \sum_{i = 1}^N q_i^2`, is the sum of squared +charges, :math:`r_{\text{cutoff}}` is the short-range cutoff, :math:`V` is the volume of the +simulation box, :math:`p` is the order of the interpolation scheme, :math:`H` is the spacing of mesh +points and :math:`a_m^{(p)}` is an expansion coefficient. + + +.. autofunction:: torchpme.tuning.tune_p3m + +.. autoclass:: torchpme.tuning.p3m.P3MErrorBounds + :members: diff --git a/docs/src/references/tuning/tune_pme.rst b/docs/src/references/tuning/tune_pme.rst new file mode 100644 index 00000000..aa4a1ea2 --- /dev/null +++ b/docs/src/references/tuning/tune_pme.rst @@ -0,0 +1,27 @@ +Tune PME +######### + +The tuning is based on the following error formulas: + +.. math:: + \Delta F_\mathrm{real} + \approx \frac{Q^2}{\sqrt{N}} + \frac{2}{\sqrt{r_{\text{cutoff}} V}} + e^{-r_{\text{cutoff}}^2 / 2 \sigma^2} + +.. math:: + \Delta F_\mathrm{Fourier}^\mathrm{PME} + \approx 2\pi^{1/4}\sqrt{\frac{3\sqrt{2} / \sigma}{N(2p+3)}} + \frac{Q^2}{L^2}\frac{(\sqrt{2}H/\sigma)^{p+1}}{(p+1)!} \times + \exp{\frac{(p+1)[\log{(p+1)} - \log 2 - 1]}{2}} \left< \phi_p^2 \right> ^{1/2} + +where :math:`N` is the number of charges, :math:`Q^2 = \sum_{i = 1}^N q_i^2`, is the sum of squared +charges, :math:`r_{\text{cutoff}}` is the short-range cutoff, :math:`V` is the volume of the +simulation box, :math:`p` is the order of the interpolation scheme, :math:`H` is the spacing of mesh +points, and :math:`\phi_p^2 = H^{-(p+1)}\prod_{s\in S_H^{(p)}}(x - s)`, in which :math:`S_H^{(p)}` is +the :math:`p+1` mesh points closest to the point :math:`x`. + +.. autofunction:: torchpme.tuning.tune_pme + +.. autoclass:: torchpme.tuning.pme.PMEErrorBounds + :members: diff --git a/examples/1-charges-example.py b/examples/01-charges-example.py similarity index 94% rename from examples/1-charges-example.py rename to examples/01-charges-example.py index 1c112ce4..92d82949 100644 --- a/examples/1-charges-example.py +++ b/examples/01-charges-example.py @@ -37,6 +37,7 @@ from metatensor.torch.atomistic import NeighborListOptions, System import torchpme +from torchpme.tuning import tune_pme # %% # @@ -44,6 +45,7 @@ symbols = ("Cs", "Cl") types = torch.tensor([55, 17]) +charges = torch.tensor([[1.0], [-1.0]], dtype=torch.float64) positions = torch.tensor([(0, 0, 0), (0.5, 0.5, 0.5)], dtype=torch.float64) cell = torch.eye(3, dtype=torch.float64) pbc = torch.tensor([True, True, True]) @@ -55,8 +57,21 @@ # The ``sum_squared_charges`` is equal to ``2.0`` becaue each atom either has a charge # of 1 or -1 in units of elementary charges. -smearing, pme_params, cutoff = torchpme.tuning.tune_pme( - sum_squared_charges=2.0, cell=cell, positions=positions +cutoff = 4.4 +nl = vesin.torch.NeighborList(cutoff=cutoff, full_list=False) +neighbor_indices, neighbor_distances = nl.compute( + points=positions.to(dtype=torch.float64, device="cpu"), + box=cell.to(dtype=torch.float64, device="cpu"), + periodic=True, + quantities="Pd", +) +smearing, pme_params, _ = tune_pme( + charges=charges, + cell=cell, + positions=positions, + cutoff=cutoff, + neighbor_indices=neighbor_indices, + neighbor_distances=neighbor_distances, ) # %% diff --git a/examples/2-neighbor-lists-usage.py b/examples/02-neighbor-lists-usage.py similarity index 93% rename from examples/2-neighbor-lists-usage.py rename to examples/02-neighbor-lists-usage.py index 68a7cf53..322e6a69 100644 --- a/examples/2-neighbor-lists-usage.py +++ b/examples/02-neighbor-lists-usage.py @@ -46,6 +46,7 @@ import vesin.torch import torchpme +from torchpme.tuning import tune_pme # %% # @@ -92,9 +93,22 @@ cell = torch.from_numpy(atoms.cell.array) sum_squared_charges = float(torch.sum(charges**2)) +cutoff = 4.4 +nl = vesin.torch.NeighborList(cutoff=cutoff, full_list=False) +neighbor_indices, neighbor_distances = nl.compute( + points=positions.to(dtype=torch.float64, device="cpu"), + box=cell.to(dtype=torch.float64, device="cpu"), + periodic=True, + quantities="Pd", +) -smearing, pme_params, cutoff = torchpme.tuning.tune_pme( - sum_squared_charges=sum_squared_charges, cell=cell, positions=positions +smearing, pme_params, _ = tune_pme( + charges=charges, + cell=cell, + positions=positions, + cutoff=cutoff, + neighbor_indices=neighbor_indices, + neighbor_distances=neighbor_distances, ) # %% diff --git a/examples/3-mesh-demo.py b/examples/03-mesh-demo.py similarity index 100% rename from examples/3-mesh-demo.py rename to examples/03-mesh-demo.py diff --git a/examples/4-kspace-demo.py b/examples/04-kspace-demo.py similarity index 100% rename from examples/4-kspace-demo.py rename to examples/04-kspace-demo.py diff --git a/examples/5-autograd-demo.py b/examples/05-autograd-demo.py similarity index 98% rename from examples/5-autograd-demo.py rename to examples/05-autograd-demo.py index 63ceff11..b44416be 100644 --- a/examples/5-autograd-demo.py +++ b/examples/05-autograd-demo.py @@ -17,6 +17,8 @@ exercise to the reader. """ +# %% + from time import time import ase @@ -477,10 +479,11 @@ def forward(self, positions, cell, charges): ) # %% -# We can also time the difference in execution +# We can also evaluate the difference in execution # time between the Pytorch and scripted versions of the # module (depending on the system, the relative efficiency -# of the two evaluations could go either way!) +# of the two evaluations could go either way, as this is +# a too small system to make a difference!) duration = 0.0 for _i in range(20): @@ -513,5 +516,3 @@ def forward(self, positions, cell, charges): # %% print(f"Evaluation time:\nPytorch: {time_python}ms\nJitted: {time_jit}ms") - -# %% diff --git a/examples/6-splined-potential.py b/examples/06-splined-potential.py similarity index 100% rename from examples/6-splined-potential.py rename to examples/06-splined-potential.py diff --git a/examples/7-lode-demo.py b/examples/07-lode-demo.py similarity index 100% rename from examples/7-lode-demo.py rename to examples/07-lode-demo.py diff --git a/examples/8-combined-potential.py b/examples/08-combined-potential.py similarity index 100% rename from examples/8-combined-potential.py rename to examples/08-combined-potential.py diff --git a/examples/9-atomistic-model.py b/examples/09-atomistic-model.py similarity index 100% rename from examples/9-atomistic-model.py rename to examples/09-atomistic-model.py diff --git a/examples/10-tuning.py b/examples/10-tuning.py new file mode 100644 index 00000000..c2d61881 --- /dev/null +++ b/examples/10-tuning.py @@ -0,0 +1,490 @@ +r""" +Parameter tuning for range-separated models +=========================================== + +.. currentmodule:: torchpme + +:Authors: Michele Ceriotti `@ceriottm `_ + +Metods to compute efficiently a long-range potential :math:`v(r)` +usually rely on partitioning it into a short-range part, evaluated +as a sum over neighbor pairs, and a long-range part evaluated +in reciprocal space + +.. math:: + + v(r)= v_{\mathrm{SR}}(r) + v_{\mathrm{LR}}(r) + +The overall cost depend on the balance of multiple factors, that +we summarize here briefly to explain how the cost of evaluating +:math:`v(r)` can be minimized, either manually or automatically. +""" + +# %% +# Import modules + +import ase +import matplotlib as mpl +import matplotlib.pyplot as plt +import numpy as np +import torch +import vesin.torch as vesin + +import torchpme +from torchpme.tuning.pme import PMEErrorBounds, tune_pme +from torchpme.tuning.tuner import TuningTimings + +device = "cpu" +dtype = torch.float64 +rng = torch.Generator() +rng.manual_seed(42) + +# get_ipython().run_line_magic("matplotlib", "inline") # type: ignore # noqa + +# %% +# Set up a test system, a supercell containing atoms with a NaCl structure + +madelung_ref = 1.7475645946 +structure = ase.Atoms( + positions=[ + [0, 0, 0], + [1, 0, 0], + [0, 1, 0], + [1, 1, 0], + [0, 0, 1], + [1, 0, 1], + [0, 1, 1], + [1, 1, 1], + ], + cell=[2, 2, 2], + symbols="NaClClNaClNaNaCl", +) +structure = structure.repeat([2, 2, 2]) +num_formula_units = len(structure) // 2 + +# Uncomment these to add a displacement (energy won't match the Madelung constant) +# displacement = torch.normal( +# mean=0.0, std=2.5e-1, size=(len(structure), 3), generator=rng +# ) +# structure.positions += displacement.numpy() + +positions = torch.from_numpy(structure.positions).to(device=device, dtype=dtype) +cell = torch.from_numpy(structure.cell.array).to(device=device, dtype=dtype) + +charges = torch.tensor( + [[1.0], [-1.0], [-1.0], [1.0], [-1.0], [1.0], [1.0], [-1.0]] + * (len(structure) // 8), + dtype=dtype, + device=device, +).reshape(-1, 1) + +# Uncomment these to randomize charges (energy won't match the Madelung constant) +# charges += torch.normal(mean=0.0, std=1e-1, size=(len(charges), 1), generator=rng) + +# %% +# +# We also need to evaluate the neighbor list; this is usually pre-computed +# by the code that calls `torch-pme`, and entails the first key parameter: +# the cutoff used to compute the real-space potential :math:`v_\mathrm{SR}(r)` + + +max_cutoff = 16.0 + +# use `vesin` +nl = vesin.NeighborList(cutoff=max_cutoff, full_list=False) +i, j, S, d = nl.compute(points=positions, box=cell, periodic=True, quantities="ijSd") +neighbor_indices = torch.stack([i, j], dim=1) +neighbor_shifts = S +neighbor_distances = d + + +# %% +# Demonstrate errors and timings for PME +# -------------------------------------- +# +# To set up a PME calculation, we need to define its basic parameters and +# setup a few preliminary quantities. +# + +# %% +# +# The PME calculator has a few further parameters: ``smearing``, that determines +# aggressive is the smoothing of the point charges. This makes the reciprocal-space +# part easier to compute, but makes :math:`v_\mathrm{SR}(r)` decay more slowly, +# and error that we shall investigate further later on. +# The mesh parameters involve both the spacing and the order of the interpolation +# used. Note that here we use :class:`CoulombPotential`, that computes a simple +# :math:`1/r` electrostatic interaction. + +smearing = 1.0 +pme_params = {"mesh_spacing": 1.0, "interpolation_nodes": 4} + +pme = torchpme.PMECalculator( + potential=torchpme.CoulombPotential(smearing=smearing), + **pme_params, # type: ignore[arg-type] +) + +# %% +# Run the calculator +# ~~~~~~~~~~~~~~~~~~ +# +# We combine the structure data and the neighbor list information to +# compute the potential at the particle positions, and then the +# energy + +potential = pme( + charges=charges, + cell=cell, + positions=positions, + neighbor_indices=neighbor_indices, + neighbor_distances=neighbor_distances, +) + +energy = charges.T @ potential +madelung = (-energy / num_formula_units).flatten().item() + +# %% +# Compute error bounds (and timings) +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Here we calculate the potential energy of the system, and compare it with the +# madelung constant to calculate the error. This is the actual error. Then we use +# the :class:`torchpme.tuning.pme.PMEErrorBounds` to calculate the error bound for +# PME. +# Error bounds are computed explicitly for a target structure +error_bounds = PMEErrorBounds(charges, cell, positions) + +estimated_error = error_bounds( + cutoff=max_cutoff, smearing=smearing, **pme_params +).item() + +# %% +# ... and a similar class can be used to estimate the timings, that are assessed +# based on a calculator (that should be initialized with the same parameters) +timings = TuningTimings( + charges, + cell, + positions, + neighbor_indices=neighbor_indices, + neighbor_distances=neighbor_distances, + run_backward=True, +) +estimated_timing = timings(pme) + +# %% +# The error bound is estimated for the force acting on atoms, and is +# expressed in force units - hence, the comparison with the Madelung constant +# error can only be qualitative. + +print( + f""" +Computed madelung constant: {madelung} +Actual error: {madelung - madelung_ref} +Estimated error: {estimated_error} +Timing: {estimated_timing} seconds +""" +) + +# %% +# Optimizing the parameters of PME +# -------------------------------- +# +# There are many parameters that enter the implementation +# of a range-separated calculator like PME, and it is necessary +# to optimize them to obtain the best possible accuracy/cost tradeoff. +# In most practical use cases, the cutoff is dictated by the external +# calculator and is treated as a fixed parameter. In cases where +# performance is critical, one may want to optimize this separately, +# which can be achieved easily with a grid or binary search. +# +# We can set up easily a brute-force evaluation of the error as a +# function of these parameters, and use it to guide the design of +# a more sophisticated optimization protocol. + + +def filter_neighbors(cutoff, neighbor_indices, neighbor_distances): + assert cutoff <= max_cutoff + + filter_idx = torch.where(neighbor_distances <= cutoff) + + return neighbor_indices[filter_idx], neighbor_distances[filter_idx] + + +def timed_madelung(cutoff, smearing, mesh_spacing, interpolation_nodes): + filter_indices, filter_distances = filter_neighbors( + cutoff, neighbor_indices, neighbor_distances + ) + + pme = torchpme.PMECalculator( + potential=torchpme.CoulombPotential(smearing=smearing), + mesh_spacing=mesh_spacing, + interpolation_nodes=interpolation_nodes, + ) + potential = pme( + charges=charges, + cell=cell, + positions=positions, + neighbor_indices=filter_indices, + neighbor_distances=filter_distances, + ) + energy = charges.T @ potential + madelung = (-energy / num_formula_units).flatten().item() + + timings = TuningTimings( + charges, + cell, + positions, + neighbor_indices=filter_indices, + neighbor_distances=filter_distances, + run_backward=True, + n_warmup=1, + n_repeat=4, + ) + estimated_timing = timings(pme) + return madelung, estimated_timing + + +smearing_grid = torch.logspace(-1, 0.5, 8) +spacing_grid = torch.logspace(-1, 0.5, 9) +results = np.zeros((len(smearing_grid), len(spacing_grid))) +timings = np.zeros((len(smearing_grid), len(spacing_grid))) +bounds = np.zeros((len(smearing_grid), len(spacing_grid))) +for ism, smearing in enumerate(smearing_grid): + for isp, spacing in enumerate(spacing_grid): + results[ism, isp], timings[ism, isp] = timed_madelung(8.0, smearing, spacing, 4) + bounds[ism, isp] = error_bounds(8.0, smearing, spacing, 4) + +# %% +# We now plot the error landscape. The estimated error can be seen as a upper bound of +# the actual error. Though the magnitude of the estimated error is higher than the +# actual error, the trend is the same. Also, from the timing results, we can see that +# the timing increases as the spacing decreases, while the smearing does not affect the +# timing, because the interactions are computed up to the fixed cutoff regardless of +# whether :math:`v_\mathrm{sr}(r)` is negligible or large. + +vmin = 1e-12 +vmax = 2 +levels = np.geomspace(vmin, vmax, 30) + +fig, ax = plt.subplots(1, 3, figsize=(9, 3), sharey=True, constrained_layout=True) +contour = ax[0].contourf( + spacing_grid, + smearing_grid, + bounds, + vmin=vmin, + vmax=vmax, + levels=levels, + norm=mpl.colors.LogNorm(), + extend="both", +) +ax[0].set_xscale("log") +ax[0].set_yscale("log") +ax[0].set_ylabel(r"$\sigma$ / Å") +ax[0].set_xlabel(r"spacing / Å") +ax[0].set_title("estimated error") +cbar = fig.colorbar(contour, ax=ax[1], label="error") +cbar.ax.set_yscale("log") + +contour = ax[1].contourf( + spacing_grid, + smearing_grid, + np.abs(results - madelung_ref), + vmin=vmin, + vmax=vmax, + levels=levels, + norm=mpl.colors.LogNorm(), + extend="both", +) +ax[1].set_xscale("log") +ax[1].set_yscale("log") +ax[1].set_xlabel(r"spacing / Å") +ax[1].set_title("actual error") + +contour = ax[2].contourf( + spacing_grid, + smearing_grid, + timings, + levels=np.geomspace(1e-2, 5e-1, 20), + norm=mpl.colors.LogNorm(), +) +ax[2].set_xscale("log") +ax[2].set_yscale("log") +ax[2].set_ylabel(r"$\sigma$ / Å") +ax[2].set_xlabel(r"spacing / Å") +ax[2].set_title("actual timing") +cbar = fig.colorbar(contour, ax=ax[2], label="time / s") +cbar.ax.set_yscale("log") + +# %% +# Optimizing the smearing +# ~~~~~~~~~~~~~~~~~~~~~~~ +# The error is a sum of an error on the real-space evaluation of the +# short-range potential, and of a long-range error. Considering the +# cutoff as given, the short-range error is determined easily by how +# quickly :math:`v_\mathrm{sr}(r)` decays to zero, which depends on +# the Gaussian smearing. + +smearing_grid = torch.logspace(-0.6, 1, 20) +err_vsr_grid = error_bounds.err_rspace(smearing_grid, torch.tensor([5.0])) +err_vlr_grid_4 = [ + error_bounds.err_kspace( + torch.tensor([s]), torch.tensor([1.0]), torch.tensor([4], dtype=int) + ) + for s in smearing_grid +] +err_vlr_grid_2 = [ + error_bounds.err_kspace( + torch.tensor([s]), torch.tensor([1.0]), torch.tensor([3], dtype=int) + ) + for s in smearing_grid +] + +fig, ax = plt.subplots(1, 1, figsize=(4, 3), constrained_layout=True) +ax.loglog(smearing_grid, err_vsr_grid, "r-", label="real-space") +ax.loglog(smearing_grid, err_vlr_grid_4, "b-", label="k-space (spacing: 1Å, n.int.: 4)") +ax.loglog(smearing_grid, err_vlr_grid_2, "c-", label="k-space (spacing: 1Å, n.int.: 2)") +ax.set_ylabel(r"estimated error / a.u.") +ax.set_xlabel(r"smearing / Å") +ax.set_title("cutoff = 5.0 Å") +ax.set_ylim(1e-20, 2) +ax.legend() + +# %% +# Given the simple, monotonic and fast-varying trend for the real-space error, +# it is easy to pick the optimal smearing as the value corresponding to roughly +# half of the target error -e.g. for a target accuracy of :math:`1e^{-5}`, +# one would pick a smearing of about 1Å. Given that usually there is a +# cost/accuracy tradeoff, and smaller smearings make the reciprocal-space evaluation +# more costly, the largest smearing is the best choice here. + +# %% +# Optimizing mesh and interpolation order +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# Once the smearing value that gives an acceptable accuracy for the real-space +# component has been determined, there may be other parameters that need to be +# optimized. One way to do this is to perform a grid search, and pick, among the +# parameters that yield an error below the threshold, those that empirically lead +# to the fastest evaluation. + +spacing_grid = torch.logspace(-1, 1, 10) +nint_grid = [3, 4, 5, 6] +results = np.zeros((len(nint_grid), len(spacing_grid))) +timings = np.zeros((len(nint_grid), len(spacing_grid))) +bounds = np.zeros((len(nint_grid), len(spacing_grid))) +for inint, nint in enumerate(nint_grid): + for isp, spacing in enumerate(spacing_grid): + results[inint, isp], timings[inint, isp] = timed_madelung( + 5.0, 1.0, spacing, nint + ) + bounds[inint, isp] = error_bounds(5.0, 1.0, spacing, nint) + + +fig, ax = plt.subplots(1, 2, figsize=(8, 3), constrained_layout=True) +colors = ["r", "#AA0066", "#6600AA", "b"] +labels = [ + "smearing 1Å, n.int: 3", + "smearing 1Å, n.int: 4", + "smearing 1Å, n.int: 5", + "smearing 1Å, n.int: 6", +] + +# Plot original lines on ax[0] +for i in range(4): + ax[0].loglog(spacing_grid, bounds[i], "-", color=colors[i], label=labels[i]) + ax[1].loglog(spacing_grid, timings[i], "-", color=colors[i], label=labels[i]) + # Find where condition is met + condition = bounds[i] < 1e-5 + # Overlay thicker markers at the points below threshold + ax[0].loglog( + spacing_grid[condition], + bounds[i][condition], + "-o", + linewidth=3, + markersize=4, + color=colors[i], + ) + ax[1].loglog( + spacing_grid[condition], + timings[i][condition], + "-o", + linewidth=3, + markersize=4, + color=colors[i], + ) + +ax[0].set_ylabel(r"estimated error / a.u.") +ax[0].set_xlabel(r"mesh spacing / Å") +ax[1].set_ylabel(r"timing / s") +ax[1].set_xlabel(r"mesh spacing / Å") +ax[0].set_title("cutoff = 5.0 Å") +ax[0].set_ylim(1e-6, 2) +ax[0].legend() + +# %% +# The overall errors saturate to the value of the real-space error, +# which is why we can pretty much fix the value of the smearing for a +# given cutoff. Higher interpolation orders allow to push the accuracy +# to higher values even with a large mesh spacing, resulting in large +# computational savings. However, depending on the specific setup, +# the overhead associated with the more complex interpolation (that is +# seen in the coarse-mesh limit) could favor intermediate values +# of ``interpolation_order``. + +# %% +# Automatic tuning +# ---------------- +# Even though these detailed examples are useful to understand the +# numerics of PME, and the logic one could follow to pick the best +# values, in practice one may want to automate the procedure. + +smearing, parameters, timing = tune_pme( + accuracy=1e-5, + charges=charges, + cell=cell, + positions=positions, + cutoff=5.0, + neighbor_indices=neighbor_indices, + neighbor_distances=neighbor_distances, +) + +print(f""" +Estimated PME parameters (cutoff={5.0} Å): +Smearing: {smearing} Å +Mesh spacing: {parameters["mesh_spacing"]} Å +Interpolation order: {parameters["interpolation_nodes"]} +Estimated time per step: {timing} s +""") + +# %% +# What is the best cutoff? +# ------------------------ +# Determining the most efficient cutoff value can be achieved by +# running a simple search over a few "reasonable" values. + +cutoff_grid = torch.tensor([3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) + +timings_grid = [] +for cutoff in cutoff_grid: + filter_indices, filter_distances = filter_neighbors( + cutoff, neighbor_indices, neighbor_distances + ) + smearing, parameters, timing = tune_pme( + accuracy=1e-5, + charges=charges, + cell=cell, + positions=positions, + cutoff=cutoff, + neighbor_indices=filter_indices, + neighbor_distances=filter_distances, + ) + timings_grid.append(timing) + +# %% +# Even though the trend is smooth, there is substantial variability, +# indicating it may be worth to perform this additional tuning whenever +# the long-range model is the bottleneck of a calculation + +fig, ax = plt.subplots(1, 1, figsize=(4, 3), constrained_layout=True) +ax.plot(cutoff_grid, timings_grid, "r-*") +ax.set_ylabel(r"avg. timings / s") +ax.set_xlabel(r"cutoff / Å") diff --git a/examples/11-4-site-water.py b/examples/11-4-site-water.py new file mode 100644 index 00000000..58e19b3e --- /dev/null +++ b/examples/11-4-site-water.py @@ -0,0 +1,84 @@ +""" +.. _example-tip4p-water: + +4-site water models +=================== + +.. currentmodule:: torchpme + +# Several water models (starting from the venerable TIP4P model of +# `Abascal and C. Vega, JCP (2005) `_) +# use a center of negative charge that is displaced from the O position. +# This is easily implemented, yielding the forces on the O and H positions +# generated by the displaced charge. +""" + +import ase +import torch + +import torchpme + +structure = ase.Atoms( + positions=[ + [0, 0, 0], + [0, 1, 0], + [1, -0.2, 0], + ], + cell=[6, 6, 6], + symbols="OHH", +) + +cell = torch.from_numpy(structure.cell.array) +positions = torch.from_numpy(structure.positions) + +# %% +# The key step is to create a "fourth site" based on the oxygen positions and use it in +# the ``interpolate`` step. + +charges = torch.tensor([[-1.0], [0.5], [0.5]]) + +positions.requires_grad_(True) +charges.requires_grad_(True) +cell.requires_grad_(True) + +positions_4site = torch.vstack( + [ + ((positions[1::3] + positions[2::3]) * 0.5 + positions[0::3] * 3) / 4, + positions[1::3], + positions[2::3], + ] +) + +# %% +# .. important:: +# +# For the automatic differentiation to work it is important to make a new tensor as +# ``positions_4site`` and do not "overwrite" the original tensor. + +ns = torch.tensor([5, 5, 5]) +interpolator = torchpme.lib.MeshInterpolator( + cell=cell, ns_mesh=ns, interpolation_nodes=3, method="Lagrange" +) +interpolator.compute_weights(positions_4site) +mesh = interpolator.points_to_mesh(charges) + +value = (mesh**2).sum() + +# %% +# The gradients can be computed by just running `backward` on the +# end result. Gradients are computed on the H and O positions. + +value.backward() + +print( + f""" +Position gradients: +{positions.grad.T} + +Cell gradients: +{cell.grad} + +Charges gradients: +{charges.grad.T} +""" +) diff --git a/src/torchpme/calculators/calculator.py b/src/torchpme/calculators/calculator.py index 627283e5..bb9e22d1 100644 --- a/src/torchpme/calculators/calculator.py +++ b/src/torchpme/calculators/calculator.py @@ -40,9 +40,10 @@ def __init__( ): super().__init__() - assert isinstance(potential, Potential), ( - f"Potential must be an instance of Potential, got {type(potential)}" - ) + if not isinstance(potential, Potential): + raise TypeError( + f"Potential must be an instance of Potential, got {type(potential)}" + ) self.device = "cpu" if device is None else device self.dtype = torch.get_default_dtype() if dtype is None else dtype diff --git a/src/torchpme/tuning/_utils.py b/src/torchpme/tuning/_utils.py index 33d7108c..c08c1a5b 100644 --- a/src/torchpme/tuning/_utils.py +++ b/src/torchpme/tuning/_utils.py @@ -1,102 +1,12 @@ -import math -import warnings -from typing import Callable, Optional - import torch -def _optimize_parameters( - params: list[torch.Tensor], - loss: Callable, - max_steps: int, - accuracy: float, - learning_rate: float, -) -> None: - optimizer = torch.optim.Adam(params, lr=learning_rate) - - for _ in range(max_steps): - loss_value = loss(*params) - if torch.isnan(loss_value) or torch.isinf(loss_value): - raise ValueError( - "The value of the estimated error is now nan, consider using a " - "smaller learning rate." - ) - loss_value.backward() - optimizer.step() - optimizer.zero_grad() - - if loss_value <= accuracy: - break - - if loss_value > accuracy: - warnings.warn( - "The searching for the parameters is ended, but the error is " - f"{float(loss_value):.3e}, larger than the given accuracy {accuracy}. " - "Consider increase max_step and", - stacklevel=2, - ) - - -def _estimate_smearing_cutoff( - cell: torch.Tensor, - smearing: Optional[float], - cutoff: Optional[float], - accuracy: float, -) -> tuple[torch.tensor, torch.tensor]: - dtype = cell.dtype - device = cell.device - - cell_dimensions = torch.linalg.norm(cell, dim=1) - min_dimension = float(torch.min(cell_dimensions)) - half_cell = min_dimension / 2.0 - - smearing_init = torch.tensor( - half_cell / 5 if smearing is None else smearing, - dtype=dtype, - device=device, - requires_grad=(smearing is None), - ) - - if cutoff is None: - # solve V_SR(cutoff) == accuracy for cutoff - def loss(cutoff): - return ( - torch.erfc(cutoff / math.sqrt(2) / smearing_init) / cutoff - accuracy - ) ** 2 - - cutoff_init = torch.tensor( - half_cell, dtype=dtype, device=device, requires_grad=True - ) - _optimize_parameters( - params=[cutoff_init], - loss=loss, - accuracy=accuracy, - max_steps=1000, - learning_rate=0.1, - ) - - cutoff_init = torch.tensor( - float(cutoff_init) if cutoff is None else cutoff, - dtype=dtype, - device=device, - requires_grad=(cutoff is None), - ) - - return smearing_init, cutoff_init - - def _validate_parameters( - sum_squared_charges: float, + charges: torch.Tensor, cell: torch.Tensor, positions: torch.Tensor, exponent: int, - accuracy: float, ) -> None: - if sum_squared_charges <= 0: - raise ValueError( - f"sum of squared charges must be positive, got {sum_squared_charges}" - ) - if exponent != 1: raise NotImplementedError("Only exponent = 1 is supported") @@ -135,5 +45,31 @@ def _validate_parameters( "periodic calculation" ) - if not isinstance(accuracy, float): - raise ValueError(f"'{accuracy}' is not a float.") + if charges.dtype != dtype: + raise ValueError( + f"each `charges` must have the same type {dtype} as `positions`, got at least " + "one tensor of type " + f"{charges.dtype}" + ) + + if charges.device != device: + raise ValueError( + f"each `charges` must be on the same device {device} as `positions`, got at " + "least one tensor with device " + f"{charges.device}" + ) + + if charges.dim() != 2: + raise ValueError( + "`charges` must be a 2-dimensional tensor, got " + f"tensor with {charges.dim()} dimension(s) and shape " + f"{list(charges.shape)}" + ) + + if list(charges.shape) != [len(positions), charges.shape[1]]: + raise ValueError( + "`charges` must be a tensor with shape [n_atoms, n_channels], with " + "`n_atoms` being the same as the variable `positions`. Got tensor with " + f"shape {list(charges.shape)} where positions contains " + f"{len(positions)} atoms" + ) diff --git a/src/torchpme/tuning/ewald.py b/src/torchpme/tuning/ewald.py index 7c653403..1e40824b 100644 --- a/src/torchpme/tuning/ewald.py +++ b/src/torchpme/tuning/ewald.py @@ -1,167 +1,207 @@ import math -from typing import Optional +from typing import Any, Optional +from warnings import warn import torch -from ._utils import ( - _estimate_smearing_cutoff, - _optimize_parameters, - _validate_parameters, -) - -TWO_PI = 2 * math.pi +from ..calculators import EwaldCalculator +from ._utils import _validate_parameters +from .tuner import GridSearchTuner, TuningErrorBounds def tune_ewald( - sum_squared_charges: float, + charges: torch.Tensor, cell: torch.Tensor, positions: torch.Tensor, - smearing: Optional[float] = None, - lr_wavelength: Optional[float] = None, - cutoff: Optional[float] = None, + cutoff: float, + neighbor_indices: torch.Tensor, + neighbor_distances: torch.Tensor, exponent: int = 1, + ns_lo: int = 1, + ns_hi: int = 14, accuracy: float = 1e-3, - max_steps: int = 50000, - learning_rate: float = 0.1, -) -> tuple[float, dict[str, float], float]: + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, +) -> tuple[float, dict[str, Any], float]: r""" Find the optimal parameters for :class:`torchpme.EwaldCalculator`. - The error formulas are given `online - `_ - (now not available, need to be updated later). Note the difference notation between - the parameters in the reference and ours: - - .. math:: - - \alpha &= \left( \sqrt{2}\,\mathrm{smearing} \right)^{-1} - - K &= \frac{2 \pi}{\mathrm{lr\_wavelength}} - - r_c &= \mathrm{cutoff} - - For the optimization we use the :class:`torch.optim.Adam` optimizer. By default this - function optimize the ``smearing``, ``lr_wavelength`` and ``cutoff`` based on the - error formula given `online`_. You can limit the optimization by giving one or more - parameters to the function. For example in usual ML workflows the cutoff is fixed - and one wants to optimize only the ``smearing`` and the ``lr_wavelength`` with - respect to the minimal error and fixed cutoff. - - :param sum_squared_charges: accumulated squared charges, must be positive - :param cell: single tensor of shape (3, 3), describing the bounding - :param positions: single tensor of shape (``len(charges), 3``) containing the - Cartesian positions of all point charges in the system. - :param smearing: if its value is given, it will not be tuned, see - :class:`torchpme.EwaldCalculator` for details - :param lr_wavelength: if its value is given, it will not be tuned, see - :class:`torchpme.EwaldCalculator` for details - :param cutoff: if its value is given, it will not be tuned, see - :class:`torchpme.EwaldCalculator` for details - :param exponent: exponent :math:`p` in :math:`1/r^p` potentials + .. note:: + + The :func:`torchpme.tuning.ewald.EwaldErrorBounds.forward` method takes floats + as the input, in order to be in consistency with the rest of the package -- + these parameters are always ``float`` but not ``torch.Tensor``. This design, + however, prevents the utilization of ``torch.autograd`` and other ``torch`` + features. To take advantage of these features, one can use the + :func:`torchpme.tuning.ewald.EwaldErrorBounds.err_rspace` and + :func:`torchpme.tuning.ewald.EwaldErrorBounds.err_kspace`, which takes + ``torch.Tensor`` as parameters. + + :param charges: torch.Tensor, atomic (pseudo-)charges + :param cell: torch.Tensor, periodic supercell for the system + :param positions: torch.Tensor, Cartesian coordinates of the particles within the + supercell. + :param cutoff: float, cutoff distance for the neighborlist + :param exponent: :math:`p` in :math:`1/r^p` potentials, currently only :math:`p=1` + is supported + :param neighbor_indices: torch.Tensor with the ``i,j`` indices of neighbors for + which the potential should be computed in real space. + :param neighbor_distances: torch.Tensor with the pair distances of the neighbors for + which the potential should be computed in real space. + :param ns_lo: Minimum number of spatial resolution along each axis + :param ns_hi: Maximum number of spatial resolution along each axis :param accuracy: Recomended values for a balance between the accuracy and speed is :math:`10^{-3}`. For more accurate results, use :math:`10^{-6}`. - :param max_steps: maximum number of gradient descent steps - :param learning_rate: learning rate for gradient descent :return: Tuple containing a float of the optimal smearing for the :class: - `CoulombPotential`, a dictionary with the parameters for - :class:`EwaldCalculator` and a float of the optimal cutoff value for the - neighborlist computation. + `CoulombPotential`, and a dictionary with the parameters for + :class:`EwaldCalculator`, and the timing of this set of parameters. Example ------- >>> import torch >>> positions = torch.tensor( - ... [[0.0, 0.0, 0.0], [0.5, 0.5, 0.5]], dtype=torch.float64 + ... [[0.0, 0.0, 0.0], [0.4, 0.4, 0.4]], dtype=torch.float64 ... ) >>> charges = torch.tensor([[1.0], [-1.0]], dtype=torch.float64) >>> cell = torch.eye(3, dtype=torch.float64) - >>> smearing, parameter, cutoff = tune_ewald( - ... torch.sum(charges**2, dim=0), cell, positions, accuracy=1e-1 + >>> neighbor_distances = torch.tensor( + ... [0.9381, 0.9381, 0.8246, 0.9381, 0.8246, 0.8246, 0.6928], + ... dtype=torch.float64, ... ) - - You can check the values of the parameters - - >>> print(smearing) - 0.7527865828476816 - - >>> print(parameter) - {'lr_wavelength': 11.138556788117427} - - >>> print(cutoff) - 2.207855328192979 - - You can give one parameter to the function to tune only other parameters, for - example, fixing the cutoff to 0.1 - - >>> smearing, parameter, cutoff = tune_ewald( - ... torch.sum(charges**2, dim=0), cell, positions, cutoff=0.4, accuracy=1e-1 + >>> neighbor_indices = torch.tensor( + ... [[0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1]] + ... ) + >>> smearing, parameter, timing = tune_ewald( + ... charges, + ... cell, + ... positions, + ... cutoff=1.0, + ... neighbor_distances=neighbor_distances, + ... neighbor_indices=neighbor_indices, + ... accuracy=1e-1, ... ) - You can check the values of the parameters, now the cutoff is fixed + """ + _validate_parameters(charges, cell, positions, exponent) + min_dimension = float(torch.min(torch.linalg.norm(cell, dim=1))) + params = [{"lr_wavelength": min_dimension / ns} for ns in range(ns_lo, ns_hi + 1)] + + tuner = GridSearchTuner( + charges=charges, + cell=cell, + positions=positions, + cutoff=cutoff, + exponent=exponent, + neighbor_indices=neighbor_indices, + neighbor_distances=neighbor_distances, + calculator=EwaldCalculator, + error_bounds=EwaldErrorBounds(charges=charges, cell=cell, positions=positions), + params=params, + dtype=dtype, + device=device, + ) + smearing = tuner.estimate_smearing(accuracy) + errs, timings = tuner.tune(accuracy) + + # There are multiple errors below the accuracy, return the one with the shortest + # calculation time. The timing of those parameters leading to an higher error than + # the accuracy are set to infinity + if any(err < accuracy for err in errs): + return smearing, params[timings.index(min(timings))], min(timings) + # No parameter meets the requirement, return the one with the smallest error + warn( + f"No parameter meets the accuracy requirement.\n" + f"Returning the parameter with the smallest error, which is {min(errs)}.\n", + stacklevel=1, + ) + return smearing, params[errs.index(min(errs))], timings[errs.index(min(errs))] - >>> print(round(smearing, 4)) - 0.1402 - We can also check the value of the other parameter like the ``lr_wavelength`` +class EwaldErrorBounds(TuningErrorBounds): + r""" + Error bounds for :class:`torchpme.calculators.ewald.EwaldCalculator`. - >>> print(round(parameter["lr_wavelength"], 3)) - 0.255 + .. math:: + \text{Error}_{\text{total}} = \sqrt{\text{Error}_{\text{real space}}^2 + + \text{Error}_{\text{Fourier space}}^2 - and finally as requested the value of the cutoff is fixed + :param charges: atomic charges + :param cell: single tensor of shape (3, 3), describing the bounding + :param positions: single tensor of shape (``len(charges), 3``) containing the + Cartesian positions of all point charges in the system. - >>> print(cutoff) - 0.4 + Example + ------- + >>> import torch + >>> positions = torch.tensor( + ... [[0.0, 0.0, 0.0], [0.4, 0.4, 0.4]], dtype=torch.float64 + ... ) + >>> charges = torch.tensor([[1.0], [-1.0]], dtype=torch.float64) + >>> cell = torch.eye(3, dtype=torch.float64) + >>> error_bounds = EwaldErrorBounds(charges, cell, positions) + >>> print(error_bounds(smearing=1.0, lr_wavelength=0.5, cutoff=4.4)) + tensor(8.4304e-05, dtype=torch.float64) """ - _validate_parameters(sum_squared_charges, cell, positions, exponent, accuracy) - smearing_opt, cutoff_opt = _estimate_smearing_cutoff( - cell=cell, smearing=smearing, cutoff=cutoff, accuracy=accuracy - ) - - # We choose a very small initial fourier wavelength, hardcoded for now - k_cutoff_opt = torch.tensor( - 1e-3 if lr_wavelength is None else TWO_PI / lr_wavelength, - dtype=cell.dtype, - device=cell.device, - requires_grad=(lr_wavelength is None), - ) - - volume = torch.abs(cell.det()) - prefac = 2 * sum_squared_charges / math.sqrt(len(positions)) - - def err_Fourier(smearing, k_cutoff): + def __init__( + self, + charges: torch.Tensor, + cell: torch.Tensor, + positions: torch.Tensor, + ): + super().__init__(charges, cell, positions) + + self.volume = torch.abs(torch.det(cell)) + self.sum_squared_charges = (charges**2).sum() + self.prefac = 2 * self.sum_squared_charges / math.sqrt(len(positions)) + self.cell = cell + self.positions = positions + + def err_kspace( + self, smearing: torch.Tensor, lr_wavelength: torch.Tensor + ) -> torch.Tensor: + """ + The Fourier space error of Ewald. + + :param smearing: see :class:`torchpme.EwaldCalculator` for details + :param lr_wavelength: see :class:`torchpme.EwaldCalculator` for details + """ return ( - prefac**0.5 + self.prefac**0.5 / smearing - / torch.sqrt(TWO_PI**2 * volume / (TWO_PI / k_cutoff) ** 0.5) - * torch.exp(-(TWO_PI**2) * smearing**2 / (TWO_PI / k_cutoff)) + / torch.sqrt((2 * torch.pi) ** 2 * self.volume / (lr_wavelength) ** 0.5) + * torch.exp(-((2 * torch.pi) ** 2) * smearing**2 / (lr_wavelength)) ) - def err_real(smearing, cutoff): + def err_rspace(self, smearing: torch.Tensor, cutoff: torch.Tensor) -> torch.Tensor: + """ + The real space error of Ewald. + + :param smearing: see :class:`torchpme.EwaldCalculator` for details + :param lr_wavelength: see :class:`torchpme.EwaldCalculator` for details + """ return ( - prefac - / torch.sqrt(cutoff * volume) + self.prefac + / torch.sqrt(cutoff * self.volume) * torch.exp(-(cutoff**2) / 2 / smearing**2) ) - def loss(smearing, k_cutoff, cutoff): + def forward( + self, smearing: float, lr_wavelength: float, cutoff: float + ) -> torch.Tensor: + r""" + Calculate the error bound of Ewald. + + :param smearing: see :class:`torchpme.EwaldCalculator` for details + :param lr_wavelength: see :class:`torchpme.EwaldCalculator` for details + :param cutoff: see :class:`torchpme.EwaldCalculator` for details + """ + smearing = torch.tensor(smearing) + lr_wavelength = torch.tensor(lr_wavelength) + cutoff = torch.tensor(cutoff) return torch.sqrt( - err_Fourier(smearing, k_cutoff) ** 2 + err_real(smearing, cutoff) ** 2 + self.err_kspace(smearing, lr_wavelength) ** 2 + + self.err_rspace(smearing, cutoff) ** 2 ) - - params = [smearing_opt, k_cutoff_opt, cutoff_opt] - _optimize_parameters( - params=params, - loss=loss, - max_steps=max_steps, - accuracy=accuracy, - learning_rate=learning_rate, - ) - - return ( - float(smearing_opt), - {"lr_wavelength": TWO_PI / float(k_cutoff_opt)}, - float(cutoff_opt), - ) diff --git a/src/torchpme/tuning/p3m.py b/src/torchpme/tuning/p3m.py index 538c7dd7..14dd0b88 100644 --- a/src/torchpme/tuning/p3m.py +++ b/src/torchpme/tuning/p3m.py @@ -1,14 +1,13 @@ import math -from typing import Optional +from itertools import product +from typing import Any, Optional +from warnings import warn import torch -from ..lib import get_ns_mesh -from ._utils import ( - _estimate_smearing_cutoff, - _optimize_parameters, - _validate_parameters, -) +from ..calculators import P3MCalculator +from ._utils import _validate_parameters +from .tuner import GridSearchTuner, TuningErrorBounds # Coefficients for the P3M Fourier error, # see Table II of http://dx.doi.org/10.1063/1.477415 @@ -69,54 +68,55 @@ def tune_p3m( - sum_squared_charges: float, + charges: torch.Tensor, cell: torch.Tensor, positions: torch.Tensor, - smearing: Optional[float] = None, - mesh_spacing: Optional[float] = None, - cutoff: Optional[float] = None, - interpolation_nodes: int = 4, + cutoff: float, + neighbor_indices: torch.Tensor, + neighbor_distances: torch.Tensor, exponent: int = 1, + nodes_lo: int = 2, + nodes_hi: int = 5, + mesh_lo: int = 2, + mesh_hi: int = 7, accuracy: float = 1e-3, - max_steps: int = 50000, - learning_rate: float = 5e-3, -) -> tuple[float, dict[str, float], float]: + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, +) -> tuple[float, dict[str, Any], float]: r""" Find the optimal parameters for :class:`torchpme.calculators.pme.PMECalculator`. - For the error formulas are given `here `_. - Note the difference notation between the parameters in the reference and ours: + For the error formulas are given `here `_. Note + the difference notation between the parameters in the reference and ours: .. math:: \alpha = \left(\sqrt{2}\,\mathrm{smearing} \right)^{-1} - .. hint:: - - Tuning uses an initial guess for the optimization, which can be applied by - setting ``max_steps = 0``. This can be useful if fast tuning is required. These - values typically result in accuracies around :math:`10^{-2}`. - - :param sum_squared_charges: accumulated squared charges, must be positive - :param cell: single tensor of shape (3, 3), describing the bounding - :param positions: single tensor of shape (``len(charges), 3``) containing the - Cartesian positions of all point charges in the system. - :param interpolation_nodes: The number ``n`` of nodes used in the interpolation per - coordinate axis. The total number of interpolation nodes in 3D will be ``n^3``. - In general, for ``n`` nodes, the interpolation will be performed by piecewise - polynomials of degree ``n`` (e.g. ``n = 3`` for cubic interpolation). Only - the values ``1, 2, 3, 4, 5`` are supported. - :param exponent: exponent :math:`p` in :math:`1/r^p` potentials + :param charges: torch.Tensor, atomic (pseudo-)charges + :param cell: torch.Tensor, periodic supercell for the system + :param positions: torch.Tensor, Cartesian coordinates of the particles within the + supercell. + :param neighbor_indices: torch.Tensor with the ``i,j`` indices of neighbors for + which the potential should be computed in real space. + :param neighbor_distances: torch.Tensor with the pair distances of the neighbors for + which the potential should be computed in real space. + :param cutoff: float, cutoff distance for the neighborlist supported + :param exponent: :math:`p` in :math:`1/r^p` potentials, currently only :math:`p=1` + is + :param nodes_lo: Minimum number of interpolation nodes + :param nodes_hi: Maximum number of interpolation nodes + :param mesh_lo: Controls the minimum number of mesh points along the shortest axis, + :math:`2^{mesh_lo}` + :param mesh_hi: Controls the maximum number of mesh points along the shortest axis, + :math:`2^{mesh_hi}` :param accuracy: Recomended values for a balance between the accuracy and speed is :math:`10^{-3}`. For more accurate results, use :math:`10^{-6}`. - :param max_steps: maximum number of gradient descent steps - :param learning_rate: learning rate for gradient descent - :param verbose: whether to print the progress of gradient descent :return: Tuple containing a float of the optimal smearing for the :py:class: `CoulombPotential`, a dictionary with the parameters for - :py:class:`PMECalculator` and a float of the optimal cutoff value for the - neighborlist computation. + :py:class:`P3MCalculator` and a float of the optimal cutoff value for the + neighborlist computation, and the timing of this set of parameters. Example ------- @@ -126,62 +126,147 @@ def tune_p3m( >>> _ = torch.manual_seed(0) >>> positions = torch.tensor( - ... [[0.0, 0.0, 0.0], [0.5, 0.5, 0.5]], dtype=torch.float64 + ... [[0.0, 0.0, 0.0], [0.4, 0.4, 0.4]], dtype=torch.float64 ... ) >>> charges = torch.tensor([[1.0], [-1.0]], dtype=torch.float64) >>> cell = torch.eye(3, dtype=torch.float64) - >>> smearing, parameter, cutoff = tune_p3m( - ... torch.sum(charges**2, dim=0), cell, positions, accuracy=1e-1 + >>> neighbor_distances = torch.tensor( + ... [0.9381, 0.9381, 0.8246, 0.9381, 0.8246, 0.8246, 0.6928], + ... dtype=torch.float64, + ... ) + >>> neighbor_indices = torch.tensor( + ... [[0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1]] + ... ) + >>> smearing, parameter, timing = tune_p3m( + ... charges, + ... cell, + ... positions, + ... cutoff=1.0, + ... neighbor_distances=neighbor_distances, + ... neighbor_indices=neighbor_indices, + ... accuracy=1e-1, ... ) - You can check the values of the parameters + """ + _validate_parameters(charges, cell, positions, exponent) + min_dimension = float(torch.min(torch.linalg.norm(cell, dim=1))) + params = [ + { + "interpolation_nodes": interpolation_nodes, + "mesh_spacing": 2 * min_dimension / (2**ns - 1), + } + for interpolation_nodes, ns in product( + range(nodes_lo, nodes_hi + 1), range(mesh_lo, mesh_hi + 1) + ) + ] + + tuner = GridSearchTuner( + charges=charges, + cell=cell, + positions=positions, + cutoff=cutoff, + exponent=exponent, + neighbor_indices=neighbor_indices, + neighbor_distances=neighbor_distances, + calculator=P3MCalculator, + error_bounds=P3MErrorBounds(charges=charges, cell=cell, positions=positions), + params=params, + dtype=dtype, + device=device, + ) + smearing = tuner.estimate_smearing(accuracy) + errs, timings = tuner.tune(accuracy) + + # There are multiple errors below the accuracy, return the one with the shortest + # calculation time. The timing of those parameters leading to an higher error + # than the accuracy are set to infinity + if any(err < accuracy for err in errs): + return smearing, params[timings.index(min(timings))], min(timings) + # No parameter meets the requirement, return the one with the smallest error, and + # throw a warning + warn( + f"No parameter meets the accuracy requirement.\n" + f"Returning the parameter with the smallest error, which is {min(errs)}.\n", + stacklevel=1, + ) + return smearing, params[errs.index(min(errs))], timings[errs.index(min(errs))] - >>> print(smearing) - 0.5084014996119913 - >>> print(parameter) - {'mesh_spacing': 0.546694745583215, 'interpolation_nodes': 4} +class P3MErrorBounds(TuningErrorBounds): + r""" + " Error bounds for :class:`torchpme.calculators.pme.P3MCalculator`. + + .. note:: + + The :func:`torchpme.tuning.p3m.P3MErrorBounds.forward` method takes floats as + the input, in order to be in consistency with the rest of the package -- these + parameters are always ``float`` but not ``torch.Tensor``. This design, however, + prevents the utilization of ``torch.autograd`` and other ``torch`` features. To + take advantage of these features, one can use the + :func:`torchpme.tuning.p3m.P3MErrorBounds.err_rspace` and + :func:`torchpme.tuning.p3m.P3MErrorBounds.err_kspace`, which takes + ``torch.Tensor`` as parameters. + + :param charges: atomic charges + :param cell: single tensor of shape (3, 3), describing the bounding + :param positions: single tensor of shape (``len(charges), 3``) containing the + Cartesian positions of all point charges in the system. - >>> print(cutoff) - 2.6863848597963442 + Example + ------- + >>> import torch + >>> positions = torch.tensor( + ... [[0.0, 0.0, 0.0], [0.4, 0.4, 0.4]], dtype=torch.float64 + ... ) + >>> charges = torch.tensor([[1.0], [-1.0]], dtype=torch.float64) + >>> cell = torch.eye(3, dtype=torch.float64) + >>> error_bounds = P3MErrorBounds(charges, cell, positions) + >>> print( + ... error_bounds( + ... smearing=1.0, mesh_spacing=0.5, cutoff=4.4, interpolation_nodes=3 + ... ) + ... ) + tensor(0.0005, dtype=torch.float64) """ - _validate_parameters(sum_squared_charges, cell, positions, exponent, accuracy) - smearing_opt, cutoff_opt = _estimate_smearing_cutoff( - cell=cell, - smearing=smearing, - cutoff=cutoff, - accuracy=accuracy, - ) - # We choose only one mesh as initial guess - if mesh_spacing is None: - ns_mesh_opt = torch.tensor( - [1, 1, 1], - device=cell.device, - dtype=cell.dtype, - requires_grad=True, - ) - else: - ns_mesh_opt = get_ns_mesh(cell, mesh_spacing) + def __init__( + self, charges: torch.Tensor, cell: torch.Tensor, positions: torch.Tensor + ): + super().__init__(charges, cell, positions) - cell_dimensions = torch.linalg.norm(cell, dim=1) - volume = torch.abs(cell.det()) - prefac = 2 * sum_squared_charges / math.sqrt(len(positions)) + self.volume = torch.abs(torch.det(cell)) + self.sum_squared_charges = (charges**2).sum() + self.prefac = 2 * self.sum_squared_charges / math.sqrt(len(positions)) + self.cell_dimensions = torch.linalg.norm(cell, dim=1) + self.cell = cell + self.positions = positions - interpolation_nodes = torch.tensor(interpolation_nodes, device=cell.device) + def err_kspace( + self, + smearing: torch.Tensor, + mesh_spacing: torch.Tensor, + interpolation_nodes: torch.Tensor, + ) -> torch.Tensor: + """ + The Fourier space error of P3M. - def err_Fourier(smearing, ns_mesh): - spacing = cell_dimensions / ns_mesh - h = torch.prod(spacing) ** (1 / 3) + :param smearing: see :class:`torchpme.P3MCalculator` for details + :param mesh_spacing: see :class:`torchpme.P3MCalculator` for details + :param interpolation_nodes: see :class:`torchpme.P3MCalculator` for details + """ + actual_spacing = self.cell_dimensions / ( + 2 * self.cell_dimensions / mesh_spacing + 1 + ) + h = torch.prod(actual_spacing) ** (1 / 3) return ( - prefac - / volume ** (2 / 3) + self.prefac + / self.volume ** (2 / 3) * (h * (1 / 2**0.5 / smearing)) ** interpolation_nodes * torch.sqrt( (1 / 2**0.5 / smearing) - * volume ** (1 / 3) + * self.volume ** (1 / 3) * math.sqrt(2 * torch.pi) * sum( A_COEF[m][interpolation_nodes] @@ -191,32 +276,43 @@ def err_Fourier(smearing, ns_mesh): ) ) - def err_real(smearing, cutoff): + def err_rspace(self, smearing: torch.Tensor, cutoff: torch.Tensor) -> torch.Tensor: + """ + The real space error of P3M. + + :param smearing: see :class:`torchpme.P3MCalculator` for details + :param cutoff: see :class:`torchpme.P3MCalculator` for details + """ return ( - prefac - / torch.sqrt(cutoff * volume) + self.prefac + / torch.sqrt(cutoff * self.volume) * torch.exp(-(cutoff**2) / 2 / smearing**2) ) - def loss(smearing, ns_mesh, cutoff): - return torch.sqrt( - err_Fourier(smearing, ns_mesh) ** 2 + err_real(smearing, cutoff) ** 2 - ) + def forward( + self, + smearing: float, + mesh_spacing: float, + cutoff: float, + interpolation_nodes: int, + ) -> torch.Tensor: + r""" + Calculate the error bound of P3M. - params = [smearing_opt, ns_mesh_opt, cutoff_opt] - _optimize_parameters( - params=params, - loss=loss, - max_steps=max_steps, - accuracy=accuracy, - learning_rate=learning_rate, - ) + .. math:: + \text{Error}_{\text{total}} = \sqrt{\text{Error}_{\text{real space}}^2 + + \text{Error}_{\text{Fourier space}}^2 - return ( - float(smearing_opt), - { - "mesh_spacing": float(torch.min(cell_dimensions / ns_mesh_opt)), - "interpolation_nodes": int(interpolation_nodes), - }, - float(cutoff_opt), - ) + :param smearing: see :class:`torchpme.P3MCalculator` for details + :param mesh_spacing: see :class:`torchpme.P3MCalculator` for details + :param cutoff: see :class:`torchpme.P3MCalculator` for details + :param interpolation_nodes: see :class:`torchpme.P3MCalculator` for details + """ + smearing = torch.tensor(smearing) + mesh_spacing = torch.tensor(mesh_spacing) + cutoff = torch.tensor(cutoff) + interpolation_nodes = torch.tensor(interpolation_nodes) + return torch.sqrt( + self.err_kspace(smearing, mesh_spacing, interpolation_nodes) ** 2 + + self.err_rspace(smearing, cutoff) ** 2 + ) diff --git a/src/torchpme/tuning/pme.py b/src/torchpme/tuning/pme.py index c072194d..e10cdda6 100644 --- a/src/torchpme/tuning/pme.py +++ b/src/torchpme/tuning/pme.py @@ -1,29 +1,31 @@ import math -from typing import Optional +from itertools import product +from typing import Any, Optional +from warnings import warn import torch -from ..lib import get_ns_mesh -from ._utils import ( - _estimate_smearing_cutoff, - _optimize_parameters, - _validate_parameters, -) +from ..calculators import PMECalculator +from ._utils import _validate_parameters +from .tuner import GridSearchTuner, TuningErrorBounds def tune_pme( - sum_squared_charges: float, + charges: torch.Tensor, cell: torch.Tensor, positions: torch.Tensor, - smearing: Optional[float] = None, - mesh_spacing: Optional[float] = None, - cutoff: Optional[float] = None, - interpolation_nodes: int = 4, + cutoff: float, + neighbor_indices: torch.Tensor, + neighbor_distances: torch.Tensor, exponent: int = 1, + nodes_lo: int = 3, + nodes_hi: int = 7, + mesh_lo: int = 2, + mesh_hi: int = 7, accuracy: float = 1e-3, - max_steps: int = 50000, - learning_rate: float = 0.1, -): + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, +) -> tuple[float, dict[str, Any], float]: r""" Find the optimal parameters for :class:`torchpme.PMECalculator`. @@ -34,38 +36,30 @@ def tune_pme( \alpha = \left(\sqrt{2}\,\mathrm{smearing} \right)^{-1} - For the optimization we use the :class:`torch.optim.Adam` optimizer. By default this - function optimize the ``smearing``, ``mesh_spacing`` and ``cutoff`` based on the - error formula given `elsewhere`_. You can limit the optimization by giving one or - more parameters to the function. For example in usual ML workflows the cutoff is - fixed and one wants to optimize only the ``smearing`` and the ``mesh_spacing`` with - respect to the minimal error and fixed cutoff. - - :param sum_squared_charges: accumulated squared charges, must be positive - :param cell: single tensor of shape (3, 3), describing the bounding - :param positions: single tensor of shape (``len(charges), 3``) containing the - Cartesian positions of all point charges in the system. - :param smearing: if its value is given, it will not be tuned, see - :class:`torchpme.PMECalculator` for details - :param mesh_spacing: if its value is given, it will not be tuned, see - :class:`torchpme.PMECalculator` for details - :param cutoff: if its value is given, it will not be tuned, see - :class:`torchpme.PMECalculator` for details - :param interpolation_nodes: The number ``n`` of nodes used in the interpolation per - coordinate axis. The total number of interpolation nodes in 3D will be ``n^3``. - In general, for ``n`` nodes, the interpolation will be performed by piecewise - polynomials of degree ``n - 1`` (e.g. ``n = 4`` for cubic interpolation). Only - the values ``3, 4, 5, 6, 7`` are supported. - :param exponent: exponent :math:`p` in :math:`1/r^p` potentials + :param charges: torch.Tensor, atomic (pseudo-)charges + :param cell: torch.Tensor, periodic supercell for the system + :param positions: torch.Tensor, Cartesian coordinates of the particles within the + supercell. + :param cutoff: float, cutoff distance for the neighborlist + :param neighbor_indices: torch.Tensor with the ``i,j`` indices of neighbors for + which the potential should be computed in real space. + :param neighbor_distances: torch.Tensor with the pair distances of the neighbors for + which the potential should be computed in real space. + :param exponent: :math:`p` in :math:`1/r^p` potentials, currently only :math:`p=1` + is supported + :param nodes_lo: Minimum number of interpolation nodes + :param nodes_hi: Maximum number of interpolation nodes + :param mesh_lo: Controls the minimum number of mesh points along the shortest axis, + :math:`2^{mesh_lo}` + :param mesh_hi: Controls the maximum number of mesh points along the shortest axis, + :math:`2^{mesh_hi}` :param accuracy: Recomended values for a balance between the accuracy and speed is :math:`10^{-3}`. For more accurate results, use :math:`10^{-6}`. - :param max_steps: maximum number of gradient descent steps - :param learning_rate: learning rate for gradient descent :return: Tuple containing a float of the optimal smearing for the :class: - `CoulombPotential`, a dictionary with the parameters for - :class:`PMECalculator` and a float of the optimal cutoff value for the - neighborlist computation. + `CoulombPotential`, a dictionary with the parameters for :class:`PMECalculator` + and a float of the optimal cutoff value for the neighborlist computation, and + the timing of this set of parameters. Example ------- @@ -79,195 +73,193 @@ def tune_pme( ... ) >>> charges = torch.tensor([[1.0], [-1.0]], dtype=torch.float64) >>> cell = torch.eye(3, dtype=torch.float64) - >>> smearing, parameter, cutoff = tune_pme( - ... torch.sum(charges**2, dim=0), cell, positions, accuracy=1e-1 + >>> neighbor_distances = torch.tensor( + ... [0.9381, 0.9381, 0.8246, 0.9381, 0.8246, 0.8246, 0.6928], + ... dtype=torch.float64, ... ) - - You can check the values of the parameters - - >>> print(smearing) - 0.6768985898318037 - - >>> print(parameter) - {'mesh_spacing': 0.6305733973385922, 'interpolation_nodes': 4} - - >>> print(cutoff) - 2.243154348782357 - - You can give one parameter to the function to tune only other parameters, for - example, fixing the cutoff to 0.1 - - >>> smearing, parameter, cutoff = tune_pme( - ... torch.sum(charges**2, dim=0), cell, positions, cutoff=0.6, accuracy=1e-1 + >>> neighbor_indices = torch.tensor( + ... [[0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1]] + ... ) + >>> smearing, parameter, timing = tune_pme( + ... charges, + ... cell, + ... positions, + ... cutoff=1.0, + ... neighbor_distances=neighbor_distances, + ... neighbor_indices=neighbor_indices, + ... accuracy=1e-1, ... ) - - You can check the values of the parameters, now the cutoff is fixed - - >>> print(smearing) - 0.22038829671671745 - - >>> print(parameter) - {'mesh_spacing': 0.5006356677116188, 'interpolation_nodes': 4} - - >>> print(cutoff) - 0.6 """ - _validate_parameters(sum_squared_charges, cell, positions, exponent, accuracy) + _validate_parameters(charges, cell, positions, exponent) + min_dimension = float(torch.min(torch.linalg.norm(cell, dim=1))) + params = [ + { + "interpolation_nodes": interpolation_nodes, + "mesh_spacing": 2 * min_dimension / (2**ns - 1), + } + for interpolation_nodes, ns in product( + range(nodes_lo, nodes_hi + 1), range(mesh_lo, mesh_hi + 1) + ) + ] - smearing_opt, cutoff_opt = _estimate_smearing_cutoff( + tuner = GridSearchTuner( + charges=charges, cell=cell, - smearing=smearing, + positions=positions, cutoff=cutoff, - accuracy=accuracy, + exponent=exponent, + neighbor_indices=neighbor_indices, + neighbor_distances=neighbor_distances, + calculator=PMECalculator, + error_bounds=PMEErrorBounds(charges=charges, cell=cell, positions=positions), + params=params, + dtype=dtype, + device=device, + ) + smearing = tuner.estimate_smearing(accuracy) + errs, timings = tuner.tune(accuracy) + + # There are multiple errors below the accuracy, return the one with the shortest + # calculation time. The timing of those parameters leading to an higher error + # than the accuracy are set to infinity + if any(err < accuracy for err in errs): + return smearing, params[timings.index(min(timings))], min(timings) + # No parameter meets the requirement, return the one with the smallest error, and + # throw a warning + warn( + f"No parameter meets the accuracy requirement.\n" + f"Returning the parameter with the smallest error, which is {min(errs)}.\n", + stacklevel=1, ) + return smearing, params[errs.index(min(errs))], timings[errs.index(min(errs))] - # We choose only one mesh as initial guess - if mesh_spacing is None: - ns_mesh_opt = torch.tensor( - [1, 1, 1], - device=cell.device, - dtype=cell.dtype, - requires_grad=True, - ) - else: - ns_mesh_opt = get_ns_mesh(cell, mesh_spacing) - cell_dimensions = torch.linalg.norm(cell, dim=1) - volume = torch.abs(cell.det()) - prefac = 2 * sum_squared_charges / math.sqrt(len(positions)) +class PMEErrorBounds(TuningErrorBounds): + r""" + Error bounds for :class:`torchpme.PMECalculator`. - interpolation_nodes = torch.tensor(interpolation_nodes, device=cell.device) + .. note:: - def err_Fourier(smearing, ns_mesh): - def H(ns_mesh): - return torch.prod(1 / ns_mesh) ** (1 / 3) + The :func:`torchpme.tuning.pme.PMEErrorBounds.forward` method takes floats as + the input, in order to be in consistency with the rest of the package -- these + parameters are always ``float`` but not ``torch.Tensor``. This design, however, + prevents the utilization of ``torch.autograd`` and other ``torch`` features. To + take advantage of these features, one can use the + :func:`torchpme.tuning.pme.PMEErrorBounds.err_rspace` and + :func:`torchpme.tuning.pme.PMEErrorBounds.err_kspace`, which takes + ``torch.Tensor`` as parameters. - def RMS_phi(ns_mesh): - return torch.linalg.norm( - _compute_RMS_phi(cell, interpolation_nodes, ns_mesh, positions) - ) + :param charges: atomic charges + :param cell: single tensor of shape (3, 3), describing the bounding + :param positions: single tensor of shape (``len(charges), 3``) containing the + Cartesian positions of all point charges in the system. - def log_factorial(x): - return torch.lgamma(x + 1) + Example + ------- + >>> import torch + >>> positions = torch.tensor( + ... [[0.0, 0.0, 0.0], [0.4, 0.4, 0.4]], dtype=torch.float64 + ... ) + >>> charges = torch.tensor([[1.0], [-1.0]], dtype=torch.float64) + >>> cell = torch.eye(3, dtype=torch.float64) + >>> error_bounds = PMEErrorBounds(charges, cell, positions) + >>> print( + ... error_bounds( + ... smearing=1.0, mesh_spacing=0.5, cutoff=4.4, interpolation_nodes=3 + ... ) + ... ) + tensor(0.0011, dtype=torch.float64) + + """ - def factorial(x): - return torch.exp(log_factorial(x)) + def __init__( + self, charges: torch.Tensor, cell: torch.Tensor, positions: torch.Tensor + ): + super().__init__(charges, cell, positions) + + self.volume = torch.abs(torch.det(cell)) + self.sum_squared_charges = (charges**2).sum() + self.prefac = 2 * self.sum_squared_charges / math.sqrt(len(positions)) + self.cell_dimensions = torch.linalg.norm(cell, dim=1) + + def err_kspace( + self, + smearing: torch.Tensor, + mesh_spacing: torch.Tensor, + interpolation_nodes: torch.Tensor, + ) -> torch.Tensor: + """ + The Fourier space error of PME. + + :param smearing: see :class:`torchpme.PMECalculator` for details + :param mesh_spacing: see :class:`torchpme.PMECalculator` for details + :param interpolation_nodes: see :class:`torchpme.PMECalculator` for details + """ + actual_spacing = self.cell_dimensions / ( + 2 * self.cell_dimensions / mesh_spacing + 1 + ) + h = torch.prod(actual_spacing) ** (1 / 3) + i_n_factorial = torch.exp(torch.lgamma(interpolation_nodes + 1)) + RMS_phi = [None, None, 0.246, 0.404, 0.950, 2.51, 8.42] return ( - prefac + self.prefac * torch.pi**0.25 * (6 * (1 / 2**0.5 / smearing) / (2 * interpolation_nodes + 1)) ** 0.5 - / volume ** (2 / 3) - * (2**0.5 / smearing * H(ns_mesh)) ** interpolation_nodes - / factorial(interpolation_nodes) + / self.volume ** (2 / 3) + * (2**0.5 / smearing * h) ** interpolation_nodes + / i_n_factorial * torch.exp( - (interpolation_nodes) * (torch.log(interpolation_nodes / 2) - 1) / 2 + interpolation_nodes * (torch.log(interpolation_nodes / 2) - 1) / 2 ) - * RMS_phi(ns_mesh) + * RMS_phi[interpolation_nodes - 1] ) - def err_real(smearing, cutoff): + def err_rspace(self, smearing: torch.Tensor, cutoff: torch.Tensor) -> torch.Tensor: + """ + The real space error of PME. + + :param smearing: see :class:`torchpme.PMECalculator` for details + :param cutoff: see :class:`torchpme.PMECalculator` for details + """ return ( - prefac - / torch.sqrt(cutoff * volume) + self.prefac + / torch.sqrt(cutoff * self.volume) * torch.exp(-(cutoff**2) / 2 / smearing**2) ) - def loss(smearing, ns_mesh, cutoff): + def error( + self, + cutoff: float, + smearing: float, + mesh_spacing: float, + interpolation_nodes: float, + ) -> torch.Tensor: + r""" + Calculate the error bound of PME. + + .. math:: + \text{Error}_{\text{total}} = \sqrt{\text{Error}_{\text{real space}}^2 + + \text{Error}_{\text{Fourier space}}^2 + + :param smearing: if its value is given, it will not be tuned, see + :class:`torchpme.PMECalculator` for details + :param mesh_spacing: if its value is given, it will not be tuned, see + :class:`torchpme.PMECalculator` for details + :param cutoff: if its value is given, it will not be tuned, see + :class:`torchpme.PMECalculator` for details + :param interpolation_nodes: The number ``n`` of nodes used in the interpolation + per coordinate axis. The total number of interpolation nodes in 3D will be + ``n^3``. In general, for ``n`` nodes, the interpolation will be performed by + piecewise polynomials of degree ``n - 1`` (e.g. ``n = 4`` for cubic + interpolation). Only the values ``3, 4, 5, 6, 7`` are supported. + """ + smearing = torch.tensor(smearing) + mesh_spacing = torch.tensor(mesh_spacing) + cutoff = torch.tensor(cutoff) + interpolation_nodes = torch.tensor(interpolation_nodes) return torch.sqrt( - err_Fourier(smearing, ns_mesh) ** 2 + err_real(smearing, cutoff) ** 2 + self.err_rspace(smearing, cutoff) ** 2 + + self.err_kspace(smearing, mesh_spacing, interpolation_nodes) ** 2 ) - - params = [smearing_opt, ns_mesh_opt, cutoff_opt] - _optimize_parameters( - params=params, - loss=loss, - max_steps=max_steps, - accuracy=accuracy, - learning_rate=learning_rate, - ) - - return ( - float(smearing_opt), - { - "mesh_spacing": float(torch.min(cell_dimensions / ns_mesh_opt)), - "interpolation_nodes": int(interpolation_nodes), - }, - float(cutoff_opt), - ) - - -def _compute_RMS_phi( - cell: torch.Tensor, - interpolation_nodes: torch.Tensor, - ns_mesh: torch.Tensor, - positions: torch.Tensor, -) -> torch.Tensor: - inverse_cell = torch.linalg.inv(cell) - # Compute positions relative to the mesh basis vectors - positions_rel = ns_mesh * torch.matmul(positions, inverse_cell) - - # Calculate positions and distances based on interpolation nodes - even = interpolation_nodes % 2 == 0 - if even: - # For Lagrange interpolation, when the number of interpolation - # is even, the relative position of a charge is the midpoint of - # the two nearest gridpoints. - positions_rel_idx = _Floor.apply(positions_rel) - else: - # For Lagrange interpolation, when the number of interpolation - # points is odd, the relative position of a charge is the nearest gridpoint. - positions_rel_idx = _Round.apply(positions_rel) - - # Calculate indices of mesh points on which the particle weights are - # interpolated. For each particle, its weight is "smeared" onto `order**3` mesh - # points, which can be achived using meshgrid below. - indices_to_interpolate = torch.stack( - [ - (positions_rel_idx + i) - for i in range( - 1 - (interpolation_nodes + 1) // 2, - 1 + interpolation_nodes // 2, - ) - ], - dim=0, - ) - positions_rel = positions_rel[torch.newaxis, :, :] - positions_rel += 1e-10 * torch.randn( - positions_rel.shape, dtype=cell.dtype, device=cell.device - ) # Noises help the algorithm work for tiny systems (<100 atoms) - return ( - torch.mean( - (torch.prod(indices_to_interpolate - positions_rel, dim=0)) ** 2, dim=0 - ) - ** 0.5 - ) - - -class _Floor(torch.autograd.Function): - """floor function with non-zero gradient""" - - @staticmethod - def forward(ctx, input): - result = torch.floor(input) - ctx.save_for_backward(result) - return result - - @staticmethod - def backward(ctx, grad_output): - return grad_output - - -class _Round(torch.autograd.Function): - """round function with non-zero gradient""" - - @staticmethod - def forward(ctx, input): - result = torch.round(input) - ctx.save_for_backward(result) - return result - - @staticmethod - def backward(ctx, grad_output): - return grad_output diff --git a/src/torchpme/tuning/tuner.py b/src/torchpme/tuning/tuner.py new file mode 100644 index 00000000..c3498c00 --- /dev/null +++ b/src/torchpme/tuning/tuner.py @@ -0,0 +1,324 @@ +import math +import time +from typing import Optional + +import torch + +from ..calculators import Calculator +from ..potentials import InversePowerLawPotential +from ._utils import _validate_parameters + + +class TuningErrorBounds(torch.nn.Module): + """ + Base class for error bounds. This class calculates the real space error and the + Fourier space error based on the error formula. This class is used in the tuning + process. It can also be used with the :class:`torchpme.tuning.tuner.TunerBase` to + build up a custom parameter tuner. + + :param charges: atomic charges + :param cell: single tensor of shape (3, 3), describing the bounding + :param positions: single tensor of shape (``len(charges), 3``) containing the + Cartesian positions of all point charges in the system. + """ + + def __init__( + self, + charges: torch.Tensor, + cell: torch.Tensor, + positions: torch.Tensor, + ): + super().__init__() + self._charges = charges + self._cell = cell + self._positions = positions + + def forward(self, *args, **kwargs): + return self.error(*args, **kwargs) + + def error(self, *args, **kwargs): + raise NotImplementedError + + +class TunerBase: + """ + Base class defining the interface for a parameter tuner. + + This class provides a framework for tuning the parameters of a calculator. The class + itself supports estimating the ``smearing`` from the real space cutoff based on the + real space error formula. The :func:`TunerBase.tune` defines the interface for a + sophisticated tuning process, which takes a value of the desired accuracy. + + :param charges: atomic charges + :param cell: single tensor of shape (3, 3), describing the bounding + :param positions: single tensor of shape (``len(charges), 3``) containing the + Cartesian positions of all point charges in the system. + :param cutoff: real space cutoff, serves as a hyperparameter here. + :param calculator: the calculator to be tuned + :param exponent: exponent of the potential, only exponent = 1 is supported + + Example + ------- + >>> import torch + >>> import torchpme + >>> positions = torch.tensor( + ... [[0.0, 0.0, 0.0], [0.4, 0.4, 0.4]], dtype=torch.float64 + ... ) + >>> charges = torch.tensor([[1.0], [-1.0]], dtype=torch.float64) + >>> cell = torch.eye(3, dtype=torch.float64) + >>> tuner = TunerBase(charges, cell, positions, 4.4, torchpme.EwaldCalculator) + >>> smearing = tuner.estimate_smearing(1e-3) + >>> print(smearing) + 1.1069526756106463 + + """ + + def __init__( + self, + charges: torch.Tensor, + cell: torch.Tensor, + positions: torch.Tensor, + cutoff: float, + calculator: type[Calculator], + exponent: int = 1, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + _validate_parameters(charges, cell, positions, exponent) + self.charges = charges + self.cell = cell + self.positions = positions + self.cutoff = cutoff + self.calculator = calculator + self.exponent = exponent + self.device = "cpu" if device is None else device + self.dtype = torch.get_default_dtype() if dtype is None else dtype + + self._prefac = 2 * float((charges**2).sum()) / math.sqrt(len(positions)) + + def tune(self, accuracy: float = 1e-3): + raise NotImplementedError + + def estimate_smearing( + self, + accuracy: float, + ) -> float: + """ + Estimate the smearing based on the error formula of the real space. The + smearing is set as leading to a real space error of ``accuracy/4``. + + :param accuracy: a float, the desired accuracy + :return: a float, the estimated smearing + """ + if not isinstance(accuracy, float): + raise ValueError(f"'{accuracy}' is not a float.") + ratio = math.sqrt( + -2 + * math.log( + accuracy + / 2 + / self._prefac + * math.sqrt(self.cutoff * float(torch.abs(self.cell.det()))) + ) + ) + smearing = self.cutoff / ratio + + return float(smearing) + + +class GridSearchTuner(TunerBase): + """ + Tuner using grid search. + + The tuner uses the error formula to estimate the error of a given parameter set. If + the error is smaller than the accuracy, the timing is measured and returned. If the + error is larger than the accuracy, the timing is set to infinity and the parameter + is skipped. + + .. note:: + + The cutoff is treated as a hyperparameter here. In case one wants to tune the + cutoff, one could instantiate the tuner with different cutoff values and + manually pick the best from the tuning results. + + :param charges: atomic charges + :param cell: single tensor of shape (3, 3), describing the bounding + :param positions: single tensor of shape (``len(charges), 3``) containing the + Cartesian positions of all point charges in the system. + :param cutoff: real space cutoff, serves as a hyperparameter here. + :param calculator: the calculator to be tuned + :param error_bounds: error bounds for the calculator + :param params: list of Fourier space parameter sets for which the error is estimated + :param neighbor_indices: torch.Tensor with the ``i,j`` indices of neighbors for + which the potential should be computed in real space. + :param neighbor_distances: torch.Tensor with the pair distances of the neighbors for + which the potential should be computed in real space. + :param exponent: exponent of the potential, only exponent = 1 is supported + """ + + def __init__( + self, + charges: torch.Tensor, + cell: torch.Tensor, + positions: torch.Tensor, + cutoff: float, + calculator: type[Calculator], + error_bounds: type[TuningErrorBounds], + params: list[dict], + neighbor_indices: torch.Tensor, + neighbor_distances: torch.Tensor, + exponent: int = 1, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + super().__init__( + charges=charges, + cell=cell, + positions=positions, + cutoff=cutoff, + calculator=calculator, + exponent=exponent, + dtype=dtype, + device=device, + ) + self.error_bounds = error_bounds + self.params = params + self.time_func = TuningTimings( + charges, + cell, + positions, + neighbor_indices, + neighbor_distances, + True, + dtype=dtype, + device=device, + ) + + def tune(self, accuracy: float = 1e-3) -> tuple[list[float], list[float]]: + """ + Estimate the error and timing for each parameter set. Only parameters for + which the error is smaller than the accuracy are timed, the others' timing is + set to infinity. + + :param accuracy: a float, the desired accuracy + :return: a list of errors and a list of timings + """ + if not isinstance(accuracy, float): + raise ValueError(f"'{accuracy}' is not a float.") + smearing = self.estimate_smearing(accuracy) + param_errors = [] + param_timings = [] + for param in self.params: + error = self.error_bounds(smearing=smearing, cutoff=self.cutoff, **param) # type: ignore[call-arg] + param_errors.append(float(error)) + # only computes timings for parameters that meet the accuracy requirements + param_timings.append( + self._timing(smearing, param) if error <= accuracy else float("inf") + ) + + return param_errors, param_timings + + def _timing(self, smearing: float, k_space_params: dict): + calculator = self.calculator( + potential=InversePowerLawPotential( + exponent=self.exponent, # but only exponent = 1 is supported + smearing=smearing, + device=self.device, + dtype=self.dtype, + ), + device=self.device, + dtype=self.dtype, + **k_space_params, + ) + + return self.time_func(calculator) + + +class TuningTimings(torch.nn.Module): + """ + Class for timing a calculator. + + The class estimates the average execution time of a given calculater after several + warmup runs. The class takes the information of the structure that one wants to + benchmark on, and the configuration of the timing process as inputs. + + :param charges: atomic charges + :param cell: single tensor of shape (3, 3), describing the bounding + :param positions: single tensor of shape (``len(charges), 3``) containing the + Cartesian positions of all point charges in the system. + :param cutoff: real space cutoff, serves as a hyperparameter here. + :param neighbor_indices: torch.Tensor with the ``i,j`` indices of neighbors for + which the potential should be computed in real space. + :param neighbor_distances: torch.Tensor with the pair distances of the neighbors for + which the potential should be computed in real space. + :param n_repeat: number of times to repeat to estimate the average timing + :param n_warmup: number of warmup runs, recommended to be at least 4 + :param run_backward: whether to run the backward pass + """ + + def __init__( + self, + charges: torch.Tensor, + cell: torch.Tensor, + positions: torch.Tensor, + neighbor_indices: torch.Tensor, + neighbor_distances: torch.Tensor, + n_repeat: int = 4, + n_warmup: int = 4, + run_backward: Optional[bool] = True, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + super().__init__() + self.charges = charges + self.cell = cell + self.positions = positions + self.dtype = dtype + self.device = device + self.n_repeat = n_repeat + self.n_warmup = n_warmup + self.run_backward = run_backward + self.neighbor_indices = neighbor_indices.to(device=self.device) + self.neighbor_distances = neighbor_distances.to( + dtype=self.dtype, device=self.device + ) + + def forward(self, calculator: torch.nn.Module): + """ + Estimate the execution time of a given calculator for the structure + to be used as benchmark. + + :param calculator: the calculator to be tuned + :return: a float, the average execution time + """ + # measure time + execution_time = 0.0 + + for _ in range(self.n_repeat + self.n_warmup): + if _ == self.n_warmup: + execution_time = 0.0 + positions = self.positions.clone() + cell = self.cell.clone() + charges = self.charges.clone() + # nb - this won't compute gradiens involving the distances + if self.run_backward: + positions.requires_grad_(True) + cell.requires_grad_(True) + charges.requires_grad_(True) + execution_time -= time.monotonic() + result = calculator.forward( + positions=positions, + charges=charges, + cell=cell, + neighbor_indices=self.neighbor_indices, + neighbor_distances=self.neighbor_distances, + ) + value = result.sum() + if self.run_backward: + value.backward(retain_graph=True) + + if self.device is torch.device("cuda"): + torch.cuda.synchronize() + execution_time += time.monotonic() + + return execution_time / self.n_repeat diff --git a/tests/calculators/test_workflow.py b/tests/calculators/test_workflow.py index 7858b395..4af4b4ab 100644 --- a/tests/calculators/test_workflow.py +++ b/tests/calculators/test_workflow.py @@ -223,6 +223,6 @@ def test_potential_and_calculator_incompatability( params["potential"].device = device params["potential"] = torch.jit.script(params["potential"]) with pytest.raises( - AssertionError, match="Potential must be an instance of Potential, got.*" + TypeError, match="Potential must be an instance of Potential, got.*" ): CalculatorClass(**params) diff --git a/tests/requirements.txt b/tests/requirements.txt index 582714a8..fab921e5 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -4,3 +4,4 @@ pytest pytest-cov scipy vesin >= 0.3.0 +vesin[torch] >= 0.3.0 diff --git a/tests/tuning/test_error_bounds.py b/tests/tuning/test_error_bounds.py new file mode 100644 index 00000000..8b2a7110 --- /dev/null +++ b/tests/tuning/test_error_bounds.py @@ -0,0 +1,35 @@ +import pytest +import torch + +from torchpme.tuning.ewald import EwaldErrorBounds +from torchpme.tuning.p3m import P3MErrorBounds +from torchpme.tuning.pme import PMEErrorBounds + + +@pytest.mark.parametrize( + ("error_bound", "params", "ref_err"), + [ + ( + EwaldErrorBounds, + dict(smearing=1.0, lr_wavelength=0.5, cutoff=4.4), + torch.tensor(8.4304e-05), + ), + ( + PMEErrorBounds, + dict(smearing=1.0, mesh_spacing=0.5, cutoff=4.4, interpolation_nodes=3), + torch.tensor(0.0011180), + ), + ( + P3MErrorBounds, + dict(smearing=1.0, mesh_spacing=0.5, cutoff=4.4, interpolation_nodes=3), + torch.tensor(4.5961e-04), + ), + ], +) +def test_error_bounds(error_bound, params, ref_err): + charges = torch.tensor([[1.0], [-1.0]]) + cell = torch.eye(3) + positions = torch.tensor([[0.0, 0.0, 0.0], [0.4, 0.4, 0.4]]) + error_bound = error_bound(charges, cell, positions) + print(float(error_bound(**params))) + torch.testing.assert_close(error_bound(**params), ref_err) diff --git a/tests/tuning/test_timer.py b/tests/tuning/test_timer.py new file mode 100644 index 00000000..44adac62 --- /dev/null +++ b/tests/tuning/test_timer.py @@ -0,0 +1,77 @@ +import sys +from pathlib import Path + +import torch + +from torchpme import ( + CoulombPotential, + EwaldCalculator, +) +from torchpme.tuning.tuner import TuningTimings + +sys.path.append(str(Path(__file__).parents[1])) +from helpers import compute_distances, define_crystal, neighbor_list + +DTYPE = torch.float32 +DEFAULT_CUTOFF = 4.4 +CHARGES_1 = torch.ones((4, 1), dtype=DTYPE) +POSITIONS_1 = 0.3 * torch.arange(12, dtype=DTYPE).reshape((4, 3)) +CELL_1 = torch.eye(3, dtype=DTYPE) + + +def _nl_calculation(pos, cell): + neighbor_indices, neighbor_shifts = neighbor_list( + positions=pos, + periodic=True, + box=cell, + cutoff=DEFAULT_CUTOFF, + neighbor_shifts=True, + ) + + neighbor_distances = compute_distances( + positions=pos, + neighbor_indices=neighbor_indices, + cell=cell, + neighbor_shifts=neighbor_shifts, + ) + + return neighbor_indices, neighbor_distances + + +def test_timer(): + n_repeat_1 = 8 + n_repeat_2 = 16 + pos, charges, cell, madelung_ref, num_units = define_crystal() + neighbor_indices, neighbor_distances = _nl_calculation(pos, cell) + + calculator = EwaldCalculator( + potential=CoulombPotential(smearing=1.0), + lr_wavelength=1.0, + dtype=DTYPE, + ) + + timing_1 = TuningTimings( + charges=charges, + cell=cell, + positions=pos, + neighbor_indices=neighbor_indices, + neighbor_distances=neighbor_distances, + dtype=DTYPE, + n_repeat=n_repeat_1, + ) + + timing_2 = TuningTimings( + charges=charges, + cell=cell, + positions=pos, + neighbor_indices=neighbor_indices, + neighbor_distances=neighbor_distances, + dtype=DTYPE, + n_repeat=n_repeat_2, + ) + + time_1 = timing_1.forward(calculator) + time_2 = timing_2.forward(calculator) + + assert time_1 > 0 + assert time_1 * n_repeat_1 < time_2 * n_repeat_2 diff --git a/tests/test_tuning.py b/tests/tuning/test_tuning.py similarity index 51% rename from tests/test_tuning.py rename to tests/tuning/test_tuning.py index a69f9b57..94710425 100644 --- a/tests/test_tuning.py +++ b/tests/tuning/test_tuning.py @@ -1,5 +1,4 @@ import sys -import warnings from pathlib import Path import pytest @@ -14,15 +13,35 @@ from torchpme.tuning import tune_ewald, tune_p3m, tune_pme sys.path.append(str(Path(__file__).parents[1])) -from helpers import define_crystal, neighbor_list +from helpers import compute_distances, define_crystal, neighbor_list DTYPE = torch.float32 DEVICE = "cpu" +DEFAULT_CUTOFF = 4.4 CHARGES_1 = torch.ones((4, 1), dtype=DTYPE, device=DEVICE) POSITIONS_1 = 0.3 * torch.arange(12, dtype=DTYPE, device=DEVICE).reshape((4, 3)) CELL_1 = torch.eye(3, dtype=DTYPE, device=DEVICE) +def _nl_calculation(pos, cell): + neighbor_indices, neighbor_shifts = neighbor_list( + positions=pos, + periodic=True, + box=cell, + cutoff=DEFAULT_CUTOFF, + neighbor_shifts=True, + ) + + neighbor_distances = compute_distances( + positions=pos, + neighbor_indices=neighbor_indices, + cell=cell, + neighbor_shifts=neighbor_shifts, + ) + + return neighbor_indices, neighbor_distances + + @pytest.mark.parametrize( ("calculator", "tune", "param_length"), [ @@ -40,21 +59,21 @@ def test_parameter_choose(calculator, tune, param_length, accuracy): # Get input parameters and adjust to account for scaling pos, charges, cell, madelung_ref, num_units = define_crystal() - smearing, params, sr_cutoff = tune( - sum_squared_charges=float(torch.sum(charges**2)), - cell=cell, - positions=pos, + # Compute neighbor list + neighbor_indices, neighbor_distances = _nl_calculation(pos, cell) + + smearing, params, _ = tune( + charges, + cell, + pos, + DEFAULT_CUTOFF, + neighbor_indices=neighbor_indices, + neighbor_distances=neighbor_distances, accuracy=accuracy, - learning_rate=0.75, ) assert len(params) == param_length - # Compute neighbor list - neighbor_indices, neighbor_distances = neighbor_list( - positions=pos, periodic=True, box=cell, cutoff=sr_cutoff - ) - # Compute potential and compare against target value using default hypers calc = calculator( potential=(CoulombPotential(smearing=smearing)), @@ -73,108 +92,22 @@ def test_parameter_choose(calculator, tune, param_length, accuracy): torch.testing.assert_close(madelung, madelung_ref, atol=0, rtol=accuracy) -def test_odd_interpolation_nodes(): - pos, charges, cell, madelung_ref, num_units = define_crystal() - - smearing, params, sr_cutoff = tune_pme( - sum_squared_charges=float(torch.sum(charges**2)), - cell=cell, - positions=pos, - interpolation_nodes=5, - learning_rate=0.75, - ) - - neighbor_indices, neighbor_distances = neighbor_list( - positions=pos, periodic=True, box=cell, cutoff=sr_cutoff - ) - - calc = PMECalculator(potential=CoulombPotential(smearing=smearing), **params) - potentials = calc.forward( - positions=pos, - charges=charges, - cell=cell, - neighbor_indices=neighbor_indices, - neighbor_distances=neighbor_distances, - ) - energies = potentials * charges - madelung = -torch.sum(energies) / num_units - - torch.testing.assert_close(madelung, madelung_ref, atol=0, rtol=1e-3) - - -@pytest.mark.parametrize("tune", [tune_ewald, tune_pme, tune_p3m]) -def test_fix_parameters(tune): - """Test that the parameters are fixed when they are passed as arguments.""" - pos, charges, cell, _, _ = define_crystal() - - kwargs_ref = { - "sum_squared_charges": float(torch.sum(charges**2)), - "cell": cell, - "positions": pos, - "max_steps": 5, - } - - kwargs = kwargs_ref.copy() - kwargs["smearing"] = 0.1 - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - smearing, _, _ = tune(**kwargs) - pytest.approx(smearing, 0.1) - - kwargs = kwargs_ref.copy() - if tune.__name__ in ["tune_pme", "tune_p3m"]: - kwargs["mesh_spacing"] = 0.1 - else: - kwargs["lr_wavelength"] = 0.1 - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - _, kspace_param, _ = tune(**kwargs) - - kspace_param = list(kspace_param.values())[0] - pytest.approx(kspace_param, 0.1) - - kwargs = kwargs_ref.copy() - kwargs["cutoff"] = 0.1 - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - _, _, sr_cutoff = tune(**kwargs) - pytest.approx(sr_cutoff, 1.0) - - -@pytest.mark.parametrize("tune", [tune_ewald, tune_pme, tune_p3m]) -def test_non_positive_charge_error(tune): - pos, _, cell, _, _ = define_crystal() - - match = "sum of squared charges must be positive, got -1.0" - with pytest.raises(ValueError, match=match): - tune(-1.0, cell, pos) - - match = "sum of squared charges must be positive, got 0.0" - with pytest.raises(ValueError, match=match): - tune(0.0, cell, pos) - - @pytest.mark.parametrize("tune", [tune_ewald, tune_pme, tune_p3m]) def test_accuracy_error(tune): pos, charges, cell, _, _ = define_crystal() match = "'foo' is not a float." + neighbor_indices, neighbor_distances = _nl_calculation(pos, cell) with pytest.raises(ValueError, match=match): - tune(float(torch.sum(charges**2)), cell, pos, accuracy="foo") - - -@pytest.mark.parametrize("tune", [tune_ewald, tune_pme, tune_p3m]) -def test_loss_is_nan_error(tune): - pos, charges, cell, _, _ = define_crystal() - - match = ( - "The value of the estimated error is now nan, " - "consider using a smaller learning rate." - ) - with pytest.raises(ValueError, match=match): - tune(float(torch.sum(charges**2)), cell, pos, learning_rate=1e1000) + tune( + charges, + cell, + pos, + DEFAULT_CUTOFF, + neighbor_indices, + neighbor_distances, + accuracy="foo", + ) @pytest.mark.parametrize("tune", [tune_ewald, tune_pme, tune_p3m]) @@ -182,8 +115,17 @@ def test_exponent_not_1_error(tune): pos, charges, cell, _, _ = define_crystal() match = "Only exponent = 1 is supported" + neighbor_indices, neighbor_distances = _nl_calculation(pos, cell) with pytest.raises(NotImplementedError, match=match): - tune(float(torch.sum(charges**2)), cell, pos, exponent=2) + tune( + charges, + cell, + pos, + DEFAULT_CUTOFF, + neighbor_indices, + neighbor_distances, + exponent=2, + ) @pytest.mark.parametrize("tune", [tune_ewald, tune_pme, tune_p3m]) @@ -194,9 +136,12 @@ def test_invalid_shape_positions(tune): ) with pytest.raises(ValueError, match=match): tune( - sum_squared_charges=1.0, - positions=torch.ones((4, 5), dtype=DTYPE, device=DEVICE), - cell=CELL_1, + CHARGES_1, + CELL_1, + torch.ones((4, 5), dtype=DTYPE, device=DEVICE), + DEFAULT_CUTOFF, + None, # dummy neighbor indices + None, # dummy neighbor distances ) @@ -209,9 +154,12 @@ def test_invalid_shape_cell(tune): ) with pytest.raises(ValueError, match=match): tune( - sum_squared_charges=1.0, - positions=POSITIONS_1, - cell=torch.ones([2, 2], dtype=DTYPE, device=DEVICE), + CHARGES_1, + torch.ones([2, 2], dtype=DTYPE, device=DEVICE), + POSITIONS_1, + DEFAULT_CUTOFF, + None, # dummy neighbor indices + None, # dummy neighbor distances ) @@ -222,11 +170,7 @@ def test_invalid_cell(tune): "periodic calculation" ) with pytest.raises(ValueError, match=match): - tune( - sum_squared_charges=1.0, - positions=POSITIONS_1, - cell=torch.zeros(3, 3), - ) + tune(CHARGES_1, torch.zeros(3, 3), POSITIONS_1, DEFAULT_CUTOFF, None, None) @pytest.mark.parametrize("tune", [tune_ewald, tune_pme, tune_p3m]) @@ -237,9 +181,12 @@ def test_invalid_dtype_cell(tune): ) with pytest.raises(ValueError, match=match): tune( - sum_squared_charges=1.0, - positions=POSITIONS_1, - cell=torch.eye(3, dtype=torch.float64, device=DEVICE), + CHARGES_1, + torch.eye(3, dtype=torch.float64, device=DEVICE), + POSITIONS_1, + DEFAULT_CUTOFF, + None, + None, ) @@ -251,7 +198,10 @@ def test_invalid_device_cell(tune): ) with pytest.raises(ValueError, match=match): tune( - sum_squared_charges=1.0, - positions=POSITIONS_1, - cell=torch.eye(3, dtype=DTYPE, device="meta"), + CHARGES_1, + torch.eye(3, dtype=DTYPE, device="meta"), + POSITIONS_1, + DEFAULT_CUTOFF, + None, + None, )