Skip to content

Commit

Permalink
Merge pull request #145 from lab-cosmo/torchexp1
Browse files Browse the repository at this point in the history
Add PyTorch implementation of the exponential integral function
  • Loading branch information
E-Rum authored Jan 22, 2025
2 parents 56e18b5 + 10e250a commit 324f6b3
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 30 deletions.
1 change: 1 addition & 0 deletions docs/src/references/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ changelog <https://keepachangelog.com/en/1.1.0/>`_ format. This project follows
Added
#####

* Added a PyTorch implementation of the exponential integral function
* Added ``dtype`` and ``device`` for ``Calculator`` classses

Changed
Expand Down
2 changes: 1 addition & 1 deletion docs/src/references/lib/math.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Math
####

.. autofunction:: torchpme.lib.exp1
.. autofunction:: torchpme.lib.gamma
.. autofunction:: torchpme.lib.torch_exp1
.. autofunction:: torchpme.lib.gammaincc_over_powerlaw
5 changes: 2 additions & 3 deletions src/torchpme/lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
generate_kvectors_for_mesh,
get_ns_mesh,
)
from .math import CustomExp1, gamma, gammaincc_over_powerlaw, torch_exp1
from .math import exp1, gamma, gammaincc_over_powerlaw
from .mesh_interpolator import MeshInterpolator
from .splines import (
CubicSpline,
Expand All @@ -28,7 +28,6 @@
"generate_kvectors_for_mesh",
"get_ns_mesh",
"gamma",
"CustomExp1",
"gammaincc_over_powerlaw",
"torch_exp1",
"exp1",
]
80 changes: 64 additions & 16 deletions src/torchpme/lib/math.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch
from scipy.special import exp1
from torch.special import gammaln


Expand All @@ -14,40 +13,89 @@ def gamma(x: torch.Tensor) -> torch.Tensor:
return torch.exp(gammaln(x))


class CustomExp1(torch.autograd.Function):
"""Custom exponential integral function Exp1(x) to have an autograd-compatible version."""

class _CustomExp1(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
input_numpy = input.cpu().numpy() if not input.is_cpu else input.numpy()
return torch.tensor(exp1(input_numpy), device=input.device, dtype=input.dtype)
def forward(ctx, x):
# this implementation is inspired by the one in scipy:
# https://github.com/scipy/scipy/blob/34d91ce06d4d05e564b79bf65288284247b1f3e3/scipy/special/xsf/expint.h#L22
ctx.save_for_backward(x)

# Constants
SCIPY_EULER = (
0.577215664901532860606512090082402431 # Euler-Mascheroni constant
)
inf = torch.inf

# Handle case when x == 0
result = torch.full_like(x, inf)
mask = x > 0

# Compute for x <= 1
x_small = x[mask & (x <= 1)]
if x_small.numel() > 0:
e1 = torch.ones_like(x_small)
r = torch.ones_like(x_small)
for k in range(1, 26):
r = -r * k * x_small / (k + 1.0) ** 2
e1 += r
if torch.all(torch.abs(r) <= torch.abs(e1) * 1e-15):
break
result[mask & (x <= 1)] = -SCIPY_EULER - torch.log(x_small) + x_small * e1

# Compute for x > 1
x_large = x[mask & (x > 1)]
if x_large.numel() > 0:
m = 20 + (80.0 / x_large).to(torch.int32)
t0 = torch.zeros_like(x_large)
for k in range(m.max(), 0, -1):
t0 = k / (1.0 + k / (x_large + t0))
t = 1.0 / (x_large + t0)
result[mask & (x > 1)] = torch.exp(-x_large) * t

return result

@staticmethod
def backward(ctx, grad_output):
(input,) = ctx.saved_tensors
return -grad_output * torch.exp(-input) / input
(x,) = ctx.saved_tensors
return -grad_output * torch.exp(-x) / x


def exp1(x):
r"""
Exponential integral E1.
def torch_exp1(input):
"""Wrapper for the custom exponential integral function."""
return CustomExp1.apply(input)
For a real number :math:`x > 0` the exponential integral can be defined as
.. math::
E1(x) = \int_{x}^{\infty} \frac{e^{-t}}{t} dt
:param x: Input tensor (x > 0)
:return: Exponential integral E1(x)
"""
return _CustomExp1.apply(x)


def gammaincc_over_powerlaw(exponent: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
"""Function to compute the regularized incomplete gamma function complement for integer exponents."""
"""
Compute the regularized incomplete gamma function complement for integer exponents.
:param exponent: Exponent of the power law
:param z: Value at which to evaluate the function
:return: Regularized incomplete gamma function complement
"""
if exponent == 1:
return torch.exp(-z) / z
if exponent == 2:
return torch.sqrt(torch.pi / z) * torch.erfc(torch.sqrt(z))
if exponent == 3:
return torch_exp1(z)
return exp1(z)
if exponent == 4:
return 2 * (
torch.exp(-z) - torch.sqrt(torch.pi * z) * torch.erfc(torch.sqrt(z))
)
if exponent == 5:
return torch.exp(-z) - z * torch_exp1(z)
return torch.exp(-z) - z * exp1(z)
if exponent == 6:
return (
(2 - 4 * z) * torch.exp(-z)
Expand Down
4 changes: 2 additions & 2 deletions src/torchpme/potentials/inversepowerlaw.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,12 @@ def background_correction(self) -> torch.Tensor:
# "charge neutrality" correction for 1/r^p potential diverges for exponent p = 3
# and is not needed for p > 3 , so we set it to zero (see in
# https://doi.org/10.48550/arXiv.2412.03281 SI section)
if self.exponent >= 3:
return torch.tensor(0.0, dtype=self.dtype, device=self.device)
if self.smearing is None:
raise ValueError(
"Cannot compute background correction without specifying `smearing`."
)
if self.exponent >= 3:
return self.smearing * 0.0
prefac = torch.pi**1.5 * (2 * self.smearing**2) ** ((3 - self.exponent) / 2)
prefac /= (3 - self.exponent) * gamma(self.exponent / 2)
return prefac
Expand Down
22 changes: 14 additions & 8 deletions tests/lib/test_math.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,31 @@
import numpy as np
import scipy.special
import torch
from scipy.special import exp1

from torchpme.lib import torch_exp1
from torchpme.lib import exp1


def finite_difference_derivative(func, x, h=1e-5):
return (func(x + h) - func(x - h)) / (2 * h)


def test_torch_exp1_consistency_with_scipy():
x = torch.rand(1000, dtype=torch.float64)
torch_result = torch_exp1(x)
scipy_result = exp1(x.numpy())
assert np.allclose(torch_result.numpy(), scipy_result, atol=1e-6)
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random_tensor = torch.rand(100000) * 1000
random_array = random_tensor.numpy()
scipy_result = scipy.special.exp1(random_array)
torch_result = exp1(random_tensor)
assert np.allclose(scipy_result, torch_result.numpy(), atol=1e-15)


def test_torch_exp1_derivative():
x = torch.rand(1, dtype=torch.float64, requires_grad=True)
torch_result = torch_exp1(x)
torch_result = exp1(x)
torch_result.backward()
torch_exp1_prime = x.grad
finite_diff_result = finite_difference_derivative(exp1, x.detach().numpy())
finite_diff_result = finite_difference_derivative(
scipy.special.exp1, x.detach().numpy()
)
assert np.allclose(torch_exp1_prime.numpy(), finite_diff_result, atol=1e-6)

0 comments on commit 324f6b3

Please sign in to comment.