Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Documentation review #143

Merged
merged 15 commits into from
Sep 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 113 additions & 109 deletions docs/src/_static/sphericart_icon.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
15 changes: 3 additions & 12 deletions docs/src/api.rst
Original file line number Diff line number Diff line change
@@ -1,19 +1,10 @@
API documentation
=================

The core implementation of ``sphericart`` is written in C++. It relies on templates and
C++17 features such as ``if constexpr`` to reduce the runtime overhead of implementing
different normalization styles, and providing both a version with and without derivatives.

Spherical harmonics and their derivatives are computed with optimized hard-coded expressions
for low values of the principal angular momentum number :math:`l`, then switch to an efficient
recursive evaluation. The API involves initializing a calculator that allocates buffer space
and computes some constant factors, and then using it to compute :math:`Y_l^m` (and possibly its
first and/or second derivatives) for one or more points in 3D space.

This core C++ library is then made available to different environments through a C API.
The core implementation of ``sphericart`` is written in C++ and CUDA. This core library is
then also made available to different environments (C, Python, PyTorch, JAX).
This section contains a description of the interface of the ``sphericart`` library for the
different languages it supports.
different languages and frameworks it supports.

.. toctree::
:maxdepth: 1
Expand Down
11 changes: 6 additions & 5 deletions docs/src/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@ sphericart documentation
``sphericart`` is a multi-language library for the efficient calculation of
spherical harmonics and their derivatives in Cartesian coordinates.

The theory behind this efficient implementation is detailed in this
`paper <https://arxiv.org/abs/2302.08381>`_.
The theory behind this efficient implementation is detailed in
`this paper <https://doi.org/10.1063/5.0156307>`_.

The core library is implemented in C++ (with OpenMP parallelism) and CUDA.
It provides APIs for C, Python (NumPy), PyTorch and JAX. The torch and JAX
implementations provide fast spherical harmonics evaluations on GPUs.
It also provides APIs for C, Python (NumPy), PyTorch and JAX. The torch and JAX
implementations provide fast spherical harmonics on GPUs.

A native Julia package is also available.
A native Julia package is also available, contributed by
`Christoph Ortner <https://personal.math.ubc.ca/~ortner/>`_.

This documentation contains an installation guide, an API overview, some examples
of how to use the library, and a brief explanation of the mathematics involved.
Expand Down
16 changes: 11 additions & 5 deletions docs/src/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ This basic package makes use of NumPy. A PyTorch-based implementation can be ins

This pre-built version available on PyPI sacrifices some performance to ensure it
can run on all systems, and it does not include GPU support.
If you need an extra 5-10% of performance or you want to evaluate the spherical harmonics on GPUs,
you should build the code from source:
If you need an extra 5-10% of performance, you want to evaluate the spherical harmonics on GPUs,
and/or you want to use it in JAX, you should build the code from source:

.. code-block:: bash

Expand All @@ -35,15 +35,21 @@ you should build the code from source:
# torch bindings (CPU-only)
pip install --extra-index-url https://download.pytorch.org/whl/cpu .[torch]

Before installing the JAX version of ``sphericart``, you should already have the JAX
library installed according to the official JAX installation instructions.
Before installing the JAX version of ``sphericart``, make sure you already have the JAX
library installed according to the `official JAX installation instructions
<https://jax.readthedocs.io/en/latest/installation.html>`_.

In addition, if you want to use the CUDA functionalities of sphericart (either with torch
or JAX), make sure you have installed the CUDA toolkit and set up the environment variables
``CUDA_HOME``, ``LD_LIBRARY_FLAGS``, and ``PATH`` accordingly.


Julia package
-------------

The native Julia package can be installed by opening a REPL,
switching to the package manager by typing ``]`` and then ``add SpheriCart``.
switching to the package manager by typing ``]`` and then executing
the command ``add SpheriCart``.


C/C++/CUDA library
Expand Down
2 changes: 2 additions & 0 deletions docs/src/jax-api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,5 @@ using the CPU or CUDA implementation.

.. autofunction:: sphericart.jax.spherical_harmonics

.. autofunction:: sphericart.jax.solid_harmonics

32 changes: 20 additions & 12 deletions docs/src/maths.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,26 @@ There are multiple conventions for choosing normalization and phases, and it is
possible to reformulate the spherical harmonics in a real-valued form, which leads
to even further ambiguity in the definitions.

