Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change normalized to SphericalHarmonics and SolidHarmonics across all APIs #137

Merged
merged 12 commits into from
Aug 28, 2024
4 changes: 2 additions & 2 deletions benchmarks/cpp/benchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ template <typename DTYPE> void run_timings(int l_max, int n_tries, int n_samples
auto ddsph1 = std::vector<DTYPE>(n_samples * 9 * (l_max + 1) * (l_max + 1), 0.0);

{
SphericalHarmonics<DTYPE> calculator(l_max, false);
SolidHarmonics<DTYPE> calculator(l_max);
sxyz[0] = xyz[0];
sxyz[1] = xyz[1];
sxyz[2] = xyz[2];
Expand Down Expand Up @@ -143,7 +143,7 @@ template <typename DTYPE> void run_timings(int l_max, int n_tries, int n_samples
}

{
SphericalHarmonics<DTYPE> calculator(l_max, true);
SphericalHarmonics<DTYPE> calculator(l_max);
benchmark("Call without derivatives (normalized)", n_samples, n_tries, [&]() {
calculator.compute(xyz, sph1);
});
Expand Down
26 changes: 15 additions & 11 deletions examples/c/example.c
Original file line number Diff line number Diff line change
Expand Up @@ -42,32 +42,36 @@ int main(int argc, char* argv[]) {
/* ===== API calls ===== */

// opaque pointer declaration: initializes buffers and numerical factors
sphericart_calculator_t* calculator = sphericart_new(l_max, 0);
sphericart_spherical_harmonics_calculator_t* calculator =
sphericart_spherical_harmonics_new(l_max);

// function calls
// without derivatives
sphericart_compute_array(calculator, xyz, 3 * n_samples, sph, sph_size);
sphericart_spherical_harmonics_compute_array(calculator, xyz, 3 * n_samples, sph, sph_size);
// with derivatives
sphericart_compute_array_with_gradients(
sphericart_spherical_harmonics_compute_array_with_gradients(
calculator, xyz, 3 * n_samples, sph, sph_size, dsph, dsph_size
);
// with second derivatives
sphericart_compute_array_with_hessians(
sphericart_spherical_harmonics_compute_array_with_hessians(
calculator, xyz, 3 * n_samples, sph, sph_size, dsph, dsph_size, ddsph, ddsph_size
);

// per-sample calculation - we reuse the same arrays for simplicity, but
// only the first item is computed
sphericart_compute_sample(calculator, xyz, 3, sph, sph_size);
sphericart_compute_sample_with_gradients(calculator, xyz, 3, sph, sph_size, dsph, dsph_size);
sphericart_compute_sample_with_hessians(
sphericart_spherical_harmonics_compute_sample(calculator, xyz, 3, sph, sph_size);
sphericart_spherical_harmonics_compute_sample_with_gradients(
calculator, xyz, 3, sph, sph_size, dsph, dsph_size
);
sphericart_spherical_harmonics_compute_sample_with_hessians(
calculator, xyz, 3, sph, sph_size, dsph, dsph_size, ddsph, ddsph_size
);

// float version
sphericart_calculator_f_t* calculator_f = sphericart_new_f(l_max, 0);
sphericart_spherical_harmonics_calculator_f_t* calculator_f =
sphericart_spherical_harmonics_new_f(l_max);

sphericart_compute_array_with_gradients_f(
sphericart_spherical_harmonics_compute_array_with_gradients_f(
calculator_f, xyz_f, 3 * n_samples, sph_f, sph_size, dsph_f, dsph_size
);

Expand All @@ -83,12 +87,12 @@ int main(int argc, char* argv[]) {
/* ===== clean up ===== */

// frees up data arrays and sph object pointers
sphericart_delete(calculator);
sphericart_spherical_harmonics_delete(calculator);
free(xyz);
free(sph);
free(dsph);

sphericart_delete_f(calculator_f);
sphericart_spherical_harmonics_delete_f(calculator_f);
free(xyz_f);
free(sph_f);
free(dsph_f);
Expand Down
51 changes: 23 additions & 28 deletions examples/jax/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,63 +6,58 @@
key = jax.random.PRNGKey(0)
xyz = 6 * jax.random.normal(key, (10, 3))
l_max = 3 # set l_max to 3
normalized = True # in this example, we always compute normalized spherical harmonics

# calculate the spherical harmonics with the corresponding function
sph = sphericart.jax.spherical_harmonics(xyz, l_max, normalized)
# calculate the spherical harmonics with the corresponding function,
# we could also compute the solid harmonics with sphericart.jax.solid_harmonics
sph = sphericart.jax.spherical_harmonics(xyz, l_max)

# jit the function with jax.jit()
# the l_max and normalized arguments (positions 1 and 2 in the signature) must be static
jitted_sph_function = jax.jit(sphericart.jax.spherical_harmonics, static_argnums=(1, 2))
# the l_max argument (position 1 in the signature) must be static
jitted_sph_function = jax.jit(sphericart.jax.spherical_harmonics, static_argnums=(1,))

# compute the spherical harmonics with the jitted function and check their values
# against the non-jitted version
jitted_sph = jitted_sph_function(xyz, l_max, normalized)
jitted_sph = jitted_sph_function(xyz, l_max)
assert jax.numpy.allclose(sph, jitted_sph)


# calculate a scalar function of the spherical harmonics and take its gradient
# with respect to the input Cartesian coordinates, as well as its hessian
def scalar_output(xyz, l_max, normalized):
return jax.numpy.sum(sphericart.jax.spherical_harmonics(xyz, l_max, normalized))
def scalar_output(xyz, l_max):
return jax.numpy.sum(sphericart.jax.spherical_harmonics(xyz, l_max))


grad = jax.grad(scalar_output)(xyz, l_max, normalized)
grad = jax.grad(scalar_output)(xyz, l_max)

# NB: this computes a (n_samples,3,n_samples,3) hessian, i.e. includes cross terms
# between samples.
hessian = jax.hessian(scalar_output)(xyz, l_max, normalized)

# usually you want a (n_samples,3,3), taking derivatives wrt the coordinates
# of the same sample. one way to achieve this is as follows
hessian = jax.hessian(scalar_output)(xyz, l_max)

# usually you want a hessian in the shape (n_samples, 3, 3), taking derivatives
# wrt the coordinates of the same sample. one way to achieve this is as follows

def single_scalar_output(xyz, l_max, normalized):
return jax.numpy.sum(sphericart.jax.spherical_harmonics(xyz, l_max, normalized))
def single_scalar_output(xyz, l_max):
return jax.numpy.sum(sphericart.jax.spherical_harmonics(xyz, l_max))


# Compute the Hessian for a single (3,) input
# define a function that computes the Hessian for a single (3,) input
single_hessian = jax.hessian(single_scalar_output)

# Use vmap to vectorize the Hessian computation over the first axis
sh_hess = jax.vmap(single_hessian, in_axes=(0, None, None))
# use vmap to vectorize the Hessian computation over the first axis
sh_hess = jax.vmap(single_hessian, in_axes=(0, None))


# calculate a function of the spherical harmonics that returns an array
# and take its jacobian with respect to the input Cartesian coordinates,
# both in forward mode and in reverse mode
def array_output(xyz, l_max, normalized):
def array_output(xyz, l_max):
return jax.numpy.sum(
sphericart.jax.spherical_harmonics(xyz, l_max, normalized), axis=0
sphericart.jax.spherical_harmonics(xyz, l_max), axis=0
)


jacfwd = jax.jacfwd(array_output)(xyz, l_max, normalized)
jacrev = jax.jacrev(array_output)(xyz, l_max, normalized)
assert jax.numpy.allclose(jacfwd, jacrev) # check that the two are the same
jacfwd = jax.jacfwd(array_output)(xyz, l_max)
jacrev = jax.jacrev(array_output)(xyz, l_max)

# use vmap and compare the result with the original result:
vmapped_sph = jax.vmap(sphericart.jax.spherical_harmonics, in_axes=(0, None, None))(
xyz, l_max, normalized
vmapped_sph = jax.vmap(sphericart.jax.spherical_harmonics, in_axes=(0, None))(
xyz, l_max
)
assert jax.numpy.allclose(sph, vmapped_sph)
17 changes: 5 additions & 12 deletions examples/python/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
"""


def sphericart_example(l_max=10, n_samples=10000, normalized=False):
def sphericart_example(l_max=10, n_samples=10000):
# `sphericart` provides a SphericalHarmonics object that initializes the
# calculation and then can be called on any n x 3 arrays of Cartesian
# coordinates. It computes _all_ SPH up to a given l_max, and can compute
# scaled (default) and normalized (standard Ylm) harmonics.
# coordinates. It computes _all_ SPH up to a given l_max.

# ===== set up the calculation =====

Expand All @@ -27,8 +26,8 @@ def sphericart_example(l_max=10, n_samples=10000, normalized=False):
xyz_f = np.array(xyz, dtype=np.float32)

# ===== API calls =====

sh_calculator = sphericart.SphericalHarmonics(l_max, normalized=normalized)
# we could also compute the solid harmonics with sphericart.SolidHarmonics
sh_calculator = sphericart.SphericalHarmonics(l_max)

# without gradients
sh_sphericart = sh_calculator.compute(xyz)
Expand Down Expand Up @@ -62,14 +61,8 @@ def sphericart_example(l_max=10, n_samples=10000, normalized=False):

parser.add_argument("-l", type=int, default=10, help="maximum angular momentum")
parser.add_argument("-s", type=int, default=1000, help="number of samples")
parser.add_argument(
"--normalized",
action="store_true",
default=False,
help="compute normalized spherical harmonics",
)

args = parser.parse_args()

# Process everything.
sphericart_example(args.l, args.s, args.normalized)
sphericart_example(args.l, args.s)
24 changes: 9 additions & 15 deletions examples/pytorch/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,20 @@ class SHModule(torch.nn.Module):
"""Example of how to use SphericalHarmonics from within a
`torch.nn.Module`"""

def __init__(self, l_max, normalized=False):
self.spherical_harmonics = sphericart.torch.SphericalHarmonics(l_max, normalized)
def __init__(self, l_max):
self.spherical_harmonics = sphericart.torch.SphericalHarmonics(l_max)
# or SolidHarmonics if we wanted to compute solid harmonics
super().__init__()

def forward(self, xyz):
sh = self.spherical_harmonics(xyz) # or self.spherical_harmonics.compute(xyz)
return sh


def sphericart_example(l_max=10, n_samples=10000, normalized=False):
def sphericart_example(l_max=10, n_samples=10000):
# `sphericart` provides a SphericalHarmonics object that initializes the
# calculation and then can be called on any n x 3 arrays of Cartesian
# coordinates. It computes _all_ SPH up to a given l_max, and can compute
# scaled (default) and normalized (standard Ylm) harmonics.
# coordinates. It computes _all_ SPH up to a given l_max.

# ===== set up the calculation =====

Expand All @@ -43,7 +43,7 @@ def sphericart_example(l_max=10, n_samples=10000, normalized=False):

# ===== API calls =====

sh_calculator = sphericart.torch.SphericalHarmonics(l_max, normalized=normalized)
sh_calculator = sphericart.torch.SphericalHarmonics(l_max)

# the interface allows to return directly the forward derivatives (up to second
# order), similar to the Python version
Expand Down Expand Up @@ -92,7 +92,7 @@ def sphericart_example(l_max=10, n_samples=10000, normalized=False):
# double derivatives. In order to access them via back-propagation, an additional
# flag must be specified at class instantiation:
sh_calculator_2 = sphericart.torch.SphericalHarmonics(
l_max, normalized=normalized, backward_second_derivatives=True
l_max, backward_second_derivatives=True
)

# double grad() call:
Expand All @@ -116,7 +116,7 @@ def func(xyz):
# ===== torchscript integration =====
xyz_jit = xyz.clone().detach().type(torch.float64).to("cpu").requires_grad_()

module = SHModule(l_max, normalized)
module = SHModule(l_max)

# JIT compilation of the module
script = torch.jit.script(module)
Expand Down Expand Up @@ -157,14 +157,8 @@ def func(xyz):

parser.add_argument("-l", type=int, default=10, help="maximum angular momentum")
parser.add_argument("-s", type=int, default=1000, help="number of samples")
parser.add_argument(
"--normalized",
action="store_true",
default=False,
help="compute normalized spherical harmonics",
)

args = parser.parse_args()

# Process everything.
sphericart_example(args.l, args.s, args.normalized)
sphericart_example(args.l, args.s)
2 changes: 1 addition & 1 deletion python/src/sphericart/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .spherical_harmonics import SphericalHarmonics # noqa
from .spherical_harmonics import SphericalHarmonics, SolidHarmonics # noqa
Loading
Loading