Skip to content

Commit

Permalink
Enable optional use of jeig for eigendecomposition (#127)
Browse files Browse the repository at this point in the history
* Use jeig for eigendecomposition

* make jeig dep optional

* make jeig dep optional

* tests with jeig

* make jeig dep optional

* reorder imports

* install jeig for lint and typecheck

---------

Co-authored-by: Martin Schubert <[email protected]>
  • Loading branch information
mfschubert and Martin Schubert authored Sep 4, 2024
1 parent cc5fed0 commit e5cc6cd
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 7 deletions.
20 changes: 19 additions & 1 deletion .github/workflows/build-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
- name: Setup environment
run: |
python -m pip install --upgrade pip
pip install ".[dev]"
pip install ".[dev,jeig]"
- name: Lint Python files
run: |
find . -name "*.py" | xargs black --check
Expand Down Expand Up @@ -55,6 +55,24 @@ jobs:
- name: Test fmmax
run: pytest tests/fmmax

test-fmmax-jeig:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Setup python
uses: actions/setup-python@v5
with:
python-version: "3.10"
cache: "pip"
cache-dependency-path: pyproject.toml
- name: Setup environment
run: |
python -m pip install --upgrade pip
pip install ".[tests,dev,jeig]"
- name: Test fmmax
run: pytest tests/fmmax

test-grcwa:
runs-on: ubuntu-latest
steps:
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ dependencies = [
"jaxlib",
"numpy",
]

[project.optional-dependencies]
jeig = [
"jeig",
]
tests = [
"grcwa",
"parameterized",
Expand Down
26 changes: 23 additions & 3 deletions src/fmmax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,18 @@
import jax
import jax.numpy as jnp

# The `jeig` package offers several jax-wrapped implementations of eigendecomposition,
# some of which have performance benefits. However, since `jeig` has a dependency on
# pytorch, we make its use optional. If `jeig` is not available, we fall back on a
# pure-jax implementation of the eigendecomposition.
try:
import jeig

_JEIG_AVALABLE = True
except ModuleNotFoundError:
_JEIG_AVALABLE = False


EIG_EPS_RELATIVE = 1e-12
EIG_EPS_MINIMUM = 1e-24

Expand Down Expand Up @@ -120,10 +132,10 @@ def eig(
The eigenvalues and eigenvectors.
"""
del eps_relative
return _eig_host(matrix)
return _eig(matrix)


def _eig_host(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
def _eig_host_jax(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Wraps jnp.linalg.eig so that it can be jit-ed on a machine with GPUs."""

def _eig_cpu(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
Expand All @@ -143,12 +155,20 @@ def _eig_cpu(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
)


def _eig(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Eigendecomposition using `jeig` if available, and `_eig_host_jax` if not."""
if _JEIG_AVALABLE:
return jeig.eig(matrix)
else:
return _eig_host_jax(matrix)


def _eig_fwd(
matrix: jnp.ndarray,
eps_relative: float,
) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray, float]]:
"""Implements the forward calculation for `eig`."""
eigenvalues, eigenvectors = _eig_host(matrix)
eigenvalues, eigenvectors = _eig(matrix)
return (eigenvalues, eigenvectors), (eigenvalues, eigenvectors, eps_relative)


Expand Down
4 changes: 2 additions & 2 deletions tests/fmmax/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ def test_value_matches_eig_with_nondegenerate_eigenvalues(self):
jax.device_put(matrix, device=jax.devices("cpu")[0])
)
eigval, eigvec = utils.eig(matrix)
onp.testing.assert_array_equal(eigval, expected_eigval)
onp.testing.assert_array_equal(eigvec, expected_eigvec)
onp.testing.assert_allclose(eigval, expected_eigval, rtol=1e-12)
onp.testing.assert_allclose(eigvec, expected_eigvec, rtol=1e-12)

def test_eigvalue_jacobian_matches_expected_real_matrix(self):
matrix = jax.random.normal(jax.random.PRNGKey(0), (2, 4, 4)).astype(complex)
Expand Down

0 comments on commit e5cc6cd

Please sign in to comment.