Within `sphericart` we take an opinionated stance: we compute only real-valued
harmonics, we express them as a function of the full Cartesian coordinates of a
point in three dimensions :math:`(x,y,z)` and compute by default "scaled"
versions :math:`\tilde{Y}^m_l(x, y, z)` which correspond to homogeneous polynomials
of the Cartesian coordinates:
Within `sphericart`, we compute only real-valued spherical harmonics and we express
them as a function of the full Cartesian coordinates of a point in three dimensions.
These correspond to the real spherical harmonics as defined in the corresponding
`Wikipedia article <https://en.wikipedia.org/wiki/Spherical_harmonics>`_, which we
refer to as :math:`Y^m_l`.
If you need complex spherical harmonics, or use a different convention for normalization
and storage order it is usually simple - if tedious and inefficient - to perform the
conversion manually, see :doc:`spherical-complex` for a simple example.


We also offer the possibility to compute "solid" harmonics, which are given by
:math:`\tilde{Y}^m_l = r^l\,{Y}_l^m`. Since these can be expressed as homogeneous
polynomials of the Cartesian coordinates :math:`(x,y,z)`, as opposed to
:math:`(x/r,y/r,z/r)`, they are less computationally expensive to evaluate.
Besides being slightly faster, they can also provide a more natural scaling if
used together with a radial expansion, and we recommend using them unless you
need the normalized version.

The formulas used to compute the solid harmonics (and, with few modifications,
also for the spherical harmonics) are:

.. math ::
\tilde{Y}_l^m(x, y, z) = r^l\,{Y}_l^m(x, y, z) = F_l^{|m|} Q_l^{|m|}(z, r) \times
Expand All @@ -41,13 +56,6 @@ If we neglect some constant normalization factors, these correspond to the
See also the `reference paper <https://arxiv.org/abs/2302.08381>`_ for further
implementation details.

The radially normalized version of the spherical harmonics can also be computed by providing
the appropriate flag when creating the `sphericart` calculators. These correspond to
the real spherical harmonics as defined in the corresponding
`Wikipedia article <https://en.wikipedia.org/wiki/Spherical_harmonics>`_.
However, we recommend using the scaled versions, which are slightly faster and
provide a more natural scaling if used together with a radial expansion.

The :math:`\tilde{Y}^m_l(x)` are stored contiguously in memory, e.g. as
:math:`\{ (l,m)=(0,0), (1,-1), (1,0), (1,1), (2,-2), \ldots \}`.
With zero-based indexing of the arrays, the ``(l,m)`` term is stored at
Expand Down
3 changes: 3 additions & 0 deletions docs/src/python-api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,6 @@ Python API

.. autoclass:: sphericart.SphericalHarmonics
:members:

