Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revamped tuning #130

Merged
merged 86 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
86 commits
Select commit Hold shift + click to select a range
519ebd3
Initial version of `grid_search`
GardevoirX Nov 26, 2024
830e0f7
Remove error
GardevoirX Nov 26, 2024
1b80e93
Allow a precomputed nl
GardevoirX Nov 26, 2024
1e346b8
Renamed examples, and added a tuning playground
ceriottm Nov 23, 2024
3e35d74
Nelder mead (doesn't work because actual error is not a good target)
ceriottm Nov 24, 2024
4a935f1
Added a tuning class
ceriottm Nov 24, 2024
628a11a
I'm not a morning person it seems
ceriottm Nov 24, 2024
8f1f543
Examples
ceriottm Nov 24, 2024
4aa010c
Better plotting
ceriottm Nov 24, 2024
9f6e2fc
Fixes on `H` and `RMS_phi`
GardevoirX Nov 25, 2024
cc69e8c
Some cleaning and test fix
GardevoirX Nov 25, 2024
68057b4
Further clean
GardevoirX Nov 26, 2024
658a268
Replace `loss` in tuning with `ErrorBounds` and draft for `Tuner`
GardevoirX Nov 27, 2024
80951af
Supress output
GardevoirX Nov 27, 2024
fc598ac
Update `grid_search`
GardevoirX Nov 28, 2024
bc60428
Return something when is cannot reach desired accuracy
GardevoirX Nov 28, 2024
a73ea6a
Supress output
GardevoirX Nov 28, 2024
019545e
Repair some errors of the example
GardevoirX Nov 28, 2024
904b310
Add a warning for the case that no parameter can meet the accuracy re…
GardevoirX Dec 5, 2024
a85e918
Update warning
GardevoirX Dec 5, 2024
9b7660c
Documentations and pytests update
GardevoirX Dec 18, 2024
f68e739
Added a TIP4P example
ceriottm Dec 20, 2024
104f1f8
Started to change the API to use full charges rather than the sum of …
ceriottm Dec 20, 2024
66cbacb
Move from `sum_squared_charges` to `charges`
GardevoirX Dec 28, 2024
fac45f2
Refactor the tuning methods with a base class
GardevoirX Dec 28, 2024
efc4cd0
Fix pytests and make linter happy
GardevoirX Dec 28, 2024
1c0fcaf
Mini cleanups
ceriottm Dec 29, 2024
a184b53
Docs fix
GardevoirX Dec 29, 2024
f36e751
Separate timings calculator
ceriottm Dec 29, 2024
c57589d
Linting
ceriottm Dec 29, 2024
e5bbbfa
Try fix github action failures
GardevoirX Dec 29, 2024
c4bb2f8
Add tuning functions back
GardevoirX Jan 7, 2025
670ec3a
Allow doctests
GardevoirX Jan 7, 2025
7c54f87
Fix doctests and remove orphan functions
GardevoirX Jan 7, 2025
46e0837
Fix ewald doctest again and remove unused members
GardevoirX Jan 7, 2025
2fbc904
Formatting
GardevoirX Jan 7, 2025
30de1c6
Draft for renovated tuning
GardevoirX Jan 13, 2025
8ac6b3b
For now move back to `CoulombPotential`
GardevoirX Jan 13, 2025
411db71
Rearange the tuning stuff
GardevoirX Jan 13, 2025
b7bf21a
Rearrange again
GardevoirX Jan 15, 2025
c0a49e4
An initial version of refurnished documentation
GardevoirX Jan 15, 2025
9c4ff6f
Minor modification
GardevoirX Jan 16, 2025
fb45c97
ErrorBounds related updates
GardevoirX Jan 16, 2025
68f17fa
Update error formulas
GardevoirX Jan 16, 2025
388bdd9
Update tuning tests
GardevoirX Jan 16, 2025
3851726
Lint manually
GardevoirX Jan 16, 2025
b97e0b3
Update tuning doctests
GardevoirX Jan 16, 2025
0d335c7
Update example
GardevoirX Jan 16, 2025
7b36cf1
some minor cleanup
PicoCentauri Jan 17, 2025
410bf74
update some docs
PicoCentauri Jan 19, 2025
f6769ac
Documentation and minor fixes
GardevoirX Jan 19, 2025
39db980
Fix test
GardevoirX Jan 19, 2025
ac70152
Update documentation
GardevoirX Jan 19, 2025
12649f6
Fix vesin
GardevoirX Jan 20, 2025
e929b3b
Skip tuning tests if on win32
GardevoirX Jan 20, 2025
6d193a6
Remove vesin from src
GardevoirX Jan 20, 2025
6d6c0ba
Minor
GardevoirX Jan 20, 2025
0514c24
some cleanups
PicoCentauri Jan 20, 2025
ce4ed1f
More tests for base classes of tuner
GardevoirX Jan 20, 2025
1118dee
Test timer on different devices
GardevoirX Jan 20, 2025
097d445
Add some text to the example
GardevoirX Jan 20, 2025
8232294
Minor fix
GardevoirX Jan 20, 2025
d698d41
Formatting
GardevoirX Jan 20, 2025
adc7b4e
Minor
GardevoirX Jan 20, 2025
1485e09
Formatting
GardevoirX Jan 20, 2025
2b4cca7
Update default warmup rounds
GardevoirX Jan 20, 2025
4241a38
Started work on the tuning example
ceriottm Jan 21, 2025
96ba959
Merge branch 'revamped-tuning' of github.com:lab-cosmo/torch-pme into…
ceriottm Jan 21, 2025
74caf9f
WIP
ceriottm Jan 22, 2025
3d00e44
More work
ceriottm Jan 22, 2025
1cbffed
Helpstring update and tuning function return update
GardevoirX Jan 22, 2025
42a5afc
Cleaning and finishing up the example
ceriottm Jan 22, 2025
8ef6b73
Merge branch 'main' into revamped-tuning
GardevoirX Jan 22, 2025
1bac763
Cleaning and finishing up the example
ceriottm Jan 22, 2025
095ea3f
Fix pytests
GardevoirX Jan 22, 2025
e6ce9c1
Merge branch 'revamped-tuning' of github.com:lab-cosmo/torch-pme into…
ceriottm Jan 22, 2025
f4a7be3
Merge branch 'revamped-tuning' of github.com:lab-cosmo/torch-pme into…
ceriottm Jan 22, 2025
374e226
Make example faster by reducing replication of cell
ceriottm Jan 22, 2025
5dfa218
fix docs
PicoCentauri Jan 22, 2025
13481d6
fix tests
PicoCentauri Jan 22, 2025
a14beb6
fix linter
PicoCentauri Jan 22, 2025
4a9c546
Fix timing tests
GardevoirX Jan 22, 2025
1545c97
Disable timing test
GardevoirX Jan 22, 2025
6daacf3
Update changelog.rst
GardevoirX Jan 22, 2025
6621663
Add gallery for tuning classes
PicoCentauri Jan 23, 2025
0c02558
Merge branch 'main' into revamped-tuning
PicoCentauri Jan 23, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/src/references/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,16 @@ Added

* Added a PyTorch implementation of the exponential integral function
* Added ``dtype`` and ``device`` for ``Calculator`` classses
* Added an example on the tuning scheme and usage, and how to optimize the ``cutoff``

Changed
#######

* Removed ``utils`` module. ``utils.tuning`` and ``utils.prefactor`` are now in the root
of the package. ``utils.splines`` is now in the ``lib`` module.
* The tuning now uses a grid-search based scheme, instead of a gradient based scheme.
* The tuning functions no longer takes the ``cutoff`` parameter, and thus does not
support a built-in NL calculation.

Fixed
#####
Expand Down
2 changes: 1 addition & 1 deletion docs/src/references/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ refer to the :ref:`userdoc-how-to` section.

potentials/index
calculators/index
tuning
tuning/index
prefactors
metatensor
lib/index
Expand Down
22 changes: 0 additions & 22 deletions docs/src/references/tuning.rst

This file was deleted.

52 changes: 52 additions & 0 deletions docs/src/references/tuning/base_classes.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
Base Classes
############

GardevoirX marked this conversation as resolved.
Show resolved Hide resolved
Current scheme behind all tuning functions is grid-searching based, focusing on the Fourier
space parameters like ``lr_wavelength``, ``mesh_spacing`` and ``interpolation_nodes``.
For real space parameter ``cutoff``, it is treated as a hyperparameter here, which
should be manually specified by the user. The parameter ``smearing`` is determined by
the real space error formula and is set to achieve a real space error of
``desired_accuracy / 4``.

The Fourier space parameters are all discrete, so it's convenient to do the grid-search.
Default searching-ranges are provided for those parameters. For ``lr_wavelength``, the
values are chosen to be with a minimum of 1 and a maximum of 13 mesh points in each
spatial direction ``(x, y, z)``. For ``mesh_spacing``, the values are set to have
minimally 2 and maximally 7 mesh points in each spatial direction, for both the P3M and
PME method. The values of ``interpolation_nodes`` are the same as those supported in
:class:`torchpme.lib.MeshInterpolator`.

In the grid-searching, all possible parameter combinations are evaluated. The error
associated with the parameter is estimated by the error formulas implemented in the
subclasses of :class:`torchpme.tuning.tuner.TuningErrorBounds`. Parameter with
the error within the desired accuracy are benchmarked for computational time by
:class:`torchpme.tuning.tuner.TuningTimings` The timing of the other parameters are
not tested and set to infinity.

The return of these tuning functions contains the ``smearing`` and a dictionary, in
which there is parameter for the Fourier space. The parameter is that of the desired
accuracy and the shortest timing. The parameter of the smallest error will be returned
in the case that no parameter can fulfill the accuracy requirement.


.. autoclass:: torchpme.tuning.tuner.TunerBase
:members:

.. autoclass:: torchpme.tuning.tuner.GridSearchTuner
:members:

.. autoclass:: torchpme.tuning.tuner.TuningTimings
:members:

.. autoclass:: torchpme.tuning.tuner.TuningErrorBounds
:members:

Examples using Tuning Classes
-----------------------------

.. minigallery::

torchpme.tuning.tuner.TunerBase
torchpme.tuning.tuner.GridSearchTuner
torchpme.tuning.tuner.TuningTimings
torchpme.tuning.tuner.TuningErrorBounds
24 changes: 24 additions & 0 deletions docs/src/references/tuning/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
Tuning
######

The choice of parameters like the neighborlist ``cutoff``, the ``smearing`` or the
``lr_wavelength``/``mesh_spacing`` has a large influence one the accuracy of the
calculation. To help find the parameters that meet the accuracy requirements, this
module offers tuning methods for the calculators.

For usual tuning procedures we provide simple functions like
:func:`torchpme.tuning.tune_ewald` that returns for a given system the optimal
parameters for the Ewald summation. For more complex tuning procedures, we provide
classes like :class:`torchpme.tuning.ewald.EwaldErrorBounds` that can be used to
implement custom tuning procedures.

.. important::

Current tuning methods are only implemented for the :class:`Coulomb potential
<torchpme.CoulombPotential>`.

.. toctree::
:maxdepth: 1
:glob:

./*
24 changes: 24 additions & 0 deletions docs/src/references/tuning/tune_ewald.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
Tune Ewald
##########

GardevoirX marked this conversation as resolved.
Show resolved Hide resolved
The tuning is based on the following error formulas:

.. math::
\Delta F_\mathrm{real}
\approx \frac{Q^2}{\sqrt{N}}
\frac{2}{\sqrt{r_{\text{cutoff}} V}}
e^{-r_{\text{cutoff}}^2 / 2 \sigma^2}

.. math::
\Delta F_\mathrm{Fourier}^\mathrm{Ewald}
\approx \frac{Q^2}{\sqrt{N}}
\frac{\sqrt{2} / \sigma}{\pi\sqrt{2 V / h}} e^{-2\pi^2 \sigma^2 / h ^ 2}

where :math:`N` is the number of charges, :math:`Q^2 = \sum_{i = 1}^N q_i^2`, is the sum of squared
charges, :math:`r_{\text{cutoff}}` is the short-range cutoff, :math:`V` is the volume of the
simulation box and :math:`h^2` is the long range wavelength.

.. autofunction:: torchpme.tuning.tune_ewald

.. autoclass:: torchpme.tuning.ewald.EwaldErrorBounds
:members:
27 changes: 27 additions & 0 deletions docs/src/references/tuning/tune_p3m.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
Tune P3M
#########

The tuning is based on the following error formulas:

.. math::
\Delta F_\mathrm{real}
\approx \frac{Q^2}{\sqrt{N}}
\frac{2}{\sqrt{r_{\text{cutoff}} V}}
e^{-r_{\text{cutoff}}^2 / 2 \sigma^2}

.. math::
\Delta F_\mathrm{Fourier}^\mathrm{P3M}
\approx \frac{Q^2}{L^2}(\frac{\sqrt{2}H}{\sigma})^p
\sqrt{\frac{\sqrt{2}L}{N\sigma}
\sqrt{2\pi}\sum_{m=0}^{p-1}a_m^{(p)}(\frac{\sqrt{2}H}{\sigma})^{2m}}

where :math:`N` is the number of charges, :math:`Q^2 = \sum_{i = 1}^N q_i^2`, is the sum of squared
charges, :math:`r_{\text{cutoff}}` is the short-range cutoff, :math:`V` is the volume of the
simulation box, :math:`p` is the order of the interpolation scheme, :math:`H` is the spacing of mesh
points and :math:`a_m^{(p)}` is an expansion coefficient.


.. autofunction:: torchpme.tuning.tune_p3m

.. autoclass:: torchpme.tuning.p3m.P3MErrorBounds
:members:
27 changes: 27 additions & 0 deletions docs/src/references/tuning/tune_pme.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
Tune PME
#########

The tuning is based on the following error formulas:

.. math::
\Delta F_\mathrm{real}
\approx \frac{Q^2}{\sqrt{N}}
\frac{2}{\sqrt{r_{\text{cutoff}} V}}
e^{-r_{\text{cutoff}}^2 / 2 \sigma^2}

.. math::
\Delta F_\mathrm{Fourier}^\mathrm{PME}
\approx 2\pi^{1/4}\sqrt{\frac{3\sqrt{2} / \sigma}{N(2p+3)}}
\frac{Q^2}{L^2}\frac{(\sqrt{2}H/\sigma)^{p+1}}{(p+1)!} \times
\exp{\frac{(p+1)[\log{(p+1)} - \log 2 - 1]}{2}} \left< \phi_p^2 \right> ^{1/2}

where :math:`N` is the number of charges, :math:`Q^2 = \sum_{i = 1}^N q_i^2`, is the sum of squared
charges, :math:`r_{\text{cutoff}}` is the short-range cutoff, :math:`V` is the volume of the
simulation box, :math:`p` is the order of the interpolation scheme, :math:`H` is the spacing of mesh
points, and :math:`\phi_p^2 = H^{-(p+1)}\prod_{s\in S_H^{(p)}}(x - s)`, in which :math:`S_H^{(p)}` is
the :math:`p+1` mesh points closest to the point :math:`x`.

.. autofunction:: torchpme.tuning.tune_pme

.. autoclass:: torchpme.tuning.pme.PMEErrorBounds
:members:
19 changes: 17 additions & 2 deletions examples/1-charges-example.py → examples/01-charges-example.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,15 @@
from metatensor.torch.atomistic import NeighborListOptions, System

import torchpme
from torchpme.tuning import tune_pme

# %%
#
# Create the properties CsCl unit cell

symbols = ("Cs", "Cl")
types = torch.tensor([55, 17])
charges = torch.tensor([[1.0], [-1.0]], dtype=torch.float64)
positions = torch.tensor([(0, 0, 0), (0.5, 0.5, 0.5)], dtype=torch.float64)
cell = torch.eye(3, dtype=torch.float64)
pbc = torch.tensor([True, True, True])
Expand All @@ -55,8 +57,21 @@
# The ``sum_squared_charges`` is equal to ``2.0`` becaue each atom either has a charge
# of 1 or -1 in units of elementary charges.

smearing, pme_params, cutoff = torchpme.tuning.tune_pme(
sum_squared_charges=2.0, cell=cell, positions=positions
cutoff = 4.4
nl = vesin.torch.NeighborList(cutoff=cutoff, full_list=False)
neighbor_indices, neighbor_distances = nl.compute(
points=positions.to(dtype=torch.float64, device="cpu"),
box=cell.to(dtype=torch.float64, device="cpu"),
periodic=True,
quantities="Pd",
)
smearing, pme_params, _ = tune_pme(
charges=charges,
cell=cell,
positions=positions,
cutoff=cutoff,
neighbor_indices=neighbor_indices,
neighbor_distances=neighbor_distances,
)

# %%
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import vesin.torch

import torchpme
from torchpme.tuning import tune_pme

# %%
#
Expand Down Expand Up @@ -92,9 +93,22 @@
cell = torch.from_numpy(atoms.cell.array)

sum_squared_charges = float(torch.sum(charges**2))
cutoff = 4.4
nl = vesin.torch.NeighborList(cutoff=cutoff, full_list=False)
neighbor_indices, neighbor_distances = nl.compute(
points=positions.to(dtype=torch.float64, device="cpu"),
box=cell.to(dtype=torch.float64, device="cpu"),
periodic=True,
quantities="Pd",
)

smearing, pme_params, cutoff = torchpme.tuning.tune_pme(
sum_squared_charges=sum_squared_charges, cell=cell, positions=positions
smearing, pme_params, _ = tune_pme(
charges=charges,
cell=cell,
positions=positions,
cutoff=cutoff,
neighbor_indices=neighbor_indices,
neighbor_distances=neighbor_distances,
)

# %%
Expand Down
File renamed without changes.
File renamed without changes.
9 changes: 5 additions & 4 deletions examples/5-autograd-demo.py → examples/05-autograd-demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
exercise to the reader.
"""

# %%

from time import time

import ase
Expand Down Expand Up @@ -477,10 +479,11 @@ def forward(self, positions, cell, charges):
)

# %%
# We can also time the difference in execution
# We can also evaluate the difference in execution
# time between the Pytorch and scripted versions of the
# module (depending on the system, the relative efficiency
# of the two evaluations could go either way!)
# of the two evaluations could go either way, as this is
# a too small system to make a difference!)

duration = 0.0
for _i in range(20):
Expand Down Expand Up @@ -513,5 +516,3 @@ def forward(self, positions, cell, charges):

# %%
print(f"Evaluation time:\nPytorch: {time_python}ms\nJitted: {time_jit}ms")

# %%
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Loading
Loading