.. autoclass:: sphericart.SolidHarmonics
:members:
37 changes: 37 additions & 0 deletions examples/cuda/example.cu
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,29 @@ int main() {
// internal buffers and numerical factors are initalized at construction
sphericart::cuda::SphericalHarmonics<double> calculator_cuda(l_max);

// allcate device memory
double* xyz_cuda;
CUDA_CHECK(cudaMalloc(&xyz_cuda, n_samples * 3 * sizeof(double)));
CUDA_CHECK(
cudaMemcpy(xyz_cuda, xyz.data(), n_samples * 3 * sizeof(double), cudaMemcpyHostToDevice)
);
double* sph_cuda;
CUDA_CHECK(cudaMalloc(&sph_cuda, n_samples * (l_max + 1) * (l_max + 1) * sizeof(double)));
double* dsph_cuda;
CUDA_CHECK(cudaMalloc(&dsph_cuda, n_samples * 3 * (l_max + 1) * (l_max + 1) * sizeof(double)));
double* ddsph_cuda;
CUDA_CHECK(
cudaMalloc(&ddsph_cuda, n_samples * 3 * 3 * (l_max + 1) * (l_max + 1) * sizeof(double))
);

// calculation examples
calculator_cuda.compute(xyz_cuda, n_samples, sph_cuda); // no gradients
calculator_cuda.compute_with_gradients(
xyz_cuda, n_samples, sph_cuda, dsph_cuda
); // with gradients
calculator_cuda.compute_with_hessians(
xyz_cuda, n_samples, sph_cuda, dsph_cuda, ddsph_cuda
); // with gradients and hessians

CUDA_CHECK(cudaMemcpy(
sph.data(), sph_cuda, n_samples * (l_max + 1) * (l_max + 1) * sizeof(double), cudaMemcpyDeviceToHost
Expand All @@ -82,8 +96,21 @@ int main() {
);
float* sph_cuda_f;
CUDA_CHECK(cudaMalloc(&sph_cuda_f, n_samples * (l_max + 1) * (l_max + 1) * sizeof(float)));
float* dsph_cuda_f;
CUDA_CHECK(cudaMalloc(&dsph_cuda_f, n_samples * 3 * (l_max + 1) * (l_max + 1) * sizeof(float)));
float* ddsph_cuda_f;
CUDA_CHECK(
cudaMalloc(&ddsph_cuda_f, n_samples * 3 * 3 * (l_max + 1) * (l_max + 1) * sizeof(float))
);

// calculation examples (float)
calculator_cuda_f.compute(xyz_cuda_f, n_samples, sph_cuda_f); // no gradients
calculator_cuda_f.compute_with_gradients(
xyz_cuda_f, n_samples, sph_cuda_f, dsph_cuda_f
); // with gradients
calculator_cuda_f.compute_with_hessians(
xyz_cuda_f, n_samples, sph_cuda_f, dsph_cuda_f, ddsph_cuda_f
); // with gradients and hessians

CUDA_CHECK(cudaMemcpy(
sph_f.data(),
Expand All @@ -101,5 +128,15 @@ int main() {
}
printf("Float vs double relative error: %12.8e\n", sqrt(sph_error / sph_norm));

/* ===== free device memory ===== */
CUDA_CHECK(cudaFree(xyz_cuda));
CUDA_CHECK(cudaFree(sph_cuda));
CUDA_CHECK(cudaFree(dsph_cuda));
CUDA_CHECK(cudaFree(ddsph_cuda));
CUDA_CHECK(cudaFree(xyz_cuda_f));
CUDA_CHECK(cudaFree(sph_cuda_f));
CUDA_CHECK(cudaFree(dsph_cuda_f));
CUDA_CHECK(cudaFree(ddsph_cuda_f));

return 0;
}
113 changes: 44 additions & 69 deletions python/src/sphericart/spherical_harmonics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@

class SphericalHarmonics:
"""
Spherical harmonics calculator, up to degree ``l_max``.
Spherical harmonics calculator, which computes the real spherical harmonics
:math:`Y^m_l` up to degree ``l_max``. The calculated spherical harmonics
are consistent with the definition of real spherical harmonics from Wikipedia.

This class computes the real spherical harmonics :math:`Y^l_m`.

In order to minimize the cost of each call, the `SphericalHarmonics` object
computes prefactors and initializes buffers upon creation
The `SphericalHarmonics` object computes prefactors and initializes buffers
upon creation

>>> import numpy as np
>>> import sphericart as sc
Expand All @@ -34,11 +34,11 @@ class SphericalHarmonics:
(10, 3, 81)

which returns the gradient as a tensor with size
`(n_samples, 3, (l_max+1)**2)`.
``(n_samples, 3, (l_max+1)**2)``.

:param l_max: the maximum degree of the spherical harmonics to be calculated

:return: a calculator, in the form of a `SphericalHarmonics` object
:return: a calculator, in the form of a ``SphericalHarmonics`` object
"""

def __init__(self, l_max: int):
Expand All @@ -64,7 +64,7 @@ def __del__(self):
def compute(self, xyz: np.ndarray) -> np.ndarray:
"""
Calculates the spherical harmonics for a set of 3D points, whose
coordinates are in the ``xyz`` array.
coordinates are given by the ``xyz`` array.

>>> import numpy as np
>>> import sphericart as sc
Expand Down Expand Up @@ -135,22 +135,21 @@ def compute_with_gradients(self, xyz: np.ndarray) -> Tuple[np.ndarray, np.ndarra
>>> sh_grads.shape
(10, 3, 81)

:param xyz:
The Cartesian coordinates of the 3D points, as an array with
:param xyz: The Cartesian coordinates of the 3D points, as an array with
shape ``(n_samples, 3)``.

:return:
A tuple containing:
* an array of shape ``(n_samples, (l_max+1)**2)`` containing all the
spherical harmonics up to degree `l_max` in lexicographic order.
For example, if ``l_max = 2``, The last axis will correspond to
spherical harmonics with ``(l, m) = (0, 0), (1, -1), (1, 0), (1,
1), (2, -2), (2, -1), (2, 0), (2, 1), (2, 2)``, in this order.
* An array of shape ``(n_samples, 3, (l_max+1)**2)`` containing all
the spherical harmonics' derivatives up to degree ``l_max``. The
last axis is organized in the same way as in the spherical
harmonics return array, while the second-to-last axis refers to
derivatives in the the x, y, and z directions, respectively.
:return: A tuple containing:

- an array of shape ``(n_samples, (l_max+1)**2)`` containing all the
spherical harmonics up to degree ``l_max`` in lexicographic order.
For example, if ``l_max = 2``, The last axis will correspond to
spherical harmonics with ``(l, m) = (0, 0), (1, -1), (1, 0), (1,
1), (2, -2), (2, -1), (2, 0), (2, 1), (2, 2)``, in this order.
- an array of shape ``(n_samples, 3, (l_max+1)**2)`` containing all
the spherical harmonics' derivatives up to degree ``l_max``. The
last axis is organized in the same way as in the spherical
harmonics return array, while the second-to-last axis refers to
derivatives in the the x, y, and z directions, respectively.

"""

Expand Down Expand Up @@ -225,27 +224,26 @@ def compute_with_hessians(
>>> sh_hessians.shape
(10, 3, 3, 81)

:param xyz:
The Cartesian coordinates of the 3D points, as an array with
:param xyz: The Cartesian coordinates of the 3D points, as an array with
shape ``(n_samples, 3)``.

:return:
A tuple containing:
* an array of shape ``(n_samples, (l_max+1)**2)`` containing all the
spherical harmonics up to degree `l_max` in lexicographic order.
For example, if ``l_max = 2``, The last axis will correspond to
spherical harmonics with ``(l, m) = (0, 0), (1, -1), (1, 0), (1,
1), (2, -2), (2, -1), (2, 0), (2, 1), (2, 2)``, in this order.
* An array of shape ``(n_samples, 3, (l_max+1)**2)`` containing all
the spherical harmonics' derivatives up to degree ``l_max``. The
last axis is organized in the same way as in the spherical
harmonics return array, while the second-to-last axis refers to
derivatives in the the x, y, and z directions, respectively.
* An array of shape ``(n_samples, 3, 3, (l_max+1)**2)`` containing all
the spherical harmonics' second derivatives up to degree ``l_max``.
The last axis is organized in the same way as in the spherical
harmonics return array, while the two intermediate axes represent the
Hessian dimensions.
:return: A tuple containing:

- an array of shape ``(n_samples, (l_max+1)**2)`` containing all the
spherical harmonics up to degree ``l_max`` in lexicographic order.
For example, if ``l_max = 2``, The last axis will correspond to
spherical harmonics with ``(l, m) = (0, 0), (1, -1), (1, 0), (1,
1), (2, -2), (2, -1), (2, 0), (2, 1), (2, 2)``, in this order.
- an array of shape ``(n_samples, 3, (l_max+1)**2)`` containing all
the spherical harmonics' derivatives up to degree ``l_max``. The
last axis is organized in the same way as in the spherical
harmonics return array, while the second-to-last axis refers to
derivatives in the the x, y, and z directions, respectively.
- an array of shape ``(n_samples, 3, 3, (l_max+1)**2)`` containing all
the spherical harmonics' second derivatives up to degree ``l_max``.
The last axis is organized in the same way as in the spherical
harmonics return array, while the two intermediate axes represent the
Hessian dimensions.

"""

Expand Down Expand Up @@ -314,36 +312,13 @@ class SolidHarmonics:
Solid harmonics calculator, up to degree ``l_max``.

This class computes the solid harmonics, a non-normalized form of the real
spherical harmonics, i.e. :math:`r^l Y^l_m`. These scaled spherical harmonics
are polynomials in the Cartesian coordinates of the input points.

In order to minimize the cost of each call, the `SolidHarmonics` object
computes prefactors and initializes buffers upon creation

>>> import numpy as np
>>> import sphericart as sc
>>> sh = sc.SolidHarmonics(l_max=8)

Then, the :py:func:`compute` method can be called on an array of 3D
Cartesian points to compute the solid harmonics

>>> xyz = np.random.normal(size=(10,3))
>>> sh_values = sh.compute(xyz)
>>> sh_values.shape
(10, 81)
spherical harmonics, i.e. :math:`r^l Y^m_l`. These scaled spherical harmonics
are polynomials in the Cartesian coordinates of the input points, and they
are therefore faster to compute.

In order to also compute derivatives, you can use

>>> sh_values, sh_grads = sh.compute_with_gradients(xyz)
>>> sh_grads.shape
(10, 3, 81)

which returns the gradient as a tensor with size
`(n_samples, 3, (l_max+1)**2)`.

:param l_max: the maximum degree of the spherical harmonics to be calculated
:param l_max: the maximum degree of the solid harmonics to be calculated

:return: a calculator, in the form of a `SolidHarmonics` object
:return: a calculator, in the form of a ``SolidHarmonics`` object
"""

def __init__(self, l_max: int):
Expand Down
Loading
Loading